torch.einsum用法

发布时间:2024年01月12日

目录

1. 向量点乘(内积)

2. 外积

3. 矩阵和向量的乘法

4. 矩阵乘法

5. 批量矩阵乘法

6. 求和操作


torch.einsum 是 PyTorch 中的一个强大工具,它允许你通过 Einstein summation convention(爱因斯坦求和约定)来执行复杂的张量操作。使用这种约定,你可以用一个字符串来指定张量操作的维度规则。这个函数非常灵活,可以用于实现各种张量运算,如元素相乘、矩阵乘法、批量矩阵乘法、迹等。

torch.einsum的语法格式如下:

torch.einsum(equation, *operands)

????????其中,equation是一个字符串,指定了张量操作的约定;*operands是要操作的张量。

以下是 torch.einsum 的一些常见用法:

1. 向量点乘(内积)

import torch

a = torch.tensor([1, 2, 3])
b = torch.tensor([0, 1, 0])

# 计算两个向量的内积
result = torch.einsum('i,i->', a, b)  # 输出: tensor(2)

这里 'i,i->' 告诉 einsum 计算 ab 的对应元素相乘的和,最后结果是一个标量。

2. 外积

# 计算两个向量的外积
result = torch.einsum('i,j->ij', a, b)  # 输出: 2x3 矩阵

这里 'i,j->ij' 表示 a 的每个元素(索引为 i)与 b 的每个元素(索引为 j)相乘,得到一个二维矩阵。

3. 矩阵和向量的乘法

A = torch.tensor([[1, 2], [0, 1], [2, 0]])
v = torch.tensor([0, 1])

# 矩阵和向量的乘法
result = torch.einsum('ij,j->i', A, v)  # 输出: 1D tensor of size 3

这里 'ij,j->i' 表示 A 的每一行与向量 v 相乘,结果是一个向量。

4. 矩阵乘法

A = torch.tensor([[1, 2], [0, 1]])
B = torch.tensor([[2, 0], [0, 2]])

# 矩阵乘法
result = torch.einsum('ik,kj->ij', A, B)  # 输出: 2x2 矩阵

5. 批量矩阵乘法

 
A = torch.randn(3, 2, 5)
B = torch.randn(3, 5, 3)

# 批量矩阵乘法 (Batched Matrix Multiplication)
result = torch.einsum('bik,bkj->bij', A, B)  # 输出: 3x2x3 矩阵

这里 'bik,bkj->bij' 表示对于批量中的每个矩阵,进行矩阵乘法。

6. 求和操作

A = torch.tensor([[1, 2, 3], [4, 5, 6]])

# 计算矩阵中所有元素的和
result = torch.einsum('ij->', A)  # 输出: tensor(21)

这里 'ij->' 表示对 A 的所有元素进行求和。

torch.einsum 的强大之处在于,你可以通过正确地安排这些字母和箭头来执行非常复杂的操作。需要注意的是,einsum 操作的效率可能不如专门的函数(如 torch.matmul 对于矩阵乘法),但它提供了一种非常简洁和通用的方式来表达复杂的张量计算。

文章来源:https://blog.csdn.net/sinat_34461199/article/details/135440995
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。