本文主要基于LLaMA和ImageBind工作,结合多模态信息和文本指令来实现一系列任务。训练中仅使用图像文本信息作为多模态信息提取能力的训练数据(only leverage the vision-language data for multi-modality instruction tuning)。Github代码 link.
对于一个图像文本对,
所以整个模型可以分为两个阶段的训练,
主要作用是对齐ImageBind和LLaMA之间的特征空间。
代码实现:
import torch
import torch.nn as nn
import torch.nn.functional as F
# Define the RMSNorm
class RMSNorm(torch.nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
# Define the repeated feedforward block in bind network
class FeedForwardBlock(nn.Module):
def __init__(self, dim: int, hidden_dim: int):
super().__init__()
# normalize the input
self.norm = RMSNorm(dim)
# Define 3 linear projection layers whose parameters are w1, w2 and w3 respectively.
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
self.w2 = nn.Linear(dim, hidden_dim, bias=False)
self.w3 = nn.Linear(hidden_dim, dim, bias=False)
def forward(self, x):
# cascade linear linears with RMSNorm, SiLU activation functions and residual connections
x = self.norm(x)
return x + self.w3(F.silu(self.w1(x)) * self.w2(x))
class bind_network(nn.Module):
def __init__(self, args):
super.__init__()
self.image_dim = args.image_dim # e.g., 1024, encoded by ImageBind
self.model_dim = args.model_dim # e.g., 4096
self.ffn_dim = self.model_dim * 4 #
self.linear_0 = nn.Linear(self.image_dim, self.model_dim)
self.feed_forward_1 = FeedForwardBlock(dim=self.model_dim, hidden_dim=self.ffn_dim)
self.feed_forward_2 = FeedForwardBlock(dim=self.model_dim, hidden_dim=self.ffn_dim)
self.feed_forward_3 = FeedForwardBlock(dim=self.model_dim, hidden_dim=self.ffn_dim)
def forward(self, image_feature):
# image_feature, (1,C1) / (1,image_dim)
# Adopt the linear projection layer at first
image_feature = self.linear_0(image_feature) # image_feature, (1, model_dim)
# Cascade 3 projection blocks
image_feature = self.feed_forward_1(image_feature)
image_feature = self.feed_forward_2(image_feature)
transformed_image_feature = self.feed_forward_3(image_feature)
return transformed_image_feature
计算过程,对于输入向量 x ∈ R m x∈R^m x∈Rm,
LayerNorm的计算方法,
RMSNorm的计算方法,
故RMSNorm完整减少了计算加权和平均值μ的步骤,保证模型与输入向量和权重解耦、训练过程中梯度稳定及模型收敛速度的前提下,减少了额外的计算开销,加速7%~64%的网络训练(具体的提升指标受硬件、网络结构、其他部分计算开销等影响)。
模型输入图像 (image inputs),输出文本(language responses)。
Pipeline:
局限:只能解决简单的视觉问答(visual question answering scenarios)问题,例如ScienceQA
OceanneDLG@outlook.com