手写数字识别任务第一次训练(结果不好)

云深之无迹 2021-01-27 00:00

数字识别是计算机从纸质文档、照片或其他来源接收、理解并识别可读的数字的能力,目前比较受关注的是手写数字识别。手写数字识别是一个典型的图像分类问题,已经被广泛应用于汇款单号识别、手写邮政编码识别,大大缩短了业务处理时间,提升了工作效率和质量。另一个重要的原因是,对于编程来说入门是打印一个HelloWorld,但是深度学习入门就是实现一个手写数字的识别~



图1:手写数字识别任务示意图
  • 任务输入:一系列手写数字图片,其中每张图片都是28x28的像素矩阵。

  • 任务输出:经过了大小归一化和居中处理,输出对应的0~9数字标签。



在处理如 图1 所示的手写邮政编码的简单图像分类任务时,可以使用基于MNIST数据集的手写数字识别模型。MNIST是深度学习领域标准、易用的成熟数据集,包含60000条训练样本和10000条测试样本。


MNIST数据集

MNIST数据集是从NIST的Special Database 3(SD-3)和Special Database 1(SD-1)构建而来。Yann LeCun等人从SD-1和SD-3中各取一半作为MNIST训练集和测试集,其中训练集来自250位不同的标注员,且训练集和测试集的标注员完全不同。

MNIST数据集的发布,吸引了大量科学家训练模型。1998年,LeCun分别用单层线性分类器、多层感知器(Multilayer Perceptron, MLP)和多层卷积神经网络LeNet进行实验使得测试集的误差不断下降(从12%下降到0.7%)。在研究过程中,LeCun提出了卷积神经网络(Convolutional Neural Network,CNN),大幅度地提高了手写字符的识别能力,也因此成为了深度学习领域的奠基人之一。

如今在深度学习领域,卷积神经网络占据了至关重要的地位,从最早LeCun提出的简单LeNet,到如今ImageNet大赛上的优胜模型VGGNet、GoogLeNet、ResNet等,人们在图像分类领域,利用卷积神经网络得到了一系列惊人的结果。

手写数字识别的模型是深度学习中相对简单的模型,非常适用初学者。

构建手写数字识别的神经网络模型

使用飞桨完成手写数字识别模型构建的代码结构如 图2 所示

训练的流程

我们这次训练,用GPU,嘤嘤嘤

看看我的GPU还好吗~

这个是我们今后做实验的流程图

import paddleimport paddle.fluid as fluidfrom paddle.fluid.dygraph.nn import Linearimport numpy as npimport osfrom PIL import Image

在进入训练环境之后,需要导入的Python库有这些

https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/data_cn/dataset_cn.html

相关要处理的数据在这个网址里有,相关的API也封装好了

# 如果~/.cache/paddle/dataset/mnist/目录下没有MNIST数据,API会自动将MINST数据下载到该文件夹下# 设置数据读取器,读取MNIST数据训练集trainset = paddle.dataset.mnist.train()# 包装数据读取器,每次读取的数据数量设置为batch_size=8train_reader = paddle.batch(trainset, batch_size=8)

通过paddle.dataset.mnist.train()函数设置数据读取器,

batch_size设置为8,即一个批次有8张图片和8个标签.

Cache file /home/aistudio/.cache/paddle/dataset/mnist/train-images-idx3-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-images-idx3-ubyte.gz Begin to download
Download finishedCache file /home/aistudio/.cache/paddle/dataset/mnist/train-labels-idx1-ubyte.gz not found, downloading https://dataset.bj.bcebos.com/mnist/train-labels-idx1-ubyte.gz Begin to download........Download finished

这个是下载好的输出结果


长这个样的

# 以迭代的形式读取数据for batch_id, data in enumerate(train_reader()): # 获得图像数据,并转为float32类型的数组 img_data = np.array([x[0] for x in data]).astype('float32') # 获得图像标签数据,并转为float32类型的数组 label_data = np.array([x[1] for x in data]).astype('float32') # 打印数据形状 print("图像数据形状和对应数据为:", img_data.shape, img_data[0]) print("图像标签形状和对应数据为:", label_data.shape, label_data[0]) break
print("\n打印第一个batch的第一个图像,对应标签数字为{}".format(label_data[0]))# 显示第一batch的第一个图像import matplotlib.pyplot as pltimg = np.array(img_data[0]+1)*127.5img = np.reshape(img, [28, 28]).astype(np.uint8)
plt.figure("Image") # 图像窗口名称plt.imshow(img)plt.axis('on') # 关掉坐标轴为 offplt.title('image') # 图像题目plt.show()

paddle.batch函数将MNIST数据集拆分成多个批次,

通过如下代码读取第一个批次的数据内容,观察数据打印结果。



执行的结果很多,我就截图一些


从打印结果看,从数据加载器train_reader()中读取一次数据,可以得到形状为(8, 784)的图像数据和形状为(8,)的标签数据。其中,形状中的数字8与设置的batch_size大小对应,784为MINIST数据集中每个图像的像素大小(28*28)。

此外,从打印的图像数据来看,图像数据的范围是[-1, 1],表明这是已经完成图像归一化后的图像数据,并且空白背景部分的值是-1。将图像数据反归一化,并使用matplotlib工具包将其显示出来,如图2 所示。可以看到图片显示的数字是5,和对应标签数字一致。


图2:matplotlib打印结果示意图
因为存储的是28x28的向量图,所以坐标也是对应的


https://www.paddlepaddle.org.cn/documentation/docs/zh/api_guides/index_cn.html



一开始的API中讲了一些训练的基本概念

还有基本的数学概念

我们可以写一个代码验证

要用到的网络模型的样子

模型的输入为784维(28*28)数据,输出为1维数据




输入像素的位置排布信息对理解图像内容非常重要(如将原始尺寸为28*28图像的像素按照7*112的尺寸排布,那么其中的数字将不可识别),因此网络的输入设计为28*28的尺寸,而不是1*784,以便于模型能够正确处理像素之间的空间信息。


事实上,采用只有一层的简单网络(对输入求加权和)时并没有处理位置关系信息,因此可以猜测出此模型的预测效果可能有限。在后续优化环节介绍的卷积神经网络则更好的考虑了这种位置关系信息,模型的预测效果也会有显著提升。

# 定义mnist数据识别网络结构class MNIST(fluid.dygraph.Layer): def __init__(self): super(MNIST, self).__init__()
# 定义一层全连接层,输出维度是1,激活函数为None,即不使用激活函数 self.fc = Linear(input_dim=784, output_dim=1, act=None)
# 定义网络结构的前向计算过程 def forward(self, inputs): outputs = self.fc(inputs) return outputs

定义一个神经网络层

# 定义飞桨动态图工作环境with fluid.dygraph.guard(): # 声明网络结构 model = MNIST() # 启动训练模式 model.train() # 定义数据读取函数,数据读取batch_size设置为16 train_loader = paddle.batch(paddle.dataset.mnist.train(), batch_size=16) # 定义优化器,使用随机梯度下降SGD优化器,学习率设置为0.001 optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.001, parameter_list=model.parameters())

训练配置需要先生成模型实例(设为“训练”状态),再设置优化算法和学习率(使用随机梯度下降SGD,学习率设置为0.001)

# 通过with语句创建一个dygraph运行的context# 动态图下的一些操作需要在guard下进行with fluid.dygraph.guard(): model = MNIST() model.train() train_loader = paddle.batch(paddle.dataset.mnist.train(), batch_size=16) optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.001, parameter_list=model.parameters()) EPOCH_NUM = 10 for epoch_id in range(EPOCH_NUM): for batch_id, data in enumerate(train_loader()): #准备数据,格式需要转换成符合框架要求的 image_data = np.array([x[0] for x in data]).astype('float32') label_data = np.array([x[1] for x in data]).astype('float32').reshape(-1, 1) # 将数据转为飞桨动态图格式 image = fluid.dygraph.to_variable(image_data) label = fluid.dygraph.to_variable(label_data)
#前向计算的过程 predict = model(image)
#计算损失,取一个批次样本损失的平均值 loss = fluid.layers.square_error_cost(predict, label) avg_loss = fluid.layers.mean(loss)
#每训练了1000批次的数据,打印下当前Loss的情况 if batch_id !=0 and batch_id % 1000 == 0: print("epoch: {}, batch: {}, loss is: {}".format(epoch_id, batch_id, avg_loss.numpy()))
#后向传播,更新参数的过程 avg_loss.backward() optimizer.minimize(avg_loss) model.clear_gradients()
# 保存模型 fluid.save_dygraph(model.state_dict(), 'mnist')


训练过程采用二层循环嵌套方式,训练完成后需要保存模型参数,以便后续使用。

  • 内层循环:负责整个数据集的一次遍历,遍历数据集采用分批次(batch)方式。

  • 外层循环:定义遍历数据集的次数,本次训练中外层循环10次,通过参数EPOCH_NUM设置


开始训练

训练到最后一轮的时候,发现损失函数还是这么高

模型测试的主要目的是验证训练好的模型是否能正确识别出数字,包括如下四步:

  • 声明实例

  • 加载模型:加载训练过程中保存的模型参数,

  • 灌入数据:将测试样本传入模型,模型的状态设置为校验状态(eval),显式告诉框架我们接下来只会使用前向计算的流程,不会计算梯度和梯度反向传播。

  • 获取预测结果,取整后作为预测标签输出。

在模型测试之前,需要先从'./work/example_0.jpg'文件中读取样例图片,并进行归一化处理。

# 导入图像读取第三方库import matplotlib.image as mpimgimport matplotlib.pyplot as pltimport cv2import numpy as np# 读取图像img1 = cv2.imread('./work/example_0.png')example = mpimg.imread('./work/example_0.png')# 显示图像plt.imshow(example)plt.show()im = Image.open('./work/example_0.png').convert('L')print(np.array(im).shape)im = im.resize((28, 28), Image.ANTIALIAS)plt.imshow(im)plt.show()print(np.array(im).shape)

加载并处理,很显然

这个是0

# 读取一张本地的样例图片,转变成模型输入的格式def load_image(img_path): # 从img_path中读取图像,并转为灰度图 im = Image.open(img_path).convert('L') print(np.array(im)) im = im.resize((28, 28), Image.ANTIALIAS) im = np.array(im).reshape(1, -1).astype(np.float32) # 图像归一化,保持和数据集的数据范围一致 im = 1 - im / 127.5 return im
# 定义预测过程with fluid.dygraph.guard(): model = MNIST() params_file_path = 'mnist' img_path = './work/example_0.png'# 加载模型参数 model_dict, _ = fluid.load_dygraph("mnist") model.load_dict(model_dict)# 灌入数据 model.eval() tensor_img = load_image(img_path) result = model(fluid.dygraph.to_variable(tensor_img))# 预测输出取整,即为预测的数字,打印结果 print("本次预测的数字是", result.numpy().astype('int32'))

执行结果,出个3

那这个结果肯定是骗不了我的,那证明我姿势不太对,我继续捣鼓~

我的半个小时GPU时间啊,训练个什么东西出来


评论
  • 本文介绍瑞芯微RK3588主板/开发板Android12系统下,APK签名文件生成方法。触觉智能EVB3588开发板演示,搭载了瑞芯微RK3588芯片,该开发板是核心板加底板设计,音视频接口、通信接口等各类接口一应俱全,可帮助企业提高产品开发效率,缩短上市时间,降低成本和设计风险。工具准备下载Keytool-ImportKeyPair工具在源码:build/target/product/security/系统初始签名文件目录中,将以下三个文件拷贝出来:platform.pem;platform.
    Industio_触觉智能 2024-12-12 10:27 32浏览
  • 全球智能电视时代来临这年头若是消费者想随意地从各个通路中选购电视时,不难发现目前市场上的产品都已是具有智能联网功能的智能电视了,可以宣告智能电视的普及时代已到临!Google从2021年开始大力推广Google TV(即原Android TV的升级版),其他各大品牌商也都跟进推出搭载Google TV操作系统的机种,除了Google TV外,LG、Samsung、Panasonic等大厂牌也开发出自家的智能电视平台,可以看出各家业者都一致地看好这块大饼。智能电视的Wi-Fi连线怎么消失了?智能电
    百佳泰测试实验室 2024-12-12 17:33 41浏览
  • 铁氧体芯片是一种基于铁氧体磁性材料制成的芯片,在通信、传感器、储能等领域有着广泛的应用。铁氧体磁性材料能够通过外加磁场调控其导电性质和反射性质,因此在信号处理和传感器技术方面有着独特的优势。以下是对半导体划片机在铁氧体划切领域应用的详细阐述: 一、半导体划片机的工作原理与特点半导体划片机是一种使用刀片或通过激光等方式高精度切割被加工物的装置,是半导体后道封测中晶圆切割和WLP切割环节的关键设备。它结合了水气电、空气静压高速主轴、精密机械传动、传感器及自动化控制等先进技术,具有高精度、高
    博捷芯划片机 2024-12-12 09:16 82浏览
  • RK3506 是瑞芯微推出的MPU产品,芯片制程为22nm,定位于轻量级、低成本解决方案。该MPU具有低功耗、外设接口丰富、实时性高的特点,适合用多种工商业场景。本文将基于RK3506的设计特点,为大家分析其应用场景。RK3506核心板主要分为三个型号,各型号间的区别如下图:​图 1  RK3506核心板处理器型号场景1:显示HMIRK3506核心板显示接口支持RGB、MIPI、QSPI输出,且支持2D图形加速,轻松运行QT、LVGL等GUI,最快3S内开
    万象奥科 2024-12-11 15:42 83浏览
  • 一、SAE J1939协议概述SAE J1939协议是由美国汽车工程师协会(SAE,Society of Automotive Engineers)定义的一种用于重型车辆和工业设备中的通信协议,主要应用于车辆和设备之间的实时数据交换。J1939基于CAN(Controller Area Network)总线技术,使用29bit的扩展标识符和扩展数据帧,CAN通信速率为250Kbps,用于车载电子控制单元(ECU)之间的通信和控制。小北同学在之前也对J1939协议做过扫盲科普【科普系列】SAE J
    北汇信息 2024-12-11 15:45 108浏览
  • 习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习笔记&记录学习习笔记&记学习学习笔记&记录学习学习笔记&记录学习习笔记&记录学习学习笔记&记录学习学习笔记记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记
    youyeye 2024-12-11 17:58 83浏览
  • 应用环境与极具挑战性的测试需求在服务器制造领域里,系统整合测试(System Integration Test;SIT)是确保产品质量和性能的关键步骤。随着服务器系统的复杂性不断提升,包括:多种硬件组件、操作系统、虚拟化平台以及各种应用程序和服务的整合,服务器制造商面临着更有挑战性的测试需求。这些挑战主要体现在以下五个方面:1. 硬件和软件的高度整合:现代服务器通常包括多个处理器、内存模块、储存设备和网络接口。这些硬件组件必须与操作系统及应用软件无缝整合。SIT测试可以帮助制造商确保这些不同组件
    百佳泰测试实验室 2024-12-12 17:45 32浏览
  • 天问Block和Mixly是两个不同的编程工具,分别在单片机开发和教育编程领域有各自的应用。以下是对它们的详细比较: 基本定义 天问Block:天问Block是一个基于区块链技术的数字身份验证和数据交换平台。它的目标是为用户提供一个安全、去中心化、可信任的数字身份验证和数据交换解决方案。 Mixly:Mixly是一款由北京师范大学教育学部创客教育实验室开发的图形化编程软件,旨在为初学者提供一个易于学习和使用的Arduino编程环境。 主要功能 天问Block:支持STC全系列8位单片机,32位
    丙丁先生 2024-12-11 13:15 63浏览
  • 我的一台很多年前人家不要了的九十年代SONY台式组合音响,接手时只有CD功能不行了,因为不需要,也就没修,只使用收音机、磁带机和外接信号功能就够了。最近五年在外地,就断电闲置,没使用了。今年9月回到家里,就一个劲儿地忙着收拾家当,忙了一个多月,太多事啦!修了电气,清理了闲置不用了的电器和电子,就是一个劲儿地扔扔扔!几十年的“工匠式”收留收藏,只能断舍离,拆解不过来的了。一天,忽然感觉室内有股臭味,用鼻子的嗅觉功能朝着臭味重的方向寻找,觉得应该就是这台组合音响?怎么会呢?这无机物的东西不会腐臭吧?
    自做自受 2024-12-10 16:34 170浏览
  • 在智能化技术快速发展当下,图像数据的采集与处理逐渐成为自动驾驶、工业等领域的一项关键技术。高质量的图像数据采集与算法集成测试都是确保系统性能和可靠性的关键。随着技术的不断进步,对于图像数据的采集、处理和分析的需求日益增长,这不仅要求我们拥有高性能的相机硬件,还要求我们能够高效地集成和测试各种算法。我们探索了一种多源相机数据采集与算法集成测试方案,能够满足不同应用场景下对图像采集和算法测试的多样化需求,确保数据的准确性和算法的有效性。一、相机组成相机一般由镜头(Lens),图像传感器(Image
    康谋 2024-12-12 09:45 74浏览
  • 首先在gitee上打个广告:ad5d2f3b647444a88b6f7f9555fd681f.mp4 · 丙丁先生/香河英茂工作室中国 - Gitee.com丙丁先生 (mr-bingding) - Gitee.com2024年对我来说是充满挑战和机遇的一年。在这一年里,我不仅进行了多个开发板的测评,还尝试了多种不同的项目和技术。今天,我想分享一下这一年的故事,希望能给大家带来一些启发和乐趣。 年初的时候,我开始对各种开发板进行测评。从STM32WBA55CG到瑞萨、平头哥和平海的开发板,我都
    丙丁先生 2024-12-11 20:14 68浏览
  • 时源芯微——RE超标整机定位与解决详细流程一、 初步测量与问题确认使用专业的电磁辐射测量设备,对整机的辐射发射进行精确测量。确认是否存在RE超标问题,并记录超标频段和幅度。二、电缆检查与处理若存在信号电缆:步骤一:拔掉所有信号电缆,仅保留电源线,再次测量整机的辐射发射。若测量合格:判定问题出在信号电缆上,可能是电缆的共模电流导致。逐一连接信号电缆,每次连接后测量,定位具体哪根电缆或接口导致超标。对问题电缆进行处理,如加共模扼流圈、滤波器,或优化电缆布局和屏蔽。重新连接所有电缆,再次测量
    时源芯微 2024-12-11 17:11 109浏览
  • 习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习笔记&记录学习习笔记&记学习学习笔记&记录学习学习笔记&记录学习习笔记&记录学习学习笔记&记录学习学习笔记记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记
    youyeye 2024-12-12 10:13 30浏览
  • 全球知名半导体制造商ROHM Co., Ltd.(以下简称“罗姆”)宣布与Taiwan Semiconductor Manufacturing Company Limited(以下简称“台积公司”)就车载氮化镓功率器件的开发和量产事宜建立战略合作伙伴关系。通过该合作关系,双方将致力于将罗姆的氮化镓器件开发技术与台积公司业界先进的GaN-on-Silicon工艺技术优势结合起来,满足市场对高耐压和高频特性优异的功率元器件日益增长的需求。氮化镓功率器件目前主要被用于AC适配器和服务器电源等消费电子和
    电子资讯报 2024-12-10 17:09 98浏览
  • 近日,搭载紫光展锐W517芯片平台的INMO GO2由影目科技正式推出。作为全球首款专为商务场景设计的智能翻译眼镜,INMO GO2 以“快、准、稳”三大核心优势,突破传统翻译产品局限,为全球商务人士带来高效、自然、稳定的跨语言交流体验。 INMO GO2内置的W517芯片,是紫光展锐4G旗舰级智能穿戴平台,采用四核处理器,具有高性能、低功耗的优势,内置超微高集成技术,采用先进工艺,计算能力相比同档位竞品提升4倍,强大的性能提供更加多样化的应用场景。【视频见P盘链接】 依托“
    紫光展锐 2024-12-11 11:50 69浏览
我要评论
0
点击右上角,分享到朋友圈 我知道啦
请使用浏览器分享功能 我知道啦