Target and input sizes mismatch warning in L1 Loss / L1 Smooth Loss (#18565)
authorAurélien Roy <aurelien_roy@outlook.com>
Fri, 29 Mar 2019 03:46:03 +0000 (20:46 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 29 Mar 2019 03:49:51 +0000 (20:49 -0700)
Summary:
Addind the same warning message already present in the mse_loss function to the L1 losses when input and target sizes are different.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18565

Differential Revision: D14671415

Pulled By: soumith

fbshipit-source-id: 01f5e1fb1ea119dbb2aecf1d94d0cb462f284982

torch/nn/functional.py

index 1a08505..3517699 100644 (file)
@@ -2163,6 +2163,11 @@ def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mea
 
     See :class:`~torch.nn.SmoothL1Loss` for details.
     """
+    if not (target.size() == input.size()):
+        warnings.warn("Using a target size ({}) that is different to the input size ({}). "
+                      "This will likely lead to incorrect results due to broadcasting. "
+                      "Please ensure they have the same size.".format(target.size(), input.size()),
+                      stacklevel=2)
     if size_average is not None or reduce is not None:
         reduction = _Reduction.legacy_get_string(size_average, reduce)
     if target.requires_grad:
@@ -2184,6 +2189,11 @@ def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'):
 
     See :class:`~torch.nn.L1Loss` for details.
     """
+    if not (target.size() == input.size()):
+        warnings.warn("Using a target size ({}) that is different to the input size ({}). "
+                      "This will likely lead to incorrect results due to broadcasting. "
+                      "Please ensure they have the same size.".format(target.size(), input.size()),
+                      stacklevel=2)
     if size_average is not None or reduce is not None:
         reduction = _Reduction.legacy_get_string(size_average, reduce)
     if target.requires_grad: