torch.diagonal_scatter
- torch.diagonal_scatter(input, src, offset=0, dim1=0, dim2=1) → Tensor
-
将
src
张量的值沿input
中相对于dim1
和dim2
的对角线元素嵌入到input
中。此函数返回一个具有新存储空间的张量,而不是返回视图。
参数
offset
控制要考虑哪一条对角线:-
当
offset
= 0 时,表示为主对角线。 -
如果
offset
大于 0,它位于主对角线的上方。 -
如果
offset
小于 0,它位于主对角线之下。
- 参数
注意
src
必须与input
具有相同的大小,才能被嵌入。具体来说,src
的形状应与torch.diagonal(input, offset, dim1, dim2)
相同。示例:
>>> a = torch.zeros(3, 3) >>> a tensor([[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]) >>> torch.diagonal_scatter(a, torch.ones(3), 0) tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) >>> torch.diagonal_scatter(a, torch.ones(2), 1) tensor([[0., 1., 0.], [0., 0., 1.], [0., 0., 0.]])
-