All Things ViTs系列讲座从ViT视觉模型注意力机制出发,本文给出DINO attention map可视化部分阅读学习体会.
课程视频与课件: https://all-things-vits.github.io/atv/
代码:https://colab.research.google.com/github/all-things-vits/code-samples/blob/main/probing/dino_attention_maps.ipynb
文献: Emerging Properties in Self-Supervised Vision Transformers
这部分代码实际与CVPR 2023 Hybrid Tutorial: All Things ViTs之mean attention distance (MAD)比较类似.Emerging Properties in Self-Supervised Vision Transformers这篇论文主要想论述采用DINO(self-distillation with no labels)的训练方式可以达到类似监督学习的训练效果,因此通过可视化注意力机制图的方式来说明DINO学习方法是学到了需要的知识.这部分代码的核心是可视化最后一个block中CLS token与其他图像patch token之间的注意力图.即展示每个patch与CLS之间的关联性,如Fig 1:
Fig 1 DINO attention map
def get_attention_scores(image: Image, model: torch.nn.Module, processor):
"""Extracts attention scores given an image, a model,
and its processor."""
inputs = processor(image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs, output_attentions=True)
return outputs.attentions # [[1, 12, 197, 197]*12]
这部分代码将图像输入ViT网络,并得到输出的每个head的注意力分数(outputs.attentions).outputs.attentions是一个tuple,其中包括12个维度为[1, 12, 197, 197]的tensor.这个tensor可理解如下,其中12为head的数量,197是token的数量.197*197表示每个token之间的注意力分数.
def process_attention_map(
image: torch.Tensor, attention_scores, block_id=11, patch_size=16
):
"""
Processes the attention scores such that they can be overlaid on the input image.
Args:
image (torch.Tensor): The input image tensor.
attention_scores (Tuple[torch.Tensor]): Tuple of attention scores.
block_id (int, optional): The block ID. Default is 11. 11 is the last
transformer block for a DINO base model.
patch_size (int, optional): The size of the patches. Default is 16.
Returns:
numpy.ndarray: The processed attention map as a NumPy array.
"""
height, width = image.shape[2:]
w_featmap = width // patch_size
h_featmap = height // patch_size
num_heads = attention_scores[block_id].shape[1] # Number of attention heads. 12
# Taking the representations from CLS token.
attentions = attention_scores[block_id][0, :, 0, 1:].reshape(num_heads, -1)# only visual cls tokens of each head
print(attentions.shape) # [12, 196], 196 is the number of patches(tokens), 12 is the number of heads, attention of cls token to each patch in one head
# Reshape the attention scores to resemble mini patches.
attentions = attentions.reshape(num_heads, w_featmap, h_featmap) # [12, 14, 14]
print(attentions.shape)
# Resize the attention patches to 224x224 (224: 14x16)
attentions = F.resize(
attentions, size=(h_featmap * patch_size, w_featmap * patch_size)
)
return attentions.numpy()
这部分代码则是抽取CLS token与其他patch token之间的attention矩阵,并上采样到与原始图像一样大并绘制.
这部分代码可视化比较好理解,接下来可以更全面看看DINO可视化效果:
可以看到,经过DINO训练,ViT能够有效关注图像中有意义的部分.并且不同的head在关注不同的区域.后续经过简单微调,即可泛化至下游任务.