torch.all

torch.all(input) Tensor

判断input中的所有元素是否都为True

注意

此函数的行为与 NumPy 一致:对于所有受支持的数据类型(除 uint8 之外),返回的输出数据类型为 bool。而对于 uint8 数据类型,输出的数据类型仍然是 uint8

示例:

>>> a = torch.rand(1, 2).bool()
>>> a
tensor([[False, True]], dtype=torch.bool)
>>> torch.all(a)
tensor(False, dtype=torch.bool)
>>> a = torch.arange(0, 3)
>>> a
tensor([0, 1, 2])
>>> torch.all(a)
tensor(False)
torch.all(input, dim, keepdim=False, *, out=None) Tensor

对于给定维度 diminput 的每一行,如果该行的所有元素都为True,则返回True,否则返回False

如果 keepdimTrue,则输出张量与输入张量大小相同,除了在 dim 指定的维度上其大小为 1。否则,dim 维度会被挤压(参见 torch.squeeze()),导致输出张量比输入少 1 (或 len(dim))个维度。

参数
  • input (Tensor) – 需要输入的张量。

  • dim (int元组 of ints) – 需要减少的维度。

  • keepdim (bool) – 是否在输出张量中保留dim维度。

关键字参数

out (Tensor, 可选) – 指定输出张量。

示例:

>>> a = torch.rand(4, 2).bool()
>>> a
tensor([[True, True],
        [True, False],
        [True, True],
        [True, True]], dtype=torch.bool)
>>> torch.all(a, dim=1)
tensor([ True, False,  True,  True], dtype=torch.bool)
>>> torch.all(a, dim=0)
tensor([ True, False], dtype=torch.bool)
本页目录