from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
freeze_rng_state
-from common_nn import module_tests, new_module_tests
+from common_nn import module_tests, new_module_tests, criterion_tests
from textwrap import dedent
import os
import io
'test_nn_AdaptiveAvgPool3d_tuple_none',
'test_nn_AdaptiveMaxPool2d_tuple_none',
'test_nn_AdaptiveMaxPool3d_tuple_none',
- 'test_nn_LayerNorm_1d_elementwise_affine',
- 'test_nn_LayerNorm_1d_no_elementwise_affine',
- 'test_nn_LayerNorm_3d_elementwise_affine',
- 'test_nn_LayerNorm_3d_no_elementwise_affine',
- 'test_nn_Linear_no_bias',
-
- # unsupported None parameter
- 'test_nn_BCELoss_weights',
- 'test_nn_CrossEntropyLoss',
- 'test_nn_NLLLoss_weights',
- 'test_nn_NLLLoss_ignore_index',
- 'test_nn_NLLLoss',
- 'test_nn_MultiMarginLoss',
- 'test_nn_NLLLoss_weights_ignore_index',
- 'test_nn_NLLLoss_weights_ignore_index_neg',
- 'test_nn_BCEWithLogitsLoss_weights',
- 'test_nn_BCELoss',
}
DISABLE_AUTODIFF_SUBGRAPH_INLINING = {
self.assertEqual(outputs, outputs_test)
self.assertEqual(grads, grads_test)
for g2, g2_test in zip(grads2, grads2_test):
- if g2 is None and g2_ge is None:
+ if g2 is None and g2_test is None:
continue
self.assertTrue(torch.allclose(g2, g2_test, atol=5e-4, rtol=1e-4))
'AdaptiveAvgPool3d',
}
-local_module_tests = []
-
-
-def to_module_test_format(tup):
- dic = dict(module_name=tup[0], constructor_args=tup[1], input_fn=lambda: tup[2])
- if len(tup) >= 5:
- dic['desc'] = tup[4]
- local_module_tests.append(dic)
-
-
-def add_interpolate_module_tests():
- # logic from test_interpolate in test_nn.py
- def _make_input(dim):
- size = [1, 1]
- size += [2] * dim
- return torch.ones(size, requires_grad=True)
-
- i = 0
- size = None
- for scale_factor in [0.5, 1.5, 2.0]:
- for mode in ['nearest', 'area']:
- args = (size, scale_factor, mode)
- for input in [_make_input(1), _make_input(2), _make_input(3)]:
- to_module_test_format(('Upsample', args, input, False, str(i)))
- i = i + 1
-
- for align_corners in [True, False]:
- args = (size, scale_factor, 'linear', align_corners)
- to_module_test_format(('Upsample', args, _make_input(1), False, str(i)))
- i = i + 1
-
- args = (size, scale_factor, 'bilinear', align_corners)
- to_module_test_format(('Upsample', args, _make_input(2), False, str(i)))
- i = i + 1
-
- args = (size, scale_factor, 'trilinear', align_corners)
- to_module_test_format(('Upsample', args, _make_input(3), False, str(i)))
- i = i + 1
-
- # test_upsamplingTrilinear3d_spatial_invariance
- scale_factor = 3.
- args = (size, scale_factor, 'trilinear', False)
- in_t_9 = torch.zeros(1, 1, 9, 9, 9)
- in_t_9[:, :, :4, :4, :4].normal_()
- to_module_test_format(('Upsample', args, in_t_9, False, str(i)))
- i = i + 1
-
- # testing where size is not none test_upsamplingNearest2d
- size = 4
- scale_factor = None
- in_t = torch.ones(1, 1, 2, 2)
-
- args = (size, scale_factor)
- to_module_test_format(('UpsamplingNearest2d', args, Variable(in_t), False,))
- to_module_test_format(('UpsamplingBilinear2d', args, Variable(in_t), False,))
-
-
-add_interpolate_module_tests()
-
# NB: JIT script tests for all nn functional interfaces, script mode does
# not support in_place operations yet, so no inplace operation tests added.
# removed all the deprecated functions
('gumbel_softmax', (S, S), (2., True,), 'hard'),
('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),),
('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)), \
- 1, 1, non_differentiable(torch.randn(S))),),
+ 1, 1., non_differentiable(torch.randn(S))),),
('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)), \
non_differentiable(torch.randn(3, 2))),),
('binary_cross_entropy', torch.randn(3, 2).sigmoid(),
input_size=(S, S),
extra_args=((S, S),)
),
- dict( # noqa: C408
- module_name='L1Loss',
- input_fn=lambda: ((2, 3, 4), (2, 3, 4)),
- ),
- dict( # noqa: C408
- module_name='NLLLoss',
- input_fn=lambda: (torch.rand(15, 10).log(), torch.Tensor(15).uniform_().mul(10).floor().long()),
- check_sum_reduction=True
- ),
- dict( # noqa: C408
- module_name='NLLLoss',
- constructor_args=(None, None, 2),
- input_fn=lambda: (torch.rand(15, 10).log(), torch.Tensor(15).uniform_().mul(10).floor().long()),
- desc='ignore_index'
- ),
- dict( # noqa: C408
- module_name='NLLLoss',
- constructor_args_fn=lambda: (torch.rand(10),),
- input_fn=lambda: (torch.rand(15, 10).add(1e-2).log(), torch.Tensor(15).uniform_().mul(10).floor().long()),
- desc='weights',
- ),
- dict( # noqa: C408
- module_name='NLLLoss',
- constructor_args_fn=lambda: (torch.rand(10), None, 2),
- input_fn=lambda: (torch.rand(15, 10).add(1e-2).log(), torch.Tensor(15).uniform_().mul(10).floor().long()),
- desc='weights_ignore_index'
- ),
- dict( # noqa: C408
- module_name='NLLLoss',
- constructor_args_fn=lambda: (torch.rand(10), None, -1),
- input_fn=lambda:
- (torch.rand(15, 10).add(1e-2).log(),
- torch.Tensor(15).uniform_().mul(10 + 1).floor().long() - 1),
- desc='weights_ignore_index_neg'
- ),
- dict( # noqa: C408
- module_name='KLDivLoss',
- input_fn=lambda: (torch.rand(10, 10).log(), torch.rand(10, 10)),
- ),
- dict( # noqa: C408
- module_name='MSELoss',
- input_fn=lambda: ((2, 3, 4, 5), (2, 3, 4, 5)),
- ),
- dict( # noqa: C408
- module_name='BCELoss',
- input_fn=lambda: (torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2), torch.randn(15, 10).gt(0).double()),
- no_grad=True,
- ),
- dict( # noqa: C408
- module_name='BCELoss',
- constructor_args_fn=lambda: (torch.rand(10),),
- input_fn=lambda: (torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2), torch.randn(15, 10).gt(0).double()),
- desc='weights',
- no_grad=True,
- ),
- dict( # noqa: C408
- module_name='BCEWithLogitsLoss',
- constructor_args=(torch.rand(10), False, None, 'mean', torch.rand(10)),
- input_fn=lambda: (torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2), torch.randn(15, 10).gt(0).double()),
- no_grad=True,
- ),
- dict( # noqa: C408
- module_name='BCEWithLogitsLoss',
- constructor_args=(torch.rand(15, 10), False),
- input_fn=lambda: (torch.rand(15, 10).clamp_(1e-2, 1 - 1e-2), torch.randn(15, 10).gt(0).double()),
- desc='weights',
- ),
- dict( # noqa: C408
- module_name='HingeEmbeddingLoss',
- input_fn=lambda: (torch.randn(10), torch.randn(10).gt(0).double().mul_(2).sub(1)),
- no_grad=True,
- ),
- dict( # noqa: C408
- module_name='HingeEmbeddingLoss',
- constructor_args=(0.5,),
- input_fn=lambda: (torch.randn(10), torch.randn(10).gt(0).double().mul_(2).sub(1)),
- desc='margin',
- no_grad=True,
- ),
- dict( # noqa: C408
- module_name='MultiLabelMarginLoss',
- input_fn=lambda: (torch.rand(10,), torch.rand(10).mul(10).floor().long()),
- no_grad=True,
- ),
- dict( # noqa: C408
- module_name='SmoothL1Loss',
- input_fn=lambda: ((5, 10), (5, 10)),
- ),
- dict( # noqa: C408
- module_name='SoftMarginLoss',
- input_fn=lambda: (torch.randn(5, 5).sign(), torch.randn(5, 5).sign()),
- no_grad=True,
- ),
- dict( # noqa: C408
- module_name='CrossEntropyLoss',
- input_fn=lambda: (torch.randn(15, 10), torch.Tensor(15).uniform_().mul(10).floor().long()),
- ),
- dict( # noqa: C408
- module_name='MultiLabelSoftMarginLoss',
- constructor_args=(torch.rand(10),),
- input_fn=lambda: (torch.randn(5, 10), torch.rand(5, 10).mul(2).floor()),
- no_grad=True,
- ),
- dict( # noqa: C408
- module_name='CosineEmbeddingLoss',
- input_fn=lambda: (torch.rand(15, 10), torch.rand(15, 10), torch.randn(15).sign()),
- no_grad=True,
- ),
- dict( # noqa: C408
- module_name='MarginRankingLoss',
- input_fn=lambda: (torch.randn(50).mul(10), torch.randn(50).mul(10), torch.randn(50).sign()),
- ),
- dict( # noqa: C408
- module_name='TripletMarginLoss',
- input_fn=lambda: (torch.randn(5, 10, requires_grad=True), torch.randn(5, 10, requires_grad=True),
- torch.randn(5, 10, requires_grad=True)),
- ),
- dict( # noqa: C408
- module_name='MultiMarginLoss',
- input_fn=lambda: (torch.randn(5, 10), torch.rand(5).mul(8).floor().long()),
- no_grad=True,
- ),
- dict( # noqa: C408
- module_name='PoissonNLLLoss',
- input_fn=lambda:(torch.randn(2, 3, 4, 5), torch.randn(2, 3, 4, 5).floor_().abs_()),
- ),
- dict(
- module_name='CTCLoss',
- constructor_args=(14,),
- input_fn=lambda: (torch.randn(50, 16, 20).log_softmax(2),
- torch.randint(1, 20, (16, 30), dtype=torch.long),
- torch.full((16,), 50, dtype=torch.long),
- torch.randint(10, 30, (16,), dtype=torch.long)),
- no_grad=True,
- ),
]
if "FunctionalModule" in str(nn_module):
return
- constructor_args = kwargs.get('constructor_args', ())
+ if 'constructor_args_fn' in kwargs:
+ constructor_args = kwargs['constructor_args_fn']()
+ else:
+ constructor_args = kwargs.get('constructor_args', ())
# Construct a script module that passes arguments through
# to self.submodule
module = nn_module(*constructor_args)
return module(*args)
- # Check against Python module as reference
+ # Set up inputs from tuple of sizes or constructor fn
if 'input_fn' in kwargs:
input = kwargs['input_fn']()
else:
input = (kwargs['input_size'],)
+ # Extra parameters to forward()
if 'extra_args' in kwargs:
input = input + kwargs['extra_args']
+ if 'target_size' in kwargs:
+ input = input + (kwargs['target_size'],)
+ elif 'target_fn' in kwargs:
+ if torch.is_tensor(input):
+ input = (input,)
+ input = input + (kwargs['target_fn'](),)
+
args_variable, kwargs_variable = create_input(input)
f_args_variable = deepcopy(unpack_variables(args_variable))
+ # Check against Python module as reference
check_against_reference(self, create_script_module, create_nn_module, f_args_variable, no_grad=no_grad)
post_add_test(test_name, (), do_test)
for test in nn_functional_tests:
add_nn_functional_test(*test)
-for test in module_tests + new_module_tests + additional_module_tests + local_module_tests:
+for test in module_tests + new_module_tests + additional_module_tests:
+ add_nn_module_test(**test)
+
+for test in criterion_tests:
+ test['no_grad'] = True
add_nn_module_test(**test)
if __name__ == '__main__':