torch.sort
- torch.sort(input, dim=-1, descending=False, stable=False, *, out=None)
-
沿指定维度,按值升序对
input
张量的元素进行排序。如果没有提供
dim
,则默认选择输入的最后一个维度。如果
descending
为True
,则元素将按照值的降序进行排序。如果
stable
为True
,则排序过程将变为稳定的,保持等价元素原有的顺序。返回一个包含两个元素的元组(namedtuple)(values, indices),其中 values 是排序后的值,indices 是原始 input 张量中对应元素的索引。
- 参数
- 关键字参数
-
out (元组, 可选) – 可选的输出元组,包含 Tensor 和 LongTensor,可以用作输出缓冲区
示例:
>>> x = torch.randn(3, 4) >>> sorted, indices = torch.sort(x) >>> sorted tensor([[-0.2162, 0.0608, 0.6719, 2.3332], [-0.5793, 0.0061, 0.6058, 0.9497], [-0.5071, 0.3343, 0.9553, 1.0960]]) >>> indices tensor([[ 1, 0, 2, 3], [ 3, 1, 0, 2], [ 0, 3, 1, 2]]) >>> sorted, indices = torch.sort(x, 0) >>> sorted tensor([[-0.5071, -0.2162, 0.6719, -0.5793], [ 0.0608, 0.0061, 0.9497, 0.3343], [ 0.6058, 0.9553, 1.0960, 2.3332]]) >>> indices tensor([[ 2, 0, 0, 1], [ 0, 1, 1, 2], [ 1, 2, 2, 0]]) >>> x = torch.tensor([0, 1] * 9) >>> x.sort() torch.return_types.sort( values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), indices=tensor([ 2, 16, 4, 6, 14, 8, 0, 10, 12, 9, 17, 15, 13, 11, 7, 5, 3, 1])) >>> x.sort(stable=True) torch.return_types.sort( values=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]), indices=tensor([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 1, 3, 5, 7, 9, 11, 13, 15, 17]))