首先给出官方对此函数的定义网页:torch.clamp — PyTorch 2.1 documentation
torch.clamp(input, min=None, max=None, *, out=None) → Tensor
其中:
input
: 输入张量,即需要进行元素限制的张量。min
: 张量中的元素的最小值。如果元素小于这个值,将被替换为这个最小值。max
: 张量中的元素的最大值。如果元素大于这个值,将被替换为这个最大值。out
(可选): 输出张量,用于保存结果。如果没有提供,函数会创建一个新的张量来保存结果。将元素限制在指定范围内: 对于输入张量 input
中的每个元素,torch.clamp
将其限制在指定的范围 [min, max]
内。如果元素小于 min
,就被替换为 min
;如果元素大于 max
,就被替换为 max
。
示例:
import torch
x = torch.tensor([1, 5, 10, -3, 8])
result = torch.clamp(x, min=2, max=8)
print(result)
输出:
tensor([2, 5, 8, 2, 8])
在这个例子中,torch.clamp
将张量 x
中的元素限制在范围 [2, 8]
内,小于2的元素被替换为2,大于8的元素被替换为8。
如果参数中未指定min
,则不限制张量的下边界;如果参数中未指定max
,则不限制张量的上边界;如果min
和max
均未提供,则不进行任何限制,函数返回的张量将和原始张量保持一致。
示例
import torch
x = torch.tensor([1, 5, 10, -3, 8])
result = torch.clamp(x, max=8)#未指定min值,则不限制下边界
print(result)
输出
tensor([ 1, 5, 8, -3, 8])
min和max的指定并不要求为整数,可以为浮点数,如下示例中,张量的元素被限制在[-2.5,8.7]
内
示例
import torch
x = torch.tensor([1, 5, 10, -3, 8])
result = torch.clamp(x, min=2.5, max=8.7)
print(result)
输出
tensor([2.5000, 5.0000, 8.7000, 2.5000, 8.0000])