torch.argwhere

torch.argwhere(input) Tensor

返回一个张量,其中包含 input 中所有非零元素的索引。结果中的每一行对应于 input 中的一个非零元素的索引。结果按字典顺序排序,最后一个索引变化最快(C风格)。

如果 input$n$ 维度,那么结果索引张量 out 的大小为 $(z \times n)$,其中 $z$input 张量中非零元素的总数。

注意

此函数类似于 NumPy 的 argwhere 函数。

input 在 CUDA 上时,此函数会触发主机和设备之间的同步。

参数

{input}

示例:

>>> t = torch.tensor([1, 0, 1])
>>> torch.argwhere(t)
tensor([[0],
        [2]])
>>> t = torch.tensor([[1, 0, 1], [0, 1, 1]])
>>> torch.argwhere(t)
tensor([[0, 0],
        [0, 2],
        [1, 1],
        [1, 2]])
本页目录