From 0b219538cfb70503baa74053d3e827951e0ea6f9 Mon Sep 17 00:00:00 2001 From: Zachary DeVito Date: Wed, 19 Dec 2018 15:02:13 -0800 Subject: [PATCH] add unpack_outputs to inlineCallTo Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15382 Differential Revision: D13518844 Pulled By: zdevito fbshipit-source-id: 981936988080af80629b70bf5f6dfa52ceb09c2f --- torch/csrc/jit/autodiff.cpp | 17 ++--------------- torch/csrc/jit/ir.cpp | 17 ++++++++++++++++- torch/csrc/jit/ir.h | 4 +++- torch/csrc/jit/passes/to_batch.cpp | 7 +------ 4 files changed, 22 insertions(+), 23 deletions(-) diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp index 54e88be..686a0b9 100644 --- a/torch/csrc/jit/autodiff.cpp +++ b/torch/csrc/jit/autodiff.cpp @@ -150,19 +150,6 @@ bool isDifferentiable(Graph & g) { static_cast(isDifferentiable)); } -// TODO: Remove this after #15355. -namespace { - std::vector inlineUnpackedCallTo(Graph& g, Graph& callee, ArrayRef inputs) { - auto outputs = inlineCallTo(g, callee, inputs); - if (callee.outputs().size() == 1 && callee.outputs().at(0)->type()->kind() == TupleType::Kind) { - auto tc = createTupleUnpack(outputs.at(0)); - outputs = std::vector(tc.begin(), tc.end()); - } - return outputs; - } -} //anonymous namespace - - // NB: Write gradient using torchscript // For example, node aten::mul() should be defined as follows // def forward(x, y): @@ -200,7 +187,7 @@ static c10::optional> build_script_grad( { WithInsertPoint guard(node->next()); auto fw_graph = compiled_graphs->forward; - new_outputs = inlineUnpackedCallTo(*graph, *fw_graph, node->inputs()); + new_outputs = inlineCallTo(*graph, *fw_graph, node->inputs(), /*unpack_outputs=*/true); for (size_t i = 0; i < node->outputs().size(); ++i) { new_outputs.at(i)->setType(node->outputs()[i]->type()); new_outputs.at(i)->replaceAllUsesWith(node->outputs()[i]); @@ -213,7 +200,7 @@ static c10::optional> build_script_grad( auto it = grad_vec.begin(); grad_vec.insert(it, new_outputs.back()); ArrayRef grad(grad_vec); - auto grad_inputs = inlineUnpackedCallTo(*graph, *bw_graph, grad); + auto grad_inputs = inlineCallTo(*graph, *bw_graph, grad, /*unpack_outputs=*/true); return grad_inputs; }; diff --git a/torch/csrc/jit/ir.cpp b/torch/csrc/jit/ir.cpp index df038e3..3ff7340 100644 --- a/torch/csrc/jit/ir.cpp +++ b/torch/csrc/jit/ir.cpp @@ -1495,7 +1495,7 @@ at::ArrayRef createTupleUnpack(Value* v) { return g.insertNode(g.createTupleUnpack(v))->outputs(); } -std::vector inlineCallTo(Graph& g, Graph& callee, ArrayRef inputs) { +std::vector inlineCallTo(Graph& g, Graph& callee, ArrayRef inputs, bool unpack_outputs) { std::unordered_map value_map; auto value_map_func = [&](Value* v) { return value_map.at(v); }; JIT_ASSERT(callee.inputs().size() == inputs.size()); @@ -1514,6 +1514,21 @@ std::vector inlineCallTo(Graph& g, Graph& callee, ArrayRef input for (auto* output : callee.outputs()) { outputs.push_back(value_map_func(output)); } + + if (unpack_outputs && outputs.size() == 1 && + callee.outputs().at(0)->type()->kind() == TupleType::Kind) { + auto tup = outputs[0]; + outputs.clear(); + for(Value* v : createTupleUnpack(tup)) { + outputs.emplace_back(v); + } + // if this was a peephole tuple unpack we can just get rid of + // the tuple construct here and prevent needing DCE + if (tup->node()->kind() == prim::TupleConstruct && !tup->node()->hasUses()) { + tup->node()->destroy(); + } + } + return outputs; } diff --git a/torch/csrc/jit/ir.h b/torch/csrc/jit/ir.h index 71a5361..251033d 100644 --- a/torch/csrc/jit/ir.h +++ b/torch/csrc/jit/ir.h @@ -1105,7 +1105,9 @@ inline Node* Graph::createPythonOp( TORCH_API void LintGraph(std::shared_ptr& graph); TORCH_API at::ArrayRef createTupleUnpack(Value* v); -TORCH_API std::vector inlineCallTo(Graph& g, Graph& callee, ArrayRef inputs); +// unpack_outputs - if true, and the callee returns a single tuple value, then insert a tuple unpack node +// and return the resulting values +TORCH_API std::vector inlineCallTo(Graph& g, Graph& callee, ArrayRef inputs, bool unpack_outputs=false); }} // namespace torch::jit diff --git a/torch/csrc/jit/passes/to_batch.cpp b/torch/csrc/jit/passes/to_batch.cpp index 2451304..5a1c940 100644 --- a/torch/csrc/jit/passes/to_batch.cpp +++ b/torch/csrc/jit/passes/to_batch.cpp @@ -21,12 +21,7 @@ std::shared_ptr ToBatch::getBatchOperator(const std::string& name, int64_ } std::vector inlineUnpackedCallTo(Graph& g, Graph& callee, ArrayRef inputs) { - auto outputs = inlineCallTo(g, callee, inputs); - if (callee.outputs().size() == 1 && callee.outputs().at(0)->type()->kind() == TupleType::Kind) { - auto tc = createTupleUnpack(outputs.at(0)); - outputs = std::vector(tc.begin(), tc.end()); - } - return outputs; + return inlineCallTo(g, callee, inputs, /*unpack_outputs=*/true); } // replace aten operator node with BatchTensor operator graph -- 2.7.4