目录
load_state_dict(state_dict, strict=True, assign=False)
named_buffers(prefix='', recurse=True, remove_duplicate=True)
named_modules(memo=None, prefix='', remove_duplicate=True)
named_parameters(prefix='', recurse=True, remove_duplicate=True)
register_buffer(name, tensor, persistent=True)
register_forward_pre_hook(hook)
register_full_backward_hook(hook)
register_parameter(name, param)
????????在当今快速发展的人工智能领域,深度学习已成为其中最引人注目的技术之一。PyTorch 作为一种流行的深度学习框架,因其灵活性和易用性而受到广泛欢迎。在 PyTorch 的众多组件中,torch.nn
模块无疑是构建复杂深度学习模型的基石。本文将深入探讨 torch.nn
模块的功能、优势和使用技巧,旨在为读者提供一个清晰的理解和应用指南。torch.nn
提供了构建神经网络所需的所有基本构建块,包括各种类型的层(如卷积层、池化层、激活函数)、损失函数和容器。这些组件不仅是模块化和可重用的,而且也支持灵活的网络架构设计。通过本文,我们将逐一解析这些组件的特性和使用场景,并分享一些实用的技巧来优化网络性能。无论是新手还是有经验的开发者,都可以从中获得宝贵的见解,以更好地利用这个强大的模块来设计和实现高效的深度学习模型。
????????接下来的章节将从 torch.nn
的基础知识开始,逐步深入到更高级的主题,包括定制网络层、优化技巧和最佳实践。准备好,让我们开始这次深入浅出的 torch.nn
之旅吧!
? ? torch.nn.parameter.Parameter
是 PyTorch 深度学习框架中的一个重要类,用于表示神经网络中的参数。这个类是 Tensor
的子类,它在与 Module
(模块)一起使用时具有特殊属性。当 Parameter
被赋值为 Module
的属性时,它自动被添加到模块的参数列表中,并且会出现在例如 parameters()
迭代器中。这与普通的 Tensor
不同,因为 Tensor
赋值给模块时不会有这样的效果。
Parameter
主要用于将张量标记为模块的参数。这对于模型的训练和参数更新至关重要,因为只有被标记为 Parameter
的张量才会在模型训练时更新。Parameter
来定义可训练的参数(如权重和偏置)。这些参数在训练过程中会通过反向传播进行优化。Parameter
类对其进行初始化,从而确保这些参数会被识别并在训练过程中更新。requires_grad
参数,可以控制特定参数是否需要在反向传播中计算梯度。这对于冻结模型的部分参数或进行特定的优化策略非常有用。以下是 torch.nn.parameter.Parameter
的使用示例:
import torch
import torch.nn as nn
# 定义一个自定义的线性层
class CustomLinearLayer(nn.Module):
def __init__(self, in_features, out_features):
super(CustomLinearLayer, self).__init__()
# 定义权重为一个可训练的参数
self.weight = nn.Parameter(torch.randn(out_features, in_features))
# 定义偏置为一个可训练的参数
self.bias = nn.Parameter(torch.randn(out_features))
def forward(self, x):
# 实现前向传播
return torch.matmul(x, self.weight.t()) + self.bias
# 创建一个自定义的线性层实例
layer = CustomLinearLayer(5, 3)
print(list(layer.parameters()))
????????在上述代码中,CustomLinearLayer
类中定义了两个 Parameter
对象:weight
和 bias
。这些参数在模块被实例化时自动注册,并在训练过程中会被优化。通过打印 layer.parameters()
,可以看到这些被注册的参数。
torch.nn.parameter.UninitializedParameter
是 PyTorch 中的一个特殊类,用于表示尚未初始化的参数。这个类是 torch.nn.Parameter
的一个特殊情况,其主要特点是在创建时数据的形状(shape)还未知。
torch.nn.Parameter
不同,UninitializedParameter
不持有任何数据。这意味着在初始化之前,试图访问某些属性(如它们的形状)会引发运行时错误。UninitializedParameter
允许在模型定义阶段创建参数,而不必立即指定它们的大小或形状。这在某些情况下非常有用,例如,当参数的大小依赖于运行时才可知的因素时。UninitializedParameter
的数据类型。UninitializedParameter
移动到不同的设备(例如从 CPU 移到 GPU)。UninitializedParameter
转换为常规的 torch.nn.Parameter
,此时需要指定其形状和数据。在下面的示例中,将展示如何创建一个未初始化的参数,并在稍后将其转换为常规参数:
import torch
import torch.nn as nn
class CustomLayer(nn.Module):
def __init__(self):
super(CustomLayer, self).__init__()
# 创建一个未初始化的参数
self.uninitialized_param = nn.parameter.UninitializedParameter()
def forward(self, x):
# 在前向传播中使用参数前必须先初始化
if isinstance(self.uninitialized_param, nn.parameter.UninitializedParameter):
# 初始化参数
self.uninitialized_param = nn.Parameter(torch.randn(x.size(1), x.size(1)))
return torch.matmul(x, self.uninitialized_param.t())
# 创建自定义层的实例
layer = CustomLayer()
# 假设输入x
x = torch.randn(10, 5)
# 使用自定义层
output = layer(x)
print(output)
????????在这个例子中,CustomLayer
在初始化时创建了一个 UninitializedParameter
。在进行前向传播时,检查这个参数是否已初始化,如果没有,则对其进行初始化,并将其转换为常规的 Parameter
。这种方式在处理动态大小的输入时特别有用。
????????torch.nn.parameter.UninitializedBuffer
是 PyTorch 中的一个特殊类,它代表一个尚未初始化的缓冲区。这个类是 torch.Tensor
的一个特殊情形,其主要特点是在创建时数据的形状(shape)还未知。
torch.Tensor
不同,UninitializedBuffer
不持有任何数据。这意味着在初始化之前,尝试访问某些属性(如它们的形状)会引发运行时错误。UninitializedBuffer
适用于那些在模型定义阶段需要创建缓冲区,但其大小或形状取决于后来才可知的数据或配置的情况。UninitializedBuffer
的数据类型。UninitializedBuffer
移动到不同的设备(例如从 CPU 移到 GPU)。UninitializedBuffer
转换为常规的 torch.Tensor
,此时需要指定其形状和数据。在下面的示例中,将展示如何创建一个未初始化的缓冲区,并在稍后将其转换为常规张量:
import torch
import torch.nn as nn
class CustomLayer(nn.Module):
def __init__(self):
super(CustomLayer, self).__init__()
# 创建一个未初始化的缓冲区
self.uninitialized_buffer = nn.parameter.UninitializedBuffer()
def forward(self, x):
# 在前向传播中使用缓冲区前必须先初始化
if isinstance(self.uninitialized_buffer, nn.parameter.UninitializedBuffer):
# 初始化缓冲区
self.uninitialized_buffer = torch.Tensor(x.size(0), x.size(1))
# 在这里可以使用缓冲区进行计算或其他操作
return x + self.uninitialized_buffer
# 创建自定义层的实例
layer = CustomLayer()
# 假设输入x
x = torch.randn(10, 5)
# 使用自定义层
output = layer(x)
print(output)
????????在这个例子中,CustomLayer
在初始化时创建了一个 UninitializedBuffer
。在进行前向传播时,检查这个缓冲区是否已初始化,如果没有,则对其进行初始化,并将其转换为常规的 Tensor
。这种方法在动态处理数据大小时非常有用,特别是在需要临时存储数据但在模型定义阶段无法确定其大小的情况下。??
?????????torch.nn.Module
是 PyTorch 中用于构建所有神经网络模型的基类。几乎所有的 PyTorch 神经网络模型都是通过继承 torch.nn.Module
来构建的。这个类提供了模型需要的基本功能,如参数管理、模型保存和加载、设备转移(例如,从 CPU 到 GPU)等。
Module
可以包含其他 Module
,形成一个嵌套的树状结构。这允许用户以模块化的方式构建复杂的神经网络。Module
自动管理其属性中的所有 Parameter
和 Buffer
对象。这包括注册参数、转移到不同设备、保存和加载模型状态等。forward
方法,以定义其在接收输入时的计算过程。以下是一个基本的 torch.nn.Module
子类的示例:
import torch.nn as nn
import torch.nn.functional as F
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5)
def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
model = SimpleModel()
????????在这个例子中,SimpleModel
继承了 torch.nn.Module
。在其构造函数中,定义了两个卷积层 conv1
和 conv2
,并在 forward
方法中定义了模型的前向传播逻辑。
torch.nn.Module
主要方法详解add_module(name, module)
name
: 子模块的名称。# 定义一个自定义模块
class CustomModule(nn.Module):
def __init__(self):
super(CustomModule, self).__init__()
# 创建一个线性层
linear = nn.Linear(10, 5)
# 使用 add_module 添加线性层作为子模块
self.add_module('linear', linear)
apply(fn)
fn
应用于每个子模块及其自身。fn
: 要应用的函数,通常用于初始化参数。# 定义一个初始化权重的函数
def init_weights(m):
if type(m) == nn.Linear:
nn.init.uniform_(m.weight)
# 应用 init_weights 函数初始化模型的权重
model = CustomModule()
model.apply(init_weights)
bfloat16()
# 将模型的参数和缓冲区转换为 bfloat16 数据类型
model.bfloat16()
buffers(recurse=True)
recurse
: 如果为 True,则遍历此模块及所有子模块的缓冲区。# 遍历模型的所有缓冲区
for buf in model.buffers():
print(buf.size())
children()
# 遍历模型的直接子模块
for child in model.children():
print(child)
cpu()
# 将模型移动到 CPU
model.cpu()
cuda(device=None)
device
: 指定 GPU 设备。# 将模型移动到 GPU
model.cuda()
double()
# 将模型的参数和缓冲区转换为 double 数据类型
model.double()
eval()
# 将模型设置为评估模式
model.eval()
extra_repr()
# 自定义模型的额外表示
class CustomModule(nn.Module):
def __init__(self):
super(CustomModule, self).__init__()
def extra_repr(self):
return '自定义信息'
model = CustomModule()
print(model)
float()
# 将模型的参数和缓冲区转换为 float 数据类型
model.float()
forward(*input)
# 定义模型的前向传播
class CustomModule(nn.Module):
def __init__(self):
super(CustomModule, self).__init__()
self.linear = nn.Linear(10, 5)
def forward(self, x):
return self.linear(x)
model = CustomModule()
input = torch.randn(1, 10)
output = model(input)
get_buffer(target)
# 获取特定名称的缓冲区
buffer = model.get_buffer('buffer_name')
get_parameter(target)
# 获取特定名称的参数
parameter = model.get_parameter('param_name')
half()
# 将模型的参数和缓冲区转换为半精度 (half) 数据类型
model.half()
load_state_dict(state_dict, strict=True, assign=False)
state_dict
中复制参数和缓冲区到此模块及其后代。state_dict
: 包含参数和持久缓冲区的字典。strict
: 是否严格匹配 state_dict
和模块的键。# 从 state_dict 加载模型状态
state_dict = {'linear.weight': torch.randn(5, 10), 'linear.bias': torch.randn(5)}
model.load_state_dict(state_dict, strict=False)
modules()
# 遍历网络中的所有模块
for module in model.modules():
print(module)
named_buffers(prefix='', recurse=True, remove_duplicate=True)
# 遍历模型的所有缓冲区,同时提供缓冲区的名称
for name, buf in model.named_buffers():
print(f"Buffer name: {name}, Buffer: {buf}")
named_children()
# 遍历模型的直接子模块,同时提供子模块的名称
for name, child in model.named_children():
print(f"Child name: {name}, Child module: {child}")
named_modules(memo=None, prefix='', remove_duplicate=True)
# 遍历网络中的所有模块,同时提供模块的名称
for name, module in model.named_modules():
print(f"Module name: {name}, Module: {module}")
named_parameters(prefix='', recurse=True, remove_duplicate=True)
# 遍历模型的所有参数,同时提供参数的名称
for name, param in model.named_parameters():
print(f"Parameter name: {name}, Parameter: {param}")
parameters(recurse=True)
# 遍历模型的所有参数
for param in model.parameters():
print(param)
register_backward_hook(hook)
# 注册一个反向传播钩子
def backward_hook(module, grad_input, grad_output):
print(f"Backward hook in {module}")
model.register_backward_hook(backward_hook)
register_buffer(name, tensor, persistent=True)
# 向模块添加一个缓冲区
model.register_buffer('new_buffer', torch.randn(5))
register_forward_hook(hook)
# 注册一个前向传播钩子
def forward_hook(module, input, output):
print(f"Forward hook in {module}")
model.register_forward_hook(forward_hook)
register_forward_pre_hook(hook)
# 注册一个前向传播钩子
def forward_hook(module, input, output):
print(f"Forward hook in {module}")
model.register_forward_hook(forward_hook)
register_full_backward_hook(hook)
# 注册一个完整的反向传播钩子
def full_backward_hook(module, grad_input, grad_output):
print(f"Full backward hook in {module}")
model.register_full_backward_hook(full_backward_hook)
register_parameter(name, param)
# 向模块添加一个参数
param = nn.Parameter(torch.randn(5))
model.register_parameter('new_param', param)
state_dict()
# 获取模块所有状态信息的字典
state_dict = model.state_dict()
to(*args, **kwargs)
# 移动和/或转换参数和缓冲区
# 移动模型到 GPU 并转换为 double 类型
model.to('cuda', dtype=torch.double)
train(mode=True)
# 将模块设置为训练模式
model.train()
type(dst_type)
# 将所有参数和缓冲区转换为指定类型
model.type(torch.float32)
zero_grad(set_to_none=True)
# 重置所有模型参数的梯度
model.zero_grad()
?????????这些示例涵盖了 torch.nn.Module
类中的大多数主要方法,展示了如何在实际情况中使用它们。
? ?torch.nn.Sequential
是 PyTorch 中的一个容器模块,用于按顺序封装一系列子模块。它简化了模型的构建过程,使得将多个模块组合成一个单独的序列变得容易和直观。
Sequential
按照它们在构造函数中传递的顺序,依次处理每个子模块。输入数据首先被传递到第一个模块,然后依次传递到每个后续模块。Sequential
允许将整个容器视为单一模块,对其进行的任何转换都适用于它存储的每个模块(每个模块都是 Sequential
的一个注册子模块)。torch.nn.ModuleList
的区别torch.nn.ModuleList
仅仅是一个存储子模块的列表,而 Sequential
中的层是级联连接的。在 ModuleList
中,层之间没有直接的数据流动关联,而在 Sequential
中,一个层的输出直接成为下一个层的输入。
使用 Sequential 创建一个简单的模型:
model = nn.Sequential(
nn.Conv2d(1, 20, 5),
nn.ReLU(),
nn.Conv2d(20, 64, 5),
nn.ReLU()
)
在这个例子中,输入数据首先通过一个 Conv2d
层,然后是 ReLU
层,接着是第二个 Conv2d
层,最后是另一个 ReLU
层。
使用带有 OrderedDict
的 Sequential:
from collections import OrderedDict
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1, 20, 5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20, 64, 5)),
('relu2', nn.ReLU())
]))
?????????使用 OrderedDict
允许为每个模块指定一个唯一的名称。这在需要引用特定层或在打印模型结构时提高了可读性。
append(module)
方法module
(nn.Module
): 要附加的模块。Sequential
这种方式构建的模型可以简化前向传播的实现,使得模型的构建和理解更加直观。
torch.nn.ModuleList
是 PyTorch 中用于存储子模块的列表容器。它类似于 Python 的常规列表,但具有额外的功能,使其能够适当地注册其中包含的模块,并使它们对所有 Module
方法可见。
ModuleList
提供了一个列表式的结构来保存模块,允许通过索引或迭代器访问这些模块。.parameters()
或 .to(device)
等 Module
方法时,这些子模块也会被考虑在内。class MyModule(nn.Module):
def __init__(self):
super().__init__()
# 使用 ModuleList 创建一个线性层的列表
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
def forward(self, x):
# ModuleList 可以作为迭代器,也可以使用索引访问
for i, l in enumerate(self.linears):
x = self.linears[i // 2](x) + l(x)
return x
?在这个例子中,MyModule
创建了一个 ModuleList
,其中包含了 10 个 nn.Linear(10, 10)
层。在 forward
方法中,使用了两种不同的方式来访问和应用这些层。
ModuleList
的方法append(module)
module
(nn.Module
):要添加的模块。extend(modules)
modules
(iterable):可迭代的模块对象。insert(index, module)
index
(int):插入的索引。module
(nn.Module
):要插入的模块。ModuleList
提供了灵活的方式来管理子模块的集合,特别是当模型的某些部分是动态的或者模型结构中的层的数量在初始化时未知时非常有用。
torch.nn.ModuleDict
是 PyTorch 中的一个容器模块,用于以字典形式保存子模块。它类似于 Python 的常规字典,但其包含的模块会被正确注册,并且对所有 Module
方法可见。
ModuleDict
提供了一个字典式的结构来保存模块,允许通过键值对访问这些模块。ModuleDict
是一个有序字典,它会保留插入顺序和合并顺序。class MyModule(nn.Module):
def __init__(self):
super().__init__()
# 使用 ModuleDict 创建一个由不同层组成的字典
self.choices = nn.ModuleDict({
'conv': nn.Conv2d(10, 10, 3),
'pool': nn.MaxPool2d(3)
})
# 可以使用列表初始化 ModuleDict
self.activations = nn.ModuleDict([
['lrelu', nn.LeakyReLU()],
['prelu', nn.PReLU()]
])
def forward(self, x, choice, act):
# 通过键值访问 ModuleDict 中的模块
x = self.choices[choice](x)
x = self.activations[act](x)
return x
在这个例子中,MyModule
创建了两个 ModuleDict
,一个用于保存卷积层和池化层,另一个用于保存激活层。
ModuleDict
的方法clear()
ModuleDict
中的所有项。items()
ModuleDict
中的键/值对的迭代器。keys()
ModuleDict
键的迭代器。pop(key)
ModuleDict
中移除键并返回其模块。key
(str):要从 ModuleDict
中弹出的键。update(modules)
ModuleDict
,覆盖现有的键。modules
(iterable):从字符串到模块的映射(字典),或键值对的迭代器。values()
ModuleDict
中模块值的迭代器。ModuleDict
提供了一个灵活的方式来管理具有特定键的子模块的集合。这在模型设计中特别有用,尤其是当模型的不同部分需要根据键动态选择时。
torch.nn.ParameterList
是 PyTorch 中的一个容器模块,用于按列表形式保存参数(Parameter
对象)。它类似于 Python 的常规列表,但其特殊之处在于其中包含的 Tensor
对象会被转换为 Parameter
对象,并正确注册,使得这些参数对所有 Module
方法可见。
ParameterList
提供了一个列表式的结构来保存参数,允许通过索引或迭代器访问这些参数。Tensor
对象会被自动转换为 Parameter
对象,确保它们可以被 PyTorch 的优化器等模块正确处理。class MyModule(nn.Module):
def __init__(self):
super().__init__()
# 使用 ParameterList 创建一个包含多个参数的列表
self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
def forward(self, x):
# ParameterList 可以作为迭代器,也可以使用索引访问
for i, p in enumerate(self.params):
x = self.params[i // 2].mm(x) + p.mm(x)
return x
在这个例子中,MyModule
创建了一个 ParameterList
,其中包含了 10 个形状为 10x10
的随机参数。在 forward
方法中,这些参数被用于矩阵乘法操作。
ParameterList
的方法append(value)
Parameter
)。value
(Any):要添加的值。extend(values)
Parameter
)。values
(iterable):要添加的值的可迭代对象。ParameterList
提供了一种灵活的方式来管理模型中的参数集合,特别是当模型的某些部分参数数量动态变化时非常有用。通过使用 ParameterList
,您可以确保模型的所有参数都正确注册,并且可以通过标准的 PyTorch 方法进行访问和优化。
torch.nn.ParameterDict
是 PyTorch 中用于以字典形式保存参数(Parameter
对象)的容器模块。它类似于 Python 的常规字典,但其特殊之处在于其中包含的参数被正确注册,并对所有 Module
方法可见。
ParameterDict
提供了一个字典式的结构来保存参数,允许通过键值对访问这些参数。ParameterDict
是一个有序字典,它保留插入顺序和合并顺序(对于 OrderedDict
或另一个 ParameterDict
)。class MyModule(nn.Module):
def __init__(self):
super().__init__()
# 使用 ParameterDict 创建一个由不同参数组成的字典
self.params = nn.ParameterDict({
'left': nn.Parameter(torch.randn(5, 10)),
'right': nn.Parameter(torch.randn(5, 10))
})
def forward(self, x, choice):
# 通过键值访问 ParameterDict 中的参数
x = self.params[choice].mm(x)
return x
在这个例子中,MyModule
创建了一个 ParameterDict
,其中包含了两个名为 'left' 和 'right' 的参数。在 forward
方法中,根据传入的 choice
键来选择相应的参数进行矩阵乘法操作。
ParameterDict
的方法clear()
ParameterDict
中的所有项。copy()
ParameterDict
实例的副本。fromkeys(keys, default=None)
ParameterDict
。keys
(iterable, string):用于创建新 ParameterDict
的键。default
(Parameter, 可选):为所有键设置的默认值。get(key, default=None)
key
相关联的参数。否则,如果提供了 default
,则返回 default
;如果没有提供,则返回 None
。items()
ParameterDict
键/值对的迭代器。keys()
ParameterDict
键的迭代器。pop(key)
ParameterDict
中移除键并返回其参数。key
(str):要从 ParameterDict
中弹出的键。popitem()
ParameterDict
中移除并返回最后插入的 (键, 参数) 对。setdefault(key, default=None)
key
在 ParameterDict
中,则返回其值。如果不是,插入 key
与参数 default
并返回 default
。default
默认为 None
。update(parameters)
ParameterDict
,覆盖现有的键。values()
ParameterDict
中参数值的迭代器。ParameterDict
提供了一种灵活的方式来管理模型中具有特定键的参数集合。这在模型设计中特别有用,尤其是当模型的不同部分需要根据键动态选择参数时。
????????本文深入探索了 PyTorch 框架中的 torch.nn
模块,这是构建和实现高效深度学习模型的核心组件。我们详细介绍了 torch.nn
的关键类别和功能,包括 Parameter
, Module
, Sequential
, ModuleList
, ModuleDict
, ParameterList
和 ParameterDict
,为读者提供了一个全面的理解和应用指南。这篇博客仅仅是torch.nn的一部分功能,后续我这边会继续更新这个模块的其他相关功能。