torch.max()是PyTorch中的一个函数,用于返回给定张量中的最大值。
函数签名如下:
torch.max(input) -> Tensor
torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor)
其中:
input
(张量)- 输入张量。dim
(int或tuple,可选)- 沿哪个维度进行最大值计算。默认情况下,返回整个张量中的最大值。keepdim
(bool,可选)- 是否保留输出张量的维度。out
(张量,可选)- 输出张量。返回一个元组,包含两个输出:
max_values
(张量)- 输入张量中的最大值。max_indices
(张量)- 输入张量中最大值的索引。如果dim
参数未指定,则此张量将是一维的,包含整个张量中最大值的索引。使用示例:
import torch
# 创建一个随机张量
x = torch.randn(3, 4)
# 计算整个张量的最大值和最大值的索引
max_value, max_index = torch.max(x, dim=None)
print("Max value:", max_value)
print("Max index:", max_index)
# 沿第二个维度计算最大值和最大值的索引
max_value, max_index = torch.max(x, dim=1)
print("Max value along dim=1:", max_value)
print("Max index along dim=1:", max_index)
一次输出:
Max value: tensor(2.2439)
Max index: tensor(11)
Max value along dim=1: tensor([ 1.3841, 1.1835, 2.2439])
Max index along dim=1: tensor([ 0, 1, 3])
关于dim:
在PyTorch中,张量的每个维度都有一个编号,从0开始。dim参数指定在哪个维度上进行最大值计算。以下是一个示例张量:
tensor([[1, 2, 3],
[4, 5, 6]])
dim=0
表示沿着第一个维度(即行)进行最大值计算,返回一个形状为(1, 3)
的张量,其中的元素为每列的最大值。即:tensor([4, 5, 6])
dim=1
表示沿着第二个维度(即列)进行最大值计算,返回一个形状为(2, 1)
的张量,其中的元素为每行的最大值。即:tensor([[3],
[6]])