torch.repeat_interleave
- torch.repeat_interleave(input, repeats, dim=None, *, output_size=None) → Tensor
-
复制张量的元素。
警告
这与
torch.Tensor.repeat()
不同,但与numpy.repeat
类似。- 参数
- 关键字参数
-
output_size (int, 可选) – 给定轴的总输出大小(例如重复次数之和)。如果提供,可以避免计算张量形状所需的流同步。
- 返回值
-
一个重复的张量,其形状与输入张量相同,只是在指定的轴上有所不同。
- 返回类型
示例:
>>> x = torch.tensor([1, 2, 3]) >>> x.repeat_interleave(2) tensor([1, 1, 2, 2, 3, 3]) >>> y = torch.tensor([[1, 2], [3, 4]]) >>> torch.repeat_interleave(y, 2) tensor([1, 1, 2, 2, 3, 3, 4, 4]) >>> torch.repeat_interleave(y, 3, dim=1) tensor([[1, 1, 1, 2, 2, 2], [3, 3, 3, 4, 4, 4]]) >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0) tensor([[1, 2], [3, 4], [3, 4]]) >>> torch.repeat_interleave(y, torch.tensor([1, 2]), dim=0, output_size=3) tensor([[1, 2], [3, 4], [3, 4]])
如果 repeats 是 tensor([n1, n2, n3, …]),那么输出将会是 tensor([0, 0, …, 1, 1, …, 2, 2, …, …])。其中 0 出现 n1 次,1 出现 n2 次,2 出现 n3 次,以此类推。
- torch.repeat_interleave(repeats) → Tensor
重复 0 次为 repeats[0],重复 1 次为 repeats[1],重复 2 次为 repeats[2],等等。
示例:
>>> torch.repeat_interleave(torch.tensor([1, 2, 3])) tensor([0, 1, 1, 2, 2, 2])