torch.nn.functional.gumbel_softmax

torch.nn.functional.gumbel_softmax(logits, tau=1, hard=False, eps=1e-10, dim=-1)[源代码]

从 Gumbel-Softmax 分布 (链接 1 链接 2) 中采样,并可选择进行离散化。

参数
  • logits (Tensor) – 形状为 […, num_features] 的张量,包含未归一化的对数概率值

  • tau (float) – 表示非负标量温度

  • hard (bool) – 如果为 True,返回的样本将被离散化为 one-hot 向量,但在自动微分过程中会被视为软样本进行处理。

  • dim (int) – 计算softmax的维度。默认值:-1。

返回值

从 Gumbel-Softmax 分布中采样一个与 logits 具有相同形状的张量。如果 hard=True,返回的样本将是一-hot 编码形式;否则它们将是沿 dim 维度上和为1的概率分布。

返回类型

Tensor

注意

该函数为了兼容旧版本而保留,将来可能从 nn.Functional 中删除。

注意

对于hard,主要的技巧是执行y_hard - y_soft.detach() + y_soft

它实现了以下两个目标:1. 将输出值变为完全的one-hot形式(因为我们在加完y_soft值后再减去它)。2. 让梯度与y_soft的梯度相等(因为我们消去了所有其他的梯度影响)

示例:
>>> logits = torch.randn(20, 32)
>>> # Sample soft categorical using reparametrization trick:
>>> F.gumbel_softmax(logits, tau=1, hard=False)
>>> # Sample hard categorical using "Straight-through" trick:
>>> F.gumbel_softmax(logits, tau=1, hard=True)
本页目录