Make debug subgraph inlining thread local (#19136)
authorZachary DeVito <zdevito@fb.com>
Sat, 13 Apr 2019 15:28:11 +0000 (08:28 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 13 Apr 2019 15:42:14 +0000 (08:42 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19136
ghimport-source-id: 3a24ab36aa753ce5cce7bba3467bdbe88e5c7f60

Reviewed By: jamesr66a

Differential Revision: D14885051

Pulled By: zdevito

fbshipit-source-id: b39c6ceef73ad9caefcbf8f40dd1b9132bba03c2

test/test_jit.py
torch/csrc/jit/graph_executor.cpp
torch/csrc/jit/graph_executor.h
torch/csrc/jit/script/compilation_unit.h
torch/csrc/jit/script/init.cpp

index ea9c853..015c67c 100644 (file)
@@ -252,6 +252,13 @@ def enable_cpu_fuser(fn):
             torch._C._jit_override_can_fuse_on_cpu(False)
     return wrapper
 
+# note: not re-entrant, use unnested only
+@contextmanager
+def disable_autodiff_subgraph_inlining(enabled=True):
+    torch._C._debug_set_autodiff_subgraph_inlining(not enabled)
+    yield
+    torch._C._debug_set_autodiff_subgraph_inlining(True)
+
 
 # helper function to get sum of List[Tensor]
 def _sum_of_list(tensorlist):
@@ -3744,20 +3751,20 @@ a")
         @torch.jit.script
         def func2(x, y):
             return torch.cat((x, x), y)
-        func2.debug_disable_autodiff_subgraph_inlining()
+        with disable_autodiff_subgraph_inlining():
 
-        x = torch.rand([2, 2]).requires_grad_()
-        y = torch.tensor(1)
+            x = torch.rand([2, 2]).requires_grad_()
+            y = torch.tensor(1)
 
-        output = func2(x, y)
-        output_ref = torch.cat((x, x), y)
-        self.assertEqual(output, output_ref)
+            output = func2(x, y)
+            output_ref = torch.cat((x, x), y)
+            self.assertEqual(output, output_ref)
 
-        self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::cat'], [])
+            self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::cat'], [])
 
-        grad = torch.autograd.grad(output.sum(), x)
-        grad_ref = torch.autograd.grad(output_ref.sum(), x)
-        self.assertEqual(grad, grad_ref)
+            grad = torch.autograd.grad(output.sum(), x)
+            grad_ref = torch.autograd.grad(output_ref.sum(), x)
+            self.assertEqual(grad, grad_ref)
 
     def test_cat_lifts(self):
         @torch.jit.script
@@ -3787,60 +3794,57 @@ a")
         def func2(x, y):
             return torch.stack((x, y), dim=0)
 
-        func2.debug_disable_autodiff_subgraph_inlining()
-
-        x = torch.randn([2, 2]).requires_grad_()
-        y = torch.randn([2, 2]).requires_grad_()
+        with disable_autodiff_subgraph_inlining():
+            x = torch.randn([2, 2]).requires_grad_()
+            y = torch.randn([2, 2]).requires_grad_()
 
-        output = func2(x, y)
-        output_ref = torch.stack((x, y), 0)
-        self.assertEqual(output, output_ref)
+            output = func2(x, y)
+            output_ref = torch.stack((x, y), 0)
+            self.assertEqual(output, output_ref)
 
-        self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::stack'], [])
+            self.assertAutodiffNode(func2.graph_for(x, y), True, ['aten::stack'], [])
 
-        grads = torch.autograd.grad(output.sum(), (x, y))
-        grads_ref = torch.autograd.grad(output_ref.sum(), (x, y))
-        self.assertEqual(grads, grads_ref)
+            grads = torch.autograd.grad(output.sum(), (x, y))
+            grads_ref = torch.autograd.grad(output_ref.sum(), (x, y))
+            self.assertEqual(grads, grads_ref)
 
     def test_unbind(self):
         @torch.jit.script
         def func(x, y):
             # type: (Tensor, int) -> List[Tensor]
             return torch.unbind(x, y)
-        func.debug_disable_autodiff_subgraph_inlining()
-
-        x = torch.rand([2, 2]).requires_grad_()
-        y = 0
-        outputs = func(x, y)
-        outputs_ref = torch.unbind(x, dim=y)
-        self.assertEqual(outputs, outputs_ref)
+        with disable_autodiff_subgraph_inlining():
+            x = torch.rand([2, 2]).requires_grad_()
+            y = 0
+            outputs = func(x, y)
+            outputs_ref = torch.unbind(x, dim=y)
+            self.assertEqual(outputs, outputs_ref)
 
-        self.assertAutodiffNode(func.graph_for(x, y), True, ['aten::unbind'], [])
+            self.assertAutodiffNode(func.graph_for(x, y), True, ['aten::unbind'], [])
 
-        grad = torch.autograd.grad(_sum_of_list(outputs), x)
-        grad_ref = torch.autograd.grad(_sum_of_list(outputs_ref), x)
-        self.assertEqual(grad, grad_ref)
+            grad = torch.autograd.grad(_sum_of_list(outputs), x)
+            grad_ref = torch.autograd.grad(_sum_of_list(outputs_ref), x)
+            self.assertEqual(grad, grad_ref)
 
     def test_meshgrid(self):
         @torch.jit.script
         def func(a):
             # type: (List[Tensor]) -> List[Tensor]
             return torch.meshgrid(a)
-        func.debug_disable_autodiff_subgraph_inlining()
+        with disable_autodiff_subgraph_inlining():
+            a = torch.tensor([1.0, 2, 3]).requires_grad_()
+            b = torch.tensor([1.0, 2, 3, 4]).requires_grad_()
+            inputs = [a, b]
 
-        a = torch.tensor([1.0, 2, 3]).requires_grad_()
-        b = torch.tensor([1.0, 2, 3, 4]).requires_grad_()
-        inputs = [a, b]
+            outputs_ref = torch.meshgrid(inputs)
+            outputs = func(inputs)
+            self.assertEqual(outputs, outputs_ref)
 
-        outputs_ref = torch.meshgrid(inputs)
-        outputs = func(inputs)
-        self.assertEqual(outputs, outputs_ref)
+            self.assertAutodiffNode(func.graph_for(inputs), True, ['aten::meshgrid'], [])
 
-        self.assertAutodiffNode(func.graph_for(inputs), True, ['aten::meshgrid'], [])
-
-        grads = torch.autograd.grad(_sum_of_list(outputs), inputs)
-        grads_ref = torch.autograd.grad(_sum_of_list(outputs_ref), inputs)
-        self.assertEqual(grads, grads_ref)
+            grads = torch.autograd.grad(_sum_of_list(outputs), inputs)
+            grads_ref = torch.autograd.grad(_sum_of_list(outputs_ref), inputs)
+            self.assertEqual(grads, grads_ref)
 
     def test_list_literal(self):
         def reassign():
@@ -5785,14 +5789,13 @@ a")
         def func2(t, t_ref):
             return t.to(t_ref)
 
-        func2.debug_disable_autodiff_subgraph_inlining()
-
-        t_ref = torch.tensor(4).double()
-        out_ref = t.to(t_ref)
-        out = func2(t, t_ref)
-        grad_ref = torch.autograd.grad(out_ref.sum(), t)
-        grad = torch.autograd.grad(out.sum(), t)
-        self.assertEqual(grad_ref, grad)
+        with disable_autodiff_subgraph_inlining():
+            t_ref = torch.tensor(4).double()
+            out_ref = t.to(t_ref)
+            out = func2(t, t_ref)
+            grad_ref = torch.autograd.grad(out_ref.sum(), t)
+            grad = torch.autograd.grad(out.sum(), t)
+            self.assertEqual(grad_ref, grad)
 
     @unittest.skipIf(not RUN_CUDA, "No CUDA")
     def test_tensor_number_math_cuda(self):
@@ -12063,17 +12066,11 @@ def partial_apply_nontensors(fn, args, **kwargs):
 
 
 # create a trace function from input fn
-#
-# disable_autodiff_subgraph_inlining:
-#   Don't inline autodiff subgraphs so we can test autodiff
-def create_traced_fn(self, fn,
-                     disable_autodiff_subgraph_inlining=False):
+def create_traced_fn(self, fn):
     def traced_fn(*inputs, **kwargs):
         fn_tensors, inputs_tensors = partial_apply_nontensors(fn, inputs, **kwargs)
         traced = torch.jit.trace(fn_tensors, inputs_tensors)
         self.assertExportImport(traced.graph, inputs_tensors)
-        if disable_autodiff_subgraph_inlining:
-            traced.debug_disable_autodiff_subgraph_inlining()
         output = traced(*inputs_tensors)
         traced_fn.last_graph = traced.graph_for(*inputs_tensors)
         return output
@@ -12118,8 +12115,7 @@ def get_script_args(args):
 # create a script function from (name, func_type, output_process_fn),
 # returns a function takes in (args, kwargs) and runs the compiled function and
 # then applies the post process fn to the outputs
-def create_script_fn(self, method_name, func_type, output_process_fn,
-                     disable_autodiff_subgraph_inlining=False):
+def create_script_fn(self, method_name, func_type, output_process_fn):
     def script_fn(*args, **kwargs):
         formals, tensors, actuals = get_script_args(args)
         kwargs_str = ''
@@ -12137,8 +12133,6 @@ def create_script_fn(self, method_name, func_type, output_process_fn,
         script = script_template.format(', '.join(formals), call)
 
         CU = torch.jit.CompilationUnit(script)
-        if disable_autodiff_subgraph_inlining:
-            CU.the_method.debug_disable_autodiff_subgraph_inlining()
         self.assertExportImport(CU.the_method.graph, tensors)
         output = output_process_fn(CU.the_method(*tensors))
         script_fn.last_graph = CU.the_method.graph_for(*tensors)
@@ -12249,11 +12243,11 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
     # TODO: It is better if we can test directly on graphs instead of the current
     # end-to-end fashion.
     def _perform_ad_subgraph_slicing(self, fn, *input_sizes):
-        ge = torch.jit.script(fn)
-        ge.debug_disable_autodiff_subgraph_inlining()
-        inputs = [torch.randn(size, requires_grad=True) for size in input_sizes]
-        ge(*inputs)
-        return ge.graph_for(*inputs)
+        with disable_autodiff_subgraph_inlining():
+            ge = torch.jit.script(fn)
+            inputs = [torch.randn(size, requires_grad=True) for size in input_sizes]
+            ge(*inputs)
+            return ge.graph_for(*inputs)
 
     def assertGraphSize(self, graph, size):
         self.assertEqual(len(list(graph.nodes())), size)
@@ -12265,9 +12259,9 @@ class TestAutodiffSubgraphSlicing(JitTestCase):
             return (x1, x2)
 
         input = torch.rand(6, 10).requires_grad_()
-        func.debug_disable_autodiff_subgraph_inlining()
-        output = func(input)
-        self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], [])
+        with disable_autodiff_subgraph_inlining():
+            output = func(input)
+            self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], [])
 
     def test_simple_merge(self):
         # o --> o
@@ -12850,53 +12844,51 @@ def add_autograd_test(
                     return output_process_fn(output)
 
                 check_types = test_name not in EXCLUDE_TYPE_CHECK
-
-                if not is_inplace and name not in EXCLUDE_GRADCHECK and not exclude_tensor_method(name, test_name):
-                    # Test with disable_autodiff_subgraph_inlining, which forces the graph
-                    # to contain DifferentiableGraph nodes whenever possible. This allows us
-                    # to test autodiff; we assume that autograd is correct and use autodiff for backprop
-                    should_autodiff_node, autodiff_nodes, fusible_nodes = normalize_check_ad(check_ad, name)
-                    if test_name not in EXCLUDE_TRACED:
-                        traced_fn = create_traced_fn(self, fn, disable_autodiff_subgraph_inlining=True)
-
-                        check_against_reference(self, traced_fn,
-                                                fn, (self_variable,) + args_variable, kwargs_variable,
-                                                check_types=check_types)
-                        self.assertAutodiffNode(traced_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)
-
-                    if not is_magic_method and test_name not in EXCLUDE_SCRIPT:
-                        script_fn = create_script_fn(self, name, 'method', output_process_fn,
-                                                     disable_autodiff_subgraph_inlining=True)
-                        check_against_reference(self, script_fn,
-                                                fn, (self_variable,) + args_variable, kwargs_variable,
-                                                check_types=check_types)
-
-                        self.assertAutodiffNode(script_fn.last_graph,
-                                                should_autodiff_node and test_name not in EXCLUDE_SCRIPT_AD_CHECK,
-                                                autodiff_nodes,
-                                                fusible_nodes)
-
-                # functional interface tests
-                if hasattr(torch, name) and name not in EXCLUDE_FUNCTIONAL:
-                    def fn(*inputs, **kwargs):
-                        output = getattr(torch, name)(*inputs, **kwargs)
-                        return output_process_fn(output)
-
-                    f_args_variable = (self_variable,) + args_variable
-                    f_args_tensor = (self_tensor,) + args_tensor
-
-                    if not is_inplace and test_name not in EXCLUDE_TRACED:
-                        check_against_reference(self,
-                                                create_traced_fn(self, fn,
-                                                                 disable_autodiff_subgraph_inlining=True),
-                                                fn, f_args_variable, kwargs_variable, check_types=check_types)
-
-                    if not is_inplace and test_name not in EXCLUDE_SCRIPT:
-                        check_against_reference(self,
-                                                create_script_fn(self, name, 'functional', output_process_fn,
-                                                                 disable_autodiff_subgraph_inlining=True),
-                                                fn, f_args_variable, kwargs_variable,
-                                                check_types=check_types)
+                with disable_autodiff_subgraph_inlining():
+                    if not is_inplace and name not in EXCLUDE_GRADCHECK and not exclude_tensor_method(name, test_name):
+                        # Test with disable_autodiff_subgraph_inlining, which forces the graph
+                        # to contain DifferentiableGraph nodes whenever possible. This allows us
+                        # to test autodiff; we assume that autograd is correct and use autodiff for backprop
+                        should_autodiff_node, autodiff_nodes, fusible_nodes = normalize_check_ad(check_ad, name)
+
+                        if test_name not in EXCLUDE_TRACED:
+                            traced_fn = create_traced_fn(self, fn)
+
+                            check_against_reference(self, traced_fn,
+                                                    fn, (self_variable,) + args_variable, kwargs_variable,
+                                                    check_types=check_types)
+                            self.assertAutodiffNode(traced_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)
+
+                        if not is_magic_method and test_name not in EXCLUDE_SCRIPT:
+                            script_fn = create_script_fn(self, name, 'method', output_process_fn)
+                            check_against_reference(self, script_fn,
+                                                    fn, (self_variable,) + args_variable, kwargs_variable,
+                                                    check_types=check_types)
+
+                            self.assertAutodiffNode(script_fn.last_graph,
+                                                    should_autodiff_node and test_name not in EXCLUDE_SCRIPT_AD_CHECK,
+                                                    autodiff_nodes,
+                                                    fusible_nodes)
+
+                    # functional interface tests
+                    if hasattr(torch, name) and name not in EXCLUDE_FUNCTIONAL:
+                        def fn(*inputs, **kwargs):
+                            output = getattr(torch, name)(*inputs, **kwargs)
+                            return output_process_fn(output)
+
+                        f_args_variable = (self_variable,) + args_variable
+                        f_args_tensor = (self_tensor,) + args_tensor
+
+                        if not is_inplace and test_name not in EXCLUDE_TRACED:
+                            check_against_reference(self,
+                                                    create_traced_fn(self, fn),
+                                                    fn, f_args_variable, kwargs_variable, check_types=check_types)
+
+                        if not is_inplace and test_name not in EXCLUDE_SCRIPT:
+                            check_against_reference(self,
+                                                    create_script_fn(self, name, 'functional', output_process_fn),
+                                                    fn, f_args_variable, kwargs_variable,
+                                                    check_types=check_types)
 
                 # alias annotation testing
                 if is_inplace and test_name not in EXCLUDE_SCRIPT:
@@ -12954,11 +12946,11 @@ def add_nn_functional_test(name, self_size, args, variant_name='', check_ad=(),
         should_autodiff_node, autodiff_nodes, fusible_nodes = normalize_check_ad(check_ad, name)
         if test_name not in EXCLUDE_SCRIPT:
             def run_test():
-                script_fn = create_script_fn(self, name, 'nn_functional', output_process_fn,
-                                             disable_autodiff_subgraph_inlining=should_autodiff_node)
-                check_against_reference(self, script_fn, fn, f_args_variable, kwargs_variable, no_grad=no_grad)
-                # For tests we disabled AD subgraph inlining, make sure it's not falling back to autograd
-                self.assertAutodiffNode(script_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)
+                with disable_autodiff_subgraph_inlining(should_autodiff_node):
+                    script_fn = create_script_fn(self, name, 'nn_functional', output_process_fn)
+                    check_against_reference(self, script_fn, fn, f_args_variable, kwargs_variable, no_grad=no_grad)
+                    # For tests we disabled AD subgraph inlining, make sure it's not falling back to autograd
+                    self.assertAutodiffNode(script_fn.last_graph, should_autodiff_node, autodiff_nodes, fusible_nodes)
 
             if test_name in EXCLUDE_PYTHON_PRINT:
                 with self.disableModuleHook():
index 966246c..92d149f 100644 (file)
 namespace torch {
 namespace jit {
 
+// for debugging it is helpful to be able to force autodiff subgraphs
+// to be created, to check their correctness, even when the
+// size of the of the subgraph is too small to be profitable.
+thread_local bool autodiff_subgraph_inlining = true;
+void debugSetAutodiffSubgraphInlining(bool state) {
+  autodiff_subgraph_inlining = state;
+}
+
 namespace {
 
 using tensor_list = std::vector<at::Tensor>;
 using Variable = autograd::Variable;
 using autograd::variable_list;
 
+// Tunable parameters for deciding when to create/keep subgraphs of
+// differentiable code
+
+const size_t autodiffSubgraphNodeThreshold = 2;
+const size_t autodiffSubgraphInlineThreshold = 5;
+
 struct ExecutionPlan {
   ExecutionPlan() = default;
   ExecutionPlan(std::shared_ptr<Graph> graph)
@@ -480,9 +494,9 @@ struct GraphExecutorImpl {
         num_inputs(this->graph->inputs().size()),
         arg_spec_creator_(*graph),
         num_outputs(this->graph->outputs().size()) {
-          logging::getLogger()->addStatValue(
-              logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0);
-        }
+    logging::getLogger()->addStatValue(
+        logging::runtime_counters::GRAPH_EXECUTORS_CONSTRUCTED, 1.0);
+  }
 
   // entry point where execution begins
   void run(Stack& stack) {
@@ -533,14 +547,6 @@ struct GraphExecutorImpl {
     return state;
   }
 
-  // This function should be used only for testing purposes
-  void debugDisableAutodiffSubgraphInlining() {
-    // Allow single-node autodiff subgraphs
-    autodiffSubgraphNodeThreshold = 1;
-    // Don't inline autodiff subgraphs into autograd functions
-    autodiffSubgraphInlineThreshold = 1;
-  }
-
  private:
   friend struct GraphExecutor;
 
@@ -603,15 +609,18 @@ struct GraphExecutorImpl {
     // Phase 5. Apply non-differentiable optimizations to the graphs we've found
     //          (or the whole grpah if we know we won't need its derivative).
     if (needsGradient(opt_graph)) {
-      auto diff_nodes =
-          CreateAutodiffSubgraphs(opt_graph, autodiffSubgraphNodeThreshold);
+      auto diff_nodes = CreateAutodiffSubgraphs(
+          opt_graph,
+          autodiff_subgraph_inlining ? autodiffSubgraphNodeThreshold : 1);
       for (Node* dnode : diff_nodes) {
         auto diff_graph = std::move(dnode->g(attr::Subgraph));
         Gradient gradient = differentiate(diff_graph);
         runNondiffOptimization(gradient.f);
         packGradient(gradient, dnode);
       }
-      InlineAutodiffSubgraphs(opt_graph, autodiffSubgraphInlineThreshold);
+      InlineAutodiffSubgraphs(
+          opt_graph,
+          autodiff_subgraph_inlining ? autodiffSubgraphInlineThreshold : 1);
     } else {
       runNondiffOptimization(opt_graph);
     }
@@ -731,10 +740,6 @@ struct GraphExecutorImpl {
   // GraphExecutors can be accessed from multiple threads, so this thread needs
   // to be held every time we access the fallback or plan_cache.
   std::mutex compile_mutex;
-
-  // Some tunable parameters
-  size_t autodiffSubgraphNodeThreshold = 2;
-  size_t autodiffSubgraphInlineThreshold = 5;
 };
 
 GraphExecutor::GraphExecutor(std::shared_ptr<Graph> graph, bool optimize)
@@ -756,10 +761,6 @@ GraphExecutorState GraphExecutor::getDebugState() {
   return pImpl->getDebugState();
 }
 
-void GraphExecutor::debugDisableAutodiffSubgraphInlining() {
-  return pImpl->debugDisableAutodiffSubgraphInlining();
-}
-
 void runRequiredPasses(const std::shared_ptr<Graph>& g) {
   specializeAutogradZero(*g);
   LowerGradOf(*g);
index e6564e5..9066d9f 100644 (file)
@@ -37,7 +37,6 @@ struct TORCH_API GraphExecutor {
   std::shared_ptr<Graph> graph() const;
   std::shared_ptr<Graph> graphFor(const Stack& inputs) const;
   GraphExecutorState getDebugState();
-  void debugDisableAutodiffSubgraphInlining();
 
  private:
   std::shared_ptr<GraphExecutorImpl> pImpl;
@@ -47,6 +46,8 @@ struct TORCH_API GraphExecutor {
 // regardless of whether sizes have been specialized or not.
 TORCH_API void runRequiredPasses(const std::shared_ptr<Graph>& g);
 
+TORCH_API void debugSetAutodiffSubgraphInlining(bool state);
+
 namespace detail {
 
 GraphExecutor* getGradExecutor(Operation& op);
index 345431b..434b862 100644 (file)
@@ -108,10 +108,6 @@ struct TORCH_API Function {
     return get_executor().getDebugState();
   }
 
-  void debugDisableAutodiffSubgraphInlining() {
-    return get_executor().debugDisableAutodiffSubgraphInlining();
-  }
-
   bool is_optimized() const {
     return optimize_;
   }
index 1aeadea..668faa5 100644 (file)
@@ -925,14 +925,6 @@ void initJitScriptBindings(PyObject* module) {
                 "Attempted to call get_debug_state on a Module without a compiled forward()");
           })
       .def(
-          "debug_disable_autodiff_subgraph_inlining",
-          [](Module& self) {
-            if (self.find_method("forward")) {
-              Method& m = self.get_method("forward");
-              m.get_executor().debugDisableAutodiffSubgraphInlining();
-            }
-          })
-      .def(
           "forward",
           [](py::args args, py::kwargs kwargs) {
             // We implement this in C++ to avoid incurring the pybind11 dispatch
@@ -1040,11 +1032,6 @@ void initJitScriptBindings(PyObject* module) {
             return self.graph_for(createStackForSchema(
                 self.getSchema(), tuple_slice(std::move(args), 1), kwargs));
           })
-      .def(
-          "debug_disable_autodiff_subgraph_inlining",
-          [](Method& m) {
-            return m.get_executor().debugDisableAutodiffSubgraphInlining();
-          })
       .def("schema", &Method::getSchema)
       .def(
           "pretty_print_schema",
@@ -1142,6 +1129,8 @@ void initJitScriptBindings(PyObject* module) {
   m.def("_jit_import_methods", import_methods);
   m.def("_jit_set_emit_module_hook", setEmitModuleHook);
   m.def("_jit_clear_class_registry", ClassType::clearRegistry);
+  m.def(
+      "_debug_set_autodiff_subgraph_inlining", debugSetAutodiffSubgraphInlining);
 
   py::class_<testing::FileCheck>(m, "FileCheck")
       .def(py::init<>())