+from inspect import signature
+from copy import deepcopy
import tempfile
import torch
output_from_copy = m_copy(*args, **kwargs)
self.assertEqual(output, output_from_copy)
+ @modules([module_info for module_info in module_db
+ if 'inplace' in signature(module_info.module_cls).parameters])
+ def test_check_inplace(self, device, dtype, module_info):
+ # Check if the inplace variant of the module gives the same result as the out of place
+ # variant.
+ module_cls = module_info.module_cls
+ module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
+ requires_grad=True)
+ for module_input in module_inputs:
+ if module_input.forward_input is None:
+ continue
+
+ # === Instantiate the module. ===
+ args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
+ m_op = module_cls(*args, **kwargs, inplace=False)
+ m_op.to(device).to(dtype)
+ m_inplace = module_cls(*args, **kwargs, inplace=True)
+ m_inplace.to(device).to(dtype)
+
+ # === Inplace modules only supports inplace operations on the first argument ===
+ input_args, input_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
+
+ # === Do not allow the first input to be in input_kwargs ===
+ forward_sig = signature(m_op).parameters
+ self.assertGreaterEqual(len(forward_sig), 1)
+ first_param_name = next(iter(forward_sig.items()))
+ self.assertNotIn(first_param_name, input_kwargs)
+
+ # === Out of place operation does not write to original tensor ===
+ self.assertGreaterEqual(len(input_args), 1)
+ input_version = input_args[0]._version
+ with freeze_rng_state():
+ output_op = m_op(*input_args, **input_kwargs)
+ self.assertEqual(input_args[0]._version, input_version)
+
+ # === Check that the inplace operation gives the same result ===
+ input_arg_copy = deepcopy(input_args)
+ input_arg_clone = tuple(i.clone() for i in input_arg_copy)
+ with freeze_rng_state():
+ output_ip = m_inplace(*input_arg_clone, **input_kwargs)
+ self.assertNotEqual(input_arg_clone[0]._version, input_version)
+ self.assertEqual(output_op, output_ip)
+
+ # === Check that the gradients are the same ===
+ grad = output_op.data.clone().normal_()
+ output_op.backward(grad)
+ output_ip.backward(grad)
+ self.assertEqual(input_args[0].grad, input_arg_copy[0].grad)
+
instantiate_device_type_tests(TestModule, globals())
desc='no_batch_dim',
reference_fn=no_batch_dim_reference_fn)]
+def module_inputs_torch_nn_ReLU(module_info, device, dtype, requires_grad):
+ make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
+
+ module_inputs = [
+ ModuleInput(constructor_input=FunctionInput(),
+ forward_input=FunctionInput(make_input((2, 3, 4, 5)))),
+ ModuleInput(constructor_input=FunctionInput(),
+ forward_input=FunctionInput(make_input(4)),
+ desc='no_batch_dim'),
+ ]
+ return module_inputs
+
def module_inputs_torch_nn_L1Loss(module_info, device, dtype, requires_grad, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
ModuleInfo(torch.nn.Linear,
module_inputs_func=module_inputs_torch_nn_Linear),
ModuleInfo(torch.nn.NLLLoss,
- module_inputs_func=module_inputs_torch_nn_NLLLoss)
+ module_inputs_func=module_inputs_torch_nn_NLLLoss),
+ ModuleInfo(torch.nn.ReLU,
+ module_inputs_func=module_inputs_torch_nn_ReLU),
]