LazyModuleMixin

torch.nn.modules.lazy.LazyModuleMixin(*args, **kwargs)[源代码]

这是一个用于模块的混入类,这些模块会延迟初始化参数,也被称为“懒惰模块”。

懒惰初始化参数的模块(即“懒惰模块”)会根据正向方法的第一个输入来推导参数的形状。在第一次正向传递之前,这些模块包含不应被访问或使用的torch.nn.UninitializedParameters;而在之后,则包含常规的torch.nn.Parameters。懒惰模块很方便,因为它们不需要预先计算某些参数值,例如在典型的torch.nn.Linear中需要的in_features参数。

构建完成后,带有延迟模块的网络应首先转换为所需的数据类型(dtype)并放置在预期设备上。这是因为延迟模块仅执行形状推断,因此通常的数据类型和设备放置行为仍然适用。接下来,延迟模块应该进行“干运行”,以初始化模块中的所有组件。“干运行”会将正确大小、数据类型和设备的输入通过网络发送到每个延迟模块。完成这些步骤后,网络可以像平常一样使用。

>>> class LazyMLP(torch.nn.Module):
...    def __init__(self) -> None:
...        super().__init__()
...        self.fc1 = torch.nn.LazyLinear(10)
...        self.relu1 = torch.nn.ReLU()
...        self.fc2 = torch.nn.LazyLinear(1)
...        self.relu2 = torch.nn.ReLU()
...
...    def forward(self, input):
...        x = self.relu1(self.fc1(input))
...        y = self.relu2(self.fc2(x))
...        return y
>>> # constructs a network with lazy modules
>>> lazy_mlp = LazyMLP()
>>> # transforms the network's device and dtype
>>> # NOTE: these transforms can and should be applied after construction and before any 'dry runs'
>>> lazy_mlp = lazy_mlp.cuda().double()
>>> lazy_mlp
LazyMLP( (fc1): LazyLinear(in_features=0, out_features=10, bias=True)
  (relu1): ReLU()
  (fc2): LazyLinear(in_features=0, out_features=1, bias=True)
  (relu2): ReLU()
)
>>> # performs a dry run to initialize the network's lazy modules
>>> lazy_mlp(torch.ones(10,10).cuda())
>>> # after initialization, LazyLinear modules become regular Linear modules
>>> lazy_mlp
LazyMLP(
  (fc1): Linear(in_features=10, out_features=10, bias=True)
  (relu1): ReLU()
  (fc2): Linear(in_features=10, out_features=1, bias=True)
  (relu2): ReLU()
)
>>> # attaches an optimizer, since parameters can now be used as usual
>>> optim = torch.optim.SGD(mlp.parameters(), lr=0.01)

在使用延迟模块时需要注意的是,网络参数的初始化顺序可能会发生变化,因为延迟模块总是会在其他模块之后进行初始化。例如,如果上述定义的LazyMLP类中首先有一个torch.nn.LazyLinear 模块,然后是一个常规的 torch.nn.Linear 模块,在构造时第二个模块会被初始化,而第一个模块则会在第一次干运行时被初始化。这可能导致使用延迟模块的网络参数与不使用延迟模块的网络参数有不同的初始化顺序,因为参数初始化的顺序通常依赖于有状态的随机数生成器,并且不同。详情请参阅可重复性

惰性模块可以像其他模块一样使用状态字典进行序列化。例如:

>>> lazy_mlp = LazyMLP()
>>> # The state dict shows the uninitialized parameters
>>> lazy_mlp.state_dict()
OrderedDict([('fc1.weight', Uninitialized parameter),
             ('fc1.bias',
              tensor([-1.8832e+25,  4.5636e-41, -1.8832e+25,  4.5636e-41, -6.1598e-30,
                       4.5637e-41, -1.8788e+22,  4.5636e-41, -2.0042e-31,  4.5637e-41])),
             ('fc2.weight', Uninitialized parameter),
             ('fc2.bias', tensor([0.0019]))])

惰性模块可以加载常规的 torch.nn.Parameter(也就是说,你可以序列化和反序列化已初始化的惰性模块,它们会保持已初始化的状态)

>>> full_mlp = LazyMLP()
>>> # Dry run to initialize another module
>>> full_mlp.forward(torch.ones(10, 1))
>>> # Load an initialized state into a lazy module
>>> lazy_mlp.load_state_dict(full_mlp.state_dict())
>>> # The state dict now holds valid values
>>> lazy_mlp.state_dict()
OrderedDict([('fc1.weight',
              tensor([[-0.3837],
                      [ 0.0907],
                      [ 0.6708],
                      [-0.5223],
                      [-0.9028],
                      [ 0.2851],
                      [-0.4537],
                      [ 0.6813],
                      [ 0.5766],
                      [-0.8678]])),
             ('fc1.bias',
              tensor([-1.8832e+25,  4.5636e-41, -1.8832e+25,  4.5636e-41, -6.1598e-30,
                       4.5637e-41, -1.8788e+22,  4.5636e-41, -2.0042e-31,  4.5637e-41])),
             ('fc2.weight',
              tensor([[ 0.1320,  0.2938,  0.0679,  0.2793,  0.1088, -0.1795, -0.2301,  0.2807,
                        0.2479,  0.1091]])),
             ('fc2.bias', tensor([0.0019]))])

需要注意的是,如果在加载状态时参数已经被初始化,那么在进行“dry run”(试运行)时,这些参数将不会被替换。这样可以避免在同一模块的不同上下文中使用已经初始化的模块。

has_uninitialized_params()[源代码]

检查模块是否有未初始化的参数。

initialize_parameters(*args, **kwargs)[源代码]

根据输入批次的属性来初始化参数。

这添加了一个接口,以便在进行参数形状推断时将参数初始化与前向传递分离。

本页目录