张量视图
PyTorch 允许一个张量成为现有张量的视图(View
)。视图张量与基础张量共享相同的底层数据。支持 View
可以避免显式的数据复制,从而实现快速且内存高效的重塑、切片和逐元素操作。
例如,要查看现有张量 t
,可以调用 t.view(...)
。
>>> t = torch.rand(4, 4) >>> b = t.view(2, 8) >>> t.storage().data_ptr() == b.storage().data_ptr() # `t` and `b` share the same underlying data. True # Modifying view tensor changes base tensor as well. >>> b[0][0] = 3.14 >>> t[0][0] tensor(3.14)
因为视图与基础张量共用底层数据,所以如果修改了视图中的数据,基础张量中的数据也会随之改变。
通常,PyTorch 操作会返回一个新的张量作为输出,例如 add()
。但在视图操作的情况下,输出是输入张量的视图,以避免不必要的数据复制。创建视图时不会发生数据移动,只是改变了对相同数据的解释方式。取连续张量的视图可能会产生非连续张量。用户应注意,连续性可能会影响性能。transpose()
是一个常见的例子。
>>> base = torch.tensor([[0, 1],[2, 3]]) >>> base.is_contiguous() True >>> t = base.transpose(0, 1) # `t` is a view of `base`. No data movement happened here. # View tensors might be non-contiguous. >>> t.is_contiguous() False # To get a contiguous tensor, call `.contiguous()` to enforce # copying data when `t` is not contiguous. >>> c = t.contiguous()
作为参考,以下是 PyTorch 中所有视图操作的完整列表:
-
基本切片和索引操作,例如
tensor[0, 2:, 1:7:2]
返回基础tensor
的视图。请参见下面的说明。 -
view_as_real()
-
split_with_sizes()
-
indices()
(仅适用于稀疏张量) -
values()
(仅适用于稀疏张量)
注意
当通过索引访问张量的内容时,PyTorch遵循Numpy的行为:基本索引返回视图,高级索引返回副本。无论是基本索引还是高级索引的赋值操作都是就地进行的。更多示例请参见Numpy 索引文档。
也值得一提的是几个具有特殊行为的操作。
-
reshape()
、reshape_as()
和flatten()
可能会返回一个视图或新的张量,用户代码不应依赖于具体是哪种情况。 -
contiguous()
如果输入张量已经是连续的,则返回它本身,否则通过复制数据返回一个新的连续张量。
了解更多关于 PyTorch 内部实现的详细内容,请参阅 ezyang 的 PyTorch 内部实现博客文章。