数据集应用|快速入门数据增强方法Mixup,显著提升图像识别准确度

OpenCV学堂 2021-08-14 21:01

作者|Ta-Ying Cheng,牛津大学博士研究生,Medium技术博主,多篇文章均被平台官方刊物Towards Data Science收录

翻译|颂贤

关于格物钛

格物钛公开数据集提供海量优质公开数据集搜索、数据托管、一站式搜索,与全球AI开发者共创公开数据集社区。本文介绍的是以格物钛公开数据集平台中的CIFAR-10数据集为基础,通过数据增强方法Mixup,显著提升图像识别准确度。

深度学习蓬勃发展的这几年来,图像分类一直是最为火热的领域之一。传统上的图像识别严重依赖像是扩张/侵蚀或者是频域变换这样的处理方法,但特征提取的困难性限制了这些方法的进步空间。现如今的神经网络则显著提高了图像识别的准确率,因为神经网络能够寻找输入图像和输出标签之间的关系,并以此不断地调整它的识别策略。

然而,神经网络往往需要大量的数据进行训练,而优质的训练数据并不是唾手可得的。因此现在许多人都在研究如何能够实现所谓的数据增强(Data augmentation),即在一个已有的小数据集中凭空增加数据量,来达到以一敌百的效果。本文就将带大家认识一种简单而有效的数据增强策略Mixup,并介绍直接在PyTorch中实现Mixup的方法。

为什么需要数据增强?

神经网络架构内的参数是根据给定的数据进行训练和更新的。但由于训练数据只覆盖了某一部分可能数据的分布情况,网络很可能就会在分布的“能见”部分过度拟合。因此,我们拥有的训练数据越多,理论上就越能覆盖整个分布的情况(这也正是为什么以数据为中心的AI(data-centric AI)非常重要)。当然,在数据量有限的情况下,我们也并不是没有办法。通过数据增强,我们就可以尝试通过微调原有数据的方式产生新数据,并将其作为“新”样本送入网络进行训练。

什么是Mixup?

图1:Mixup的简易演示图

假设我们现在要做的事情是给猫和狗的图片做分类,并且我们已经有了一组标注好了是猫是狗的数据(例如[1, 0] -> 狗, [0, 1] -> 猫),那么Mixup简单来说就是将两张图像及其标签平均化为一个新数据。

具体而言,我们可以用数学公式写出Mixup的概念:

其中,xy分别是混合xi(标签为yᵢ)和xⱼ(标签为y)后的图像和标签,而λ则是从给定的贝塔分布中取得的随机数。

由此,Mixup能够为我们提供不同数据类别之间的连续数据样本,并因此直接扩大了给定训练集的分布,从而使网络在测试阶段更加强大。

Mixup的万用性

Mixup其实只是一种数据增强方法,它和任何用于分类的网络架构都是正交的。也就是说,我们可以在任何要进行分类任务的网络中对相应的数据集使用Mixup方法。Mixup的提出者张宏毅等人基于其最初发表的论文《mixup: Beyond Empirical Risk Minimization》对多个数据集和架构进行了实验,发现了Mixup在神经网络之外的应用中也能体现其强大能力。

计算环境

我们将通过PyTorch(包括torchvision)来构建整个程序。Mixup需要的从beta分布中生成的样本,我们可以从NumPy库中获得。我们还将使用random来为Mixup寻找随机图像。下面的代码能够导入我们需要的所有库:

"""
Import necessary libraries to train a network using mixup
The code is mainly developed using the PyTorch library
"""

import numpy as np
import pickle
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader

数据集

为了演示,我们将用传统的图像分类任务来说明Mixup的强大,那么这种情况下CIFAR-10则会是非常理想的数据集。CIFAR-10包含10个类别的60000张彩色图像(每类6000张),按5:1的比例分为训练和测试集。这些图像分类起来相当简单,但比最基本的数字识别数据集MNIST要难一些。

有许多方法可以下载CIFAR-10数据集,比如多伦多大学网站里就包含了相关数据集。在这里,我推荐大家使用格物钛的公开数据集平台(graviti.cn/open-datasets),因为在这个平台上,如果使用他们的SDK,不用下载也可以获取免费的数据集资源。事实上,这个公开数据集平台包含了行业内数百个知名的优质数据集,每个数据集都有相关的作者说明,以及不同训练任务的标签,例如分类或目标检测。当然,大家也可以在这个平台下载其他分类数据集,如CompCars或SVHN,来测试Mixup在不同场景下的性能。

硬件要求

一般来说,我们最好用GPU(显卡)来训练神经网络,因为它能显著提高训练速度。不过如果只有CPU可用,我们还是可以对程序进行简单测试的。如果你想让程序能够自行确定所需硬件,使用以下代码即可:

"""
Determine if any GPUs are available
"""

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

实现

网络

这里,我们的目标是要测试Mixup的性能,而不是调试网络本身,所以我们只需要简单实现一个4层卷积层和2层全连接层的卷积神经网络(CNN)即可。为了比较使用和不使用Mixup的区别,我们将应用同一个网络来确保比较的准确性。

我们可以使用下列代码来搭建上面所说的简单网络:

"""
Create a simple CNN
"""

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()

        # Network consists of 4 convolutional layers followed by 2 fully-connected layers
        self.conv11 = nn.Conv2d(3643)
        self.conv12 = nn.Conv2d(64643)
        self.conv21 = nn.Conv2d(641283)
        self.conv22 = nn.Conv2d(1281283)
        self.fc1 = nn.Linear(128 * 5 * 5256)
        self.fc2 = nn.Linear(25610)
    def forward(self, x):
       x = F.relu(self.conv11(x))
       x = F.relu(self.conv12(x))
       x = F.max_pool2d(x, (2,2))
       x = F.relu(self.conv21(x))
       x = F.relu(self.conv22(x))
       x = F.max_pool2d(x, (2,2))

       # Size is calculated based on kernel size 3 and padding 0
       x = x.view(-1128 * 5 * 5)
       x = F.relu(self.fc1(x))
       x = self.fc2(x)

       return nn.Sigmoid()(x)

Mixup

Mixup阶段是在数据集加载过程中完成的,所以我们必须写入我们自己的数据集,而不是使用torchvision.datasets所提供的默认数据集。

下面的代码简单地实现了Mixup,并结合使用了NumPy的贝塔函数。

"""
Dataset and Dataloader creation
All data are downloaded found via Graviti Open Dataset which links to CIFAR-10 official page
The dataset implementation is where mixup take place
"""


class CIFAR_Dataset(Dataset):
    def __init__(self, data_dir, train, transform):
        self.data_dir = data_dir
        self.train = train
        self.transform = transform
        self.data = []
        self.targets = []

        # Loading all the data depending on whether the dataset is training or testing
        if self.train:
            for i in range(5):
                with open(data_dir + 'data_batch_' + str(i+1), 'rb'as f:
                    entry = pickle.load(f, encoding='latin1')
                    self.data.append(entry['data'])
                    self.targets.extend(entry['labels'])
        else:
            with open(data_dir + 'test_batch''rb'as f:
                entry = pickle.load(f, encoding='latin1')
                self.data.append(entry['data'])
                self.targets.extend(entry['labels'])

        # Reshape it and turn it into the HWC format which PyTorch takes in the images
        # Original CIFAR format can be seen via its official page
        self.data = np.vstack(self.data).reshape(-133232)
        self.data = self.data.transpose((0231))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):

        # Create a one hot label
        label = torch.zeros(10)
        label[self.targets[idx]] = 1.

        # Transform the image by converting to tensor and normalizing it
        if self.transform:
            image = transform(self.data[idx])

        # If data is for training, perform mixup, only perform mixup roughly on 1 for every 5 images
        if self.train and idx > 0 and idx%5 == 0:

            # Choose another image/label randomly
            mixup_idx = random.randint(0, len(self.data)-1)
            mixup_label = torch.zeros(10)
            label[self.targets[mixup_idx]] = 1.
            if self.transform:
                mixup_image = transform(self.data[mixup_idx])

            # Select a random number from the given beta distribution
            # Mixup the images accordingly
            alpha = 0.2
            lam = np.random.beta(alpha, alpha)
            image = lam * image + (1 - lam) * mixup_image
            label = lam * label + (1 - lam) * mixup_label

        return image, label

需要注意的是,我们并没有对所有的图像都进行Mixup,而是大概每5张处理1张。我们还使用了一个0.2的贝塔分布。你可以自己为不同的实验改变分布以及被混合的图像的数量,或许你会取得更好的结果!

训练和评估

下面的代码展示的是训练过程。我们将批次大小设置为128,学习率为1e-3,总次数为30次。整个训练进行了两次,唯一区别是有没有使用Mixup。需要注意的是, 损失函数需要由我们自己定义,因为目前BCE损失不允许使用带有小数的标签。

"""
Initialize the network, loss Adam optimizer
Torch BCE Loss does not support mixup labels (not 1 or 0), so we implement our own
"""

net = CNN().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=LEARNING_RATE)
def bceloss(x, y):
    eps = 1e-6
    return -torch.mean(y * torch.log(x + eps) + (1 - y) * torch.log(1 - x + eps))
best_Acc = 0


"""
Training Procedure
"""

for epoch in range(NUM_EPOCHS):
    net.train()
    # We train and visualize the loss every 100 iterations
    for idx, (imgs, labels) in enumerate(train_dataloader):
        imgs = imgs.to(device)
        labels = labels.to(device)
        preds = net(imgs)
        loss = bceloss(preds, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if idx%100 == 0:
            print("Epoch {} Iteration {}, Current Loss: {}".format(epoch, idx, loss))

    # We evaluate the network after every epoch based on test set accuracy
    net.eval()
    with torch.no_grad():
        total = 0
        numCorrect = 0
        for (imgs, labels) in test_dataloader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            preds = net(imgs)
            numCorrect += (torch.argmax(preds, dim=1) == torch.argmax(labels, dim=1)).float().sum()
            total += len(imgs)
        acc = numCorrect/total
        print("Current image classification accuracy at epoch {}: {}".format(epoch, acc))
        if acc > best_Acc:
            best_Acc = acc

为了评估Mixup的效果,我们进行了三次对照试验来计算最终的准确性。在没有Mixup的情况下,该网络在测试集上的准确率约为74.5%,而在使用了Mixup的情况下,准确率提高到了约76.5%!

图像分类之外

Mixup将图像分类的准确性带到了一个前所未有的高度,但研究表明,Mixup的好处还能延伸到其他计算机视觉任务中,比如对抗性数据的生成和防御。另外也有相关文献在Mixup拓展到三维表示中,目前的结果表明Mixup在这一领域也十分有效的,例如PointMixup(arxiv.org/abs/2008.06374)。

结语

由此,我们用Mixup做的小实验就大功告成啦!在这篇文章中,我们简单介绍了Mixup的概念并演示了如何在图像分类网络训练中应用Mixup。完整的实现方式可以在这—GitHub仓库(github.com/ttchengab/mixup.git)中找到。

Open Datasets

格物钛 | 公开数据集

graviti.cn | open-datasets

订阅号 格物钛

微博 |   |  格物钛

https://www.graviti.cn/

OpenCV学堂 专注计算机视觉开发技术分享,技术框架使用,包括OpenCV,Tensorflow,Pytorch教程与案例,相关算法详解,最新CV方向论文,硬核代码干货与代码案例详解!作者在CV工程化方面深度耕耘15年,感谢您的关注!
评论
  • 作为优秀工程师的你,已身经百战、阅板无数!请先醒醒,新的项目来了,这是一个既要、又要、还要的产品需求,ARM核心板中一个处理器怎么能实现这么丰富的外围接口?踌躇之际,你偶阅此文。于是,“潘多拉”的魔盒打开了!没错,USB资源就是你打开新世界得钥匙,它能做哪些扩展呢?1.1  USB扩网口通用ARM处理器大多带两路网口,如果项目中有多路网路接口的需求,一般会选择在主板外部加交换机/路由器。当然,出于成本考虑,也可以将Switch芯片集成到ARM核心板或底板上,如KSZ9897、
    万象奥科 2024-12-03 10:24 68浏览
  • RDDI-DAP错误通常与调试接口相关,特别是在使用CMSIS-DAP协议进行嵌入式系统开发时。以下是一些可能的原因和解决方法: 1. 硬件连接问题:     检查调试器(如ST-Link)与目标板之间的连接是否牢固。     确保所有必要的引脚都已正确连接,没有松动或短路。 2. 电源问题:     确保目标板和调试器都有足够的电源供应。     检查电源电压是否符合目标板的规格要求。 3. 固件问题: &n
    丙丁先生 2024-12-01 17:37 102浏览
  • 光伏逆变器是一种高效的能量转换设备,它能够将光伏太阳能板(PV)产生的不稳定的直流电压转换成与市电频率同步的交流电。这种转换后的电能不仅可以回馈至商用输电网络,还能供独立电网系统使用。光伏逆变器在商业光伏储能电站和家庭独立储能系统等应用领域中得到了广泛的应用。光耦合器,以其高速信号传输、出色的共模抑制比以及单向信号传输和光电隔离的特性,在光伏逆变器中扮演着至关重要的角色。它确保了系统的安全隔离、干扰的有效隔离以及通信信号的精准传输。光耦合器的使用不仅提高了系统的稳定性和安全性,而且由于其低功耗的
    晶台光耦 2024-12-02 10:40 120浏览
  • 《高速PCB设计经验规则应用实践》+PCB绘制学习与验证读书首先看目录,我感兴趣的是这一节;作者在书中列举了一条经典规则,然后进行详细分析,通过公式推导图表列举说明了传统的这一规则是受到电容加工特点影响的,在使用了MLCC陶瓷电容后这一条规则已经不再实用了。图书还列举了高速PCB设计需要的专业工具和仿真软件,当然由于篇幅所限,只是介绍了一点点设计步骤;我最感兴趣的部分还是元件布局的经验规则,在这里列举如下:在这里,演示一下,我根据书本知识进行电机驱动的布局:这也算知行合一吧。对于布局书中有一句:
    wuyu2009 2024-11-30 20:30 125浏览
  • 学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习笔记&记录学习习笔记&记学习学习笔记&记录学习学习笔记&记录学习习笔记&记录学习学习笔记&记录学习学习笔记记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&
    youyeye 2024-11-30 14:30 78浏览
  • 戴上XR眼镜去“追龙”是种什么体验?2024年11月30日,由上海自然博物馆(上海科技馆分馆)与三湘印象联合出品、三湘印象旗下观印象艺术发展有限公司(下简称“观印象”)承制的《又见恐龙》XR嘉年华在上海自然博物馆重磅开幕。该体验项目将于12月1日正式对公众开放,持续至2025年3月30日。双向奔赴,恐龙IP撞上元宇宙不久前,上海市经济和信息化委员会等部门联合印发了《上海市超高清视听产业发展行动方案》,特别提到“支持博物馆、主题乐园等场所推动超高清视听技术应用,丰富线下文旅消费体验”。作为上海自然
    电子与消费 2024-11-30 22:03 98浏览
  • 概述 说明(三)探讨的是比较器一般带有滞回(Hysteresis)功能,为了解决输入信号转换速率不够的问题。前文还提到,即便使能滞回(Hysteresis)功能,还是无法解决SiPM读出测试系统需要解决的问题。本文在说明(三)的基础上,继续探讨为SiPM读出测试系统寻求合适的模拟脉冲检出方案。前四代SiPM使用的高速比较器指标缺陷 由于前端模拟信号属于典型的指数脉冲,所以下降沿转换速率(Slew Rate)过慢,导致比较器检出出现不必要的问题。尽管比较器可以使能滞回(Hysteresis)模块功
    coyoo 2024-12-03 12:20 111浏览
  •         温度传感器的精度受哪些因素影响,要先看所用的温度传感器输出哪种信号,不同信号输出的温度传感器影响精度的因素也不同。        现在常用的温度传感器输出信号有以下几种:电阻信号、电流信号、电压信号、数字信号等。以输出电阻信号的温度传感器为例,还细分为正温度系数温度传感器和负温度系数温度传感器,常用的铂电阻PT100/1000温度传感器就是正温度系数,就是说随着温度的升高,输出的电阻值会增大。对于输出
    锦正茂科技 2024-12-03 11:50 111浏览
  • TOF多区传感器: ND06   ND06是一款微型多区高集成度ToF测距传感器,其支持24个区域(6 x 4)同步测距,测距范围远达5m,具有测距范围广、精度高、测距稳定等特点。适用于投影仪的无感自动对焦和梯形校正、AIoT、手势识别、智能面板和智能灯具等多种场景。                 如果用ND06进行手势识别,只需要经过三个步骤: 第一步&
    esad0 2024-12-04 11:20 58浏览
  • 11-29学习笔记11-29学习笔记习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习笔记&记录学习习笔记&记学习学习笔记&记录学习学习笔记&记录学习习笔记&记录学习学习笔记&记录学习学习笔记记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&记录学习学习笔记&学习学习笔记&记录学习学习笔记&记录学习学习笔记&记
    youyeye 2024-12-02 23:58 73浏览
  • 最近几年,新能源汽车愈发受到消费者的青睐,其销量也是一路走高。据中汽协公布的数据显示,2024年10月,新能源汽车产销分别完成146.3万辆和143万辆,同比分别增长48%和49.6%。而结合各家新能源车企所公布的销量数据来看,比亚迪再度夺得了销冠宝座,其10月新能源汽车销量达到了502657辆,同比增长66.53%。众所周知,比亚迪是新能源汽车领域的重要参与者,其一举一动向来为外界所关注。日前,比亚迪汽车旗下品牌方程豹汽车推出了新车方程豹豹8,该款车型一上市就迅速吸引了消费者的目光,成为SUV
    刘旷 2024-12-02 09:32 119浏览
  • 当前,智能汽车产业迎来重大变局,随着人工智能、5G、大数据等新一代信息技术的迅猛发展,智能网联汽车正呈现强劲发展势头。11月26日,在2024紫光展锐全球合作伙伴大会汽车电子生态论坛上,紫光展锐与上汽海外出行联合发布搭载紫光展锐A7870的上汽海外MG量产车型,并发布A7710系列UWB数字钥匙解决方案平台,可应用于数字钥匙、活体检测、脚踢雷达、自动泊车等多种智能汽车场景。 联合发布量产车型,推动汽车智能化出海紫光展锐与上汽海外出行达成战略合作,联合发布搭载紫光展锐A7870的量产车型
    紫光展锐 2024-12-03 11:38 101浏览
  • 遇到部分串口工具不支持1500000波特率,这时候就需要进行修改,本文以触觉智能RK3562开发板修改系统波特率为115200为例,介绍瑞芯微方案主板Linux修改系统串口波特率教程。温馨提示:瑞芯微方案主板/开发板串口波特率只支持115200或1500000。修改Loader打印波特率查看对应芯片的MINIALL.ini确定要修改的bin文件#查看对应芯片的MINIALL.ini cat rkbin/RKBOOT/RK3562MINIALL.ini修改uart baudrate参数修改以下目
    Industio_触觉智能 2024-12-03 11:28 87浏览
我要评论
0
点击右上角,分享到朋友圈 我知道啦
请使用浏览器分享功能 我知道啦