torch.argwhere
- torch.argwhere(input) → Tensor
-
返回一个张量,其中包含
input
中所有非零元素的索引。结果中的每一行对应于input
中的一个非零元素的索引。结果按字典顺序排序,最后一个索引变化最快(C风格)。如果
input
有 $n$ 维度,那么结果索引张量out
的大小为 $(z \times n)$,其中 $z$ 是input
张量中非零元素的总数。注意
此函数类似于 NumPy 的 argwhere 函数。
当
input
在 CUDA 上时,此函数会触发主机和设备之间的同步。- 参数
-
{input} –
示例:
>>> t = torch.tensor([1, 0, 1]) >>> torch.argwhere(t) tensor([[0], [2]]) >>> t = torch.tensor([[1, 0, 1], [0, 1, 1]]) >>> torch.argwhere(t) tensor([[0, 0], [0, 2], [1, 1], [1, 2]])