From 8bc5b867093f1b44f5168f3e41b1bdf0ba3e05a7 Mon Sep 17 00:00:00 2001 From: mc-robinson Date: Sun, 24 Mar 2019 19:17:00 -0700 Subject: [PATCH] Added tensor size warning to F.mse_loss() (#18349) Summary: To address the issue of broadcasting giving the wrong result in `nn.MSELoss()` as mentioned here https://github.com/pytorch/pytorch/issues/16045 . In particular, the issue often arises when computing the loss between tensors with shapes (n, 1) and (n,) Pull Request resolved: https://github.com/pytorch/pytorch/pull/18349 Differential Revision: D14594176 Pulled By: soumith fbshipit-source-id: f23ae68a4bf42f3554ad7678a314ba2c7532a6db --- test/test_nn.py | 12 ++++++++++++ torch/nn/functional.py | 10 ++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/test/test_nn.py b/test/test_nn.py index 897255c..b93f850 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4363,6 +4363,18 @@ class TestNN(NNTestCase): def test_loss_equal_input_target_shape(self): self._test_loss_equal_input_target_shape(lambda x: x) + def test_mse_loss_size_warning(self): + i = torch.randn((10, 1), requires_grad=True) + t = torch.randn((10,)) + with warnings.catch_warnings(record=True) as w: + # Ensure warnings are being shown + warnings.simplefilter("always") + # Trigger Warning + F.mse_loss(i, t) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertIn('Please ensure they have the same size.', str(w[0])) + def test_nll_loss_mismatched_batch(self): x = torch.randn((10, 3), requires_grad=True) # t should have size (10,) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 9990818..1a08505 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2071,9 +2071,10 @@ def binary_cross_entropy(input, target, weight=None, size_average=None, reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: reduction_enum = _Reduction.get_enum(reduction) - if not (target.size() == input.size()): + if target.size() != input.size(): warnings.warn("Using a target size ({}) that is different to the input size ({}) is deprecated. " - "Please ensure they have the same size.".format(target.size(), input.size())) + "Please ensure they have the same size.".format(target.size(), input.size()), + stacklevel=2) if input.numel() != target.numel(): raise ValueError("Target and input must have the same number of elements. target nelement ({}) " "!= input nelement ({})".format(target.numel(), input.numel())) @@ -2204,6 +2205,11 @@ def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'): See :class:`~torch.nn.MSELoss` for details. """ + if not (target.size() == input.size()): + warnings.warn("Using a target size ({}) that is different to the input size ({}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.".format(target.size(), input.size()), + stacklevel=2) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) if target.requires_grad: -- 2.7.4