From: David Riazati Date: Wed, 5 Dec 2018 02:32:05 +0000 (-0800) Subject: Enable testing on Loss modules (#14778) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~2472 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a66669a110a6eeeef8e08293174138ca87dc1a66;p=platform%2Fupstream%2Fpytorch.git Enable testing on Loss modules (#14778) Summary: This PR adds `None` buffers as parameters (similarly to #14715). It also cleans up a bunch of the `test_jit.py` tests that should be covered by `common_nn.py` and brings in `criterion_tests` to test loss functions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14778 Differential Revision: D13330849 Pulled By: driazati fbshipit-source-id: 924cc4cf94e0dcd11e811a55222fd2ebc42a9e76 --- diff --git a/test/common_nn.py b/test/common_nn.py index eebd597..60a17c6 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -2639,7 +2639,7 @@ criterion_tests = [ ), dict( module_name='MultiMarginLoss', - constructor_args=(1, 1, torch.rand(10)), + constructor_args=(1, 1., torch.rand(10)), legacy_constructor_args=(1, torch.rand(10)), input_size=(5, 10), target_fn=lambda: torch.rand(5).mul(8).floor().long(), diff --git a/test/test_jit.py b/test/test_jit.py index b3a249b..0821fa4 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -14,7 +14,7 @@ from torch._six import inf, PY2 from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \ skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \ freeze_rng_state -from common_nn import module_tests, new_module_tests +from common_nn import module_tests, new_module_tests, criterion_tests from textwrap import dedent import os import io @@ -9838,23 +9838,6 @@ EXCLUDE_SCRIPT_MODULES = { 'test_nn_AdaptiveAvgPool3d_tuple_none', 'test_nn_AdaptiveMaxPool2d_tuple_none', 'test_nn_AdaptiveMaxPool3d_tuple_none', - 'test_nn_LayerNorm_1d_elementwise_affine', - 'test_nn_LayerNorm_1d_no_elementwise_affine', - 'test_nn_LayerNorm_3d_elementwise_affine', - 'test_nn_LayerNorm_3d_no_elementwise_affine', - 'test_nn_Linear_no_bias', - - # unsupported None parameter - 'test_nn_BCELoss_weights', - 'test_nn_CrossEntropyLoss', - 'test_nn_NLLLoss_weights', - 'test_nn_NLLLoss_ignore_index', - 'test_nn_NLLLoss', - 'test_nn_MultiMarginLoss', - 'test_nn_NLLLoss_weights_ignore_index', - 'test_nn_NLLLoss_weights_ignore_index_neg', - 'test_nn_BCEWithLogitsLoss_weights', - 'test_nn_BCELoss', } DISABLE_AUTODIFF_SUBGRAPH_INLINING = { @@ -10056,7 +10039,7 @@ def check_against_reference(self, func, reference_func, args, kwargs=None, self.assertEqual(outputs, outputs_test) self.assertEqual(grads, grads_test) for g2, g2_test in zip(grads2, grads2_test): - if g2 is None and g2_ge is None: + if g2 is None and g2_test is None: continue self.assertTrue(torch.allclose(g2, g2_test, atol=5e-4, rtol=1e-4)) @@ -10378,65 +10361,6 @@ EXCLUDE_MODULE_EXPORT_IMPORT = { 'AdaptiveAvgPool3d', } -local_module_tests = [] - - -def to_module_test_format(tup): - dic = dict(module_name=tup[0], constructor_args=tup[1], input_fn=lambda: tup[2]) - if len(tup) >= 5: - dic['desc'] = tup[4] - local_module_tests.append(dic) - - -def add_interpolate_module_tests(): - # logic from test_interpolate in test_nn.py - def _make_input(dim): - size = [1, 1] - size += [2] * dim - return torch.ones(size, requires_grad=True) - - i = 0 - size = None - for scale_factor in [0.5, 1.5, 2.0]: - for mode in ['nearest', 'area']: - args = (size, scale_factor, mode) - for input in [_make_input(1), _make_input(2), _make_input(3)]: - to_module_test_format(('Upsample', args, input, False, str(i))) - i = i + 1 - - for align_corners in [True, False]: - args = (size, scale_factor, 'linear', align_corners) - to_module_test_format(('Upsample', args, _make_input(1), False, str(i))) - i = i + 1 - - args = (size, scale_factor, 'bilinear', align_corners) - to_module_test_format(('Upsample', args, _make_input(2), False, str(i))) - i = i + 1 - - args = (size, scale_factor, 'trilinear', align_corners) - to_module_test_format(('Upsample', args, _make_input(3), False, str(i))) - i = i + 1 - - # test_upsamplingTrilinear3d_spatial_invariance - scale_factor = 3. - args = (size, scale_factor, 'trilinear', False) - in_t_9 = torch.zeros(1, 1, 9, 9, 9) - in_t_9[:, :, :4, :4, :4].normal_() - to_module_test_format(('Upsample', args, in_t_9, False, str(i))) - i = i + 1 - - # testing where size is not none test_upsamplingNearest2d - size = 4 - scale_factor = None - in_t = torch.ones(1, 1, 2, 2) - - args = (size, scale_factor) - to_module_test_format(('UpsamplingNearest2d', args, Variable(in_t), False,)) - to_module_test_format(('UpsamplingBilinear2d', args, Variable(in_t), False,)) - - -add_interpolate_module_tests() - # NB: JIT script tests for all nn functional interfaces, script mode does # not support in_place operations yet, so no inplace operation tests added. # removed all the deprecated functions @@ -10556,7 +10480,7 @@ nn_functional_tests = [ ('gumbel_softmax', (S, S), (2., True,), 'hard'), ('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),), ('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)), \ - 1, 1, non_differentiable(torch.randn(S))),), + 1, 1., non_differentiable(torch.randn(S))),), ('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)), \ non_differentiable(torch.randn(3, 2))),), ('binary_cross_entropy', torch.randn(3, 2).sigmoid(), @@ -10593,141 +10517,6 @@ additional_module_tests = [ input_size=(S, S), extra_args=((S, S),) ), - dict( # noqa: C408 - module_name='L1Loss', - input_fn=lambda: ((2, 3, 4), (2, 3, 4)), - ), - dict( # noqa: C408 - module_name='NLLLoss', - input_fn=lambda: (torch.rand(15, 10).log(), torch.Tensor(15).uniform_().mul(10).floor().long()), - check_sum_reduction=True - ), - dict( # noqa: C408 - module_name='NLLLoss', - constructor_args=(None, None, 2), - input_fn=lambda: (torch.rand(15, 10).log(), torch.Tensor(15).uniform_().mul(10).floor().long()), - desc='ignore_index' - ), - dict( # noqa: C408 - module_name='NLLLoss', - constructor_args_fn=lambda: (torch.rand(10),), - input_fn=lambda: (torch.rand(15, 10).add(1e-2).log(), torch.Tensor(15).uniform_().mul(10).floor().long()), - desc='weights', - ), - dict( # noqa: C408 - module_name='NLLLoss', - constructor_args_fn=lambda: (torch.rand(10), None, 2), - input_fn=lambda: (torch.rand(15, 10).add(1e-2).log(), torch.Tensor(15).uniform_().mul(10).floor().long()), - desc='weights_ignore_index' - ), - dict( # noqa: C408 - module_name='NLLLoss', - constructor_args_fn=lambda: (torch.rand(10), None, -1), - input_fn=lambda: - (torch.rand(15, 10).add(1e-2).log(), - torch.Tensor(15).uniform_().mul(10 + 1).floor().long() - 1), - desc='weights_ignore_index_neg' - ), - dict( # noqa: C408 - module_name='KLDivLoss', - input_fn=lambda: (torch.rand(10, 10).log(), torch.rand(10, 10)), - ), - dict( # noqa: C408 - module_name='MSELoss', - input_fn=lambda: ((2, 3, 4, 5), (2, 3, 4, 5)), - ), - dict( # noqa: C408 - module_name='BCELoss', - input_fn=lambda: (torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2), torch.randn(15, 10).gt(0).double()), - no_grad=True, - ), - dict( # noqa: C408 - module_name='BCELoss', - constructor_args_fn=lambda: (torch.rand(10),), - input_fn=lambda: (torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2), torch.randn(15, 10).gt(0).double()), - desc='weights', - no_grad=True, - ), - dict( # noqa: C408 - module_name='BCEWithLogitsLoss', - constructor_args=(torch.rand(10), False, None, 'mean', torch.rand(10)), - input_fn=lambda: (torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2), torch.randn(15, 10).gt(0).double()), - no_grad=True, - ), - dict( # noqa: C408 - module_name='BCEWithLogitsLoss', - constructor_args=(torch.rand(15, 10), False), - input_fn=lambda: (torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2), torch.randn(15, 10).gt(0).double()), - desc='weights', - ), - dict( # noqa: C408 - module_name='HingeEmbeddingLoss', - input_fn=lambda: (torch.randn(10), torch.randn(10).gt(0).double().mul_(2).sub(1)), - no_grad=True, - ), - dict( # noqa: C408 - module_name='HingeEmbeddingLoss', - constructor_args=(0.5,), - input_fn=lambda: (torch.randn(10), torch.randn(10).gt(0).double().mul_(2).sub(1)), - desc='margin', - no_grad=True, - ), - dict( # noqa: C408 - module_name='MultiLabelMarginLoss', - input_fn=lambda: (torch.rand(10,), torch.rand(10).mul(10).floor().long()), - no_grad=True, - ), - dict( # noqa: C408 - module_name='SmoothL1Loss', - input_fn=lambda: ((5, 10), (5, 10)), - ), - dict( # noqa: C408 - module_name='SoftMarginLoss', - input_fn=lambda: (torch.randn(5, 5).sign(), torch.randn(5, 5).sign()), - no_grad=True, - ), - dict( # noqa: C408 - module_name='CrossEntropyLoss', - input_fn=lambda: (torch.randn(15, 10), torch.Tensor(15).uniform_().mul(10).floor().long()), - ), - dict( # noqa: C408 - module_name='MultiLabelSoftMarginLoss', - constructor_args=(torch.rand(10),), - input_fn=lambda: (torch.randn(5, 10), torch.rand(5, 10).mul(2).floor()), - no_grad=True, - ), - dict( # noqa: C408 - module_name='CosineEmbeddingLoss', - input_fn=lambda: (torch.rand(15, 10), torch.rand(15, 10), torch.randn(15).sign()), - no_grad=True, - ), - dict( # noqa: C408 - module_name='MarginRankingLoss', - input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10), torch.randn(50).sign()), - ), - dict( # noqa: C408 - module_name='TripletMarginLoss', - input_fn=lambda: (torch.randn(5, 10, requires_grad=True), torch.randn(5, 10, requires_grad=True), - torch.randn(5, 10, requires_grad=True)), - ), - dict( # noqa: C408 - module_name='MultiMarginLoss', - input_fn=lambda: (torch.randn(5, 10), torch.rand(5).mul(8).floor().long()), - no_grad=True, - ), - dict( # noqa: C408 - module_name='PoissonNLLLoss', - input_fn=lambda:(torch.randn(2, 3, 4, 5), torch.randn(2, 3, 4, 5).floor_().abs_()), - ), - dict( - module_name='CTCLoss', - constructor_args=(14,), - input_fn=lambda: (torch.randn(50, 16, 20).log_softmax(2), - torch.randint(1, 20, (16, 30), dtype=torch.long), - torch.full((16,), 50, dtype=torch.long), - torch.randint(10, 30, (16,), dtype=torch.long)), - no_grad=True, - ), ] @@ -10910,7 +10699,10 @@ def add_nn_module_test(*args, **kwargs): if "FunctionalModule" in str(nn_module): return - constructor_args = kwargs.get('constructor_args', ()) + if 'constructor_args_fn' in kwargs: + constructor_args = kwargs['constructor_args_fn']() + else: + constructor_args = kwargs.get('constructor_args', ()) # Construct a script module that passes arguments through # to self.submodule @@ -10954,18 +10746,27 @@ def add_nn_module_test(*args, **kwargs): module = nn_module(*constructor_args) return module(*args) - # Check against Python module as reference + # Set up inputs from tuple of sizes or constructor fn if 'input_fn' in kwargs: input = kwargs['input_fn']() else: input = (kwargs['input_size'],) + # Extra parameters to forward() if 'extra_args' in kwargs: input = input + kwargs['extra_args'] + if 'target_size' in kwargs: + input = input + (kwargs['target_size'],) + elif 'target_fn' in kwargs: + if torch.is_tensor(input): + input = (input,) + input = input + (kwargs['target_fn'](),) + args_variable, kwargs_variable = create_input(input) f_args_variable = deepcopy(unpack_variables(args_variable)) + # Check against Python module as reference check_against_reference(self, create_script_module, create_nn_module, f_args_variable, no_grad=no_grad) post_add_test(test_name, (), do_test) @@ -11133,7 +10934,11 @@ for test in autograd_method_tests: for test in nn_functional_tests: add_nn_functional_test(*test) -for test in module_tests + new_module_tests + additional_module_tests + local_module_tests: +for test in module_tests + new_module_tests + additional_module_tests: + add_nn_module_test(**test) + +for test in criterion_tests: + test['no_grad'] = True add_nn_module_test(**test) if __name__ == '__main__': diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 7146903..7ae2f67 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -1159,7 +1159,10 @@ if _enabled: elif isinstance(item, Parameter) or (isinstance(item, Module) and item is not self): ScriptModule.__setattr__(self, name, item) for name in original._buffers: - self.register_buffer(name, original._buffers[name]) + if original._buffers[name] is None: + object.__setattr__(self, name, None) + else: + self.register_buffer(name, original._buffers[name]) # Copy constants self.__dict__["_constants_set"] = set(getattr(original, "__constants__", [])) diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 883d5ab..a3789f7 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2264,9 +2264,9 @@ def cosine_embedding_loss(input1, input2, target, margin=0, size_average=None, @torch._jit_internal.weak_script -def multi_margin_loss(input, target, p=1, margin=1, weight=None, size_average=None, +def multi_margin_loss(input, target, p=1, margin=1., weight=None, size_average=None, reduce=None, reduction='mean'): - # type: (Tensor, Tensor, int, int, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor + # type: (Tensor, Tensor, int, float, Optional[Tensor], Optional[bool], Optional[bool], str) -> Tensor r"""multi_margin_loss(input, target, p=1, margin=1, weight=None, size_average=None, reduce=None, reduction='mean') -> Tensor diff --git a/torch/nn/modules/loss.py b/torch/nn/modules/loss.py index 1f539e3..df6bb17 100644 --- a/torch/nn/modules/loss.py +++ b/torch/nn/modules/loss.py @@ -1094,7 +1094,7 @@ class MultiMarginLoss(_WeightedLoss): """ __constants__ = ['p', 'margin', 'weight', 'reduction'] - def __init__(self, p=1, margin=1, weight=None, size_average=None, + def __init__(self, p=1, margin=1., weight=None, size_average=None, reduce=None, reduction='mean'): super(MultiMarginLoss, self).__init__(weight, size_average, reduce, reduction) if p != 1 and p != 2: