From: mc-robinson Date: Mon, 25 Mar 2019 02:17:00 +0000 (-0700) Subject: Added tensor size warning to F.mse_loss() (#18349) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~656 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=8bc5b867093f1b44f5168f3e41b1bdf0ba3e05a7;p=platform%2Fupstream%2Fpytorch.git Added tensor size warning to F.mse_loss() (#18349) Summary: To address the issue of broadcasting giving the wrong result in `nn.MSELoss()` as mentioned here https://github.com/pytorch/pytorch/issues/16045 . In particular, the issue often arises when computing the loss between tensors with shapes (n, 1) and (n,) Pull Request resolved: https://github.com/pytorch/pytorch/pull/18349 Differential Revision: D14594176 Pulled By: soumith fbshipit-source-id: f23ae68a4bf42f3554ad7678a314ba2c7532a6db --- diff --git a/test/test_nn.py b/test/test_nn.py index 897255c..b93f850 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -4363,6 +4363,18 @@ class TestNN(NNTestCase): def test_loss_equal_input_target_shape(self): self._test_loss_equal_input_target_shape(lambda x: x) + def test_mse_loss_size_warning(self): + i = torch.randn((10, 1), requires_grad=True) + t = torch.randn((10,)) + with warnings.catch_warnings(record=True) as w: + # Ensure warnings are being shown + warnings.simplefilter("always") + # Trigger Warning + F.mse_loss(i, t) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertIn('Please ensure they have the same size.', str(w[0])) + def test_nll_loss_mismatched_batch(self): x = torch.randn((10, 3), requires_grad=True) # t should have size (10,) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 9990818..1a08505 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2071,9 +2071,10 @@ def binary_cross_entropy(input, target, weight=None, size_average=None, reduction_enum = _Reduction.legacy_get_enum(size_average, reduce) else: reduction_enum = _Reduction.get_enum(reduction) - if not (target.size() == input.size()): + if target.size() != input.size(): warnings.warn("Using a target size ({}) that is different to the input size ({}) is deprecated. " - "Please ensure they have the same size.".format(target.size(), input.size())) + "Please ensure they have the same size.".format(target.size(), input.size()), + stacklevel=2) if input.numel() != target.numel(): raise ValueError("Target and input must have the same number of elements. target nelement ({}) " "!= input nelement ({})".format(target.numel(), input.numel())) @@ -2204,6 +2205,11 @@ def mse_loss(input, target, size_average=None, reduce=None, reduction='mean'): See :class:`~torch.nn.MSELoss` for details. """ + if not (target.size() == input.size()): + warnings.warn("Using a target size ({}) that is different to the input size ({}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.".format(target.size(), input.size()), + stacklevel=2) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) if target.requires_grad: