定义: torch.cat(tensors, dim=0, out=None)
→ Tensor
参数:
tensors
(Sequence of Tensors):要连接的张量序列。dim
(int, 可选):沿着此维连接张量序列。当 dim=0
时,torch.cat()
会按行连接多个张量,也就是在第一个维度上进行连接。这将导致张量在垂直方向上叠加。当 dim=1
时,torch.cat()
会按列连接多个张量,也就是在第二个维度上进行连接。这将导致张量在水平方向上叠加。out
(Tensor, 可选):输出张量。返回值:
用途:
torch.cat
用于将给定维度上的一系列张量连接在一起。张量在除连接维以外的所有维度上必须具有相同的形状。定义: torch.stack(tensors, dim=0, out=None)
→ Tensor
参数:
tensors
(Sequence of Tensors):要堆叠的张量序列,所有张量都应有相同的形状。dim
(int, 可选):插入新维度的索引。out
(Tensor, 可选):输出张量。返回值:
用途:
torch.stack
用于创建一个新的维度,并在该维度上堆叠一系列张量。与torch.cat
不同,torch.stack
会增加一个新的维度,所以输出张量的维度会比输入张量多一个。