pass
else:
raise RuntimeError("Can't get a grad_executor for a non-differentiable graph")
- grad_executors = list(plan_state.code.grad_executors())
+ grad_executors = list(plan_state.code.grad_executor_states())
return grad_executors[diff_graph_idx or 0]
raise RuntimeError('Expected ScriptModule')
ge_state = script_module.get_debug_state()
fwd_plan = get_execution_plan(ge_state)
- grad_executor = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx)
- bwd_plan = get_execution_plan(grad_executor.get_debug_state())
+ grad_executor_state = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx)
+ bwd_plan = get_execution_plan(grad_executor_state)
# Running JIT passes requires that we own the graph (with a shared_ptr).
# The debug state struct does not own its graph so we make a copy of it.
return bwd_plan.graph.copy()
else:
recording_inputs = reference_tensors
- if isinstance(func, torch._C.Graph):
- ge = torch._C.GraphExecutor(func, optimize)
- else:
- ge = torch.jit.trace(func, input_tensors, optimize=optimize, check_tolerance=check_tolerance,
- _force_outplace=_force_outplace)
+ ge = torch.jit.trace(func, input_tensors, optimize=optimize, check_tolerance=check_tolerance,
+ _force_outplace=_force_outplace)
if export_import:
ge = self.getExportImportCopy(ge)
return MyInplaceFn.apply(x)
x = torch.randn(5, 5)
- ge = torch._C.GraphExecutor(fn, (x,), lambda var: '', _force_outplace=True)
+ ge = torch.jit.trace(fn, (x,), _force_outplace=True, check_trace=False)
with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
ge(x)
return a * b / (a - b) + b
V = Variable
a, b = V(torch.rand(1)), V(torch.rand(1))
- ge = torch._C.GraphExecutor(foo, (a, b), lambda var: '')
+ ge = torch.jit.trace(foo, (a, b))
a, b = V(torch.rand(1), requires_grad=True), V(
torch.rand(1), requires_grad=True)
r, = ge(a, b)
value_map = {}
pb_graph = pb_graph or graph_pb2.GraphDef()
- if isinstance(graph, (torch._C.GraphExecutor, torch._C.GraphExecutorState)):
+ if isinstance(graph, torch._C.GraphExecutorState):
visualize_graph_executor(graph, name_prefix, pb_graph,
partial(visualize, pb_graph=pb_graph))
return pb_graph
The strategy is to embed all different configurations as independent subgraphs,
while inlining the original graph as the one that actually produces the values.
"""
- if isinstance(state, torch._C.GraphExecutor):
- state = state.get_debug_state()
-
if state.autograd_fallback_graph is not None:
visualize(graph=state.autograd_fallback_graph,
name_prefix=name_prefix + 'autograd_fallback/',
});
// NOLINTNEXTLINE(bugprone-unused-raii)
py::class_<ArgumentSpec>(m, "ArgumentSpec");
- py::class_<Code>(m, "Code").def("grad_executors", [](Code& c) {
- return py::make_iterator(
- c.grad_executors().begin(), c.grad_executors().end());
+ py::class_<Code>(m, "Code").def("grad_executor_states", [](Code& c) {
+ std::vector<GraphExecutorState> states;
+ for (auto& e : c.grad_executors()) {
+ states.emplace_back(e->getDebugState());
+ }
+ return states;
});
py::class_<ExecutionPlanState>(m, "ExecutionPlanState")
.def_property_readonly(
"fallback", [](GraphExecutorState& s) { return s.fallback; });
- py::class_<GraphExecutor>(m, "GraphExecutor", py::dynamic_attr())
- .def(
- py::init([](py::function func,
- py::tuple inputs,
- py::function var_name_lookup_fn,
- bool optimize,
- bool _force_outplace) {
- auto graph = tracer::createGraphByTracing(
- func, toStack(inputs), var_name_lookup_fn, _force_outplace);
- return GraphExecutor(graph, optimize);
- }),
- py::arg("func"),
- py::arg("inputs"),
- py::arg("var_name_lookup_fn"),
- py::arg("optimize") = true,
- py::arg("_force_outplace") = false)
- .def(
- py::init([](std::shared_ptr<Graph> graph, bool optimize) {
- return GraphExecutor(std::move(graph), optimize);
- }),
- py::arg("graph"),
- py::arg("optimize") = true)
- .def(
- "graph_for",
- [](GraphExecutor& ge, py::args args) {
- return ge.graphFor(evilDeprecatedBadCreateStackDoNotUse(
- args, ge.graph()->inputs()));
- })
- .def_property_readonly(
- "graph", [](GraphExecutor& ge) { return ge.graph(); })
- .def(
- "get_debug_state",
- [](GraphExecutor& ge) { return ge.getDebugState(); })
- .def("__call__", [](GraphExecutor& ge, py::args args) -> py::object {
- const auto& graph = ge.graph();
- auto stack =
- evilDeprecatedBadCreateStackDoNotUse(args, graph->inputs());
- {
- AutoNoGIL no_gil_guard;
- ge.run(stack);
- }
- return createPyObjectForStack(std::move(stack));
- });
-
py::class_<PyTorchStreamWriter>(m, "PyTorchFileWriter")
.def(py::init<std::string>())
.def(