torch._C._jit_override_can_fuse_on_cpu(False)
return wrapper
+# note: not re-entrant, use unnested only
+@contextmanager
+def disable_autodiff_subgraph_inlining(enabled=True):
+ torch._C._debug_set_autodiff_subgraph_inlining(not enabled)
+ yield
+ torch._C._debug_set_autodiff_subgraph_inlining(True)
+
# helper function to get sum of List[Tensor]
def _sum_of_list(tensorlist):
@torch.jit.script
def func2(x, y):
return torch.cat((x, x), y)
- func2.debug_disable_autodiff_subgraph_inlining()
+ with disable_autodiff_subgraph_inlining():
- x = torch.rand([2, 2]).requires_grad_()
- y = torch.tensor(1)
+ x = torch.rand([2, 2]).requires_grad_()
+ y = torch.tensor(1)
- output = func2(x, y)
- output_ref = torch.cat((x, x), y)
- self.assertEqual(output, output_ref)
+ output = func2(x, y)
+ output_ref = torch.cat((x, x), y)
+ self.assertEqual(output, output_ref)
- self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::cat'], [])
+ self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::cat'], [])
- grad = torch.autograd.grad(output.sum(), x)
- grad_ref = torch.autograd.grad(output_ref.sum(), x)
- self.assertEqual(grad, grad_ref)
+ grad = torch.autograd.grad(output.sum(), x)
+ grad_ref = torch.autograd.grad(output_ref.sum(), x)
+ self.assertEqual(grad, grad_ref)
def test_cat_lifts(self):
@torch.jit.script
def func2(x, y):
return torch.stack((x, y), dim=0)
- func2.debug_disable_autodiff_subgraph_inlining()
-
- x = torch.randn([2, 2]).requires_grad_()
- y = torch.randn([2, 2]).requires_grad_()
+ with disable_autodiff_subgraph_inlining():
+ x = torch.randn([2, 2]).requires_grad_()
+ y = torch.randn([2, 2]).requires_grad_()
- output = func2(x, y)
- output_ref = torch.stack((x, y), 0)
- self.assertEqual(output, output_ref)
+ output = func2(x, y)
+ output_ref = torch.stack((x, y), 0)
+ self.assertEqual(output, output_ref)
- self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::stack'], [])
+ self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::stack'], [])
- grads = torch.autograd.grad(output.sum(), (x, y))
- grads_ref = torch.autograd.grad(output_ref.sum(), (x, y))
- self.assertEqual(grads, grads_ref)
+ grads = torch.autograd.grad(output.sum(), (x, y))
+ grads_ref = torch.autograd.grad(output_ref.sum(), (x, y))
+ self.assertEqual(grads, grads_ref)
def test_unbind(self):
@torch.jit.script
def func(x, y):
# type: (Tensor, int) -> List[Tensor]
return torch.unbind(x, y)
- func.debug_disable_autodiff_subgraph_inlining()
-
- x = torch.rand([2, 2]).requires_grad_()
- y = 0
- outputs = func(x, y)
- outputs_ref = torch.unbind(x, dim=y)
- self.assertEqual(outputs, outputs_ref)
+ with disable_autodiff_subgraph_inlining():
+ x = torch.rand([2, 2]).requires_grad_()
+ y = 0
+ outputs = func(x, y)
+ outputs_ref = torch.unbind(x, dim=y)
+ self.assertEqual(outputs, outputs_ref)
- self.assertAutodiffNode(func.graph_for(x, y), True, ['aten::unbind'], [])
+ self.assertAutodiffNode(func.graph_for(x, y), True, ['aten::unbind'], [])
- grad = torch.autograd.grad(_sum_of_list(outputs), x)
- grad_ref = torch.autograd.grad(_sum_of_list(outputs_ref), x)
- self.assertEqual(grad, grad_ref)
+ grad = torch.autograd.grad(_sum_of_list(outputs), x)
+ grad_ref = torch.autograd.grad(_sum_of_list(outputs_ref), x)
+ self.assertEqual(grad, grad_ref)
def test_meshgrid(self):
@torch.jit.script
def func(a):
# type: (List[Tensor]) -> List[Tensor]
return torch.meshgrid(a)
- func.debug_disable_autodiff_subgraph_inlining()
+ with disable_autodiff_subgraph_inlining():
+ a = torch.tensor([1.0, 2, 3]).requires_grad_()
+ b = torch.tensor([1.0, 2, 3, 4]).requires_grad_()
+ inputs = [a, b]
- a = torch.tensor([1.0, 2, 3]).requires_grad_()
- b = torch.tensor([1.0, 2, 3, 4]).requires_grad_()
- inputs = [a, b]
+ outputs_ref = torch.meshgrid(inputs)
+ outputs = func(inputs)
+ self.assertEqual(outputs, outputs_ref)
- outputs_ref = torch.meshgrid(inputs)
- outputs = func(inputs)
- self.assertEqual(outputs, outputs_ref)
+ self.assertAutodiffNode(func.graph_for(inputs), True, ['aten::meshgrid'], [])
- self.assertAutodiffNode(func.graph_for(inputs), True, ['aten::meshgrid'], [])
-
- grads = torch.autograd.grad(_sum_of_list(outputs), inputs)
- grads_ref = torch.autograd.grad(_sum_of_list(outputs_ref), inputs)
- self.assertEqual(grads, grads_ref)
+ grads = torch.autograd.grad(_sum_of_list(outputs), inputs)
+ grads_ref = torch.autograd.grad(_sum_of_list(outputs_ref), inputs)
+ self.assertEqual(grads, grads_ref)
def test_list_literal(self):
def reassign():
def func2(t, t_ref):
return t.to(t_ref)
- func2.debug_disable_autodiff_subgraph_inlining()
-
- t_ref = torch.tensor(4).double()
- out_ref = t.to(t_ref)
- out = func2(t, t_ref)
- grad_ref = torch.autograd.grad(out_ref.sum(), t)
- grad = torch.autograd.grad(out.sum(), t)
- self.assertEqual(grad_ref, grad)
+ with disable_autodiff_subgraph_inlining():
+ t_ref = torch.tensor(4).double()
+ out_ref = t.to(t_ref)
+ out = func2(t, t_ref)
+ grad_ref = torch.autograd.grad(out_ref.sum(), t)
+ grad = torch.autograd.grad(out.sum(), t)
+ self.assertEqual(grad_ref, grad)
@unittest.skipIf(not RUN_CUDA, "No CUDA")
def test_tensor_number_math_cuda(self):
# create a trace function from input fn
-#
-# disable_autodiff_subgraph_inlining:
-# Don't inline autodiff subgraphs so we can test autodiff
-def create_traced_fn(self, fn,
- disable_autodiff_subgraph_inlining=False):
+def create_traced_fn(self, fn):
def traced_fn(*inputs, **kwargs):
fn_tensors, inputs_tensors = partial_apply_nontensors(fn, inputs, **kwargs)
traced = torch.jit.trace(fn_tensors, inputs_tensors)
self.assertExportImport(traced.graph, inputs_tensors)
- if disable_autodiff_subgraph_inlining:
- traced.debug_disable_autodiff_subgraph_inlining()
output = traced(*inputs_tensors)
traced_fn.last_graph = traced.graph_for(*inputs_tensors)
return output
# create a script function from (name, func_type, output_process_fn),
# returns a function takes in (args, kwargs) and runs the compiled function and
# then applies the post process fn to the outputs
-def create_script_fn(self, method_name, func_type, output_process_fn,
- disable_autodiff_subgraph_inlining=False):
+def create_script_fn(self, method_name, func_type, output_process_fn):
def script_fn(*args, **kwargs):
formals, tensors, actuals = get_script_args(args)
kwargs_str = ''
script = script_template.format(', '.join(formals), call)
CU = torch.jit.CompilationUnit(script)
- if disable_autodiff_subgraph_inlining:
- CU.the_method.debug_disable_autodiff_subgraph_inlining()
self.assertExportImport(CU.the_method.graph, tensors)
output = output_process_fn(CU.the_method(*tensors))
script_fn.last_graph = CU.the_method.graph_for(*tensors)
# TODO: It is better if we can test directly on graphs instead of the current
# end-to-end fashion.
def _perform_ad_subgraph_slicing(self, fn, *input_sizes):
- ge = torch.jit.script(fn)
- ge.debug_disable_autodiff_subgraph_inlining()
- inputs = [torch.randn(size, requires_grad=True) for size in input_sizes]
- ge(*inputs)
- return ge.graph_for(*inputs)
+ with disable_autodiff_subgraph_inlining():
+ ge = torch.jit.script(fn)
+ inputs = [torch.randn(size, requires_grad=True) for size in input_sizes]
+ ge(*inputs)
+ return ge.graph_for(*inputs)
def assertGraphSize(self, graph, size):
self.assertEqual(len(list(graph.nodes())), size)
return (x1, x2)
input = torch.rand(6, 10).requires_grad_()
- func.debug_disable_autodiff_subgraph_inlining()
- output = func(input)
- self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], [])
+ with disable_autodiff_subgraph_inlining():
+ output = func(input)
+ self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], [])
def test_simple_merge(self):
# o --> o
return output_process_fn(output)
check_types = test_name not in EXCLUDE_TYPE_CHECK
-
- if not is_inplace and name not in EXCLUDE_GRADCHECK and not exclude_tensor_method(name, test_name):
- # Test with disable_autodiff_subgraph_inlining, which forces the graph
- # to contain DifferentiableGraph nodes whenever possible. This allows us
- # to test autodiff; we assume that autograd is correct and use autodiff for backprop
- should_autodiff_node, autodiff_nodes, fusible_nodes = normalize_check_ad(check_ad, name)
- if test_name not in EXCLUDE_TRACED:
- traced_fn = create_traced_fn(self, fn, disable_autodiff_subgraph_inlining=True)
-
- check_against_reference(self, traced_fn,
- fn, (self_variable,) + args_variable, kwargs_variable,
- check_types=check_types)
- self.assertAutodiffNode(traced_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)
-
- if not is_magic_method and test_name not in EXCLUDE_SCRIPT:
- script_fn = create_script_fn(self, name, 'method', output_process_fn,
- disable_autodiff_subgraph_inlining=True)
- check_against_reference(self, script_fn,
- fn, (self_variable,) + args_variable, kwargs_variable,
- check_types=check_types)
-
- self.assertAutodiffNode(script_fn.last_graph,
- should_autodiff_node and test_name not in EXCLUDE_SCRIPT_AD_CHECK,
- autodiff_nodes,
- fusible_nodes)
-
- # functional interface tests
- if hasattr(torch, name) and name not in EXCLUDE_FUNCTIONAL:
- def fn(*inputs, **kwargs):
- output = getattr(torch, name)(*inputs, **kwargs)
- return output_process_fn(output)
-
- f_args_variable = (self_variable,) + args_variable
- f_args_tensor = (self_tensor,) + args_tensor
-
- if not is_inplace and test_name not in EXCLUDE_TRACED:
- check_against_reference(self,
- create_traced_fn(self, fn,
- disable_autodiff_subgraph_inlining=True),
- fn, f_args_variable, kwargs_variable, check_types=check_types)
-
- if not is_inplace and test_name not in EXCLUDE_SCRIPT:
- check_against_reference(self,
- create_script_fn(self, name, 'functional', output_process_fn,
- disable_autodiff_subgraph_inlining=True),
- fn, f_args_variable, kwargs_variable,
- check_types=check_types)
+ with disable_autodiff_subgraph_inlining():
+ if not is_inplace and name not in EXCLUDE_GRADCHECK and not exclude_tensor_method(name, test_name):
+ # Test with disable_autodiff_subgraph_inlining, which forces the graph
+ # to contain DifferentiableGraph nodes whenever possible. This allows us
+ # to test autodiff; we assume that autograd is correct and use autodiff for backprop
+ should_autodiff_node, autodiff_nodes, fusible_nodes = normalize_check_ad(check_ad, name)
+
+ if test_name not in EXCLUDE_TRACED:
+ traced_fn = create_traced_fn(self, fn)
+
+ check_against_reference(self, traced_fn,
+ fn, (self_variable,) + args_variable, kwargs_variable,
+ check_types=check_types)
+ self.assertAutodiffNode(traced_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)
+
+ if not is_magic_method and test_name not in EXCLUDE_SCRIPT:
+ script_fn = create_script_fn(self, name, 'method', output_process_fn)
+ check_against_reference(self, script_fn,
+ fn, (self_variable,) + args_variable, kwargs_variable,
+ check_types=check_types)
+
+ self.assertAutodiffNode(script_fn.last_graph,
+ should_autodiff_node and test_name not in EXCLUDE_SCRIPT_AD_CHECK,
+ autodiff_nodes,
+ fusible_nodes)
+
+ # functional interface tests
+ if hasattr(torch, name) and name not in EXCLUDE_FUNCTIONAL:
+ def fn(*inputs, **kwargs):
+ output = getattr(torch, name)(*inputs, **kwargs)
+ return output_process_fn(output)
+
+ f_args_variable = (self_variable,) + args_variable
+ f_args_tensor = (self_tensor,) + args_tensor
+
+ if not is_inplace and test_name not in EXCLUDE_TRACED:
+ check_against_reference(self,
+ create_traced_fn(self, fn),
+ fn, f_args_variable, kwargs_variable, check_types=check_types)
+
+ if not is_inplace and test_name not in EXCLUDE_SCRIPT:
+ check_against_reference(self,
+ create_script_fn(self, name, 'functional', output_process_fn),
+ fn, f_args_variable, kwargs_variable,
+ check_types=check_types)
# alias annotation testing
if is_inplace and test_name not in EXCLUDE_SCRIPT:
should_autodiff_node, autodiff_nodes, fusible_nodes = normalize_check_ad(check_ad, name)
if test_name not in EXCLUDE_SCRIPT:
def run_test():
- script_fn = create_script_fn(self, name, 'nn_functional', output_process_fn,
- disable_autodiff_subgraph_inlining=should_autodiff_node)
- check_against_reference(self, script_fn, fn, f_args_variable, kwargs_variable, no_grad=no_grad)
- # For tests we disabled AD subgraph inlining, make sure it's not falling back to autograd
- self.assertAutodiffNode(script_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)
+ with disable_autodiff_subgraph_inlining(should_autodiff_node):
+ script_fn = create_script_fn(self, name, 'nn_functional', output_process_fn)
+ check_against_reference(self, script_fn, fn, f_args_variable, kwargs_variable, no_grad=no_grad)
+ # For tests we disabled AD subgraph inlining, make sure it's not falling back to autograd
+ self.assertAutodiffNode(script_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)
if test_name in EXCLUDE_PYTHON_PRINT:
with self.disableModuleHook():