torch.jit.load

torch.jit.load(f, map_location=None, _extra_files=None, _restore_shapes=False)[源代码]

加载一个之前使用 torch.jit.save 保存的 ScriptModuleScriptFunction

所有先前保存的模块,不论其所在设备为何,都将先被加载到CPU上,然后移回原设备。若此过程失败(如运行时系统缺少所需设备),将引发异常。
参数
  • f – 一个类文件对象(需要实现 read、readline、tell 和 seek 方法),或者是一个包含文件名的字符串

  • map_location (stringtorch.device) – 在torch.jit.save中,map_location的简化版本,用于将存储动态重映射到一组替代设备。

  • _extra_files字典,文件名到内容的映射)– 映射中指定的额外文件将被加载,其内容将会存储在提供的映射中。

  • _restore_shapes (bool) – 是否在加载模块时使用存储的输入进行重新跟踪

返回值

一个ScriptModule对象。

示例: ...测试代码:

import torch
import io

torch.jit.load('scriptmodule.pt')

# Load ScriptModule from io.BytesIO object
with open('scriptmodule.pt', 'rb') as f:
    buffer = io.BytesIO(f.read())

# Load all tensors to the original device
torch.jit.load(buffer)

# Load all tensors onto CPU, using a device
buffer.seek(0)
torch.jit.load(buffer, map_location=torch.device('cpu'))

# Load all tensors onto CPU, using a string
buffer.seek(0)
torch.jit.load(buffer, map_location='cpu')

# Load with extra files.
extra_files = {'foo.txt': ''}  # values will be replaced with data
torch.jit.load('scriptmodule.pt', _extra_files=extra_files)
print(extra_files['foo.txt'])
本页目录