目录
Masked Autoencoders Are Scalable Vision Learnershttps://arxiv.org/pdf/2111.06377.pdf
????????在深度学习和计算机视觉的领域中,预训练模型已经成为了提高下游任务性能的重要手段。传统上,许多预训练模型如ResNet、VGG等都是在大规模数据集(如ImageNet)上通过监督学习训练得到的。然而,监督学习需要大量的标记数据,这在成本和可扩展性上都是一个不小的挑战。
????????最近,自监督学习作为一个新兴研究领域,提供了一种无需手工标注数据的解决方案。自监督学习的一个关键点是设计预测任务,通过这些任务模型可以从输入数据本身学习到有用的表示。在自然语言处理(NLP)领域,BERT通过掩码语言模型(MLM)任务表现出色,这激发了计算机视觉领域对类似方法的探索。
????????MAE (Masked Autoencoder) 正是从这样的背景和动机出发,它将自监督学习中的掩码预测任务引入到视觉领域,致力于从图像数据中以无监督的方式学习高效的特征表示。
????????论文中回答了一个问题。为什么自监督在CV领域的发展要滞后于NLP呢?论文中给了两个解释:
(1)NLP主流方法是Transformer,视觉里CNN是主流方法,结构差异让视觉很难构造类似于“masked autoencoding”的任务。但是ViT的提出解决了这个问题;
(2)语言和视觉的信息密度(information density)差异巨大,前者是强语义的,高信息密度的(highly semantic and information-dense),在NLP中即使只mask一个token,对模型来说可能都是很难的任务,因此模型可以通过学习获得复杂的语言理解能力(sophisticated language understanding),但是对视觉图像来说,信息是高度冗余的,缺失一个patch,可能并不会让模型产生多少困惑,模型可以通过周围的像素信息进行推断
????????所以MAE做的一件事就是mask很高比例的patches,制造高难度的学习任务,方法简单但是极其有效
MAE 的核心创新在于其独特的自监督预训练方法。不同于之前的自监督视觉模型通常需要对比学习或复杂的数据增强,MAE 提出了一种简洁高效的方法:
Masking 策略:MAE 对输入图像进行随机遮蔽,只露出一小部分像素,模型的任务是预测被遮蔽部分的原始像素。这种策略减少了模型需要处理的数据量,同时迫使模型学习丰富的上下文信息来重建图像。
编码器-解码器架构:MAE 采用了一个不对称的编码器-解码器架构,其中编码器只对未被遮蔽的部分进行处理,大幅减少了计算量。解码器则负责图像的重建工作,它的结构相对简单,因为其主要任务是理解编码器提供的特征。
预训练与微调:MAE 的预训练阶段不依赖于标签,这使得模型可以在非常大的数据集上进行训练。一旦预训练完成,MAE 可以通过微调在各种下游任务上实现优异的性能,包括分类、检测和分割等。
数据遮掩:首先,在输入图像或序列数据中随机选择一定比例的区域进行遮掩,将其替换为特定的遮掩标记(如0或[MASK])。
编码阶段:仅将未遮掩的数据部分输入到一个轻量级的Transformer编码器中,以提取局部上下文特征。
解码阶段:将编码后的向量传递给一个解码器,该解码器通常也是一个Transformer,但会对所有像素或位置进行解码预测,恢复出被遮掩部分的信息。
损失函数:使用L1或L2距离作为损失函数,衡量预测的像素值或词向量与原始未遮掩数据之间的差异。
预训练与微调:经过大规模无标签数据上的预训练后,可以将模型参数迁移到特定的下游任务中进行微调,进一步提升任务性能。
import torch
import torch.nn as nn
import torch.nn.functional as F
class PositionalEncoding(nn.Module):
# 用于添加位置信息的模块,通常在Transformer结构中使用
def __init__(self, d_model, dropout=0.1, max_len=5000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:x.size(0), :]
return self.dropout(x)
class Encoder(nn.Module):
def __init__(self, embed_dim, num_layers, num_heads, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.):
super(Encoder, self).__init__()
self.layers = nn.ModuleList([
nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, dim_feedforward=int(embed_dim * mlp_ratio),
dropout=drop_rate, attention_dropout=attn_drop_rate, bias_qkv=qkv_bias)
for _ in range(num_layers)])
def forward(self, src, mask=None):
output = src
for layer in self.layers:
output = layer(output, src_key_padding_mask=mask)
return output
class MaskedAutoencoder(nn.Module):
def __init__(self, image_size, patch_size, num_channels, embed_dim, num_layers, num_heads, mlp_ratio, num_classes):
super(MaskedAutoencoder, self).__init__()
self.patch_size = patch_size
self.embed_dim = embed_dim
self.num_patches = (image_size // patch_size) ** 2
self.encoder = nn.Sequential(
nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size),
nn.LayerNorm(embed_dim),
)
self.pos_embed = PositionalEncoding(embed_dim)
self.transformer_encoder = Encoder(embed_dim, num_layers, num_heads, mlp_ratio)
self.decoder = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.GELU(),
nn.Linear(embed_dim, num_channels * patch_size ** 2),
nn.PixelShuffle(patch_size),
)
self.to_patch_embedding = nn.Sequential(
nn.Unflatten(dim=1, unflattened_size=(num_patches, embed_dim)),
nn.Dropout(p=0.1),
)
def forward(self, x, mask_ratio=0.75):
B, C, H, W = x.shape
assert H == W, "Input image must be square"
x = self.encoder(x)
x = self.pos_embed(x)
# 随机掩码
rand_mask = torch.rand(B, self.num_patches, 1, 1, device=x.device) < mask_ratio
masked_x = x.clone()
masked_x[rand_mask] = 0.
# 编码
encoded_patches = self.transformer_encoder(self.to_patch_embedding(masked_x))
# 解码
reconstructed_image = self.decoder(encoded_patches)
return reconstructed_image
# 初始化模型
model = MaskedAutoencoder(image_size=224, patch_size=16, num_channels=3, embed_dim=768, num_layers=12, num_heads=12, mlp_ratio=4., num_classes=0)
# 假设我们有输入数据x
x = torch.randn((10, 3, 224, 224))
# 计算重构后的图像
reconstruction = model(x)