ViT模型详解与Pytorch代码实现

OpenCV学堂 2025-01-16 17:31

点击上方蓝字关注我们

微信公众号:OpenCV学堂

关注获取更多计算机视觉与深度学习知识

免费领学习资料+微信:OpenCVXueTang_Asst

介绍

使用PyTorch 从头开始实现 ViT模型代码,在 CIFAR-10 数据集上训练ViT模型 以完成图像分类。

ViT的架构

ViT 的架构受到 BERT 的启发,BERT 是一种仅编码器的 transformer 模型,通常用于文本分类或命名实体识别等 NLP 监督学习任务。ViT 背后的主要思想是,图像可以看作是一系列的补丁,在 NLP 任务中可以被视为令牌

输入图像被分割成小块,然后被展平为向量序列。然后,这些向量由 transformer 编码器处理,它允许模型通过自注意力机制学习补丁之间的交互。然后,transformer 编码器的输出被馈送到一个分类层,该分类层输出输入图像的预测类别

代码实现

下面是模型各个部分组件的 PyTorch代码实现。

01

图像嵌入转换

为了将输入图像馈送到 Transformer 模型,我们需要将图像转换为一系列向量。这是通过将图像分割成一个不重叠的补丁网格来完成的,然后线性投影这些补丁以获得每个补丁的固定大小的嵌入向量。为此,我们可以使用 PyTorch 的层:nn.Conv2d
  1. class PatchEmbeddings(nn.Module):

  2. """

  3. Convert the image into patches and then project them into a vector space.

  4. """

  5. def __init__(self, config):

  6. super().__init__()

  7. self.image_size = config["image_size"]

  8. self.patch_size = config["patch_size"]

  9. self.num_channels = config["num_channels"]

  10. self.hidden_size = config["hidden_size"]

  11. # Calculate the number of patches from the image size and patch size

  12. self.num_patches = (self.image_size // self.patch_size) ** 2

  13. # Create a projection layer to convert the image into patches

  14. # The layer projects each patch into a vector of size hidden_size

  15. self.projection = nn.Conv2d(self.num_channels, self.hidden_size, kernel_size=self.patch_size, stride=self.patch_size)

  16. def forward(self, x):

  17. # (batch_size, num_channels, image_size, image_size) -> (batch_size, num_patches, hidden_size)

  18. x = self.projection(x)

  19. x = x.flatten(2).transpose(1, 2)

  20. return x

kernel_size=self.patch_size并确保图层的滤镜应用于非重叠的面片。stride=self.patch_size在补丁转换为嵌入序列后,[CLS] 标记被添加到序列的开头,稍后将在分类层中用于对图像进行分类。[CLS] 令牌的嵌入是在训练期间学习的。

由于来自不同位置的补丁对最终预测的贡献可能不同,我们还需要一种方法将补丁位置编码到序列中。我们将使用可学习的位置嵌入向量将位置信息添加到嵌入向量中。这类似于在 Transformer 模型中为 NLP 任务使用位置嵌入的方式。
  1. class Embeddings(nn.Module):

  2. def __init__(self, config):

  3. super().__init__()

  4. self.config = config

  5. self.patch_embeddings = PatchEmbeddings(config)

  6. # Create a learnable [CLS] token

  7. # Similar to BERT, the [CLS] token is added to the beginning of the input sequence

  8. # and is used to classify the entire sequence

  9. self.cls_token = nn.Parameter(torch.randn(1, 1, config["hidden_size"]))

  10. # Create position embeddings for the [CLS] token and the patch embeddings

  11. # Add 1 to the sequence length for the [CLS] token

  12. self.position_embeddings = \

  13. nn.Parameter(torch.randn(1, self.patch_embeddings.num_patches + 1, config["hidden_size"]))

  14. self.dropout = nn.Dropout(config["hidden_dropout_prob"])

  15. def forward(self, x):

  16. x = self.patch_embeddings(x)

  17. batch_size, _, _ = x.size()

  18. # Expand the [CLS] token to the batch size

  19. # (1, 1, hidden_size) -> (batch_size, 1, hidden_size)

  20. cls_tokens = self.cls_token.expand(batch_size, -1, -1)

  21. # Concatenate the [CLS] token to the beginning of the input sequence

  22. # This results in a sequence length of (num_patches + 1)

  23. x = torch.cat((cls_tokens, x), dim=1)

  24. x = x + self.position_embeddings

  25. x = self.dropout(x)

  26. return x

在此步骤中,输入图像被转换为带有位置信息的嵌入序列,并准备馈送到 transformer 层。

02

多头注意力

在介绍 transformer 编码器之前,我们首先探索 multi-head attention module,这是它的核心组件。多头注意力用于计算输入图像中不同色块之间的交互。多头注意力由多个注意力头组成,每个注意力头都是一个注意力层。

让我们实现多头注意力模块的 head。该模块将一系列嵌入向量作为输入,并计算每个嵌入向量的查询向量、键向量和值向量。然后,使用查询和关键向量来计算每个标记的注意力权重。然后,使用注意力权重通过值向量的加权和来计算新的嵌入。我们可以将此机制视为数据库查询的软版本,其中查询向量在数据库中查找最相关的键向量,并检索值向量以计算查询输出。
  1. class AttentionHead(nn.Module):

  2. """

  3. A single attention head.

  4. This module is used in the MultiHeadAttention module.

  5. """

  6. def __init__(self, hidden_size, attention_head_size, dropout, bias=True):

  7. super().__init__()

  8. self.hidden_size = hidden_size

  9. self.attention_head_size = attention_head_size

  10. # Create the query, key, and value projection layers

  11. self.query = nn.Linear(hidden_size, attention_head_size, bias=bias)

  12. self.key = nn.Linear(hidden_size, attention_head_size, bias=bias)

  13. self.value = nn.Linear(hidden_size, attention_head_size, bias=bias)

  14. self.dropout = nn.Dropout(dropout)

  15. def forward(self, x):

  16. # Project the input into query, key, and value

  17. # The same input is used to generate the query, key, and value,

  18. # so it's usually called self-attention.

  19. # (batch_size, sequence_length, hidden_size) -> (batch_size, sequence_length, attention_head_size)

  20. query = self.query(x)

  21. key = self.key(x)

  22. value = self.value(x)

  23. # Calculate the attention scores

  24. # softmax(Q*K.T/sqrt(head_size))*V

  25. attention_scores = torch.matmul(query, key.transpose(-1, -2))

  26. attention_scores = attention_scores / math.sqrt(self.attention_head_size)

  27. attention_probs = nn.functional.softmax(attention_scores, dim=-1)

  28. attention_probs = self.dropout(attention_probs)

  29. # Calculate the attention output

  30. attention_output = torch.matmul(attention_probs, value)

  31. return (attention_output, attention_probs)


然后,所有注意力头的输出被连接起来并线性投影,以获得多头注意力模块的最终输出。

class MultiHeadAttention(nn.Module):    """    Multi-head attention module.    This module is used in the TransformerEncoder module.    """
def __init__(self, config): super().__init__() self.hidden_size = config["hidden_size"] self.num_attention_heads = config["num_attention_heads"] # The attention head size is the hidden size divided by the number of attention heads self.attention_head_size = self.hidden_size // self.num_attention_heads self.all_head_size = self.num_attention_heads * self.attention_head_size # Whether or not to use bias in the query, key, and value projection layers self.qkv_bias = config["qkv_bias"] # Create a list of attention heads self.heads = nn.ModuleList([]) for _ in range(self.num_attention_heads): head = AttentionHead( self.hidden_size, self.attention_head_size, config["attention_probs_dropout_prob"], self.qkv_bias ) self.heads.append(head) # Create a linear layer to project the attention output back to the hidden size # In most cases, all_head_size and hidden_size are the same self.output_projection = nn.Linear(self.all_head_size, self.hidden_size) self.output_dropout = nn.Dropout(config["hidden_dropout_prob"])

def forward(self, x, output_attentions=False): # Calculate the attention output for each attention head attention_outputs = [head(x) for head in self.heads] # Concatenate the attention outputs from each attention head attention_output = torch.cat([attention_output for attention_output, _ in attention_outputs], dim=-1) # Project the concatenated attention output back to the hidden size attention_output = self.output_projection(attention_output) attention_output = self.output_dropout(attention_output) # Return the attention output and the attention probabilities (optional) if not output_attentions: return (attention_output, None) else: attention_probs = torch.stack([attention_probs for _, attention_probs in attention_outputs], dim=1) return (attention_output, attention_probs)

03

编码器

编码器由一堆MHA + MLP组成。每个 transformer 层主要由我们刚刚实现的多头注意力模块和前馈网络组成。为了更好地扩展模型并稳定训练,向 transformer 层添加了两个 Layer 归一化层和跳过连接。

让我们实现一个 transformer 层(在代码中称为 ,因为它是 transformer 编码器的构建块)。我们将从前馈网络开始,这是一个简单的两层 MLP,中间有 GELU 激活。Block

class MLP(nn.Module):    """    A multi-layer perceptron module.    """    def __init__(self, config):        super().__init__()        self.dense_1 = nn.Linear(config["hidden_size"], config["intermediate_size"])        self.activation = NewGELUActivation()        self.dense_2 = nn.Linear(config["intermediate_size"], config["hidden_size"])        self.dropout = nn.Dropout(config["hidden_dropout_prob"])
def forward(self, x): x = self.dense_1(x) x = self.activation(x) x = self.dense_2(x) x = self.dropout(x)        return x
我们已经实现了多头注意力和 MLP,我们可以将它们组合起来创建变压器层。跳过连接和层标准化将应用于每个层的输入
class Block(nn.Module):    """    A single transformer block.    """
def __init__(self, config): super().__init__() self.attention = MultiHeadAttention(config) self.layernorm_1 = nn.LayerNorm(config["hidden_size"]) self.mlp = MLP(config)        self.layernorm_2 = nn.LayerNorm(config["hidden_size"])
def forward(self, x, output_attentions=False): # Self-attention attention_output, attention_probs = \ self.attention(self.layernorm_1(x), output_attentions=output_attentions) # Skip connection x = x + attention_output # Feed-forward network mlp_output = self.mlp(self.layernorm_2(x)) # Skip connection x = x + mlp_output # Return the transformer block's output and the attention probabilities (optional) if not output_attentions: return (x, None) else: return (x, attention_probs)

transformer 编码器按顺序堆叠多个 transformer 层:

class Encoder(nn.Module):    """    The transformer encoder module.    """
def __init__(self, config): super().__init__() # Create a list of transformer blocks self.blocks = nn.ModuleList([]) for _ in range(config["num_hidden_layers"]): block = Block(config)            self.blocks.append(block)
def forward(self, x, output_attentions=False): # Calculate the transformer block's output for each block all_attentions = [] for block in self.blocks: x, attention_probs = block(x, output_attentions=output_attentions) if output_attentions: all_attentions.append(attention_probs) # Return the encoder's output and the attention probabilities (optional) if not output_attentions: return (x, None) else: return (x, all_attentions)

04

ViT模型构建

将图像输入到 embedding 层和 transformer 编码器后,我们获得图像补丁和 [CLS] 标记的新嵌入。此时,嵌入在经过 transformer 编码器处理后应该有一些有用的信号用于分类。与 BERT 类似,我们将仅使用 [CLS] 标记的嵌入传递到分类层。

分类层是一个完全连接的层,它将 [CLS] 嵌入作为输入并输出每个图像的 logit。以下代码实现了用于图像分类的 ViT 模型:
class ViTForClassfication(nn.Module):
"""
The ViT model for classification.
"""


def __init__(self, config):
super().__init__()
self.config = config
self.image_size = config["image_size"]
self.hidden_size = config["hidden_size"]
self.num_classes = config["num_classes"]
# Create the embedding module
self.embedding = Embeddings(config)
# Create the transformer encoder module
self.encoder = Encoder(config)
# Create a linear layer to project the encoder's output to the number of classes
self.classifier = nn.Linear(self.hidden_size, self.num_classes)
# Initialize the weights
self.apply(self._init_weights)

def forward(self, x, output_attentions=False):
# Calculate the embedding output
embedding_output = self.embedding(x)
# Calculate the encoder's output
encoder_output, all_attentions = self.encoder(embedding_output, output_attentions=output_attentions)
# Calculate the logits, take the [CLS] token's output as features for classification
logits = self.classifier(encoder_output[:, 0])
# Return the logits and the attention probabilities (optional)
if not output_attentions:
return (logits, None)
else:
return (logits, all_attentions)


参考

代码其实是我从github上面整理加工跟翻译得到的(个人认为非常的通俗易懂,有点pytorch基础都可以看懂学会),感兴趣的可以看这里:

https://github.com/lukemelas/PyTorch-Pretrained-ViT/blob/master/pytorch_pretrained_vit/transformer.pyhttps://tintn.github.io/Implementing-Vision-Transformer-from-Scratch/


系统化学习QT5 + OpenCV4

原价:498

折扣:399


推荐阅读

OpenCV4.8+YOLOv8对象检测C++推理演示

ZXING+OpenCV打造开源条码检测应用

总结 | OpenCV4 Mat操作全接触

三行代码实现 TensorRT8.6 C++ 深度学习模型部署

实战 | YOLOv8+OpenCV 实现DM码定位检测与解析

对象检测边界框损失 – 从IOU到ProbIOU

YOLOv8 OBB实现自定义旋转对象检测

初学者必看 | 学习深度学习的五个误区

YOLOv8自定义数据集训练实现安全帽检测


OpenCV学堂 专注计算机视觉开发技术分享,技术框架使用,包括OpenCV,Tensorflow,Pytorch教程与案例,相关算法详解,最新CV方向论文,硬核代码干货与代码案例详解!作者在CV工程化方面深度耕耘15年,感谢您的关注!
评论 (0)
  • 文/Leon编辑/cc孙聪颖‍2023年,厨电行业在相对平稳的市场环境中迎来温和复苏,看似为行业增长积蓄势能。带着对市场向好的预期,2024 年初,老板电器副董事长兼总经理任富佳为企业定下双位数增长目标。然而现实与预期相悖,过去一年,这家老牌厨电企业不仅未能达成业绩目标,曾提出的“三年再造一个老板电器”愿景,也因市场下行压力面临落空风险。作为“企二代”管理者,任富佳在掌舵企业穿越市场周期的过程中,正面临着前所未有的挑战。4月29日,老板电器(002508.SZ)发布了2024年年度报告及2025
    华尔街科技眼 2025-04-30 12:40 271浏览
  • 在电子电路设计和调试中,晶振为电路提供稳定的时钟信号。我们可能会遇到晶振有电压,但不起振,从而导致整个电路无法正常工作的情况。今天凯擎小妹聊一下可能的原因和解决方案。1. 误区解析在硬件调试中,许多工程师在测量晶振时发现两端都有电压,例如1.6V,但没有明显的压差,第一反应可能是怀疑短路。晶振电路本质上是一个交流振荡电路。当晶振未起振时,两端会静止在一个中间电位,通常接近电源电压的一半。万用表测得的是稳定的直流电压,因此没有压差。这种情况一般是:晶振没起振,并不是短路。2. 如何判断真
    koan-xtal 2025-04-28 05:09 294浏览
  • 一、gao效冷却与控温机制‌1、‌冷媒流动设计‌采用低压液氮(或液氦)通过毛细管路导入蒸发器,蒸汽喷射至样品腔实现快速冷却,冷却效率高(室温至80K约20分钟,至4.2K约30分钟)。通过控温仪动态调节蒸发器加热功率,结合温度传感器(如PT100铂电阻或Cernox磁场不敏感传感器),实现±0.01K的高精度温度稳定性。2、‌宽温区覆盖与扩展性‌标准温区为80K-325K,通过降压选件可将下限延伸至65K(液氮模式)或4K(液氦模式)。可选配475K高温模块,满足材料在ji端温度下的性能测试需求
    锦正茂科技 2025-04-30 13:08 371浏览
  • 文/郭楚妤编辑/cc孙聪颖‍越来越多的企业开始蚕食动力电池市场,行业“去宁王化”态势逐渐明显。随着这种趋势的加强,打开新的市场对于宁德时代而言至关重要。“我们不希望被定义为电池的制造者,而是希望把自己称作新能源产业的开拓者。”4月21日,在宁德时代举行的“超级科技日”发布会上,宁德时代掌门人曾毓群如是说。随着宁德时代核心新品骁遥双核电池的发布,其搭载的“电电增程”技术也走进业界视野。除此之外,经过近3年试水,宁德时代在换电业务上重资加码。曾毓群认为换电是一个重资产、高投入、长周期的产业,涉及的利
    华尔街科技眼 2025-04-28 21:55 197浏览
  •  探针台的维护直接影响其测试精度与使用寿命,需结合日常清洁、环境控制、定期校准等多维度操作,具体方法如下:一、日常清洁与保养1.‌表面清洁‌l 使用无尘布或软布擦拭探针台表面,避免残留清洁剂或硬物划伤精密部件。l 探针头清洁需用非腐蚀性溶剂(如异丙醇)擦拭,检查是否弯曲或损坏。2.‌光部件维护‌l 镜头、观察窗等光学部件用镜头纸蘸取wu水jiu精从中心向外轻擦,操作时远离火源并保持通风。3.‌内部防尘‌l 使用后及时吹扫灰尘,防止污染物进入机械滑
    锦正茂科技 2025-04-28 11:45 121浏览
  • 随着电子元器件的快速发展,导致各种常见的贴片电阻元器件也越来越小,给我们分辨也就变得越来越难,下面就由smt贴片加工厂_安徽英特丽就来告诉大家如何分辨的SMT贴片元器件。先来看看贴片电感和贴片电容的区分:(1)看颜色(黑色)——一般黑色都是贴片电感。贴片电容只有勇于精密设备中的贴片钽电容才是黑色的,其他普通贴片电容基本都不是黑色的。(2)看型号标码——贴片电感以L开头,贴片电容以C开头。从外形是圆形初步判断应为电感,测量两端电阻为零点几欧,则为电感。(3)检测——贴片电感一般阻值小,更没有“充放
    贴片加工小安 2025-04-29 14:59 297浏览
  • 你是不是也有在公共场合被偷看手机或笔电的经验呢?科技时代下,不少现代人的各式机密数据都在手机、平板或是笔电等可携式的3C产品上处理,若是经常性地需要在公共场合使用,不管是工作上的机密文件,或是重要的个人信息等,民众都有防窃防盗意识,为了避免他人窥探内容,都会选择使用「防窥保护贴片」,以防止数据外泄。现今市面上「防窥保护贴」、「防窥片」、「屏幕防窥膜」等产品就是这种目的下产物 (以下简称防窥片)!防窥片功能与常见问题解析首先,防窥片最主要的功能就是用来防止他人窥视屏幕上的隐私信息,它是利用百叶窗的
    百佳泰测试实验室 2025-04-30 13:28 480浏览
  • 网约车,真的“饱和”了?近日,网约车市场的 “饱和” 话题再度引发热议。多地陆续发布网约车风险预警,提醒从业者谨慎入局,这背后究竟隐藏着怎样的市场现状呢?从数据来看,网约车市场的“过剩”现象已愈发明显。以东莞为例,截至2024年12月底,全市网约车数量超过5.77万辆,考取网约车驾驶员证的人数更是超过13.48万人。随着司机数量的不断攀升,订单量却未能同步增长,导致单车日均接单量和营收双双下降。2024年下半年,东莞网约出租车单车日均订单量约10.5单,而单车日均营收也不容乐
    用户1742991715177 2025-04-29 18:28 272浏览
  • 浪潮之上:智能时代的觉醒    近日参加了一场课题的答辩,这是医疗人工智能揭榜挂帅的国家项目的地区考场,参与者众多,围绕着医疗健康的主题,八仙过海各显神通,百花齐放。   中国大地正在发生着激动人心的场景:深圳前海深港人工智能算力中心高速运转的液冷服务器,武汉马路上自动驾驶出租车穿行的智慧道路,机器人参与北京的马拉松竞赛。从中央到地方,人工智能相关政策和消息如雨后春笋般不断出台,数字中国的建设图景正在智能浪潮中徐徐展开,战略布局如同围棋
    广州铁金刚 2025-04-30 15:24 260浏览
  • 贞光科技代理品牌紫光国芯的车规级LPDDR4内存正成为智能驾驶舱的核心选择。在汽车电子国产化浪潮中,其产品以宽温域稳定工作能力、优异电磁兼容性和超长使用寿命赢得市场认可。紫光国芯不仅确保供应链安全可控,还提供专业本地技术支持。面向未来,紫光国芯正研发LPDDR5车规级产品,将以更高带宽、更低功耗支持汽车智能化发展。随着智能网联汽车的迅猛发展,智能驾驶舱作为人机交互的核心载体,对处理器和存储器的性能与可靠性提出了更高要求。在汽车电子国产化浪潮中,贞光科技代理品牌紫光国芯的车规级LPDDR4内存凭借
    贞光科技 2025-04-28 16:52 317浏览
  • 晶振在使用过程中可能会受到污染,导致性能下降。可是污染物是怎么进入晶振内部的?如何检测晶振内部污染物?我可不可以使用超声波清洗?今天KOAN凯擎小妹将逐一解答。1. 污染物来源a. 制造过程:生产环境不洁净或封装密封不严,可能导致灰尘和杂质进入晶振。b. 使用环境:高湿度、温度变化、化学物质和机械应力可能导致污染物渗入。c. 储存不当:不良的储存环境和不合适的包装材料可能引发化学物质迁移。建议储存湿度维持相对湿度在30%至75%的范围内,有助于避免湿度对晶振的不利影响。避免雨淋或阳光直射。d.
    koan-xtal 2025-04-28 06:11 165浏览
  • 4月22日下午,备受瞩目的飞凌嵌入式「2025嵌入式及边缘AI技术论坛」在深圳深铁皇冠假日酒店盛大举行,此次活动邀请到了200余位嵌入式技术领域的技术专家、企业代表和工程师用户,共享嵌入式及边缘AI技术的盛宴!1、精彩纷呈的展区产品及方案展区是本场活动的第一场重头戏,从硬件产品到软件系统,从企业级应用到高校教学应用,都吸引了现场来宾的驻足观看和交流讨论。全产品矩阵展区展示了飞凌嵌入式丰富的产品线,从嵌入式板卡到工控机,从进口芯片平台到全国产平台,无不体现出飞凌嵌入式在嵌入式主控设备研发设计方面的
    飞凌嵌入式 2025-04-28 14:43 180浏览
  • 在智能硬件设备趋向微型化的背景下,语音芯片方案厂商针对小体积设备开发了多款超小型语音芯片方案,其中WTV系列和WT2003H系列凭借其QFN封装设计、高性能与高集成度,成为微型设备语音方案的理想选择。以下从封装特性、功能优势及典型应用场景三个方面进行详细介绍。一、超小体积封装:QFN技术的核心优势WTV系列与WT2003H系列均提供QFN封装(如QFN32,尺寸为4×4mm),这种封装形式具有以下特点:体积紧凑:QFN封装通过减少引脚间距和优化内部结构,显著缩小芯片体积,适用于智能门铃、穿戴设备
    广州唯创电子 2025-04-30 09:02 321浏览
  • 在CAN总线分析软件领域,当CANoe不再是唯一选择时,虹科PCAN-Explorer 6软件成为了一个有竞争力的解决方案。在现代工业控制和汽车领域,CAN总线分析软件的重要性不言而喻。随着技术的进步和市场需求的多样化,单一的解决方案已无法满足所有用户的需求。正是在这样的背景下,虹科PCAN-Explorer 6软件以其独特的模块化设计和灵活的功能扩展,为CAN总线分析领域带来了新的选择和可能性。本文将深入探讨虹科PCAN-Explorer 6软件如何以其创新的模块化插件策略,提供定制化的功能选
    虹科汽车智能互联 2025-04-28 16:00 226浏览
  • 一、智能家居的痛点与创新机遇随着城市化进程加速,现代家庭正面临两大核心挑战:情感陪伴缺失:超60%的双职工家庭存在“亲子陪伴真空期”,儿童独自居家场景增加;操作复杂度攀升:智能设备功能迭代导致用户学习成本陡增,超40%用户因操作困难放弃高阶功能。而WTR096-16S录音语音芯片方案,通过“语音交互+智能录音”双核驱动,不仅解决设备易用性问题,更构建起家庭成员间的全天候情感纽带。二、WTR096-16S方案的核心技术突破1. 高保真语音交互系统动态情绪语音库:支持8种语气模板(温柔提醒/紧急告警
    广州唯创电子 2025-04-28 09:24 193浏览
我要评论
0
0
点击右上角,分享到朋友圈 我知道啦
请使用浏览器分享功能 我知道啦