From d3bcba5f85f97ef273109924c695f33bf739e115 Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Sun, 29 Aug 2021 23:31:42 -0700 Subject: [PATCH] ENH Adds label_smoothing to cross entropy loss (#63122) Summary: Fixes https://github.com/pytorch/pytorch/issues/7455 Partially resolves pytorch/vision#4281 Pull Request resolved: https://github.com/pytorch/pytorch/pull/63122 Reviewed By: iramazanli Differential Revision: D30586076 Pulled By: jbschlosser fbshipit-source-id: 06afc3aa1f8b9edb07fe9ed68c58968ad1926924 --- aten/src/ATen/native/LossNLL.cpp | 79 ++++++++- aten/src/ATen/native/native_functions.yaml | 2 +- test/cpp/api/functional.cpp | 14 ++ test/cpp/api/modules.cpp | 25 +++ test/test_nn.py | 72 ++++++++ torch/csrc/api/include/torch/nn/functional/loss.h | 9 +- torch/csrc/api/include/torch/nn/options/loss.h | 2 + torch/csrc/api/src/nn/modules/loss.cpp | 3 +- torch/nn/functional.py | 8 +- torch/nn/functional.pyi.in | 3 +- torch/nn/modules/loss.py | 13 +- torch/onnx/symbolic_opset12.py | 6 +- torch/overrides.py | 2 +- torch/testing/_internal/common_nn.py | 199 +++++++++++++++++++++- 14 files changed, 412 insertions(+), 25 deletions(-) diff --git a/aten/src/ATen/native/LossNLL.cpp b/aten/src/ATen/native/LossNLL.cpp index c7c65f7..83f1699 100644 --- a/aten/src/ATen/native/LossNLL.cpp +++ b/aten/src/ATen/native/LossNLL.cpp @@ -459,9 +459,10 @@ TORCH_IMPL_FUNC(nll_loss_backward_out_cpu) Tensor cross_entropy_loss_prob_target( const Tensor& self, - const Tensor& target, + const Tensor& target_, const Tensor& weight, - int64_t reduction) { + int64_t reduction, + double label_smoothing) { const auto n_classes = self.size(1); TORCH_CHECK( !weight.defined() || (weight.dim() == 1 && weight.numel() == n_classes), @@ -472,6 +473,15 @@ Tensor cross_entropy_loss_prob_target( weight.sizes()); auto input = at::log_softmax(self, 1, self.scalar_type()); + Tensor target; + + if (label_smoothing > 0.0) { + TORCH_CHECK(label_smoothing <= 1.0, "label_smoothing must be between 0.0 and 1.0. Got: ", label_smoothing); + target = target_ * (1 - label_smoothing) + label_smoothing / n_classes; + } else { + target = target_; + } + if (weight.defined()) { // Expand weight to the correct number of dims for broadcasting with input / target auto weight_broadcast_shape = SmallBuffer(input.dim()); @@ -503,12 +513,66 @@ Tensor cross_entropy_loss_prob_target( } } +Tensor cross_entropy_loss_label_smoothing( + const Tensor& self, + const Tensor& target, + const Tensor& weight, + int64_t reduction, + int64_t ignore_index, + double label_smoothing) { + + auto input = at::log_softmax(self, 1, self.scalar_type()); + auto nllloss = at::nll_loss_nd(input, target, weight, reduction, ignore_index); + + auto n_classes = input.size(1); + + Tensor smooth_loss; + if (weight.defined()) { + // Expand weight to the correct number of dims for broadcasting with input / target + auto weight_broadcast_shape = SmallBuffer(input.dim()); + std::fill(weight_broadcast_shape.begin(), weight_broadcast_shape.end(), 1); + weight_broadcast_shape[1] = weight.size(0); + Tensor weight_ = weight.view(weight_broadcast_shape); + + smooth_loss = -(input * weight_).sum(1); + } else { + smooth_loss = -input.sum(1); + } + + if (ignore_index >= 0) { + smooth_loss.index_put_({target == ignore_index}, 0.0); + } + + Tensor ret; + switch (reduction) { + case Reduction::Mean: + if (weight.defined()) { + // TODO: This code can path can be removed if #61309 is resolved + // loss is normalized by the weights to be consistent with nll_loss_nd + ret = smooth_loss.sum() / weight.gather(0, target.flatten()).sum(); + } else { + ret = smooth_loss.mean(); + } + break; + case Reduction::Sum: + ret = smooth_loss.sum(); + break; + case Reduction::None: + ret = smooth_loss; + break; + default: + TORCH_CHECK(false, "Invalid reduction type encountered in cross_entropy: ", reduction); + } + return (1 - label_smoothing) * nllloss + ret * (label_smoothing / n_classes); +} + Tensor cross_entropy_loss( const Tensor& self, const Tensor& target, const c10::optional& weight, int64_t reduction, - int64_t ignore_index) { + int64_t ignore_index, + double label_smoothing) { Tensor ret; if (self.sizes() == target.sizes()) { // Assume soft targets when input and target shapes are the same @@ -519,7 +583,14 @@ Tensor cross_entropy_loss( // See [Note: hacky wrapper removal for optional tensor] c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight); const Tensor& weight_ = *weight_maybe_owned; - ret = cross_entropy_loss_prob_target(self, target, weight_, reduction); + ret = cross_entropy_loss_prob_target(self, target, weight_, reduction, label_smoothing); + } else if (label_smoothing > 0.0) { + TORCH_CHECK(label_smoothing <= 1.0, "label_smoothing must be between 0.0 and 1.0. Got: ", label_smoothing); + + // See [Note: hacky wrapper removal for optional tensor] + c10::MaybeOwned weight_maybe_owned = at::borrow_from_optional_tensor(weight); + const Tensor& weight_ = *weight_maybe_owned; + ret = cross_entropy_loss_label_smoothing(self, target, weight_, reduction, ignore_index, label_smoothing); } else { ret = at::nll_loss_nd( at::log_softmax(self, 1, self.scalar_type()), diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 224d850..688763e 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -6652,7 +6652,7 @@ device_check: NoCheck # TensorIterator variants: method -- func: cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, int ignore_index=-100) -> Tensor +- func: cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, int ignore_index=-100, float label_smoothing=0.0) -> Tensor python_module: nn - func: lstsq.X(Tensor self, Tensor A, *, Tensor(a!) X, Tensor(b!) qr) -> (Tensor(a!) solution, Tensor(b!) QR) diff --git a/test/cpp/api/functional.cpp b/test/cpp/api/functional.cpp index 2ecb8418..8b7889f 100644 --- a/test/cpp/api/functional.cpp +++ b/test/cpp/api/functional.cpp @@ -792,6 +792,20 @@ TEST_F(FunctionalTest, CrossEntropy) { ASSERT_TRUE(output.allclose(expected, 1e-04)); ASSERT_TRUE(F::cross_entropy(input, target).allclose(expected, 1e-04)); + + // label smoothing with class indices + input = torch::tensor({{3., 1.}, {1., 2.}}, torch::kFloat); + output = F::cross_entropy( + input, target, F::CrossEntropyFuncOptions().label_smoothing(0.15).reduction(torch::kMean)); + expected = torch::tensor(0.3326, torch::kFloat); + ASSERT_TRUE(output.allclose(expected, 1e-04)); + + // label smoothing with target probabilities + target = torch::tensor({{0.8, 0.2}, {0.1, 0.9}}, torch::kFloat); + output = F::cross_entropy( + input, target, F::CrossEntropyFuncOptions().label_smoothing(0.2).reduction(torch::kMean)); + expected = torch::tensor(0.5701, torch::kFloat); + ASSERT_TRUE(output.allclose(expected, 1e-04)); } TEST_F(FunctionalTest, MaxUnpool1d) { diff --git a/test/cpp/api/modules.cpp b/test/cpp/api/modules.cpp index 23d75ef..927d884 100644 --- a/test/cpp/api/modules.cpp +++ b/test/cpp/api/modules.cpp @@ -2315,6 +2315,31 @@ TEST_F(ModulesTest, CrossEntropyLoss) { ASSERT_TRUE( CrossEntropyLoss(CrossEntropyLossOptions().ignore_index(-100).reduction(torch::kMean)) ->forward(input, target).allclose(expected, 1e-04)); + + // label smoothing with class indices + loss = CrossEntropyLoss(CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kMean)); + input = torch::tensor({{3., 1.}, {1., 2.}}, torch::dtype(torch::kFloat).requires_grad(true)); + target = torch::tensor({0, 1}, torch::kLong); + output = loss->forward(input, target); + expected = torch::tensor(0.3326, torch::kFloat); + s = output.sum(); + s.backward(); + + ASSERT_TRUE(output.allclose(expected, 1e-04)); + ASSERT_EQ(input.sizes(), input.grad().sizes()); + + // label smoothing with with target probabilities + loss = CrossEntropyLoss(CrossEntropyLossOptions().label_smoothing(0.2).reduction(torch::kMean)); + input = torch::tensor({{3., 1.}, {1., 2.}}, torch::dtype(torch::kFloat).requires_grad(true)); + target = torch::tensor({{0.8, 0.2}, {0.1, 0.9}}, torch::kFloat); + output = loss->forward(input, target); + expected = torch::tensor(0.5701, torch::kFloat); + s = output.sum(); + s.backward(); + + ASSERT_TRUE(output.allclose(expected, 1e-04)); + ASSERT_EQ(input.sizes(), input.grad().sizes()); + } TEST_F(ModulesTest, CosineSimilarity) { diff --git a/test/test_nn.py b/test/test_nn.py index 7d26246..bb4dd59 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -17183,6 +17183,78 @@ class TestNNDeviceType(NNTestCase): output_one_hot = m(input, target_one_hot) self.assertEqual(output, output_one_hot) + def test_cross_entropy_label_smoothing_errors(self, device): + N, C = 3, 4 + input_args = [ + (torch.randn((N, C), device=device), torch.arange(0, C, device=device)), + (torch.randn((N, C), device=device), torch.randn(N, C, device=device)) + ] + for input_arg in input_args: + loss = nn.CrossEntropyLoss(label_smoothing=1.2) + with self.assertRaisesRegex(RuntimeError, + r"label_smoothing must be between 0\.0"): + loss(*input_arg) + + def test_cross_entropy_label_smoothing_consistent_index_target_and_probs(self, device): + N, C = 10, 4 + ks = range(5) + reductions = ['none', 'mean', 'sum'] + label_smoothings = [0.05, 0.15] + + for k, reduction, label_smoothing in product(ks, reductions, label_smoothings): + other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)] + input = torch.randn(N, C, *other_dims, device=device, requires_grad=True) + target = torch.empty(N, *other_dims, dtype=torch.long, device=device).random_(0, C) + + # construct target probablity that should have the same result as label_smoothing + target_proba = F.one_hot(target, num_classes=C) + # Need to put the C dim at index 1. + target_proba = target_proba.permute(0, -1, *range(1, target_proba.dim() - 1)) + target_mask = (target_proba == 1) + target_proba = target_proba.to(dtype=input.dtype) + + # y_k^ls = y_k * (1 - label_smoothing) + label_smoothing / n_classes + # Get one-hot representation of the target. + target_proba.masked_fill_(target_mask, 1 - label_smoothing + label_smoothing / C) + target_proba.masked_fill_(~target_mask, label_smoothing / C) + + loss = nn.CrossEntropyLoss(reduction=reduction) + output_with_prob = loss(input, target_proba) + + loss = nn.CrossEntropyLoss( + reduction=reduction, label_smoothing=label_smoothing) + output_with_index = loss(input, target) + + self.assertEqual(output_with_prob, output_with_index, + rtol=1e-07, atol=1e-05) + + def test_cross_entropy_label_smoothing_with_probs(self, device): + N, C = 10, 4 + ks = range(5) + reductions = ['none', 'mean', 'sum'] + label_smoothings = [0.05, 0.15] + + # Test with k-dimensional loss. + for k, label_smoothing in product(ks, label_smoothings): + other_dims = [torch.randint(2, 5, size=(1,)).item() for _ in range(k)] + input = torch.randn(N, C, *other_dims, device=device, requires_grad=True) + target = F.log_softmax(torch.randn(N, C, *other_dims, device=device), dim=1) + + for reduction in reductions: + # use with label_smoothing + loss = nn.CrossEntropyLoss(reduction=reduction, label_smoothing=label_smoothing) + output_with_smoothing = loss(input, target) + + # manually smoothing target + # class_proba^ls = class_proba * (1 - label_smoothing) + + # label_smoothing / n_classes + target_with_smoothing = target * (1 - label_smoothing) + label_smoothing / C + loss = nn.CrossEntropyLoss(reduction=reduction) + output_with_manual_smoothing = loss(input, target_with_smoothing) + + self.assertEqual(output_with_smoothing, output_with_manual_smoothing) + + def test_softshrink_negative(self, device): input = torch.randn(5, device=device, requires_grad=True) m = torch.nn.Softshrink(-1) diff --git a/torch/csrc/api/include/torch/nn/functional/loss.h b/torch/csrc/api/include/torch/nn/functional/loss.h index ea2f606..1fa91ad 100644 --- a/torch/csrc/api/include/torch/nn/functional/loss.h +++ b/torch/csrc/api/include/torch/nn/functional/loss.h @@ -824,13 +824,15 @@ inline Tensor cross_entropy( const Tensor& target, const Tensor& weight, int64_t ignore_index, - CrossEntropyFuncOptions::reduction_t reduction) { + CrossEntropyFuncOptions::reduction_t reduction, + double label_smoothing) { return torch::cross_entropy_loss( input, target, weight, enumtype::reduction_get_enum(reduction), - ignore_index); + ignore_index, + label_smoothing); } } // namespace detail #endif /* DOXYGEN_SHOULD_SKIP_THIS */ @@ -855,7 +857,8 @@ inline Tensor cross_entropy( target, options.weight(), options.ignore_index(), - options.reduction()); + options.reduction(), + options.label_smoothing()); } // ============================================================================ diff --git a/torch/csrc/api/include/torch/nn/options/loss.h b/torch/csrc/api/include/torch/nn/options/loss.h index d8ffd15..1479de5 100644 --- a/torch/csrc/api/include/torch/nn/options/loss.h +++ b/torch/csrc/api/include/torch/nn/options/loss.h @@ -662,6 +662,8 @@ struct TORCH_API CrossEntropyLossOptions { TORCH_ARG(int64_t, ignore_index) = -100; /// Specifies the reduction to apply to the output. Default: Mean TORCH_ARG(reduction_t, reduction) = torch::kMean; + /// Specifies the amount of smoothing when computing the loss. Default: 0.0 + TORCH_ARG(double, label_smoothing) = 0.0; }; namespace functional { diff --git a/torch/csrc/api/src/nn/modules/loss.cpp b/torch/csrc/api/src/nn/modules/loss.cpp index d5d8c68..dda67fe 100644 --- a/torch/csrc/api/src/nn/modules/loss.cpp +++ b/torch/csrc/api/src/nn/modules/loss.cpp @@ -378,7 +378,8 @@ Tensor CrossEntropyLossImpl::forward( target, weight, options.ignore_index(), - options.reduction()); + options.reduction(), + options.label_smoothing()); } // ============================================================================ diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 5212586..c11e261 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2772,6 +2772,7 @@ def cross_entropy( ignore_index: int = -100, reduce: Optional[bool] = None, reduction: str = "mean", + label_smoothing: float = 0.0, ) -> Tensor: r"""This criterion computes the cross entropy loss between input and target. @@ -2808,6 +2809,10 @@ def cross_entropy( elements in the output, ``'sum'``: the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. Examples:: @@ -2834,10 +2839,11 @@ def cross_entropy( ignore_index=ignore_index, reduce=reduce, reduction=reduction, + label_smoothing=label_smoothing, ) if size_average is not None or reduce is not None: reduction = _Reduction.legacy_get_string(size_average, reduce) - return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index) + return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing) def binary_cross_entropy( diff --git a/torch/nn/functional.pyi.in b/torch/nn/functional.pyi.in index 828f8df..cbd05d7 100644 --- a/torch/nn/functional.pyi.in +++ b/torch/nn/functional.pyi.in @@ -239,7 +239,8 @@ def kl_div(input: Tensor, target: Tensor, size_average: Optional[bool] = ..., re def cross_entropy(input: Tensor, target: Tensor, weight: Optional[Tensor] = ..., size_average: Optional[bool] = ..., - ignore_index: int = ..., reduce: Optional[bool] = ..., reduction: str = ...) -> Tensor: ... + ignore_index: int = ..., reduce: Optional[bool] = ..., reduction: str = ..., + label_smoothing: float = ...) -> Tensor: ... def binary_cross_entropy(input: Tensor, target: Tensor, weight: Optional[Tensor] = ..., diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index af1da83..d72c614 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -1104,6 +1104,10 @@ class CrossEntropyLoss(_WeightedLoss): and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: ``'mean'`` + label_smoothing (float, optional): A float in [0.0, 1.0]. Specifies the amount + of smoothing when computing the loss, where 0.0 means no smoothing. The targets + become a mixture of the original ground truth and a uniform distribution as described in + `Rethinking the Inception Architecture for Computer Vision `__. Default: :math:`0.0`. Shape: - Input: :math:`(N, C)` where `C = number of classes`, or @@ -1132,17 +1136,20 @@ class CrossEntropyLoss(_WeightedLoss): >>> output = loss(input, target) >>> output.backward() """ - __constants__ = ['ignore_index', 'reduction'] + __constants__ = ['ignore_index', 'reduction', 'label_smoothing'] ignore_index: int + label_smoothing: float def __init__(self, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100, - reduce=None, reduction: str = 'mean') -> None: + reduce=None, reduction: str = 'mean', label_smoothing: float = 0.0) -> None: super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction) self.ignore_index = ignore_index + self.label_smoothing = label_smoothing def forward(self, input: Tensor, target: Tensor) -> Tensor: return F.cross_entropy(input, target, weight=self.weight, - ignore_index=self.ignore_index, reduction=self.reduction) + ignore_index=self.ignore_index, reduction=self.reduction, + label_smoothing=self.label_smoothing) class MultiLabelSoftMarginLoss(_WeightedLoss): diff --git a/torch/onnx/symbolic_opset12.py b/torch/onnx/symbolic_opset12.py index d8f9541..ab39325 100644 --- a/torch/onnx/symbolic_opset12.py +++ b/torch/onnx/symbolic_opset12.py @@ -65,7 +65,7 @@ def nll_loss_nd(g, self, target, weight, reduction, ignore_index): return nll_loss(g, self, target, weight, reduction, ignore_index) -def cross_entropy_loss(g, self, target, weight, reduction, ignore_index): +def cross_entropy_loss(g, self, target, weight, reduction, ignore_index, label_smoothing): # none reduction : onnx::Constant[value={0}] # mean reduction : onnx::Constant[value={1}] # sum reduction : onnx::Constant[value={2}] @@ -73,6 +73,10 @@ def cross_entropy_loss(g, self, target, weight, reduction, ignore_index): reduction_vals = ["none", "mean", "sum"] reduction = reduction_vals[reduction] + label_smoothing = sym_help._maybe_get_const(label_smoothing, "f") + if label_smoothing > 0.0: + raise RuntimeError("Unsupported: ONNX does not support label_smoothing") + # in onnx SoftmaxCrossEntropyLoss specification, ignore_index is optional without default value. # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). ignore_index = sym_help._maybe_get_const(ignore_index, "i") diff --git a/torch/overrides.py b/torch/overrides.py index 09748b9..64b18b8 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -677,7 +677,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.nn.functional.cosine_embedding_loss: (lambda input1, input2, target, margin=0, size_average=None, reduce=None, reduction='mean': -1), torch.nn.functional.cross_entropy: (lambda input, target, weight=None, size_average=None, ignore_index=-100, - reduce=None, reduction="mean": -1), + reduce=None, reduction="mean", label_smoothing=0.0: -1), torch.nn.functional.ctc_loss: (lambda log_probs, targets, input_lengths, target_lengths, blank=0, reduction='mean', zero_infinity=False: -1), torch.nn.functional.dropout: lambda input, p=0.5, training=True, inplace=False: -1, diff --git a/torch/testing/_internal/common_nn.py b/torch/testing/_internal/common_nn.py index e0d09b7..73233df 100644 --- a/torch/testing/_internal/common_nn.py +++ b/torch/testing/_internal/common_nn.py @@ -4103,7 +4103,8 @@ def nlllossNd_reference(input, target, weight=None, ignore_index=-100, return output -def cross_entropy_loss_prob_target_reference(input, target, weight=None, reduction='mean'): +def cross_entropy_loss_prob_target_reference(input, target, weight=None, reduction='mean', + label_smoothing=0.0): assert input.dim() >= 2 input = torch.log_softmax(input, 1) @@ -4112,6 +4113,10 @@ def cross_entropy_loss_prob_target_reference(input, target, weight=None, reducti weight = torch.ones(C).type_as(input) weight = weight.view(1, C, *(1 for _ in input.shape[2:])) + if label_smoothing > 0.0: + assert label_smoothing <= 1.0 + target = (target * (1 - label_smoothing) + label_smoothing / C) + output = -(input * target * weight).sum(dim=1) if reduction == 'mean': return output.mean() @@ -4120,20 +4125,61 @@ def cross_entropy_loss_prob_target_reference(input, target, weight=None, reducti return output -def cross_entropy_loss_reference(input, target, weight=None, ignore_index=-100, reduction='mean'): +def cross_entropy_loss_indices_target_reference(input, target, weight=None, ignore_index=-100, + reduction='mean', label_smoothing=0.0): + log_softmax_input = torch.log_softmax(input, 1) + nllloss = F.nll_loss( + log_softmax_input, + target, + weight, + ignore_index=ignore_index, + reduction=reduction) + + if label_smoothing == 0.0: + return nllloss + + assert 0.0 < label_smoothing <= 1.0 + + input = torch.log_softmax(input, 1) + C = input.size(1) + if weight is not None: + input = input * weight.view(1, C, *(1 for _ in input.shape[2:])) + + smooth_loss = -torch.sum(input, 1) + + if ignore_index >= 0: + ignore_mask = target == ignore_index + smooth_loss.masked_fill_(ignore_mask, 0.0) + + if reduction == 'mean': + if weight is not None: + # TODO: This code can path can be removed if #61309 is resolved + # loss is normalized by the weights to be consistent with nll_loss_nd + ret = torch.sum(smooth_loss) / weight.gather(0, target.flatten()).sum() + else: + ret = torch.mean(smooth_loss) + elif reduction == 'sum': + ret = torch.sum(smooth_loss) + else: + ret = smooth_loss + + return (1 - label_smoothing) * nllloss + ret * (label_smoothing / C) + + +def cross_entropy_loss_reference(input, target, weight=None, ignore_index=-100, reduction='mean', + label_smoothing=0.0): if input.shape == target.shape: return cross_entropy_loss_prob_target_reference( input, target, weight=weight, - reduction=reduction) + reduction=reduction, + label_smoothing=label_smoothing) else: - return nlllossNd_reference( - torch.log_softmax(input, 1), - target, - weight, - ignore_index=ignore_index, - reduction=reduction) + return cross_entropy_loss_indices_target_reference( + input, target, weight=weight, reduction=reduction, + ignore_index=ignore_index, label_smoothing=label_smoothing + ) def nllloss_reference(input, target, weight=None, ignore_index=-100, @@ -4894,6 +4940,141 @@ criterion_tests = [ check_bfloat16=False, ), dict( + fullname='CrossEntropyLoss_2d_prob_target_smoothing_sum_reduction', + constructor=lambda *args, **kwargs: nn.CrossEntropyLoss(reduction='sum', + label_smoothing=0.15), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kSum)', + input_size=(5, 3), + target_fn=lambda: torch.rand(5, 3).softmax(dim=1), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_2d_prob_target_smoothing', + constructor=lambda *args: nn.CrossEntropyLoss(label_smoothing=0.15), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15)', + input_size=(5, 3), + target_fn=lambda: torch.rand(5, 3).softmax(dim=1), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_2d_prob_target_smoothing_weight', + constructor_args_fn=lambda: (torch.rand(3).abs(),), + constructor=lambda weight: nn.CrossEntropyLoss(weight, label_smoothing=0.15), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).weight(torch::rand(3).abs())', + input_size=(5, 3), + target_fn=lambda: torch.rand(5, 3).softmax(dim=1), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), weight=get_weight(m), label_smoothing=0.15), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_3d_prob_target_smoothing_sum_reduction', + constructor=lambda *args: nn.CrossEntropyLoss(reduction='sum', + label_smoothing=0.15), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kSum)', + input_size=(5, 3, 4), + target_fn=lambda: torch.rand(5, 3, 4).softmax(dim=1), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_3d_prob_target_smoothing', + constructor=lambda *args: nn.CrossEntropyLoss(label_smoothing=0.15), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15)', + input_size=(5, 3, 4), + target_fn=lambda: torch.rand(5, 3, 4).softmax(dim=1), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_3d_indices_target_smoothing', + constructor=lambda *args: nn.CrossEntropyLoss(label_smoothing=0.15), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15)', + input_size=(2, 3, 5), + target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_3d_indices_target_smoothing_ignore_index', + constructor=lambda *args: nn.CrossEntropyLoss(label_smoothing=0.15, ignore_index=1), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).ignore_index(1)', + input_size=(2, 3, 5), + target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15, ignore_index=1), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_3d_indices_target_smoothing_sum_reduction', + constructor=lambda *args: nn.CrossEntropyLoss(reduction='sum', label_smoothing=0.15), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kSum)', + input_size=(2, 3, 5), + target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_3d_indices_target_smoothing_sum_reduction_ignore_index', + constructor=lambda *args: nn.CrossEntropyLoss(reduction='sum', label_smoothing=0.15, + ignore_index=1), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kSum).ignore_index(1)', + input_size=(2, 3, 5), + target_fn=lambda: torch.rand(2, 5).mul(3).floor().long(), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15, ignore_index=1), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_2d_indices_target_smoothing', + constructor=lambda *args: nn.CrossEntropyLoss(label_smoothing=0.15), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15)', + input_size=(15, 10), + target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_2d_indices_target_smoothing_sum_reduction', + constructor=lambda *args: nn.CrossEntropyLoss(reduction='sum', label_smoothing=0.15), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).reduction(torch::kSum)', + input_size=(15, 10), + target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_2d_indices_target_smoothing_ignore_index', + constructor=lambda *args: nn.CrossEntropyLoss(label_smoothing=0.15, ignore_index=3), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).ignore_index(3)', + input_size=(15, 10), + target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), label_smoothing=0.15, ignore_index=3), + check_bfloat16=False, + ), + dict( + fullname='CrossEntropyLoss_2d_indices_target_smoothing_weight', + constructor_args_fn=lambda: (torch.rand(10).abs(),), + constructor=lambda weight: nn.CrossEntropyLoss(weight, label_smoothing=0.15), + cpp_constructor_args='torch::nn::CrossEntropyLossOptions().label_smoothing(0.15).weight(torch::rand(10).abs())', + input_size=(15, 10), + target_fn=lambda: torch.empty(15).uniform_().mul(10).floor().long(), + reference_fn=lambda i, t, m: + loss_reference_fns['CrossEntropyLoss'](i, t, reduction=get_reduction(m), weight=get_weight(m), label_smoothing=0.15), + check_bfloat16=False, + ), + dict( module_name='CrossEntropyLoss', constructor_args_fn=lambda: (torch.rand(3),), cpp_constructor_args='torch::nn::CrossEntropyLossOptions().weight(torch::rand(3))', -- 2.7.4