1. torch.stack() 方法:
import torch
tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6]])
# 沿着新的维度(第0维)堆叠
result = torch.stack((tensor1, tensor2))
print(result)
# 输出:
# tensor([[[1, 2],
# [3, 4]],
# [[5, 6]]])
2. torch.cat() 与 torch.unsqueeze() 方法:
torch.unsqueeze()
在现有维度上增加一个维度,然后使用 torch.cat()
连接张量。这样可以在现有维度上进行连接,但增加了一个额外的维度。import torch
tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6]])
# 使用 torch.unsqueeze 在现有维度上增加一个维度,然后使用 torch.cat() 连接
tensor2_expanded = torch.unsqueeze(tensor2, dim=0)
result = torch.cat((tensor1, tensor2_expanded), dim=0)
print(result)
# 输出:
# tensor([[1, 2],
# [3, 4],
# [5, 6]])
3. torch.cat() 与 torch.chunk() 方法:
torch.chunk()
将一个张量沿着指定维度分割成多个张量,然后通过 torch.cat()
连接所需的部分。这样可以在分割点处灵活地进行连接。import torch
tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6]])
# 使用 torch.chunk() 将 tensor1 沿着第0维度分割成两部分,然后通过 torch.cat() 连接所需的部分
chunks = torch.chunk(tensor1, chunks=2, dim=0)
result = torch.cat((chunks[0], tensor2, chunks[1]), dim=0)
print(result)
# 输出:
# tensor([[1, 2],
# [5, 6],
# [3, 4]])
4. torch.cat() 与 Python 列表方法:
torch.cat()
连接列表中的张量。这样提供了一种更灵活的方式来组织和连接张量。import torch
tensor1 = torch.tensor([[1, 2], [3, 4]])
tensor2 = torch.tensor([[5, 6]])
# 使用 Python 列表包含张量,然后使用 torch.cat() 连接列表中的张量
tensors_list = [tensor1, tensor2]
result = torch.cat(tensors_list, dim=0)
print(result)
# 输出:
# tensor([[1, 2],
# [3, 4],
# [5, 6]])
相关链接:torch.cat()函数的理解