transfomer中正余弦位置编码的源码实现

发布时间:2024年01月17日

简介

Transformer模型抛弃了RNN、CNN作为序列学习的基本模型。循环神经网络本身就是一种顺序结构,天生就包含了词在序列中的位置信息。当抛弃循环神经网络结构,完全采用Attention取而代之,这些词序信息就会丢失,模型就没有办法知道每个词在句子中的相对和绝对的位置信息。因此,有必要把词序信号加到词向量上帮助模型学习这些信息,位置编码(Positional Encoding)就是用来解决这种问题的方法。
关于位置编码更多介绍参考bev感知专栏博客

源码实现:

import torch
import matplotlib.pyplot as plt


def posemb_sincos_2d(h, w, dim, temperature: int = 10000, dtype=torch.float32):
    y, x = torch.meshgrid(torch.arange(h), torch.arange(w), indexing="ij")
    assert (dim % 4) == 0, "feature dimension must be multiple of 4 for sincos emb"
    omega = torch.arange(dim // 4) / (dim // 4 - 1)
    omega = 1.0 / (temperature ** omega)

    y = y.flatten()[:, None] * omega[None, :]
    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos(), y.sin(), y.cos()), dim=1)
    return pe.type(dtype)


def posemb_sincos_1d(len, dim, temperature: int = 10000, dtype=torch.float32):
    x = torch.arange(len)
    assert (dim % 2) == 0, "feature dimension must be multiple of 2 for sincos emb"
    omega = torch.arange(dim // 2) / (dim // 2 - 1)
    omega = 1.0 / (temperature ** omega)

    x = x.flatten()[:, None] * omega[None, :]
    pe = torch.cat((x.sin(), x.cos()), dim=1)  # 这里不用担心,不交叉无所谓,
    return pe.type(dtype)


if __name__ == '__main__':
    pos = posemb_sincos_1d(200, 256)
    # pos = posemb_sincos_2d(20,20,256)

    # 创建一个热力图
    plt.imshow(pos, cmap='hot', interpolation='nearest')
    # 添加颜色条
    plt.colorbar()
    # 显示图形
    plt.show()
    pass

可视化结果如下:
在这里插入图片描述

文章来源:https://blog.csdn.net/zwhdldz/article/details/135638500
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。