autodiff_subgraph_inlining = state;
}
+thread_local std::weak_ptr<Graph> last_executed_optimized_graph;
+std::shared_ptr<Graph> lastExecutedOptimizedGraph() {
+ return last_executed_optimized_graph.lock();
+}
+
namespace {
using tensor_list = std::vector<at::Tensor>;
: code(graph), graph(std::move(graph)) {}
void run(Stack& stack) const {
- return InterpreterState(code).run(stack);
+ InterpreterState(code).run(stack);
+ last_executed_optimized_graph = graph;
}
operator bool() const {
return execution_plan.run(stack);
}
- std::shared_ptr<Graph> graphFor(const Stack& stack) const {
- AT_ASSERT(stack.size() >= num_inputs);
-
- ArgumentSpec spec =
- arg_spec_creator_.create(autograd::GradMode::is_enabled(), stack);
-
- if (!optimize) {
- AT_CHECK(fallback, "No graph found for given inputs");
- return fallback.graph;
- }
-
- auto it = plan_cache.find(spec);
- AT_CHECK(it != plan_cache.end(), "No graph found for given inputs");
- return it->second.graph;
- }
-
GraphExecutorState getDebugState() {
GraphExecutorState state;
state.graph = graph.get();
return pImpl->graph;
}
-std::shared_ptr<Graph> GraphExecutor::graphFor(const Stack& inputs) const {
- return pImpl->graphFor(inputs);
-}
-
GraphExecutorState GraphExecutor::getDebugState() {
return pImpl->getDebugState();
}
return pImpl != nullptr;
}
std::shared_ptr<Graph> graph() const;
- std::shared_ptr<Graph> graphFor(const Stack& inputs) const;
GraphExecutorState getDebugState();
private:
TORCH_API void runRequiredPasses(const std::shared_ptr<Graph>& g);
TORCH_API void debugSetAutodiffSubgraphInlining(bool state);
+TORCH_API std::shared_ptr<Graph> lastExecutedOptimizedGraph();
namespace detail {
GraphExecutor* getGradExecutor(Operation& op);
+// for debugging information we expose a way to get the last actually
+// run graph. Previous approaches allowed querying the GraphExecutor
+// for what graph it would run in certain circumstances (graphFor), but
+// this is fragile because we sometimes change how these decisions are made.
+// This interface still allows our tests to look at optimized graphs, but
+// with less plumbing.
} // namespace detail
} // namespace jit
return stack.front();
}
- std::shared_ptr<Graph> graph_for(Stack inputs) {
- return get_executor().graphFor(inputs);
- }
-
std::shared_ptr<Graph> graph() const {
return graph_;
}
#include <torch/csrc/jit/testing/file_check.h>
#include <torch/csrc/jit/constants.h>
+#include <torch/csrc/jit/graph_executor.h>
#include <torch/csrc/jit/hooks_for_testing.h>
#include <torch/csrc/jit/import_source.h>
#include <torch/csrc/jit/irparser.h>
didFinishEmitModule(self);
})
.def(
- "graph_for",
- [](py::args args, py::kwargs kwargs) {
- // [pybind11 varargs] note: old version of pybind11 have a bug that
- // leaks memory when py::args is mixed with positional arguments
- // https://github.com/pybind/pybind11/pull/1216
- // we work around this by not mixing positional arguments with
- // varargs
- Module& self = py::cast<Module&>(args[0]);
- if (self.find_method("forward")) {
- Method& m = self.get_method("forward");
- return m.graph_for(createStackForSchema(
- m.getSchema(), tuple_slice(std::move(args), 1), kwargs));
- }
- throw std::runtime_error(
- "Attempted to call graph_for on a Module without a compiled forward()");
- })
- .def(
"get_debug_state",
[](Module& self) {
if (self.find_method("forward")) {
}
return tensors;
})
- .def(
- "graph_for",
- [](py::args args, py::kwargs kwargs) {
- // see: [pybind11 varargs]
- Method& self = py::cast<Method&>(args[0]);
- return self.graph_for(createStackForSchema(
- self.getSchema(), tuple_slice(std::move(args), 1), kwargs));
- })
.def_property_readonly("schema", &Method::getSchema)
.def_property_readonly("code", [](Method& self) {
std::ostringstream ss;
m.def("_jit_set_emit_module_hook", setEmitModuleHook);
m.def("_jit_clear_class_registry", ClassType::clearRegistry);
m.def(
- "_debug_set_autodiff_subgraph_inlining", debugSetAutodiffSubgraphInlining);
+ "_debug_set_autodiff_subgraph_inlining",
+ debugSetAutodiffSubgraphInlining);
m.def("_propagate_shapes", _propagate_shapes);
m.def(
"_propagate_and_assign_input_and_output_shapes",
}
return std::make_pair(ss.str(), std::move(constants));
});
+ m.def(
+ "_last_executed_optimized_graph",
+ []() { return lastExecutedOptimizedGraph(); },
+ "Retrieve the optimized graph that was run the last time the graph executor ran on this thread");
py::class_<testing::FileCheck>(m, "FileCheck")
.def(py::init<>())
return initial_ivalues_;
}
- // proxies for underlying unbound Function
- std::shared_ptr<Graph> graph_for(Stack inputs) {
- for (auto tp : initial_ivalues_) {
- inputs.emplace_back(tp.value());
- }
- return function_->get_executor().graphFor(inputs);
- }
-
std::shared_ptr<Graph> graph() const {
return function_->graph();
}
"Mixed serialization of script and non-script modules is not supported. " +
"For purely script modules use my_script_module.save(<filename>) instead.")
+ def graph_for(self, *args, **kwargs):
+ return self._get_method('forward').graph_for(*args, **kwargs)
+
class WeakScriptModuleProxy(ScriptModule):
def __init__(self, original, stubs):
# Guards behavior of __setattr__ and __getattr__ so ScriptModule
Attribute = collections.namedtuple('Attribute', ['value', 'type'])
+last_executed_optimized_graph = torch._C._last_executed_optimized_graph
+
+
+def _graph_for(self, *args, **kwargs):
+ self(*args, **kwargs)
+ return last_executed_optimized_graph()
+
+torch._C.ScriptMethod.graph_for = _graph_for
if not torch._C._jit_init():
raise RuntimeError("JIT initialization failed")