TST Adds inplace checks to module_info (#63739)
authorThomas J. Fan <thomasjpfan@gmail.com>
Wed, 8 Sep 2021 18:00:11 +0000 (11:00 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Wed, 8 Sep 2021 18:08:12 +0000 (11:08 -0700)
Summary:
Follow up to https://github.com/pytorch/pytorch/pull/61935

This PR adds inplace checks to `test_modules`. This version checks the constructor for `inplace` and performs the check automatically.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63739

Reviewed By: saketh-are

Differential Revision: D30737774

Pulled By: jbschlosser

fbshipit-source-id: 8813534511e9296c8424d1ca878412726ddd4043

test/test_modules.py
torch/testing/_internal/common_modules.py

index 6d6adbc..37ab347 100644 (file)
@@ -1,3 +1,5 @@
+from inspect import signature
+from copy import deepcopy
 import tempfile
 
 import torch
@@ -154,6 +156,55 @@ class TestModule(TestCase):
                     output_from_copy = m_copy(*args, **kwargs)
                     self.assertEqual(output, output_from_copy)
 
+    @modules([module_info for module_info in module_db
+              if 'inplace' in signature(module_info.module_cls).parameters])
+    def test_check_inplace(self, device, dtype, module_info):
+        # Check if the inplace variant of the module gives the same result as the out of place
+        # variant.
+        module_cls = module_info.module_cls
+        module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
+                                                       requires_grad=True)
+        for module_input in module_inputs:
+            if module_input.forward_input is None:
+                continue
+
+            # === Instantiate the module. ===
+            args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
+            m_op = module_cls(*args, **kwargs, inplace=False)
+            m_op.to(device).to(dtype)
+            m_inplace = module_cls(*args, **kwargs, inplace=True)
+            m_inplace.to(device).to(dtype)
+
+            # === Inplace modules only supports inplace operations on the first argument ===
+            input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
+
+            # ===  Do not allow the first input to be in input_kwargs ===
+            forward_sig = signature(m_op).parameters
+            self.assertGreaterEqual(len(forward_sig), 1)
+            first_param_name = next(iter(forward_sig.items()))
+            self.assertNotIn(first_param_name, input_kwargs)
+
+            # === Out of place operation does not write to original tensor ===
+            self.assertGreaterEqual(len(input_args), 1)
+            input_version = input_args[0]._version
+            with freeze_rng_state():
+                output_op = m_op(*input_args, **input_kwargs)
+            self.assertEqual(input_args[0]._version, input_version)
+
+            # === Check that the inplace operation gives the same result ===
+            input_arg_copy = deepcopy(input_args)
+            input_arg_clone = tuple(i.clone() for i in input_arg_copy)
+            with freeze_rng_state():
+                output_ip = m_inplace(*input_arg_clone, **input_kwargs)
+            self.assertNotEqual(input_arg_clone[0]._version, input_version)
+            self.assertEqual(output_op, output_ip)
+
+            # === Check that the gradients are the same ===
+            grad = output_op.data.clone().normal_()
+            output_op.backward(grad)
+            output_ip.backward(grad)
+            self.assertEqual(input_args[0].grad, input_arg_copy[0].grad)
+
 
 instantiate_device_type_tests(TestModule, globals())
 
index a1059f6..b1cbbb3 100644 (file)
@@ -274,6 +274,18 @@ def module_inputs_torch_nn_CELU(module_info, device, dtype, requires_grad, **kwa
                     desc='no_batch_dim',
                     reference_fn=no_batch_dim_reference_fn)]
 
+def module_inputs_torch_nn_ReLU(module_info, device, dtype, requires_grad):
+    make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+    module_inputs = [
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input((2, 3, 4, 5)))),
+        ModuleInput(constructor_input=FunctionInput(),
+                    forward_input=FunctionInput(make_input(4)),
+                    desc='no_batch_dim'),
+    ]
+    return module_inputs
+
 
 def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, **kwargs):
     make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@@ -301,5 +313,7 @@ module_db: List[ModuleInfo] = [
     ModuleInfo(torch.nn.Linear,
                module_inputs_func=module_inputs_torch_nn_Linear),
     ModuleInfo(torch.nn.NLLLoss,
-               module_inputs_func=module_inputs_torch_nn_NLLLoss)
+               module_inputs_func=module_inputs_torch_nn_NLLLoss),
+    ModuleInfo(torch.nn.ReLU,
+               module_inputs_func=module_inputs_torch_nn_ReLU),
 ]