CTC损失函数

classtorch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=False)[源代码]

连接时序分类损失

计算连续(未分割)时间序列和目标序列之间的损失值。CTCLoss 通过对输入与目标可能的对齐概率求和,生成一个相对于每个输入节点可微的损失值。输入到目标的对齐方式为“多对一”,这意味着目标序列的长度必须小于或等于输入长度。

参数
  • blank (int, 可选) – 空白标签,默认值为$0$

  • reduction (str, 可选) – 指定要应用于输出的缩减方式: 'none' | 'mean' | 'sum''none': 不进行任何缩减,'mean': 输出损失将被目标长度除后,在批次上取平均值,'sum': 输出损失将被求和。默认: 'mean'

  • zero_infinity (bool, optional) – 是否将无穷大的损失值设为零及其相关梯度。默认:False 当输入太短无法与目标对齐时,通常会出现无穷大的损失。

形状:
  • Log_probs: 张量大小为 $(T, N, C)$$(T, C)$,其中 $T = \text{输入长度}$$N = \text{批量大小}$$C = \text{类别数量(包括空白)}$。输出的概率值经过对数化处理(例如通过torch.nn.functional.log_softmax()获得)。

  • 目标:大小为$(N, S)$$(\operatorname{sum}(\text{target\_lengths}))$的张量,其中$N = \text{批量大小}$$S = \text{最大目标长度(如果形状为 } (N, S) \text{)}$。它表示目标序列,每个元素是一个类索引,并且不能是空白(默认值=0)。在$(N, S)$形式中,目标被填充到最长序列的长度并堆叠在一起。在$(\operatorname{sum}(\text{target\_lengths}))$形式中,假设目标未被填充并在一个维度内连接。

  • Input_lengths: 元组或大小为$(N)$$()$的张量,其中$N = \text{批量大小}$。它表示输入序列的长度(每个长度必须小于等于$\leq T$)。并且为每个序列指定长度以在假设所有序列已填充到相同长度的情况下实现屏蔽。

  • target_lengths: 元组或大小为 $(N)$$()$ 的张量,其中 $N = \text{批量大小}$。它表示目标序列的长度。为了在假设所有序列都被填充为等长的情况下实现屏蔽功能,需要为每个序列指定其实际长度。如果目标形状是 $(N,S)$,则 target_lengths 实际上是每个目标序列的有效结束索引 $s_n$,使得对于批量中的每个目标有 target_n = targets[n,0:s_n]。这些长度必须各自小于或等于序列的最大长度 $S$。如果目标以一维张量的形式给出,该张量是各个目标的串联,则所有 target_lengths 的总和必须等于整个张量的长度。

  • 输出:如果 reduction'mean'(默认值)或 'sum',则为标量。 如果 reduction'none',输入批量时输出维度为$(N)$,未批量时输出维度为$()$,其中 $N = \text{batch size}$

    输出:如果 reduction'mean'(默认值)或 'sum',则为标量。 如果 reduction'none',输入批量时输出维度为$(N)$,未批量时输出维度为$()$,其中 $N = \text{batch size}$

示例:

>>> # Target are to be padded
>>> T = 50      # Input sequence length
>>> C = 20      # Number of classes (including blank)
>>> N = 16      # Batch size
>>> S = 30      # Target sequence length of longest target in batch (padding length)
>>> S_min = 10  # Minimum target length, for demonstration purposes
>>>
>>> # Initialize random batch of input vectors, for *size = (T,N,C)
>>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
>>>
>>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
>>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
>>>
>>>
>>> # Target are to be un-padded
>>> T = 50      # Input sequence length
>>> C = 20      # Number of classes (including blank)
>>> N = 16      # Batch size
>>>
>>> # Initialize random batch of input vectors, for *size = (T,N,C)
>>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
>>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
>>> target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long)
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
>>>
>>>
>>> # Target are to be un-padded and unbatched (effectively N=1)
>>> T = 50      # Input sequence length
>>> C = 20      # Number of classes (including blank)
>>>
>>> # Initialize random batch of input vectors, for *size = (T,C)
>>> input = torch.randn(T, C).log_softmax(1).detach().requires_grad_()
>>> input_lengths = torch.tensor(T, dtype=torch.long)
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long)
>>> target = torch.randint(low=1, high=C, size=(target_lengths,), dtype=torch.long)
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
参考

A. Graves 等人:连接主义时间分类:使用循环神经网络对未分割的序列数据进行标注:https://www.cs.toronto.edu/~graves/icml_2006.pdf

注意

为了使用CuDNN,需要满足以下条件:变量targets 必须以拼接格式存在;所有input_lengths 的值必须为 T$blank=0$target_lengths 不超过256;整数参数的数据类型必须是 torch.int32

常规实现使用(在 PyTorch 中较为常见的)torch.long 数据类型。

注意

在某些情况下,使用 CUDA 后端与 CuDNN 时,此操作符可能会选择一个非确定性算法来提高性能。如果这不可取,你可以尝试通过将 torch.backends.cudnn.deterministic = True 设置为真来使操作具有确定性(可能会影响性能)。请参阅关于可重复性的说明以获取背景信息。

本页目录