torch.unravel_index
- torch.unravel_index(indices, shape)[源代码]
-
将扁平索引的张量转换为坐标张量的元组,以便对具有指定形状的任意张量进行索引。
- 参数
-
-
indices (Tensor) – 一个包含任意形状为
shape
的张量的展平版本中的索引的整数张量。所有元素必须在范围[0, prod(shape) - 1]
内。 -
shape (int, ints 序列,或 torch.Size) – 任意张量的形状。所有元素必须是非负数。
-
- 返回值
-
输出中的每个第
i
个张量与shape
的第i
维相对应。每个张量的形状与indices
相同,并且包含针对维度i
的一个索引,对应于由indices
提供的每个扁平化索引。 - 返回类型
-
元组 包含 Tensors
示例:
>>> import torch >>> torch.unravel_index(torch.tensor(4), (3, 2)) (tensor(2), tensor(0)) >>> torch.unravel_index(torch.tensor([4, 1]), (3, 2)) (tensor([2, 0]), tensor([0, 1])) >>> torch.unravel_index(torch.tensor([0, 1, 2, 3, 4, 5]), (3, 2)) (tensor([0, 0, 1, 1, 2, 2]), tensor([0, 1, 0, 1, 0, 1])) >>> torch.unravel_index(torch.tensor([1234, 5678]), (10, 10, 10, 10)) (tensor([1, 5]), tensor([2, 6]), tensor([3, 7]), tensor([4, 8])) >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (10, 10, 10, 10)) (tensor([[1], [5]]), tensor([[2], [6]]), tensor([[3], [7]]), tensor([[4], [8]])) >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (100, 100)) (tensor([[12], [56]]), tensor([[34], [78]]))