From b9c20d5224caea1f360285f765d351014fd29df5 Mon Sep 17 00:00:00 2001 From: Zachary DeVito Date: Tue, 16 Apr 2019 09:01:03 -0700 Subject: [PATCH] graph_for based on last_optimized_executed_graph (#19142) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19142 ghimport-source-id: 822013fb7e93032c74867fc77c6774c680aef6d1 Differential Revision: D14888703 Pulled By: zdevito fbshipit-source-id: a2ad65a042d08b1adef965c2cceef37bb5d26ba9 --- torch/csrc/jit/graph_executor.cpp | 28 +++++++-------------------- torch/csrc/jit/graph_executor.h | 8 +++++++- torch/csrc/jit/script/compilation_unit.h | 4 ---- torch/csrc/jit/script/init.cpp | 33 +++++++------------------------- torch/csrc/jit/script/module.h | 8 -------- torch/jit/__init__.py | 11 +++++++++++ 6 files changed, 32 insertions(+), 60 deletions(-) diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index 92d149f..9ce52ee 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -54,6 +54,11 @@ void debugSetAutodiffSubgraphInlining(bool state) { autodiff_subgraph_inlining = state; } +thread_local std::weak_ptr last_executed_optimized_graph; +std::shared_ptr lastExecutedOptimizedGraph() { + return last_executed_optimized_graph.lock(); +} + namespace { using tensor_list = std::vector; @@ -72,7 +77,8 @@ struct ExecutionPlan { : code(graph), graph(std::move(graph)) {} void run(Stack& stack) const { - return InterpreterState(code).run(stack); + InterpreterState(code).run(stack); + last_executed_optimized_graph = graph; } operator bool() const { @@ -519,22 +525,6 @@ struct GraphExecutorImpl { return execution_plan.run(stack); } - std::shared_ptr graphFor(const Stack& stack) const { - AT_ASSERT(stack.size() >= num_inputs); - - ArgumentSpec spec = - arg_spec_creator_.create(autograd::GradMode::is_enabled(), stack); - - if (!optimize) { - AT_CHECK(fallback, "No graph found for given inputs"); - return fallback.graph; - } - - auto it = plan_cache.find(spec); - AT_CHECK(it != plan_cache.end(), "No graph found for given inputs"); - return it->second.graph; - } - GraphExecutorState getDebugState() { GraphExecutorState state; state.graph = graph.get(); @@ -753,10 +743,6 @@ std::shared_ptr GraphExecutor::graph() const { return pImpl->graph; } -std::shared_ptr GraphExecutor::graphFor(const Stack& inputs) const { - return pImpl->graphFor(inputs); -} - GraphExecutorState GraphExecutor::getDebugState() { return pImpl->getDebugState(); } diff --git a/torch/csrc/jit/graph_executor.h b/torch/csrc/jit/graph_executor.h index 9066d9f..e2906a5 100644 --- a/torch/csrc/jit/graph_executor.h +++ b/torch/csrc/jit/graph_executor.h @@ -35,7 +35,6 @@ struct TORCH_API GraphExecutor { return pImpl != nullptr; } std::shared_ptr graph() const; - std::shared_ptr graphFor(const Stack& inputs) const; GraphExecutorState getDebugState(); private: @@ -47,11 +46,18 @@ struct TORCH_API GraphExecutor { TORCH_API void runRequiredPasses(const std::shared_ptr& g); TORCH_API void debugSetAutodiffSubgraphInlining(bool state); +TORCH_API std::shared_ptr lastExecutedOptimizedGraph(); namespace detail { GraphExecutor* getGradExecutor(Operation& op); +// for debugging information we expose a way to get the last actually +// run graph. Previous approaches allowed querying the GraphExecutor +// for what graph it would run in certain circumstances (graphFor), but +// this is fragile because we sometimes change how these decisions are made. +// This interface still allows our tests to look at optimized graphs, but +// with less plumbing. } // namespace detail } // namespace jit diff --git a/torch/csrc/jit/script/compilation_unit.h b/torch/csrc/jit/script/compilation_unit.h index 434b862..c3da6e2 100644 --- a/torch/csrc/jit/script/compilation_unit.h +++ b/torch/csrc/jit/script/compilation_unit.h @@ -66,10 +66,6 @@ struct TORCH_API Function { return stack.front(); } - std::shared_ptr graph_for(Stack inputs) { - return get_executor().graphFor(inputs); - } - std::shared_ptr graph() const { return graph_; } diff --git a/torch/csrc/jit/script/init.cpp b/torch/csrc/jit/script/init.cpp index cd7ddaa..f701cf8 100644 --- a/torch/csrc/jit/script/init.cpp +++ b/torch/csrc/jit/script/init.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -952,23 +953,6 @@ void initJitScriptBindings(PyObject* module) { didFinishEmitModule(self); }) .def( - "graph_for", - [](py::args args, py::kwargs kwargs) { - // [pybind11 varargs] note: old version of pybind11 have a bug that - // leaks memory when py::args is mixed with positional arguments - // https://github.com/pybind/pybind11/pull/1216 - // we work around this by not mixing positional arguments with - // varargs - Module& self = py::cast(args[0]); - if (self.find_method("forward")) { - Method& m = self.get_method("forward"); - return m.graph_for(createStackForSchema( - m.getSchema(), tuple_slice(std::move(args), 1), kwargs)); - } - throw std::runtime_error( - "Attempted to call graph_for on a Module without a compiled forward()"); - }) - .def( "get_debug_state", [](Module& self) { if (self.find_method("forward")) { @@ -1047,14 +1031,6 @@ void initJitScriptBindings(PyObject* module) { } return tensors; }) - .def( - "graph_for", - [](py::args args, py::kwargs kwargs) { - // see: [pybind11 varargs] - Method& self = py::cast(args[0]); - return self.graph_for(createStackForSchema( - self.getSchema(), tuple_slice(std::move(args), 1), kwargs)); - }) .def_property_readonly("schema", &Method::getSchema) .def_property_readonly("code", [](Method& self) { std::ostringstream ss; @@ -1136,7 +1112,8 @@ void initJitScriptBindings(PyObject* module) { m.def("_jit_set_emit_module_hook", setEmitModuleHook); m.def("_jit_clear_class_registry", ClassType::clearRegistry); m.def( - "_debug_set_autodiff_subgraph_inlining", debugSetAutodiffSubgraphInlining); + "_debug_set_autodiff_subgraph_inlining", + debugSetAutodiffSubgraphInlining); m.def("_propagate_shapes", _propagate_shapes); m.def( "_propagate_and_assign_input_and_output_shapes", @@ -1154,6 +1131,10 @@ void initJitScriptBindings(PyObject* module) { } return std::make_pair(ss.str(), std::move(constants)); }); + m.def( + "_last_executed_optimized_graph", + []() { return lastExecutedOptimizedGraph(); }, + "Retrieve the optimized graph that was run the last time the graph executor ran on this thread"); py::class_(m, "FileCheck") .def(py::init<>()) diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index e99436c..1b477ae 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -94,14 +94,6 @@ struct TORCH_API Method { return initial_ivalues_; } - // proxies for underlying unbound Function - std::shared_ptr graph_for(Stack inputs) { - for (auto tp : initial_ivalues_) { - inputs.emplace_back(tp.value()); - } - return function_->get_executor().graphFor(inputs); - } - std::shared_ptr graph() const { return function_->graph(); } diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index b4d33ced..5989dbd 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -1208,6 +1208,9 @@ if _enabled: "Mixed serialization of script and non-script modules is not supported. " + "For purely script modules use my_script_module.save() instead.") + def graph_for(self, *args, **kwargs): + return self._get_method('forward').graph_for(*args, **kwargs) + class WeakScriptModuleProxy(ScriptModule): def __init__(self, original, stubs): # Guards behavior of __setattr__ and __getattr__ so ScriptModule @@ -1548,6 +1551,14 @@ def annotate(the_type, the_value): Attribute = collections.namedtuple('Attribute', ['value', 'type']) +last_executed_optimized_graph = torch._C._last_executed_optimized_graph + + +def _graph_for(self, *args, **kwargs): + self(*args, **kwargs) + return last_executed_optimized_graph() + +torch._C.ScriptMethod.graph_for = _graph_for if not torch._C._jit_init(): raise RuntimeError("JIT initialization failed") -- 2.7.4