PyTorch 还提供的几种连接张量的方法

发布时间:2024年01月16日

1. torch.stack() 方法:

  • 行为: 创建一个新的维度,并在该维度上堆叠张量。结果张量的维度比原始张量多一维。
  • 适用场景: 当需要在新的维度上堆叠张量时,特别是在创建新批次或样本时。
import torch

tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6]])

# 沿着新的维度(第0维)堆叠
result = torch.stack((tensor1, tensor2))
print(result)
# 输出:
# tensor([[[1, 2],
#          [3, 4]],
#         [[5, 6]]])

2. torch.cat() 与 torch.unsqueeze() 方法:

  • 行为: 使用 torch.unsqueeze() 在现有维度上增加一个维度,然后使用 torch.cat() 连接张量。这样可以在现有维度上进行连接,但增加了一个额外的维度。
  • 适用场景: 当希望在现有维度上连接张量,但需要处理张量的不同形状时,可以使用此方法。
import torch

tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6]])

# 使用 torch.unsqueeze 在现有维度上增加一个维度,然后使用 torch.cat() 连接
tensor2_expanded = torch.unsqueeze(tensor2, dim=0)
result = torch.cat((tensor1, tensor2_expanded), dim=0)
print(result)
# 输出:
# tensor([[1, 2],
#         [3, 4],
#         [5, 6]])

3. torch.cat() 与 torch.chunk() 方法:

  • 行为: 使用 torch.chunk() 将一个张量沿着指定维度分割成多个张量,然后通过 torch.cat() 连接所需的部分。这样可以在分割点处灵活地进行连接。
  • 适用场景: 当需要在连接之前在指定点处对张量进行分割时,特别是处理不同部分的不同操作。
import torch

tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6]])

# 使用 torch.chunk() 将 tensor1 沿着第0维度分割成两部分,然后通过 torch.cat() 连接所需的部分
chunks = torch.chunk(tensor1, chunks=2, dim=0)
result = torch.cat((chunks[0], tensor2, chunks[1]), dim=0)
print(result)
# 输出:
# tensor([[1, 2],
#         [5, 6],
#         [3, 4]])

4. torch.cat() 与 Python 列表方法:

  • 行为: 使用 Python 列表来包含张量,然后使用 torch.cat() 连接列表中的张量。这样提供了一种更灵活的方式来组织和连接张量。
  • 适用场景: 当有多个张量需要连接,而这些张量以列表的形式存在时,可以使用此方法。这种方法对于动态构建需要连接的张量集合很有用。
import torch

tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6]])

# 使用 Python 列表包含张量,然后使用 torch.cat() 连接列表中的张量
tensors_list = [tensor1, tensor2]
result = torch.cat(tensors_list, dim=0)
print(result)
# 输出:
# tensor([[1, 2],
#         [3, 4],
#         [5, 6]])

相关链接:torch.cat()函数的理解

文章来源:https://blog.csdn.net/weixin_51659315/article/details/135498257
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。