Prefix-tuning: 对于每个任务,都有一个特定的前缀被添加到输入序列的开始部分。这些前缀相当于任务特定的提示,可以是一组固定的词或是可训练的嵌入向量。同时,为了防止直接更新Prefix的参数导致训练不稳定和性能下降的情况,在Prefix层前面加了MLP结构,训练完成后,只保留Prefix的参数。Prefix-tuning形式为 [PREFIX; x; y]。
class PrefixEncoder(torch.nn.Module):
......
if self.prefix_projection and not config.inference_mode:
# Use a two-layer MLP to encode the prefix
self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim)
self.transform