def test_loss_equal_input_target_shape(self):
self._test_loss_equal_input_target_shape(lambda x: x)
+ def test_mse_loss_size_warning(self):
+ i = torch.randn((10, 1), requires_grad=True)
+ t = torch.randn((10,))
+ with warnings.catch_warnings(record=True) as w:
+ # Ensure warnings are being shown
+ warnings.simplefilter("always")
+ # Trigger Warning
+ F.mse_loss(i, t)
+ # Check warning occurs
+ self.assertEqual(len(w), 1)
+ self.assertIn('Please ensure they have the same size.', str(w[0]))
+
def test_nll_loss_mismatched_batch(self):
x = torch.randn((10, 3), requires_grad=True)
# t should have size (10,)
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
else:
reduction_enum = _Reduction.get_enum(reduction)
- if not (target.size() == input.size()):
+ if target.size() != input.size():
warnings.warn("Using a target size ({}) that is different to the input size ({}) is deprecated. "
- "Please ensure they have the same size.".format(target.size(), input.size()))
+ "Please ensure they have the same size.".format(target.size(), input.size()),
+ stacklevel=2)
if input.numel() != target.numel():
raise ValueError("Target and input must have the same number of elements. target nelement ({}) "
"!= input nelement ({})".format(target.numel(), input.numel()))
See :class:`~torch.nn.MSELoss` for details.
"""
+ if not (target.size() == input.size()):
+ warnings.warn("Using a target size ({}) that is different to the input size ({}). "
+ "This will likely lead to incorrect results due to broadcasting. "
+ "Please ensure they have the same size.".format(target.size(), input.size()),
+ stacklevel=2)
if size_average is not None or reduce is not None:
reduction = _Reduction.legacy_get_string(size_average, reduce)
if target.requires_grad: