手写数字识别任务第一次训练(结果不好)

云深之无迹 2021-01-27 00:00

数字识别是计算机从纸质文档、照片或其他来源接收、理解并识别可读的数字的能力,目前比较受关注的是手写数字识别。手写数字识别是一个典型的图像分类问题,已经被广泛应用于汇款单号识别、手写邮政编码识别,大大缩短了业务处理时间,提升了工作效率和质量。另一个重要的原因是,对于编程来说入门是打印一个HelloWorld,但是深度学习入门就是实现一个手写数字的识别~



图1:手写数字识别任务示意图
  • 任务输入:一系列手写数字图片,其中每张图片都是28x28的像素矩阵。

  • 任务输出:经过了大小归一化和居中处理,输出对应的0~9数字标签。



在处理如 图1 所示的手写邮政编码的简单图像分类任务时,可以使用基于MNIST数据集的手写数字识别模型。MNIST是深度学习领域标准、易用的成熟数据集,包含60000条训练样本和10000条测试样本。


MNIST数据集

MNIST数据集是从NIST的Special Database 3(SD-3)和Special Database 1(SD-1)构建而来。Yann LeCun等人从SD-1和SD-3中各取一半作为MNIST训练集和测试集,其中训练集来自250位不同的标注员,且训练集和测试集的标注员完全不同。

MNIST数据集的发布,吸引了大量科学家训练模型。1998年,LeCun分别用单层线性分类器、多层感知器(Multilayer Perceptron, MLP)和多层卷积神经网络LeNet进行实验使得测试集的误差不断下降(从12%下降到0.7%)。在研究过程中,LeCun提出了卷积神经网络(Convolutional Neural Network,CNN),大幅度地提高了手写字符的识别能力,也因此成为了深度学习领域的奠基人之一。

如今在深度学习领域,卷积神经网络占据了至关重要的地位,从最早LeCun提出的简单LeNet,到如今ImageNet大赛上的优胜模型VGGNet、GoogLeNet、ResNet等,人们在图像分类领域,利用卷积神经网络得到了一系列惊人的结果。

手写数字识别的模型是深度学习中相对简单的模型,非常适用初学者。

构建手写数字识别的神经网络模型

使用飞桨完成手写数字识别模型构建的代码结构如 图2 所示

训练的流程

我们这次训练,用GPU,嘤嘤嘤

看看我的GPU还好吗~

这个是我们今后做实验的流程图

import paddleimport paddle.fluid as fluidfrom paddle.fluid.dygraph.nn import Linearimport numpy as npimport osfrom PIL import Image

在进入训练环境之后,需要导入的Python库有这些

https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/data_cn/dataset_cn.html

相关要处理的数据在这个网址里有,相关的API也封装好了

# 如果~/.cache/paddle/dataset/mnist/目录下没有MNIST数据,API会自动将MINST数据下载到该文件夹下# 设置数据读取器,读取MNIST数据训练集trainset = paddle.dataset.mnist.train()# 包装数据读取器,每次读取的数据数量设置为batch_size=8train_reader = paddle.batch(trainset, batch_size=8)

通过paddle.dataset.mnist.train()函数设置数据读取器,

batch_size设置为8,即一个批次有8张图片和8个标签.

Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-images-idx3-ubyte.gz Begin to download
Download finishedCache file /home/aistudio/.cache/paddle/dataset/mnist/train-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-labels-idx1-ubyte.gz Begin to download........Download finished

这个是下载好的输出结果


长这个样的

# 以迭代的形式读取数据for batch_id, data in enumerate(train_reader()): # 获得图像数据,并转为float32类型的数组 img_data = np.array([x[0] for x in data]).astype('float32') # 获得图像标签数据,并转为float32类型的数组 label_data = np.array([x[1] for x in data]).astype('float32') # 打印数据形状 print("图像数据形状和对应数据为:", img_data.shape, img_data[0]) print("图像标签形状和对应数据为:", label_data.shape, label_data[0]) break
print("\n打印第一个batch的第一个图像,对应标签数字为{}".format(label_data[0]))# 显示第一batch的第一个图像import matplotlib.pyplot as pltimg = np.array(img_data[0]+1)*127.5img = np.reshape(img, [28, 28]).astype(np.uint8)
plt.figure("Image") # 图像窗口名称plt.imshow(img)plt.axis('on') # 关掉坐标轴为 offplt.title('image') # 图像题目plt.show()

paddle.batch函数将MNIST数据集拆分成多个批次,

通过如下代码读取第一个批次的数据内容,观察数据打印结果。



执行的结果很多,我就截图一些


从打印结果看,从数据加载器train_reader()中读取一次数据,可以得到形状为(8, 784)的图像数据和形状为(8,)的标签数据。其中,形状中的数字8与设置的batch_size大小对应,784为MINIST数据集中每个图像的像素大小(28*28)。

此外,从打印的图像数据来看,图像数据的范围是[-1, 1],表明这是已经完成图像归一化后的图像数据,并且空白背景部分的值是-1。将图像数据反归一化,并使用matplotlib工具包将其显示出来,如图2 所示。可以看到图片显示的数字是5,和对应标签数字一致。


图2:matplotlib打印结果示意图
因为存储的是28x28的向量图,所以坐标也是对应的


https://www.paddlepaddle.org.cn/documentation/docs/zh/api_guides/index_cn.html



一开始的API中讲了一些训练的基本概念

还有基本的数学概念

我们可以写一个代码验证

要用到的网络模型的样子

模型的输入为784维(28*28)数据,输出为1维数据




输入像素的位置排布信息对理解图像内容非常重要(如将原始尺寸为28*28图像的像素按照7*112的尺寸排布,那么其中的数字将不可识别),因此网络的输入设计为28*28的尺寸,而不是1*784,以便于模型能够正确处理像素之间的空间信息。


事实上,采用只有一层的简单网络(对输入求加权和)时并没有处理位置关系信息,因此可以猜测出此模型的预测效果可能有限。在后续优化环节介绍的卷积神经网络则更好的考虑了这种位置关系信息,模型的预测效果也会有显著提升。

# 定义mnist数据识别网络结构class MNIST(fluid.dygraph.Layer): def __init__(self): super(MNIST, self).__init__()
# 定义一层全连接层,输出维度是1,激活函数为None,即不使用激活函数 self.fc = Linear(input_dim=784, output_dim=1, act=None)
# 定义网络结构的前向计算过程 def forward(self, inputs): outputs = self.fc(inputs) return outputs

定义一个神经网络层

# 定义飞桨动态图工作环境with fluid.dygraph.guard(): # 声明网络结构 model = MNIST() # 启动训练模式 model.train() # 定义数据读取函数,数据读取batch_size设置为16 train_loader = paddle.batch(paddle.dataset.mnist.train(), batch_size=16) # 定义优化器,使用随机梯度下降SGD优化器,学习率设置为0.001 optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.001, parameter_list=model.parameters())

训练配置需要先生成模型实例(设为“训练”状态),再设置优化算法和学习率(使用随机梯度下降SGD,学习率设置为0.001)

# 通过with语句创建一个dygraph运行的context# 动态图下的一些操作需要在guard下进行with fluid.dygraph.guard(): model = MNIST() model.train() train_loader = paddle.batch(paddle.dataset.mnist.train(), batch_size=16) optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.001, parameter_list=model.parameters()) EPOCH_NUM = 10 for epoch_id in range(EPOCH_NUM): for batch_id, data in enumerate(train_loader()): #准备数据,格式需要转换成符合框架要求的 image_data = np.array([x[0] for x in data]).astype('float32') label_data = np.array([x[1] for x in data]).astype('float32').reshape(-1, 1) # 将数据转为飞桨动态图格式 image = fluid.dygraph.to_variable(image_data) label = fluid.dygraph.to_variable(label_data)
#前向计算的过程 predict = model(image)
#计算损失,取一个批次样本损失的平均值 loss = fluid.layers.square_error_cost(predict, label) avg_loss = fluid.layers.mean(loss)
#每训练了1000批次的数据,打印下当前Loss的情况 if batch_id !=0 and batch_id % 1000 == 0: print("epoch: {}, batch: {}, loss is: {}".format(epoch_id, batch_id, avg_loss.numpy()))
#后向传播,更新参数的过程 avg_loss.backward() optimizer.minimize(avg_loss) model.clear_gradients()
# 保存模型 fluid.save_dygraph(model.state_dict(), 'mnist')


训练过程采用二层循环嵌套方式,训练完成后需要保存模型参数,以便后续使用。

  • 内层循环:负责整个数据集的一次遍历,遍历数据集采用分批次(batch)方式。

  • 外层循环:定义遍历数据集的次数,本次训练中外层循环10次,通过参数EPOCH_NUM设置


开始训练

训练到最后一轮的时候,发现损失函数还是这么高

模型测试的主要目的是验证训练好的模型是否能正确识别出数字,包括如下四步:

  • 声明实例

  • 加载模型:加载训练过程中保存的模型参数,

  • 灌入数据:将测试样本传入模型,模型的状态设置为校验状态(eval),显式告诉框架我们接下来只会使用前向计算的流程,不会计算梯度和梯度反向传播。

  • 获取预测结果,取整后作为预测标签输出。

在模型测试之前,需要先从'./work/example_0.jpg'文件中读取样例图片,并进行归一化处理。

# 导入图像读取第三方库import matplotlib.image as mpimgimport matplotlib.pyplot as pltimport cv2import numpy as np# 读取图像img1 = cv2.imread('./work/example_0.png')example = mpimg.imread('./work/example_0.png')# 显示图像plt.imshow(example)plt.show()im = Image.open('./work/example_0.png').convert('L')print(np.array(im).shape)im = im.resize((28, 28), Image.ANTIALIAS)plt.imshow(im)plt.show()print(np.array(im).shape)

加载并处理,很显然

这个是0

# 读取一张本地的样例图片,转变成模型输入的格式def load_image(img_path): # 从img_path中读取图像,并转为灰度图 im = Image.open(img_path).convert('L') print(np.array(im)) im = im.resize((28, 28), Image.ANTIALIAS) im = np.array(im).reshape(1, -1).astype(np.float32) # 图像归一化,保持和数据集的数据范围一致 im = 1 - im / 127.5 return im
# 定义预测过程with fluid.dygraph.guard(): model = MNIST() params_file_path = 'mnist' img_path = './work/example_0.png'# 加载模型参数 model_dict, _ = fluid.load_dygraph("mnist") model.load_dict(model_dict)# 灌入数据 model.eval() tensor_img = load_image(img_path) result = model(fluid.dygraph.to_variable(tensor_img))# 预测输出取整,即为预测的数字,打印结果 print("本次预测的数字是", result.numpy().astype('int32'))

执行结果,出个3

那这个结果肯定是骗不了我的,那证明我姿势不太对,我继续捣鼓~

我的半个小时GPU时间啊,训练个什么东西出来


评论 (0)
  • 数字隔离芯片是一种实现电气隔离功能的集成电路,在工业自动化、汽车电子、光伏储能与电力通信等领域的电气系统中发挥着至关重要的作用。其不仅可令高、低压系统之间相互独立,提高低压系统的抗干扰能力,同时还可确保高、低压系统之间的安全交互,使系统稳定工作,并避免操作者遭受来自高压系统的电击伤害。典型数字隔离芯片的简化原理图值得一提的是,数字隔离芯片历经多年发展,其应用范围已十分广泛,凡涉及到在高、低压系统之间进行信号传输的场景中基本都需要应用到此种芯片。那么,电气工程师在进行电路设计时到底该如何评估选择一
    华普微HOPERF 2025-01-20 16:50 137浏览
  •  光伏及击穿,都可视之为 复合的逆过程,但是,复合、光伏与击穿,不单是进程的方向相反,偏置状态也不一样,复合的工况,是正偏,光伏是零偏,击穿与漂移则是反偏,光伏的能源是外来的,而击穿消耗的是结区自身和电源的能量,漂移的载流子是 客席载流子,须借外延层才能引入,客席载流子 不受反偏PN结的空乏区阻碍,能漂不能漂,只取决于反偏PN结是否处于外延层的「射程」范围,而穿通的成因,则是因耗尽层的过度扩张,致使跟 端子、外延层或其他空乏区 碰触,当耗尽层融通,耐压 (反向阻断能力) 即告彻底丧失,
    MrCU204 2025-01-17 11:30 220浏览
  •  万万没想到!科幻电影中的人形机器人,正在一步步走进我们人类的日常生活中来了。1月17日,乐聚将第100台全尺寸人形机器人交付北汽越野车,再次吹响了人形机器人疯狂进厂打工的号角。无独有尔,银河通用机器人作为一家成立不到两年时间的创业公司,在短短一年多时间内推出革命性的第一代产品Galbot G1,这是一款轮式、双臂、身体可折叠的人形机器人,得到了美团战投、经纬创投、IDG资本等众多投资方的认可。作为一家成立仅仅只有两年多时间的企业,智元机器人也把机器人从梦想带进了现实。2024年8月1
    刘旷 2025-01-21 11:15 737浏览
  • Ubuntu20.04默认情况下为root账号自动登录,本文介绍如何取消root账号自动登录,改为通过输入账号密码登录,使用触觉智能EVB3568鸿蒙开发板演示,搭载瑞芯微RK3568,四核A55处理器,主频2.0Ghz,1T算力NPU;支持OpenHarmony5.0及Linux、Android等操作系统,接口丰富,开发评估快人一步!添加新账号1、使用adduser命令来添加新用户,用户名以industio为例,系统会提示设置密码以及其他信息,您可以根据需要填写或跳过,命令如下:root@id
    Industio_触觉智能 2025-01-17 14:14 159浏览
  • 高速先生成员--黄刚这不马上就要过年了嘛,高速先生就不打算给大家上难度了,整一篇简单但很实用的文章给大伙瞧瞧好了。相信这个标题一出来,尤其对于PCB设计工程师来说,心就立马凉了半截。他们辛辛苦苦进行PCB的过孔设计,高速先生居然说设计多大的过孔他们不关心!另外估计这时候就跳出很多“挑刺”的粉丝了哈,因为翻看很多以往的文章,高速先生都表达了过孔孔径对高速性能的影响是很大的哦!咋滴,今天居然说孔径不关心了?别,别急哈,听高速先生在这篇文章中娓娓道来。首先还是要对各位设计工程师的设计表示肯定,毕竟像我
    一博科技 2025-01-21 16:17 170浏览
  • 现在为止,我们已经完成了Purple Pi OH主板的串口调试和部分配件的连接,接下来,让我们趁热打铁,完成剩余配件的连接!注:配件连接前请断开主板所有供电,避免敏感电路损坏!1.1 耳机接口主板有一路OTMP 标准四节耳机座J6,具备进行音频输出及录音功能,接入耳机后声音将优先从耳机输出,如下图所示:1.21.2 相机接口MIPI CSI 接口如上图所示,支持OV5648 和OV8858 摄像头模组。接入摄像头模组后,使用系统相机软件打开相机拍照和录像,如下图所示:1.3 以太网接口主板有一路
    Industio_触觉智能 2025-01-20 11:04 200浏览
  • 日前,商务部等部门办公厅印发《手机、平板、智能手表(手环)购新补贴实施方案》明确,个人消费者购买手机、平板、智能手表(手环)3类数码产品(单件销售价格不超过6000元),可享受购新补贴。每人每类可补贴1件,每件补贴比例为减去生产、流通环节及移动运营商所有优惠后最终销售价格的15%,每件最高不超过500元。目前,京东已经做好了承接手机、平板等数码产品国补优惠的落地准备工作,未来随着各省市关于手机、平板等品类的国补开启,京东将第一时间率先上线,满足消费者的换新升级需求。为保障国补的真实有效发放,基于
    华尔街科技眼 2025-01-17 10:44 252浏览
  • 书接上回:【2022年终总结】阳光总在风雨后,启航2023-面包板社区  https://mbb.eet-china.com/blog/468701-438244.html 总结2019,松山湖有个欧洲小镇-面包板社区  https://mbb.eet-china.com/blog/468701-413397.html        2025年该是总结下2024年的喜怒哀乐,有个好的开始,才能更好的面对2025年即将
    liweicheng 2025-01-24 23:18 45浏览
  • 临近春节,各方社交及应酬也变得多起来了,甚至一月份就排满了各式约见。有的是关系好的专业朋友的周末“恳谈会”,基本是关于2025年经济预判的话题,以及如何稳定工作等话题;但更多的预约是来自几个客户老板及副总裁们的见面,他们为今年的经济预判与企业发展焦虑而来。在聊天过程中,我发现今年的聊天有个很有意思的“点”,挺多人尤其关心我到底是怎么成长成现在的多领域风格的,还能掌握一些经济趋势的分析能力,到底学过哪些专业、在企业管过哪些具体事情?单单就这个一个月内,我就重复了数次“为什么”,再辅以我上次写的:《
    牛言喵语 2025-01-22 17:10 209浏览
  • 本文介绍瑞芯微开发板/主板Android配置APK默认开启性能模式方法,开启性能模式后,APK的CPU使用优先级会有所提高。触觉智能RK3562开发板演示,搭载4核A53处理器,主频高达2.0GHz;内置独立1Tops算力NPU,可应用于物联网网关、平板电脑、智能家居、教育电子、工业显示与控制等行业。源码修改修改源码根目录下文件device/rockchip/rk3562/package_performance.xml并添加以下内容,注意"+"号为添加内容,"com.tencent.mm"为AP
    Industio_触觉智能 2025-01-17 14:09 218浏览
  • 飞凌嵌入式基于瑞芯微RK3562系列处理器打造的FET3562J-C全国产核心板,是一款专为工业自动化及消费类电子设备设计的产品,凭借其强大的功能和灵活性,自上市以来得到了各行业客户的广泛关注。本文将详细介绍如何启动并测试RK3562J处理器的MCU,通过实际操作步骤,帮助各位工程师朋友更好地了解这款芯片。1、RK3562J处理器概述RK3562J处理器采用了4*Cortex-A53@1.8GHz+Cortex-M0@200MHz架构。其中,4个Cortex-A53核心作为主要核心,负责处理复杂
    飞凌嵌入式 2025-01-24 11:21 99浏览
  • 嘿,咱来聊聊RISC-V MCU技术哈。 这RISC-V MCU技术呢,简单来说就是基于一个叫RISC-V的指令集架构做出的微控制器技术。RISC-V这个啊,2010年的时候,是加州大学伯克利分校的研究团队弄出来的,目的就是想搞个新的、开放的指令集架构,能跟上现代计算的需要。到了2015年,专门成立了个RISC-V基金会,让这个架构更标准,也更好地推广开了。这几年啊,这个RISC-V的生态系统发展得可快了,好多公司和机构都加入了RISC-V International,还推出了不少RISC-V
    丙丁先生 2025-01-21 12:10 815浏览
  • 故障现象 一辆2007款日产天籁车,搭载VQ23发动机(气缸编号如图1所示,点火顺序为1-2-3-4-5-6),累计行驶里程约为21万km。车主反映,该车起步加速时偶尔抖动,且行驶中加速无力。 图1 VQ23发动机的气缸编号 故障诊断接车后试车,发动机怠速运转平稳,但只要换挡起步,稍微踩下一点加速踏板,就能感觉到车身明显抖动。用故障检测仪检测,发动机控制模块(ECM)无故障代码存储,且无失火数据流。用虹科Pico汽车示波器测量气缸1点火信号(COP点火信号)和曲轴位置传感器信
    虹科Pico汽车示波器 2025-01-23 10:46 97浏览
  • 2024年是很平淡的一年,能保住饭碗就是万幸了,公司业绩不好,跳槽又不敢跳,还有一个原因就是老板对我们这些员工还是很好的,碍于人情也不能在公司困难时去雪上加霜。在工作其间遇到的大问题没有,小问题还是有不少,这里就举一两个来说一下。第一个就是,先看下下面的这个封装,你能猜出它的引脚间距是多少吗?这种排线座比较常规的是0.6mm间距(即排线是0.3mm间距)的,而这个规格也是我们用得最多的,所以我们按惯性思维来看的话,就会认为这个座子就是0.6mm间距的,这样往往就不会去细看规格书了,所以这次的运气
    wuliangu 2025-01-21 00:15 395浏览
  •     IPC-2581是基于ODB++标准、结合PCB行业特点而指定的PCB加工文件规范。    IPC-2581旨在替代CAM350格式,成为PCB加工行业的新的工业规范。    有一些免费软件,可以查看(不可修改)IPC-2581数据文件。这些软件典型用途是工艺校核。    1. Vu2581        出品:Downstream     
    电子知识打边炉 2025-01-22 11:12 162浏览
我要评论
0
1
点击右上角,分享到朋友圈 我知道啦
请使用浏览器分享功能 我知道啦