torch.nn.utils.stateless.functional_call
是 PyTorch 2.0 中的一个功能,但请注意,这个 API 在 PyTorch 2.0 版本中已被弃用,并计划在未来的版本中移除。取而代之的是 torch.func.functional_call
,它是此 API 的直接替代品。这个函数允许你在不改变模块本身状态的情况下,用提供的参数和缓冲区执行模块的功能调用。
tie_weights
参数控制是否尊重原模型中的权重绑定(权重共享)。strict
参数控制传递给模块的参数和缓冲区是否需要与原模块完全匹配。module
(torch.nn.Module): 要调用的模块。parameters_and_buffers
(dict of str and Tensor): 将在模块调用中使用的参数。args
(Any or tuple): 要传递给模块调用的参数。如果不是元组,则视为单个参数。kwargs
(dict): 要传递给模块调用的关键字参数。tie_weights
(bool, 可选): 如果为True,则在重新参数化版本中将按照原模型中绑定的参数和缓冲区进行处理。默认为True。strict
(bool, 可选): 如果为True,则传递的参数和缓冲区必须与原模块中的参数和缓冲区匹配。默认为False。torch.func.functional_call
。parameters_and_buffers
输入中。import torch
import torch.nn as nn
import torch.nn.utils.stateless as stateless
# 定义一个简单的模块
class MyModule(nn.Module):
def __init__(self):
super(MyModule, self).__init__()
self.param = nn.Parameter(torch.tensor(1.0))
def forward(self, x):
return x + self.param
module = MyModule()
new_params = {"param": torch.tensor(2.0)}
result = stateless.functional_call(module, new_params, args=(torch.tensor(3.0),))
?在这个示例中,通过 stateless.functional_call
,我们能够使用新的参数 new_params
临时调用 module
,而不改变模块自身的状态。这种方法在需要快速测试模块在不同参数配置下的行为时非常有用。
torch.nn.utils.stateless.functional_call
是 PyTorch 中的一个函数,用于在不更改模块(如神经网络层)本身状态的情况下,暂时性地用提供的参数和缓冲区执行该模块的调用。这种方法特别适用于需要评估或测试模型在不同参数配置下的性能时,提供了一种高效、灵活的方式来进行无状态调用。该函数支持权重共享控制和严格性检查,但需要注意的是,这个 API 在 PyTorch 2.0 中已被弃用,建议使用其替代品 torch.func.functional_call
。此功能适用于模型评估和测试,尤其是在需要频繁变更参数的场景中。