金字塔ViT|华为提出使用金字塔结构改进Transformer,涨点明显(Pytorch逐行解读)

OpenCV学堂 2022-01-09 23:43

Transformer在计算机视觉任务方面取得了很大的进展。Transformer-in-Transformer (TNT)体系结构利用内部Transformer和外部Transformer来提取局部和全局表示。在这项工作中,通过引入2种先进的设计来提出新的TNT Baseline:

  1. Pyramid Architecture

  2. Convolutional Stem

新的“PyramidTNT”通过建立层次表示,显著地改进了原来的TNT。PyramidTNT相较于之前最先进的Vision Transformer具有更好的性能,如Swin-Transformer。

1简介

Vision Transformer为计算机视觉提供了一种新的解决思路。从ViT开始,提出了一系列改进Vision Transformer体系结构的工作。

  • PVT介绍了Vision Transformer的金字塔网络体系结构

  • T2T-ViT-14 递归地将相邻的Token聚合为一个Token,以提取局部结构,减少Token的数量

  • TNT 利用 inner Transformer和outer Transformer来建模 word-level 和 sentence-level 的视觉表示

  • Swin-Transformer提出了一种分层Transformer,其表示由Shifted windows来进行计算

随着近年来的研究进展,Vision Transformer的性能已经可以优于卷积神经网络(CNN)。而本文的这项工作是建立了基于TNT框架的改进的 Vision Transformer Baseline。这里主要引入了两个主要的架构修改:

  1. Pyramid Architecture:逐渐降低分辨率,提取多尺度表示

  2. Convolutional Stem:修补Stem和稳定训练

这里作者还使用了几个其他技巧来进一步提高效率。新的Transformer被命名为PyramidTNT

对图像分类和目标检测的实验证明了金字塔检测的优越性。具体来说,PyramidTNT-S在只有3.3B FLOPs的情况下获得了82.0%的ImageNet分类准确率,明显优于原来的TNT-S和Swin-T。

对于COCO检测,PyramidTNT-S比现有的Transformer和MLP检测模型以更少的计算成本实现42.0的mAP。

2本文方法

2.1 Convolutional Stem

给定一个输入图像,TNT模型首先将图像分割成多个patch,并进一步将每个patch视为一个sub-patch序列。然后应用线性层将sub-patch投射到visual word vector(又称token)。这些视觉word被拼接在一起并转换成一个visual sentence vector。

肖奥等人发现在ViT中使用多个卷积作为Stem可以提高优化稳定性,也能提高性能。在此基础上,本文构造了一个金字塔的卷积Stem。利用3×3卷积的堆栈产生visual word vector ,其中C是visual word vector的维度。同样也可以得到visual sentence vector ,其中D是visual sentence vector 的维度。word-level 和 sentence-level位置编码分别添加到visual words和sentences上,和原始的TNT一样。

class Stem(nn.Module):
    """ 
    Image to Visual Word Embedding
    """

    def __init__(self, img_size=224, in_chans=3, outer_dim=768, inner_dim=24):
        super().__init__()
        img_size = to_2tuple(img_size)
        self.img_size = img_size
        self.inner_dim = inner_dim
        self.num_patches = img_size[0] // 8 * img_size[1] // 8
        self.num_words = 16
        
        self.common_conv = nn.Sequential(
            nn.Conv2d(in_chans, inner_dim*23, stride=2, padding=1),
            nn.BatchNorm2d(inner_dim*2),
            nn.ReLU(inplace=True),
        )
        # 利用 inner Transformer来建模 word-level
        self.inner_convs = nn.Sequential(
            nn.Conv2d(inner_dim*2, inner_dim, 3, stride=1, padding=1),
            nn.BatchNorm2d(inner_dim),
            nn.ReLU(inplace=False),
        )
        # 利用outer Transformer来建模 sentence-level 的视觉表示
        self.outer_convs = nn.Sequential(
            nn.Conv2d(inner_dim*2, inner_dim*43, stride=2, padding=1),
            nn.BatchNorm2d(inner_dim*4),
            nn.ReLU(inplace=True),
            nn.Conv2d(inner_dim*4, inner_dim*83, stride=2, padding=1),
            nn.BatchNorm2d(inner_dim*8),
            nn.ReLU(inplace=True),
            nn.Conv2d(inner_dim*8, outer_dim, 3, stride=1, padding=1),
            nn.BatchNorm2d(outer_dim),
            nn.ReLU(inplace=False),
        )
        
        self.unfold = nn.Unfold(kernel_size=4, padding=0, stride=4)

    def forward(self, x):
        B, C, H, W = x.shape
        H_out, W_out = H // 8, W // 8
        H_in, W_in = 44
        x = self.common_conv(x)
        # inner_tokens建模word level表征
        inner_tokens = self.inner_convs(x) # B, C, H, W
        inner_tokens = self.unfold(inner_tokens).transpose(12# B, N, Ck2
        inner_tokens = inner_tokens.reshape(B * H_out * W_out, self.inner_dim, H_in*W_in).transpose(12# B*N, C, 4*4
        # outer_tokens建模 sentence level表征
        outer_tokens = self.outer_convs(x) # B, C, H_out, W_out
        outer_tokens = outer_tokens.permute(0231).reshape(B, H_out * W_out, -1)
        return inner_tokens, outer_tokens, (H_out, W_out), (H_in, W_in)

2.2  Pyramid Architecture

原始的TNT网络在继ViT之后的每个块中保持相同数量的token。visual words和visual sentences的数量从下到上保持不变。

本文受PVT的启发,为TNT构建了4个不同数量的Token阶段,如图1(b)。所示在这4个阶段中,visual words的空间形状分别设置为H/2×W/2、H/4×W/4、H/8×W/8、H/16×W/16;visual sentences的空间形状分别设置为H/8×W/8、H/16×W/16、H/32×W/32、H/64×W/64。下采样操作是通过stride=2的卷积来实现的。每个阶段由几个TNT块组成,TNT块在word-level 和 sentence-level特征上操作。最后,利用全局平均池化操作,将输出的visual sentences融合成一个向量作为图像表示。

class SentenceAggregation(nn.Module):
    """ 
    Sentence Aggregation
    """

    def __init__(self, dim_in, dim_out, stride=2, act_layer=nn.GELU):
        super().__init__()
        self.stride = stride
        self.norm = nn.LayerNorm(dim_in)
        self.conv = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=2*stride-1, padding=stride-1, stride=stride),
        )
        
    def forward(self, x, H, W):
        B, N, C = x.shape # B, N, C
        x = self.norm(x)
        x = x.transpose(12).reshape(B, C, H, W)
        x = self.conv(x)
        H, W = math.ceil(H / self.stride), math.ceil(W / self.stride)
        x = x.reshape(B, -1, H * W).transpose(12)
        return x, H, W


class WordAggregation(nn.Module):
    """ 
    Word Aggregation
    """

    def __init__(self, dim_in, dim_out, stride=2, act_layer=nn.GELU):
        super().__init__()
        self.stride = stride
        self.dim_out = dim_out
        self.norm = nn.LayerNorm(dim_in)
        self.conv = nn.Sequential(
            nn.Conv2d(dim_in, dim_out, kernel_size=2*stride-1, padding=stride-1, stride=stride),
        )

    def forward(self, x, H_out, W_out, H_in, W_in):
        B_N, M, C = x.shape # B*N, M, C
        x = self.norm(x)
        x = x.reshape(-1, H_out, W_out, H_in, W_in, C)
        
        # padding to fit (1333, 800) in detection.
        pad_input = (H_out % 2 == 1or (W_out % 2 == 1)
        if pad_input:
            x = F.pad(x.permute(034512), (0, W_out % 20, H_out % 2))
            x = x.permute(045123)            
        # patch merge
        x1 = x[:, 0::20::2, :, :, :]  # B, H/2, W/2, H_in, W_in, C
        x2 = x[:, 1::20::2, :, :, :]
        x3 = x[:, 0::21::2, :, :, :]
        x4 = x[:, 1::21::2, :, :, :]
        x = torch.cat([torch.cat([x1, x2], 3), torch.cat([x3, x4], 3)], 4# B, H/2, W/2, 2*H_in, 2*W_in, C
        x = x.reshape(-12*H_in, 2*W_in, C).permute(0312# B_N/4, C, 2*H_in, 2*W_in
        x = self.conv(x)  # B_N/4, C, H_in, W_in
        x = x.reshape(-1, self.dim_out, M).transpose(12)
        return x
    

class Stage(nn.Module):
    """ 
    PyramidTNT stage
    """

    def __init__(self, num_blocks, outer_dim, inner_dim, outer_head, inner_head, num_patches, num_words, mlp_ratio=4.,
                 qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU,
                 norm_layer=nn.LayerNorm, se=0, sr_ratio=1)
:

        super().__init__()
        blocks = []
        drop_path = drop_path if isinstance(drop_path, list) else [drop_path] * num_blocks
        
        for j in range(num_blocks):
            if j == 0:
                _inner_dim = inner_dim
            elif j == 1 and num_blocks > 6:
                _inner_dim = inner_dim
            else:
                _inner_dim = -1
            blocks.append(Block(
                outer_dim, _inner_dim, outer_head=outer_head, inner_head=inner_head,
                num_words=num_words, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop,
                attn_drop=attn_drop, drop_path=drop_path[j], act_layer=act_layer, norm_layer=norm_layer,
                se=se, sr_ratio=sr_ratio))

        self.blocks = nn.ModuleList(blocks)
        self.relative_pos = nn.Parameter(torch.randn(1, outer_head, num_patches, num_patches // sr_ratio // sr_ratio))

    def forward(self, inner_tokens, outer_tokens, H_out, W_out, H_in, W_in):
        for blk in self.blocks:
            inner_tokens, outer_tokens = blk(inner_tokens, outer_tokens, H_out, W_out, H_in, W_in, self.relative_pos)
        return inner_tokens, outer_tokens
    
    
class PyramidTNT(nn.Module):
    """ 
    PyramidTNT 
    """

    def __init__(self, configs=None, img_size=224, in_chans=3, num_classes=1000, mlp_ratio=4., qkv_bias=False,
                qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, se=0)
:

        super().__init__()
        self.num_classes = num_classes
        depths = configs['depths']
        outer_dims = configs['outer_dims']
        inner_dims = configs['inner_dims']
        outer_heads = configs['outer_heads']
        inner_heads = configs['inner_heads']
        sr_ratios = [4211]
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule 
        self.num_features = outer_dims[-1]  # num_features for consistency with other models       

        self.patch_embed = Stem(
            img_size=img_size, in_chans=in_chans, outer_dim=outer_dims[0], inner_dim=inner_dims[0])
        num_patches = self.patch_embed.num_patches
        num_words = self.patch_embed.num_words
        
        self.outer_pos = nn.Parameter(torch.zeros(1, num_patches, outer_dims[0]))
        self.inner_pos = nn.Parameter(torch.zeros(1, num_words, inner_dims[0]))
        self.pos_drop = nn.Dropout(p=drop_rate)

        depth = 0
        self.word_merges = nn.ModuleList([])
        self.sentence_merges = nn.ModuleList([])
        self.stages = nn.ModuleList([])
        # 搭建PyramidTNT所需要的4个Stage
        for i in range(4):
            if i > 0:
                self.word_merges.append(WordAggregation(inner_dims[i-1], inner_dims[i], stride=2))
                self.sentence_merges.append(SentenceAggregation(outer_dims[i-1], outer_dims[i], stride=2))
            self.stages.append(Stage(depths[i], outer_dim=outer_dims[i], inner_dim=inner_dims[i],
                        outer_head=outer_heads[i], inner_head=inner_heads[i],
                        num_patches=num_patches // (2 ** i) // (2 ** i), num_words=num_words, mlp_ratio=mlp_ratio,
                        qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate,
                        drop_path=dpr[depth:depth+depths[i]], norm_layer=norm_layer, se=se, sr_ratio=sr_ratios[i])
            )
            depth += depths[i]
        
        self.norm = norm_layer(outer_dims[-1])

        # Classifier head
        self.head = nn.Linear(outer_dims[-1], num_classes) if num_classes > 0 else nn.Identity()

    def forward_features(self, x):
        inner_tokens, outer_tokens, (H_out, W_out), (H_in, W_in) = self.patch_embed(x)
        inner_tokens = inner_tokens + self.inner_pos # B*N, 8*8, C
        outer_tokens = outer_tokens + self.pos_drop(self.outer_pos)  # B, N, D
        
        for i in range(4):
            if i > 0:
                inner_tokens = self.word_merges[i-1](inner_tokens, H_out, W_out, H_in, W_in)
                outer_tokens, H_out, W_out = self.sentence_merges[i-1](outer_tokens, H_out, W_out)
            inner_tokens, outer_tokens = self.stages[i](inner_tokens, outer_tokens, H_out, W_out, H_in, W_in)
        
        outer_tokens = self.norm(outer_tokens)
        return outer_tokens.mean(dim=1)

    def forward(self, x):
        # 特征提取层,可以作为Backbone用到下游任务
        x = self.forward_features(x)
        # 分类层
        x = self.head(x)
        return x

2.3 其他的Tricks

除了修改网络体系结构外,还采用了几种Vision Transformer的高级技巧。

  1. 在自注意力模块上添加相对位置编码,以更好地表示Token之间的相对位置。

  2. 前两个阶段利用Linear spatial reduction attention(LSRA)来降低长序列自注意力的计算复杂度。

3实验

3.1 分类

表3显示了ImageNet-1K分类结果。与原来的TNT相比,PyramidTNT实现了更好的图像分类精度。例如,与TNT-S相比,使用少1.9B的TNT-S的Top-1精度高0.5%。这里还将PyramidTNT与其他具有代表性的CNN、MLP和基于Transformer的模型进行了比较。从结果中可以看到PyramidTNT是最先进的Vision Transformer。

3.2 目标检测

表4报告了“1x”训练计划下的目标检测和实例分割的结果。PyramidTNT-S在One-Stage和Two-Stage检测器上都显著优于其他Backbone,且计算成本相似。例如,基于PyramidTNT-S的RetinaNet达到了42.0 AP和57.7AP-L,分别高出使用Swin-Transformer的模型0.5AP和2.2APL。

这些结果表明,PyramidTNT体系结构可以更好地捕获大型物体的全局信息。金字塔的简单的上采样策略和较小的空间形状使AP-S从一个大规模的推广。

3.3 实例分割

PyramidTNT-S在Mask R-CNN和Cascade Mask R-CNN上的AP-m可以获得更好的AP-b和AP-m,显示出更好的特征表示能力。例如,在ParamidTNN约束上,MaskR-CNN-S超过Hire-MLPS 的0.9AP-b。

4参考

[1].PyramidTNT:Improved Transformer-in-Transformer Baselines with Pyramid Architecture

OpenCV学堂 专注计算机视觉开发技术分享,技术框架使用,包括OpenCV,Tensorflow,Pytorch教程与案例,相关算法详解,最新CV方向论文,硬核代码干货与代码案例详解!作者在CV工程化方面深度耕耘15年,感谢您的关注!
评论
  •     IPC-2581是基于ODB++标准、结合PCB行业特点而指定的PCB加工文件规范。    IPC-2581旨在替代CAM350格式,成为PCB加工行业的新的工业规范。    有一些免费软件,可以查看(不可修改)IPC-2581数据文件。这些软件典型用途是工艺校核。    1. Vu2581        出品:Downstream     
    电子知识打边炉 2025-01-22 11:12 153浏览
  • 嘿,咱来聊聊RISC-V MCU技术哈。 这RISC-V MCU技术呢,简单来说就是基于一个叫RISC-V的指令集架构做出的微控制器技术。RISC-V这个啊,2010年的时候,是加州大学伯克利分校的研究团队弄出来的,目的就是想搞个新的、开放的指令集架构,能跟上现代计算的需要。到了2015年,专门成立了个RISC-V基金会,让这个架构更标准,也更好地推广开了。这几年啊,这个RISC-V的生态系统发展得可快了,好多公司和机构都加入了RISC-V International,还推出了不少RISC-V
    丙丁先生 2025-01-21 12:10 751浏览
  • 数字隔离芯片是一种实现电气隔离功能的集成电路,在工业自动化、汽车电子、光伏储能与电力通信等领域的电气系统中发挥着至关重要的作用。其不仅可令高、低压系统之间相互独立,提高低压系统的抗干扰能力,同时还可确保高、低压系统之间的安全交互,使系统稳定工作,并避免操作者遭受来自高压系统的电击伤害。典型数字隔离芯片的简化原理图值得一提的是,数字隔离芯片历经多年发展,其应用范围已十分广泛,凡涉及到在高、低压系统之间进行信号传输的场景中基本都需要应用到此种芯片。那么,电气工程师在进行电路设计时到底该如何评估选择一
    华普微HOPERF 2025-01-20 16:50 134浏览
  • 临近春节,各方社交及应酬也变得多起来了,甚至一月份就排满了各式约见。有的是关系好的专业朋友的周末“恳谈会”,基本是关于2025年经济预判的话题,以及如何稳定工作等话题;但更多的预约是来自几个客户老板及副总裁们的见面,他们为今年的经济预判与企业发展焦虑而来。在聊天过程中,我发现今年的聊天有个很有意思的“点”,挺多人尤其关心我到底是怎么成长成现在的多领域风格的,还能掌握一些经济趋势的分析能力,到底学过哪些专业、在企业管过哪些具体事情?单单就这个一个月内,我就重复了数次“为什么”,再辅以我上次写的:《
    牛言喵语 2025-01-22 17:10 203浏览
  • 2024年是很平淡的一年,能保住饭碗就是万幸了,公司业绩不好,跳槽又不敢跳,还有一个原因就是老板对我们这些员工还是很好的,碍于人情也不能在公司困难时去雪上加霜。在工作其间遇到的大问题没有,小问题还是有不少,这里就举一两个来说一下。第一个就是,先看下下面的这个封装,你能猜出它的引脚间距是多少吗?这种排线座比较常规的是0.6mm间距(即排线是0.3mm间距)的,而这个规格也是我们用得最多的,所以我们按惯性思维来看的话,就会认为这个座子就是0.6mm间距的,这样往往就不会去细看规格书了,所以这次的运气
    wuliangu 2025-01-21 00:15 378浏览
  •  万万没想到!科幻电影中的人形机器人,正在一步步走进我们人类的日常生活中来了。1月17日,乐聚将第100台全尺寸人形机器人交付北汽越野车,再次吹响了人形机器人疯狂进厂打工的号角。无独有尔,银河通用机器人作为一家成立不到两年时间的创业公司,在短短一年多时间内推出革命性的第一代产品Galbot G1,这是一款轮式、双臂、身体可折叠的人形机器人,得到了美团战投、经纬创投、IDG资本等众多投资方的认可。作为一家成立仅仅只有两年多时间的企业,智元机器人也把机器人从梦想带进了现实。2024年8月1
    刘旷 2025-01-21 11:15 726浏览
  • 飞凌嵌入式基于瑞芯微RK3562系列处理器打造的FET3562J-C全国产核心板,是一款专为工业自动化及消费类电子设备设计的产品,凭借其强大的功能和灵活性,自上市以来得到了各行业客户的广泛关注。本文将详细介绍如何启动并测试RK3562J处理器的MCU,通过实际操作步骤,帮助各位工程师朋友更好地了解这款芯片。1、RK3562J处理器概述RK3562J处理器采用了4*Cortex-A53@1.8GHz+Cortex-M0@200MHz架构。其中,4个Cortex-A53核心作为主要核心,负责处理复杂
    飞凌嵌入式 2025-01-24 11:21 88浏览
  • 现在为止,我们已经完成了Purple Pi OH主板的串口调试和部分配件的连接,接下来,让我们趁热打铁,完成剩余配件的连接!注:配件连接前请断开主板所有供电,避免敏感电路损坏!1.1 耳机接口主板有一路OTMP 标准四节耳机座J6,具备进行音频输出及录音功能,接入耳机后声音将优先从耳机输出,如下图所示:1.21.2 相机接口MIPI CSI 接口如上图所示,支持OV5648 和OV8858 摄像头模组。接入摄像头模组后,使用系统相机软件打开相机拍照和录像,如下图所示:1.3 以太网接口主板有一路
    Industio_触觉智能 2025-01-20 11:04 200浏览
  • 高速先生成员--黄刚这不马上就要过年了嘛,高速先生就不打算给大家上难度了,整一篇简单但很实用的文章给大伙瞧瞧好了。相信这个标题一出来,尤其对于PCB设计工程师来说,心就立马凉了半截。他们辛辛苦苦进行PCB的过孔设计,高速先生居然说设计多大的过孔他们不关心!另外估计这时候就跳出很多“挑刺”的粉丝了哈,因为翻看很多以往的文章,高速先生都表达了过孔孔径对高速性能的影响是很大的哦!咋滴,今天居然说孔径不关心了?别,别急哈,听高速先生在这篇文章中娓娓道来。首先还是要对各位设计工程师的设计表示肯定,毕竟像我
    一博科技 2025-01-21 16:17 169浏览
  • 故障现象 一辆2007款日产天籁车,搭载VQ23发动机(气缸编号如图1所示,点火顺序为1-2-3-4-5-6),累计行驶里程约为21万km。车主反映,该车起步加速时偶尔抖动,且行驶中加速无力。 图1 VQ23发动机的气缸编号 故障诊断接车后试车,发动机怠速运转平稳,但只要换挡起步,稍微踩下一点加速踏板,就能感觉到车身明显抖动。用故障检测仪检测,发动机控制模块(ECM)无故障代码存储,且无失火数据流。用虹科Pico汽车示波器测量气缸1点火信号(COP点火信号)和曲轴位置传感器信
    虹科Pico汽车示波器 2025-01-23 10:46 95浏览
我要评论
0
点击右上角,分享到朋友圈 我知道啦
请使用浏览器分享功能 我知道啦