搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了(十九)

极市平台

共 20226字,需浏览 41分钟

 · 2021-12-01

↑ 点击蓝字 关注极市平台

作者丨科技猛兽
编辑丨极市平台

极市导读

 

本文说明了只要一个模型采用元变换器 (MetaFormer) 作为通用架构,即只要模型的基本架构采用Token information mixing模块+Channel MLP模块的Meta 形式,而不论Token information mixing模块取什么样子、什么形式,模型都可以得到有希望的结果。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

专栏目录:

https://zhuanlan.zhihu.com/p/348593638

本文目录

38 MetaTransformer:简单到尴尬的视觉模型
(来自 Sea AI Lab,新加坡国立大学)
38.1 MetaTransformer 原理分析
38.2 MetaTransformer 代码解读
Transformer 是 Google 的团队在 2017 年提出的一种 NLP 经典模型,现在比较火热的 Bert 也是基于 Transformer。Transformer 模型使用了 Self-Attention 机制,不采用 RNN 的顺序结构,使得模型可以并行化训练,而且能够拥有全局信息。

38 MetaTransformer:简单到尴尬的视觉模型

论文名称:PoolFormer: MetaFormer is Actually What You Need for Vision
论文地址:
https://arxiv.org/abs/2111.11418

38.1 MetaTransformer 原理分析:

Transformer 做视觉取得巨大成功,视觉 Transformer 模型的基本架构是 Token information mixing 模块 + Channel MLP 模块。分块后的图片 (image patch) 输入模型之后首先被编码为 image tokens,再经过 Token information mixing 模块来融合各个 tokens 之间的信息,这个模块在模型中的作用是最大的,因为图片不同 patch 之间的信息交互就是靠它完成的。
而 Token information mixing 模块有很多种实现的方式,最典型的就是以 Vision Transformer (ViT,DeiT) 为代表的注意力模块 self-attention 机制。self-attention 机制具体也有很多种实现形式,比如采用 window-based self-attention 的 Swin Transformer,融合了卷积操作的 Conv-ViT 等等。
后续 Vision MLP 的研究发现不用注意力换成 Spatial MLP (MLP-Mixer,ResMLP) 效果也很好,他们把 self-attention 机制换成了简单的 MLP 模块来混合不同 tokens 之间的信息。这一发现挑战了基于 self-attention 机制的 token mixer 的统治地位,并引发了研究界关于哪种 token mixer 是更好的热烈讨论。
再后续有研究发现甚至是使用傅立叶变换模块也能够混合不同 tokens 之间的信息,所以能够很好地替代 self-attention 机制或者 Spatial MLP 来充当 Token information mixing 模块的角色。
所以,综合所有这些结果,似乎只要一个模型采用元变换器 (MetaFormer) 作为通用架构,即,只要模型的基本架构采用 Token information mixing 模块 + Channel MLP 模块的 Meta 形式,而不论 Token information mixing 模块取什么样子,什么形式,模型都可以得到有希望的结果。 一个自然的问题是:
什么才是 Transformer 及其变种成功的真正原因?Token information mixing 模块到底能简单到啥地步?它是否是模型 work 的关键?
争议之下,颜水成团队的最新论文给出一个不同观点
其实 Token information mixing 模块的具体形式并不重要,Transformer 的成功来自其整体架构--MetaFormer。
这项工作的目标既不是参与 "究竟 Token information mixing 模块是最好的?" 的这场辩论,也不是设计新的复杂的 Token information mixing 模块,以达到 SOTA。相反,作者要研究一个基本问题:什么才是 Transformer 及其变种成功的真正原因?
作者的答案是:"MetaFormer 才是 Transformer 及其变种成功的真正原因"
为了证实这一点,他们把 Transformer 中的注意力模块替换成了一个没有参数的 Operator,即:空间池化算子 (Pooling),新模型命名为 PoolFormer。这里原文的说法很有意思,“简单到让人尴尬 (an embarrassingly simple spatial pooling operator)”。
什么是 MetaFormer?
作者在本文中提出了 MetaFormer 的概念,即:Transformer 模型及其变种结构的一般形式
图1:Transformer 模型及其变种结构的一般形式:MetaFormer
如上图1所示,MetaFormer 是一种通用体系结构,其中没有指定具体的 Token Mixer,而其他的 component 保持与 Transformer 一模一样。
第1步:把输入图片分 patch,并编码成 token:
式中,  ,是 embedding tokens,长度为  ,  是 embedding dimension。
这个 embedded token 会继续输入给 MetaFormer Blocks,每个 MetaFormer Block 包括2个 sub-block,第1个 sub-block 就是 token mixer 模块:
式中,  表示范数,比如 Layer Normalization 或者 Batch Normalization。token mixer 模块的主要功能是混合 token 之间的信息。
第2个 sub-block 就是简单的两层 Channel MLP 模块:
式中,  ,  是其权值,  是非线性的激活函数。
作者认为这种 MetaFormer 通用体系结构对最近的 Transformer 和 MLP like 模型的成功贡献最大。为了证实这一点,作者特意使用了一个简单得令人尴尬的操作符,即 pooling,作为 token mixer。这个操作符没有可学习的参数,它只是使每个 token 平均聚合其附近的 tokens。Pooling operator 可以写成下式:
式中,  是 Pool size,式4中取减号减掉输入  是因为后续有残差连接会再把 LN 层之前的输入加上,伪代码如下:
图2:Pooling operator 伪代码
Pool_size 是 Pooling operator 的基本参数,padding 取 Pool_size // 2,使得输出分辨率和输入分辨率相等。
从计算复杂度的角度分析,self-attention 和 spatial MLP 的计算量随着序列长度呈现平方关系。而且,spatial MLP 会给模型带来额外的参数,所以,self-attention 和 spatial MLP 只适合处理序列长度较短的模型。相比之下,池化操作需要计算复杂度与序列长度成线性关系,而且不需要任何可学习的参数。
图3:PoolFormer 结构
如上图3所示是 PoolFormer 的整体结构,也是金字塔结构,包含4个 stage,分辨率分别是  。四个阶段的层数分别为 
实验结果
Image classification
Dataset: ImageNet-1k
数据增强: MixUp,CutMix,CutOut,RandAugment
优化器: AdamW,weight decay=0.05,base lr=1e-3× bs / 1024,warmup epochs=5,cosine lr decay
Epochs: 300
Batch size: 4096
Label Smoothing: 0.1
如下图4所示是 ImageNet 实验结果。21M 和 30M 的小模型 PoolFormer-S24 和 PoolFormer-S36 可以分别达到80.3%和81.4%的 Top-1 Accuracy,只需要3.6G 和5.2G 的 MACs,超过了几种典型的 Transformer 和 MLP 架构。通过 Pool operator,每个 token 平均聚合其附近 token 的特性。实验结果表明,即使使用这种极其简单的 token mixer,MetaFormer 仍然具有很具有竞争力的性能,说明 MetaFormer 才是我们在设计视觉模型时所真正需要的。
图4:ImageNet 实验结果
Object Detection
Dataset: COCO
检测实验:PoolFormer 作为 backbone,RetinaNet 作为检测头。
实例分割实验:PoolFormer 作为 backbone,Mask R-CNN 作为分割头。
ImageNet pre-trained weights 作为 backbone 初始化权重,其他部分的权重采用 Xavier 初始化。
优化器: AdamW,batch size=16,base lr=1e-4
1×training schedule:训练 Detector 12 epochs。
框架基于 mmdetection。
如下图5所示是 Object Detection 实验结果,PoolFormer 以更少的参数取得比 ResNet 更高的性能。
图5:Object Detection 实验结果,模型 RetinaNet
如下图6所示是 Instance Segmentation 实验结果,PoolFormer 以更少的参数取得比 ResNet 更高的性能。
图6:Object Detection 和 Instance Segmentation 实验结果,模型 Mask R-CNN
Sementic Segmentation
Dataset: ADE20K
语义分割实验:PoolFormer 作为 backbone,Semantic FPN 作为分割头。
ImageNet pre-trained weights 作为 backbone 初始化权重,其他部分的权重采用 Xavier 初始化。
训练 40k iteration,batch size=32。
优化器: AdamW,base lr=2e-4,polynomial decay schedule (0.9)
框架基于 mmsegmentation。
如下图7所示是 Sementic Segmentation 实验结果,PoolFormer 的表现也超过了 ResNet、ResNeXt 和 PVT。
图7:Sementic Segmentation 实验结果
图8:以 PoolFormer-S12 为基线,对比实验结果
对比实验:Normalization 的选择
如上图8所示,Group Normalization 的性能比 Layer Normalization 和 Batch Normalization 分别高了0.7%和0.8%。
对比实验:Hybrid Stages
如上图8所示,4个阶段的 Pooling 操作使用 Attention 模块或者 Spatial MLP 来替换,在4个阶段中把注意力和空间全连接层等机制混合起来用性能也不会下降。其中特别观察到,前两阶段用池化后两阶段用注意力这种组合表现突出。这样的配置下稍微增加一下规模精度就可达到 81%,作为对比的 ResMLP-B24 模型达到相同性能需要7倍的参数规模和8.5倍的累计乘加操作。

38.2 MetaTransformer 代码解读:

代码来自:
https://github.com/sail-sg/poolformer
Patch Embedding 的形式和 ViT 一致:
class PatchEmbed(nn.Module):
"""
Patch Embedding that is implemented by a layer of conv.
Input: tensor in shape [B, C, H, W]
Output: tensor in shape [B, C, H/stride, W/stride]
"""
def __init__(self, patch_size=16, stride=16, padding=0,
in_chans=3, embed_dim=768, norm_layer=None):
super().__init__()
patch_size = to_2tuple(patch_size)
stride = to_2tuple(stride)
padding = to_2tuple(padding)
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size,
stride=stride, padding=padding)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()

def forward(self, x):
x = self.proj(x)
x = self.norm(x)
return x
定义只在 Channel 维度上进行的 Layer Normalization 操作:
class LayerNormChannel(nn.Module):
"""
LayerNorm only for Channel Dimension.
Input: tensor in shape [B, C, H, W]
"""
def __init__(self, num_channels, eps=1e-05):
super().__init__()
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps

def forward(self, x):
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight.unsqueeze(-1).unsqueeze(-1) * x \
+ self.bias.unsqueeze(-1).unsqueeze(-1)
return x
定义 Group Normalization 操作:
class GroupNorm(nn.GroupNorm):
"""
Group Normalization with 1 group.
Input: tensor in shape [B, C, H, W]
"""
def __init__(self, num_channels, **kwargs):
super().__init__(1, num_channels, **kwargs)
定义 Pooling 操作,和伪代码一致:
class Pooling(nn.Module):
"""
Implementation of pooling for PoolFormer
--pool_size: pooling size
"""
def __init__(self, pool_size=3):
super().__init__()
self.pool = nn.AvgPool2d(
pool_size, stride=1, padding=pool_size//2, count_include_pad=False)

def forward(self, x):
return self.pool(x) - x
借助1×1卷积来实现 MLP (这个 MLP 类自带初始化):
class Mlp(nn.Module):
"""
Implementation of MLP with 1*1 convolutions.
Input: tensor with shape [B, C, H, W]
"""
def __init__(self, in_features, hidden_features=None,
out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
self.act = act_layer()
self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)

def _init_weights(self, m):
if isinstance(m, nn.Conv2d):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)

def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
一个 PoolFormer Block:
class PoolFormerBlock(nn.Module):
"""
Implementation of one PoolFormer block.
--dim: embedding dim
--pool_size: pooling size
--mlp_ratio: mlp expansion ratio
--act_layer: activation
--norm_layer: normalization
--drop: dropout rate
--drop path: Stochastic Depth,
refer to https://arxiv.org/abs/1603.09382
--use_layer_scale, --layer_scale_init_value: LayerScale,
refer to https://arxiv.org/abs/2103.17239
"""
def __init__(self, dim, pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU, norm_layer=GroupNorm,
drop=0., drop_path=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):

super().__init__()

self.norm1 = norm_layer(dim)
self.token_mixer = Pooling(pool_size=pool_size)
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,
act_layer=act_layer, drop=drop)

# The following two techniques are useful to train deep PoolFormers.
self.drop_path = DropPath(drop_path) if drop_path > 0. \
else nn.Identity()
self.use_layer_scale = use_layer_scale
if use_layer_scale:
self.layer_scale_1 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)
self.layer_scale_2 = nn.Parameter(
layer_scale_init_value * torch.ones((dim)), requires_grad=True)

def forward(self, x):
if self.use_layer_scale:
x = x + self.drop_path(
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1)
* self.token_mixer(self.norm1(x)))
x = x + self.drop_path(
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1)
* self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.token_mixer(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
定义 basic_block,由 layers[i] 个 PoolFormerBlock 构成:
def basic_blocks(dim, index, layers,
pool_size=3, mlp_ratio=4.,
act_layer=nn.GELU, norm_layer=GroupNorm,
drop_rate=.0, drop_path_rate=0.,
use_layer_scale=True, layer_scale_init_value=1e-5):
"""
generate PoolFormer blocks for a stage
return: PoolFormer blocks
"""
blocks = []
for block_idx in range(layers[index]):
block_dpr = drop_path_rate * (
block_idx + sum(layers[:index])) / (sum(layers) - 1)
blocks.append(PoolFormerBlock(
dim, pool_size=pool_size, mlp_ratio=mlp_ratio,
act_layer=act_layer, norm_layer=norm_layer,
drop=drop_rate, drop_path=block_dpr,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value,
))
blocks = nn.Sequential(*blocks)

return blocks
定义整体的 PoolFormer 模型:
class PoolFormer(nn.Module):
"""
PoolFormer, the main class of our model
--layers: [x,x,x,x], number of blocks for the 4 stages
--embed_dims, --mlp_ratios, --pool_size: the embedding dims, mlp ratios and
pooling size for the 4 stages
--downsamples: flags to apply downsampling or not
--norm_layer, --act_layer: define the types of normalizaiotn and activation
--num_classes: number of classes for the image classification
--in_patch_size, --in_stride, --in_pad: specify the patch embedding
for the input image
--down_patch_size --down_stride --down_pad:
specify the downsample (patch embed.)
--fork_faat: whetehr output features of the 4 stages, for dense prediction
--init_cfg,--pretrained:
for mmdetection and mmsegmentation to load pretrianfed weights
"""
def __init__(self, layers, embed_dims=None,
mlp_ratios=None, downsamples=None,
pool_size=3,
norm_layer=GroupNorm, act_layer=nn.GELU,
num_classes=1000,
in_patch_size=7, in_stride=4, in_pad=2,
down_patch_size=3, down_stride=2, down_pad=1,
drop_rate=0., drop_path_rate=0.,
use_layer_scale=True, layer_scale_init_value=1e-5,
fork_feat=False,
init_cfg=None,
pretrained=None,
**kwargs):

super().__init__()

if not fork_feat:
self.num_classes = num_classes
self.fork_feat = fork_feat

self.patch_embed = PatchEmbed(
patch_size=in_patch_size, stride=in_stride, padding=in_pad,
in_chans=3, embed_dim=embed_dims[0])

# set the main block in network
network = []
for i in range(len(layers)):
stage = basic_blocks(embed_dims[i], i, layers,
pool_size=pool_size, mlp_ratio=mlp_ratios[i],
act_layer=act_layer, norm_layer=norm_layer,
drop_rate=drop_rate,
drop_path_rate=drop_path_rate,
use_layer_scale=use_layer_scale,
layer_scale_init_value=layer_scale_init_value)
network.append(stage)
if i >= len(layers) - 1:
break
if downsamples[i] or embed_dims[i] != embed_dims[i+1]:
# downsampling between two stages
network.append(
PatchEmbed(
patch_size=down_patch_size, stride=down_stride,
padding=down_pad,
in_chans=embed_dims[i], embed_dim=embed_dims[i+1]
)
)

self.network = nn.ModuleList(network)

if self.fork_feat:
# add a norm layer for each output
self.out_indices = [0, 2, 4, 6]
for i_emb, i_layer in enumerate(self.out_indices):
if i_emb == 0 and os.environ.get('FORK_LAST3', None):
# TODO: more elegant way
"""For RetinaNet, `start_level=1`. The first norm layer will not used.
cmd: `FORK_LAST3=1 python -m torch.distributed.launch ...`
"""
layer = nn.Identity()
else:
layer = norm_layer(embed_dims[i_emb])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
else:
# Classifier head
self.norm = norm_layer(embed_dims[-1])
self.head = nn.Linear(
embed_dims[-1], num_classes) if num_classes > 0 \
else nn.Identity()

self.apply(self.cls_init_weights)

self.init_cfg = copy.deepcopy(init_cfg)
# load pre-trained model
if self.fork_feat and (
self.init_cfg is not None or pretrained is not None):
self.init_weights()

# init for classification
def cls_init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)

# init for mmdetection or mmsegmentation by loading
# imagenet pre-trained weights
def init_weights(self, pretrained=None):
logger = get_root_logger()
if self.init_cfg is None and pretrained is None:
logger.warn(f'No pre-trained weights for '
f'{self.__class__.__name__}, '
f'training start from scratch')
pass
else:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
if self.init_cfg is not None:
ckpt_path = self.init_cfg['checkpoint']
elif pretrained is not None:
ckpt_path = pretrained

ckpt = _load_checkpoint(
ckpt_path, logger=logger, map_location='cpu')
if 'state_dict' in ckpt:
_state_dict = ckpt['state_dict']
elif 'model' in ckpt:
_state_dict = ckpt['model']
else:
_state_dict = ckpt

state_dict = _state_dict
missing_keys, unexpected_keys = \
self.load_state_dict(state_dict, False)
print('missing_keys: ', missing_keys)
print('unexpected_keys: ', unexpected_keys)

def get_classifier(self):
return self.head

def reset_classifier(self, num_classes):
self.num_classes = num_classes
self.head = nn.Linear(
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

def forward_embeddings(self, x):
x = self.patch_embed(x)
return x

def forward_tokens(self, x):
outs = []
for idx, block in enumerate(self.network):
x = block(x)
if self.fork_feat and idx in self.out_indices:
norm_layer = getattr(self, f'norm{idx}')
x_out = norm_layer(x)
outs.append(x_out)
if self.fork_feat:
# output the features of four stages for dense prediction
return outs
# output only the features of last layer for image classification
return x

def forward(self, x):
# input embedding
x = self.forward_embeddings(x)
# through backbone
x = self.forward_tokens(x)
if self.fork_feat:
# otuput features of four stages for dense prediction
return x
x = self.norm(x)
cls_out = self.head(x.mean([-2, -1]))
# for image classification
return cls_out
1) self.forward_embeddings(x): 一开始的图像分 patch 的 Patch Embedding 操作。
2) 这套代码的 x 即图片的特征维度和 ViT 不同,自始至终都保持:[B, C, W, H],其中,C 是 Embedding dimension。所以作者的 MLP 是通过 Conv1×1 实现的,且最后的 Average Pool 操作是:x.mean([-2, -1]),默认对 H 和 W 的最后2维使用 Pooling。而 ViT 中使用的是 x.mean(1)。
3) cls_init_weights 这个函数只是给分类头 self.head 做初始化。init_weights 这个函数是在检测,分割实验中使用,将 init_cfg 里面的 'checkpoint' 加载到模型里面。
4) 如果 self.fork_feat==True,即进行检测,分割实验,则 forward() 函数只输出4个 stage 结束的特征,并方便 mmdetection 或者 mmsegmentation 调用这个输出特征。如果 self.fork_feat==False,即进行分类实验,就依次进行 LN,Average Pooling 和通过分类头得到分类结果。

总结

本文说明了只要一个模型采用元变换器 (MetaFormer) 作为通用架构,即,只要模型的基本架构采用 Token information mixing 模块 + Channel MLP 模块的 Meta 形式,而不论 Token information mixing 模块取什么样子,什么形式,模型都可以得到有希望的结果。其实 Token information mixing 模块的具体形式并不重要,Transformer 的成功来自其整体架构--MetaFormer,它才是 Transformer 及其变种成功的真正原因。


如果觉得有用,就请分享到朋友圈吧!

△点击卡片关注极市平台,获取最新CV干货

公众号后台回复“transformer”获取最新Transformer综述论文下载~


极市干货
课程/比赛:珠港澳人工智能算法大赛保姆级零基础人工智能教程
算法trick目标检测比赛中的tricks集锦从39个kaggle竞赛中总结出来的图像分割的Tips和Tricks
技术综述:一文弄懂各种loss function工业图像异常检测最新研究总结(2019-2020)


极市平台签约作者#


科技猛兽

知乎:科技猛兽


清华大学自动化系19级硕士

研究领域:AI边缘计算 (Efficient AI with Tiny Resource):专注模型压缩,搜索,量化,加速,加法网络,以及它们与其他任务的结合,更好地服务于端侧设备。


作品精选

搞懂 Vision Transformer 原理和代码,看这篇技术综述就够了
用Pytorch轻松实现28个视觉Transformer,开源库 timm 了解一下!(附代码解读)
轻量高效!清华智能计算实验室开源基于PyTorch的视频 (图片) 去模糊框架SimDeblur



投稿方式:
添加小编微信Fengcall(微信号:fengcall19),备注:姓名-投稿
△长按添加极市平台小编

觉得有用麻烦给个在看啦~  
浏览 28
点赞
评论
收藏
分享

手机扫一扫分享

举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

举报