From 43c0f033fc39cd81c5577d90f4454969753706ef Mon Sep 17 00:00:00 2001 From: "Thomas J. Fan" Date: Wed, 8 Sep 2021 11:00:11 -0700 Subject: [PATCH] TST Adds inplace checks to module_info (#63739) Summary: Follow up to https://github.com/pytorch/pytorch/pull/61935 This PR adds inplace checks to `test_modules`. This version checks the constructor for `inplace` and performs the check automatically. Pull Request resolved: https://github.com/pytorch/pytorch/pull/63739 Reviewed By: saketh-are Differential Revision: D30737774 Pulled By: jbschlosser fbshipit-source-id: 8813534511e9296c8424d1ca878412726ddd4043 --- test/test_modules.py | 51 +++++++++++++++++++++++++++++++ torch/testing/_internal/common_modules.py | 16 +++++++++- 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/test/test_modules.py b/test/test_modules.py index 6d6adbc..37ab347 100644 --- a/test/test_modules.py +++ b/test/test_modules.py @@ -1,3 +1,5 @@ +from inspect import signature +from copy import deepcopy import tempfile import torch @@ -154,6 +156,55 @@ class TestModule(TestCase): output_from_copy = m_copy(*args, **kwargs) self.assertEqual(output, output_from_copy) + @modules([module_info for module_info in module_db + if 'inplace' in signature(module_info.module_cls).parameters]) + def test_check_inplace(self, device, dtype, module_info): + # Check if the inplace variant of the module gives the same result as the out of place + # variant. + module_cls = module_info.module_cls + module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, + requires_grad=True) + for module_input in module_inputs: + if module_input.forward_input is None: + continue + + # === Instantiate the module. === + args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs + m_op = module_cls(*args, **kwargs, inplace=False) + m_op.to(device).to(dtype) + m_inplace = module_cls(*args, **kwargs, inplace=True) + m_inplace.to(device).to(dtype) + + # === Inplace modules only supports inplace operations on the first argument === + input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs + + # === Do not allow the first input to be in input_kwargs === + forward_sig = signature(m_op).parameters + self.assertGreaterEqual(len(forward_sig), 1) + first_param_name = next(iter(forward_sig.items())) + self.assertNotIn(first_param_name, input_kwargs) + + # === Out of place operation does not write to original tensor === + self.assertGreaterEqual(len(input_args), 1) + input_version = input_args[0]._version + with freeze_rng_state(): + output_op = m_op(*input_args, **input_kwargs) + self.assertEqual(input_args[0]._version, input_version) + + # === Check that the inplace operation gives the same result === + input_arg_copy = deepcopy(input_args) + input_arg_clone = tuple(i.clone() for i in input_arg_copy) + with freeze_rng_state(): + output_ip = m_inplace(*input_arg_clone, **input_kwargs) + self.assertNotEqual(input_arg_clone[0]._version, input_version) + self.assertEqual(output_op, output_ip) + + # === Check that the gradients are the same === + grad = output_op.data.clone().normal_() + output_op.backward(grad) + output_ip.backward(grad) + self.assertEqual(input_args[0].grad, input_arg_copy[0].grad) + instantiate_device_type_tests(TestModule, globals()) diff --git a/torch/testing/_internal/common_modules.py b/torch/testing/_internal/common_modules.py index a1059f6..b1cbbb3 100644 --- a/torch/testing/_internal/common_modules.py +++ b/torch/testing/_internal/common_modules.py @@ -274,6 +274,18 @@ def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, **kwa desc='no_batch_dim', reference_fn=no_batch_dim_reference_fn)] +def module_inputs_torch_nn_ReLU(module_info, device, dtype, requires_grad): + make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) + + module_inputs = [ + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input((2, 3, 4, 5)))), + ModuleInput(constructor_input=FunctionInput(), + forward_input=FunctionInput(make_input(4)), + desc='no_batch_dim'), + ] + return module_inputs + def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, **kwargs): make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) @@ -301,5 +313,7 @@ module_db: List[ModuleInfo] = [ ModuleInfo(torch.nn.Linear, module_inputs_func=module_inputs_torch_nn_Linear), ModuleInfo(torch.nn.NLLLoss, - module_inputs_func=module_inputs_torch_nn_NLLLoss) + module_inputs_func=module_inputs_torch_nn_NLLLoss), + ModuleInfo(torch.nn.ReLU, + module_inputs_func=module_inputs_torch_nn_ReLU), ] -- 2.7.4