From: Zachary DeVito Date: Sat, 13 Apr 2019 15:28:11 +0000 (-0700) Subject: Make debug subgraph inlining thread local (#19136) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~242 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=1827ca4c3590959940a311a2f08be02c82e79bd5;p=platform%2Fupstream%2Fpytorch.git Make debug subgraph inlining thread local (#19136) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19136 ghimport-source-id: 3a24ab36aa753ce5cce7bba3467bdbe88e5c7f60 Reviewed By: jamesr66a Differential Revision: D14885051 Pulled By: zdevito fbshipit-source-id: b39c6ceef73ad9caefcbf8f40dd1b9132bba03c2 --- diff --git a/test/test_jit.py b/test/test_jit.py index ea9c853..015c67c 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -252,6 +252,13 @@ def enable_cpu_fuser(fn): torch._C._jit_override_can_fuse_on_cpu(False) return wrapper +# note: not re-entrant, use unnested only +@contextmanager +def disable_autodiff_subgraph_inlining(enabled=True): + torch._C._debug_set_autodiff_subgraph_inlining(not enabled) + yield + torch._C._debug_set_autodiff_subgraph_inlining(True) + # helper function to get sum of List[Tensor] def _sum_of_list(tensorlist): @@ -3744,20 +3751,20 @@ a") @torch.jit.script def func2(x, y): return torch.cat((x, x), y) - func2.debug_disable_autodiff_subgraph_inlining() + with disable_autodiff_subgraph_inlining(): - x = torch.rand([2, 2]).requires_grad_() - y = torch.tensor(1) + x = torch.rand([2, 2]).requires_grad_() + y = torch.tensor(1) - output = func2(x, y) - output_ref = torch.cat((x, x), y) - self.assertEqual(output, output_ref) + output = func2(x, y) + output_ref = torch.cat((x, x), y) + self.assertEqual(output, output_ref) - self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::cat'], []) + self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::cat'], []) - grad = torch.autograd.grad(output.sum(), x) - grad_ref = torch.autograd.grad(output_ref.sum(), x) - self.assertEqual(grad, grad_ref) + grad = torch.autograd.grad(output.sum(), x) + grad_ref = torch.autograd.grad(output_ref.sum(), x) + self.assertEqual(grad, grad_ref) def test_cat_lifts(self): @torch.jit.script @@ -3787,60 +3794,57 @@ a") def func2(x, y): return torch.stack((x, y), dim=0) - func2.debug_disable_autodiff_subgraph_inlining() - - x = torch.randn([2, 2]).requires_grad_() - y = torch.randn([2, 2]).requires_grad_() + with disable_autodiff_subgraph_inlining(): + x = torch.randn([2, 2]).requires_grad_() + y = torch.randn([2, 2]).requires_grad_() - output = func2(x, y) - output_ref = torch.stack((x, y), 0) - self.assertEqual(output, output_ref) + output = func2(x, y) + output_ref = torch.stack((x, y), 0) + self.assertEqual(output, output_ref) - self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::stack'], []) + self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::stack'], []) - grads = torch.autograd.grad(output.sum(), (x, y)) - grads_ref = torch.autograd.grad(output_ref.sum(), (x, y)) - self.assertEqual(grads, grads_ref) + grads = torch.autograd.grad(output.sum(), (x, y)) + grads_ref = torch.autograd.grad(output_ref.sum(), (x, y)) + self.assertEqual(grads, grads_ref) def test_unbind(self): @torch.jit.script def func(x, y): # type: (Tensor, int) -> List[Tensor] return torch.unbind(x, y) - func.debug_disable_autodiff_subgraph_inlining() - - x = torch.rand([2, 2]).requires_grad_() - y = 0 - outputs = func(x, y) - outputs_ref = torch.unbind(x, dim=y) - self.assertEqual(outputs, outputs_ref) + with disable_autodiff_subgraph_inlining(): + x = torch.rand([2, 2]).requires_grad_() + y = 0 + outputs = func(x, y) + outputs_ref = torch.unbind(x, dim=y) + self.assertEqual(outputs, outputs_ref) - self.assertAutodiffNode(func.graph_for(x, y), True, ['aten::unbind'], []) + self.assertAutodiffNode(func.graph_for(x, y), True, ['aten::unbind'], []) - grad = torch.autograd.grad(_sum_of_list(outputs), x) - grad_ref = torch.autograd.grad(_sum_of_list(outputs_ref), x) - self.assertEqual(grad, grad_ref) + grad = torch.autograd.grad(_sum_of_list(outputs), x) + grad_ref = torch.autograd.grad(_sum_of_list(outputs_ref), x) + self.assertEqual(grad, grad_ref) def test_meshgrid(self): @torch.jit.script def func(a): # type: (List[Tensor]) -> List[Tensor] return torch.meshgrid(a) - func.debug_disable_autodiff_subgraph_inlining() + with disable_autodiff_subgraph_inlining(): + a = torch.tensor([1.0, 2, 3]).requires_grad_() + b = torch.tensor([1.0, 2, 3, 4]).requires_grad_() + inputs = [a, b] - a = torch.tensor([1.0, 2, 3]).requires_grad_() - b = torch.tensor([1.0, 2, 3, 4]).requires_grad_() - inputs = [a, b] + outputs_ref = torch.meshgrid(inputs) + outputs = func(inputs) + self.assertEqual(outputs, outputs_ref) - outputs_ref = torch.meshgrid(inputs) - outputs = func(inputs) - self.assertEqual(outputs, outputs_ref) + self.assertAutodiffNode(func.graph_for(inputs), True, ['aten::meshgrid'], []) - self.assertAutodiffNode(func.graph_for(inputs), True, ['aten::meshgrid'], []) - - grads = torch.autograd.grad(_sum_of_list(outputs), inputs) - grads_ref = torch.autograd.grad(_sum_of_list(outputs_ref), inputs) - self.assertEqual(grads, grads_ref) + grads = torch.autograd.grad(_sum_of_list(outputs), inputs) + grads_ref = torch.autograd.grad(_sum_of_list(outputs_ref), inputs) + self.assertEqual(grads, grads_ref) def test_list_literal(self): def reassign(): @@ -5785,14 +5789,13 @@ a") def func2(t, t_ref): return t.to(t_ref) - func2.debug_disable_autodiff_subgraph_inlining() - - t_ref = torch.tensor(4).double() - out_ref = t.to(t_ref) - out = func2(t, t_ref) - grad_ref = torch.autograd.grad(out_ref.sum(), t) - grad = torch.autograd.grad(out.sum(), t) - self.assertEqual(grad_ref, grad) + with disable_autodiff_subgraph_inlining(): + t_ref = torch.tensor(4).double() + out_ref = t.to(t_ref) + out = func2(t, t_ref) + grad_ref = torch.autograd.grad(out_ref.sum(), t) + grad = torch.autograd.grad(out.sum(), t) + self.assertEqual(grad_ref, grad) @unittest.skipIf(not RUN_CUDA, "No CUDA") def test_tensor_number_math_cuda(self): @@ -12063,17 +12066,11 @@ def partial_apply_nontensors(fn, args, **kwargs): # create a trace function from input fn -# -# disable_autodiff_subgraph_inlining: -# Don't inline autodiff subgraphs so we can test autodiff -def create_traced_fn(self, fn, - disable_autodiff_subgraph_inlining=False): +def create_traced_fn(self, fn): def traced_fn(*inputs, **kwargs): fn_tensors, inputs_tensors = partial_apply_nontensors(fn, inputs, **kwargs) traced = torch.jit.trace(fn_tensors, inputs_tensors) self.assertExportImport(traced.graph, inputs_tensors) - if disable_autodiff_subgraph_inlining: - traced.debug_disable_autodiff_subgraph_inlining() output = traced(*inputs_tensors) traced_fn.last_graph = traced.graph_for(*inputs_tensors) return output @@ -12118,8 +12115,7 @@ def get_script_args(args): # create a script function from (name, func_type, output_process_fn), # returns a function takes in (args, kwargs) and runs the compiled function and # then applies the post process fn to the outputs -def create_script_fn(self, method_name, func_type, output_process_fn, - disable_autodiff_subgraph_inlining=False): +def create_script_fn(self, method_name, func_type, output_process_fn): def script_fn(*args, **kwargs): formals, tensors, actuals = get_script_args(args) kwargs_str = '' @@ -12137,8 +12133,6 @@ def create_script_fn(self, method_name, func_type, output_process_fn, script = script_template.format(', '.join(formals), call) CU = torch.jit.CompilationUnit(script) - if disable_autodiff_subgraph_inlining: - CU.the_method.debug_disable_autodiff_subgraph_inlining() self.assertExportImport(CU.the_method.graph, tensors) output = output_process_fn(CU.the_method(*tensors)) script_fn.last_graph = CU.the_method.graph_for(*tensors) @@ -12249,11 +12243,11 @@ class TestAutodiffSubgraphSlicing(JitTestCase): # TODO: It is better if we can test directly on graphs instead of the current # end-to-end fashion. def _perform_ad_subgraph_slicing(self, fn, *input_sizes): - ge = torch.jit.script(fn) - ge.debug_disable_autodiff_subgraph_inlining() - inputs = [torch.randn(size, requires_grad=True) for size in input_sizes] - ge(*inputs) - return ge.graph_for(*inputs) + with disable_autodiff_subgraph_inlining(): + ge = torch.jit.script(fn) + inputs = [torch.randn(size, requires_grad=True) for size in input_sizes] + ge(*inputs) + return ge.graph_for(*inputs) def assertGraphSize(self, graph, size): self.assertEqual(len(list(graph.nodes())), size) @@ -12265,9 +12259,9 @@ class TestAutodiffSubgraphSlicing(JitTestCase): return (x1, x2) input = torch.rand(6, 10).requires_grad_() - func.debug_disable_autodiff_subgraph_inlining() - output = func(input) - self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], []) + with disable_autodiff_subgraph_inlining(): + output = func(input) + self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], []) def test_simple_merge(self): # o --> o @@ -12850,53 +12844,51 @@ def add_autograd_test( return output_process_fn(output) check_types = test_name not in EXCLUDE_TYPE_CHECK - - if not is_inplace and name not in EXCLUDE_GRADCHECK and not exclude_tensor_method(name, test_name): - # Test with disable_autodiff_subgraph_inlining, which forces the graph - # to contain DifferentiableGraph nodes whenever possible. This allows us - # to test autodiff; we assume that autograd is correct and use autodiff for backprop - should_autodiff_node, autodiff_nodes, fusible_nodes = normalize_check_ad(check_ad, name) - if test_name not in EXCLUDE_TRACED: - traced_fn = create_traced_fn(self, fn, disable_autodiff_subgraph_inlining=True) - - check_against_reference(self, traced_fn, - fn, (self_variable,) + args_variable, kwargs_variable, - check_types=check_types) - self.assertAutodiffNode(traced_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes) - - if not is_magic_method and test_name not in EXCLUDE_SCRIPT: - script_fn = create_script_fn(self, name, 'method', output_process_fn, - disable_autodiff_subgraph_inlining=True) - check_against_reference(self, script_fn, - fn, (self_variable,) + args_variable, kwargs_variable, - check_types=check_types) - - self.assertAutodiffNode(script_fn.last_graph, - should_autodiff_node and test_name not in EXCLUDE_SCRIPT_AD_CHECK, - autodiff_nodes, - fusible_nodes) - - # functional interface tests - if hasattr(torch, name) and name not in EXCLUDE_FUNCTIONAL: - def fn(*inputs, **kwargs): - output = getattr(torch, name)(*inputs, **kwargs) - return output_process_fn(output) - - f_args_variable = (self_variable,) + args_variable - f_args_tensor = (self_tensor,) + args_tensor - - if not is_inplace and test_name not in EXCLUDE_TRACED: - check_against_reference(self, - create_traced_fn(self, fn, - disable_autodiff_subgraph_inlining=True), - fn, f_args_variable, kwargs_variable, check_types=check_types) - - if not is_inplace and test_name not in EXCLUDE_SCRIPT: - check_against_reference(self, - create_script_fn(self, name, 'functional', output_process_fn, - disable_autodiff_subgraph_inlining=True), - fn, f_args_variable, kwargs_variable, - check_types=check_types) + with disable_autodiff_subgraph_inlining(): + if not is_inplace and name not in EXCLUDE_GRADCHECK and not exclude_tensor_method(name, test_name): + # Test with disable_autodiff_subgraph_inlining, which forces the graph + # to contain DifferentiableGraph nodes whenever possible. This allows us + # to test autodiff; we assume that autograd is correct and use autodiff for backprop + should_autodiff_node, autodiff_nodes, fusible_nodes = normalize_check_ad(check_ad, name) + + if test_name not in EXCLUDE_TRACED: + traced_fn = create_traced_fn(self, fn) + + check_against_reference(self, traced_fn, + fn, (self_variable,) + args_variable, kwargs_variable, + check_types=check_types) + self.assertAutodiffNode(traced_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes) + + if not is_magic_method and test_name not in EXCLUDE_SCRIPT: + script_fn = create_script_fn(self, name, 'method', output_process_fn) + check_against_reference(self, script_fn, + fn, (self_variable,) + args_variable, kwargs_variable, + check_types=check_types) + + self.assertAutodiffNode(script_fn.last_graph, + should_autodiff_node and test_name not in EXCLUDE_SCRIPT_AD_CHECK, + autodiff_nodes, + fusible_nodes) + + # functional interface tests + if hasattr(torch, name) and name not in EXCLUDE_FUNCTIONAL: + def fn(*inputs, **kwargs): + output = getattr(torch, name)(*inputs, **kwargs) + return output_process_fn(output) + + f_args_variable = (self_variable,) + args_variable + f_args_tensor = (self_tensor,) + args_tensor + + if not is_inplace and test_name not in EXCLUDE_TRACED: + check_against_reference(self, + create_traced_fn(self, fn), + fn, f_args_variable, kwargs_variable, check_types=check_types) + + if not is_inplace and test_name not in EXCLUDE_SCRIPT: + check_against_reference(self, + create_script_fn(self, name, 'functional', output_process_fn), + fn, f_args_variable, kwargs_variable, + check_types=check_types) # alias annotation testing if is_inplace and test_name not in EXCLUDE_SCRIPT: @@ -12954,11 +12946,11 @@ def add_nn_functional_test(name, self_size, args, variant_name='', check_ad=(), should_autodiff_node, autodiff_nodes, fusible_nodes = normalize_check_ad(check_ad, name) if test_name not in EXCLUDE_SCRIPT: def run_test(): - script_fn = create_script_fn(self, name, 'nn_functional', output_process_fn, - disable_autodiff_subgraph_inlining=should_autodiff_node) - check_against_reference(self, script_fn, fn, f_args_variable, kwargs_variable, no_grad=no_grad) - # For tests we disabled AD subgraph inlining, make sure it's not falling back to autograd - self.assertAutodiffNode(script_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes) + with disable_autodiff_subgraph_inlining(should_autodiff_node): + script_fn = create_script_fn(self, name, 'nn_functional', output_process_fn) + check_against_reference(self, script_fn, fn, f_args_variable, kwargs_variable, no_grad=no_grad) + # For tests we disabled AD subgraph inlining, make sure it's not falling back to autograd + self.assertAutodiffNode(script_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes) if test_name in EXCLUDE_PYTHON_PRINT: with self.disableModuleHook(): diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index 966246c..92d149f 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -46,12 +46,26 @@ namespace torch { namespace jit { +// for debugging it is helpful to be able to force autodiff subgraphs +// to be created, to check their correctness, even when the +// size of the of the subgraph is too small to be profitable. +thread_local bool autodiff_subgraph_inlining = true; +void debugSetAutodiffSubgraphInlining(bool state) { + autodiff_subgraph_inlining = state; +} + namespace { using tensor_list = std::vector; using Variable = autograd::Variable; using autograd::variable_list; +// Tunable parameters for deciding when to create/keep subgraphs of +// differentiable code + +const size_t autodiffSubgraphNodeThreshold = 2; +const size_t autodiffSubgraphInlineThreshold = 5; + struct ExecutionPlan { ExecutionPlan() = default; ExecutionPlan(std::shared_ptr graph) @@ -480,9 +494,9 @@ struct GraphExecutorImpl { num_inputs(this->graph->inputs().size()), arg_spec_creator_(*graph), num_outputs(this->graph->outputs().size()) { - logging::getLogger()->addStatValue( - logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0); - } + logging::getLogger()->addStatValue( + logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0); + } // entry point where execution begins void run(Stack& stack) { @@ -533,14 +547,6 @@ struct GraphExecutorImpl { return state; } - // This function should be used only for testing purposes - void debugDisableAutodiffSubgraphInlining() { - // Allow single-node autodiff subgraphs - autodiffSubgraphNodeThreshold = 1; - // Don't inline autodiff subgraphs into autograd functions - autodiffSubgraphInlineThreshold = 1; - } - private: friend struct GraphExecutor; @@ -603,15 +609,18 @@ struct GraphExecutorImpl { // Phase 5. Apply non-differentiable optimizations to the graphs we've found // (or the whole grpah if we know we won't need its derivative). if (needsGradient(opt_graph)) { - auto diff_nodes = - CreateAutodiffSubgraphs(opt_graph, autodiffSubgraphNodeThreshold); + auto diff_nodes = CreateAutodiffSubgraphs( + opt_graph, + autodiff_subgraph_inlining ? autodiffSubgraphNodeThreshold : 1); for (Node* dnode : diff_nodes) { auto diff_graph = std::move(dnode->g(attr::Subgraph)); Gradient gradient = differentiate(diff_graph); runNondiffOptimization(gradient.f); packGradient(gradient, dnode); } - InlineAutodiffSubgraphs(opt_graph, autodiffSubgraphInlineThreshold); + InlineAutodiffSubgraphs( + opt_graph, + autodiff_subgraph_inlining ? autodiffSubgraphInlineThreshold : 1); } else { runNondiffOptimization(opt_graph); } @@ -731,10 +740,6 @@ struct GraphExecutorImpl { // GraphExecutors can be accessed from multiple threads, so this thread needs // to be held every time we access the fallback or plan_cache. std::mutex compile_mutex; - - // Some tunable parameters - size_t autodiffSubgraphNodeThreshold = 2; - size_t autodiffSubgraphInlineThreshold = 5; }; GraphExecutor::GraphExecutor(std::shared_ptr graph, bool optimize) @@ -756,10 +761,6 @@ GraphExecutorState GraphExecutor::getDebugState() { return pImpl->getDebugState(); } -void GraphExecutor::debugDisableAutodiffSubgraphInlining() { - return pImpl->debugDisableAutodiffSubgraphInlining(); -} - void runRequiredPasses(const std::shared_ptr& g) { specializeAutogradZero(*g); LowerGradOf(*g); diff --git a/torch/csrc/jit/graph_executor.h b/torch/csrc/jit/graph_executor.h index e6564e5..9066d9f 100644 --- a/torch/csrc/jit/graph_executor.h +++ b/torch/csrc/jit/graph_executor.h @@ -37,7 +37,6 @@ struct TORCH_API GraphExecutor { std::shared_ptr graph() const; std::shared_ptr graphFor(const Stack& inputs) const; GraphExecutorState getDebugState(); - void debugDisableAutodiffSubgraphInlining(); private: std::shared_ptr pImpl; @@ -47,6 +46,8 @@ struct TORCH_API GraphExecutor { // regardless of whether sizes have been specialized or not. TORCH_API void runRequiredPasses(const std::shared_ptr& g); +TORCH_API void debugSetAutodiffSubgraphInlining(bool state); + namespace detail { GraphExecutor* getGradExecutor(Operation& op); diff --git a/torch/csrc/jit/script/compilation_unit.h b/torch/csrc/jit/script/compilation_unit.h index 345431b..434b862 100644 --- a/torch/csrc/jit/script/compilation_unit.h +++ b/torch/csrc/jit/script/compilation_unit.h @@ -108,10 +108,6 @@ struct TORCH_API Function { return get_executor().getDebugState(); } - void debugDisableAutodiffSubgraphInlining() { - return get_executor().debugDisableAutodiffSubgraphInlining(); - } - bool is_optimized() const { return optimize_; } diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index 1aeadea..668faa5 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -925,14 +925,6 @@ void initJitScriptBindings(PyObject* module) { "Attempted to call get_debug_state on a Module without a compiled forward()"); }) .def( - "debug_disable_autodiff_subgraph_inlining", - [](Module& self) { - if (self.find_method("forward")) { - Method& m = self.get_method("forward"); - m.get_executor().debugDisableAutodiffSubgraphInlining(); - } - }) - .def( "forward", [](py::args args, py::kwargs kwargs) { // We implement this in C++ to avoid incurring the pybind11 dispatch @@ -1040,11 +1032,6 @@ void initJitScriptBindings(PyObject* module) { return self.graph_for(createStackForSchema( self.getSchema(), tuple_slice(std::move(args), 1), kwargs)); }) - .def( - "debug_disable_autodiff_subgraph_inlining", - [](Method& m) { - return m.get_executor().debugDisableAutodiffSubgraphInlining(); - }) .def("schema", &Method::getSchema) .def( "pretty_print_schema", @@ -1142,6 +1129,8 @@ void initJitScriptBindings(PyObject* module) { m.def("_jit_import_methods", import_methods); m.def("_jit_set_emit_module_hook", setEmitModuleHook); m.def("_jit_clear_class_registry", ClassType::clearRegistry); + m.def( + "_debug_set_autodiff_subgraph_inlining", debugSetAutodiffSubgraphInlining); py::class_(m, "FileCheck") .def(py::init<>())