ViT模型详解与Pytorch代码实现

OpenCV学堂 2025-01-16 17:31

点击上方蓝字关注我们

微信公众号:OpenCV学堂

关注获取更多计算机视觉与深度学习知识

免费领学习资料+微信:OpenCVXueTang_Asst

介绍

使用PyTorch 从头开始实现 ViT模型代码,在 CIFAR-10 数据集上训练ViT模型 以完成图像分类。

ViT的架构

ViT 的架构受到 BERT 的启发,BERT 是一种仅编码器的 transformer 模型,通常用于文本分类或命名实体识别等 NLP 监督学习任务。ViT 背后的主要思想是,图像可以看作是一系列的补丁,在 NLP 任务中可以被视为令牌

输入图像被分割成小块,然后被展平为向量序列。然后,这些向量由 transformer 编码器处理,它允许模型通过自注意力机制学习补丁之间的交互。然后,transformer 编码器的输出被馈送到一个分类层,该分类层输出输入图像的预测类别

代码实现

下面是模型各个部分组件的 PyTorch代码实现。

01

图像嵌入转换

为了将输入图像馈送到 Transformer 模型,我们需要将图像转换为一系列向量。这是通过将图像分割成一个不重叠的补丁网格来完成的,然后线性投影这些补丁以获得每个补丁的固定大小的嵌入向量。为此,我们可以使用 PyTorch 的层:nn.Conv2d
  1. class PatchEmbeddings(nn.Module):

  2. """

  3. Convert the image into patches and then project them into a vector space.

  4. """

  5. def __init__(self, config):

  6. super().__init__()

  7. self.image_size = config["image_size"]

  8. self.patch_size = config["patch_size"]

  9. self.num_channels = config["num_channels"]

  10. self.hidden_size = config["hidden_size"]

  11. # Calculate the number of patches from the image size and patch size

  12. self.num_patches = (self.image_size // self.patch_size) ** 2

  13. # Create a projection layer to convert the image into patches

  14. # The layer projects each patch into a vector of size hidden_size

  15. self.projection = nn.Conv2d(self.num_channels, self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size)

  16. def forward(self, x):

  17. # (batch_size, num_channels, image_size, image_size) -> (batch_size, num_patches, hidden_size)

  18. x = self.projection(x)

  19. x = x.flatten(2).transpose(1, 2)

  20. return x

kernel_size=self.patch_size并确保图层的滤镜应用于非重叠的面片。stride=self.patch_size在补丁转换为嵌入序列后,[CLS] 标记被添加到序列的开头,稍后将在分类层中用于对图像进行分类。[CLS] 令牌的嵌入是在训练期间学习的。

由于来自不同位置的补丁对最终预测的贡献可能不同,我们还需要一种方法将补丁位置编码到序列中。我们将使用可学习的位置嵌入向量将位置信息添加到嵌入向量中。这类似于在 Transformer 模型中为 NLP 任务使用位置嵌入的方式。
  1. class Embeddings(nn.Module):

  2. def __init__(self, config):

  3. super().__init__()

  4. self.config = config

  5. self.patch_embeddings = PatchEmbeddings(config)

  6. # Create a learnable [CLS] token

  7. # Similar to BERT, the [CLS] token is added to the beginning of the input sequence

  8. # and is used to classify the entire sequence

  9. self.cls_token = nn.Parameter(torch.randn(1, 1, config["hidden_size"]))

  10. # Create position embeddings for the [CLS] token and the patch embeddings

  11. # Add 1 to the sequence length for the [CLS] token

  12. self.position_embeddings = \

  13. nn.Parameter(torch.randn(1, self.patch_embeddings.num_patches + 1, config["hidden_size"]))

  14. self.dropout = nn.Dropout(config["hidden_dropout_prob"])

  15. def forward(self, x):

  16. x = self.patch_embeddings(x)

  17. batch_size, _, _ = x.size()

  18. # Expand the [CLS] token to the batch size

  19. # (1, 1, hidden_size) -> (batch_size, 1, hidden_size)

  20. cls_tokens = self.cls_token.expand(batch_size, -1, -1)

  21. # Concatenate the [CLS] token to the beginning of the input sequence

  22. # This results in a sequence length of (num_patches + 1)

  23. x = torch.cat((cls_tokens, x), dim=1)

  24. x = x + self.position_embeddings

  25. x = self.dropout(x)

  26. return x

在此步骤中,输入图像被转换为带有位置信息的嵌入序列,并准备馈送到 transformer 层。

02

多头注意力

在介绍 transformer 编码器之前,我们首先探索 multi-head attention module,这是它的核心组件。多头注意力用于计算输入图像中不同色块之间的交互。多头注意力由多个注意力头组成,每个注意力头都是一个注意力层。

让我们实现多头注意力模块的 head。该模块将一系列嵌入向量作为输入,并计算每个嵌入向量的查询向量、键向量和值向量。然后,使用查询和关键向量来计算每个标记的注意力权重。然后,使用注意力权重通过值向量的加权和来计算新的嵌入。我们可以将此机制视为数据库查询的软版本,其中查询向量在数据库中查找最相关的键向量,并检索值向量以计算查询输出。
  1. class AttentionHead(nn.Module):

  2. """

  3. A single attention head.

  4. This module is used in the MultiHeadAttention module.

  5. """

  6. def __init__(self, hidden_size, attention_head_size, dropout, bias=True):

  7. super().__init__()

  8. self.hidden_size = hidden_size

  9. self.attention_head_size = attention_head_size

  10. # Create the query, key, and value projection layers

  11. self.query = nn.Linear(hidden_size, attention_head_size, bias=bias)

  12. self.key = nn.Linear(hidden_size, attention_head_size, bias=bias)

  13. self.value = nn.Linear(hidden_size, attention_head_size, bias=bias)

  14. self.dropout = nn.Dropout(dropout)

  15. def forward(self, x):

  16. # Project the input into query, key, and value

  17. # The same input is used to generate the query, key, and value,

  18. # so it's usually called self-attention.

  19. # (batch_size, sequence_length, hidden_size) -> (batch_size, sequence_length, attention_head_size)

  20. query = self.query(x)

  21. key = self.key(x)

  22. value = self.value(x)

  23. # Calculate the attention scores

  24. # softmax(Q*K.T/sqrt(head_size))*V

  25. attention_scores = torch.matmul(query, key.transpose(-1, -2))

  26. attention_scores = attention_scores / math.sqrt(self.attention_head_size)

  27. attention_probs = nn.functional.softmax(attention_scores, dim=-1)

  28. attention_probs = self.dropout(attention_probs)

  29. # Calculate the attention output

  30. attention_output = torch.matmul(attention_probs, value)

  31. return (attention_output, attention_probs)


然后,所有注意力头的输出被连接起来并线性投影,以获得多头注意力模块的最终输出。

class MultiHeadAttention(nn.Module):    """    Multi-head attention module.    This module is used in the TransformerEncoder module.    """
def __init__(self, config): super().__init__() self.hidden_size = config["hidden_size"] self.num_attention_heads = config["num_attention_heads"] # The attention head size is the hidden size divided by the number of attention heads self.attention_head_size = self.hidden_size // self.num_attention_heads self.all_head_size = self.num_attention_heads * self.attention_head_size # Whether or not to use bias in the query, key, and value projection layers self.qkv_bias = config["qkv_bias"] # Create a list of attention heads self.heads = nn.ModuleList([]) for _ in range(self.num_attention_heads): head = AttentionHead( self.hidden_size, self.attention_head_size, config["attention_probs_dropout_prob"], self.qkv_bias ) self.heads.append(head) # Create a linear layer to project the attention output back to the hidden size # In most cases, all_head_size and hidden_size are the same self.output_projection = nn.Linear(self.all_head_size, self.hidden_size) self.output_dropout = nn.Dropout(config["hidden_dropout_prob"])

def forward(self, x, output_attentions=False): # Calculate the attention output for each attention head attention_outputs = [head(x) for head in self.heads] # Concatenate the attention outputs from each attention head attention_output = torch.cat([attention_output for attention_output, _ in attention_outputs], dim=-1) # Project the concatenated attention output back to the hidden size attention_output = self.output_projection(attention_output) attention_output = self.output_dropout(attention_output) # Return the attention output and the attention probabilities (optional) if not output_attentions: return (attention_output, None) else: attention_probs = torch.stack([attention_probs for _, attention_probs in attention_outputs], dim=1) return (attention_output, attention_probs)

03

编码器

编码器由一堆MHA + MLP组成。每个 transformer 层主要由我们刚刚实现的多头注意力模块和前馈网络组成。为了更好地扩展模型并稳定训练,向 transformer 层添加了两个 Layer 归一化层和跳过连接。

让我们实现一个 transformer 层(在代码中称为 ,因为它是 transformer 编码器的构建块)。我们将从前馈网络开始,这是一个简单的两层 MLP,中间有 GELU 激活。Block

class MLP(nn.Module):    """    A multi-layer perceptron module.    """    def __init__(self, config):        super().__init__()        self.dense_1 = nn.Linear(config["hidden_size"], config["intermediate_size"])        self.activation = NewGELUActivation()        self.dense_2 = nn.Linear(config["intermediate_size"], config["hidden_size"])        self.dropout = nn.Dropout(config["hidden_dropout_prob"])
def forward(self, x): x = self.dense_1(x) x = self.activation(x) x = self.dense_2(x) x = self.dropout(x)        return x
我们已经实现了多头注意力和 MLP,我们可以将它们组合起来创建变压器层。跳过连接和层标准化将应用于每个层的输入
class Block(nn.Module):    """    A single transformer block.    """
def __init__(self, config): super().__init__() self.attention = MultiHeadAttention(config) self.layernorm_1 = nn.LayerNorm(config["hidden_size"]) self.mlp = MLP(config)        self.layernorm_2 = nn.LayerNorm(config["hidden_size"])
def forward(self, x, output_attentions=False): # Self-attention attention_output, attention_probs = \ self.attention(self.layernorm_1(x), output_attentions=output_attentions) # Skip connection x = x + attention_output # Feed-forward network mlp_output = self.mlp(self.layernorm_2(x)) # Skip connection x = x + mlp_output # Return the transformer block's output and the attention probabilities (optional) if not output_attentions: return (x, None) else: return (x, attention_probs)

transformer 编码器按顺序堆叠多个 transformer 层:

class Encoder(nn.Module):    """    The transformer encoder module.    """
def __init__(self, config): super().__init__() # Create a list of transformer blocks self.blocks = nn.ModuleList([]) for _ in range(config["num_hidden_layers"]): block = Block(config)            self.blocks.append(block)
def forward(self, x, output_attentions=False): # Calculate the transformer block's output for each block all_attentions = [] for block in self.blocks: x, attention_probs = block(x, output_attentions=output_attentions) if output_attentions: all_attentions.append(attention_probs) # Return the encoder's output and the attention probabilities (optional) if not output_attentions: return (x, None) else: return (x, all_attentions)

04

ViT模型构建

将图像输入到 embedding 层和 transformer 编码器后,我们获得图像补丁和 [CLS] 标记的新嵌入。此时,嵌入在经过 transformer 编码器处理后应该有一些有用的信号用于分类。与 BERT 类似,我们将仅使用 [CLS] 标记的嵌入传递到分类层。

分类层是一个完全连接的层,它将 [CLS] 嵌入作为输入并输出每个图像的 logit。以下代码实现了用于图像分类的 ViT 模型:
class ViTForClassfication(nn.Module):
"""
The ViT model for classification.
"""


def __init__(self, config):
super().__init__()
self.config = config
self.image_size = config["image_size"]
self.hidden_size = config["hidden_size"]
self.num_classes = config["num_classes"]
# Create the embedding module
self.embedding = Embeddings(config)
# Create the transformer encoder module
self.encoder = Encoder(config)
# Create a linear layer to project the encoder's output to the number of classes
self.classifier = nn.Linear(self.hidden_size, self.num_classes)
# Initialize the weights
self.apply(self._init_weights)

def forward(self, x, output_attentions=False):
# Calculate the embedding output
embedding_output = self.embedding(x)
# Calculate the encoder's output
encoder_output, all_attentions = self.encoder(embedding_output, output_attentions=output_attentions)
# Calculate the logits, take the [CLS] token's output as features for classification
logits = self.classifier(encoder_output[:, 0])
# Return the logits and the attention probabilities (optional)
if not output_attentions:
return (logits, None)
else:
return (logits, all_attentions)


参考

代码其实是我从github上面整理加工跟翻译得到的(个人认为非常的通俗易懂,有点pytorch基础都可以看懂学会),感兴趣的可以看这里:

https://github.com/lukemelas/PyTorch-Pretrained-ViT/blob/master/pytorch_pretrained_vit/transformer.pyhttps://tintn.github.io/Implementing-Vision-Transformer-from-Scratch/


系统化学习QT5 + OpenCV4

原价:498

折扣:399


推荐阅读

OpenCV4.8+YOLOv8对象检测C++推理演示

ZXING+OpenCV打造开源条码检测应用

总结 | OpenCV4 Mat操作全接触

三行代码实现 TensorRT8.6 C++ 深度学习模型部署

实战 | YOLOv8+OpenCV 实现DM码定位检测与解析

对象检测边界框损失 – 从IOU到ProbIOU

YOLOv8 OBB实现自定义旋转对象检测

初学者必看 | 学习深度学习的五个误区

YOLOv8自定义数据集训练实现安全帽检测


OpenCV学堂 专注计算机视觉开发技术分享,技术框架使用,包括OpenCV,Tensorflow,Pytorch教程与案例,相关算法详解,最新CV方向论文,硬核代码干货与代码案例详解!作者在CV工程化方面深度耕耘15年,感谢您的关注!
评论
  • PNT、GNSS、GPS均是卫星定位和导航相关领域中的常见缩写词,他们经常会被用到,且在很多情况下会被等同使用或替换使用。我们会把定位导航功能测试叫做PNT性能测试,也会叫做GNSS性能测试。我们会把定位导航终端叫做GNSS模块,也会叫做GPS模块。但是实际上他们之间是有一些重要的区别。伴随着技术发展与越发深入,我们有必要对这三个词汇做以清晰的区分。一、什么是GPS?GPS是Global Positioning System(全球定位系统)的缩写,它是美国建立的全球卫星定位导航系统,是GNSS概
    德思特测试测量 2025-01-13 15:42 540浏览
  • 流量传感器是实现对燃气、废气、生活用水、污水、冷却液、石油等各种流体流量精准计量的关键手段。但随着工业自动化、数字化、智能化与低碳化进程的不断加速,采用传统机械式检测方式的流量传感器已不能满足当代流体计量行业对于测量精度、测量范围、使用寿命与维护成本等方面的精细需求。流量传感器的应用场景(部分)超声波流量传感器,是一种利用超声波技术测量流体流量的新型传感器,其主要通过发射超声波信号并接收反射回来的信号,根据超声波在流体中传播的时间、幅度或相位变化等参数,间接计算流体的流量,具有非侵入式测量、高精
    华普微HOPERF 2025-01-13 14:18 519浏览
  • 全球领先的光学解决方案供应商艾迈斯欧司朗(SIX:AMS)近日宣布,与汽车技术领先者法雷奥合作,采用创新的开放系统协议(OSP)技术,旨在改变汽车内饰照明方式,革新汽车行业座舱照明理念。结合艾迈斯欧司朗开创性的OSIRE® E3731i智能LED和法雷奥的动态环境照明系统,两家公司将为车辆内饰设计和功能设立一套全新标准。汽车内饰照明的作用日益凸显,座舱设计的主流趋势应满足终端用户的需求:即易于使用、个性化,并能提供符合用户生活方式的清晰信息。因此,动态环境照明带来了众多新机遇。智能LED的应用已
    艾迈斯欧司朗 2025-01-15 19:00 49浏览
  • 一个易用且轻量化的UI可以大大提高用户的使用效率和满意度——通过快速启动、直观操作和及时反馈,帮助用户快速上手并高效完成任务;轻量化设计则可以减少资源占用,提升启动和运行速度,增强产品竞争力。LVGL(Light and Versatile Graphics Library)是一个免费开源的图形库,专为嵌入式系统设计。它以轻量级、高效和易于使用而著称,支持多种屏幕分辨率和硬件配置,并提供了丰富的GUI组件,能够帮助开发者轻松构建出美观且功能强大的用户界面。近期,飞凌嵌入式为基于NXP i.MX9
    飞凌嵌入式 2025-01-16 13:15 61浏览
  • 百佳泰特为您整理2025年1月各大Logo的最新规格信息,本月有更新信息的logo有HDMI、Wi-Fi、Bluetooth、DisplayHDR、ClearMR、Intel EVO。HDMI®▶ 2025年1月6日,HDMI Forum, Inc. 宣布即将发布HDMI规范2.2版本。新规范将支持更高的分辨率和刷新率,并提供更多高质量选项。更快的96Gbps 带宽可满足数据密集型沉浸式和虚拟应用对传输的要求,如 AR/VR/MR、空间现实和光场显示,以及各种商业应用,如大型数字标牌、医疗成像和
    百佳泰测试实验室 2025-01-16 15:41 50浏览
  • 晶台光耦KL817和KL3053在小家电产品(如微波炉等)辅助电源中的广泛应用。具备小功率、高性能、高度集成以及低待机功耗的特点,同时支持宽输入电压范围。▲光耦在实物应用中的产品图其一次侧集成了交流电压过零检测与信号输出功能,该功能产生的过零信号可用于精确控制继电器、可控硅等器件的过零开关动作,从而有效减小开关应力,显著提升器件的使用寿命。通过高度的集成化和先进的控制技术,该电源大幅减少了所需的外围器件数量,不仅降低了系统成本和体积,还进一步增强了整体的可靠性。▲电路示意图该电路的过零检测信号由
    晶台光耦 2025-01-16 10:12 40浏览
  • 食物浪费已成为全球亟待解决的严峻挑战,并对环境和经济造成了重大影响。最新统计数据显示,全球高达三分之一的粮食在生产过程中损失或被无谓浪费,这不仅导致了资源消耗,还加剧了温室气体排放,并带来了巨大经济损失。全球领先的光学解决方案供应商艾迈斯欧司朗(SIX:AMS)近日宣布,艾迈斯欧司朗基于AS7341多光谱传感器开发的创新应用来解决食物浪费这一全球性难题。其多光谱传感解决方案为农业与食品行业带来深远变革,该技术通过精确判定最佳收获时机,提升质量控制水平,并在整个供应链中有效减少浪费。 在2024
    艾迈斯欧司朗 2025-01-14 18:45 96浏览
  • 随着智慧科技的快速发展,智能显示器的生态圈应用变得越来越丰富多元,智能显示器不仅仅是传统的显示设备,透过结合人工智能(AI)和语音助理,它还可以成为家庭、办公室和商业环境中的核心互动接口。提供多元且个性化的服务,如智能家居控制、影音串流拨放、实时信息显示等,极大提升了使用体验。此外,智能家居系统的整合能力也不容小觑,透过智能装置之间的无缝连接,形成了强大的多元应用生态圈。企业也利用智能显示器进行会议展示和多方远程合作,大大提高效率和互动性。Smart Display Ecosystem示意图,作
    百佳泰测试实验室 2025-01-16 15:37 45浏览
  •   在信号处理过程中,由于信号的时域截断会导致频谱扩展泄露现象。那么导致频谱泄露发生的根本原因是什么?又该采取什么样的改善方法。本文以ADC性能指标的测试场景为例,探讨了对ADC的输出结果进行非周期截断所带来的影响及问题总结。 两个点   为了更好的分析或处理信号,实际应用时需要从频域而非时域的角度观察原信号。但物理意义上只能直接获取信号的时域信息,为了得到信号的频域信息需要利用傅里叶变换这个工具计算出原信号的频谱函数。但对于计算机来说实现这种计算需要面对两个问题: 1.
    TIAN301 2025-01-14 14:15 138浏览
  • 数字隔离芯片是现代电气工程师在进行电路设计时所必须考虑的一种电子元件,主要用于保护低压控制电路中敏感电子设备的稳定运行与操作人员的人身安全。其不仅能隔离两个或多个高低压回路之间的电气联系,还能防止漏电流、共模噪声与浪涌等干扰信号的传播,有效增强电路间信号传输的抗干扰能力,同时提升电子系统的电磁兼容性与通信稳定性。容耦隔离芯片的典型应用原理图值得一提的是,在电子电路中引入隔离措施会带来传输延迟、功耗增加、成本增加与尺寸增加等问题,而数字隔离芯片的目标就是尽可能消除这些不利影响,同时满足安全法规的要
    华普微HOPERF 2025-01-15 09:48 119浏览
  • 近期,智能家居领域Matter标准的制定者,全球最具影响力的科技联盟之一,连接标准联盟(Connectivity Standards Alliance,简称CSA)“利好”频出,不仅为智能家居领域的设备制造商们提供了更为快速便捷的Matter认证流程,而且苹果、三星与谷歌等智能家居平台厂商都表示会接纳CSA的Matter认证体系,并计划将其整合至各自的“Works with”项目中。那么,在本轮“利好”背景下,智能家居的设备制造商们该如何捉住机会,“掘金”万亿市场呢?重认证快通道计划,为家居设备
    华普微HOPERF 2025-01-16 10:22 72浏览
  • 电竞鼠标应用环境与客户需求电竞行业近年来发展迅速,「鼠标延迟」已成为决定游戏体验与比赛结果的关键因素。从技术角度来看,传统鼠标的延迟大约为20毫秒,入门级电竞鼠标通常为5毫秒,而高阶电竞鼠标的延迟可降低至仅2毫秒。这些差异看似微小,但在竞技激烈的游戏中,尤其在对反应和速度要求极高的场景中,每一毫秒的优化都可能带来致胜的优势。电竞比赛的普及促使玩家更加渴望降低鼠标延迟以提升竞技表现。他们希望通过精确的测试,了解不同操作系统与设定对延迟的具体影响,并寻求最佳配置方案来获得竞技优势。这样的需求推动市场
    百佳泰测试实验室 2025-01-16 15:45 55浏览
  • 实用性高值得收藏!! (时源芯微)时源专注于EMC整改与服务,配备完整器件 TVS全称Transient Voltage Suppre,亦称TVS管、瞬态抑制二极管等,有单向和双向之分。单向TVS 一般应用于直流供电电路,双向TVS 应用于电压交变的电路。在直流电路的应用中,TVS被并联接入电路中。在电路处于正常运行状态时,TVS会保持截止状态,从而不对电路的正常工作产生任何影响。然而,一旦电路中出现异常的过电压,并且这个电压达到TVS的击穿阈值时,TVS的状态就会
    时源芯微 2025-01-16 14:23 71浏览
  • 故障现象 一辆2007款法拉利599 GTB车,搭载6.0 L V12自然吸气发动机(图1),累计行驶里程约为6万km。该车因发动机故障灯异常点亮进厂检修。 图1 发动机的布置 故障诊断接车后试车,发动机怠速轻微抖动,发动机故障灯长亮。用故障检测仪检测,发现发动机控制单元(NCM)中存储有故障代码“P0300 多缸失火”“P0309 气缸9失火”“P0307 气缸7失火”,初步判断发动机存在失火故障。考虑到该车使用年数较长,决定先使用虹科Pico汽车示波器进行相对压缩测试,以
    虹科Pico汽车示波器 2025-01-15 17:30 43浏览
我要评论
0
点击右上角,分享到朋友圈 我知道啦
请使用浏览器分享功能 我知道啦