From ba70cf22fa867ada25f0dace95868999e53ef91f Mon Sep 17 00:00:00 2001 From: Elias Ellison Date: Tue, 4 Dec 2018 12:27:22 -0800 Subject: [PATCH] Loss (#14720) Summary: Adding Loss modules to script. Some of the modules have an optional tensor parameter. I will wait until wanchao's diff to support optional tensors is landed before landing this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14720 Differential Revision: D13317990 Pulled By: eellison fbshipit-source-id: 535925bdf126d28d9e7d64077b83ebd836a5beba --- test/test_jit.py | 161 ++++++++++++++++++++++++++++++++++++++++++++++- torch/nn/modules/loss.py | 71 ++++++++++++++++++++- 2 files changed, 228 insertions(+), 4 deletions(-) diff --git a/test/test_jit.py b/test/test_jit.py index b75b133..9106b26 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -9729,6 +9729,23 @@ EXCLUDE_SCRIPT_MODULES = { 'test_nn_AdaptiveAvgPool3d_tuple_none', 'test_nn_AdaptiveMaxPool2d_tuple_none', 'test_nn_AdaptiveMaxPool3d_tuple_none', + 'test_nn_LayerNorm_1d_elementwise_affine', + 'test_nn_LayerNorm_1d_no_elementwise_affine', + 'test_nn_LayerNorm_3d_elementwise_affine', + 'test_nn_LayerNorm_3d_no_elementwise_affine', + 'test_nn_Linear_no_bias', + + # unsupported None parameter + 'test_nn_BCELoss_weights', + 'test_nn_CrossEntropyLoss', + 'test_nn_NLLLoss_weights', + 'test_nn_NLLLoss_ignore_index', + 'test_nn_NLLLoss', + 'test_nn_MultiMarginLoss', + 'test_nn_NLLLoss_weights_ignore_index', + 'test_nn_NLLLoss_weights_ignore_index_neg', + 'test_nn_BCEWithLogitsLoss_weights', + 'test_nn_BCELoss', } DISABLE_AUTODIFF_SUBGRAPH_INLINING = { @@ -10467,6 +10484,141 @@ additional_module_tests = [ input_size=(S, S), extra_args=((S, S),) ), + dict( # noqa: C408 + module_name='L1Loss', + input_fn=lambda: ((2, 3, 4), (2, 3, 4)), + ), + dict( # noqa: C408 + module_name='NLLLoss', + input_fn=lambda: (torch.rand(15, 10).log(), torch.Tensor(15).uniform_().mul(10).floor().long()), + check_sum_reduction=True + ), + dict( # noqa: C408 + module_name='NLLLoss', + constructor_args=(None, None, 2), + input_fn=lambda: (torch.rand(15, 10).log(), torch.Tensor(15).uniform_().mul(10).floor().long()), + desc='ignore_index' + ), + dict( # noqa: C408 + module_name='NLLLoss', + constructor_args_fn=lambda: (torch.rand(10),), + input_fn=lambda: (torch.rand(15, 10).add(1e-2).log(), torch.Tensor(15).uniform_().mul(10).floor().long()), + desc='weights', + ), + dict( # noqa: C408 + module_name='NLLLoss', + constructor_args_fn=lambda: (torch.rand(10), None, 2), + input_fn=lambda: (torch.rand(15, 10).add(1e-2).log(), torch.Tensor(15).uniform_().mul(10).floor().long()), + desc='weights_ignore_index' + ), + dict( # noqa: C408 + module_name='NLLLoss', + constructor_args_fn=lambda: (torch.rand(10), None, -1), + input_fn=lambda: + (torch.rand(15, 10).add(1e-2).log(), + torch.Tensor(15).uniform_().mul(10 + 1).floor().long() - 1), + desc='weights_ignore_index_neg' + ), + dict( # noqa: C408 + module_name='KLDivLoss', + input_fn=lambda: (torch.rand(10, 10).log(), torch.rand(10, 10)), + ), + dict( # noqa: C408 + module_name='MSELoss', + input_fn=lambda: ((2, 3, 4, 5), (2, 3, 4, 5)), + ), + dict( # noqa: C408 + module_name='BCELoss', + input_fn=lambda: (torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2), torch.randn(15, 10).gt(0).double()), + no_grad=True, + ), + dict( # noqa: C408 + module_name='BCELoss', + constructor_args_fn=lambda: (torch.rand(10),), + input_fn=lambda: (torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2), torch.randn(15, 10).gt(0).double()), + desc='weights', + no_grad=True, + ), + dict( # noqa: C408 + module_name='BCEWithLogitsLoss', + constructor_args=(torch.rand(10), False, None, 'mean', torch.rand(10)), + input_fn=lambda: (torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2), torch.randn(15, 10).gt(0).double()), + no_grad=True, + ), + dict( # noqa: C408 + module_name='BCEWithLogitsLoss', + constructor_args=(torch.rand(15, 10), False), + input_fn=lambda: (torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2), torch.randn(15, 10).gt(0).double()), + desc='weights', + ), + dict( # noqa: C408 + module_name='HingeEmbeddingLoss', + input_fn=lambda: (torch.randn(10), torch.randn(10).gt(0).double().mul_(2).sub(1)), + no_grad=True, + ), + dict( # noqa: C408 + module_name='HingeEmbeddingLoss', + constructor_args=(0.5,), + input_fn=lambda: (torch.randn(10), torch.randn(10).gt(0).double().mul_(2).sub(1)), + desc='margin', + no_grad=True, + ), + dict( # noqa: C408 + module_name='MultiLabelMarginLoss', + input_fn=lambda: (torch.rand(10,), torch.rand(10).mul(10).floor().long()), + no_grad=True, + ), + dict( # noqa: C408 + module_name='SmoothL1Loss', + input_fn=lambda: ((5, 10), (5, 10)), + ), + dict( # noqa: C408 + module_name='SoftMarginLoss', + input_fn=lambda: (torch.randn(5, 5).sign(), torch.randn(5, 5).sign()), + no_grad=True, + ), + dict( # noqa: C408 + module_name='CrossEntropyLoss', + input_fn=lambda: (torch.randn(15, 10), torch.Tensor(15).uniform_().mul(10).floor().long()), + ), + dict( # noqa: C408 + module_name='MultiLabelSoftMarginLoss', + constructor_args=(torch.rand(10),), + input_fn=lambda: (torch.randn(5, 10), torch.rand(5, 10).mul(2).floor()), + no_grad=True, + ), + dict( # noqa: C408 + module_name='CosineEmbeddingLoss', + input_fn=lambda: (torch.rand(15, 10), torch.rand(15, 10), torch.randn(15).sign()), + no_grad=True, + ), + dict( # noqa: C408 + module_name='MarginRankingLoss', + input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10), torch.randn(50).sign()), + ), + dict( # noqa: C408 + module_name='TripletMarginLoss', + input_fn=lambda: (torch.randn(5, 10, requires_grad=True), torch.randn(5, 10, requires_grad=True), + torch.randn(5, 10, requires_grad=True)), + ), + dict( # noqa: C408 + module_name='MultiMarginLoss', + input_fn=lambda: (torch.randn(5, 10), torch.rand(5).mul(8).floor().long()), + no_grad=True, + ), + dict( # noqa: C408 + module_name='PoissonNLLLoss', + input_fn=lambda:(torch.randn(2, 3, 4, 5), torch.randn(2, 3, 4, 5).floor_().abs_()), + ), + dict( + module_name='CTCLoss', + constructor_args=(14,), + input_fn=lambda: (torch.randn(50, 16, 20).log_softmax(2), + torch.randint(1, 20, (16, 30), dtype=torch.long), + torch.full((16,), 50, dtype=torch.long), + torch.randint(10, 30, (16,), dtype=torch.long)), + no_grad=True, + ), ] @@ -10621,6 +10773,8 @@ def add_nn_module_test(*args, **kwargs): elif 'constructor' in kwargs: name = kwargs['constructor'].__name__ + no_grad = False if 'no_grad' not in kwargs else kwargs['no_grad'] + module_name = name.split("_")[0] module = getattr(torch.nn, module_name, None) @@ -10643,6 +10797,10 @@ def add_nn_module_test(*args, **kwargs): nn_module = kwargs['constructor'] else: nn_module = getattr(torch.nn, name) + + if "FunctionalModule" in str(nn_module): + return + constructor_args = kwargs.get('constructor_args', ()) # Construct a script module that passes arguments through @@ -10699,7 +10857,7 @@ def add_nn_module_test(*args, **kwargs): args_variable, kwargs_variable = create_input(input) f_args_variable = deepcopy(unpack_variables(args_variable)) - check_against_reference(self, create_script_module, create_nn_module, f_args_variable) + check_against_reference(self, create_script_module, create_nn_module, f_args_variable, no_grad=no_grad) post_add_test(test_name, (), do_test) @@ -10859,6 +11017,7 @@ class TestAsync(JitTestCase): self.assertEqual(y2, foo2(x1, x2)) self.assertEqual(y3, foo3(x1, x2, x3)) + for test in autograd_method_tests: add_autograd_test(*test) diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index cdc3a7a..1f539e3 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -6,6 +6,7 @@ from .container import Sequential from .activation import LogSoftmax from .. import functional as F from .. import _reduction as _Reduction +from ..._jit_internal import weak_module, weak_script_method class _Loss(Module): @@ -23,6 +24,7 @@ class _WeightedLoss(_Loss): self.register_buffer('weight', weight) +@weak_module class L1Loss(_Loss): r"""Creates a criterion that measures the mean absolute error (MAE) between each element in the input `x` and target `y`. @@ -81,13 +83,17 @@ class L1Loss(_Loss): >>> output = loss(input, target) >>> output.backward() """ + __constants__ = ['reduction'] + def __init__(self, size_average=None, reduce=None, reduction='mean'): super(L1Loss, self).__init__(size_average, reduce, reduction) + @weak_script_method def forward(self, input, target): return F.l1_loss(input, target, reduction=self.reduction) +@weak_module class NLLLoss(_WeightedLoss): r"""The negative log likelihood loss. It is useful to train a classification problem with `C` classes. @@ -192,16 +198,19 @@ class NLLLoss(_WeightedLoss): >>> output = loss(m(conv(data)), target) >>> output.backward() """ + __constants__ = ['ignore_index', 'weight', 'reduction'] def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean'): super(NLLLoss, self).__init__(weight, size_average, reduce, reduction) self.ignore_index = ignore_index + @weak_script_method def forward(self, input, target): return F.nll_loss(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction) +@weak_module class NLLLoss2d(NLLLoss): def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean'): @@ -211,6 +220,7 @@ class NLLLoss2d(NLLLoss): super(NLLLoss2d, self).__init__(weight, size_average, ignore_index, reduce, reduction) +@weak_module class PoissonNLLLoss(_Loss): r"""Negative log likelihood loss with Poisson distribution of target. @@ -261,6 +271,8 @@ class PoissonNLLLoss(_Loss): >>> output = loss(log_input, target) >>> output.backward() """ + __constants__ = ['log_input', 'full', 'eps', 'reduction'] + def __init__(self, log_input=True, full=False, size_average=None, eps=1e-8, reduce=None, reduction='mean'): super(PoissonNLLLoss, self).__init__(size_average, reduce, reduction) @@ -268,11 +280,13 @@ class PoissonNLLLoss(_Loss): self.full = full self.eps = eps + @weak_script_method def forward(self, log_input, target): return F.poisson_nll_loss(log_input, target, log_input=self.log_input, full=self.full, eps=self.eps, reduction=self.reduction) +@weak_module class KLDivLoss(_Loss): r"""The `Kullback-Leibler divergence`_ Loss @@ -343,13 +357,17 @@ class KLDivLoss(_Loss): the same shape as the input """ + __constants__ = ['reduction'] + def __init__(self, size_average=None, reduce=None, reduction='mean'): super(KLDivLoss, self).__init__(size_average, reduce, reduction) + @weak_script_method def forward(self, input, target): return F.kl_div(input, target, reduction=self.reduction) +@weak_module class MSELoss(_Loss): r"""Creates a criterion that measures the mean squared error (squared L2 norm) between each element in the input `x` and target `y`. @@ -407,13 +425,17 @@ class MSELoss(_Loss): >>> output = loss(input, target) >>> output.backward() """ + __constants__ = ['reduction'] + def __init__(self, size_average=None, reduce=None, reduction='mean'): super(MSELoss, self).__init__(size_average, reduce, reduction) + @weak_script_method def forward(self, input, target): return F.mse_loss(input, target, reduction=self.reduction) +@weak_module class BCELoss(_WeightedLoss): r"""Creates a criterion that measures the Binary Cross Entropy between the target and the output: @@ -472,13 +494,17 @@ class BCELoss(_WeightedLoss): >>> output = loss(m(input), target) >>> output.backward() """ + __constants__ = ['reduction', 'weight'] + def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'): super(BCELoss, self).__init__(weight, size_average, reduce, reduction) + @weak_script_method def forward(self, input, target): return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction) +@weak_module class BCEWithLogitsLoss(_Loss): r"""This loss combines a `Sigmoid` layer and the `BCELoss` in one single class. This version is more numerically stable than using a plain `Sigmoid` @@ -554,11 +580,14 @@ class BCEWithLogitsLoss(_Loss): >>> output = loss(input, target) >>> output.backward() """ + __constants__ = ['weight', 'pos_weight', 'reduction'] + def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None): super(BCEWithLogitsLoss, self).__init__(size_average, reduce, reduction) self.register_buffer('weight', weight) self.register_buffer('pos_weight', pos_weight) + @weak_script_method def forward(self, input, target): return F.binary_cross_entropy_with_logits(input, target, self.weight, @@ -566,6 +595,7 @@ class BCEWithLogitsLoss(_Loss): reduction=self.reduction) +@weak_module class HingeEmbeddingLoss(_Loss): r"""Measures the loss given an input tensor `x` and a labels tensor `y` containing values (`1` or `-1`). @@ -614,15 +644,18 @@ class HingeEmbeddingLoss(_Loss): - Target: Same shape as input. - Output: scalar. If reduce is ``False``, then same shape as the input """ + __constants__ = ['margin', 'reduction'] def __init__(self, margin=1.0, size_average=None, reduce=None, reduction='mean'): super(HingeEmbeddingLoss, self).__init__(size_average, reduce, reduction) self.margin = margin + @weak_script_method def forward(self, input, target): return F.hinge_embedding_loss(input, target, margin=self.margin, reduction=self.reduction) +@weak_module class MultiLabelMarginLoss(_Loss): r"""Creates a criterion that optimizes a multi-class multi-classification hinge loss (margin-based loss) between input `x` (a 2D mini-batch `Tensor`) @@ -667,13 +700,17 @@ class MultiLabelMarginLoss(_Loss): - Target: :math:`(C)` or :math:`(N, C)`, same shape as the input. - Output: scalar. If `reduce` is False, then `(N)`. """ + __constants__ = ['reduction'] + def __init__(self, size_average=None, reduce=None, reduction='mean'): super(MultiLabelMarginLoss, self).__init__(size_average, reduce, reduction) + @weak_script_method def forward(self, input, target): return F.multilabel_margin_loss(input, target, reduction=self.reduction) +@weak_module class SmoothL1Loss(_Loss): r"""Creates a criterion that uses a squared term if the absolute element-wise error falls below 1 and an L1 term otherwise. @@ -723,13 +760,17 @@ class SmoothL1Loss(_Loss): :math:`(N, *)`, same shape as the input """ + __constants__ = ['reduction'] + def __init__(self, size_average=None, reduce=None, reduction='mean'): super(SmoothL1Loss, self).__init__(size_average, reduce, reduction) + @weak_script_method def forward(self, input, target): return F.smooth_l1_loss(input, target, reduction=self.reduction) +@weak_module class SoftMarginLoss(_Loss): r"""Creates a criterion that optimizes a two-class classification logistic loss between input tensor `x` and target tensor `y` (containing 1 or @@ -761,13 +802,17 @@ class SoftMarginLoss(_Loss): - Output: scalar. If reduce is ``False``, then same shape as the input """ + __constants__ = ['reduction'] + def __init__(self, size_average=None, reduce=None, reduction='mean'): super(SoftMarginLoss, self).__init__(size_average, reduce, reduction) + @weak_script_method def forward(self, input, target): return F.soft_margin_loss(input, target, reduction=self.reduction) +@weak_module class CrossEntropyLoss(_WeightedLoss): r"""This criterion combines :func:`nn.LogSoftmax` and :func:`nn.NLLLoss` in one single class. @@ -846,17 +891,20 @@ class CrossEntropyLoss(_WeightedLoss): >>> output = loss(input, target) >>> output.backward() """ + __constants__ = ['weight', 'ignore_index', 'reduction'] def __init__(self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean'): super(CrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction) self.ignore_index = ignore_index + @weak_script_method def forward(self, input, target): return F.cross_entropy(input, target, weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction) +@weak_module class MultiLabelSoftMarginLoss(_WeightedLoss): r"""Creates a criterion that optimizes a multi-label one-versus-all loss based on max-entropy, between input `x` and target `y` of size `(N, C)`. @@ -893,14 +941,17 @@ class MultiLabelSoftMarginLoss(_WeightedLoss): - Target: :math:`(N, C)`, same shape as the input. - Output: scalar. If `reduce` is False, then `(N)`. """ + __constants__ = ['weight', 'reduction'] def __init__(self, weight=None, size_average=None, reduce=None, reduction='mean'): super(MultiLabelSoftMarginLoss, self).__init__(weight, size_average, reduce, reduction) + @weak_script_method def forward(self, input, target): return F.multilabel_soft_margin_loss(input, target, weight=self.weight, reduction=self.reduction) +@weak_module class CosineEmbeddingLoss(_Loss): r"""Creates a criterion that measures the loss given input tensors :math:`x_1`, :math:`x_2` and a `Tensor` label `y` with values 1 or -1. @@ -936,15 +987,18 @@ class CosineEmbeddingLoss(_Loss): and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: 'mean' """ + __constants__ = ['margin', 'reduction'] - def __init__(self, margin=0, size_average=None, reduce=None, reduction='mean'): + def __init__(self, margin=0., size_average=None, reduce=None, reduction='mean'): super(CosineEmbeddingLoss, self).__init__(size_average, reduce, reduction) self.margin = margin + @weak_script_method def forward(self, input1, input2, target): return F.cosine_embedding_loss(input1, input2, target, margin=self.margin, reduction=self.reduction) +@weak_module class MarginRankingLoss(_Loss): r"""Creates a criterion that measures the loss given inputs `x1`, `x2`, two 1D mini-batch `Tensor`s, @@ -981,15 +1035,18 @@ class MarginRankingLoss(_Loss): - Target: :math:`(N)` - Output: scalar. If `reduce` is False, then `(N)`. """ + __constants__ = ['margin', 'reduction'] - def __init__(self, margin=0, size_average=None, reduce=None, reduction='mean'): + def __init__(self, margin=0., size_average=None, reduce=None, reduction='mean'): super(MarginRankingLoss, self).__init__(size_average, reduce, reduction) self.margin = margin + @weak_script_method def forward(self, input1, input2, target): return F.margin_ranking_loss(input1, input2, target, margin=self.margin, reduction=self.reduction) +@weak_module class MultiMarginLoss(_WeightedLoss): r"""Creates a criterion that optimizes a multi-class classification hinge loss (margin-based loss) between input `x` (a 2D mini-batch `Tensor`) and @@ -1035,6 +1092,7 @@ class MultiMarginLoss(_WeightedLoss): and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. Default: 'mean' """ + __constants__ = ['p', 'margin', 'weight', 'reduction'] def __init__(self, p=1, margin=1, weight=None, size_average=None, reduce=None, reduction='mean'): @@ -1045,11 +1103,13 @@ class MultiMarginLoss(_WeightedLoss): self.p = p self.margin = margin + @weak_script_method def forward(self, input, target): return F.multi_margin_loss(input, target, p=self.p, margin=self.margin, weight=self.weight, reduction=self.reduction) +@weak_module class TripletMarginLoss(_Loss): r"""Creates a criterion that measures the triplet loss given an input tensors x1, x2, x3 and a margin with a value greater than 0. @@ -1109,8 +1169,9 @@ class TripletMarginLoss(_Loss): .. _Learning shallow convolutional feature descriptors with triplet losses: http://www.iis.ee.ic.ac.uk/%7Evbalnt/shallow_descr/TFeat_paper.pdf """ + __constants__ = ['margin', 'p', 'eps', 'swap', 'reduction'] - def __init__(self, margin=1.0, p=2, eps=1e-6, swap=False, size_average=None, + def __init__(self, margin=1.0, p=2., eps=1e-6, swap=False, size_average=None, reduce=None, reduction='mean'): super(TripletMarginLoss, self).__init__(size_average, reduce, reduction) self.margin = margin @@ -1118,11 +1179,13 @@ class TripletMarginLoss(_Loss): self.eps = eps self.swap = swap + @weak_script_method def forward(self, anchor, positive, negative): return F.triplet_margin_loss(anchor, positive, negative, margin=self.margin, p=self.p, eps=self.eps, swap=self.swap, reduction=self.reduction) +@weak_module class CTCLoss(_Loss): r"""The Connectionist Temporal Classification loss. @@ -1174,11 +1237,13 @@ class CTCLoss(_Loss): """ + __constants__ = ['blank', 'reduction'] def __init__(self, blank=0, reduction='mean'): super(CTCLoss, self).__init__(reduction=reduction) self.blank = blank + @weak_script_method def forward(self, log_probs, targets, input_lengths, target_lengths): return F.ctc_loss(log_probs, targets, input_lengths, target_lengths, self.blank, self.reduction) -- 2.7.4