torch.optim.Optimizer.state_dict
- Optimizer.state_dict()[源代码]
-
以
dict
形式返回优化器的状态。它包含了两个条目:
-
-
state
: 一个保存当前优化状态的 Dict。其内容包括 -
虽然在不同的优化器类之间存在差异,但也有一些共同的特点。例如,状态是按参数进行保存的,并且不会保存参数本身。
state
是一个字典,将参数ID映射到包含每个参数对应状态信息的字典。
-
-
-
param_groups
: 包含所有参数组的列表,每个参数组内 -
参数组是一个字典。每个参数组包含了特定于优化器的元数据(如学习率和权重衰减),以及该组内参数的ID列表。
-
注意:参数 ID 虽然看起来像索引,但实际上只是用于将状态与 param_group 关联的标识符。在从 state_dict 加载时,优化器会按顺序将 param_group 中的
params
(整数 ID)和优化器中的param_groups
(实际的nn.Parameter
s)进行 zip 操作以匹配状态,而无需额外验证。返回的 state 字典可能看起来像这样:
{ 'state': { 0: {'momentum_buffer': tensor(...), ...}, 1: {'momentum_buffer': tensor(...), ...}, 2: {'momentum_buffer': tensor(...), ...}, 3: {'momentum_buffer': tensor(...), ...} }, 'param_groups': [ { 'lr': 0.01, 'weight_decay': 0, ... 'params': [0] }, { 'lr': 0.001, 'weight_decay': 0.5, ... 'params': [1, 2, 3] } ] }
-