torch.where()
是 PyTorch 库中的一个函数,用于根据给定的条件返回输入张量中满足条件的元素的索引。
用法一:
torch.where(condition) -> Tensor
示例:
import torch
# 创建一个布尔类型的张量,表示条件
condition = torch.tensor([True, False, True, False])
# 使用 torch.where() 函数获取满足条件的元素索引
result = torch.where(condition)
print(result)
输出结果:
tensor([0, 2])
用法二:
torch.where(condition, x, y) -> Tensor
示例:
import torch
# 创建一个布尔类型的张量,表示条件
condition = torch.tensor([True, False, True, False])
# 创建两个与 condition 形状相同的张量
x = torch.tensor([1, 2, 3, 4])
y = torch.tensor([5, 6, 7, 8])
# 使用 torch.where() 函数获取满足条件的元素索引
result = torch.where(condition, x, y)
print(result)
输出结果:
tensor([1, 6, 3, 8])
参数说明:
condition
:一个布尔类型的张量,表示条件。x
, y
:两个与condition
形状相同的张量,用于在满足和不满足条件时返回对应的元素。