torch.repeat_interleave

torch.repeat_interleave(input, repeats, dim=None, *, output_size=None) Tensor

复制张量的元素。

警告

这与torch.Tensor.repeat()不同,但与numpy.repeat类似。

参数
  • input (Tensor) – 需要输入的张量。

  • repeats (Tensorint) – 指定每个元素需要重复的次数。参数 repeats 会根据给定轴的形状进行广播。

  • dim (int, 可选) – 指定重复值的维度。默认情况下,使用展平的输入数组,并返回一个扁平的输出数组。

关键字参数

output_size (int, 可选) – 给定轴的总输出大小(例如重复次数之和)。如果提供,可以避免计算张量形状所需的流同步。

返回值

一个重复的张量,其形状与输入张量相同,只是在指定的轴上有所不同。

返回类型

张量

示例:

>>> x = torch.tensor([1, 2, 3])
>>> x.repeat_interleave(2)
tensor([1, 1, 2, 2, 3, 3])
>>> y = torch.tensor([[1, 2], [3, 4]])
>>> torch.repeat_interleave(y, 2)
tensor([1, 1, 2, 2, 3, 3, 4, 4])
>>> torch.repeat_interleave(y, 3, dim=1)
tensor([[1, 1, 1, 2, 2, 2],
        [3, 3, 3, 4, 4, 4]])
>>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0)
tensor([[1, 2],
        [3, 4],
        [3, 4]])
>>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3)
tensor([[1, 2],
        [3, 4],
        [3, 4]])

如果 repeatstensor([n1, n2, n3, …]),那么输出将会是 tensor([0, 0, …, 1, 1, …, 2, 2, …, …])。其中 0 出现 n1 次,1 出现 n2 次,2 出现 n3 次,以此类推。

torch.repeat_interleave(repeats) Tensor

重复 0 次为 repeats[0],重复 1 次为 repeats[1],重复 2 次为 repeats[2],等等。

参数

repeats (Tensor) – 表示每个元素需要重复的次数。

返回值

sum(repeats)大小的重复张量。

返回类型

张量

示例:

>>> torch.repeat_interleave(torch.tensor([1, 2, 3]))
tensor([0, 1, 1, 2, 2, 2])
本页目录