torch.repeat_interleave() 是 PyTorch 中的一个函数,用于按指定的方式重复张量中的元素。
以下是该函数的详细说明:
torch.repeat_interleave()
的原理是将输入张量中的每个元素重复指定的次数,并将这些重复的元素拼接成一个新的张量。
torch.repeat_interleave(input, repeats, dim=None)
示例1:
import torch
# 创建一个示例张量
tensor = torch.tensor([1, 2, 3])
# 重复每个元素两次
result = torch.repeat_interleave(tensor, repeats=2)
print(result)
示例说明:
上述示例创建了一个张量 [1, 2, 3],并使用 torch.repeat_interleave() 将每个元素重复了两次。因此,输出将是一个新的张量 [1, 1, 2, 2, 3, 3]。
输出结果:
tensor([1, 1, 2, 2, 3, 3])
这个函数在处理序列数据、生成数据扩充样本等场景中很有用。
示例2:
假设有一个二维张量,并且想要沿着某个维度重复每行的元素不同的次数。
import torch
# 创建一个二维张量
matrix = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 指定每行的重复次数
repeats_per_row = torch.tensor([2, 3, 1])
# 沿着第一维度重复
result = torch.repeat_interleave(matrix, repeats=repeats_per_row, dim=0)
print(result)
在这个例子中,我们有一个二维张量 matrix,以及一个指定每行重复次数的张量 repeats_per_row。通过使用 torch.repeat_interleave() 沿着第一维度(行)重复每行的元素,我们得到了一个新的张量。
输出结果:
tensor([[1, 2, 3],
[1, 2, 3],
[4, 5, 6],
[4, 5, 6],
[4, 5, 6],
[7, 8, 9]])
在这个例子中,第一行的元素被重复了两次,第二行的元素被重复了三次,而第三行的元素被重复了一次。这样,我们就实现了按照指定方式重复每行的元素。