Update documentation for CTCLoss (#18415)
authorryan <rlorigro@ucsc.edu>
Sat, 30 Mar 2019 08:20:55 +0000 (01:20 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 30 Mar 2019 08:26:34 +0000 (01:26 -0700)
Summary:
This is meant to resolve #18249, where I pointed out a few things that could improve the CTCLoss docs.

My main goal was to clarify:
- Target sequences are sequences of class indices, excluding the blank index
- Lengths of `target` and `input` are needed for masking unequal length sequences, and do not necessarily = S, which is the length of the longest sequence in the batch.

I thought about Thomas's suggestion to link the distill.pub article, but I'm not sure about it. I think that should be up to y'all to decide.

I have no experience with .rst, so it might not render as expected :)

t-vi ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18415

Differential Revision: D14691969

Pulled By: soumith

fbshipit-source-id: 381a2d52307174661c58053ae9dfae6e40cbfd46

torch/nn/modules/loss.py

index 2cb4069..7ba5ffe 100644 (file)
@@ -1221,38 +1221,75 @@ class TripletMarginLoss(_Loss):
 class CTCLoss(_Loss):
     r"""The Connectionist Temporal Classification loss.
 
-    Args:
-        blank (int, optional): blank label. Default :math:`0`.
+    Calculates loss between a continuous (unsegmented) time series and a target sequence. CTCLoss sums over the
+    probability of possible alignments of input to target, producing a loss value which is differentiable
+    with respect to each input node. The alignment of input to target is assumed to be "many-to-one", which
+    limits the length of the target sequence such that it must be :math: `\leq` the input length.
+
+    **Args:**
+        **blank** (int, optional): blank label. Default :math:`0`.
         reduction (string, optional): Specifies the reduction to apply to the output:
             ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction will be applied,
             ``'mean'``: the output losses will be divided by the target lengths and
             then the mean over the batch is taken. Default: ``'mean'``
-        zero_infinity (bool, optional):
+
+        **zero_infinity** (bool, optional):
             Whether to zero infinite losses and the associated gradients.
             Default: ``False``
             Infinite losses mainly occur when the inputs are too short
             to be aligned to the targets.
 
-    Inputs:
-        log_probs: Tensor of size :math:`(T, N, C)` where `C = number of characters in alphabet including blank`,
-            `T = input length`, and `N = batch size`.
+    **Inputs:**
+        **log_probs**: Tensor of size :math:`(T, N, C)`
+            | :math:`T = input length`
+            | :math:`N = batch size`
+            | :math:`C = number of classes (including blank)`
+
             The logarithmized probabilities of the outputs
             (e.g. obtained with :func:`torch.nn.functional.log_softmax`).
-        targets: Tensor of size :math:`(N, S)` or `(sum(target_lengths))`.
-            Targets (cannot be blank). In the second form, the targets are assumed to be concatenated.
-        input_lengths: Tuple or tensor of size :math:`(N)`.
-            Lengths of the inputs (must each be :math:`\leq T`)
-        target_lengths: Tuple or tensor of size  :math:`(N)`.
-            Lengths of the targets
+        **targets**: Tensor of size :math:`(N, S)` or `(sum(target_lengths))`
+            | :math:`N = batch size`
+            | :math:`S = max target length, if shape is (N, S)`.
+
+            | Target sequences. Each element in the target sequence is a class index. Target index
+              cannot be blank (default=0).
+
+            | In the :math:`(N, S)` form, targets are padded to the length of the longest sequence, and stacked.
+            | In the :math:`(sum(target_lengths))` form, the targets are assumed to be un-padded and concatenated
+              within 1 dimension.
+        **input_lengths**: Tuple or tensor of size :math:`(N)`.
+            Lengths of the inputs (must each be :math:`\leq T`).
+            Lengths are specified for each sequence to achieve masking under the
+            assumption that sequences are padded to equal lengths.
+        **target_lengths**: Tuple or tensor of size  :math:`(N)`.
+            | Lengths of the targets. Lengths are specified for each sequence to achieve masking under the
+              assumption that sequences are padded to equal lengths.
+
+            | If target shape is :math:`(N,S)`, target_lengths are effectively the stop index
+              :math:`s_n` for each target sequence, such that ``target_n = targets[n,0:s_n]`` for
+              each target in a batch. Lengths must each be :math:`\leq S`
+
+            | If the targets are given as a 1d tensor that is the concatenation of individual targets,
+              the target_lengths must add up to the total length of the tensor.
 
     Example::
 
+        >>> T = 50      # Input sequence length
+        >>> C = 20      # Number of classes (excluding blank)
+        >>> N = 16      # Batch size
+        >>> S = 30      # Target sequence length of longest target in batch
+        >>> S_min = 10  # Minimum target length, for demonstration purposes
+        >>>
+        >>> # Initialize random batch of input vectors, for *size = (T,N,C)
+        >>> input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
+        >>>
+        >>> # Initialize random batch of targets (0 = blank, 1:C+1 = classes)
+        >>> target = torch.randint(low=1, high=C+1, size=(N, S), dtype=torch.long)
+        >>>
+        >>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
+        >>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
         >>> ctc_loss = nn.CTCLoss()
-        >>> log_probs = torch.randn(50, 16, 20).log_softmax(2).detach().requires_grad_()
-        >>> targets = torch.randint(1, 20, (16, 30), dtype=torch.long)
-        >>> input_lengths = torch.full((16,), 50, dtype=torch.long)
-        >>> target_lengths = torch.randint(10,30,(16,), dtype=torch.long)
-        >>> loss = ctc_loss(log_probs, targets, input_lengths, target_lengths)
+        >>> loss = ctc_loss(input, target, input_lengths, target_lengths)
         >>> loss.backward()
 
     Reference:
@@ -1271,7 +1308,6 @@ class CTCLoss(_Loss):
 
     .. include:: cudnn_deterministic.rst
 
-
     """
     __constants__ = ['blank', 'reduction']