torch.nn.functional.ctc_loss
- torch.nn.functional.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False)[源代码]
-
应用连接主义时序分类损失。
详情请参阅
CTCLoss
。注意
在某些情况下,当张量位于 CUDA 设备上并使用 CuDNN 时,此操作符可能会选择一个非确定性算法以提高性能。如果你不希望这样,可以通过将
torch.backends.cudnn.deterministic = True
设置为True
来使操作具有确定性(这可能会影响性能)。有关更多信息,请参阅可重复性。注意
当给定的张量位于CUDA设备上时,此操作可能会产生非确定性的梯度。更多详细信息请参见重现性。
- 参数
-
-
log_probs (Tensor) – $(T, N, C)$ 或 $(T, C)$,其中 C = 字母表中字符的数量(包括空白字符), T = 输入长度, 和 N = 批量大小。输出的概率对数值(例如,可以通过
torch.nn.functional.log_softmax()
获得)。 -
targets (Tensor) – 形状为 $(N, S)$ 或 (sum(target_lengths))。Targets不能为空白,在第二种形式中,假设targets是拼接在一起的。
-
input_lengths (Tensor) – 形状为 $(N)$ 或 $()$。表示输入序列的长度(每个长度必须小于等于$\leq T$)
-
target_lengths (Tensor) – 形状为 $(N)$ 或 $()$。表示目标的长度。
-
blank (int, optional) – 空白标签,默认值为$0$。
-
reduction (str, 可选) – 指定要应用于输出的缩减方式:
'none'
|'mean'
|'sum'
。'none'
: 不进行任何缩减,'mean'
: 输出损失将被目标长度除后,在批次上取平均值,'sum'
: 输出将会被求和。默认:'mean'
-
zero_infinity (bool, optional) – 是否将无穷大的损失值设为零及其相关梯度。默认:
False
当输入太短无法与目标对齐时,通常会出现无穷大的损失。
-
- 返回类型
示例:
>>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_() >>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long) >>> input_lengths = torch.full((16,), 50, dtype=torch.long) >>> target_lengths = torch.randint(10, 30, (16,), dtype=torch.long) >>> loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths) >>> loss.backward()