torch.diagflat
- torch.diagflat(input, offset=0) → Tensor
-
-
如果
input
是一个向量(即一维张量),则返回一个二维的方形张量,其对角线上的元素与input
中的元素相同。 -
如果
input
是一个多维张量,那么返回一个二维张量,其对角线元素等于展平后的input
。
参数
offset
控制要考虑哪一条对角线:-
当
offset
= 0 时,表示为主对角线。 -
如果
offset
大于 0,它位于主对角线的上方。 -
如果
offset
小于 0,它位于主对角线之下。
示例:
>>> a = torch.randn(3) >>> a tensor([-0.2956, -0.9068, 0.1695]) >>> torch.diagflat(a) tensor([[-0.2956, 0.0000, 0.0000], [ 0.0000, -0.9068, 0.0000], [ 0.0000, 0.0000, 0.1695]]) >>> torch.diagflat(a, 1) tensor([[ 0.0000, -0.2956, 0.0000, 0.0000], [ 0.0000, 0.0000, -0.9068, 0.0000], [ 0.0000, 0.0000, 0.0000, 0.1695], [ 0.0000, 0.0000, 0.0000, 0.0000]]) >>> a = torch.randn(2, 2) >>> a tensor([[ 0.2094, -0.3018], [-0.1516, 1.9342]]) >>> torch.diagflat(a) tensor([[ 0.2094, 0.0000, 0.0000, 0.0000], [ 0.0000, -0.3018, 0.0000, 0.0000], [ 0.0000, 0.0000, -0.1516, 0.0000], [ 0.0000, 0.0000, 0.0000, 1.9342]])
-