tensor.topk 以及tensor.argmax

发布时间:2024年01月12日

tensor.topk 以及tensor.argmax

torch.topk(input, k, dim=None, largest=True, sorted=True, ***, out=None)

参数说明 -

  1. input(Tensor) - 输入张量。
  2. k(int) - 最大或最小的前k个。
  3. dim(int, optional) - 默认是-1, 按照dim进行排列的逻辑是:其他dim的索引相同元素属于一组进行比较。
  4. largest(bool, optional) - 默认是True。
  5. sorted(bool, optional) - 返回的最大的几个k值,再进行大小排列。

使用说明 -

  1. 返回 - (values, indices)
  2. 返回values(Tensor) - 维数不变,dim维大小是k。例如形状是(12, 14, 15)的Tensor按照dim=-1, k=9进行排列,返回的values形状是(12, 14, 9)
  3. 返回indices(LongTensor) - 和values形状一样,但是indices的值输入Tensor的dim索引。

使用举例

  • values.size() - (2, 3, 3)
  • indices.size() - (2, 3, 3)
import torch

a = torch.randint(0, 20, (2, 3, 8))

values, indices = torch.topk(a, dim=-1, k=3, sorted=True)
torch.argmax(input, dim, keepdim=False) → LongTensor

参数说明 -

  1. input(Tensor) - 输入张量。
  2. dim(int) - 按照dim进行reduce, reduce的方式时返回最大值的索引。默认是None,将整个Tensor拉平之后返回最大索引。
  3. keepdim(Bool) - 保持dim数量不变。

使用说明 -

  1. 一般情况是用于分类模型的预测,例如二分类模型输出形状为(16, 2)的Tensor, 进行argmax(dim=-1)之后,返回一个长度为16的列表,列表中就是预测的分类编号。
  2. 相当于只有torch.topk的indices功能。

使用举例 -

  • resutl - tensor([1, 3, 0, 2])
import torch

a = torch.randint(0, 20, (4, 4))

result = torch.argmax(a, dim=-1)
文章来源:https://blog.csdn.net/Akun_2217/article/details/135554257
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。