参数说明 -
- input(Tensor) - 输入张量。
- k(int) - 最大或最小的前k个。
- dim(int, optional) - 默认是-1, 按照dim进行排列的逻辑是:其他dim的索引相同元素属于一组进行比较。
- largest(bool, optional) - 默认是True。
- sorted(bool, optional) - 返回的最大的几个k值,再进行大小排列。
使用说明 -
- 返回 - (values, indices)
- 返回values(Tensor) - 维数不变,dim维大小是k。例如形状是(12, 14, 15)的Tensor按照dim=-1, k=9进行排列,返回的values形状是(12, 14, 9)
- 返回indices(LongTensor) - 和values形状一样,但是indices的值输入Tensor的dim索引。
使用举例
- values.size() - (2, 3, 3)
- indices.size() - (2, 3, 3)
import torch
a = torch.randint(0, 20, (2, 3, 8))
values, indices = torch.topk(a, dim=-1, k=3, sorted=True)
参数说明 -
- input(Tensor) - 输入张量。
- dim(int) - 按照dim进行reduce, reduce的方式时返回最大值的索引。默认是None,将整个Tensor拉平之后返回最大索引。
- keepdim(Bool) - 保持dim数量不变。
使用说明 -
- 一般情况是用于分类模型的预测,例如二分类模型输出形状为(16, 2)的Tensor, 进行argmax(dim=-1)之后,返回一个长度为16的列表,列表中就是预测的分类编号。
- 相当于只有torch.topk的indices功能。
使用举例 -
- resutl - tensor([1, 3, 0, 2])
import torch
a = torch.randint(0, 20, (4, 4))
result = torch.argmax(a, dim=-1)