介绍了一种基于BERT和ResNet的多模态模型,该模型在图像和文本信息上进行联合训练,实现了卓越的性能。
1 原理介绍
(1)模型结构
模型以BERT和ResNet为基础,分别处理文本和图像输入。接着,通过自注意力机制实现多模态信息的融合。让我们逐步解析这个模型的关键部分。除了多模态处理外,该模型还支持单一模态的处理,即只有文本输入或只有图像输入。在这两种情况下,模型分别提取单一模态的特征,并通过分类器得到输出。
(2)特征提取
模型首先通过BERT处理文本输入,获取文本的隐藏状态。这些隐藏状态包含了文本的语义信息,为后续的多模态融合做准备。对于图像输入,模型使用了预训练的ResNet152模型提取图像的特征。通过对图像特征进行池化和线性变换,得到图像的隐藏状态。
(3)特征融合
接下来,模型将文本和图像的隐藏状态进行拼接,构成共同的特征表示。通过设置attention_mask,模型实现了对文本中padding部分的处理,并使用self-attention机制进行多模态融合。
(4)分类
最后,模型分别提取多模态融合后的图像和文本特征,并通过线性变换进行分类,得到最终的输出。
2 使用的数据集
(1)Twitter-15和17数据集
一种是基于LSTM模型的“.txt”格式,另一种是BERT模型的“tsv”格式。例如,对于twitter2015的“train.tsv”,每一行都是一个样本:第一列是索引;第二列是情感标签(0表示负面,1表示中性,2表示正面);第三列是该推文对应图像的ID,可以在“twitter2015_images”文件夹中找到;第四和第五列分别是通过掩码当前意见目标和意见目标(即实体)的原始推文。请注意,每个推文可能包含多个意见目标(即实体),它可能对应于多个连续的样本。例如,对于twitter2015的“train.tsv”,第一个和第二个样本都是关于同一条推文的,但是涉及不同的实体。“.txt”文件与“train.tsv”类似。
(2)MVSA-Single和Multi数据集
由两个独立的数据集组成,分别是MVSA-Single数据集和 MVSA-Multi数据集,前者的每条图文对只有一个标注,后者的每条图文对由三个标注者给出。官方声明MVSA-Single数据集包含 5,129 条图文对(实际只有4869条),MVSA-Multi 包含了 19,600 条图文对(实际19600条)
?3?代码示例
"""拼接文本和图像,拼接得到共同特征"""
image_text_hidden_state = torch.cat([image_hidden_state, text_hidden_state], 1)
"""设置attention_mask,padding部分置-10000"""
attention_mask = text_input.attention_mask
image_attention_mask = torch.ones((attention_mask.size(0), 1)).to(device)
attention_mask = torch.cat([image_attention_mask, attention_mask], 1).unsqueeze(1).unsqueeze(2)
attention_mask = (1.0 - attention_mask) * -10000
"""利用self-attention机制进行多模态融合"""
image_text_attention_state = self.comb_attention(image_text_hidden_state, attention_mask)[0]
"""分别提取图像和文本特征,拼接"""
image_pooled_output = self.image_pool(image_text_attention_state[:, 0, :])
text_pooled_output = self.text_pool(image_text_attention_state[:, 1, :])
final_output = torch.cat([image_pooled_output, text_pooled_output], 1)
"""利用拼接向量进行分类"""
out = self.classifier(final_output)
?4?完整代码获取
最后:
如果你想要进一步了解更多的相关知识,可以关注下面公众号联系~会不定期发布相关设计内容包括但不限于如下内容:信号处理、通信仿真、算法设计、matlab appdesigner,gui设计、simulink仿真......希望能帮到你!