torch.jit.save

torch.jit.save(m, f, _extra_files=None)[源代码]

保存此模块的离线版本,以便在单独的进程中使用。

保存的模块会序列化此模块的所有方法、子模块、参数和属性。可以使用 C++ API 中的 torch::jit::load(filename) 或 Python API 中的 torch.jit.load 进行加载。

为了能够保存一个模块,它不能调用任何原生的Python函数。这意味着所有的子模块也必须是 ScriptModule 的子类。

危险

无论模块的设备信息如何,在加载时所有模块都会被加载到 CPU 上。这与 torch.load() 的语义不同,未来可能会有所改变。

参数
  • m – 要保存的一个ScriptModule

  • f – 一个类似于文件的对象(需要实现写入和刷新功能)或包含文件名的字符串。

  • _extra_files – 文件名与内容之间的映射,这些内容将作为 f 的一部分进行存储。

注意

torch.jit.save 功能旨在跨不同版本保留某些操作符的行为。例如,在 PyTorch 1.5 中,两个整数张量的除法执行地板除法。如果包含该代码的模块在 PyTorch 1.5 中保存并在 PyTorch 1.6 中加载,则其除法行为将被保留。然而,若同一个模块在 PyTorch 1.6 中保存,并尝试在较早版本如 PyTorch 1.5 中加载时会失败,因为从 PyTorch 1.6 开始,除法的行为发生了变化,而旧版本无法理解并复制新版本中的行为。

示例: ...测试代码:

import torch
import io

class MyModule(torch.nn.Module):
    def forward(self, x):
        return x + 10

m = torch.jit.script(MyModule())

# Save to file
torch.jit.save(m, 'scriptmodule.pt')
# This line is equivalent to the previous
m.save("scriptmodule.pt")

# Save to io.BytesIO buffer
buffer = io.BytesIO()
torch.jit.save(m, buffer)

# Save with extra files
extra_files = {'foo.txt': b'bar'}
torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)
本页目录