torch.combinations

torch.combinations(input, r=2, with_replacement=False) seq

计算给定张量长度为$r$的组合。当with_replacement设置为False时,行为类似于Python中的itertools.combinations;当with_replacement设置为True时,则类似于itertools.combinations_with_replacement

参数
  • 输入 (Tensor) – 一个 1D 向量。

  • r (int, 可选) – 指定要组合的元素的数量

  • with_replacement (bool, optional) – 是否允许组合中有重复的元素

返回值

将所有的输入张量转换成列表,然后使用itertools.combinationsitertools.combinations_with_replacement对这些列表进行操作,最后将生成的列表再转换成张量。

返回类型

张量

示例:

>>> a = [1, 2, 3]
>>> list(itertools.combinations(a, r=2))
[(1, 2), (1, 3), (2, 3)]
>>> list(itertools.combinations(a, r=3))
[(1, 2, 3)]
>>> list(itertools.combinations_with_replacement(a, r=2))
[(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)]
>>> tensor_a = torch.tensor(a)
>>> torch.combinations(tensor_a)
tensor([[1, 2],
        [1, 3],
        [2, 3]])
>>> torch.combinations(tensor_a, r=3)
tensor([[1, 2, 3]])
>>> torch.combinations(tensor_a, with_replacement=True)
tensor([[1, 1],
        [1, 2],
        [1, 3],
        [2, 2],
        [2, 3],
        [3, 3]])
本页目录