Transformer原理与代码实现

发布时间:2024年01月15日

Transformer作为进年来语言模型的底层架构意义重大,如果不能仔细研读并尝试训练一下,总觉得自己的基础缺缺。Transformer是在这篇文章《Attention Is All You Need》中提出, 2年前写过这篇论文的阅读笔记:【文本分类】Attention Is All You Need。Transformer已经有代码实现, 我已经在参考部分列出了其中一些。在阅读这些源代码时,我学习了一些技巧,这些技巧并未写在论文中,所以我们想专门写一篇教程来介绍代码细节。

概览

在这里插入图片描述
??从整体角度上来说,编码器将输入序列映射到向量中,该向量保存该输入的所有学习信息。然后,解码器获取该连续向量,同时还被输入先前的输出序列,然后逐步生成单个输出。

??从代码实现上来看,我们依次需要实现的模块有:

??嵌入层 Embedding
??位置编码 Positional Encoding
??Transformer嵌入层 Transformer Embedding
??带缩放的点积注意力机制 Scaled Dot-Product Attention
??多头注意力 Multi-Head Attention
??分位置的前馈机制 Position-wise Feed-Forward
??序列掩码 Look-Ahead Mask
??掩码多头注意力 Masked Multi-Head Attention
??(整合)编码器块EncoderBlock
??编码器 Encoder
??(整合)解码器块DecoderBlock
??解码器 Decoder
??(整合)Transformer

??任重而道远,现在让我们开始吧。

一、嵌入层 Embedding

??第一步是将输入输出到单词嵌入层。单词嵌入层可以被认为是获取每个单词的学习矢量表示的查找表。神经网络通过数字来学习,所以每个单词都映射到一个具有连续值的向量来表示该单词。

嵌入是最最基础的概念,它的目的是把句子中的每个词转化成对应的向量。我之前写过很多介绍嵌入的博客,请参考:一文了解Word2vec 阐述训练流程【文本分类】深入理解embedding层的模型、结构与文本表示

在这里插入图片描述
??词嵌入只需要引用torch的一个Embedding层就可以实现。

from torch import nn

self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)

二、位置编码 Positional Encoding

??下一步是将位置信息添加到嵌入中。因为变换器编码器不像递归神经网络那样具有递归性,所以我们必须将一些关于位置的信息添加到输入嵌入中。这是使用位置编码完成的。作者想出了一个使用正弦和余弦函数的聪明绝招。

在这里插入图片描述
??对于输入向量的每个奇数索引,使用cos函数创建一个向量。对于每个偶数索引,使用sin函数创建一个向量。然后将这些向量添加到它们相应的输入嵌入中。这成功地给出了关于每个矢量位置的网络信息。选择正弦和余弦函数是因为它们具有线性属性,模型可以很容易地学会处理。

??代码实现:

class PositionalEncoding(nn.Module):
    def __init__(self, max_positions: int, dim_embed: int) -> None:
        
        super().__init__()

        assert dim_embed % 2 == 0

        position = torch.arange(max_positions).unsqueeze(1)
        dim_pair = torch.arange(0, dim_embed, 2)
        div_term = torch.exp(dim_pair * (-math.log(10000.0) / dim_embed))

        pe = torch.zeros(max_positions, dim_embed)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # 添加batch维度
        pe = pe.unsqueeze(0)

        # 整个学习阶段, 位置信息是不变的, 注册为不可学习的数据
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        # 计算每个batch的最大句子长度
        max_sequence_length = x.size(1)

        return self.pe[:, :max_sequence_length]

??在上面的代码中,把pe固定到缓存中是因为Transformer的位置向量就是按正弦或余弦函数算出来的固定值。假设PositionalEncodingmax_positions=8dim_embed=8,我们打断点调试可以看到pe是个固定向量,其值为:
在这里插入图片描述

??拓展问题一:为什么要对位置进行编码?

??因为: Attention提取特征的时候,可以获取全局每个词对之间的关系,但是并没有显式保留时序信息,或者说位置信息。就算打乱序列中token的顺序,最后所得到的Attention结果也不会变,这会丢失语言中的时序信息,因此需要额外对位置进行编码以引入时序信息。

??拓展问题二:Transformer的位置编码和BERT的位置编码是一样的吗?

??答: 不一样,不妨去transformers.models.bert.modeling_bert.py中看下源码,会发现BERT的位置编码其实也是个Embedding层,和词嵌入一样。BERT选择这么做的原因可能是,相比于Transformer,BERT训练所用的数据量充足,完全可以让模型自己学习。

三、(整合)Transformer嵌入层 Transformer Embedding

??参考了BERT模型的源码后,决定将词嵌入位置嵌入统一一下称作transformer的嵌入TransformerEmbeddings。最终的向量结果是词嵌入和位置嵌入直接做加法,比较简单。

class TransformerEmbeddings(nn.Module):
    """Construct the embeddings from word, position embeddings."""

    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = PositionalEncoding(config.max_position_embeddings, config.hidden_size)

        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
        inputs_embeds = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(inputs_embeds)

        embeddings = inputs_embeds + position_embeddings

        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

??到目前为止,我们完成了以下模块的编码工作:

在这里插入图片描述

四、带缩放的点积注意力机制 Scaled Dot-Product Attention

??在讲解transformer的带缩放点积注意力机制之前,先举个例子简单介绍注意力机制,以下是例子:


??注意力机制的三个输入分别是QKV,即query、key、value。query的含义是要进行查询的数据,(key, value)表示将要被查询的键值对。

??假设现在有一个身高体重的键值对表,然后我拿着一个数据162去查询:

图片参考博主athrunsunny

??通用的做法是用查询数字162,对键值表里面的每一个身高160、166、173进行一些权重运算这个权重我们称之为注意力,这个权重运算函数 F F F我们称之为注意力计算模型。拿到了注意力之后再乘上value,我们就可以估算出查询数字162的体重是多少。

A t t e n t i o n = F ( q , k 1 ) v 1 + F ( q , k 2 ) v 2 + F ( q , k 3 ) v 3 Attention = F(q,k_1)v_1 + F(q,k_2)v_2 + F(q,k_3)v_3 Attention=F(q,k1?)v1?+F(q,k2?)v2?+F(q,k3?)v3?

??设想一下,如果此时QKV全都是同一个向量X,那么Q和K经过注意力模型 F F F之后得到权重,再乘以V,最终得到的就是向量X对自己进行自注意力之后的数据。这就是自注意力的本质。

??在上面的等式中我们把 F F F换成缩放点积模型,这样就变成了transformer论文里面的注意力公式了:

F = s o f t m a x ( Q ? K T / d k ) ? V F = softmax(Q·K^T/\sqrt{d_k})·V F=softmax(Q?KT/dk? ?)?V


??举例结束。其实注意力机制实现起来还略有差别,实际的transformer中会将QKV使用linear做线性变换(可学习参数W),映射到不同的线性空间,并且会将其分成多个head,每个head能学到不同的东西,来增加特征的多样性,从而为模型提供更多的表达能力。transformer中实际注意力计算模型如下:

在这里插入图片描述
??这里的attention机制,相比于经典的Dot-product Attention其实就是多了一个scale项。这里的作用是啥呢?当d比较小的时候,要不要scale都无所谓,但是当d比较大时,内积的值的范围就会变得很大,不同的内积的差距也会拉大,这样的话,再经过softmax进一步的扩大差距,就会使得得到的attention分布很接近one-hot,这样会导致梯度下降困难,模型难以训练。在Transformer中,d=512,算比较大了,因此需要进行scaling。

??代码实现如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

def attention(query: Tensor, key: Tensor, value: Tensor, mask: Tensor=None) -> Tensor:
    sqrt_dim_head = query.shape[-1]**0.5

    scores = torch.matmul(query, key.transpose(-2, -1))
    scores = scores / sqrt_dim_head
    
    if mask is not None:
        scores = scores.masked_fill(mask==0, -1e9)
    
    weight = F.softmax(scores, dim=-1)    
    return torch.matmul(weight, value)

??提示:函数的mask参数可以先忽略一下,在下文会介绍到。这是个很重要的机制。

??在通过线性层学习QKV向量之后,Q和K经历点积矩阵乘法以产生得分矩阵:
在这里插入图片描述
??得分矩阵决定了一个单词在其他单词上的关注程度。因此每个单词都有一个与时间步长中的其他单词相对应的分数。分数越高,注意力越集中。这就是Q映射到K的方式:

在这里插入图片描述

??然后,分数通过除以Q和K的维度的平方根而缩小。这是为了获得更稳定的数据,否则可能会产生爆炸效果。

在这里插入图片描述
??下一步,使用softmax来得到注意力权重,返回0到1之间的概率值。通过做softmax,高分得到提高,低分受到抑制。这样模型可以决定对哪些单词的注意力更高哪些单词的注意力比较低。

在这里插入图片描述
??然后把注意力权重乘以V,得到一个输出向量。softmax分数越高,模型学习的单词值就越重要。较低的分数会淹没不重要的单词。然后把它的输出输入到一个线性层进行处理。

在这里插入图片描述

五、多头注意力 Multi-Head Attention

??要使用多头注意力计算,需要在自我注意力之前将QKV分成N个向量。分裂的向量然后单独经历自我注意过程。每一个自我关注的过程被称为一个头。每个头产生一个输出向量,在通过最终的线性层之前,该向量被连接成一个向量。理论上,每个头将学习不同的东西,因此给编码器模型更多的表示能力。

在这里插入图片描述
??代码如下:

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, dim_embed: int, drop_prob: float) -> None:
        super().__init__()
        assert dim_embed % num_heads == 0

        self.num_heads = num_heads
        self.dim_embed = dim_embed
        self.dim_head = dim_embed // num_heads

        self.query = nn.Linear(dim_embed, dim_embed)
        self.key = nn.Linear(dim_embed, dim_embed)
        self.value = nn.Linear(dim_embed, dim_embed)
        self.output = nn.Linear(dim_embed, dim_embed)
        self.dropout = nn.Dropout(drop_prob)

    def forward(self, x: Tensor, y: Tensor, mask: Tensor = None) -> Tensor:
        query = self.query(x)
        key = self.key(y)
        value = self.value(y)

        batch_size = x.size(0)
        query = query.view(batch_size, -1, self.num_heads, self.dim_head)
        key = key.view(batch_size, -1, self.num_heads, self.dim_head)
        value = value.view(batch_size, -1, self.num_heads, self.dim_head)

        # Into the number of heads (batch_size, num_heads, -1, dim_head)
        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1)

        attn = attention(query, key, value, mask)
        attn = attn.transpose(1, 2).contiguous().view(batch_size, -1, self.dim_embed)

        out = self.dropout(self.output(attn))

        return out

??对于多头自注意力来说,它的“头”的大小是不影响模型参数量的。 假设你的嵌入层维度是300维,如果你有4个头的话,那就会把你的张量切割成4个75维的张量;同理,如果你有10个头的话,那就会把你的张量切割成10个30维的张量。

??300维的张量如何计算注意力,分割成n个头的张量还是同样的方式去计算注意力,只不过计算完了之后会再重新拼接成300维。所以,多头注意力机制的代码很简单,没必要细看。

六、分位置的前馈机制 Position-wise Feed-Forward

??Position-wise Feed-Forward 给词向量增加了非线性。 词向量的形状是(batch_size, max_sequence_length, dim_embed),很多神经网络处理词向量的时候会进行flatten然后再进入前馈神经网络, 我们并没有将词向量进行flatten, 我们的线性操作是对每个位置进行独立的操作,因此, 这里被称为 Position-wise 。

??代码如下:

import torch.nn as nn
from torch import Tensor

class PositionwiseFeedForward(nn.Module):
    def __init__(self, dim_embed: int, dim_pffn: int, drop_prob: float) -> None:
        super().__init__()
        self.pffn = nn.Sequential(
            nn.Linear(dim_embed, dim_pffn),
            nn.ReLU(inplace=True),
            nn.Dropout(drop_prob),
            nn.Linear(dim_pffn, dim_embed),
            nn.Dropout(drop_prob),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.pffn(x)

??朴实无华,所以这里就不过多介绍了。

??到了这里,所有编码器的组成部分全都介绍完了。

七、序列掩码 Look-Ahead Mask

??目前为止,编码器的组成部分介绍完了。下面介绍解码器。

??解码器的组成部分和编码器一样,只是比编码器多了一个掩码多头注意力Masked Multi-Head Attention在介绍它之前,必须得先说一下作者提出的非常厉害的Look-Ahead Mask的机制。

??这也是在【四、带缩放的点积注意力机制 Scaled Dot-Product Attention 】中实现的attention函数的参数里为什么有mask变量的原因。


??由于解码器是自回归的,并且一个字一个字地生成序列,所以需要防止它在训练时能看到未来的词。例如,当计算单词“am”的注意力分数时,不应该访问单词“fine ”,因为该单词是在之后生成的未来单词。单词“am”应该只能访问它自己和它前面的单词。这对于所有其他的词都是如此,在那里它们只能注意前面的词。

在这里插入图片描述

??我们需要一种方法来防止计算未来单词的注意力分数。这种方法叫做mask。为了防止解码器查看将来的单词,可以使用一个mask向量。在计算softmax之前和缩放分数之后添加mask。让我们来看看这是如何工作的。

??mask是一个矩阵,其大小与填充有0值和负无穷大值的注意力分数相同。当把mask加到缩放的注意力分数上时,会得到一个分数矩阵,右上角的三角形填充了否定的无穷大。

在这里插入图片描述

??mask的原理是一旦取权重分数的softmax,负的无穷大被清零,为未来的词留下等于零的注意力分数。例如下图所看到的,“am”的关注度得分本身及其之前的所有单词都有值,但单词“fine”的关注度得分为零。这实际上是告诉模型不要关注这些单词。

在这里插入图片描述

??此时再用权重矩阵去乘以value矩阵的话,就会发现比如<start>的向量计算时并不会累加它之后词的向量数据,也就是完成了当前词看不到后面词的功能。

虽然mask我们看起来简单,但是能提出这样的想法,并且行之有效我觉得非常了不起。近年来随着ChatGPT的爆火,国内各大公司纷纷跟进大模型,但每每读到国外的这种开创性的论文,和非常具有创新力的想法,再想到国内的学术界和工业界,只能说五味杂陈吧。

??代码实现:

def make_x_mask(self, x):
    x_mask = (x != self.pad_token_id).unsqueeze(1)
    return x_mask

def make_y_mask(self, y):
    N, y_len = y.shape
    y_mask = torch.tril(torch.ones((y_len, y_len))).expand(
        N, y_len, y_len
    )
    return y_mask

??具体的代码实现稍微有一些不一样,我们是创建了一个都为1的下三角矩阵,其他位置为0,在mask的时候把为零位置的数据设置成无穷小,这样节省了相加的步骤会更好一些。

八、掩码多头注意力 Masked Multi-Head Attention

??掩码多头注意力的代码实现和在【四、带缩放的点积注意力机制 Scaled Dot-Product Attention 】中实现的attention函数一样。

在这里插入图片描述

??只不过正常的多头注意力在计算的时候,attention函数的mask参数并不会起作用,但是掩码多头注意力的mask参数会起作用。

??到此为止,我们就介绍完了transformer模型的所有模块,下面我们开始把这些模块都组装起来构建真正的transform模型。

九、(整合)编码器块EncoderBlock

??首先构建transformer模型的编码器部分,编码器是由n个编码器块循环堆叠构成的。所以首先介绍编码器块EncoderBlock

??代码如下。其实他就是把多头注意力Multi-Head Attention和分位置的前馈机制 Position-wise Feed-Forward拼起来。

class EncoderBlock(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()

        # Self-attention
        self.self_atten = MultiHeadAttention(config.num_heads, config.hidden_size, config.attention_probs_dropout_prob)
        self.layer_norm1 = nn.LayerNorm(config.hidden_size)

        # Point-wise feed-forward
        self.feed_forward = PositionwiseFeedForward(config.hidden_size, config.dim_pffn,
                                                    config.attention_probs_dropout_prob)
        self.layer_norm2 = nn.LayerNorm(config.hidden_size)

    def forward(self, x: Tensor, x_mask: Tensor) -> Tensor:
        # 图中Add节点
        x = x + self.sub_layer1(x, x_mask)
        x = x + self.sub_layer2(x)
        return x

    def sub_layer1(self, x: Tensor, x_mask: Tensor) -> Tensor:
        # 先进行norm
        x = self.layer_norm1(x)
        x = self.self_atten(x, x, x_mask)
        return x

    def sub_layer2(self, x: Tensor) -> Tensor:
        x = self.layer_norm2(x)
        x = self.feed_forward(x)
        return x

十、编码器Encoder

??Transformer使用了多个encoder模块, 下面的代码实现了多个encoder的堆叠。

class Encoder(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()

        self.blocks = nn.ModuleList(
            [EncoderBlock(config)
             for _ in range(config.num_blocks)]
        )
        self.layer_norm = nn.LayerNorm(config.hidden_size)

    def forward(self, x: Tensor, x_mask: Tensor):
        for block in self.blocks:
            x = block(x, x_mask)
        x = self.layer_norm(x)
        return x

十一、(整合)解码器块DecoderBlock

??然后构建transformer模型的解码器部分,解码器是由n个解码器块循环堆叠构成的。所以介绍解码器块DecoderBlock

??代码如下。其实是在编码器块之前,再加一个掩码多头注意力 Masked Multi-Head Attention

class DecoderBlock(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()

        # Self-attention
        self.self_attn = MultiHeadAttention(config.num_heads, config.hidden_size, config.attention_probs_dropout_prob)
        self.layer_norm1 = nn.LayerNorm(config.hidden_size)

        # Target-source
        self.tgt_src_attn = MultiHeadAttention(config.num_heads, config.hidden_size,
                                               config.attention_probs_dropout_prob)
        self.layer_norm2 = nn.LayerNorm(config.hidden_size)

        # Position-wise
        self.feed_forward = PositionwiseFeedForward(config.hidden_size, config.dim_pffn,
                                                    config.attention_probs_dropout_prob)
        self.layer_norm3 = nn.LayerNorm(config.hidden_size)

    def forward(self, y, y_mask, x, x_mask) -> Tensor:
        # 实现residual connection
        y = y + self.sub_layer1(y, y_mask)
        y = y + self.sub_layer2(y, x, x_mask)
        y = y + self.sub_layer3(y)
        return y

    def sub_layer1(self, y: Tensor, y_mask: Tensor) -> Tensor:
        y = self.layer_norm1(y)
        y = self.self_attn(y, y, y_mask)
        return y

    def sub_layer2(self, y: Tensor, x: Tensor, x_mask: Tensor) -> Tensor:
        y = self.layer_norm2(y)
        y = self.tgt_src_attn(y, x, x_mask)
        return y

    def sub_layer3(self, y: Tensor) -> Tensor:
        y = self.layer_norm3(y)
        y = self.feed_forward(y)
        return y

十二、解码器Decoder

??同样的,解码器也是解码器块循环n次。

class Decoder(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.blocks = nn.ModuleList(
            [DecoderBlock(config)
             for _ in range(config.num_blocks)]
        )
        self.layer_norm = nn.LayerNorm(config.hidden_size)

    def forward(self, x: Tensor, x_mask: Tensor, y: Tensor, y_mask: Tensor) -> Tensor:
        for block in self.blocks:
            y = block(y, y_mask, x, x_mask)
        y = self.layer_norm(y)
        return y

十三、Transformer

??下面就是将所有模块整合成一个模块的时候了, 所以下面的代码用到了以上所有的代码:

class Transformer(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.pad_token_id = config.pad_token_id

        # Input embeddings, positional encoding, and encoder
        self.input_embedding = TransformerEmbeddings(config)
        self.encoder = Encoder(config)

        # Ouput embeddings, positional encoding, and decoder
        self.output_embedding = TransformerEmbeddings(config)
        self.decoder = Decoder(config)

        self.projection = nn.Linear(config.hidden_size, config.vocab_size)

        # Initialize parameters
        for param in self.parameters():
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)

    def make_x_mask(self, x):
        x_mask = (x != self.pad_token_id).unsqueeze(1)
        return x_mask

    def make_y_mask(self, y):
        N, y_len = y.shape
        y_mask = torch.tril(torch.ones((y_len, y_len))).expand(
            N, y_len, y_len
        )
        return y_mask

    def forward(self, x: Tensor, y: Tensor) -> Tensor:
        x_mask = self.make_x_mask(x)
        y_mask = self.make_y_mask(y)

        x = self.encode(x, x_mask)
        y = self.decode(x, y, x_mask, y_mask)
        return y

    def encode(self, x: Tensor, x_mask: Tensor = None) -> Tensor:
        x = self.input_embedding(x)
        x = self.encoder(x, x_mask)
        return x

    def decode(self, x: Tensor, y: Tensor,
               x_mask: Tensor = None, y_mask: Tensor = None) -> Tensor:
        y = self.output_embedding(y)
        y = self.decoder(x, x_mask, y, y_mask)
        return self.projection(y)

??最后一层使用 nn.Linear 将词向量的维数转换为output_vocab_size, 这样就可以使用softmax输出词的概率。

十四、全部代码,开箱即用

??如下:

import math
from typing import Optional, Tuple

import torch
from torch import nn, Tensor
import torch.nn.functional as F


# 计算QKV的注意力机制
def attention(query: Tensor, key: Tensor, value: Tensor, mask: Tensor = None) -> Tensor:
    sqrt_dim_head = query.shape[-1] ** 0.5

    scores = torch.matmul(query, key.transpose(-2, -1))
    scores = scores / sqrt_dim_head

    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)

    weight = F.softmax(scores, dim=-1)
    return torch.matmul(weight, value)


class TransformerConfig:
    def __init__(self, vocab_size, hidden_size, max_position_embeddings, layer_norm_eps, pad_token_id,
                 hidden_dropout_prob, attention_probs_dropout_prob, num_blocks, num_heads, dim_pffn):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.max_position_embeddings = max_position_embeddings
        self.layer_norm_eps = layer_norm_eps
        self.pad_token_id = pad_token_id
        self.hidden_dropout_prob = hidden_dropout_prob
        self.attention_probs_dropout_prob = attention_probs_dropout_prob
        self.num_blocks = num_blocks
        self.num_heads = num_heads
        self.dim_pffn = dim_pffn


class PositionalEncoding(nn.Module):
    def __init__(self, max_positions: int, dim_embed: int) -> None:
        """https://mlln.cn/2022/12/10/Transformer%E8%AF%A6%E7%BB%86%E8%A7%A3%E8%AF%BB%E5%92%8C%E4%BB%A3%E7%A0%81%E6%A1%88%E4%BE%8B/"""
        super().__init__()

        assert dim_embed % 2 == 0

        position = torch.arange(max_positions).unsqueeze(1)
        dim_pair = torch.arange(0, dim_embed, 2)
        div_term = torch.exp(dim_pair * (-math.log(10000.0) / dim_embed))

        pe = torch.zeros(max_positions, dim_embed)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # 添加batch维度
        pe = pe.unsqueeze(0)

        # 整个学习阶段, 位置信息是不变的, 注册为不可学习的数据
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        # 计算每个batch的最大句子长度
        max_sequence_length = x.size(1)

        return self.pe[:, :max_sequence_length]


class TransformerEmbeddings(nn.Module):
    """Construct the embeddings from word, position embeddings."""

    def __init__(self, config):
        super().__init__()
        self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
        self.position_embeddings = PositionalEncoding(config.max_position_embeddings, config.hidden_size)

        self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, input_ids: Optional[torch.LongTensor] = None) -> torch.Tensor:
        inputs_embeds = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(inputs_embeds)

        embeddings = inputs_embeds + position_embeddings

        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings


class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads: int, dim_embed: int, drop_prob: float) -> None:
        super().__init__()
        assert dim_embed % num_heads == 0

        self.num_heads = num_heads
        self.dim_embed = dim_embed
        self.dim_head = dim_embed // num_heads

        self.query = nn.Linear(dim_embed, dim_embed)
        self.key = nn.Linear(dim_embed, dim_embed)
        self.value = nn.Linear(dim_embed, dim_embed)
        self.output = nn.Linear(dim_embed, dim_embed)
        self.dropout = nn.Dropout(drop_prob)

    def forward(self, x: Tensor, y: Tensor, mask: Tensor = None) -> Tensor:
        query = self.query(x)
        key = self.key(y)
        value = self.value(y)

        batch_size = x.size(0)
        query = query.view(batch_size, -1, self.num_heads, self.dim_head)
        key = key.view(batch_size, -1, self.num_heads, self.dim_head)
        value = value.view(batch_size, -1, self.num_heads, self.dim_head)

        # Into the number of heads (batch_size, num_heads, -1, dim_head)
        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)

        if mask is not None:
            mask = mask.unsqueeze(1)

        attn = attention(query, key, value, mask)
        attn = attn.transpose(1, 2).contiguous().view(batch_size, -1, self.dim_embed)

        out = self.dropout(self.output(attn))

        return out


class PositionwiseFeedForward(nn.Module):
    def __init__(self, dim_embed: int, dim_pffn: int, drop_prob: float) -> None:
        super().__init__()
        self.pffn = nn.Sequential(
            nn.Linear(dim_embed, dim_pffn),
            nn.ReLU(inplace=True),
            nn.Dropout(drop_prob),
            nn.Linear(dim_pffn, dim_embed),
            nn.Dropout(drop_prob),
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.pffn(x)


class EncoderBlock(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()

        # Self-attention
        self.self_atten = MultiHeadAttention(config.num_heads, config.hidden_size, config.attention_probs_dropout_prob)
        self.layer_norm1 = nn.LayerNorm(config.hidden_size)

        # Point-wise feed-forward
        self.feed_forward = PositionwiseFeedForward(config.hidden_size, config.dim_pffn,
                                                    config.attention_probs_dropout_prob)
        self.layer_norm2 = nn.LayerNorm(config.hidden_size)

    def forward(self, x: Tensor, x_mask: Tensor) -> Tensor:
        # 图中Add节点
        x = x + self.sub_layer1(x, x_mask)
        x = x + self.sub_layer2(x)
        return x

    def sub_layer1(self, x: Tensor, x_mask: Tensor) -> Tensor:
        # 先进行norm
        x = self.layer_norm1(x)
        x = self.self_atten(x, x, x_mask)
        return x

    def sub_layer2(self, x: Tensor) -> Tensor:
        x = self.layer_norm2(x)
        x = self.feed_forward(x)
        return x


class Encoder(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()

        self.blocks = nn.ModuleList(
            [EncoderBlock(config)
             for _ in range(config.num_blocks)]
        )
        self.layer_norm = nn.LayerNorm(config.hidden_size)

    def forward(self, x: Tensor, x_mask: Tensor):
        for block in self.blocks:
            x = block(x, x_mask)
        x = self.layer_norm(x)
        return x


class DecoderBlock(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()

        # Self-attention
        self.self_attn = MultiHeadAttention(config.num_heads, config.hidden_size, config.attention_probs_dropout_prob)
        self.layer_norm1 = nn.LayerNorm(config.hidden_size)

        # Target-source
        self.tgt_src_attn = MultiHeadAttention(config.num_heads, config.hidden_size,
                                               config.attention_probs_dropout_prob)
        self.layer_norm2 = nn.LayerNorm(config.hidden_size)

        # Position-wise
        self.feed_forward = PositionwiseFeedForward(config.hidden_size, config.dim_pffn,
                                                    config.attention_probs_dropout_prob)
        self.layer_norm3 = nn.LayerNorm(config.hidden_size)

    def forward(self, y, y_mask, x, x_mask) -> Tensor:
        # 实现residual connection
        y = y + self.sub_layer1(y, y_mask)
        y = y + self.sub_layer2(y, x, x_mask)
        y = y + self.sub_layer3(y)
        return y

    def sub_layer1(self, y: Tensor, y_mask: Tensor) -> Tensor:
        y = self.layer_norm1(y)
        y = self.self_attn(y, y, y_mask)
        return y

    def sub_layer2(self, y: Tensor, x: Tensor, x_mask: Tensor) -> Tensor:
        y = self.layer_norm2(y)
        y = self.tgt_src_attn(y, x, x_mask)
        return y

    def sub_layer3(self, y: Tensor) -> Tensor:
        y = self.layer_norm3(y)
        y = self.feed_forward(y)
        return y


class Decoder(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.blocks = nn.ModuleList(
            [DecoderBlock(config)
             for _ in range(config.num_blocks)]
        )
        self.layer_norm = nn.LayerNorm(config.hidden_size)

    def forward(self, x: Tensor, x_mask: Tensor, y: Tensor, y_mask: Tensor) -> Tensor:
        for block in self.blocks:
            y = block(y, y_mask, x, x_mask)
        y = self.layer_norm(y)
        return y


class Transformer(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.pad_token_id = config.pad_token_id

        # Input embeddings, positional encoding, and encoder
        self.input_embedding = TransformerEmbeddings(config)
        self.encoder = Encoder(config)

        # Ouput embeddings, positional encoding, and decoder
        self.output_embedding = TransformerEmbeddings(config)
        self.decoder = Decoder(config)

        self.projection = nn.Linear(config.hidden_size, config.vocab_size)

        # Initialize parameters
        for param in self.parameters():
            if param.dim() > 1:
                nn.init.xavier_uniform_(param)

    def make_x_mask(self, x):
        x_mask = (x != self.pad_token_id).unsqueeze(1)
        return x_mask

    def make_y_mask(self, y):
        N, y_len = y.shape
        y_mask = torch.tril(torch.ones((y_len, y_len))).expand(
            N, y_len, y_len
        )
        return y_mask

    def forward(self, x: Tensor, y: Tensor) -> Tensor:
        x_mask = self.make_x_mask(x)
        y_mask = self.make_y_mask(y)

        x = self.encode(x, x_mask)
        y = self.decode(x, y, x_mask, y_mask)
        return y

    def encode(self, x: Tensor, x_mask: Tensor = None) -> Tensor:
        x = self.input_embedding(x)
        x = self.encoder(x, x_mask)
        return x

    def decode(self, x: Tensor, y: Tensor,
               x_mask: Tensor = None, y_mask: Tensor = None) -> Tensor:
        y = self.output_embedding(y)
        y = self.decoder(x, x_mask, y, y_mask)
        return self.projection(y)

文章来源:https://blog.csdn.net/qq_43592352/article/details/135596093
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。