torch.nn.utils.parametrize.register_parametrization

torch.nn.utils.parametrize.register_parametrization(module, tensor_name, parametrization, *, unsafe=False)[源代码]

在模块中为张量注册参数化。

假设 tensor_name="weight" 以简化说明。当访问 module.weight 时,模块将返回参数化后的版本 parametrization(module.weight)。如果原始张量需要计算梯度,在反向传播过程中会通过 parametrization 进行求导,并且优化器会相应地更新该张量。

当一个模块第一次注册参数化时,此函数会向该模块添加一个 parametrizations 属性,其类型为 ParametrizationList

张量 weight 上的参数化可以在 module.parametrizations.weight 下访问。

原始张量可以通过 module.parametrizations.weight.original 访问。

可以通过在同一属性上注册多个参数化来连接它们。

在注册时,注册的参数化训练模式会更新,以匹配宿主模块的训练模式。

参数化参数和缓冲区内置了缓存系统,可以使用上下文管理器 cached() 来激活。

一个parametrization可以选择实现具有以下签名的方法

def right_inverse(self, X: Tensor) -> Union[Tensor, Sequence[Tensor]]

当首次进行参数化注册时,此方法会在未参数化的张量上调用,用于计算原始张量的初始值。如果没有实现该方法,那么原始张量就等同于未参数化的张量。

如果张量上注册的所有参数化都实现了right_inverse,那么就可以通过赋值来初始化一个参数化的张量,如下例所示。

第一个参数化可能依赖于多个输入。这可以通过在right_inverse中返回一个张量元组来实现(参见下面的RankOne参数化的示例实现)。

在这种情况下,未约束的张量也位于module.parametrizations.weight下,其名称分别为original0original1等。

注意

如果 unsafe=False(默认设置),将分别调用 forward 和 right_inverse 方法进行一致性检查。如果 unsafe=True,则仅当张量未被参数化时才会调用 right_inverse 方法,其他情况下不执行任何操作。

注意

在大多数情况下,right_inverse 将是一个函数,使得 forward(right_inverse(X)) == X(参见右逆函数)。有时,在参数化不是满射时,可以适当放宽这一要求。

警告

如果参数化依赖于多个输入,register_parametrization() 会注册一些新的参数。如果在创建优化器之后进行参数化注册,则需要手动将这些新参数添加到优化器中。参见 torch.Optimizer.add_param_group()

参数
  • module (nn.Module) – 需要注册参数化的模块

  • tensor_name (str) – 需要注册参数化的张量的名称

  • parametrization (nn.Module) — 需要注册的参数化方法

关键字参数

unsafe (bool) – 一个布尔标志,表示参数化是否可能更改张量的数据类型和形状。默认值为 False。警告:在注册时不会验证参数化的连贯性,请自行承担启用此标志的风险。

异常

ValueError – 如果模块不存在名为tensor_name的参数或缓冲区

返回类型

Module

示例

>>> import torch
>>> import torch.nn as nn
>>> import torch.nn.utils.parametrize as P
>>>
>>> class Symmetric(nn.Module):
>>>     def forward(self, X):
>>>         return X.triu() + X.triu(1).T  # Return a symmetric matrix
>>>
>>>     def right_inverse(self, A):
>>>         return A.triu()
>>>
>>> m = nn.Linear(5, 5)
>>> P.register_parametrization(m, "weight", Symmetric())
>>> print(torch.allclose(m.weight, m.weight.T))  # m.weight is now symmetric
True
>>> A = torch.rand(5, 5)
>>> A = A + A.T   # A is now symmetric
>>> m.weight = A  # Initialize the weight to be the symmetric matrix A
>>> print(torch.allclose(m.weight, A))
True
>>> class RankOne(nn.Module):
>>>     def forward(self, x, y):
>>>         # Form a rank 1 matrix multiplying two vectors
>>>         return x.unsqueeze(-1) @ y.unsqueeze(-2)
>>>
>>>     def right_inverse(self, Z):
>>>         # Project Z onto the rank 1 matrices
>>>         U, S, Vh = torch.linalg.svd(Z, full_matrices=False)
>>>         # Return rescaled singular vectors
>>>         s0_sqrt = S[0].sqrt().unsqueeze(-1)
>>>         return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
>>>
>>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne())
>>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item())
1
本页目录