From 96456bfa4cf9394c9c926b143cf724a09901908d Mon Sep 17 00:00:00 2001 From: ryan Date: Sat, 30 Mar 2019 01:20:55 -0700 Subject: [PATCH] Update documentation for CTCLoss (#18415) Summary: This is meant to resolve #18249, where I pointed out a few things that could improve the CTCLoss docs. My main goal was to clarify: - Target sequences are sequences of class indices, excluding the blank index - Lengths of `target` and `input` are needed for masking unequal length sequences, and do not necessarily = S, which is the length of the longest sequence in the batch. I thought about Thomas's suggestion to link the distill.pub article, but I'm not sure about it. I think that should be up to y'all to decide. I have no experience with .rst, so it might not render as expected :) t-vi ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/18415 Differential Revision: D14691969 Pulled By: soumith fbshipit-source-id: 381a2d52307174661c58053ae9dfae6e40cbfd46 --- torch/nn/modules/loss.py | 72 ++++++++++++++++++++++++++++++++++++------------ 1 file changed, 54 insertions(+), 18 deletions(-) diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index 2cb4069..7ba5ffe 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -1221,38 +1221,75 @@ class TripletMarginLoss(_Loss): class CTCLoss(_Loss): r"""The Connectionist Temporal Classification loss. - Args: - blank (int, optional): blank label. Default :math:`0`. + Calculates loss between a continuous (unsegmented) time series and a target sequence. CTCLoss sums over the + probability of possible alignments of input to target, producing a loss value which is differentiable + with respect to each input node. The alignment of input to target is assumed to be "many-to-one", which + limits the length of the target sequence such that it must be :math: `\leq` the input length. + + **Args:** + **blank** (int, optional): blank label. Default :math:`0`. reduction (string, optional): Specifies the reduction to apply to the output: ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied, ``'mean'``: the output losses will be divided by the target lengths and then the mean over the batch is taken. Default: ``'mean'`` - zero_infinity (bool, optional): + + **zero_infinity** (bool, optional): Whether to zero infinite losses and the associated gradients. Default: ``False`` Infinite losses mainly occur when the inputs are too short to be aligned to the targets. - Inputs: - log_probs: Tensor of size :math:`(T, N, C)` where `C = number of characters in alphabet including blank`, - `T = input length`, and `N = batch size`. + **Inputs:** + **log_probs**: Tensor of size :math:`(T, N, C)` + | :math:`T = input length` + | :math:`N = batch size` + | :math:`C = number of classes (including blank)` + The logarithmized probabilities of the outputs (e.g. obtained with :func:`torch.nn.functional.log_softmax`). - targets: Tensor of size :math:`(N, S)` or `(sum(target_lengths))`. - Targets (cannot be blank). In the second form, the targets are assumed to be concatenated. - input_lengths: Tuple or tensor of size :math:`(N)`. - Lengths of the inputs (must each be :math:`\leq T`) - target_lengths: Tuple or tensor of size :math:`(N)`. - Lengths of the targets + **targets**: Tensor of size :math:`(N, S)` or `(sum(target_lengths))` + | :math:`N = batch size` + | :math:`S = max target length, if shape is (N, S)`. + + | Target sequences. Each element in the target sequence is a class index. Target index + cannot be blank (default=0). + + | In the :math:`(N, S)` form, targets are padded to the length of the longest sequence, and stacked. + | In the :math:`(sum(target_lengths))` form, the targets are assumed to be un-padded and concatenated + within 1 dimension. + **input_lengths**: Tuple or tensor of size :math:`(N)`. + Lengths of the inputs (must each be :math:`\leq T`). + Lengths are specified for each sequence to achieve masking under the + assumption that sequences are padded to equal lengths. + **target_lengths**: Tuple or tensor of size :math:`(N)`. + | Lengths of the targets. Lengths are specified for each sequence to achieve masking under the + assumption that sequences are padded to equal lengths. + + | If target shape is :math:`(N,S)`, target_lengths are effectively the stop index + :math:`s_n` for each target sequence, such that ``target_n = targets[n,0:s_n]`` for + each target in a batch. Lengths must each be :math:`\leq S` + + | If the targets are given as a 1d tensor that is the concatenation of individual targets, + the target_lengths must add up to the total length of the tensor. Example:: + >>> T = 50 # Input sequence length + >>> C = 20 # Number of classes (excluding blank) + >>> N = 16 # Batch size + >>> S = 30 # Target sequence length of longest target in batch + >>> S_min = 10 # Minimum target length, for demonstration purposes + >>> + >>> # Initialize random batch of input vectors, for *size = (T,N,C) + >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_() + >>> + >>> # Initialize random batch of targets (0 = blank, 1:C+1 = classes) + >>> target = torch.randint(low=1, high=C+1, size=(N, S), dtype=torch.long) + >>> + >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long) + >>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long) >>> ctc_loss = nn.CTCLoss() - >>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_() - >>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long) - >>> input_lengths = torch.full((16,), 50, dtype=torch.long) - >>> target_lengths = torch.randint(10,30,(16,), dtype=torch.long) - >>> loss = ctc_loss(log_probs, targets, input_lengths, target_lengths) + >>> loss = ctc_loss(input, target, input_lengths, target_lengths) >>> loss.backward() Reference: @@ -1271,7 +1308,6 @@ class CTCLoss(_Loss): .. include:: cudnn_deterministic.rst - """ __constants__ = ['blank', 'reduction'] -- 2.7.4