torch.index_select
- torch.index_select(input, dim, index, *, out=None) → Tensor
-
沿维度
dim
使用index
(一个 LongTensor)对input
张量进行索引,并返回一个新的张量。返回的张量与原始张量 (
input
) 具有相同的维度数量。第dim
维度的大小等于index
的长度;其他维度的大小则与原始张量中的相同。注意
返回的张量不会与原始张量共享同一存储空间。如果
out
的形状与预期不符,我们会自动调整其形状为正确形式,并在需要时重新分配底层存储。- 参数
- 关键字参数
-
out (Tensor, 可选) – 指定输出张量。
示例:
>>> x = torch.randn(3, 4) >>> x tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], [-0.4664, 0.2647, -0.1228, -1.1068], [-1.1734, -0.6571, 0.7230, -0.6004]]) >>> indices = torch.tensor([0, 2]) >>> torch.index_select(x, 0, indices) tensor([[ 0.1427, 0.0231, -0.5414, -1.0009], [-1.1734, -0.6571, 0.7230, -0.6004]]) >>> torch.index_select(x, 1, indices) tensor([[ 0.1427, -0.5414], [-0.4664, -0.1228], [-1.1734, 0.7230]])