torch.hub

Pytorch Hub 是一个预训练模型库,旨在促进研究的可重复性。

发布模型

PyTorch Hub 支持通过添加一个简单的 hubconf.py 文件,将预训练模型(包括模型定义和预训练权重)发布到 GitHub 仓库。

hubconf.py 文件可以包含多个入口点,每个入口点是一个 Python 函数,用于定义你想要发布的预训练模型等。

def entrypoint_name(*args, **kwargs):
    # args & kwargs are optional, for models which take positional/keyword arguments.
    ...

如何实现入口函数?

这是一个代码片段,指定了resnet18模型的入口点。如果我们扩展了pytorch/vision/hubconf.py中的实现,在大多数情况下直接在hubconf.py中导入正确的函数就足够了。这里我们只是想使用扩展版本作为示例来展示其工作原理。你可以在pytorch/vision 仓库中查看完整的脚本。

dependencies = ['torch']
from torchvision.models.resnet import resnet18 as _resnet18

# resnet18 is the name of entrypoint
def resnet18(pretrained=False, **kwargs):
""" # This docstring shows up in hub.help()
    Resnet18 model
    pretrained (bool): kwargs, load pretrained weights into the model
    """
    # Call the model, load pretrained weights
    model = _resnet18(pretrained=pretrained, **kwargs)
    return model
  • dependencies 变量是一个包含加载模型所需包名的列表。请注意,这些依赖项可能与训练模型时所需的有所不同。

  • argskwargs 将被传递给实际的可调用函数。

  • 函数的文档字符串用作帮助信息,解释模型的功能及允许的位置参数和关键字参数,并强烈建议添加一些示例。

  • 入口函数可以返回一个模型(如 nn.module),或提供一些辅助工具来简化用户的工作流程,例如分词器。

  • 以下划线开头的可调用函数会被视为辅助函数,在torch.hub.list()中不会显示。

  • 预训练权重可以存储在 GitHub 仓库的本地,或通过torch.hub.load_state_dict_from_url() 加载。如果文件大小小于2GB,建议将其附加到项目发布 并使用该发布的 URL。在上面的例子中,torchvision.models.resnet.resnet18 处理了 pretrained 参数,或者你也可以将以下逻辑放在入口点定义中。

if pretrained:
    # For checkpoint saved in local GitHub repo, e.g. <RELATIVE_PATH_TO_CHECKPOINT>=weights/save.pth
    dirname = os.path.dirname(__file__)
    checkpoint = os.path.join(dirname, <RELATIVE_PATH_TO_CHECKPOINT>)
    state_dict = torch.load(checkpoint)
    model.load_state_dict(state_dict)

    # For checkpoint saved elsewhere
    checkpoint = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
    model.load_state_dict(torch.hub.load_state_dict_from_url(checkpoint, progress=False))

重要通知

  • 发布的模型应位于某个分支或标签中,而不能是任意的一个提交。

从Hub加载模型

PyTorch Hub 提供了方便的 API,可以通过 torch.hub.list() 探索所有可用的模型,通过 torch.hub.help() 显示文档字符串和示例,并使用 torch.hub.load() 加载预训练模型。

torch.hub.list(github, force_reload=False, skip_validation=False, trust_repo=None, verbose=True)[源代码]

列出在指定的 github 仓库中所有可用的可调用入口点。

参数
  • github (str) – 格式为“repo_owner/repo_name[:ref]”的字符串,其中 ref(标签或分支)是可选的。如果未指定 ref,则默认分支假设为 main(如果存在),否则为 master。示例:‘pytorch/vision:0.10’

  • force_reload (bool, 可选) – 是否强制丢弃现有缓存并进行全新下载。默认值为 False

  • skip_validation (bool, optional) – 如果设置为 False,torchhub 将验证由 github 参数指定的分支或提交是否属于仓库所有者。这将向 GitHub API 发送请求;你可以通过设置环境变量 GITHUB_TOKEN 来使用非默认的 GitHub 令牌。默认值为 False

  • trust_repo (bool, strNone) –

    "check", True, FalseNone。该参数在 v1.12 版本中引入,有助于确保用户仅运行来自他们信任的仓库中的代码。

    • 如果为 False,会提示用户是否信任该仓库。

    • 如果为 True,仓库将会被加入到受信任的列表中,并自动加载而无需明确确认。

    • 如果为"check",则会检查仓库是否在缓存的受信任仓库列表中。如果没有找到,则会退回到trust_repo=False选项。

    • 如果为 None: 这将引发一个警告,提示用户将 trust_repo 设置为 FalseTrue"check" 中的一个值。这只是为了向后兼容而暂时保留,并将在 v2.0 版本中移除。

    默认值为None,在v2.0版本中将会更改为"check"

  • verbose (bool, optional) – 如果为 False,则不会显示关于命中本地缓存的消息。请注意,默认情况下首次下载的消息无法被静音。默认值是 True

返回值

可用的 callable 入口点

返回类型

列表

示例

>>> entrypoints = torch.hub.list("pytorch/vision", force_reload=True)
torch.hub.help(github, model, force_reload=False, skip_validation=False, trust_repo=None)[源代码]

显示 model 入口点的文档字符串。

参数
  • github (str) – 一个格式为<repo_owner/repo_name[:ref]>的字符串,其中 ref 是可选的(可以是标签或分支)。如果未指定ref,则默认分支假设为main(如果存在),否则为master。示例:‘pytorch/vision:0.10’

  • model (str) – 仓库中 hubconf.py 文件里定义的入口点名称的字符串

  • force_reload (bool, 可选) – 是否强制丢弃现有缓存并进行全新下载。默认值为 False

  • skip_validation (bool, optional) – 如果设置为 False,torchhub 将验证通过 github 参数指定的 ref 是否正确属于仓库所有者。这将向 GitHub API 发送请求;你可以通过设置环境变量 GITHUB_TOKEN 来使用非默认的 GitHub token。默认值为 False

  • trust_repo (bool, strNone) –

    "check", True, FalseNone。该参数在 v1.12 版本中引入,有助于确保用户仅运行来自他们信任的仓库中的代码。

    • 如果为 False,会提示用户是否信任该仓库。

    • 如果为 True,仓库将会被加入到受信任的列表中,并自动加载而无需明确确认。

    • 如果为"check",则会检查仓库是否在缓存的受信任仓库列表中。如果没有找到,则会退回到trust_repo=False选项。

    • 如果为 None: 这将引发一个警告,提示用户将 trust_repo 设置为 FalseTrue"check" 中的一个值。这只是为了向后兼容而暂时保留,并将在 v2.0 版本中移除。

    默认值为None,在v2.0版本中将会更改为"check"

示例

>>> print(torch.hub.help("pytorch/vision", "resnet18", force_reload=True))
torch.hub.load(repo_or_dir, model, *args, source='github', trust_repo=None, force_reload=False, verbose=True, skip_validation=False, **kwargs)[源代码]
从 GitHub 仓库或本地目录加载模型。

注意:加载模型是最常见的用例,但也适用于加载其他对象,如分词器、损失函数等。

如果 source 是 ‘github’,repo_or_dir 应该是形式为 repo_owner/repo_name[:ref] 的字符串,并且可选地包含一个 ref(如标签或分支)。

如果 source 是 ‘local’,则 repo_or_dir 应该是本地目录的路径。

参数
  • repo_or_dir (str) – 如果 source 是 ‘github’,则应对应于一个格式为 repo_owner/repo_name[:ref] 的 GitHub 仓库,并且可选地包含 ref(标签或分支),例如 ‘pytorch/vision:0.10’。如果未指定 ref,则默认分支为 main(若存在的话),否则为 master。如果 source 是 ‘local’,则应是一个指向本地目录的路径。

  • model (str) – 仓库或目录中 hubconf.py 文件里定义的一个可调用函数(入口点)的名称。

  • *args (可选) – 用于可调用对象 model 的参数。

  • source (str, 可选) – “github” 或 “local”。用于指定如何解释 repo_or_dir。默认值为 “github”。

  • trust_repo (bool, strNone) –

    "check", True, FalseNone。该参数在 v1.12 版本中引入,有助于确保用户仅运行来自他们信任的仓库中的代码。

    • 如果为 False,会提示用户是否信任该仓库。

    • 如果为 True,仓库将会被加入到受信任的列表中,并自动加载而无需明确确认。

    • 如果为"check",则会检查仓库是否在缓存的受信任仓库列表中。如果没有找到,则会退回到trust_repo=False选项。

    • 如果为 None: 这将引发一个警告,提示用户将 trust_repo 设置为 FalseTrue"check" 中的一个值。这只是为了向后兼容而暂时保留,并将在 v2.0 版本中移除。

    默认值为None,在v2.0版本中将会更改为"check"

  • force_reload (bool, 可选) – 是否无条件强制重新下载 github 仓库。如果 source = 'local',则不会产生任何效果。默认值为 False

  • verbose (bool, optional) – 如果设置为 False,则不会显示关于命中本地缓存的消息。需要注意的是,首次下载的消息无法被静音。当 source = 'local' 时,此参数无效。默认值为 True

  • skip_validation (bool, optional) – 如果设置为 False,torchhub 将验证由 github 参数指定的分支或提交是否属于仓库所有者。这将向 GitHub API 发送请求;你可以通过设置环境变量 GITHUB_TOKEN 来使用非默认的 GitHub 令牌。默认值为 False

  • **kwargs (可选) – 可调用对象 model 的相关关键字参数。

返回值

使用给定的 *args**kwargs 调用 model 时的返回结果。

示例

>>> # from a github repo
>>> repo = "pytorch/vision"
>>> model = torch.hub.load(
...     repo, "resnet50", weights="ResNet50_Weights.IMAGENET1K_V1"
... )
>>> # from a local directory
>>> path = "/some/local/path/pytorch/vision"
>>> model = torch.hub.load(path, "resnet50", weights="ResNet50_Weights.DEFAULT")
torch.hub.download_url_to_file(url, dst, hash_prefix=None, progress=True)[源代码]

将给定网址的对象下载到本地路径。

参数
  • url (str) – 需要下载的对象的 URL

  • dst (str) – 对象将被保存的完整路径,例如 /tmp/temporary_file

  • hash_prefix (str, 可选) – 如果不为 None,则下载文件的 SHA256 哈希值应以 hash_prefix 开头。默认:None

  • progress (bool, optional) – 是否在 stderr 显示进度条,默认为 True

示例

>>> torch.hub.download_url_to_file(
...     "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth",
...     "/tmp/temporary_file",
... )
torch.hub.load_state_dict_from_url(url, model_dir=None, map_location=None, progress=True, check_hash=False, file_name=None, weights_only=False)[源代码]

从给定的URL加载Torch序列化对象。

如果是zip文件,它将会被自动解压。

如果对象已经在model_dir中存在,它会被反序列化并返回。默认情况下,model_dir 的值为 <hub_dir>/checkpoints,其中 hub_dir 是由get_dir() 函数返回的目录。

参数
  • url (str) – 需要下载的对象的 URL

  • model_dir (str, 可选) – 用于保存对象的目录

  • map_location (可选) – 用于指定如何重新映射存储位置的函数或字典(详见 torch.load)

  • progress (bool, optional) – 是否显示进度条(默认为标准错误输出)。默认值:True

  • check_hash (bool, 可选) – 如果为 True,URL 的文件名部分应遵循命名约定 filename-<sha256>.ext,其中 <sha256> 是文件内容的 SHA256 哈希值的前八个或更多位数字。哈希用于确保名称唯一并验证文件的内容。默认值:False

  • file_name (str, 可选) – 下载文件的名称。如果没有设置,则使用 url 中的文件名。

  • weights_only (bool, optional) – 如果为 True,则仅加载权重而不加载复杂的 pickled 对象。推荐用于不信任的来源。更多详细信息请参见load()

返回类型

Dict[str, Any]

示例

>>> state_dict = torch.hub.load_state_dict_from_url(
...     "https://s3.amazonaws.com/pytorch/models/resnet18-5c106cde.pth"
... )

运行加载的模型

注意,在torch.hub.load() 中的 *args**kwargs 用于实例化一个模型。加载完模型后,如何了解可以对模型进行哪些操作?建议的工作流程是:

  • 使用 dir(model) 查看模型的所有可用方法。

  • help(model.foo) 用于查看 model.foo 的参数

为了帮助用户不用频繁查阅文档就能轻松探索,我们强烈建议仓库所有者使函数的帮助信息清晰简洁,并提供一个最小工作示例。

我的下载模型存放在哪里?

按以下顺序使用位置

  • 调用 hub.set_dir(<PATH_TO_HUB_DIR>)

  • 如果设置了环境变量 TORCH_HOME,则使用 $TORCH_HOME/hub

  • 如果环境变量 XDG_CACHE_HOME 已设置,则使用 $XDG_CACHE_HOME/torch/hub

  • ~/.cache/torch/hub

torch.hub.get_dir()[源代码]

获取用于存储下载模型和权重的Torch Hub缓存目录。

如果没有调用set_dir(),默认路径是$TORCH_HOME/hub。其中环境变量$TORCH_HOME 默认值为 $XDG_CACHE_HOME/torch。如果未设置环境变量,则$XDG_CACHE_HOME 按照 Linux 文件系统布局的 X Design Group 规范,默认值为 ~/.cache

torch.hub.set_dir(d)[源代码]

可选地设置用于保存下载模型和权重的Torch Hub目录。

参数

d (str) – 本地文件夹的路径,用于保存下载的模型和权重。

缓存逻辑

默认情况下,我们在加载文件后不会进行清理。如果缓存已存在get_dir()返回的目录中,Hub 会默认使用该缓存。

用户可以通过调用hub.load(..., force_reload=True)来强制重新加载。这将删除现有的GitHub文件夹和已下载的权重,并重新初始化一个新的下载过程。当相同的分支发布了更新时,此操作可以帮助用户跟上最新版本。

已知限制

Torch hub 通过像安装一样导入包来工作。在 Python 中,导入会引入一些副作用,例如你会看到新的项目出现在 sys.modulessys.path_importer_cache 缓存中,这是正常的 Python 行为。这意味着当你从不同仓库中导入不同的模型时,如果这些仓库有相同的子包名(通常是一个 model 子包),你可能会遇到导入错误。解决这类问题的一个方法是从 sys.modules 字典中移除有问题的子包;更多详情可以在这个 GitHub 问题 中找到。

这里有一个重要的限制需要注意:用户不能在同一 Python 进程中同时加载同一个仓库的不同分支。这就像在 Python 中安装同名的两个包,这是不允许的。如果你尝试这样做,可能会遇到缓存带来的意外问题。当然,在不同的进程中加载它们是没有问题的。

本页目录