ModuleDict

torch.nn.ModuleDict(modules=None)[源代码]
将子模块存放在字典里。

ModuleDict 可以像常规的 Python 字典一样进行索引,但其中包含的模块会被正确注册,并且所有 Module 方法都能访问到这些模块。

ModuleDict 是一个有序字典。

  • 插入顺序,以及

  • update()中,合并后的OrderedDict、从Python 3.6开始的dict或另一个ModuleDict(作为update()参数)的顺序。

注意,使用其他无序映射类型(例如,在Python 3.6版本之前的普通dict)调用update()不会保留合并后映射的顺序。

参数

modules (iterable, 可选) – 一个 (字符串: 模块) 形式的映射(字典)或类型为 (字符串, 模块) 的键值对的迭代对象

示例:

class MyModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.choices = nn.ModuleDict({
                'conv': nn.Conv2d(10, 10, 3),
                'pool': nn.MaxPool2d(3)
        })
        self.activations = nn.ModuleDict([
                ['lrelu', nn.LeakyReLU()],
                ['prelu', nn.PReLU()]
        ])

    def forward(self, x, choice, act):
        x = self.choices[choice](x)
        x = self.activations[act](x)
        return x
clear()[源代码]

将ModuleDict中的所有项移除。

items()[源代码]

返回包含 ModuleDict 中键值对的可迭代对象。

返回类型

Iterable[Tuple[str, Module]]

keys()[源代码]

返回包含 ModuleDict 中所有键的可迭代对象。

返回类型

Iterable[\str]

pop(key)[源代码]

从ModuleDict中删除键并返回对应的模块。

参数

key (str) – 需要从 ModuleDict 中移除的键

返回类型

Module

update(modules)[源代码]

使用映射中的键值对来更新ModuleDict,并覆盖已存在的键。

注意

如果 modules 是一个 OrderedDict、一个 ModuleDict 或者是一个键值对的可迭代对象,新元素的顺序会被保留。

参数

modules (可迭代对象) – 一个从字符串到Module的映射(字典),或者是一个类型为 (string, Module) 的键值对的可迭代对象。

values()[源代码]

返回一个包含 ModuleDict 值的迭代器。

返回类型

Iterable[ Module]

本页目录