torch.nn.utils.skip_init

torch.nn.utils.skip_init(module_cls, *args, **kwargs)[源代码]

给定一个模块类对象和参数 args/kwargs,无需初始化参数或缓冲区来实例化该模块。

如果初始化过程较慢或需要进行自定义初始化(从而使默认初始化变得不再必要),这种情况会非常有用。不过,由于该函数的实现方式,有一些需要注意的地方。

1. 该模块必须在其构造函数中接受一个device参数,并将此参数传递给在构造过程中创建的任何参数或缓冲区。

2. 该模块的构造函数除了调用初始化函数(如 torch.nn.init 中的函数)外,不得对参数进行任何计算。

如果满足这些条件,模块可以像使用 torch.empty() 创建时一样,用未初始化的参数和缓冲区值进行实例化。

参数
  • module_cls - 类对象;应为 torch.nn.Module 的一个子类

  • args - 模块构造函数的参数

  • kwargs - 模块构造函数的参数

返回值

实例化了一个参数和缓冲区都未初始化的模块

示例:

>>> import torch
>>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1)
>>> m.weight
Parameter containing:
tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]],
       requires_grad=True)
>>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1)
>>> m2.weight
Parameter containing:
tensor([[-1.4677e+24,  4.5915e-41,  1.4013e-45,  0.0000e+00, -1.4677e+24,
          4.5915e-41]], requires_grad=True)