看下面的position embedding的代码:
class LearnablePositionalEmbedding(torch.nn.Module):
"""Shorthand for a learnable embedding."""
def __init__(self, embed_dim, max_position_embeddings=1024, dropout=0.0):
super().__init__()
self.embedding = torch.nn.Embedding(max_position_embeddings, embed_dim)
self.dropout = torch.nn.Dropout(p=dropout)
def forward(self, input_embeddings):
"""This is a batch-first implementation"""
position_ids = torch.arange(input_embeddings.shape[1], device=self.embedding.weight.device)
position_embeddings = self.embedding(position_ids[None, :])
return self.dropout(input_embeddings + position_embeddings)
简言之:
position_ids[None, :]
的目的是为了将其变成一个二维张量,以便与 input_embeddings
进行相加。在这里,position_ids
的长度(N)应该与 input_embeddings
张量的第二个维度长度相同
input_embeddings
的形状是 (batch_size, sequence_length, embed_dim)
,那么 position_ids[None, :]
的形状将变为 (1, sequence_length)
,然后通过广播(broadcasting)机制,它会与 input_embeddings
的第一个维度进行广播,使得两者的形状能够相加。这样,每个位置的嵌入都与相应的位置信息相加,从而引入了位置编码。?