torch.load

torch.load(f, map_location=None, pickle_module=pickle, *, weights_only=False, mmap=None, **pickle_load_args)[源代码]

从文件中加载一个使用 torch.save() 保存的对象。

torch.load() 使用 Python 的反序列化功能,但会特别处理支撑张量的存储。这些存储首先在 CPU 上进行反序列化,然后移动到它们被保存时所在的设备。如果这失败(例如,因为运行时系统缺少某些设备),则会抛出异常。然而,可以使用 map_location 参数将存储动态映射到一组替代设备。

如果 map_location 是一个可调用对象,它将为每个序列化的存储器调用一次,并带有两个参数:存储和位置。 存储参数将是初始的反序列化存储,位于 CPU 上。 每个序列化的存储都有一个与之关联的位置标签,该标签标识了保存它的设备,这个标签是传递给 map_location 的第二个参数。 内置位置标签包括用于 CPU 张量的 'cpu' 和用于 CUDA 张量的 'cuda:device_id'(例如,'cuda:2')。 如果 map_location 返回一个存储器,则该存储器将作为最终反序列化的对象使用,并且已经移动到了正确的设备上。 否则,torch.load() 将回退到默认行为,就好像没有指定 map_location 一样。

如果 map_location 是一个 torch.device 对象或包含设备标签的字符串,它表示所有张量应被加载的位置。

否则,如果 map_location 是一个字典,它会将其用作重映射:将文件中的位置标签(键)映射到指定存储位置的标签(值)。

用户扩展可以使用 torch.serialization.register_package() 注册自定义的位置标签、标记以及反序列化方法。

参数
  • f (Union[str, PathLike, BinaryIO, IO[bytes]]) – 类文件对象(必须实现 read()readline()tell()seek() 方法),或包含文件名的字符串或 os.PathLike 对象

  • map_location (Optional[Union[Callable[[Storage, str], Storage], device, str, Dict[str, str]]]) – 一个函数、torch.device 对象、字符串或字典,用于指定如何重新映射存储位置。

  • pickle_module (Optional[Any]) – 用于反序列化元数据和对象的模块(必须与用于序列化文件的 pickle_module 匹配)

  • weights_only (Optional[bool]) – 表示 unpickler 是否仅限于加载张量、原始类型、字典以及通过 torch.serialization.add_safe_globals() 添加的任何安全全局类型。

  • mmap (Optional[bool]) – 表示文件是否应该被内存映射(mmap)而不是将所有存储加载到内存中。通常情况下,文件中的张量存储会先从磁盘移动到CPU内存,然后根据保存时标记的位置或map_location指定的位置进行移动。如果最终位置是CPU,则这一步为无操作(no-op)。当设置mmap标志时,在第一步中不会将张量存储从磁盘复制到CPU内存,而是直接对文件进行内存映射。

  • pickle_load_args (Any) – (仅限 Python 3) 可选的参数,传递给 pickle_module.load()pickle_module.Unpickler() 方法,例如 errors=...

返回类型

Any

警告

torch.load() 除非将 weights_only 参数设置为 True,否则会隐式使用被认为不安全的 pickle 模块。恶意的 pickle 数据可以在反序列化时执行任意代码。永远不要在不安全模式下加载可能来自不可信来源或被篡改的数据。只加载你信任的数据

注意

当你调用torch.load()来加载一个包含GPU张量的文件时,默认情况下这些张量会被加载到GPU。为了避免在加载模型检查点时导致GPU内存激增,你可以先调用torch.load(.., map_location='cpu'),然后再调用load_state_dict()

注意

默认情况下,我们将字节字符串解码为utf-8。这是为了避免在Python 3中加载由Python 2保存的文件时出现常见的错误情况:UnicodeDecodeError: 'ascii' codec can't decode byte 0x...。如果此默认设置不正确,您可以使用额外的encoding关键字参数来指定如何加载这些对象。例如,encoding='latin1' 使用latin1编码将它们解码为字符串;而encoding='bytes' 将它们保持为字节数组,稍后可以使用byte_array.decode(...)进行解码。

示例

>>> torch.load("tensors.pt", weights_only=True)
# Load all tensors onto the CPU
>>> torch.load("tensors.pt", map_location=torch.device("cpu"), weights_only=True)
# Load all tensors onto the CPU, using a function
>>> torch.load(
...     "tensors.pt", map_location=lambda storage, loc: storage, weights_only=True
... )
# Load all tensors onto GPU 1
>>> torch.load(
...     "tensors.pt",
...     map_location=lambda storage, loc: storage.cuda(1),
...     weights_only=True,
... )  # type: ignore[attr-defined]
# Map tensors from GPU 1 to GPU 0
>>> torch.load("tensors.pt", map_location={"cuda:1": "cuda:0"}, weights_only=True)
# Load tensor from io.BytesIO object
# Loading from a buffer setting weights_only=False, warning this can be unsafe
>>> with open("tensor.pt", "rb") as f:
...     buffer = io.BytesIO(f.read())
>>> torch.load(buffer, weights_only=False)
# Load a module with 'ascii' encoding for unpickling
# Loading from a module setting weights_only=False, warning this can be unsafe
>>> torch.load("module.pt", encoding="ascii", weights_only=False)
本页目录