graph_for based on last_optimized_executed_graph (#19142)
authorZachary DeVito <zdevito@fb.com>
Tue, 16 Apr 2019 16:01:03 +0000 (09:01 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 16 Apr 2019 16:17:53 +0000 (09:17 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19142
ghimport-source-id: 822013fb7e93032c74867fc77c6774c680aef6d1

Differential Revision: D14888703

Pulled By: zdevito

fbshipit-source-id: a2ad65a042d08b1adef965c2cceef37bb5d26ba9

torch/csrc/jit/graph_executor.cpp
torch/csrc/jit/graph_executor.h
torch/csrc/jit/script/compilation_unit.h
torch/csrc/jit/script/init.cpp
torch/csrc/jit/script/module.h
torch/jit/__init__.py

index 92d149f..9ce52ee 100644 (file)
@@ -54,6 +54,11 @@ void debugSetAutodiffSubgraphInlining(bool state) {
   autodiff_subgraph_inlining = state;
 }
 
+thread_local std::weak_ptr<Graph> last_executed_optimized_graph;
+std::shared_ptr<Graph> lastExecutedOptimizedGraph() {
+  return last_executed_optimized_graph.lock();
+}
+
 namespace {
 
 using tensor_list = std::vector<at::Tensor>;
@@ -72,7 +77,8 @@ struct ExecutionPlan {
       : code(graph), graph(std::move(graph)) {}
 
   void run(Stack& stack) const {
-    return InterpreterState(code).run(stack);
+    InterpreterState(code).run(stack);
+    last_executed_optimized_graph = graph;
   }
 
   operator bool() const {
@@ -519,22 +525,6 @@ struct GraphExecutorImpl {
     return execution_plan.run(stack);
   }
 
-  std::shared_ptr<Graph> graphFor(const Stack& stack) const {
-    AT_ASSERT(stack.size() >= num_inputs);
-
-    ArgumentSpec spec =
-        arg_spec_creator_.create(autograd::GradMode::is_enabled(), stack);
-
-    if (!optimize) {
-      AT_CHECK(fallback, "No graph found for given inputs");
-      return fallback.graph;
-    }
-
-    auto it = plan_cache.find(spec);
-    AT_CHECK(it != plan_cache.end(), "No graph found for given inputs");
-    return it->second.graph;
-  }
-
   GraphExecutorState getDebugState() {
     GraphExecutorState state;
     state.graph = graph.get();
@@ -753,10 +743,6 @@ std::shared_ptr<Graph> GraphExecutor::graph() const {
   return pImpl->graph;
 }
 
-std::shared_ptr<Graph> GraphExecutor::graphFor(const Stack& inputs) const {
-  return pImpl->graphFor(inputs);
-}
-
 GraphExecutorState GraphExecutor::getDebugState() {
   return pImpl->getDebugState();
 }
index 9066d9f..e2906a5 100644 (file)
@@ -35,7 +35,6 @@ struct TORCH_API GraphExecutor {
     return pImpl != nullptr;
   }
   std::shared_ptr<Graph> graph() const;
-  std::shared_ptr<Graph> graphFor(const Stack& inputs) const;
   GraphExecutorState getDebugState();
 
  private:
@@ -47,11 +46,18 @@ struct TORCH_API GraphExecutor {
 TORCH_API void runRequiredPasses(const std::shared_ptr<Graph>& g);
 
 TORCH_API void debugSetAutodiffSubgraphInlining(bool state);
+TORCH_API std::shared_ptr<Graph> lastExecutedOptimizedGraph();
 
 namespace detail {
 
 GraphExecutor* getGradExecutor(Operation& op);
 
+// for debugging information we expose a way to get the last actually
+// run graph. Previous approaches allowed querying the GraphExecutor
+// for what graph it would run in certain circumstances (graphFor), but
+// this is fragile because we sometimes change how these decisions are made.
+// This interface still allows our tests to look at optimized graphs, but
+// with less plumbing.
 } // namespace detail
 
 } // namespace jit
index 434b862..c3da6e2 100644 (file)
@@ -66,10 +66,6 @@ struct TORCH_API Function {
     return stack.front();
   }
 
-  std::shared_ptr<Graph> graph_for(Stack inputs) {
-    return get_executor().graphFor(inputs);
-  }
-
   std::shared_ptr<Graph> graph() const {
     return graph_;
   }
index cd7ddaa..f701cf8 100644 (file)
@@ -11,6 +11,7 @@
 #include <torch/csrc/jit/testing/file_check.h>
 
 #include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/graph_executor.h>
 #include <torch/csrc/jit/hooks_for_testing.h>
 #include <torch/csrc/jit/import_source.h>
 #include <torch/csrc/jit/irparser.h>
@@ -952,23 +953,6 @@ void initJitScriptBindings(PyObject* module) {
             didFinishEmitModule(self);
           })
       .def(
-          "graph_for",
-          [](py::args args, py::kwargs kwargs) {
-            // [pybind11 varargs] note: old version of pybind11 have a bug that
-            // leaks memory when py::args is mixed with positional arguments
-            // https://github.com/pybind/pybind11/pull/1216
-            // we work around this by not mixing positional arguments with
-            // varargs
-            Module& self = py::cast<Module&>(args[0]);
-            if (self.find_method("forward")) {
-              Method& m = self.get_method("forward");
-              return m.graph_for(createStackForSchema(
-                  m.getSchema(), tuple_slice(std::move(args), 1), kwargs));
-            }
-            throw std::runtime_error(
-                "Attempted to call graph_for on a Module without a compiled forward()");
-          })
-      .def(
           "get_debug_state",
           [](Module& self) {
             if (self.find_method("forward")) {
@@ -1047,14 +1031,6 @@ void initJitScriptBindings(PyObject* module) {
             }
             return tensors;
           })
-      .def(
-          "graph_for",
-          [](py::args args, py::kwargs kwargs) {
-            // see: [pybind11 varargs]
-            Method& self = py::cast<Method&>(args[0]);
-            return self.graph_for(createStackForSchema(
-                self.getSchema(), tuple_slice(std::move(args), 1), kwargs));
-          })
       .def_property_readonly("schema", &Method::getSchema)
       .def_property_readonly("code", [](Method& self) {
         std::ostringstream ss;
@@ -1136,7 +1112,8 @@ void initJitScriptBindings(PyObject* module) {
   m.def("_jit_set_emit_module_hook", setEmitModuleHook);
   m.def("_jit_clear_class_registry", ClassType::clearRegistry);
   m.def(
-      "_debug_set_autodiff_subgraph_inlining", debugSetAutodiffSubgraphInlining);
+      "_debug_set_autodiff_subgraph_inlining",
+      debugSetAutodiffSubgraphInlining);
   m.def("_propagate_shapes", _propagate_shapes);
   m.def(
       "_propagate_and_assign_input_and_output_shapes",
@@ -1154,6 +1131,10 @@ void initJitScriptBindings(PyObject* module) {
     }
     return std::make_pair(ss.str(), std::move(constants));
   });
+  m.def(
+      "_last_executed_optimized_graph",
+      []() { return lastExecutedOptimizedGraph(); },
+      "Retrieve the optimized graph that was run the last time the graph executor ran on this thread");
 
   py::class_<testing::FileCheck>(m, "FileCheck")
       .def(py::init<>())
index e99436c..1b477ae 100644 (file)
@@ -94,14 +94,6 @@ struct TORCH_API Method {
     return initial_ivalues_;
   }
 
-  // proxies for underlying unbound Function
-  std::shared_ptr<Graph> graph_for(Stack inputs) {
-    for (auto tp : initial_ivalues_) {
-      inputs.emplace_back(tp.value());
-    }
-    return function_->get_executor().graphFor(inputs);
-  }
-
   std::shared_ptr<Graph> graph() const {
     return function_->graph();
   }
index b4d33ce..5989dbd 100644 (file)
@@ -1208,6 +1208,9 @@ if _enabled:
                 "Mixed serialization of script and non-script modules is not supported. " +
                 "For purely script modules use my_script_module.save(<filename>) instead.")
 
+        def graph_for(self, *args, **kwargs):
+            return self._get_method('forward').graph_for(*args, **kwargs)
+
     class WeakScriptModuleProxy(ScriptModule):
         def __init__(self, original, stubs):
             # Guards behavior of __setattr__ and __getattr__ so ScriptModule
@@ -1548,6 +1551,14 @@ def annotate(the_type, the_value):
 
 Attribute = collections.namedtuple('Attribute', ['value', 'type'])
 
+last_executed_optimized_graph = torch._C._last_executed_optimized_graph
+
+
+def _graph_for(self, *args, **kwargs):
+    self(*args, **kwargs)
+    return last_executed_optimized_graph()
+
+torch._C.ScriptMethod.graph_for = _graph_for
 
 if not torch._C._jit_init():
     raise RuntimeError("JIT initialization failed")