torch: 返回最大的几个值--topk()

发布时间:2024年01月15日

torch.topk()?是 PyTorch 中的一个函数,用于从张量(tensor)中选取最大的 k 个值及其对应的索引。这个函数对于需要找到最大值或者对数据进行排序的场景非常有用。

函数的基本语法如下:

torch.topk(input, k, dim=None, largest=True, sorted=True, out=None)

参数解释:

  • input:输入的张量。
  • k:返回最大的 k 个值。
  • dim:在哪个维度上计算 top k。如果为 None,则在整个张量上计算。默认为 None。
  • largest:如果为 True,则返回最大的 k 个值。如果为 False,则返回最小的 k 个值。默认为 True。
  • sorted:如果为 True,则返回的张量是排序后的。如果为 False,则返回的张量可能不是排序后的。默认为 True。
  • out:可选参数,输出结果的张量。如果提供,其形状必须能够容纳返回的结果。

返回值:

  • values:最大的 k 个值。
  • indices:每个最大值在输入张量中的索引。

示例:

import torch  
  
# 创建一个随机的张量  
x = torch.randn(3, 3)  
print(x)  
# 输出:   
# tensor([[ 0.1296, -0.1872,  0.9590],  
#          [-0.1385,  0.3172,  0.4423],  
#          [ 0.5977, -0.5863,  0.1649]])  
  
# 找出最大的 2 个值及其索引  
values, indices = torch.topk(x, 2)  
print(values)  
# 输出:   
# tensor([ 0.9590,  0.4423])  
print(indices)  
# 输出:   
# tensor([[2, 1],  
#          [0, 2]])

文章来源:https://blog.csdn.net/Ethan_Rich/article/details/135587760
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。