torch.where()函数

发布时间:2024年01月22日

torch.where() 是 PyTorch 库中的一个函数,用于根据给定的条件返回输入张量中满足条件的元素的索引。
用法一:

torch.where(condition) -> Tensor

示例:

import torch

# 创建一个布尔类型的张量,表示条件
condition = torch.tensor([True, False, True, False])

# 使用 torch.where() 函数获取满足条件的元素索引
result = torch.where(condition)

print(result)

输出结果:

tensor([0, 2])

用法二:

torch.where(condition, x, y) -> Tensor

示例:

import torch

# 创建一个布尔类型的张量,表示条件
condition = torch.tensor([True, False, True, False])

# 创建两个与 condition 形状相同的张量
x = torch.tensor([1, 2, 3, 4])
y = torch.tensor([5, 6, 7, 8])

# 使用 torch.where() 函数获取满足条件的元素索引
result = torch.where(condition, x, y)

print(result)

输出结果:

tensor([1, 6, 3, 8])

参数说明:

  • condition:一个布尔类型的张量,表示条件。
  • x, y:两个与condition形状相同的张量,用于在满足和不满足条件时返回对应的元素。
文章来源:https://blog.csdn.net/weixin_43941438/article/details/135742809
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。