torch.nn.functional.gumbel_softmax
- torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1)[源代码]
-
从 Gumbel-Softmax 分布 (链接 1 链接 2) 中采样,并可选择进行离散化。
- 参数
- 返回值
-
从 Gumbel-Softmax 分布中采样一个与 logits 具有相同形状的张量。如果
hard=True
,返回的样本将是一-hot 编码形式;否则它们将是沿 dim 维度上和为1的概率分布。 - 返回类型
注意
该函数为了兼容旧版本而保留,将来可能从 nn.Functional 中删除。
注意
对于hard,主要的技巧是执行y_hard - y_soft.detach() + y_soft
它实现了以下两个目标:1. 将输出值变为完全的one-hot形式(因为我们在加完y_soft值后再减去它)。2. 让梯度与y_soft的梯度相等(因为我们消去了所有其他的梯度影响)
- 示例:
-
>>> logits = torch.randn(20, 32) >>> # Sample soft categorical using reparametrization trick: >>> F.gumbel_softmax(logits, tau=1, hard=False) >>> # Sample hard categorical using "Straight-through" trick: >>> F.gumbel_softmax(logits, tau=1, hard=True)