元设备

“元”设备是一个抽象的概念,代表一个仅记录元数据而不包含实际数据的张量。元张量主要有两个应用场景:

  • 模型可以在元设备上加载,这样你就可以在不实际将参数加载到内存的情况下,加载模型的表示形式。如果你需要在加载实际数据之前对模型进行一些变换操作,这样做会非常有帮助。

  • 大多数操作都可以在元张量上执行,并生成新的元张量来描述实际张量上的操作结果。你可以使用这个功能来进行抽象分析,而无需花费时间或空间来表示实际的张量。由于元张量没有真实的数据,你不能执行像 torch.nonzero()item() 这样的数据依赖操作。在某些情况下,不同的设备类型(例如 CPU 和 CUDA)对于同一个操作可能不会产生完全相同的输出元数据;在这种情况下,我们通常更倾向于准确地表示 CUDA 的行为。

警告

尽管原则上元张量计算应该比等效的CPU/CUDA计算更快,但许多元张量实现是在Python中完成,并未移植到C++以提升性能。因此,你可能会发现使用小规模CPU张量时,框架的绝对延迟较低。

处理元张量的常用表达方式

可以使用torch.load()并指定map_location='meta'将对象加载到元设备上:

>>> torch.save(torch.randn(2), 'foo.pt')
>>> torch.load('foo.pt', map_location='meta')
tensor(..., device='meta', size=(2,))

如果你的代码在创建张量时未明确指定设备,可以使用torch.device()上下文管理器将其改为在元设备上构建张量。

>>> with torch.device('meta'):
...     print(torch.randn(30, 30))
...
tensor(..., device='meta', size=(30, 30))

这在构建NN模块时特别有用,因为在这些情况下你通常无法明确地传入设备进行初始化。

>>> from torch.nn.modules import Linear
>>> with torch.device('meta'):
...     print(Linear(20, 30))
...
Linear(in_features=20, out_features=30, bias=True)

你无法直接将元张量转换为CPU/CUDA张量,因为元张量不包含实际数据,我们也无从得知新张量应有的正确数据值。

>>> torch.ones(5, device='meta').to("cpu")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
NotImplementedError: Cannot copy out of meta tensor; no data!

使用像torch.empty_like()这样的工厂函数,来明确指定你希望如何填充缺失的数据。

NN 模块提供了一个方便的方法 torch.nn.Module.to_empty(),允许你将模块迁移到另一个设备上而不初始化参数。你需要手动重新初始化这些参数。

>>> from torch.nn.modules import Linear
>>> with torch.device('meta'):
...     m = Linear(20, 30)
>>> m.to_empty(device="cpu")
Linear(in_features=20, out_features=30, bias=True)

torch._subclasses.meta_utils 提供了一些未记录的工具,可以将任意张量转换为高保真度的元张量。这些 API 是实验性的,并且可能在任何时候以破坏兼容性的方式进行更改。

本页目录