Add new reduction mode in kl_div (#14457)
authorAiling Zhang <ailzhang@fb.com>
Tue, 4 Dec 2018 20:21:17 +0000 (12:21 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 4 Dec 2018 20:24:28 +0000 (12:24 -0800)
Summary:
Fixes #6622 .
We used to average over all elements for kl divergence, which is not aligned with its math definition.
This PR corrects the default reduction behavior of KL divergence that it now naverages over batch dimension.

- In KL, default behavior `reduction=mean` averages over batch dimension. While for most other loss functions, `reduction=mean` averages over all elements.
- We used to support scalar tensor as well. For BC purpose, we still support it, no reduction is performed on scalar tensor.
- Added a new reduction mode called `batchmean` which has the correct behavior for KL. Add a warning to make `batchmean` as default for KL instead of `mean` in next major release.
- [deprecated]I chose to not add a new reduction option, since "mean over batch dimension" is kinda special, and it only makes sense in few cases like KL. We don't want to explain why there's a option "batchmean" but it's not applicable for all other functions. I'm open to discussion on this one, as I cannot think of a perfect solution for this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14457

Differential Revision: D13236016

Pulled By: ailzhang

fbshipit-source-id: 905cc7b3bfc35a11d7cf098b1ebc382170a087a7

test/common_nn.py
test/test_nn.py
torch/nn/functional.py
torch/nn/modules/loss.py

index 36efff2..eebd597 100644 (file)
@@ -2208,6 +2208,8 @@ def kldivloss_reference(input, target, reduction='mean'):
         return result.mean()
     elif reduction == 'sum':
         return result.sum()
+    elif reduction == 'batchmean' and results.dim() != 0:
+        return result.sum() / result.size(0)
     return result
 
 
index e3645ef..b349e49 100644 (file)
@@ -4016,6 +4016,19 @@ class TestNN(NNTestCase):
         with self.assertRaisesRegex(ValueError, 'Expected.*batch_size'):
             F.nll_loss(x, t)
 
+    def test_KLDivLoss_batch_mean(self):
+        input_shape = (2, 5)
+        log_prob1 = F.log_softmax(torch.randn(input_shape), 1)
+        prob2 = F.softmax(torch.randn(input_shape), 1)
+
+        loss = nn.KLDivLoss(reduction='batchmean')
+        l = loss(log_prob1, prob2)
+
+        loss_none_reduce = nn.KLDivLoss(reduction='sum')(log_prob1, prob2)
+        expected = loss_none_reduce / input_shape[0]
+
+        self.assertEqual(l, expected)
+
     @unittest.skipIf(not (TEST_CUDNN and TEST_CUDNN_VERSION >= 7000), "needs cudnn >= 7.0")
     def test_CTCLoss_cudnn(self):
         target_lengths = [30, 25, 20]
index d690000..883d5ab 100644 (file)
@@ -1885,17 +1885,40 @@ def kl_div(input, target, size_average=None, reduce=None, reduction='mean'):
             on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
             batch element instead and ignores :attr:`size_average`. Default: ``True``
         reduction (string, optional): Specifies the reduction to apply to the output:
-            'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
-            'mean': the sum of the output will be divided by the number of
-            elements in the output, 'sum': the output will be summed. Note: :attr:`size_average`
-            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
-            specifying either of those two args will override :attr:`reduction`. Default: 'mean'
+            'none' | 'batchmean' | 'sum' | 'mean'.
+            'none': no reduction will be applied
+            'batchmean': the sum of the output will be divided by the batchsize
+            'sum': the output will be summed
+            'mean': the output will be divided by the number of elements in the output
+            Default: 'mean'
+
+        .. note:: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated,
+            and in the meantime, specifying either of those two args will override :attr:`reduction`.
+
+        .. note:: `reduction='mean'` doesn't return the true kl divergence value, please use
+            `reduction='batchmean'` which aligns with KL math definition.
+            In the next major release, 'mean' will be changed to be the same as 'batchmean'.
     """
     if size_average is not None or reduce is not None:
         reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
     else:
-        reduction_enum = _Reduction.get_enum(reduction)
-    return torch.kl_div(input, target, reduction_enum)
+        if reduction == 'mean':
+            warnings.warn("reduction: 'mean' divides the total loss by both the batch size and the support size."
+                          "'batchmean' divides only by the batch size, and aligns with the KL div math definition."
+                          "'mean' will be changed to behave the same as 'batchmean' in the next major release.")
+
+        # special case for batchmean
+        if reduction == 'batchmean':
+            reduction_enum = _Reduction.get_enum('sum')
+        else:
+            reduction_enum = _Reduction.get_enum(reduction)
+
+    reduced = torch.kl_div(input, target, reduction_enum)
+
+    if reduction == 'batchmean' and input.dim() != 0:
+        reduced = reduced / input.size()[0]
+
+    return reduced
 
 
 @torch._jit_internal.weak_script
index 472bb53..cdc3a7a 100644 (file)
@@ -282,7 +282,7 @@ class KLDivLoss(_Loss):
 
     As with :class:`~torch.nn.NLLLoss`, the `input` given is expected to contain
     *log-probabilities*. However, unlike :class:`~torch.nn.NLLLoss`, `input` is not
-    restricted to a 2D Tensor, because the criterion is applied element-wise.
+    restricted to a 2D Tensor.
     The targets are given as *probabilities* (i.e. without taking the logarithm).
 
     This criterion expects a `target` `Tensor` of the same size as the
@@ -303,31 +303,14 @@ class KLDivLoss(_Loss):
             \operatorname{sum}(L),  & \text{if}\; \text{size\_average} = \text{False}.
         \end{cases}
 
-    By default, the losses are averaged for each minibatch over observations
-    **as well as** over dimensions. However, if the field
-    :attr:`size_average` is set to ``False``, the losses are instead summed.
+    In default reduction mode 'mean', the losses are averaged for each minibatch over observations
+    **as well as** over dimensions. 'batchmean' mode gives the correct KL divergence where losses
+    are averaged over batch dimension only. 'mean' mode's behavior will be changed to the same as
+    'batchmean' in the next major release.
 
     .. _Kullback-Leibler divergence:
         https://en.wikipedia.org/wiki/Kullback-Leibler_divergence
 
-    .. note:: The default averaging means that the loss is actually **not** the
-          KL Divergence because the terms are already probability weighted.
-          A future release of PyTorch may move the default loss closer to the
-          mathematical definition.
-
-          To get the real KL Divergence, use ``size_average=False``, and
-          then divide the output by the batch size.
-
-          Example::
-
-            >>> loss = nn.KLDivLoss(size_average=False)
-            >>> batch_size = 5
-            >>> log_probs1 = F.log_softmax(torch.randn(batch_size, 10), 1)
-            >>> probs2 = F.softmax(torch.randn(batch_size, 10), 1)
-            >>> loss(log_probs1, probs2) / batch_size
-            tensor(0.7142)
-
-
     Args:
         size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
             the losses are averaged over each loss element in the batch. Note that for
@@ -339,11 +322,18 @@ class KLDivLoss(_Loss):
             on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
             batch element instead and ignores :attr:`size_average`. Default: ``True``
         reduction (string, optional): Specifies the reduction to apply to the output:
-            'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
-            'mean': the sum of the output will be divided by the number of
-            elements in the output, 'sum': the output will be summed. Note: :attr:`size_average`
-            and :attr:`reduce` are in the process of being deprecated, and in the meantime,
-            specifying either of those two args will override :attr:`reduction`. Default: 'mean'
+            'none' | 'batchmean' | 'sum' | 'mean'.
+            'none': no reduction will be applied.
+            'batchmean': the sum of the output will be divided by batchsize.
+            'sum': the output will be summed.
+            'mean': the output will be divided by the number of elements in the output.
+            Default: 'mean'
+        .. note:: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated,
+            and in the meantime, specifying either of those two args will override :attr:`reduction`.
+        .. note:: `reduction='mean'` doesn't return the true kl divergence value, please use
+            `reduction='batchmean'` which aligns with KL math definition.
+            In the next major release, 'mean' will be changed to be the same as 'batchmean'.
+
 
     Shape:
         - input: :math:`(N, *)` where `*` means, any number of additional