torch.take_along_dim
- torch.take_along_dim(input, indices, dim=None, *, out=None)) → Tensor
-
根据
indices
指定的一维索引,在input
的给定dim
维度上选择相应的值。如果
dim
为 None,输入数组将被视为已被展平为一维。返回维度索引的函数,如
torch.argmax()
和torch.argsort()
,都是为此功能设计的。请参见下面的例子。注意
该函数类似于 NumPy 的 take_along_axis。详情请参阅
torch.gather()
。- 参数
- 关键字参数
-
out (Tensor, 可选) – 指定输出张量。
示例:
>>> t = torch.tensor([[10, 30, 20], [60, 40, 50]]) >>> max_idx = torch.argmax(t) >>> torch.take_along_dim(t, max_idx) tensor([60]) >>> sorted_idx = torch.argsort(t, dim=1) >>> torch.take_along_dim(t, sorted_idx, dim=1) tensor([[10, 20, 30], [40, 50, 60]])