torch.topk()
?是 PyTorch 中的一个函数,用于从张量(tensor)中选取最大的 k 个值及其对应的索引。这个函数对于需要找到最大值或者对数据进行排序的场景非常有用。
函数的基本语法如下:
torch.topk(input, k, dim=None, largest=True, sorted=True, out=None)
参数解释:
input
:输入的张量。k
:返回最大的 k 个值。dim
:在哪个维度上计算 top k。如果为 None,则在整个张量上计算。默认为 None。largest
:如果为 True,则返回最大的 k 个值。如果为 False,则返回最小的 k 个值。默认为 True。sorted
:如果为 True,则返回的张量是排序后的。如果为 False,则返回的张量可能不是排序后的。默认为 True。out
:可选参数,输出结果的张量。如果提供,其形状必须能够容纳返回的结果。返回值:
values
:最大的 k 个值。indices
:每个最大值在输入张量中的索引。示例:
import torch
# 创建一个随机的张量
x = torch.randn(3, 3)
print(x)
# 输出:
# tensor([[ 0.1296, -0.1872, 0.9590],
# [-0.1385, 0.3172, 0.4423],
# [ 0.5977, -0.5863, 0.1649]])
# 找出最大的 2 个值及其索引
values, indices = torch.topk(x, 2)
print(values)
# 输出:
# tensor([ 0.9590, 0.4423])
print(indices)
# 输出:
# tensor([[2, 1],
# [0, 2]])