您当前的位置:首页 > 电脑百科 > 程序开发 > 语言 > Python

用Pytorch基于MNIST实现手写数字识别

时间:2020-09-27 10:29:24  来源:  作者:

代码的基本结构还是延续我通过深度学习神经网络,基于MNIST实现手写数字识别 的结构,只是神经网络部分使用了Pytorch的API。

有一些地方要多说一点,但是不展开:

1、激活函数选用了ReLU,而非之前的sigmoid,二者的不同,网上文章很多,有机会总结一下。

2、可以跟前文的代码进行比较看,主要看train、query两个方法,感受一下Pytorch的封装。

3、用Pytorch构建的神经网络,在训练、测试时采用对应的模式train()、eval(),主要是对BN、Dropout层进行设置,具体的情况有机会详细说一下。

4、本次还是使用了200个隐藏层节点,学习率0.1,使用了6W+条数据用以训练,1W+条数据用以测试,激活函数分别用ReLU,Sigmoid训练了7个世代,结果如下:

Sigmoid 7世代
准确率=94.21%
该循环程序运行时间: 404.66520285606384
Relu 7世代
准确率=98.15%
该循环程序运行时间: 396.1038899421692

之前代码,7个世代结果:

准确率=97.26%
该循环程序运行时间: 314.45252776145935

代码如下:

# -*- coding: utf-8 -*-
#!/usr/bin/Python3import numpyimport torchimport torch.optim as optim
from torch import nn
from time import *
class NNW(torch.nn.Module):
    def __init__(self):        super(NNW, self).__init__()
        self.inputLinear=torch.nn.Linear(784,200)
        self.hidden1Linear=torch.nn.Linear(200,200)
        self.outLinear=torch.nn.Linear(200,10)
            def forward(self,x):        #relu激活函数        x=torch.nn.functional.relu(self.inputLinear(x))        x=torch.nn.functional.relu(self.hidden1Linear(x))        return self.outLinear(x)
        #sigmoid激活函数        #x=torch.sigmoid(self.inputLinear(x))        #x=torch.sigmoid(self.hidden1Linear(x))        #return torch.sigmoid(self.outLinear(x))
begin_time = time()nnw=NNW()#损失函数loss=torch.nn.MSELoss()#随机梯度下降optimizer=optim.SGD(nnw.parameters(),0.1)
def train(inputs,targets):    #前向传播    output = nnw(inputs)    loss_result=loss(output,targets)    #反向传播    optimizer.zero_grad()       loss_result.backward()    optimizer.step()def query(inputs):    out=nnw(inputs)
    return out
#从文件取出训练数据trainFile=open("mnist_train.csv","r")
trains=trainFile.readlines()trainFile.close#从文件取出测试数据testFile=open("mnist_test.csv","r")
tests=testFile.readlines()testFile.close#训练模式,启用BN层、Dropout层nnw.train()for size in range(1):
    print("第{}次训练".format(size+1))
    for data in trains:
        allVals=data.split(",")
        inputs_list=numpy.asfarray(allVals[1:])/255.0*0.99+0.01
                targets_list=numpy.zeros(10)+0.01
        targets_list[int(allVals[0])]=0.99        
        inputs=torch.autograd.Variable(torch.tensor(inputs_list))        targets=torch.autograd.Variable(torch.tensor(targets_list))        train(inputs.float(),targets.float())score=[]#测试模式,固定住BN层和Dropout层,使用已经训练好的值nnw.eval()for data in tests:
    allVals=data.split(",")
    realNum=int(allVals[0])
    inputs_list=numpy.asfarray(allVals[1:])/255.0*0.99+0.01
    inputs = torch.autograd.Variable(torch.tensor(inputs_list))    result=query(inputs.float())    outNum=numpy.argmax(result.cpu().detach().numpy())    #print("真实值{},结果值{}".format(realNum,outNum))
    if(outNum==realNum):
        score.Append(1)
    else:
        score.append(0)
scoreArr=numpy.asarray(score)print("准确率={}%".format(scoreArr.sum()/scoreArr.size*100))
end_time = time()run_time = end_time-begin_timeprint ('该循环程序运行时间:',run_time)


Tags:MNIST   点击:()  评论:()
声明:本站部分内容及图片来自互联网,转载是出于传递更多信息之目的,内容观点仅代表作者本人,如有任何标注错误或版权侵犯请与我们联系(Email:2595517585@qq.com),我们将及时更正、删除,谢谢。
▌相关推荐
MNIST 这里就不多展开了,我们上几期的文章都是使用此数据集进行的分享。手写字母识别EMNIST数据集Extended MNIST (EMNIST), 因为 MNIST 被大家熟知,所以这里就推出了 EMNIST...【详细内容】
2021-09-08  Tags: MNIST  点击:(182)  评论:(0)  加入收藏
代码的基本结构还是延续我通过深度学习神经网络,基于MNIST实现手写数字识别 的结构,只是神经网络部分使用了Pytorch的API。有一些地方要多说一点,但是不展开:1、激活函数选用了R...【详细内容】
2020-09-27  Tags: MNIST  点击:(104)  评论:(0)  加入收藏
前文我们介绍了如何在Windows环境下安装TensorFlow(人工智能学习入门之TensorFlow2.2版本安装(Windows10) )。学习知识最好的方法就是实践了,因此接下来我们通过实操的方式来学习...【详细内容】
2020-08-20  Tags: MNIST  点击:(86)  评论:(0)  加入收藏
▌简易百科推荐
Python 是一个很棒的语言。它是世界上发展最快的编程语言之一。它一次又一次地证明了在开发人员职位中和跨行业的数据科学职位中的实用性。整个 Python 及其库的生态系统使...【详细内容】
2021-12-27  IT资料库    Tags:Python 库   点击:(1)  评论:(0)  加入收藏
菜单驱动程序简介菜单驱动程序是通过显示选项列表从用户那里获取输入并允许用户从选项列表中选择输入的程序。菜单驱动程序的一个简单示例是 ATM(自动取款机)。在交易的情况下...【详细内容】
2021-12-27  子冉爱python    Tags:Python   点击:(1)  评论:(0)  加入收藏
有不少同学学完Python后仍然很难将其灵活运用。我整理15个Python入门的小程序。在实践中应用Python会有事半功倍的效果。01 实现二元二次函数实现数学里的二元二次函数:f(x,...【详细内容】
2021-12-22  程序汪小成    Tags:Python入门   点击:(32)  评论:(0)  加入收藏
Verilog是由一个个module组成的,下面是其中一个module在网表中的样子,我只需要提取module名字、实例化关系。module rst_filter ( ...); 端口声明... wire定义......【详细内容】
2021-12-22  编程啊青    Tags:Verilog   点击:(7)  评论:(0)  加入收藏
运行环境 如何从 MP4 视频中提取帧 将帧变成 GIF 创建 MP4 到 GIF GUI ...【详细内容】
2021-12-22  修道猿    Tags:Python   点击:(5)  评论:(0)  加入收藏
面向对象:Object Oriented Programming,简称OOP,即面向对象程序设计。类(Class)和对象(Object)类是用来描述具有相同属性和方法对象的集合。对象是类的具体实例。比如,学生都有...【详细内容】
2021-12-22  我头秃了    Tags:python   点击:(9)  评论:(0)  加入收藏
所谓内置函数,就是Python提供的, 可以直接拿来直接用的函数,比如大家熟悉的print,range、input等,也有不是很熟,但是很重要的,如enumerate、zip、join等,Python内置的这些函数非常...【详细内容】
2021-12-21  程序员小新ds    Tags:python初   点击:(5)  评论:(0)  加入收藏
Hi,大家好。我们在接口自动化测试项目中,有时候需要一些加密。今天给大伙介绍Python实现各种 加密 ,接口加解密再也不愁。目录一、项目加解密需求分析六、Python加密库PyCrypto...【详细内容】
2021-12-21  Python可乐    Tags:Python   点击:(7)  评论:(0)  加入收藏
借助pyautogui库,我们可以轻松地控制鼠标、键盘以及进行图像识别,实现自动抢课的功能1.准备工作我们在仓库里提供了2个必须的文件,包括: auto_get_lesson_pic_recognize.py:脚本...【详细内容】
2021-12-17  程序员道道    Tags:python   点击:(13)  评论:(0)  加入收藏
前言越来越多开发者表示,自从用了Python/Pandas,Excel都没有打开过了,用Python来处理与可视化表格就是四个字——非常快速!下面我来举几个明显的例子1.删除重复行和空...【详细内容】
2021-12-16  查理不是猹    Tags:Python   点击:(20)  评论:(0)  加入收藏
最新更新
栏目热门
栏目头条