torch.nonzero
- torch.nonzero(input, *, out=None, as_tuple=False) → LongTensor 或 LongTensors 元组
-
注意
torch.nonzero(..., as_tuple=False)
(默认情况下)返回一个二维张量,其中每一行表示一个非零元素的索引。torch.nonzero(..., as_tuple=True)
返回一个包含1-D索引张量的元组,支持高级索引操作。因此,x[x.nonzero(as_tuple=True)]
可以获取张量x
中的所有非零值。每个返回的索引张量包含特定维度上的非零索引。关于这两种行为的详情,请参见下文。
当
input
在 CUDA 上时,torch.nonzero()
会触发主机和设备之间的同步。当
as_tuple
为False
(默认情况下):返回一个张量,其中包含
input
中所有非零元素的索引。结果中的每一行对应于input
中的一个非零元素的索引。结果按字典顺序排序,最后一个索引变化最快(C风格)。如果
input
有 $n$ 维度,那么结果索引张量out
的大小为 $(z \times n)$,其中 $z$ 是input
张量中非零元素的总数。当
as_tuple
为True
:返回一个元组,其中包含与
input
每个维度对应的1-D张量。每个张量包含了input
中该维度上所有非零元素的索引。如果
input
有 $n$ 维度,那么结果元组将包含 $n$ 个大小为 $z$ 的张量,其中 $z$ 是input
张量中非零元素的总数。作为一种特殊情况,当
input
是零维度但有非零标量值时,它会被视为一个包含单个元素的一维张量。- 参数
-
input (Tensor) – 需要输入的张量。
- 关键字参数
-
out (LongTensor, optional) – 可选的输出张量,用于存储索引
- 返回值
-
如果
as_tuple
为False
,输出将是一个包含索引的张量。如果as_tuple
为True
,则对于每个维度会有一个1-D张量,其中包含该维度上所有非零元素的索引。 - 返回类型
-
LongTensor 或包含 LongTensor 的 元组
示例:
>>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1])) tensor([[ 0], [ 1], [ 2], [ 4]]) >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], ... [0.0, 0.4, 0.0, 0.0], ... [0.0, 0.0, 1.2, 0.0], ... [0.0, 0.0, 0.0,-0.4]])) tensor([[ 0, 0], [ 1, 1], [ 2, 2], [ 3, 3]]) >>> torch.nonzero(torch.tensor([1, 1, 1, 0, 1]), as_tuple=True) (tensor([0, 1, 2, 4]),) >>> torch.nonzero(torch.tensor([[0.6, 0.0, 0.0, 0.0], ... [0.0, 0.4, 0.0, 0.0], ... [0.0, 0.0, 1.2, 0.0], ... [0.0, 0.0, 0.0,-0.4]]), as_tuple=True) (tensor([0, 1, 2, 3]), tensor([0, 1, 2, 3])) >>> torch.nonzero(torch.tensor(5), as_tuple=True) (tensor([0]),)