add unpack_outputs to inlineCallTo
authorZachary DeVito <zdevito@fb.com>
Wed, 19 Dec 2018 23:02:13 +0000 (15:02 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 19 Dec 2018 23:11:59 +0000 (15:11 -0800)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15382

Differential Revision: D13518844

Pulled By: zdevito

fbshipit-source-id: 981936988080af80629b70bf5f6dfa52ceb09c2f

torch/csrc/jit/autodiff.cpp
torch/csrc/jit/ir.cpp
torch/csrc/jit/ir.h
torch/csrc/jit/passes/to_batch.cpp

index 54e88be..686a0b9 100644 (file)
@@ -150,19 +150,6 @@ bool isDifferentiable(Graph & g) {
                      static_cast<bool(*)(Node*)>(isDifferentiable));
 }
 
-// TODO: Remove this after #15355.
-namespace {
-  std::vector<Value*> inlineUnpackedCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) {
-    auto outputs = inlineCallTo(g, callee, inputs);
-    if (callee.outputs().size() == 1 && callee.outputs().at(0)->type()->kind() == TupleType::Kind) {
-      auto tc = createTupleUnpack(outputs.at(0));
-      outputs = std::vector<Value*>(tc.begin(), tc.end());
-    }
-    return outputs;
-  }
-} //anonymous namespace
-
-
 // NB: Write gradient using torchscript
 // For example, node aten::mul() should be defined as follows
 // def forward(x, y):
@@ -200,7 +187,7 @@ static c10::optional<std::vector<Value*>> build_script_grad(
   {
     WithInsertPoint guard(node->next());
     auto fw_graph = compiled_graphs->forward;
-    new_outputs = inlineUnpackedCallTo(*graph, *fw_graph, node->inputs());
+    new_outputs = inlineCallTo(*graph, *fw_graph, node->inputs(), /*unpack_outputs=*/true);
     for (size_t i = 0; i < node->outputs().size(); ++i) {
       new_outputs.at(i)->setType(node->outputs()[i]->type());
       new_outputs.at(i)->replaceAllUsesWith(node->outputs()[i]);
@@ -213,7 +200,7 @@ static c10::optional<std::vector<Value*>> build_script_grad(
   auto it = grad_vec.begin();
   grad_vec.insert(it, new_outputs.back());
   ArrayRef<Value*> grad(grad_vec);
-  auto grad_inputs = inlineUnpackedCallTo(*graph, *bw_graph, grad);
+  auto grad_inputs = inlineCallTo(*graph, *bw_graph, grad, /*unpack_outputs=*/true);
   return grad_inputs;
 };
 
index df038e3..3ff7340 100644 (file)
@@ -1495,7 +1495,7 @@ at::ArrayRef<Value*> createTupleUnpack(Value* v) {
   return g.insertNode(g.createTupleUnpack(v))->outputs();
 }
 
-std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) {
+std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs, bool unpack_outputs) {
   std::unordered_map<Value*, Value*> value_map;
   auto value_map_func = [&](Value* v) { return value_map.at(v); };
   JIT_ASSERT(callee.inputs().size() == inputs.size());
@@ -1514,6 +1514,21 @@ std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> input
   for (auto* output : callee.outputs()) {
     outputs.push_back(value_map_func(output));
   }
+
+  if (unpack_outputs && outputs.size() == 1 &&
+      callee.outputs().at(0)->type()->kind() == TupleType::Kind) {
+      auto tup = outputs[0];
+      outputs.clear();
+      for(Value* v : createTupleUnpack(tup)) {
+        outputs.emplace_back(v);
+      }
+      // if this was a peephole tuple unpack we can just get rid of
+      // the tuple construct here and prevent needing DCE
+      if (tup->node()->kind() == prim::TupleConstruct && !tup->node()->hasUses()) {
+        tup->node()->destroy();
+      }
+  }
+
   return outputs;
 }
 
index 71a5361..251033d 100644 (file)
@@ -1105,7 +1105,9 @@ inline Node* Graph::createPythonOp(
 TORCH_API void LintGraph(std::shared_ptr<Graph>& graph);
 
 TORCH_API at::ArrayRef<Value*> createTupleUnpack(Value* v);
-TORCH_API std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs);
+// unpack_outputs - if true, and the callee returns a single tuple value, then insert a tuple unpack node
+//                  and return the resulting values
+TORCH_API std::vector<Value*> inlineCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs, bool unpack_outputs=false);
 
 
 }} // namespace torch::jit
index 2451304..5a1c940 100644 (file)
@@ -21,12 +21,7 @@ std::shared_ptr<Graph> ToBatch::getBatchOperator(const std::string& name, int64_
 }
 
 std::vector<Value*> inlineUnpackedCallTo(Graph& g, Graph& callee, ArrayRef<Value*> inputs) {
-  auto outputs = inlineCallTo(g, callee, inputs);
-  if (callee.outputs().size() == 1 && callee.outputs().at(0)->type()->kind() == TupleType::Kind) {
-    auto tc = createTupleUnpack(outputs.at(0));
-    outputs = std::vector<Value*>(tc.begin(), tc.end());
-  }
-  return outputs;
+  return inlineCallTo(g, callee, inputs, /*unpack_outputs=*/true);
 }
 
 // replace aten operator node with BatchTensor operator graph