【PyTorch】PyTorch之包装容器

发布时间:2024年01月23日


前言

介绍pytorch关于model的包装容器。

Containers

1. torch.nn.Sequential(arg: OrderedDict[str, Module])

Sequential是一个顺序容器。模块将按照它们在构造函数中传递的顺序添加到其中。另外,可以传递一个包含模块的 OrderedDict。
Sequential 的 forward() 方法接受任何输入,并将其转发到它包含的第一个模块。然后,对于每个后续模块,它将输出顺序链接到输入,最终返回最后一个模块的输出。
与手动调用一系列模块相比,Sequential 提供的价值在于它允许将整个容器视为单个模块,这样在 Sequential 上执行转换会应用于它存储的每个模块(它们各自是 Sequential 的已注册子模块)。
Sequential 和 torch.nn.ModuleList 之间有什么区别?ModuleList 正是其字面意思 - 用于存储模块的列表!另一方面,Sequential 中的层以级联的方式连接。

# 使用 Sequential 创建一个小模型。当运行 model 时,
# 输入首先会传递给 Conv2d(1,20,5)。然后,Conv2d(1,20,5) 的输出将用作第一个 ReLU 的输入;
# 第一个 ReLU 的输出将成为 Conv2d(20,64,5) 的输入。最后,Conv2d(20,64,5) 的输出将用作第二个 # ReLU 的输入
model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU(),
          nn.Conv2d(20,64,5),
          nn.ReLU()
        )

# 使用带有 OrderedDict 的 Sequential。这在功能上与上面的代码相同
model = nn.Sequential(OrderedDict([
          ('conv1', nn.Conv2d(1,20,5)),
          ('relu1', nn.ReLU()),
          ('conv2', nn.Conv2d(20,64,5)),
          ('relu2', nn.ReLU())
        ]))

Sequential类的方法:
append(module)
Parameters:
module (nn.Module) – 要添加的module
Return type:
Sequential

将给定的模块追加到末尾。

2. torch.nn.ModuleList(modules=None)

Parameters:
modules (iterable, optional) – 要添加的模块的可迭代对象。

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])

    def forward(self, x):
        # ModuleList can act as an iterable, or be indexed using ints
        for i, l in enumerate(self.linears):
            x = self.linears[i // 2](x) + l(x)
        return x

ModuleList的方法列表:

append(module)

Parameters:
module (nn.Module) – 要添加的模块
Return type:
ModuleList

将给定的模块追加到列表的末尾。

extend(module)

Parameters:
modules (iterable) – i要追加的模块的可迭代对象。
Return type:
ModuleList

从 Python 可迭代对象中追加模块到列表的末尾。

insert(index, module)

Parameters:
index (int) – 要插入的索引。
module (nn.Module) – 要插入的模块。

在列表的给定索引之前插入给定的模块。

3. torch.nn.ModuleDict(modules=None)

Parameters:
modules (iterable, optional) – 一个映射(字典),格式为 (string: module) 或一个类型为 (string, module) 的键值对的可迭代对象。

ModuleDict 可以像常规的 Python 字典一样进行索引,但它包含的模块已经被正确注册,并且将被所有 Module 方法看到。
ModuleDict 是一个有序字典,它遵循以下顺序:

  1. 插入的顺序
  2. 在 update() 中,遵循合并的 OrderedDict、dict(从 Python 3.6 开始)或另一个
    ModuleDict(update() 的参数)

请注意,对于其他无序映射类型的 update()(例如,在 Python 版本 3.6 之前的普通 dict),不保留合并映射的顺序。

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.choices = nn.ModuleDict({
                'conv': nn.Conv2d(10, 10, 3),
                'pool': nn.MaxPool2d(3)
        })
        self.activations = nn.ModuleDict([
                ['lrelu', nn.LeakyReLU()],
                ['prelu', nn.PReLU()]
        ])

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x

ModuleDict类的方法:

clear()

移除ModuleDict中的所有项目。

items()

Return type:
Iterable[Tuple[str, Module]]

返回ModuleDict键/值对的可迭代对象。

keys()

Return type:
Iterable[str]

返回ModuleDict键的可迭代对象。

pop(key)

Parameters:
key (str) – key to pop from the ModuleDict
Return type:
Module

从ModuleDict中移除key并返回它的模块。

update(modules)

Parameters:
modules (iterable) – 一个从字符串到 Module 的映射(字典),或者一个键值对类型为 (string, Module) 的可迭代对象。

使用映射或可迭代对象的键值对来更新 ModuleDict,覆盖现有的键。
注意:
如果 modules 是 OrderedDict、ModuleDict 或键值对的可迭代对象,其中新元素的顺序将被保留。

values()

Return type:
Iterable[Module]

返回ModuleDict值的可迭代对象。

4. torch.nn.ParameterList(values=None)

Parameters:
parameters (iterable, optional) – 要添加到列表中的元素的可迭代对象.

ParameterList 可以像常规的 Python 列表一样使用,但作为 Parameter 的 Tensor 已经被正确注册,并且将被所有 Module 方法看到。
请注意,构造函数、分配列表的元素、append() 方法和 extend() 方法将把任何 Tensor 转换为 Parameter。

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])

    def forward(self, x):
        # ParameterList can act as an iterable, or be indexed using ints
        for i, p in enumerate(self.params):
            x = self.params[i // 2].mm(x) + p.mm(x)
        return x

ParameterList类的方法:

append(value)

Parameters:
value (Any) – 将要添加的value
Return type:
ParameterList

将给定的值追加到列表的末尾。

extend(values)

Parameters:
values (iterable) – iterable of values to append
Return type:
ParameterList

将Python可迭代对象中的值附加到列表的末尾。

5. torch.nn.ParameterDict(parameters=None)

Parameters:
values (iterable, optional) – 一个映射(字典),其元素为 (string : Any) 或一个类型为 (string, Any) 的键值对的可迭代对象。

ParameterDict 在一个字典中保存参数。
ParameterDict 可以像常规的 Python 字典一样进行索引,但它包含的 Parameters 已经被正确注册,并且将被所有 Module 方法看到。其他对象将被处理得像常规的 Python 字典一样。
ParameterDict 是一个有序字典。使用其他无序映射类型(例如 Python 的普通 dict)进行的 update() 操作不会保留合并映射的顺序。另一方面,OrderedDict 或另一个 ParameterDict 将保留它们的顺序。
请注意,构造函数、分配字典的元素和 update() 方法将把任何 Tensor 转换为 Parameter。

class MyModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.params = nn.ParameterDict({
                'left': nn.Parameter(torch.randn(5, 10)),
                'right': nn.Parameter(torch.randn(5, 10))
        })

    def forward(self, x, choice):
        x = self.params[choice].mm(x)
        return x

ParameterDict类的方法:

clear()

移除ParameterDict容器中的所有项目。

copy()

Return type:
ParameterDict

返回这个ParameterDict实例的副本。

fromkeys(keys, default=None)

Parameters:
keys (iterable, string) – 用于创建新 ParameterDict 的键
default (Parameter, optional) – 为所有键设置的值
Return type:
ParameterDict

返回一个带有参数keys的新的 ParameterDict。

get(key, default=None)

Parameters:
key (str) – 要从 ParameterDict 获取的键
default (Parameter, optional) – 如果键不存在时返回的值
Return type:
Any

如果存在与 key 关联的参数,则返回该参数。否则,如果提供了 default,则返回 default,如果没有提供则返回 None。

items()

Return type:
Iterable[Tuple[str, Any]]

返回 ParameterDict 键/值对的可迭代对象。

keys()

Return type:
Iterable[str]

返回 ParameterDict 的键的可迭代对象。

pop(key)

Parameters:
key (str) – 要从 ParameterDict 弹出的键
Return type:
Any

从 ParameterDict 中删除键并返回其参数。

popitem()

Return type:
Tuple[str, Any]

从 ParameterDict 中删除并返回最后插入的(键,参数)对。

setdefault(key, default=None)

Parameters:
key (str) – 要为其设置默认值的键
default (Any) – 设置为键的参数
Return type:
Any

如果 key 在 ParameterDict 中,则返回其值。如果不在,则插入 key 并将其参数设置为 default,并返回 default。default 默认为 None。

update(parameters)

Parameters:
parameters (iterable) – 从字符串到 Parameter 的映射(字典),或类型为(字符串,Parameter)的键值对的可迭代对象。

使用映射或可迭代对象中的键值对更新 ParameterDict,覆盖现有键。
注意:
如果 parameters 是 OrderedDict、ParameterDict 或键值对的可迭代对象,则其中新元素的顺序将被保留。

values()

Return type:
Iterable[Any]

返回 ParameterDict 值的可迭代对象。

文章来源:https://blog.csdn.net/weixin_44984705/article/details/135763976
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。