PyTorch中torch.cat函数和torch.stack函数说明

发布时间:2024年01月19日

torch.cat

  • 定义torch.cat(tensors, dim=0, out=None) → Tensor

  • 参数

    • tensors (Sequence of Tensors):要连接的张量序列。
    • dim (int, 可选):沿着此维连接张量序列。当 dim=0 时,torch.cat() 会按行连接多个张量,也就是在第一个维度上进行连接。这将导致张量在垂直方向上叠加。当 dim=1 时,torch.cat() 会按列连接多个张量,也就是在第二个维度上进行连接。这将导致张量在水平方向上叠加。
    • out (Tensor, 可选):输出张量。
  • 返回值

    • 一个新的张量,它是输入张量在指定维度上的连接。
  • 用途

    • torch.cat 用于将给定维度上的一系列张量连接在一起。张量在除连接维以外的所有维度上必须具有相同的形状。

torch.stack

  • 定义torch.stack(tensors, dim=0, out=None) → Tensor

  • 参数

    • tensors (Sequence of Tensors):要堆叠的张量序列,所有张量都应有相同的形状。
    • dim (int, 可选):插入新维度的索引。
    • out (Tensor, 可选):输出张量。
  • 返回值

    • 一个新的张量,它沿着新维度对输入张量序列进行堆叠。
  • 用途

    • torch.stack 用于创建一个新的维度,并在该维度上堆叠一系列张量。与torch.cat不同,torch.stack会增加一个新的维度,所以输出张量的维度会比输入张量多一个。
文章来源:https://blog.csdn.net/qq_61980594/article/details/135701032
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。