元设备
“元”设备是一个抽象的概念,代表一个仅记录元数据而不包含实际数据的张量。元张量主要有两个应用场景:
-
模型可以在元设备上加载,这样你就可以在不实际将参数加载到内存的情况下,加载模型的表示形式。如果你需要在加载实际数据之前对模型进行一些变换操作,这样做会非常有帮助。
-
大多数操作都可以在元张量上执行,并生成新的元张量来描述实际张量上的操作结果。你可以使用这个功能来进行抽象分析,而无需花费时间或空间来表示实际的张量。由于元张量没有真实的数据,你不能执行像
torch.nonzero()
或item()
这样的数据依赖操作。在某些情况下,不同的设备类型(例如 CPU 和 CUDA)对于同一个操作可能不会产生完全相同的输出元数据;在这种情况下,我们通常更倾向于准确地表示 CUDA 的行为。
警告
尽管原则上元张量计算应该比等效的CPU/CUDA计算更快,但许多元张量实现是在Python中完成,并未移植到C++以提升性能。因此,你可能会发现使用小规模CPU张量时,框架的绝对延迟较低。
处理元张量的常用表达方式
可以使用torch.load()
并指定map_location='meta'
将对象加载到元设备上:
>>> torch.save(torch.randn(2), 'foo.pt') >>> torch.load('foo.pt', map_location='meta') tensor(..., device='meta', size=(2,))
如果你的代码在创建张量时未明确指定设备,可以使用torch.device()
上下文管理器将其改为在元设备上构建张量。
>>> with torch.device('meta'): ... print(torch.randn(30, 30)) ... tensor(..., device='meta', size=(30, 30))
这在构建NN模块时特别有用,因为在这些情况下你通常无法明确地传入设备进行初始化。
>>> from torch.nn.modules import Linear >>> with torch.device('meta'): ... print(Linear(20, 30)) ... Linear(in_features=20, out_features=30, bias=True)
你无法直接将元张量转换为CPU/CUDA张量,因为元张量不包含实际数据,我们也无从得知新张量应有的正确数据值。
>>> torch.ones(5, device='meta').to("cpu") Traceback (most recent call last): File "<stdin>", line 1, in <module> NotImplementedError: Cannot copy out of meta tensor; no data!
使用像torch.empty_like()
这样的工厂函数,来明确指定你希望如何填充缺失的数据。
NN 模块提供了一个方便的方法 torch.nn.Module.to_empty()
,允许你将模块迁移到另一个设备上而不初始化参数。你需要手动重新初始化这些参数。
>>> from torch.nn.modules import Linear >>> with torch.device('meta'): ... m = Linear(20, 30) >>> m.to_empty(device="cpu") Linear(in_features=20, out_features=30, bias=True)
torch._subclasses.meta_utils
提供了一些未记录的工具,可以将任意张量转换为高保真度的元张量。这些 API 是实验性的,并且可能在任何时候以破坏兼容性的方式进行更改。