torch.optim.Optimizer.state_dict

Optimizer.state_dict()[源代码]

dict形式返回优化器的状态。

它包含了两个条目:

  • state: 一个保存当前优化状态的 Dict。其内容包括

    虽然在不同的优化器类之间存在差异,但也有一些共同的特点。例如,状态是按参数进行保存的,并且不会保存参数本身。state 是一个字典,将参数ID映射到包含每个参数对应状态信息的字典。

  • param_groups: 包含所有参数组的列表,每个参数组内

    参数组是一个字典。每个参数组包含了特定于优化器的元数据(如学习率和权重衰减),以及该组内参数的ID列表。

注意:参数 ID 虽然看起来像索引,但实际上只是用于将状态与 param_group 关联的标识符。在从 state_dict 加载时,优化器会按顺序将 param_group 中的 params(整数 ID)和优化器中的 param_groups (实际的 nn.Parameters)进行 zip 操作以匹配状态,而无需额外验证。

返回的 state 字典可能看起来像这样:

{
    'state': {
        0: {'momentum_buffer': tensor(...), ...},
        1: {'momentum_buffer': tensor(...), ...},
        2: {'momentum_buffer': tensor(...), ...},
        3: {'momentum_buffer': tensor(...), ...}
    },
    'param_groups': [
        {
            'lr': 0.01,
            'weight_decay': 0,
            ...
            'params': [0]
        },
        {
            'lr': 0.001,
            'weight_decay': 0.5,
            ...
            'params': [1, 2, 3]
        }
    ]
}
返回类型

Dict[str, Any]

本页目录