torch.Size

torch.Size 是调用 torch.Tensor.size() 的结果类型。它描述了原始张量所有维度的大小。作为 tuple 的子类,torch.Size 支持常见的序列操作,如索引和长度计算。

示例:

>>> x = torch.ones(10, 20, 30)
>>> s = x.size()
>>> s
torch.Size([10, 20, 30])
>>> s[1]
20
>>> len(s)
3
classtorch.Size(iterable=(), /)
count(value, /)

返回值的数量。

index(value, start=0, stop=9223372036854775807, /)

返回值的首个索引。

若值不存在,则抛出 ValueError 异常。

numel() int

返回给定大小的 torch.Tensor 包含的元素数量。

更正式地说,对于一个大小为 torch.Size([10, 10]) 的张量 x = tensor.ones(10, 10),表达式 x.numel() == x.size().numel() == s.numel() == 100 是成立的。

示例:
>>> x=torch.ones(10, 10)
>>> s=x.size()
>>> s
torch.Size([10, 10])
>>> s.numel()
100
>>> x.numel() == s.numel()
True

警告

此函数不返回torch.Size描述的维度数量,而是返回一个具有该尺寸的torch.Tensor所包含的元素数量。

本页目录