torch.combinations
- torch.combinations(input, r=2, with_replacement=False) → seq
-
计算给定张量长度为$r$的组合。当with_replacement设置为False时,行为类似于Python中的itertools.combinations;当with_replacement设置为True时,则类似于itertools.combinations_with_replacement。
- 参数
- 返回值
-
将所有的输入张量转换成列表,然后使用itertools.combinations或itertools.combinations_with_replacement对这些列表进行操作,最后将生成的列表再转换成张量。
- 返回类型
示例:
>>> a = [1, 2, 3] >>> list(itertools.combinations(a, r=2)) [(1, 2), (1, 3), (2, 3)] >>> list(itertools.combinations(a, r=3)) [(1, 2, 3)] >>> list(itertools.combinations_with_replacement(a, r=2)) [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)] >>> tensor_a = torch.tensor(a) >>> torch.combinations(tensor_a) tensor([[1, 2], [1, 3], [2, 3]]) >>> torch.combinations(tensor_a, r=3) tensor([[1, 2, 3]]) >>> torch.combinations(tensor_a, with_replacement=True) tensor([[1, 1], [1, 2], [1, 3], [2, 2], [2, 3], [3, 3]])