您当前的位置:首页 > 电脑百科 > 人工智能

使用EMNIST数据集训练第一个pytorch CNN手写字母识别神经网络

时间:2021-09-08 09:32:01  来源:  作者:人工智能研究所

MNIST 这里就不多展开了,我们上几期的文章都是使用此数据集进行的分享。

使用EMNIST数据集训练第一个pytorch CNN手写字母识别神经网络

手写字母识别

EMNIST数据集

Extended MNIST (EMNIST), 因为 MNIST 被大家熟知,所以这里就推出了 EMNIST ,一个在手写字体分类任务中更有挑战的 Benchmark 。此数据集当然也包含了手写数字的数据集

在数据集接口上,此数据集的处理方式跟 MNIST 保持一致,也是为了方便已经熟悉 MNIST 的我们去使用,这里着重介绍一下 EMNIST 的分类方式。

分类方式

EMNIST 主要分为以下 6 类:

By_Class : 共 814255 张,62 类,与 NIST 相比重新划分类训练集与测试机的图片数

By_Merge: 共 814255 张,47 类, 与 NIST 相比重新划分类训练集与测试机的图片数

Balanced : 共 131600 张,47 类, 每一类都包含了相同的数据,每一类训练集 2400 张,测试集 400 张

Digits :共 28000 张,10 类,每一类包含相同数量数据,每一类训练集 24000 张,测试集 4000 张

Letters : 共 103600 张,37 类,每一类包含相同数据,每一类训练集 2400 张,测试集 400 张

MNIST : 共 70000 张,10 类,每一类包含相同数量数据(注:这里虽然数目和分类都一样,但是图片的处理方式不一样,EMNIST 的 MNIST 子集数字占的比重更大)

这里为什么后面的分类不是26+26?其主要原因是一些大小写字母比较类似的字母就合并了,比如C等等

使用EMNIST数据集训练第一个pytorch CNN手写字母识别神经网络

 

代码实现手写字母训练神经网络

由于EMNIST数据集与MNIST类似,我们直接使用MNIST的训练代码进行此神经网络的训练

import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision  # 数据库模块
import matplotlib.pyplot as plt
# torch.manual_seed(1)  # reproducible
EPOCH = 1  # 训练整批数据次数,训练次数越多,精度越高,为了演示,我们训练5次
BATCH_SIZE = 50  # 每次训练的数据集个数
LR = 0.001  # 学习效率
DOWNLOAD_MNIST = Ture  # 如果你已经下载好了EMNIST数据就设置 False

# EMNIST 手写字母 训练集
train_data = torchvision.datasets.EMNIST(
    root='./data',
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download = DOWNLOAD_MNIST,
    split = 'letters' 
)
# EMNIST 手写字母 测试集
test_data = torchvision.datasets.EMNIST(
    root='./data',
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=False,
    split = 'letters'     
)
# 批训练 50samples, 1 channel, 28x28 (50, 1, 28, 28)
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
# 每一步 loader 释放50个数据用来学习
# 为了演示, 我们测试时提取2000个数据先
# shape from (2000, 28, 28) to (2000, 1, 28, 28), value in range(0,1)
test_x = torch.unsqueeze(test_data.data, dim=1).type(torch.FloatTensor)[:2000] / 255.  
test_y = test_data.targets[:2000]
#test_x = test_x.cuda() # 若有cuda环境,取消注释
#test_y = test_y.cuda() # 若有cuda环境,取消注释

首先,我们下载EMNIST数据集,这里由于我们分享过手写数字的部分,这里我们按照手写字母的部分进行神经网络的训练,其split为letters,这里跟MNIST数据集不一样的地方便是多了一个split标签,备注我们需要那个分类的数据

torchvision.datasets.EMNIST(root: str, split: str, **kwargs: Any)

root ( string ) –
数据集所在EMNIST/processed/training.pt 和 EMNIST/processed/test.pt存在的根目录。
split(字符串) 
  -该数据集具有6个不同的拆分:byclass,bymerge, balanced,letters,digits和mnist。此参数指定使用哪一个。
train ( bool , optional )
– 如果为 True,则从 中创建数据集training.pt,否则从test.pt.
download ( bool , optional ) 
– 如果为 true,则从 Internet 下载数据集并将其放在根目录中。如果数据集已经下载,则不会再次下载。
transform ( callable , optional ) 
– 一个函数/转换,它接收一个 PIL 图像并返回一个转换后的版本。例如,transforms.RandomCrop
target_transform ( callable , optional ) 
– 一个接收目标并对其进行转换的函数/转换。

可视化数据集

然后我们可视化一下此数据集,看看此数据集什么样子

def get_mApping(num, with_type='letters'):
    """
    根据 mapping,由传入的 num 计算 UTF8 字符
    """
    if with_type == 'byclass':
        if num <= 9:
            return chr(num + 48)  # 数字
        elif num <= 35:
            return chr(num + 55)  # 大写字母
        else:
            return chr(num + 61)  # 小写字母
    elif with_type == 'letters':
        return chr(num + 64) + " / " + chr(num + 96)  # 大写/小写字母
    elif with_type == 'digits':
        return chr(num + 96)
    else:
        return num

figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(train_data), size=(1,)).item()
    img, label = train_data[sample_idx]
    print(label)
    figure.add_subplot(rows, cols, i)
    plt.title(get_mapping(label))
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

字符编码表(UTF-8)

使用EMNIST数据集训练第一个pytorch CNN手写字母识别神经网络

字符编码表

首先我们建立一个字符编码表的规则,这是由于此数据集的label,是反馈一个字母的字符编码,我们只有通过此字符编码表才能对应起来神经网络识别的是那个字母

使用EMNIST数据集训练第一个pytorch CNN手写字母识别神经网络

 

通过观察其数据集,此字母总是感觉怪,这是由于EMNIST数据集左右翻转图片后,又进行了图片的逆时针旋转90度,小编认为是进行图片数据的特征增强,目前没有一个文件来介绍为什么这样做,大家可以讨论

搭建神经网络

神经网络的搭建,我们直接使用MNIST神经网络

# 定义神经网络
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Sequential(  # input shape (1, 28, 28)
            nn.Conv2d(
                in_channels=1,  # 输入通道数
                out_channels=16,  # 输出通道数
                kernel_size=5,   # 卷积核大小
                stride=1,  #卷积步数
                padding=2,  # 如果想要 con2d 出来的图片长宽没有变化, 
                            # padding=(kernel_size-1)/2 当 stride=1
            ),  # output shape (16, 28, 28)
            nn.ReLU(),  # activation
            nn.MaxPool2d(kernel_size=2),  # 在 2x2 空间里向下采样, output shape (16, 14, 14)
        )
        self.conv2 = nn.Sequential(  # input shape (16, 14, 14)
            nn.Conv2d(16, 32, 5, 1, 2),  # output shape (32, 14, 14)
            nn.ReLU(),  # activation
            nn.MaxPool2d(2),  # output shape (32, 7, 7)
        )
        self.out = nn.Linear(32 * 7 * 7, 37)  # 全连接层,A/Z,a/z一共37个类

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)  # 展平多维的卷积图成 (batch_size, 32 * 7 * 7)
        output = self.out(x)
        return output

以上代码在介绍手写数字时有详细的介绍,大家可以参考文章开头的文章链接,这里不再详细介绍了,这里需要注意的是神经网络有37个分类

神经网络的训练

cnn = CNN() # 创建CNN
# cnn = cnn.cuda() # 若有cuda环境,取消注释
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.CrossEntropyLoss()  # the target label is not one-hotted
for epoch in range(EPOCH):
    for step, (b_x, b_y) in enumerate(train_loader):  # 每一步 loader 释放50个数据用来学习
        #b_x = b_x.cuda() # 若有cuda环境,取消注释
        #b_y = b_y.cuda() # 若有cuda环境,取消注释
        output = cnn(b_x)  # 输入一张图片进行神经网络训练
        loss = loss_func(output, b_y)  # 计算神经网络的预测值与实际的误差
        optimizer.zero_grad() #将所有优化的torch.Tensors的梯度设置为零
        loss.backward()  # 反向传播的梯度计算
        optimizer.step()  # 执行单个优化步骤
        if step % 50 == 0: # 我们每50步来查看一下神经网络训练的结果
            test_output = cnn(test_x)
            pred_y = torch.max(test_output, 1)[1].data.squeeze()
            # 若有cuda环境,使用84行,注释82行
            # pred_y = torch.max(test_output, 1)[1].cuda().data.squeeze()
            accuracy = float((pred_y == test_y).sum()) / float(test_y.size(0))
            print('Epoch: ', epoch, '| train loss: %.4f' % loss.data, 
                '| test accuracy: %.2f' % accuracy)

神经网络的训练部分也是往期代码,这里建议大家训练的EPOCH设置的大点,毕竟37个分类若只是训练一两次,其精度也很难保证

测试神经网络与保存模型

# test 神经网络
test_output = cnn(test_x[:10])
pred_y = torch.max(test_output, 1)[1].data.squeeze()
# 若有cuda环境,使用92行,注释90行
#pred_y = torch.max(test_output, 1)[1].cuda().data.squeeze()
print(pred_y, 'prediction number')
print(test_y[:10], 'real number')
# save CNN
# 仅保存CNN参数,速度较快
torch.save(cnn.state_dict(), './model/CNN_letter.pk')
# 保存CNN整个结构
#torch.save(cnn(), './model/CNN.pkl')

训练完成后,我们可以检测一下神经网络的训练结果,这里需要注意的是,pred_y与test_y只是字母的编码,并不是真实的字母,最后我们保存一下神经网络的模型,方便下期进行手写字母的识别篇



Tags:神经网络   点击:()  评论:()
声明:本站部分内容及图片来自互联网,转载是出于传递更多信息之目的,内容观点仅代表作者本人,如有任何标注错误或版权侵犯请与我们联系(Email:2595517585@qq.com),我们将及时更正、删除,谢谢。
▌相关推荐
SimpleAI.人工智能、机器学习、深度学习还是遥不可及?来这里看看吧~ 从基本的概念、原理、公式,到用生动形象的例子去理解,到动手做实验去感知,到著名案例的学习,到用所学来实现...【详细内容】
2021-10-19  Tags: 神经网络  点击:(47)  评论:(0)  加入收藏
MNIST 这里就不多展开了,我们上几期的文章都是使用此数据集进行的分享。手写字母识别EMNIST数据集Extended MNIST (EMNIST), 因为 MNIST 被大家熟知,所以这里就推出了 EMNIST...【详细内容】
2021-09-08  Tags: 神经网络  点击:(182)  评论:(0)  加入收藏
理解什么是人工智能,以及机器学习和深度学习如何影响它,是一种不同凡响的体验。在 Mate Labs 我们有一群自学有成的工程师,希望本文能够分享一些学习的经验和捷径,帮助机器学习...【详细内容】
2021-06-09  Tags: 神经网络  点击:(128)  评论:(0)  加入收藏
资料来源:getwallpapers.com 深度学习是机器学习中重要分支之一。它的目的是教会计算机做那些对于人类来说相当自然的事情。深度学习也是无人驾驶汽车背后的一项关键性技术,...【详细内容】
2021-04-13  Tags: 神经网络  点击:(168)  评论:(0)  加入收藏
私有虚拟网络(VPN)是在公用网络基础之上建立的私有加密通信隧道网络,企业对于自管辖网络中个人使用VPN软件行为具有监管责任,但技术上却很难识别VPN的加密与通信方式,因此利用人工智能(AI)领域的神经网络技术从网络流量中识...【详细内容】
2021-01-07  Tags: 神经网络  点击:(193)  评论:(0)  加入收藏
深层神经网络的模型概括,过度拟合和正则化方法的挑战> Source 在完成了与神经网络有关的多个AI项目之后,我意识到模型的概括能力对于AI项目的成功至关重要。 我想写这篇文章来...【详细内容】
2020-10-30  Tags: 神经网络  点击:(101)  评论:(0)  加入收藏
本文最初发表于 Towards Data Science 博客,经原作者 Andre Ye 授权,InfoQ 中文站翻译并分享。卷积神经网络(Convolutional Nerual Network,CNN)构成了图像识别的基础,这无疑是深...【详细内容】
2020-10-16  Tags: 神经网络  点击:(115)  评论:(0)  加入收藏
本报告讨论了非常厉害模型优化技术 —— 知识蒸馏,并给大家过了一遍相关的TensorFlow的代码。...【详细内容】
2020-09-25  Tags: 神经网络  点击:(92)  评论:(0)  加入收藏
算法算法从1950年代的早期研究开始,机器学习的所有工作似乎都随着神经网络的创建而积累起来。 从逻辑回归到支持向量机,相继提出了新算法之后的算法,但是从字面上看,神经网络是...【详细内容】
2020-09-17  Tags: 神经网络  点击:(90)  评论:(0)  加入收藏
卷积神经网络(CNN)广泛应用于深度学习和计算机视觉算法中。虽然很多基于CNN的算法符合行业标准,可以嵌入到商业产品中,但是标准的CNN算法仍然有局限性,在很多方面还可以改进。这篇文章讨论了语义分割和编码器-解码器架构...【详细内容】
2020-09-17  Tags: 神经网络  点击:(93)  评论:(0)  加入收藏
▌简易百科推荐
作为数据科学家或机器学习从业者,将可解释性集成到机器学习模型中可以帮助决策者和其他利益相关者有更多的可见性并可以让他们理解模型输出决策的解释。在本文中,我将介绍两个...【详细内容】
2021-12-17  deephub    Tags:AI   点击:(15)  评论:(0)  加入收藏
基于算法的业务或者说AI的应用在这几年发展得很快。但是,在实际应用的场景中,我们经常会遇到一些非常奇怪的偏差现象。例如,Facebook将黑人标记为灵长类动物、城市图像识别系统...【详细内容】
2021-11-08  数据学习DataLearner    Tags:机器学习   点击:(32)  评论:(0)  加入收藏
11月2日召开的世界顶尖科学家数字未来论坛上,2013年诺贝尔化学奖得主迈克尔·莱维特、2014年诺贝尔生理学或医学奖得主爱德华·莫索尔、2007年图灵奖得主约瑟夫·斯发斯基、1986年图灵奖得主约翰·霍普克罗夫特、2002...【详细内容】
2021-11-03  张淑贤  证券时报  Tags:人工智能   点击:(39)  评论:(0)  加入收藏
鉴于物联网设备广泛部署、5G快速无线技术闪亮登场,把计算、存储和分析放在靠近数据生成的地方来处理,让边缘计算有了用武之地。 边缘计算正在改变全球数百万个设备处理和传输...【详细内容】
2021-10-26    计算机世界  Tags:边缘计算   点击:(45)  评论:(0)  加入收藏
这是几位机器学习权威专家汇总的725个机器学习术语表,非常全面了,值得收藏! 英文术语 中文翻译 0-1 Loss Function 0-1损失函数 Accept-Reject Samplin...【详细内容】
2021-10-21  Python部落    Tags:机器学习   点击:(43)  评论:(0)  加入收藏
要开始为开源项目做贡献,有一些先决条件:1. 学习一门编程语言:由于在开源贡献中你需要编写代码才能参与开发,你需要学习任意一门编程语言。根据项目的需要,在后期学习另一种语言...【详细内容】
2021-10-20  TSINGSEE青犀视频    Tags:机器学习   点击:(37)  评论:(0)  加入收藏
SimpleAI.人工智能、机器学习、深度学习还是遥不可及?来这里看看吧~ 从基本的概念、原理、公式,到用生动形象的例子去理解,到动手做实验去感知,到著名案例的学习,到用所学来实现...【详细内容】
2021-10-19  憨昊昊    Tags:神经网络   点击:(47)  评论:(0)  加入收藏
语言是人类思维的基础,当计算机具备了处理自然语言的能力,才具有真正智能的想象。自然语言处理(Natural Language Processing, NLP)作为人工智能(Artificial Intelligence, AI)的核心技术之一,是用计算机来处理、理解以及运...【详细内容】
2021-10-11    36氪  Tags:NLP   点击:(48)  评论:(0)  加入收藏
边缘计算是什么?近年来,物联网设备数量呈线性增长趋势。根据艾瑞测算, 2020年,中国物联网设备的数量达74亿,预计2025年突破150亿个。同时,设备本身也变得越来越智能化,AI与互联网在...【详细内容】
2021-09-22  汉智兴科技    Tags:   点击:(54)  评论:(0)  加入收藏
说起人工智能,大家总把它和科幻电影中的机器人联系起来,而实际上这些科幻场景与现如今的人工智能没什么太大关系。人工智能确实跟人类大脑很相似,但它们的显著差异在于人工智能...【详细内容】
2021-09-17  异步社区    Tags:人工智能   点击:(57)  评论:(0)  加入收藏
最新更新
栏目热门
栏目头条