Added tensor size warning to F.mse_loss() (#18349)
authormc-robinson <matthew.robinson@yale.edu>
Mon, 25 Mar 2019 02:17:00 +0000 (19:17 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 25 Mar 2019 02:22:14 +0000 (19:22 -0700)
Summary:
To address the issue of broadcasting giving the wrong result in `nn.MSELoss()` as mentioned here https://github.com/pytorch/pytorch/issues/16045 . In particular, the issue often arises when computing the loss between tensors with shapes (n, 1) and (n,)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18349

Differential Revision: D14594176

Pulled By: soumith

fbshipit-source-id: f23ae68a4bf42f3554ad7678a314ba2c7532a6db

test/test_nn.py
torch/nn/functional.py

index 897255c..b93f850 100644 (file)
@@ -4363,6 +4363,18 @@ class TestNN(NNTestCase):
     def test_loss_equal_input_target_shape(self):
         self._test_loss_equal_input_target_shape(lambda x: x)
 
+    def test_mse_loss_size_warning(self):
+        i = torch.randn((10, 1), requires_grad=True)
+        t = torch.randn((10,))
+        with warnings.catch_warnings(record=True) as w:
+            # Ensure warnings are being shown
+            warnings.simplefilter("always")
+            # Trigger Warning
+            F.mse_loss(i, t)
+            # Check warning occurs
+            self.assertEqual(len(w), 1)
+            self.assertIn('Please ensure they have the same size.', str(w[0]))
+
     def test_nll_loss_mismatched_batch(self):
         x = torch.randn((10, 3), requires_grad=True)
         # t should have size (10,)
index 9990818..1a08505 100644 (file)
@@ -2071,9 +2071,10 @@ def binary_cross_entropy(input, target, weight=None, size_average=None,
         reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
     else:
         reduction_enum = _Reduction.get_enum(reduction)
-    if not (target.size() == input.size()):
+    if target.size() != input.size():
         warnings.warn("Using a target size ({}) that is different to the input size ({}) is deprecated. "
-                      "Please ensure they have the same size.".format(target.size(), input.size()))
+                      "Please ensure they have the same size.".format(target.size(), input.size()),
+                      stacklevel=2)
     if input.numel() != target.numel():
         raise ValueError("Target and input must have the same number of elements. target nelement ({}) "
                          "!= input nelement ({})".format(target.numel(), input.numel()))
@@ -2204,6 +2205,11 @@ def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'):
 
     See :class:`~torch.nn.MSELoss` for details.
     """
+    if not (target.size() == input.size()):
+        warnings.warn("Using a target size ({}) that is different to the input size ({}). "
+                      "This will likely lead to incorrect results due to broadcasting. "
+                      "Please ensure they have the same size.".format(target.size(), input.size()),
+                      stacklevel=2)
     if size_average is not None or reduce is not None:
         reduction = _Reduction.legacy_get_string(size_average, reduce)
     if target.requires_grad: