From: Aurélien Roy Date: Fri, 29 Mar 2019 03:46:03 +0000 (-0700) Subject: Target and input sizes mismatch warning in L1 Loss / L1 Smooth Loss (#18565) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~571 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=12abc8a99a5fc60603b3aecf5faa37600ad4fff6;p=platform%2Fupstream%2Fpytorch.git Target and input sizes mismatch warning in L1 Loss / L1 Smooth Loss (#18565) Summary: Addind the same warning message already present in the mse_loss function to the L1 losses when input and target sizes are different. Pull Request resolved: https://github.com/pytorch/pytorch/pull/18565 Differential Revision: D14671415 Pulled By: soumith fbshipit-source-id: 01f5e1fb1ea119dbb2aecf1d94d0cb462f284982 --- diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 1a08505..3517699 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2163,6 +2163,11 @@ def smooth_l1_loss(input, target, size_average=None, reduce=None, reduction='mea See :class:`~torch.nn.SmoothL1Loss` for details. """ + if not (target.size() == input.size()): + warnings.warn("Using a target size ({}) that is different to the input size ({}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.".format(target.size(), input.size()), + stacklevel=2) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) if target.requires_grad: @@ -2184,6 +2189,11 @@ def l1_loss(input, target, size_average=None, reduce=None, reduction='mean'): See :class:`~torch.nn.L1Loss` for details. """ + if not (target.size() == input.size()): + warnings.warn("Using a target size ({}) that is different to the input size ({}). " + "This will likely lead to incorrect results due to broadcasting. " + "Please ensure they have the same size.".format(target.size(), input.size()), + stacklevel=2) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) if target.requires_grad: