序列化语义

此笔记介绍了如何在Python中保存和加载PyTorch张量和模块状态,以及如何将Python模块进行序列化以供C++加载使用。

保存和加载张量

torch.save()torch.load() 可以让你轻松地保存和加载张量:

>>> t = torch.tensor([1., 2.])
>>> torch.save(t, 'tensor.pt')
>>> torch.load('tensor.pt')
tensor([1., 2.])

根据惯例,PyTorch 文件通常使用 ‘.pt’ 或 ‘.pth’ 作为扩展名。

torch.save()torch.load() 默认使用 Python 的 pickle,因此你可以将多个张量作为元组、列表和字典等 Python 对象的一部分进行保存。

>>> d = {'a': torch.tensor([1., 2.]), 'b': torch.tensor([3., 4.])}
>>> torch.save(d, 'tensor_dict.pt')
>>> torch.load('tensor_dict.pt')
{'a': tensor([1., 2.]), 'b': tensor([3., 4.])}

如果数据结构是可pickle的,包含PyTorch张量的自定义数据结构也可以被保存。

保存和加载张量时会保留视图

保存张量时会保持它们的视图关系:

>>> numbers = torch.arange(1, 10)
>>> evens = numbers[1::2]
>>> torch.save([numbers, evens], 'tensors.pt')
>>> loaded_numbers, loaded_evens = torch.load('tensors.pt')
>>> loaded_evens *= 2
>>> loaded_numbers
tensor([ 1,  4,  3,  8,  5, 12,  7, 16,  9])

背后来看,这些张量共用同一个“存储空间”。关于视图和存储空间的更多细节,请参阅张量视图

当PyTorch保存张量时,它会分别保存它们的存储对象和元数据。这是一个实现细节,未来可能会发生变化,但这样做通常可以节省空间,并且可以让PyTorch轻松地重建加载后张量之间的视图关系。例如,在上述代码片段中,只有一个存储被写入到‘tensors.pt’文件。

然而,在某些情况下,保存当前的存储对象可能是不必要的,并会导致生成过大的文件。例如,在下面的代码片段中,一个远远超过所保存张量大小的存储空间被写入了文件。

>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small, 'small.pt')
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
999

不是只将 small 张量中的五个值保存到 'small.pt',而是将与 large 共享的 999 个存储值进行了保存和加载。

当需要保存的张量元素少于其存储对象时,可以先克隆张量以减小保存文件的大小。克隆张量会生成一个新张量,这个新张量的存储对象只包含原张量中的值。

>>> large = torch.arange(1, 1000)
>>> small = large[0:5]
>>> torch.save(small.clone(), 'small.pt')  # saves a clone of small
>>> loaded_small = torch.load('small.pt')
>>> loaded_small.storage().size()
5

由于克隆的张量彼此独立,它们之间没有原始张量的视图关系。如果在保存小于其存储对象的张量时,文件大小和视图关系都很重要,则必须谨慎构造新的张量,在保存之前确保这些张量的存储对象尽可能小但仍具有所需的视图关系。

保存和加载 torch.nn.Module

参见:教程:模型的保存与加载

在 PyTorch 中,模块的状态通常通过“状态字典”来序列化。状态字典包含了模块的所有参数和持久缓冲区。

>>> bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> list(bn.named_parameters())
[('weight', Parameter containing: tensor([1., 1., 1.], requires_grad=True)),
 ('bias', Parameter containing: tensor([0., 0., 0.], requires_grad=True))]

>>> list(bn.named_buffers())
[('running_mean', tensor([0., 0., 0.])),
 ('running_var', tensor([1., 1., 1.])),
 ('num_batches_tracked', tensor(0))]

>>> bn.state_dict()
OrderedDict([('weight', tensor([1., 1., 1.])),
             ('bias', tensor([0., 0., 0.])),
             ('running_mean', tensor([0., 0., 0.])),
             ('running_var', tensor([1., 1., 1.])),
             ('num_batches_tracked', tensor(0))])

出于兼容性的原因,建议不要直接保存模块,而是只保存其状态字典。Python 模块甚至提供了一个函数 load_state_dict(),可以从状态字典中恢复模块的状态:

>>> torch.save(bn.state_dict(), 'bn.pt')
>>> bn_state_dict = torch.load('bn.pt')
>>> new_bn = torch.nn.BatchNorm1d(3, track_running_stats=True)
>>> new_bn.load_state_dict(bn_state_dict)
<All keys matched successfully>

注意,状态字典首先通过torch.load()从文件中加载,然后使用load_state_dict()恢复状态。

即使是自定义模块或包含其他模块的模块,也都有状态字典,并可以使用这种模式:

# A module with two linear layers
>>> class MyModule(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

      def forward(self, input):
        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

>>> m = MyModule()
>>> m.state_dict()
OrderedDict([('l0.weight', tensor([[ 0.1400, 0.4563, -0.0271, -0.4406],
                                   [-0.3289, 0.2827, 0.4588, 0.2031]])),
             ('l0.bias', tensor([ 0.0300, -0.1316])),
             ('l1.weight', tensor([[0.6533, 0.3413]])),
             ('l1.bias', tensor([-0.1112]))])

>>> torch.save(m.state_dict(), 'mymodule.pt')
>>> m_state_dict = torch.load('mymodule.pt')
>>> new_m = MyModule()
>>> new_m.load_state_dict(m_state_dict)
<All keys matched successfully>

torch.save 的序列化文件格式

自从 PyTorch 1.6.0 版本以来,torch.save 默认会生成一个未压缩的 ZIP64 归档文件,除非用户设置了 _use_new_zipfile_serialization=False

在此存档中,文件的排列顺序如下:
checkpoint.pth
├── data.pkl
├── byteorder  # added in PyTorch 2.1.0
├── data/
│   ├── 0
│   ├── 1
│   ├── 2
│   └── …
└── version
条目如下:
  • data.pkl 是对传递给 torch.save 的对象进行序列化后的结果,但不包含其中的 torch.Storage 对象。

  • byteorder 包含一个字符串,表示保存时的 sys.byteorder 值(“little” 或 “big”)

  • data/ 文件夹包含对象中的所有存储,每个存储都是一个单独的文件。

  • version 在保存时会包含一个版本号,加载时可以使用这个版本号。

在保存时,PyTorch 会确保每个文件的本地文件头填充至偏移量为 64 字节倍数的位置,从而保证每个文件的偏移量都与 64 字节对齐。

注意

在某些设备(如XLA)上,张量会被序列化为pickled numpy数组,而它们的存储信息不会被序列化。因此,在这种情况下,data/ 可能不会出现在检查点中。

在C++中序列化和加载torch.nn.Module

参见:教程:使用 C++ 加载 TorchScript 模型

ScriptModules 可以被序列化为一个 TorchScript 程序,并通过 torch.jit.load() 加载。这种序列化方式包含了所有模块的方法、子模块、参数和属性,使得生成的程序可以在 C++ 中运行而无需 Python。

torch.jit.save()torch.save() 之间的区别可能不太明显。torch.save() 使用 pickle 来保存 Python 对象,这在原型设计、研究和训练时非常有用。而 torch.jit.save() 则将 ScriptModules 序列化为可以在 Python 或 C++ 中加载的格式,这对于保存和加载 C++ 模块或在部署 PyTorch 模型时使用 C++ 运行在 Python 中训练的模块非常有用。

在Python中编写、序列化并加载模块:

>>> scripted_module = torch.jit.script(MyModule())
>>> torch.jit.save(scripted_module, 'mymodule.pt')
>>> torch.jit.load('mymodule.pt')
RecursiveScriptModule( original_name=MyModule
                      (l0): RecursiveScriptModule(original_name=Linear)
                      (l1): RecursiveScriptModule(original_name=Linear) )

可以使用torch.jit.save()保存追踪的模块,但需要注意的是,只有被追踪的代码路径会被序列化。以下示例演示了这一点:

# A module with control flow
>>> class ControlFlowModule(torch.nn.Module):
      def __init__(self):
        super().__init__()
        self.l0 = torch.nn.Linear(4, 2)
        self.l1 = torch.nn.Linear(2, 1)

      def forward(self, input):
        if input.dim() > 1:
            return torch.tensor(0)

        out0 = self.l0(input)
        out0_relu = torch.nn.functional.relu(out0)
        return self.l1(out0_relu)

>>> traced_module = torch.jit.trace(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(traced_module, 'controlflowmodule_traced.pt')
>>> loaded = torch.jit.load('controlflowmodule_traced.pt')
>>> loaded(torch.randn(2, 4)))
tensor([[-0.1571], [-0.3793]], grad_fn=<AddBackward0>)

>>> scripted_module = torch.jit.script(ControlFlowModule(), torch.randn(4))
>>> torch.jit.save(scripted_module, 'controlflowmodule_scripted.pt')
>>> loaded = torch.jit.load('controlflowmodule_scripted.pt')
>> loaded(torch.randn(2, 4))
tensor(0)

上述模块包含一个不会被追踪输入触发的 if 语句,因此它不被视为被追踪的部分,并且也不会与该部分一起序列化。然而,脚本化的模块则包含了这个 if 语句并与之一起进行序列化。有关脚本编写和追踪的更多信息,请参阅TorchScript 文档

最后,在C++中加载模块:

>>> torch::jit::script::Module module;
>>> module = torch::jit::load('controlflowmodule_scripted.pt');

参见PyTorch C++ API文档,了解更多关于如何在C++中使用PyTorch模块的信息。

跨不同版本的PyTorch保存和加载ScriptModules

PyTorch团队建议使用相同版本的PyTorch来保存和加载模块。较旧版本的PyTorch可能不支持新模块,而新版本可能会移除或修改旧的行为。这些变化在PyTorch发行说明中明确列出,并且依赖于已更改功能的模块可能需要更新以继续正常工作。在某些有限的情况下(如下所述),PyTorch会保留序列化ScriptModules的历史行为,因此它们不需要进行更新。

torch.div进行整数除法运算

在 PyTorch 1.5 及其之前的版本中,torch.div() 函数在接受两个整数输入时会进行地板除法运算。

# PyTorch 1.5 (and earlier)
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1)

然而,在 PyTorch 1.7 中,torch.div() 将始终对其输入执行真正的除法操作,就像在 Python 3 中一样进行除法运算:

# PyTorch 1.7
>>> a = torch.tensor(5)
>>> b = torch.tensor(3)
>>> a / b
tensor(1.6667)

torch.div() 的行为在序列化后的 ScriptModules 中得以保留。具体来说,使用 PyTorch 1.6 版本之前的版本进行序列化的 ScriptModules 在加载时仍然会看到 torch.div() 在给定两个整数输入的情况下执行地板除法,即使是在使用较新版本的 PyTorch 加载也是如此。然而,使用 torch.div() 并在 PyTorch 1.6 及以后版本序列化的 ScriptModules 在早期版本的 PyTorch 中无法加载,因为那些早期版本不理解新的行为。

torch.full 总是被推断为浮点数据类型

在 PyTorch 1.5 及更早版本中,torch.full() 函数总是返回一个浮点张量,不论指定的填充值为何:

# PyTorch 1.5 and earlier
>>> torch.full((3,), 1)  # Note the integer fill value...
tensor([1., 1., 1.])     # ...but float tensor!

然而,在 PyTorch 1.7 中,torch.full() 函数将根据填充值来推断返回的张量的数据类型。

# PyTorch 1.7
>>> torch.full((3,), 1)
tensor([1, 1, 1])

>>> torch.full((3,), True)
tensor([True, True, True])

>>> torch.full((3,), 1.)
tensor([1., 1., 1.])

>>> torch.full((3,), 1 + 1j)
tensor([1.+1.j, 1.+1.j, 1.+1.j])

torch.full() 的行为在序列化的 ScriptModules 中得以保留。也就是说,在 PyTorch 1.6 版本之前进行序列化的 ScriptModules 将继续默认返回浮点张量,即使给定的是布尔值或整数值也是如此。然而,使用 torch.full() 并在 PyTorch 1.6 及更高版本上序列化的 ScriptModules 在早期版本的 PyTorch 中无法加载,因为那些早期版本不支持新的行为。

实用函数

以下实用函数与序列化有关:

torch.serialization.register_package(priority, tagger, deserializer)[源代码]

使用相关联的优先级注册用于标记和反序列化存储对象的可调用函数。标记在保存时将设备与存储对象关联起来,而反序列化则在加载时将存储对象移动到适当的设备。taggerdeserializer 按其 priority 顺序运行,直到某个标记器或反序列化器返回的值不是None

要覆盖全球注册表中设备的反序列化行为,可以注册一个优先级高于当前标签器的新标签器。

此函数也可以用来为新设备注册标签器和反序列化器。

参数
返回值

示例

>>> def ipu_tag(obj):
>>>     if obj.device.type == 'ipu':
>>>         return 'ipu'
>>> def ipu_deserialize(obj, location):
>>>     if location.startswith('ipu'):
>>>         ipu = getattr(torch, "ipu", None)
>>>         assert ipu is not None, "IPU device module is not loaded"
>>>         assert torch.ipu.is_available(), "ipu is not available"
>>>         return obj.ipu(location)
>>> torch.serialization.register_package(11, ipu_tag, ipu_deserialize)
torch.serialization.get_default_load_endianness()[源代码]

获取文件加载的备选字节顺序

如果保存的检查点中不存在字节顺序标记,则使用该字节顺序作为备选方案。默认情况下,它采用“native”字节顺序。

返回值

可选[LoadEndianness]

返回类型

默认加载字节序

torch.serialization.set_default_load_endianness(endianness)[源代码]

设置加载文件时的字节顺序备选方案

如果保存的检查点中不存在字节顺序标记,则使用该字节顺序作为备选方案。默认情况下,它采用“native”字节顺序。

参数

字节序 — 新的默认字节顺序

torch.serialization.get_default_mmap_options()[源代码]

使用mmap=Truetorch.load() 获取默认的 mmap 选项。

默认值为 mmap.MAP_PRIVATE

返回值

整型

返回类型

默认的 mmap 选项

torch.serialization.set_default_mmap_options(flags)[源代码]

提供一个上下文管理器或函数,用于为 torch.load()mmap=True 设置默认的内存映射选项。

目前仅支持 mmap.MAP_PRIVATEmmap.MAP_SHARED 中的一个。如果需要其他选项,请提交一个问题。

注意

此功能目前不适用于Windows系统。

参数

flags (int) – 可以是 mmap.MAP_PRIVATEmmap.MAP_SHARED

torch.serialization.add_safe_globals(safe_globals)[源代码]

将给定的全局变量标记为在weights_only加载时是安全的。例如,添加到此列表中的函数可以在反序列化过程中被调用,类也可以被实例化并设置状态。

参数

safe_globals (List[Any]) – 指定为安全的全局变量列表

示例

>>> import tempfile
>>> class MyTensor(torch.Tensor):
...     pass
>>> t = MyTensor(torch.randn(2, 3))
>>> with tempfile.NamedTemporaryFile() as f:
...     torch.save(t, f.name)
# Running `torch.load(f.name, weights_only=True)` will fail with
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
...     torch.serialization.add_safe_globals([MyTensor])
...     torch.load(f.name, weights_only=True)
# MyTensor([[-0.5024, -1.8152, -0.5455],
#          [-0.8234,  2.0500, -0.3657]])
torch.serialization.clear_safe_globals()[源代码]

清除在weights_only加载时被认为是安全的全局变量列表。

torch.serialization.get_safe_globals()[源代码]

返回一个用户添加的安全全局变量列表,这些变量适用于weights_only加载。

返回类型

`List[Any]`

torch.serialization.safe_globals(safe_globals)[源代码]

一个上下文管理器,将某些全局变量标记为weights_only加载时的安全项。

参数

safe_globals (List[Any]) – 权重-only 加载的全局变量列表。

示例

>>> import tempfile
>>> class MyTensor(torch.Tensor):
...     pass
>>> t = MyTensor(torch.randn(2, 3))
>>> with tempfile.NamedTemporaryFile() as f:
...     torch.save(t, f.name)
# Running `torch.load(f.name, weights_only=True)` will fail with
# Unsupported global: GLOBAL __main__.MyTensor was not an allowed global by default.
# Check the code and make sure MyTensor is safe to be used when loaded from an arbitrary checkpoint.
...     with torch.serialization.safe_globals([MyTensor]):
...         torch.load(f.name, weights_only=True)
# MyTensor([[-0.5024, -1.8152, -0.5455],
#          [-0.8234,  2.0500, -0.3657]])
>>> assert torch.serialization.get_safe_globals() == []
torch.serialization.skip_data(materialize_fake_tensors=False)[源代码]

一个上下文管理器,用于跳过为 torch.save 调用写入存储字节。

数据存储仍将保留,但其字节通常会被写入的空间将会是空闲状态。这些存储的字节可以在后续的独立过程中进行填充。

警告

skip_data 上下文管理器是一个早期原型,可能还会发生变化。

参数

materialize_fake_tensors (bool) – 是否将 FakeTensors 实例化。

示例

>>> import tempfile
>>> t = torch.randn(2, 3)
>>> with tempfile.NamedTemporaryFile() as f:
...     with torch.serialization.skip_data():
...         torch.save(t, f.name)
...     torch.load(f.name, weights_only=True)
tensor([[0., 0., 0.],
        [0., 0., 0.]])
本页目录