torch.cat 与torch.stack的区别
torch.cat
- 定义:按照规定的维度进行拼接。
- 实际使用: 例如使用BiLSTM时,将两个方向的向量进行叠加,就是用torch.cat。
import torch
forward_lstm = torch.randn((2, 10, 3))
backward_lstm = torch.randn((2, 10, 3))
lstm_emd = torch.cat((forward_lstm, backward_lstm), dim=-1)
print(lstm_emd.size())
'''
torch.Size([2, 10, 6])
'''
torch.stack
- 定义:官方解释是在新的dim上进行叠加。叠加的意思就是增加一个维度。
- 本质:对张量进行unsqueeze(dim)之后,再进行torch.cat(dim=dim)操作。
- 实际使用:将张量合在一起,形成一个batch。
import torch
batch_1 = torch.randn((10, 3))
batch_2 = torch.randn((10, 3))
batch = torch.stack((batch_1, batch_2), dim=0)
print(batch.size())
'''
torch.Size([2, 10, 3])
'''
- 使用torch.unsqueeze 和torch.cat实现torch.stack功能
import torch
batch_1 = torch.randn((10, 3))
batch_2 = torch.randn((10, 3))
batch_1 = torch.unsuqeeze(batch_1, dim=0)
batch_2 = torch.unsuqeeze(batch_2, dim=0)
batch = torch.cat((batch_1, batch_2), dim=0)
print(batch.size())
'''
torch.Size([2, 10, 3])
'''