torch.std_mean
- torch.std_mean(input, dim=None, *, correction=1, keepdim=False, out=None)
-
计算由
dim
指定的维度上的标准差和均值。其中dim
可以是单个维度、维度列表,或者None
(表示在所有维度上进行操作)。标准差($\sigma$)是这样计算的:
$\sigma = \sqrt{\frac{1}{\max(0,~N - \delta N)}\sum_{i=0}^{N-1}(x_i-\bar{x})^2}$其中$x$表示元素的样本集,$\bar{x}$表示样本均值,$N$表示样本数量,而$\delta N$则是
校正
。如果
keepdim
是True
,则输出张量与输入张量大小相同,除了在dim
指定的维度上其大小为 1。否则,dim
维度会被挤压(参见torch.squeeze()
),导致输出张量比输入少 1 (或len(dim)
)个维度。- 参数
- 关键字参数
- 返回值
-
一个包含标准差和平均值的元组(std, mean)。
示例
>>> a = torch.tensor( ... [[ 0.2035, 1.2959, 1.8101, -0.4644], ... [ 1.5027, -0.3270, 0.5905, 0.6538], ... [-1.5745, 1.3330, -0.5596, -0.6548], ... [ 0.1264, -0.5080, 1.6420, 0.1992]]) >>> torch.std_mean(a, dim=0, keepdim=True) (tensor([[1.2620, 1.0028, 1.0957, 0.6038]]), tensor([[ 0.0645, 0.4485, 0.8707, -0.0665]]))