torch.save

torch.save(obj, f, pickle_module=pickle, pickle_protocol=DEFAULT_PROTOCOL, _use_new_zipfile_serialization=True)[源代码]

将对象存储到磁盘文件中。

另请参阅: 保存和加载张量

参数
  • obj (对象) — 保存的目标对象

  • f (Union[str, PathLike, BinaryIO, IO[bytes]]) – 类文件对象(需要实现 write 和 flush 方法)或包含文件名的字符串或 os.PathLike 对象

  • pickle_module (Any) – 用于存储和序列化元数据及对象的模块

  • pickle_protocol (int) – 指定此参数可以覆盖默认的协议

注意

在 PyTorch 中,常用的做法是使用 .pt 文件扩展名来保存张量。

注意

PyTorch 在序列化过程中保留存储共享。关于更多细节,请参阅 保存和加载张量时保留视图

注意

PyTorch 1.6 版本将 torch.save 的文件格式切换为基于 zip 文件的新格式。而 torch.load 仍然可以加载旧的文件格式。
如果需要让 torch.save 使用旧格式,可以在调用时传递关键字参数 _use_new_zipfile_serialization=False

示例

>>> # Save to file
>>> x = torch.tensor([0, 1, 2, 3, 4])
>>> torch.save(x, "tensor.pt")
>>> # Save to io.BytesIO buffer
>>> buffer = io.BytesIO()
>>> torch.save(x, buffer)
本页目录