手写数字识别任务之数据处理

云深之无迹 2021-01-27 00:00

舍不得我的GPU

这次横向逐步进行优化


在前文中,我们直接用API导入了数据,但是现实中,搬砖环境千变万化,我们总是要拿自己的数据的处理的:

  • 读入数据

  • 划分数据集

  • 生成批次数据

  • 训练样本集乱序

  • 校验数据有效性

import paddleimport paddle.fluid as fluidfrom paddle.fluid.dygraph.nn import Linearimport numpy as npimport osimport gzipimport jsonimport random

导入一些工具库

数据集的储存结构


data包含三个元素的列表:train_setval_set、 test_set

  • train_set(训练集):包含50000条手写数字图片和对应的标签,用于确定模型参数。

  • val_set(验证集):包含10000条手写数字图片和对应的标签,用于调节模型超参数(如多个网络结构、正则化权重的最优选择)。

  • test_set(测试集):包含10000条手写数字图片和对应的标签,用于估计应用效果(没有在模型中应用过的数据,更贴近模型在真实场景应用的效果)。

train_set包含两个元素的列表:train_imagestrain_labels

  • train_images:[5000, 784]的二维列表,包含5000张图片。每张图片用一个长度为784的向量表示,内容是28*28尺寸的像素灰度值(黑白图片)。

  • train_labels:[5000, ]的列表,表示这些图片对应的分类标签,即0-9之间的一个数字。


# 声明数据集文件位置datafile = './work/mnist.json.gz'print('loading mnist dataset from {} ......'.format(datafile))# 加载json数据文件data = json.load(gzip.open(datafile))print('mnist dataset load done')# 读取到的数据区分训练集,验证集,测试集train_set, val_set, eval_set = data
# 数据集相关参数,图片高度IMG_ROWS, 图片宽度IMG_COLSIMG_ROWS = 28IMG_COLS = 28
# 打印数据信息imgs, labels = train_set[0], train_set[1]print("训练数据集数量: ", len(imgs))
# 观察验证集数量imgs, labels = val_set[0], val_set[1]print("验证数据集数量: ", len(imgs))
# 观察测试集数量imgs, labels = val= eval_set[0], eval_set[1]print("测试数据集数量: ", len(imgs))

读取以及拆分的代码



  • 训练样本乱序: 先将样本按顺序进行编号,建立ID集合index_list。然后将index_list乱序,最后按乱序后的顺序读取数据。


说明:

通过大量实验发现,模型对最后出现的数据印象更加深刻。训练数据导入后,越接近模型训练结束,最后几个批次数据对模型参数的影响越大。为了避免模型记忆影响训练效果,需要进行样本乱序操作。


  • 生成批次数据: 先设置合理的batch_size,再将数据转变成符合模型输入要求的np.array格式返回。同时,在返回数据时将Python生成器设置为yield模式,以减少内存占用。

在执行如上两个操作之前,需要先将数据处理代码封装成load_data函数,方便后续调用。load_data有三种模型:trainvalideval,分为对应返回的数据是训练集、验证集、测试集。

imgs, labels = train_set[0], train_set[1]print("训练数据集数量: ", len(imgs))# 获得数据集长度imgs_length = len(imgs)# 定义数据集每个数据的序号,根据序号读取数据index_list = list(range(imgs_length))# 读入数据时用到的批次大小BATCHSIZE = 100
# 随机打乱训练数据的索引序号random.shuffle(index_list)
# 定义数据生成器,返回批次数据def data_generator():
imgs_list = [] labels_list = [] for i in index_list: # 将数据处理成希望的格式,比如类型为float32,shape为[1, 28, 28] img = np.reshape(imgs[i], [1, IMG_ROWS, IMG_COLS]).astype('float32') label = np.reshape(labels[i], [1]).astype('float32') imgs_list.append(img) labels_list.append(label) if len(imgs_list) == BATCHSIZE: # 获得一个batchsize的数据,并返回 yield np.array(imgs_list), np.array(labels_list) # 清空数据读取列表 imgs_list = [] labels_list = []
# 如果剩余数据的数目小于BATCHSIZE, # 则剩余数据一起构成一个大小为len(imgs_list)的mini-batch if len(imgs_list) > 0: yield np.array(imgs_list), np.array(labels_list) return data_generator

# 声明数据读取函数,从训练集中读取数据train_loader = data_generator# 以迭代的形式读取数据for batch_id, data in enumerate(train_loader()): image_data, label_data = data if batch_id == 0: # 打印数据shape和类型 print("打印第一个batch数据的维度:") print("图像维度: {}, 标签维度: {}".format(image_data.shape, label_data.shape)) break



在实际应用中,原始数据可能存在标注不准确、数据杂乱或格式不统一等情况。因此在完成数据处理流程后,还需要进行数据校验,一般有两种方式:

  • 机器校验:加入一些校验和清理数据的操作。

  • 人工校验:先打印数据输出结果,观察是否是设置的格式。再从训练的结果验证数据处理和读取的有效性。


 imgs_length = len(imgs)
assert len(imgs) == len(labels), \ "length of train_imgs({}) should be the same as train_labels({})".format(len(imgs), len(label))


# 声明数据读取函数,从训练集中读取数据train_loader = data_generator# 以迭代的形式读取数据for batch_id, data in enumerate(train_loader()): image_data, label_data = data if batch_id == 0: # 打印数据shape和类型 print("打印第一个batch数据的维度,以及数据的类型:") print("图像维度: {}, 标签维度: {}, 图像数据类型: {}, 标签数据类型: {}".format(image_data.shape, label_data.shape, type(image_data), type(label_data))) break

再放一个人工校验的:

人工校验是指打印数据输出结果,观察是否是预期的格式。实现数据处理和加载函数后,我们可以调用它读取一次数据,观察数据的shape和类型是否与函数中设置的一致。

def load_data(mode='train'): datafile = './work/mnist.json.gz' print('loading mnist dataset from {} ......'.format(datafile)) # 加载json数据文件 data = json.load(gzip.open(datafile)) print('mnist dataset load done')
# 读取到的数据区分训练集,验证集,测试集 train_set, val_set, eval_set = data if mode=='train': # 获得训练数据集 imgs, labels = train_set[0], train_set[1] elif mode=='valid': # 获得验证数据集 imgs, labels = val_set[0], val_set[1] elif mode=='eval': # 获得测试数据集 imgs, labels = eval_set[0], eval_set[1] else: raise Exception("mode can only be one of ['train', 'valid', 'eval']") print("训练数据集数量: ", len(imgs))
# 校验数据 imgs_length = len(imgs)
assert len(imgs) == len(labels), \ "length of train_imgs({}) should be the same as train_labels({})".format(len(imgs), len(label))
# 获得数据集长度 imgs_length = len(imgs)
# 定义数据集每个数据的序号,根据序号读取数据 index_list = list(range(imgs_length)) # 读入数据时用到的批次大小 BATCHSIZE = 100
# 定义数据生成器 def data_generator(): if mode == 'train': # 训练模式下打乱数据 random.shuffle(index_list) imgs_list = [] labels_list = [] for i in index_list: # 将数据处理成希望的格式,比如类型为float32,shape为[1, 28, 28] img = np.reshape(imgs[i], [1, IMG_ROWS, IMG_COLS]).astype('float32') label = np.reshape(labels[i], [1]).astype('float32') imgs_list.append(img) labels_list.append(label) if len(imgs_list) == BATCHSIZE: # 获得一个batchsize的数据,并返回 yield np.array(imgs_list), np.array(labels_list) # 清空数据读取列表 imgs_list = [] labels_list = []
# 如果剩余数据的数目小于BATCHSIZE, # 则剩余数据一起构成一个大小为len(imgs_list)的mini-batch if len(imgs_list) > 0: yield np.array(imgs_list), np.array(labels_list) return data_generator

怕流程太复杂,那就写一个打包的函数就是这样~

#数据处理部分之后的代码,数据读取的部分调用Load_data函数# 定义网络结构,同上一节所使用的网络结构class MNIST(fluid.dygraph.Layer): def __init__(self): super(MNIST, self).__init__() self.fc = Linear(input_dim=784, output_dim=1, act=None)
def forward(self, inputs): inputs = fluid.layers.reshape(inputs, (-1, 784)) outputs = self.fc(inputs) return outputs
# 训练配置,并启动训练过程with fluid.dygraph.guard(): model = MNIST() model.train() #调用加载数据的函数 train_loader = load_data('train') optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.001, parameter_list=model.parameters()) EPOCH_NUM = 10 for epoch_id in range(EPOCH_NUM): for batch_id, data in enumerate(train_loader()): #准备数据,变得更加简洁 image_data, label_data = data image = fluid.dygraph.to_variable(image_data) label = fluid.dygraph.to_variable(label_data)
#前向计算的过程 predict = model(image)
#计算损失,取一个批次样本损失的平均值 loss = fluid.layers.square_error_cost(predict, label) avg_loss = fluid.layers.mean(loss)
#每训练了200批次的数据,打印下当前Loss的情况 if batch_id % 200 == 0: print("epoch: {}, batch: {}, loss is: {}".format(epoch_id, batch_id, avg_loss.numpy()))
#后向传播,更新参数的过程 avg_loss.backward() optimizer.minimize(avg_loss) model.clear_gradients()
#保存模型参数 fluid.save_dygraph(model.state_dict(), 'mnist')

定义一个神经网络层,然后开始训练

因为代码没有更改,所以.一样的lose

同步和异步的数据读取方式


上面提到的数据读取采用的是同步数据读取方式。对于样本量较大、数据读取较慢的场景,建议采用异步数据读取方式。

异步读取数据时,数据读取和模型训练并行执行,从而加快了数据读取速度,牺牲一小部分内存换取数据读取效率的提升


  • 同步数据读取:数据读取与模型训练串行。当模型需要数据时,才运行数据读取函数获得当前批次的数据。在读取数据期间,模型一直等待数据读取结束才进行训练,数据读取速度相对较慢。

  • 异步数据读取:数据读取和模型训练并行。读取到的数据不断的放入缓存区,无需等待模型训练就可以启动下一轮数据读取。当模型训练完一个批次后,不用等待数据读取过程,直接从缓存区获得下一批次数据进行训练,从而加快了数据读取速度。

  • 异步队列:数据读取和模型训练交互的仓库,二者均可以从仓库中读取数据,它的存在使得两者的工作节奏可以解耦。

# 定义数据读取后存放的位置,CPU或者GPU,这里使用CPU# place = fluid.CUDAPlace(0) 时,数据读取到GPU上place = fluid.CPUPlace()with fluid.dygraph.guard(place): # 声明数据加载函数,使用训练模式 train_loader = load_data(mode='train') # 定义DataLoader对象用于加载Python生成器产生的数据 data_loader = fluid.io.DataLoader.from_generator(capacity=5, return_list=True) # 设置数据生成器 data_loader.set_batch_generator(train_loader, places=place) # 迭代的读取数据并打印数据的形状 for i, data in enumerate(data_loader): image_data, label_data = data print(i, image_data.shape, label_data.shape) if i>=5: break

飞桨的异步读取是这样的

与同步数据读取相比,异步数据读取仅增加了三行代码,如下所示。

place = fluid.CPUPlace()

# 设置读取的数据是放在CPU还是GPU上。

data_loader = fluid.io.DataLoader.from_generator(capacity=5, return_list=True)

# 创建一个DataLoader对象用于加载Python生成器产生的数据。数据会由Python线程预先读取,并异步送入一个队列中。

data_loader.set_batch_generator(train_loader, place)

# 用创建的DataLoader对象设置一个数据生成器set_batch_generator,输入的参数是一个Python数据生成器train_loader和服务器资源类型place(标明CPU还是GPU)

fluid.io.DataLoader.from_generator参数名称和含义如下:

  • feed_list:仅在PaddlePaddle静态图中使用,动态图中设置为“None”,本教程默认使用动态图的建模方式;

  • capacity:表示在DataLoader中维护的队列容量,如果读取数据的速度很快,建议设置为更大的值;

  • use_double_buffer:是一个布尔型的参数,设置为“True”时,Dataloader会预先异步读取下一个batch的数据并放到缓存区;

  • iterable:表示创建的Dataloader对象是否是可迭代的,一般设置为“True”;

  • return_list:在动态图模式下需要设置为“True”。

异步数据读取并训练的完整案例代码如下

with fluid.dygraph.guard(): model = MNIST() model.train() #调用加载数据的函数 train_loader = load_data('train') # 创建异步数据读取器 place = fluid.CPUPlace() data_loader = fluid.io.DataLoader.from_generator(capacity=5, return_list=True) data_loader.set_batch_generator(train_loader, places=place)  optimizer = fluid.optimizer.SGDOptimizer(learning_rate=0.001, parameter_list=model.parameters()) EPOCH_NUM = 3 for epoch_id in range(EPOCH_NUM): for batch_id, data in enumerate(data_loader): image_data, label_data = data image = fluid.dygraph.to_variable(image_data) label = fluid.dygraph.to_variable(label_data)  predict = model(image)  loss = fluid.layers.square_error_cost(predict, label) avg_loss = fluid.layers.mean(loss)  if batch_id % 200 == 0: print("epoch: {}, batch: {}, loss is: {}".format(epoch_id, batch_id, avg_loss.numpy()))  avg_loss.backward() optimizer.minimize(avg_loss) model.clear_gradients()
fluid.save_dygraph(model.state_dict(), 'mnist')
评论
  • 近日,搭载紫光展锐W517芯片平台的INMO GO2由影目科技正式推出。作为全球首款专为商务场景设计的智能翻译眼镜,INMO GO2 以“快、准、稳”三大核心优势,突破传统翻译产品局限,为全球商务人士带来高效、自然、稳定的跨语言交流体验。 INMO GO2内置的W517芯片,是紫光展锐4G旗舰级智能穿戴平台,采用四核处理器,具有高性能、低功耗的优势,内置超微高集成技术,采用先进工艺,计算能力相比同档位竞品提升4倍,强大的性能提供更加多样化的应用场景。【视频见P盘链接】 依托“
    紫光展锐 2024-12-11 11:50 62浏览
  • 习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习笔记&记录学习习笔记&记学习学习笔记&记录学习学习笔记&记录学习习笔记&记录学习学习笔记&记录学习学习笔记记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记
    youyeye 2024-12-11 17:58 69浏览
  • 铁氧体芯片是一种基于铁氧体磁性材料制成的芯片,在通信、传感器、储能等领域有着广泛的应用。铁氧体磁性材料能够通过外加磁场调控其导电性质和反射性质,因此在信号处理和传感器技术方面有着独特的优势。以下是对半导体划片机在铁氧体划切领域应用的详细阐述: 一、半导体划片机的工作原理与特点半导体划片机是一种使用刀片或通过激光等方式高精度切割被加工物的装置,是半导体后道封测中晶圆切割和WLP切割环节的关键设备。它结合了水气电、空气静压高速主轴、精密机械传动、传感器及自动化控制等先进技术,具有高精度、高
    博捷芯划片机 2024-12-12 09:16 67浏览
  • 在智能化技术快速发展当下,图像数据的采集与处理逐渐成为自动驾驶、工业等领域的一项关键技术。高质量的图像数据采集与算法集成测试都是确保系统性能和可靠性的关键。随着技术的不断进步,对于图像数据的采集、处理和分析的需求日益增长,这不仅要求我们拥有高性能的相机硬件,还要求我们能够高效地集成和测试各种算法。我们探索了一种多源相机数据采集与算法集成测试方案,能够满足不同应用场景下对图像采集和算法测试的多样化需求,确保数据的准确性和算法的有效性。一、相机组成相机一般由镜头(Lens),图像传感器(Image
    康谋 2024-12-12 09:45 53浏览
  • 全球知名半导体制造商ROHM Co., Ltd.(以下简称“罗姆”)宣布与Taiwan Semiconductor Manufacturing Company Limited(以下简称“台积公司”)就车载氮化镓功率器件的开发和量产事宜建立战略合作伙伴关系。通过该合作关系,双方将致力于将罗姆的氮化镓器件开发技术与台积公司业界先进的GaN-on-Silicon工艺技术优势结合起来,满足市场对高耐压和高频特性优异的功率元器件日益增长的需求。氮化镓功率器件目前主要被用于AC适配器和服务器电源等消费电子和
    电子资讯报 2024-12-10 17:09 92浏览
  • 智能汽车可替换LED前照灯控制运行的原理涉及多个方面,包括自适应前照灯系统(AFS)的工作原理、传感器的应用、步进电机的控制以及模糊控制策略等。当下时代的智能汽车灯光控制系统通过车载网关控制单元集中控制,表现特殊点的有特斯拉,仅通过前车身控制器,整个系统就包括了灯光旋转开关、车灯变光开关、左LED前照灯总成、右LED前照灯总成、转向柱电子控制单元、CAN数据总线接口、组合仪表控制单元、车载网关控制单元等器件。变光开关、转向开关和辅助操作系统一般连为一体,开关之间通过内部线束和转向柱装置连接为多,
    lauguo2013 2024-12-10 15:53 93浏览
  • 首先在gitee上打个广告:ad5d2f3b647444a88b6f7f9555fd681f.mp4 · 丙丁先生/香河英茂工作室中国 - Gitee.com丙丁先生 (mr-bingding) - Gitee.com2024年对我来说是充满挑战和机遇的一年。在这一年里,我不仅进行了多个开发板的测评,还尝试了多种不同的项目和技术。今天,我想分享一下这一年的故事,希望能给大家带来一些启发和乐趣。 年初的时候,我开始对各种开发板进行测评。从STM32WBA55CG到瑞萨、平头哥和平海的开发板,我都
    丙丁先生 2024-12-11 20:14 56浏览
  • RK3506 是瑞芯微推出的MPU产品,芯片制程为22nm,定位于轻量级、低成本解决方案。该MPU具有低功耗、外设接口丰富、实时性高的特点,适合用多种工商业场景。本文将基于RK3506的设计特点,为大家分析其应用场景。RK3506核心板主要分为三个型号,各型号间的区别如下图:​图 1  RK3506核心板处理器型号场景1:显示HMIRK3506核心板显示接口支持RGB、MIPI、QSPI输出,且支持2D图形加速,轻松运行QT、LVGL等GUI,最快3S内开
    万象奥科 2024-12-11 15:42 80浏览
  • 天问Block和Mixly是两个不同的编程工具,分别在单片机开发和教育编程领域有各自的应用。以下是对它们的详细比较: 基本定义 天问Block:天问Block是一个基于区块链技术的数字身份验证和数据交换平台。它的目标是为用户提供一个安全、去中心化、可信任的数字身份验证和数据交换解决方案。 Mixly:Mixly是一款由北京师范大学教育学部创客教育实验室开发的图形化编程软件,旨在为初学者提供一个易于学习和使用的Arduino编程环境。 主要功能 天问Block:支持STC全系列8位单片机,32位
    丙丁先生 2024-12-11 13:15 57浏览
  • 习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习笔记&记录学习习笔记&记学习学习笔记&记录学习学习笔记&记录学习习笔记&记录学习学习笔记&记录学习学习笔记记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记
    youyeye 2024-12-10 16:13 113浏览
  • 时源芯微——RE超标整机定位与解决详细流程一、 初步测量与问题确认使用专业的电磁辐射测量设备,对整机的辐射发射进行精确测量。确认是否存在RE超标问题,并记录超标频段和幅度。二、电缆检查与处理若存在信号电缆:步骤一:拔掉所有信号电缆,仅保留电源线,再次测量整机的辐射发射。若测量合格:判定问题出在信号电缆上,可能是电缆的共模电流导致。逐一连接信号电缆,每次连接后测量,定位具体哪根电缆或接口导致超标。对问题电缆进行处理,如加共模扼流圈、滤波器,或优化电缆布局和屏蔽。重新连接所有电缆,再次测量
    时源芯微 2024-12-11 17:11 99浏览
  • 【萤火工场CEM5826-M11测评】OLED显示雷达数据本文结合之前关于串口打印雷达监测数据的研究,进一步扩展至 OLED 屏幕显示。该项目整体分为两部分: 一、框架显示; 二、数据采集与填充显示。为了减小 MCU 负担,采用 局部刷新 的方案。1. 显示框架所需库函数 Wire.h 、Adafruit_GFX.h 、Adafruit_SSD1306.h . 代码#include #include #include #include "logo_128x64.h"#include "logo_
    无垠的广袤 2024-12-10 14:03 74浏览
  • 一、SAE J1939协议概述SAE J1939协议是由美国汽车工程师协会(SAE,Society of Automotive Engineers)定义的一种用于重型车辆和工业设备中的通信协议,主要应用于车辆和设备之间的实时数据交换。J1939基于CAN(Controller Area Network)总线技术,使用29bit的扩展标识符和扩展数据帧,CAN通信速率为250Kbps,用于车载电子控制单元(ECU)之间的通信和控制。小北同学在之前也对J1939协议做过扫盲科普【科普系列】SAE J
    北汇信息 2024-12-11 15:45 99浏览
  • 我的一台很多年前人家不要了的九十年代SONY台式组合音响,接手时只有CD功能不行了,因为不需要,也就没修,只使用收音机、磁带机和外接信号功能就够了。最近五年在外地,就断电闲置,没使用了。今年9月回到家里,就一个劲儿地忙着收拾家当,忙了一个多月,太多事啦!修了电气,清理了闲置不用了的电器和电子,就是一个劲儿地扔扔扔!几十年的“工匠式”收留收藏,只能断舍离,拆解不过来的了。一天,忽然感觉室内有股臭味,用鼻子的嗅觉功能朝着臭味重的方向寻找,觉得应该就是这台组合音响?怎么会呢?这无机物的东西不会腐臭吧?
    自做自受 2024-12-10 16:34 153浏览
我要评论
0
点击右上角,分享到朋友圈 我知道啦
请使用浏览器分享功能 我知道啦