Remove GraphExecutor's python bindings (#19141)
authorZachary DeVito <zdevito@fb.com>
Sat, 13 Apr 2019 15:28:13 +0000 (08:28 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 13 Apr 2019 15:42:24 +0000 (08:42 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19141
ghimport-source-id: 796a41f5514d29959af052fcf5391a2834850a80

Reviewed By: jamesr66a

Differential Revision: D14888702

Pulled By: zdevito

fbshipit-source-id: c280145f08e7bc210434d1c99396a3257b626cf9

test/test_jit.py
torch/contrib/_tensorboard_vis.py
torch/csrc/jit/init.cpp

index 49908a8..9f09362 100644 (file)
@@ -220,7 +220,7 @@ def get_grad_executor(plan_state, diff_graph_idx=None):
             pass
         else:
             raise RuntimeError("Can't get a grad_executor for a non-differentiable graph")
-    grad_executors = list(plan_state.code.grad_executors())
+    grad_executors = list(plan_state.code.grad_executor_states())
     return grad_executors[diff_graph_idx or 0]
 
 
@@ -229,8 +229,8 @@ def backward_graph(script_module, diff_graph_idx=None):
         raise RuntimeError('Expected ScriptModule')
     ge_state = script_module.get_debug_state()
     fwd_plan = get_execution_plan(ge_state)
-    grad_executor = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx)
-    bwd_plan = get_execution_plan(grad_executor.get_debug_state())
+    grad_executor_state = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx)
+    bwd_plan = get_execution_plan(grad_executor_state)
     # Running JIT passes requires that we own the graph (with a shared_ptr).
     # The debug state struct does not own its graph so we make a copy of it.
     return bwd_plan.graph.copy()
@@ -540,11 +540,8 @@ class JitTestCase(TestCase):
         else:
             recording_inputs = reference_tensors
 
-        if isinstance(func, torch._C.Graph):
-            ge = torch._C.GraphExecutor(func, optimize)
-        else:
-            ge = torch.jit.trace(func, input_tensors, optimize=optimize, check_tolerance=check_tolerance,
-                                 _force_outplace=_force_outplace)
+        ge = torch.jit.trace(func, input_tensors, optimize=optimize, check_tolerance=check_tolerance,
+                             _force_outplace=_force_outplace)
 
         if export_import:
             ge = self.getExportImportCopy(ge)
@@ -1440,7 +1437,7 @@ graph(%x : Tensor,
             return MyInplaceFn.apply(x)
 
         x = torch.randn(5, 5)
-        ge = torch._C.GraphExecutor(fn, (x,), lambda var: '', _force_outplace=True)
+        ge = torch.jit.trace(fn, (x,), _force_outplace=True, check_trace=False)
         with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
             ge(x)
 
@@ -1856,7 +1853,7 @@ graph(%x : Tensor,
             return a * b / (a - b) + b
         V = Variable
         a, b = V(torch.rand(1)), V(torch.rand(1))
-        ge = torch._C.GraphExecutor(foo, (a, b), lambda var: '')
+        ge = torch.jit.trace(foo, (a, b))
         a, b = V(torch.rand(1), requires_grad=True), V(
             torch.rand(1), requires_grad=True)
         r, = ge(a, b)
index d9af6e0..3a1613b 100644 (file)
@@ -27,7 +27,7 @@ def visualize(graph, name_prefix='', pb_graph=None, executors_it=None):
     value_map = {}
     pb_graph = pb_graph or graph_pb2.GraphDef()
 
-    if isinstance(graph, (torch._C.GraphExecutor, torch._C.GraphExecutorState)):
+    if isinstance(graph, torch._C.GraphExecutorState):
         visualize_graph_executor(graph, name_prefix, pb_graph,
                                  partial(visualize, pb_graph=pb_graph))
         return pb_graph
@@ -65,9 +65,6 @@ def visualize_graph_executor(state, name_prefix, pb_graph, inline_graph):
     The strategy is to embed all different configurations as independent subgraphs,
     while inlining the original graph as the one that actually produces the values.
     """
-    if isinstance(state, torch._C.GraphExecutor):
-        state = state.get_debug_state()
-
     if state.autograd_fallback_graph is not None:
         visualize(graph=state.autograd_fallback_graph,
                   name_prefix=name_prefix + 'autograd_fallback/',
index b67cc90..ce67201 100644 (file)
@@ -239,9 +239,12 @@ void initJITBindings(PyObject* module) {
       });
   // NOLINTNEXTLINE(bugprone-unused-raii)
   py::class_<ArgumentSpec>(m, "ArgumentSpec");
-  py::class_<Code>(m, "Code").def("grad_executors", [](Code& c) {
-    return py::make_iterator(
-        c.grad_executors().begin(), c.grad_executors().end());
+  py::class_<Code>(m, "Code").def("grad_executor_states", [](Code& c) {
+    std::vector<GraphExecutorState> states;
+    for (auto& e : c.grad_executors()) {
+      states.emplace_back(e->getDebugState());
+    }
+    return states;
   });
 
   py::class_<ExecutionPlanState>(m, "ExecutionPlanState")
@@ -275,50 +278,6 @@ void initJITBindings(PyObject* module) {
       .def_property_readonly(
           "fallback", [](GraphExecutorState& s) { return s.fallback; });
 
-  py::class_<GraphExecutor>(m, "GraphExecutor", py::dynamic_attr())
-      .def(
-          py::init([](py::function func,
-                      py::tuple inputs,
-                      py::function var_name_lookup_fn,
-                      bool optimize,
-                      bool _force_outplace) {
-            auto graph = tracer::createGraphByTracing(
-                func, toStack(inputs), var_name_lookup_fn, _force_outplace);
-            return GraphExecutor(graph, optimize);
-          }),
-          py::arg("func"),
-          py::arg("inputs"),
-          py::arg("var_name_lookup_fn"),
-          py::arg("optimize") = true,
-          py::arg("_force_outplace") = false)
-      .def(
-          py::init([](std::shared_ptr<Graph> graph, bool optimize) {
-            return GraphExecutor(std::move(graph), optimize);
-          }),
-          py::arg("graph"),
-          py::arg("optimize") = true)
-      .def(
-          "graph_for",
-          [](GraphExecutor& ge, py::args args) {
-            return ge.graphFor(evilDeprecatedBadCreateStackDoNotUse(
-                args, ge.graph()->inputs()));
-          })
-      .def_property_readonly(
-          "graph", [](GraphExecutor& ge) { return ge.graph(); })
-      .def(
-          "get_debug_state",
-          [](GraphExecutor& ge) { return ge.getDebugState(); })
-      .def("__call__", [](GraphExecutor& ge, py::args args) -> py::object {
-        const auto& graph = ge.graph();
-        auto stack =
-            evilDeprecatedBadCreateStackDoNotUse(args, graph->inputs());
-        {
-          AutoNoGIL no_gil_guard;
-          ge.run(stack);
-        }
-        return createPyObjectForStack(std::move(stack));
-      });
-
   py::class_<PyTorchStreamWriter>(m, "PyTorchFileWriter")
       .def(py::init<std::string>())
       .def(