
-
论文下载地址和代码开源地址:https://github.com/LandskapeAI/triplet-attention
https://arxiv.org/abs/2010.03045
在本文中研究了轻量且有效的注意力机制,并提出了Triplet Attention,该注意力机制是一种经过使用Triplet Branch结构捕获跨维度交互来计算注意力权重的新方法。对于输入张量,Triplet Attention经过旋转操做和残差变换创建维度间的依存关系,并以可忽略的计算开销对通道和空间信息进行编码。该方法既简单又有效,而且能够轻松地插入经典Backbone中。node
一、简介和相关方法
最近许多工做提出使用Channel Attention或Spatial Attention,或二者结合起来提升神经网络的性能。这些Attention机制经过创建Channel之间的依赖关系或加权空间注意Mask有能力改善由标准CNN生成的特征表示。学习注意力权重背后是让网络有能力学习关注哪里,并进一步关注目标对象。这里列举一些具备表明的工做:
一、SENet(Squeeze and Excite module)
二、CBAM(Convolutional Block Attention Module)
三、BAM(Bottleneck Attention Module)
四、Grad-CAM
五、Grad-CAM++
六、
-Nets(Double Attention Networks)
七、NL(Non-Local blocks)
八、GSoP-Net(Global Second order Pooling Networks)
九、GC-Net(Global Context Networks)
十、CC-Net(Criss-Cross Networks)
十一、SPNet
等等方法(这些方法都值得你们去学习和调研,说不定会给你的项目带来意想不到的效果)。
以上大多数方法都有明显的缺点(Cross-dimension),Triplet Attention解决了这些缺点。Triplet Attention模块旨在捕捉Cross-dimension交互,从而可以在一个合理的计算开销内(与上述方法相比能够忽略不计)提供显著的性能收益。git
二、本文方法
2.一、分析
本文的目标是研究如何在不涉及任何维数下降的状况下创建廉价但有效的通道注意力模型。Triplet Attention不像CBAM和SENet须要必定数量的可学习参数来创建通道间的依赖关系,本文提出了一个几乎无参数的注意机制来建模通道注意和空间注意,即Triplet Attention。github
2.二、Triplet Attention
所提出的Triplet Attention见下图所示。顾名思义,Triplet Attention由3个平行的Branch组成,其中两个负责捕获通道C和空间H或W之间的跨维交互。最后一个Branch相似于CBAM,用于构建Spatial Attention。最终3个Branch的输出使用平均进行聚合。web

一、Cross-Dimension Interaction
传统的计算通道注意力的方法涉及计算一个权值,而后使用权值统一缩放这些特征图。可是在考虑这种方法时,有一个重要的缺失。一般,为了计算这些通道的权值,输入张量在空间上经过全局平均池化分解为一个像素。这致使了空间信息的大量丢失,所以在单像素通道上计算注意力时,通道维数和空间维数之间的相互依赖性也不存在。微信
虽而后期提出基于Spatial和Channel的CBAM模型缓解了空间相互依赖的问题,可是依然存在一个问题,即,通道注意和空间注意是分离的,计算是相互独立的。基于创建空间注意力的方法,本文提出了跨维度交互做用(cross dimension interaction)的概念,经过捕捉空间维度和输入张量通道维度之间的交互做用,解决了这一问题。网络

这里是经过三个分支分别捕捉输入张量的(C, H),(C, W)和(H, W)维间的依赖关系来引入Triplet Attention中的跨维交互做用。架构
二、Z-pool
Z-pool层负责将C维度的Tensor缩减到2维,将该维上的平均聚集特征和最大聚集特征链接起来。这使得该层可以保留实际张量的丰富表示,同时缩小其深度以使进一步的计算量更轻。能够用下式表示:app
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat((torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=11)
三、Triplet Attention
给定一个输入张量 ,首先将其传递到Triplet Attention模块中的三个分支中。编辑器
在第1个分支中,在H维度和C维度之间创建了交互:ide

为了实现这一点,输入张量 沿H轴逆时针旋转90°。这个旋转张量表示为 的形状为(W×H×C),再而后通过Z-Pool后的张量 的shape为(2×H×C),而后, 经过内核大小为k×k的标准卷积层,再经过批处理归一化层,提供维数(1×H×C)的中间输出。而后,经过将张量经过sigmoid来生成的注意力权值。在最后输出是沿着H轴进行顺时针旋转90°保持和输入的shape一致。
在第2个分支中,在C维度和W维度之间创建了交互:

为了实现这一点,输入张量 沿W轴逆时针旋转90°。这个旋转张量表示为 的形状为(H×C×W),再而后通过Z-Pool后的张量 的shape为(2×C×W ),而后, 经过内核大小为k×k的标准卷积层,再经过批处理归一化层,提供维数(1×C×W)的中间输出。而后,经过将张量经过sigmoid来生成的注意力权值。在最后输出是沿着W轴进行顺时针旋转90°保持和输入的shape一致。
在第3个分支中,在H维度和W维度之间创建了交互:
输入张量
的通道经过Z-pool将变量简化为2。将这个形状的简化张量(2×H×W)简化后经过核大小k定义的标准卷积层,而后经过批处理归一化层。输出经过sigmoid激活层生成形状为(1×H×W)的注意权值,并将其应用于输入
,获得结果
。而后经过简单的平均将3个分支产生的精细张量(C×H×W)聚合在一块儿。
**最终输出的Tensor:
class BasicConv(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
self.relu = nn.ReLU() if relu else None
def forward(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
def forward(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = torch.sigmoid_(x_out)
return x * scale
class TripletAttention(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
super(TripletAttention, self).__init__()
self.ChannelGateH = SpatialGate()
self.ChannelGateW = SpatialGate()
self.no_spatial=no_spatial
if not no_spatial:
self.SpatialGate = SpatialGate()
def forward(self, x):
x_perm1 = x.permute(0,2,1,3).contiguous()
x_out1 = self.ChannelGateH(x_perm1)
x_out11 = x_out1.permute(0,2,1,3).contiguous()
x_perm2 = x.permute(0,3,2,1).contiguous()
x_out2 = self.ChannelGateW(x_perm2)
x_out21 = x_out2.permute(0,3,2,1).contiguous()
if not self.no_spatial:
x_out = self.SpatialGate(x)
x_out = (1/3)*(x_out + x_out11 + x_out21)
else:
x_out = (1/2)*(x_out11 + x_out21)
return x_out
四、Complexity Analysis
经过与其余标准注意力机制的比较,验证了Triplet Attention的效率,C为该层的输入通道数,r为MLP在计算通道注意力时瓶颈处使用的缩减比,用于2D卷积的核大小用k表示,k<<<C。
三、实验结果
3.一、图像分类实验

3.二、目标检测实验


3.三、消融实验

3.四、HeatMap输出对比


四、总结
在这项工做中提出了一个新的注意力机制Triplet Attention,它抓住了张量中各个维度特征的重要性。Triplet Attention使用了一种有效的注意计算方法,不存在任何信息瓶颈。实验证实,Triplet Attention提升了ResNet和MobileNet等标准神经网络架构在ImageNet上的图像分类和MS COCO上的目标检测等任务上的Baseline性能,而只引入了最小的计算开销。是一个很是不错的即插即用的注意力模块。
更为详细内容能够参见论文中的描述。
References
[1] Rotate to Attend: Convolutional Triplet Attention Module
声明:转载请说明出处
扫描下方二维码关注【AI人工智能初学者】公众号,获取更多实践项目源码和论文解读,很是期待你个人相遇,让咱们以梦为马,砥砺前行!!!
点“在看”给我一朵小黄花呗
本文分享自微信公众号 - AI人工智能初学者(ChaucerG)。
若有侵权,请联系 support@oschina.cn 删除。
本文参与“OSC源创计划”,欢迎正在阅读的你也加入,一块儿分享。