ICCV-2019
分割任务中全局的上下文信息非常重要,如果高效轻量的获取上下文?
Thus, is there an alternative solution to achieve such a target in a more efficient way?
作者提出了 Criss-Cross Attention
相比于 Non-local(【NL】《Non-local Neural Networks》)
复杂度从 O((HxW)x(HxW)) 降低到了 O((HxW)x(H+W-1))
整理流程如下
Criss-Cross Attention Module 用了两次,叫 recurrent Criss-Cross attention (RCCA) module
下面是和 non-local 的对比
比如(b)中,计算蓝色块的 attention,绿色块不同深浅表示与蓝色块的相关程度,第一次结合十字架attention得到黄色块,第二次再结合十字架attention,得到红色块
为什么两次,因为一次捕获不到全局上下文信息,两次就可以,如下图
第一次,计算深绿色块的 Criss-Cross 注意力,只能获取到浅绿色块的信息,蓝色块的信息获取不到,浅绿色可以获取到蓝色块信息
第二次,计算深绿色块的 Criss-Cross 注意力,因为第一次计算浅绿色块注意力时已经有蓝色块信息了,此时,可以获取到蓝色块信息
更细节的 Criss-Cross 注意力图如下
下面结合图 3 看看公式表达
输入 H ∈ R C × W × H H \in \mathbb{R}^{C \times W \times H} H∈RC×W×H
query 和 key, { Q , K } ∈ R C ′ × W × H \{Q, K\} \in \mathbb{R}^{{C}' \times W \times H} {Q,K}∈RC′×W×H, C ′ {C}' C′ 为 1/8 C C C
Q u ∈ R C ′ Q_u \in \mathbb{R}^{{C}'} Qu?∈RC′, u u u 是 H × W H \times W H×W 中空间位置索引,特征图 Q 的子集(每个空间位置)
Ω u ∈ R ( H + W ? 1 ) × C ′ \Omega_{u} \in \mathbb{R}^{(H + W -1) \times {C}' } Ωu?∈R(H+W?1)×C′,特征图 K 的子集(每个十字架)
Affinity operation 可以定义为
d i , u = Q u Ω i , u T d_{i,u} = Q_u \Omega_{i, u}^T di,u?=Qu?Ωi,uT?
Q Q Q上每个空间位置 Q u Q_u Qu?,找到 K K K 上对应的同行同列十字架 Ω u \Omega_{u} Ωu?, i i i 是十字架中空间位置的索引, d i , u ∈ D d_{i,u} \in {D} di,u?∈D, D ∈ R ( H + W ? 1 ) × W × H D \in \mathbb{R}^{(H+W-1) \times W \times H} D∈R(H+W?1)×W×H, Q Q Q 和 K K K 计算的 D D D 经过 softmax 后成 A ∈ R ( H + W ? 1 ) × W × H A \in \mathbb{R}^{(H + W -1) \times W \times H} A∈R(H+W?1)×W×H
Q Q Q 和 K K K 计算出来了权重 A A A 最终作用到 K K K 上,形式如下:
H u ′ = ∑ i ∈ ∣ Φ u ∣ A i , u Φ i , u + H u {H}_u^{'} = \sum_{i \in | \Phi_u|} A_{i,u}\Phi_{i,u} + H_u Hu′?=i∈∣Φu?∣∑?Ai,u?Φi,u?+Hu?
Φ i , u \Phi_{i,u} Φi,u? 同 Ω i , u \Omega_{i, u} Ωi,u?,一个是特征图 V V V 的子集,一个是特征图 K K K 的子集, H H H 是输入, H ′ {H}^{'} H′ 为输出, i i i 是十字架索引, u u u 是 H H H x W W W 空间位置索引
为了使每一个位置 u u u 可以与任何位置对应起来,作者通过两次计算 Criss-cross 来完成,只需对 H ′ {H}^{'} H′ 再次计算 criss-cross attention,输出 H ′ ′ {H}^{''} H′′,此时就有:
u
u
u and
θ
\theta
θ in the same row or column
A
A
A 表示 loop = 1 时的注意力 weight,
A
′
{A}'
A′ 表示 loop = 2 时的 weight
u
u
u and
θ
\theta
θ not in the same row or column,eg 图 4,深绿色位置是
u
u
u,蓝色的位置是
θ
\theta
θ
加上
再看看代码
import torch
import torch.nn as nn
import torch.nn.functional as F
def INF(B,H,W):
return -torch.diag(torch.tensor(float("inf")).cuda().repeat(H),0).unsqueeze(0).repeat(B*W,1,1)
class CrissCrossAttention(nn.Module):
def __init__(self, in_channels):
super(CrissCrossAttention, self).__init__()
self.in_channels = in_channels
self.channels = in_channels // 8
self.ConvQuery = nn.Conv2d(self.in_channels, self.channels, kernel_size=1)
self.ConvKey = nn.Conv2d(self.in_channels, self.channels, kernel_size=1)
self.ConvValue = nn.Conv2d(self.in_channels, self.in_channels, kernel_size=1)
self.SoftMax = nn.Softmax(dim=3)
self.INF = INF
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
b, _, h, w = x.size()
# [b, c', h, w]
query = self.ConvQuery(x)
# [b, w, c', h] -> [b*w, c', h] -> [b*w, h, c']
query_H = query.permute(0, 3, 1, 2).contiguous().view(b*w, -1, h).permute(0, 2, 1)
# [b, h, c', w] -> [b*h, c', w] -> [b*h, w, c']
query_W = query.permute(0, 2, 1, 3).contiguous().view(b*h, -1, w).permute(0, 2, 1)
# [b, c', h, w]
key = self.ConvKey(x)
# [b, w, c', h] -> [b*w, c', h]
key_H = key.permute(0, 3, 1, 2).contiguous().view(b*w, -1, h)
# [b, h, c', w] -> [b*h, c', w]
key_W = key.permute(0, 2, 1, 3).contiguous().view(b*h, -1, w)
# [b, c, h, w]
value = self.ConvValue(x)
# [b, w, c, h] -> [b*w, c, h]
value_H = value.permute(0, 3, 1, 2).contiguous().view(b*w, -1, h)
# [b, h, c, w] -> [b*h, c, w]
value_W = value.permute(0, 2, 1, 3).contiguous().view(b*h, -1, w)
# [b*w, h, c']* [b*w, c', h] -> [b*w, h, h] -> [b, h, w, h]
energy_H = (torch.bmm(query_H, key_H) + self.INF(b, h, w)).view(b, w, h, h).permute(0, 2, 1, 3)
# [b*h, w, c']*[b*h, c', w] -> [b*h, w, w] -> [b, h, w, w]
energy_W = torch.bmm(query_W, key_W).view(b, h, w, w)
# [b, h, w, h+w] concate channels in axis=3
concate = self.SoftMax(torch.cat([energy_H, energy_W], 3))
# [b, h, w, h] -> [b, w, h, h] -> [b*w, h, h]
attention_H = concate[:,:,:, 0:h].permute(0, 2, 1, 3).contiguous().view(b*w, h, h)
attention_W = concate[:,:,:, h:h+w].contiguous().view(b*h, w, w)
# [b*w, h, c]*[b*w, h, h] -> [b, w, c, h]
out_H = torch.bmm(value_H, attention_H.permute(0, 2, 1)).view(b, w, -1, h).permute(0, 2, 3, 1)
out_W = torch.bmm(value_W, attention_W.permute(0, 2, 1)).view(b, h, -1, w).permute(0, 2, 1, 3)
return self.gamma*(out_H + out_W) + x
if __name__ == "__main__":
model = CrissCrossAttention(512)
x = torch.randn(2, 512, 28, 28)
model.cuda()
out = model(x.cuda())
print(out.shape)
参考
Mean IoU (mIOU, mean of class-wise intersection over union section over union) for Cityscapes and ADE20K and the standard COCO metrics Average Precision (AP) for COCO
(1)Comparisons with state-of-the-arts
DPC 用了更强的主干,更多的数据集来 train
(2)Ablation studies
消融了下循环的次数,还是很猛的,第一次就提升了 2.9 个点,第二次又提升了 1.8 个
看看效果图,重点看作者圈出来的白色虚线椭圆区域
对比看看其他的 context aggregation 模块
作者的 Criss-Cross Attention 比较猛
其次比较猛的是 Non-local,但是作者的计算量小很多
看看特征图,重点看作者圈出来的绿色十字加号区域
《Large Kernel Matters Improve Semantic Segmentation by Global Convolutional Network》