migrate subgraph slicing to use `moveBefore/moveAfter` (#13862)
authorMichael Suo <suo@fb.com>
Thu, 15 Nov 2018 01:20:36 +0000 (17:20 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 15 Nov 2018 01:33:36 +0000 (17:33 -0800)
Summary:
Migrate the `CreateAutodiffSubgraphs` pass to use topologically-safe moves instead of DynamicDAG. This is to unify the interface that we use for determining safe node moves to prepare for mutability.

The pass looks a lot like GraphFuser now, and there's a lot of code duplication. I plan to pull common stuff out into a "subgraph manipulation utils" thing, but didn't want to clutter this PR.

Future steps:
- Get rid of code duplication (see above)
- Use DynamicDAG to back the `moveBefore/After` calls.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13862

Differential Revision: D13072871

Pulled By: suo

fbshipit-source-id: 92e7880ef444e0aefd51df60964bba7feaf42ae0

16 files changed:
test/cpp/jit/gtest.cpp
test/cpp/jit/no-gtest.cpp
test/cpp/jit/tests.h
test/expect/TestJit.test_cpp_cuda.expect
test/expect/TestScript.test_lstm_fusion_cuda-backward.expect
test/expect/TestScript.test_lstm_fusion_cuda-forward.expect
test/expect/TestScript.test_milstm_fusion_cuda-backward.expect
test/expect/TestScript.test_milstm_fusion_cuda-forward.expect
torch/CMakeLists.txt
torch/csrc/jit/graph_executor.cpp
torch/csrc/jit/init.cpp
torch/csrc/jit/ir.h
torch/csrc/jit/passes/create_autodiff_subgraphs.cpp
torch/csrc/jit/passes/create_autodiff_subgraphs.h
torch/csrc/jit/passes/utils/subgraph_utils.cpp [new file with mode: 0644]
torch/csrc/jit/passes/utils/subgraph_utils.h [new file with mode: 0644]

index 0b60f62..27f9f14 100644 (file)
@@ -26,6 +26,7 @@ JIT_TEST(IValue)
 JIT_TEST(SchemaParser)
 JIT_TEST(TopologicalIndex)
 JIT_TEST(TopologicalMove)
+JIT_TEST(SubgraphUtils)
 
 #define JIT_TEST_CUDA(name)    \
   TEST(JitTest, name##_CUDA) { \
index a53aab1..6be01f0 100644 (file)
@@ -28,6 +28,7 @@ std::string runJITCPPTests() {
   testSchemaParser();
   testTopologicalIndex();
   testTopologicalMove();
+  testSubgraphUtils();
   return out.str();
 }
 
index 9605d7d..5b2c12e 100644 (file)
 #include "torch/csrc/jit/interpreter.h"
 #include "torch/csrc/jit/ir.h"
 #include "torch/csrc/jit/operator.h"
+#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
 #include "torch/csrc/jit/passes/constant_propagation.h"
 #include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
 #include "torch/csrc/jit/passes/dead_code_elimination.h"
 #include "torch/csrc/jit/passes/lower_grad_of.h"
 #include "torch/csrc/jit/passes/requires_grad_analysis.h"
 #include "torch/csrc/jit/passes/shape_analysis.h"
+#include "torch/csrc/jit/passes/utils/subgraph_utils.h"
 #include "torch/csrc/jit/symbolic_variable.h"
 #include "torch/csrc/jit/tracer.h"
 #include "torch/csrc/utils/hash.h"
@@ -697,11 +699,41 @@ void testDifferentiateWithRequiresGrad(std::ostream& out = std::cout) {
 
 void testCreateAutodiffSubgraphs(std::ostream& out = std::cout) {
   auto graph = build_lstm();
-  CreateAutodiffSubgraphs(*graph, /*threshold=*/2);
+  CreateAutodiffSubgraphs(graph, /*threshold=*/2);
   out << "testCreateAutodiffSubgraphs\n";
   out << *graph << "\n";
 }
 
+void testSubgraphUtils() {
+  auto graph = build_lstm();
+  EliminateCommonSubexpression(graph);
+
+  std::vector<Node*> originalNodes(
+      graph->nodes().begin(), graph->nodes().end());
+
+  // Merge everything into a single subgraph
+  bool first = true;
+  Node* subgraph;
+  for (auto it = graph->nodes().rbegin(); it != graph->nodes().rend();) {
+    if (first) {
+      subgraph = SubgraphUtils::createSingletonSubgraph(
+          *it, prim::DifferentiableGraph);
+      it = ++subgraph->reverseIterator();
+      first = false;
+    }
+
+    SubgraphUtils::mergeNodeIntoSubgraph(*it, subgraph);
+    it = ++subgraph->reverseIterator();
+  }
+
+  // Unmerge and compare with original node listing
+  SubgraphUtils::unmergeSubgraph(subgraph);
+  EliminateCommonSubexpression(graph);
+
+  std::vector<Node*> newNodes(graph->nodes().begin(), graph->nodes().end());
+  ASSERT_EQ(originalNodes.size(), newNodes.size());
+}
+
 autograd::Variable var(at::Type& t, at::IntList sizes, bool requires_grad) {
   return autograd::make_variable(at::rand(sizes, t.options()), requires_grad);
 }
index 0ccba95..8a92aa7 100644 (file)
@@ -65,31 +65,29 @@ graph(%0 : Dynamic
       %3 : Dynamic
       %4 : Dynamic) {
   %7 : int = prim::Constant[value=1]()
-  %19 : int = prim::Constant[value=1]()
-  %23 : Dynamic, %24 : Dynamic = prim::DifferentiableGraph_0(%0, %3, %1, %4, %2)
-  return (%24, %23);
+  %23 : Dynamic, %24 : Dynamic = prim::DifferentiableGraph_0(%2, %1, %4, %0, %3)
+  return (%23, %24);
 }
-with prim::DifferentiableGraph_0 = graph(%1 : Dynamic
-      %2 : Dynamic
-      %4 : Dynamic
-      %5 : Dynamic
-      %17 : Dynamic) {
-  %0 : Dynamic = aten::mm(%1, %2)
-  %3 : Dynamic = aten::mm(%4, %5)
-  %7 : int = prim::Constant[value=1]()
-  %6 : Dynamic = aten::add(%0, %3, %7)
-  %8 : Dynamic, %9 : Dynamic, %10 : Dynamic, %11 : Dynamic = prim::ConstantChunk[chunks=4, dim=1](%6)
-  %12 : Dynamic = aten::sigmoid(%8)
-  %13 : Dynamic = aten::sigmoid(%11)
-  %14 : Dynamic = aten::tanh(%10)
-  %15 : Dynamic = aten::sigmoid(%9)
-  %16 : Dynamic = aten::mul(%15, %17)
-  %18 : Dynamic = aten::mul(%12, %14)
-  %20 : int = prim::Constant[value=1]()
-  %19 : Dynamic = aten::add(%16, %18, %20)
-  %21 : Dynamic = aten::tanh(%19)
-  %22 : Dynamic = aten::mul(%13, %21)
-  return (%19, %22);
+with prim::DifferentiableGraph_0 = graph(%13 : Dynamic
+      %32 : Dynamic
+      %33 : Dynamic
+      %35 : Dynamic
+      %36 : Dynamic) {
+  %37 : Dynamic = aten::mm(%35, %36)
+  %34 : Dynamic = aten::mm(%32, %33)
+  %30 : int = prim::Constant[value=1]()
+  %31 : Dynamic = aten::add(%37, %34, %30)
+  %24 : Dynamic, %25 : Dynamic, %26 : Dynamic, %27 : Dynamic = prim::ConstantChunk[chunks=4, dim=1](%31)
+  %22 : Dynamic = aten::sigmoid(%24)
+  %20 : Dynamic = aten::sigmoid(%27)
+  %18 : Dynamic = aten::tanh(%26)
+  %16 : Dynamic = aten::sigmoid(%25)
+  %14 : Dynamic = aten::mul(%16, %13)
+  %11 : Dynamic = aten::mul(%22, %18)
+  %8 : Dynamic = aten::add(%14, %11, %30)
+  %4 : Dynamic = aten::tanh(%8)
+  %2 : Dynamic = aten::mul(%20, %4)
+  return (%2, %8);
 }
 
 testDifferentiate
index 0f8d60f..c49ae51 100644 (file)
@@ -17,7 +17,7 @@ graph(%0 : Float(*, *)
       %cellgate : Float(*, *)
       %outgate : Float(*, *)
       %18 : Float(*, *)) {
-  %19 : Float(*, *), %20 : Float(*, *) = prim::FusionGroup_0(%forgetgate, %ingate, %cellgate, %outgate, %11, %0, %18, %1)
+  %19 : Float(*, *), %20 : Float(*, *) = prim::FusionGroup_0(%forgetgate, %ingate, %cellgate, %outgate, %9, %1, %18, %0)
   %21 : Float(*, *) = aten::t(%13)
   %22 : Float(*, *) = aten::mm(%20, %21)
   %23 : Float(*, *) = aten::t(%10)
@@ -25,10 +25,10 @@ graph(%0 : Float(*, *)
   %25 : Float(*, *) = aten::t(%24)
   %26 : Float(*, *) = aten::t(%12)
   %27 : Float(*, *) = aten::mm(%20, %26)
-  %28 : Float(*, *) = aten::t(%9)
+  %28 : Float(*, *) = aten::t(%11)
   %29 : Float(*, *) = aten::mm(%28, %20)
   %30 : Float(*, *) = aten::t(%29)
-  return (%30, %27, %25, %22, %20, %20, %19);
+  return (%19, %20, %20, %22, %25, %27, %30);
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*, *)
       %1 : Float(*, *)
index 6350144..2bc74cd 100644 (file)
@@ -5,27 +5,27 @@ graph(%x : Float(*, *)
       %w_hh : Float(*, *)
       %b_ih : Float(*)
       %b_hh : Float(*)) {
-  %7 : Float(*, *), %8 : Float(*, *) = prim::DifferentiableGraph_0(%w_ih, %x, %w_hh, %hx, %b_ih, %b_hh, %cx)
-  return (%8, %7);
+  %hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %b_hh, %b_ih, %hx, %w_hh, %x, %w_ih)
+  return (%hy, %cy);
 }
 with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
-      %1 : Float(*, *)
-      %2 : Float(*, *)
+      %1 : Float(*)
+      %2 : Float(*)
       %3 : Float(*, *)
-      %4 : Float(*)
-      %5 : Float(*)
+      %4 : Float(*, *)
+      %5 : Float(*, *)
       %6 : Float(*, *)) {
-  %7 : Float(*, *) = aten::t(%0)
-  %8 : Float(*, *) = aten::mm(%1, %7)
-  %9 : Float(*, *) = aten::t(%2)
+  %7 : Float(*, *) = aten::t(%6)
+  %8 : Float(*, *) = aten::mm(%5, %7)
+  %9 : Float(*, *) = aten::t(%4)
   %10 : Float(*, *) = aten::mm(%3, %9)
   %11 : int = prim::Constant[value=1]()
-  %12 : Float(*, *) = prim::FusionGroup_0(%4, %8, %10)
-  %13 : Dynamic[] = prim::ListConstruct(%12, %5)
+  %12 : Float(*, *) = prim::FusionGroup_0(%2, %8, %10)
+  %13 : Dynamic[] = prim::ListConstruct(%12, %1)
   %14 : Dynamic[] = aten::broadcast_tensors(%13)
   %15 : Dynamic, %16 : Dynamic = prim::ListUnpack(%14)
-  %hy : Float(*, *), %18 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%6, %16, %15)
-  return (%cy, %hy, %7, %9, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %18);
+  %hy : Float(*, *), %18 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%0, %16, %15)
+  return (%hy, %cy, %7, %9, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %18);
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*)
       %1 : Float(*, *)
index 678012b..3b0a2fe 100644 (file)
@@ -11,10 +11,10 @@ graph(%0 : Float(*, *)
       %10 : Undefined
       %11 : Undefined
       %12 : Float(*, *)
-      %13 : Float(*, *)
+      %13 : Float(*)
       %14 : Float(*)
       %15 : Float(*)
-      %16 : Float(*)
+      %16 : Float(*, *)
       %17 : Float(*, *)
       %18 : Float(*, *)
       %Wx : Float(*, *)
@@ -26,18 +26,18 @@ graph(%0 : Float(*, *)
       %cellgate : Float(*, *)
       %outgate : Float(*, *)
       %27 : Float(*, *)) {
-  %28 : Float(*, *) = prim::FusionGroup_0(%ingate, %forgetgate, %cellgate, %outgate, %17, %0, %27, %1)
+  %28 : Float(*, *) = prim::FusionGroup_0(%ingate, %forgetgate, %cellgate, %outgate, %12, %1, %27, %0)
   %29 : Float(*, *) = aten::mul(%28, %Uz)
   %30 : Float(*, *) = aten::mul(%28, %Wx)
-  %31 : Float(*, *) = prim::FusionGroup_1(%28, %22, %16)
-  %32 : Float(*, *), %33 : Float(*, *) = prim::FusionGroup_2(%Wx, %14, %28, %Uz, %15)
-  %34 : Float(*, *) = aten::t(%13)
+  %31 : Float(*, *) = prim::FusionGroup_1(%28, %22, %13)
+  %32 : Float(*, *), %33 : Float(*, *) = prim::FusionGroup_2(%Wx, %15, %28, %Uz, %14)
+  %34 : Float(*, *) = aten::t(%16)
   %35 : Float(*, *) = aten::mm(%34, %31)
   %36 : Float(*, *) = aten::t(%35)
-  %37 : Float(*, *) = aten::t(%12)
+  %37 : Float(*, *) = aten::t(%17)
   %38 : Float(*, *) = aten::mm(%37, %33)
   %39 : Float(*, *) = aten::t(%38)
-  return (%39, %36, %32, %30, %29, %28);
+  return (%28, %29, %30, %32, %36, %39);
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*, *)
       %1 : Float(*, *)
index 5811d17..9d5aaf2 100644 (file)
@@ -7,29 +7,29 @@ graph(%x : Float(*, *)
       %beta_i : Float(*)
       %beta_h : Float(*)
       %bias : Float(*)) {
-  %9 : Float(*, *), %10 : Float(*, *) = prim::DifferentiableGraph_0(%w_ih, %x, %w_hh, %hx, %alpha, %beta_i, %beta_h, %bias, %cx)
-  return (%10, %9);
+  %hy : Float(*, *), %cy : Float(*, *) = prim::DifferentiableGraph_0(%cx, %bias, %beta_h, %beta_i, %alpha, %hx, %w_hh, %x, %w_ih)
+  return (%hy, %cy);
 }
 with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *)
-      %1 : Float(*, *)
-      %2 : Float(*, *)
-      %3 : Float(*, *)
+      %1 : Float(*)
+      %2 : Float(*)
+      %3 : Float(*)
       %4 : Float(*)
-      %5 : Float(*)
-      %6 : Float(*)
-      %7 : Float(*)
+      %5 : Float(*, *)
+      %6 : Float(*, *)
+      %7 : Float(*, *)
       %8 : Float(*, *)) {
-  %9 : Float(*, *) = aten::t(%0)
-  %Wx.1 : Float(*, *) = aten::mm(%1, %9)
-  %11 : Float(*, *) = aten::t(%2)
-  %Uz.1 : Float(*, *) = aten::mm(%3, %11)
+  %9 : Float(*, *) = aten::t(%8)
+  %Wx.1 : Float(*, *) = aten::mm(%7, %9)
+  %11 : Float(*, *) = aten::t(%6)
+  %Uz.1 : Float(*, *) = aten::mm(%5, %11)
   %13 : int = prim::Constant[value=1]()
-  %14 : Float(*, *), %15 : Float(*, *) = prim::FusionGroup_0(%6, %Uz.1, %5, %Wx.1, %4)
-  %16 : Dynamic[] = prim::ListConstruct(%14, %7)
+  %14 : Float(*, *), %15 : Float(*, *) = prim::FusionGroup_0(%2, %Uz.1, %3, %Wx.1, %4)
+  %16 : Dynamic[] = prim::ListConstruct(%14, %1)
   %17 : Dynamic[] = aten::broadcast_tensors(%16)
   %18 : Dynamic, %19 : Dynamic = prim::ListUnpack(%17)
-  %hy : Float(*, *), %21 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%8, %19, %18)
-  return (%cy, %hy, %9, %Wx.1, %11, %Uz.1, %15, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %21);
+  %hy : Float(*, *), %21 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_1(%0, %19, %18)
+  return (%hy, %cy, %9, %Wx.1, %11, %Uz.1, %15, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %21);
 }
 with prim::FusionGroup_0 = graph(%0 : Float(*)
       %1 : Float(*, *)
index ef32f00..183f6c7 100644 (file)
@@ -184,6 +184,7 @@ set(TORCH_SRCS
   ${TORCH_SRC_DIR}/csrc/jit/passes/requires_grad_analysis.cpp
   ${TORCH_SRC_DIR}/csrc/jit/passes/specialize_undef.cpp
   ${TORCH_SRC_DIR}/csrc/jit/passes/python_print.cpp
+  ${TORCH_SRC_DIR}/csrc/jit/passes/utils/subgraph_utils.cpp
   ${TORCH_SRC_DIR}/csrc/jit/fuser/interface.cpp
   ${TORCH_SRC_DIR}/csrc/jit/register_prim_ops.cpp
   ${TORCH_SRC_DIR}/csrc/jit/register_special_ops.cpp
index 4a06572..a187ce4 100644 (file)
@@ -439,7 +439,7 @@ private:
     // Phase 5. Apply non-differentiable optimizations to the graphs we've found
     //          (or the whole grpah if we know we won't need its derivative).
     if (needsGradient(opt_graph)) {
-      auto diff_nodes = CreateAutodiffSubgraphs(*opt_graph, autodiffSubgraphNodeThreshold);
+      auto diff_nodes = CreateAutodiffSubgraphs(opt_graph, autodiffSubgraphNodeThreshold);
       for (Node * dnode : diff_nodes) {
         auto diff_graph = std::move(dnode->g(attr::Subgraph));
         Gradient gradient = differentiate(diff_graph);
index 11eeb9d..d8dc964 100644 (file)
@@ -133,7 +133,7 @@ void initJITBindings(PyObject *module) {
      return ConstantPropagation(g);
    })
    .def("_jit_pass_erase_shape_information", EraseShapeInformation)
-   .def("_jit_pass_create_autodiff_subgraphs", [](Graph& graph) {
+   .def("_jit_pass_create_autodiff_subgraphs", [](std::shared_ptr<Graph> graph) {
      CreateAutodiffSubgraphs(graph);
    })
    .def("_jit_run_cpp_tests", [] {
index 6b203ad..1caa856 100644 (file)
@@ -839,6 +839,7 @@ public:
   TORCH_API Node* createUndefined();
   TORCH_API Node* createNoneGenerator();
   TORCH_API Node* createFusionGroup();
+  TORCH_API Node* createDifferentiableSubgraph();
   TORCH_API Node* createTuple(at::ArrayRef<Value*> values);
   TORCH_API Node* createTupleUnpack(Value * v);
   TORCH_API Node* createTupleIndex(Value * tup, int64_t index);
index 1be8d4e..aba60c0 100644 (file)
 #include "torch/csrc/jit/passes/create_autodiff_subgraphs.h"
 
-#include "torch/csrc/jit/ir.h"
-#include "torch/csrc/jit/autodiff.h"
 #include "torch/csrc/jit/assertions.h"
-#include "torch/csrc/jit/dynamic_dag.h"
-
-#include <cstddef>
-#include <limits>
-
-namespace torch { namespace jit {
+#include "torch/csrc/jit/autodiff.h"
+#include "torch/csrc/jit/ir.h"
+#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
+#include "torch/csrc/jit/passes/utils/subgraph_utils.h"
 
-struct Graph;
+namespace torch {
+namespace jit {
 
 namespace {
 
-// Move nodes that exist in graph g into a 'group_node_kind' node.
-// All inputs shared by the nodes become inputs to the new node.
-// Outputs from 'nodes' are redirected to outputs of the new node,
-// and the original nodes are removed.
-// prereq: it is topologically valid to place the new node
-// right before nodes[0] (i.e. it will not create cycles and all uses of
-// new node will be after this position).
-// prereq: nodes are in topological order
-Node* mergeNodes(Block * block, Symbol group_node_kind, ArrayRef<Node*> nodes) {
-  JIT_ASSERT(nodes.size() > 0);
-  std::unordered_map<Value*, Value*> value_map;
-  Graph * graph = block->owningGraph();
-
-  auto new_graph = std::make_shared<Graph>();
-  Node * group_node = graph->create(group_node_kind, 0);
-  group_node->g_(attr::Subgraph, new_graph);
-
-  auto getOrCreateInput = [&](Value * v) {
-    if(value_map.count(v) > 0) {
-      return value_map[v];
-    }
-    if (auto value = toIValue(v)) {
-      Value * nv = new_graph->insertConstant(*value);
-      value_map[v] = nv;
-      return nv;
-    }
-    Value * nv = new_graph->addInput()->setType(v->type());
-    group_node->addInput(v);
-    value_map[v] = nv;
-    return nv;
-  };
-  std::unordered_set<Node*> group_set(nodes.begin(), nodes.end());
-  for(auto n : nodes) {
-    auto nn = new_graph->appendNode(new_graph->createClone(n, getOrCreateInput));
-    for(size_t i = 0; i < nn->outputs().size(); ++i) {
-      auto old_output = n->outputs()[i];
-      auto new_output = nn->outputs()[i];
-      value_map[old_output] = new_output;
-      std::vector<Use> to_replace;
-      for(auto u : old_output->uses()) {
-        // Uses within the set do not need to be made outputs
-        if(group_set.count(u.user) > 0)
-          continue;
-        // Other uses do, but we
-        // cannot replace them here or we invalid the uses list iterator
-        to_replace.push_back(u);
-      }
-      if(to_replace.size() > 0) {
-        new_graph->registerOutput(new_output);
-        Value * external_output = group_node->addOutput()->setType(old_output->type());
-        for(auto u : to_replace) {
-          u.user->replaceInput(u.offset, external_output);
-        }
+class SubgraphSlicer {
+ public:
+  SubgraphSlicer(Block* block, size_t minSubgraphSize)
+      : block_(block), minSubgraphSize_(minSubgraphSize) {}
+
+  void run(std::vector<Node*>& diffGraphs) {
+    // We need to run the slicer multiple times in order to get all merge
+    // opportunities. This is because moveBeforeTopologicalValid may reorder
+    // nodes to be AFTER the current iteration point. In order to properly
+    // consider those nodes for merging, we need run the pass until no changes
+    // have been made.
+    //
+    // Example:
+    //   c = f(a, b)
+    //   d = f(c)
+    //   e = f(d)  <- iter is here, moving upward
+    // After c.moveBeforeTopologicallyValid(e), we have:
+    //   c = f(a, b)
+    //   e = f(d)  <- iter still here
+    //   d = f(c)  <- this was node moved on the other side.
+    bool any_changed = true;
+    while (any_changed) {
+      any_changed = false;
+      for (auto it = block_->nodes().rbegin(); it != block_->nodes().rend();) {
+        bool changed;
+        std::tie(it, changed) = scanNode(*it);
+        any_changed |= changed;
       }
     }
-  }
-  group_node->insertBefore(nodes[0]);
-  // delete backward, so that nodes are use-free before deletion
-  for(size_t i = nodes.size(); i > 0; --i) {
-    nodes[i - 1]->destroy();
-  }
-  JIT_ASSERT(isDifferentiable(*new_graph));
-  return group_node;
-}
 
-bool shouldConsiderForMerge(detail::Vertex<Node*>* v) {
-  if (v->data.size() >= 2) {
-    return true;
-  }
-  JIT_ASSERT(v->data.size() == 1);
-  auto * node = *v->data.begin();
-  if (node->kind() == prim::Constant) {
-    return false;
-  }
-  return isDifferentiable(node);
-}
+    // Done constructing subgraphs. Do some post-processing cleanup:
+    // 1. Run CSE to delete redundanet constant nodes.
+    // 2. We may need to re-inline ones that are too small.
+    auto curNode = *block_->nodes().rbegin();
+    while (curNode != *block_->nodes().rend()) {
+      for (auto subBlock : curNode->blocks()) {
+        SubgraphSlicer(subBlock, minSubgraphSize_).run(diffGraphs);
+      }
 
-static detail::DynamicDAG<Node*> make_dependency_graph(Block * block) {
-  detail::DynamicDAG<Node*> dag;
-  std::unordered_map<Node*,detail::Vertex<Node*>*> node_to_vertex;
-  // NB: the block's param and return nodes are not in the dependency graph.
-  for (Node * node : block->nodes()) {
-    node_to_vertex[node] = dag.newVertex(node);
-  }
-  for (auto * node : block->nodes()) {
-    for (auto * v : node->outputs()) {
-      for (auto & use : v->uses()) {
-        // [Determine data dependencies]
-        // Consider the following code:
-        //     y = f(x)
-        //     if k:
-        //        w += y
-        //     z = g(y)
-        // This produces a dependency graph with 3 vertices:
-        // (0: f)   (1: if k ...)   (2: g)
-        // We need to peek into the if Node* to determine its data dependencies
-        // (the body depends on the output of f, so Vertex 1 depends on Vertex 0).
-        // For each Use of y, we find an owning node of y that is a part of the
-        // dependency graph (in this case, the Vertex containing the if Node*)
-        // and then record the dependency.
-        auto * owning_node = use.user;
-        if (owning_node == block->return_node()) {
-          // The return node is not in the dag. Carry on.
-          continue;
-        }
-        while (true) {
-          auto search = node_to_vertex.find(owning_node);
-          if (search == node_to_vertex.end()) {
-            owning_node = owning_node->owningBlock()->owningNode();
-            JIT_ASSERT(owning_node != nullptr);
-            continue;
-          }
-          // NB: DynamicDAG is a simple graph (no multi-edges).
-          // addEdge is a no-op if the edge already exists.
-          dag.addEdge(node_to_vertex[node], search->second);
-          break;
+      // Save the previous node, since we might delete `curNode` in next block
+      auto prevNode = curNode->prev();
+      if (curNode->kind() == prim::DifferentiableGraph) {
+        // Inlining nodes may cause some subexpression to come back in the
+        // subgraphs (for example, copying constants in repeatedly will generate
+        // redundant prim::Constants). Run CSE to clean them up.
+        EliminateCommonSubexpression(curNode->g(attr::Subgraph));
+
+        if (!inlineIfTooSmall(curNode)) {
+          diffGraphs.push_back(curNode);
         }
       }
+      curNode = prevNode;
     }
+    // Run CSE one more time to eliminate duplicates that may have occured
+    // while re-inlining subgraphs.
+    EliminateCommonSubexpression(block_);
   }
-  return dag;
-}
 
-static void find_differentiable_groups(
-    detail::DynamicDAG<Node*>& dep_graph,
-    size_t distance_threshold=256,
-    size_t producer_edge_threshold=16) {
-  // A Vertex contains a Node* or a differentiable group of Node*.
-  // Perform graph contraction on dep_graph: contract two vertices(x, y) if
-  // the following conditions hold:
-  // - x, y can be merged to form a differentiable group
-  // - the contraction would not invalidate the dag (it creates no cycles).
+ private:
+  // Inline this node's group subgraph into the outer graph if it's smaller
+  // than the specified minimum size.
   //
-  // This performs a greedy algorithm. This greedy algorithm considers
-  // dep_graph vertices in reverse topological order by reverse iterating through
-  // ord indices. For a certain ord, we attempt to merge the vertex at that ord
-  // with each of its parents. If the vertex at the ord cannot be merged with any
-  // of its parents, then we move on to a smaller ord and repeat.
-  //
-  // Each contractEdge call is effectively constant because we limit the size
-  // of the affected region (via the distance_threshold) and the fan in/fan out
-  // via producer_edge_threshold.
-  // In addition, each sort of in_edges is bounded by producer_edge threshold.
-  // This makes the complexity of find_differential_groups effectively O(V + E).
-
-  // Iterate in reverse topological order
-  int64_t ord = dep_graph.max_size() - 1;
-  for (int64_t ord = dep_graph.max_size() - 1; ord >= 0; --ord) {
-    if (!dep_graph.at(ord)) continue;
-
-    auto* consumer = dep_graph.at(ord).value();
-    if (!shouldConsiderForMerge(consumer)) continue;
-
-    // To bound the complexity of the sort. Makes the algorithm less optimal.
-    if (consumer->in_edges().size() > producer_edge_threshold) continue;
-
-    // Iterate through consumer->in_edges() in reverse topological order.
-    // sort is performed once per ord in dep_graph and once per contraction.
-    // There can be at most dep_graph.max_size() contractions, so
-    // we do at most 2 * dep_graph.max_size() sorts.
-    consumer->in_edges().sort();
+  // Returns true if an inlining has occured, false otherwise.
+  bool inlineIfTooSmall(Node* n) {
+    JIT_ASSERT(n->kind() == prim::DifferentiableGraph);
+    auto subgraph = SubgraphUtils::getSubgraph(n);
+    size_t i = 0;
+    for (auto it = subgraph->nodes().begin(); it != subgraph->nodes().end();
+         ++it) {
+      if (++i >= minSubgraphSize_) {
+        return false;
+      }
+    }
 
-    for (auto it = consumer->in_edges().rbegin(); it != consumer->in_edges().rend(); ++it) {
-      auto * producer = *it;
-      // The distance threshold makes this algorithm "not optimal": it will miss
-      // some possible contraction opportunities, but it hopefully lets us:
-      // 1) preserve locality of tensors. We don't want to keep them alive for too long.
-      // 2) Help bound the computation complexity for contractEdge
-      if (consumer->ord - producer->ord > distance_threshold) continue;
-      if (!shouldConsiderForMerge(producer)) continue;
+    SubgraphUtils::unmergeSubgraph(n);
+    return true;
+  }
 
-      // If the edge contraction is successful, dep_graph.at(ord) may have changed
-      // as well as consumer->in_edges() so we break out of this loop
-      if (dep_graph.contractEdge(producer, consumer)) {
-        // Stay at the current ord until we are done considering the vertex
-        // at this ord for contraction
-        ++ord;
-        break;
+  value_list sortReverseTopological(ArrayRef<Value*> inputs) {
+    value_list result;
+    for (auto i : inputs) {
+      if (i->node()->owningBlock() == block_) {
+        result.push_back(i);
       }
     }
+    // Sort in reverse topological order
+    std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
+      return a->node()->isAfter(b->node());
+    });
+    return result;
   }
-}
-
-static void reorder_according_to_dag(Block * block, const detail::DynamicDAG<Node*>& dep_graph) {
-  for (size_t ord = 0; ord < dep_graph.max_size(); ++ord) {
-    const auto& vertex = dep_graph.at(ord);
-    if (!vertex.has_value()) continue;
 
-    auto& nodes = vertex.value()->data;
-    for (Node* node : nodes) {
-      // Move all nodes according to the topological order in dep_graph. A lot
-      // of the moves are unnecessary but this is a quick & easy solution.
-      node->moveBefore(block->return_node());
+  bool shouldConsiderForMerge(Node* node) {
+    // if we're already in the process of merging
+    if (node->kind() == prim::DifferentiableGraph) {
+      return true;
+    }
+    if (node->kind() == prim::Constant) {
+      return false;
     }
+    return isDifferentiable(node);
   }
-}
-
-static void merge_differentiable_groups(
-    Block * block,
-    const detail::DynamicDAG<Node*>& dep_graph,
-    size_t size_threshold,
-    std::vector<Node*>& diff_graphs) {
-  for (size_t ord = 0; ord < dep_graph.max_size(); ++ord) {
-    const auto& vertex = dep_graph.at(ord);
-    if (!vertex) continue;
-    if (!shouldConsiderForMerge(vertex.value())) continue;
 
-    auto& nodes = vertex.value()->data;
-    if (nodes.size() < size_threshold) continue;
+  std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
+    if (shouldConsiderForMerge(consumer)) {
+      if (consumer->kind() != prim::DifferentiableGraph) {
+        consumer = SubgraphUtils::createSingletonSubgraph(
+            consumer, prim::DifferentiableGraph);
+      }
+      auto inputs = sortReverseTopological(consumer->inputs());
+      for (auto input : inputs) {
+        if (auto group = tryMerge(consumer, input->node())) {
+          // we successfully merged, so the new group's `inputs` may have
+          // changed. So rescan the new group for more merging opportunities.
+          return std::make_pair(group.value()->reverseIterator(), true);
+        }
+      }
+    }
 
-    diff_graphs.push_back(mergeNodes(block, prim::DifferentiableGraph, nodes));
+    return std::make_pair(++consumer->reverseIterator(), false);
   }
-}
 
-void CreateAutodiffSubgraphsPK(
-    Block * block,
-    size_t size_threshold,
-    std::vector<Node*>& diff_graphs) {
-  for (auto * node : block->nodes()) {
-    // Find subgraphs to run this on recursively.
-    if (isDifferentiable(node)) continue;
-    for (auto * sub_block : node->blocks()) {
-      CreateAutodiffSubgraphsPK(sub_block, size_threshold, diff_graphs);
+  // Try to merge `producer` into `consumer`. If successful, this destroys
+  // `producer` and returns the `consumer` group.
+  c10::optional<Node*> tryMerge(Node* consumer, Node* producer) {
+    JIT_ASSERT(consumer->kind() == prim::DifferentiableGraph);
+    bool canMerge = shouldConsiderForMerge(producer) &&
+        producer->moveBeforeTopologicallyValid(consumer);
+
+    if (!canMerge) {
+      return c10::nullopt;
     }
-  }
 
-  auto dep_graph = make_dependency_graph(block);
-  find_differentiable_groups(dep_graph);
-  reorder_according_to_dag(block, dep_graph);
-  merge_differentiable_groups(block, dep_graph, size_threshold, diff_graphs);
-}
+    SubgraphUtils::mergeNodeIntoSubgraph(producer, consumer);
+
+    return consumer;
+  }
 
+  Block* block_;
+  size_t minSubgraphSize_;
+};
 } // anonymous namespace
 
-std::vector<Node*> CreateAutodiffSubgraphs(Graph & graph, size_t threshold) {
+std::vector<Node*> CreateAutodiffSubgraphs(
+    std::shared_ptr<Graph> graph,
+    size_t threshold) {
   std::vector<Node*> diff_nodes;
-  CreateAutodiffSubgraphsPK(graph.block(), threshold, diff_nodes);
+  SubgraphSlicer(graph->block(), threshold).run(diff_nodes);
   return diff_nodes;
 }
 
-}}
+} // namespace jit
+} // namespace torch
index 1908b03..5c95fb9 100644 (file)
@@ -11,6 +11,7 @@ namespace torch { namespace jit {
 // subgraphs that are differentiable by the jit's autodiff passes
 // threshold - minimum number of nodes that will appear in a block
 // returns all differentiable blocks that have been found
-TORCH_API std::vector<Node*> CreateAutodiffSubgraphs(Graph & graph, size_t threshold = 2);
-
+TORCH_API std::vector<Node*> CreateAutodiffSubgraphs(
+    std::shared_ptr<Graph> graph,
+    size_t threshold = 2);
 }}
diff --git a/torch/csrc/jit/passes/utils/subgraph_utils.cpp b/torch/csrc/jit/passes/utils/subgraph_utils.cpp
new file mode 100644 (file)
index 0000000..0a7885e
--- /dev/null
@@ -0,0 +1,163 @@
+#include "subgraph_utils.h"
+
+namespace torch {
+namespace jit {
+namespace SubgraphUtils {
+namespace {
+bool isSubgraphNodeKind(Symbol s) {
+  return s == prim::DifferentiableGraph || s == prim::FusionGroup;
+}
+
+bool isSubgraphNodeKind(Node* n) {
+  return isSubgraphNodeKind(n->kind());
+}
+
+// Combine the nodes in two subgraph together. The nodes will end up in
+// `mergeTo`, and `mergeFrom` is destroyed.
+void mergeSubgraph(Node* mergeTo, Node* mergeFrom) {
+  const auto nodes = unmergeSubgraph(mergeFrom);
+  for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) {
+    mergeNodeIntoSubgraph(*it, mergeTo);
+  }
+}
+} // namespace
+
+std::shared_ptr<Graph> getSubgraph(Node* n) {
+  JIT_ASSERT(isSubgraphNodeKind(n));
+  return n->g(attr::Subgraph);
+}
+
+std::vector<Node*> unmergeSubgraph(Node* subgraphNode) {
+  JIT_ASSERT(subgraphNode->kind() == prim::DifferentiableGraph);
+  auto outerGraph = subgraphNode->owningGraph();
+
+  std::vector<Node*> temporary_nodes;
+  auto subgraph = getSubgraph(subgraphNode);
+
+  // Initialize a map of inner graph values to outer graph values
+  std::unordered_map<const Value*, Value*> innerToOuter;
+  const auto innerInputs = subgraph->inputs();
+  const auto outerInputs = subgraphNode->inputs();
+  for (size_t i = 0; i < innerInputs.size(); ++i) {
+    innerToOuter[innerInputs[i]] = outerInputs[i];
+  }
+
+  // Clone all nodes
+  for (auto inner : subgraph->nodes()) {
+    Node* outer = outerGraph->createClone(
+        inner, [&](Value* k) -> Value* { return innerToOuter.at(k); });
+    outer->insertBefore(subgraphNode);
+    temporary_nodes.emplace_back(outer);
+    const auto innerOutputs = inner->outputs();
+    const auto outerOutputs = outer->outputs();
+    for (size_t i = 0; i < innerOutputs.size(); ++i) {
+      innerToOuter[innerOutputs[i]] = outerOutputs[i];
+    }
+  }
+
+  // Replace uses of group outputs and destroy the group
+  const auto subgraphOutputs = subgraph->outputs();
+  for (size_t i = 0; i < subgraphOutputs.size(); ++i) {
+    const auto outerOutput = innerToOuter.at(subgraphOutputs[i]);
+    subgraphNode->outputs()[i]->replaceAllUsesWith(outerOutput);
+  }
+  subgraphNode->destroy();
+
+  return temporary_nodes;
+}
+
+void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode) {
+  JIT_ASSERT(isSubgraphNodeKind(subgraphNode));
+  if (isSubgraphNodeKind(toMerge)) {
+    return mergeSubgraph(subgraphNode, toMerge);
+  }
+
+  auto subgraph = getSubgraph(subgraphNode);
+
+  // Map from values in the surrounding graph to inputs in the subgraph
+  std::unordered_map<Value*, Value*> inputsMap;
+
+  JIT_ASSERT(subgraphNode->inputs().size() == subgraph->inputs().size());
+  size_t idx = 0;
+  for (auto input : subgraphNode->inputs()) {
+    inputsMap[input] = subgraph->inputs()[idx];
+    idx++;
+  }
+
+  // Add n's inputs to the group's input list if we don't already have them
+  WithInsertPoint guard(*subgraph->nodes().begin());
+  for (auto input : toMerge->inputs()) {
+    if (inputsMap.count(input) == 0) {
+      // Clone constants inside the subgraph instead of referencing them, to
+      // enable more optimizations
+      if (auto value = toIValue(input)) {
+        auto nv = subgraph->insertConstant(*value);
+        inputsMap[input] = nv;
+      } else {
+        // The common case: this is a regular input, so just register it with
+        // the group node and inner subgraph
+        subgraphNode->addInput(input);
+        auto inputToGraph = subgraph->addInput();
+        inputToGraph->setType(input->type());
+        inputsMap[input] = inputToGraph;
+      }
+    }
+  }
+
+  // Merge the node into the graph
+  auto mergedNode = subgraph->insertNode(
+      subgraph->createClone(toMerge, [&](Value* v) { return inputsMap[v]; }));
+
+  // If n's outputs were inputs to `group`, remove them since we just merged
+  // n in.
+  //
+  // i.e.,
+  // x = f(w); group(x, y, z) becomes group(w, y, z).
+  // x, y, z = f(w); group(x, y, z) becomes group(w).
+  auto inputs = subgraphNode->inputs();
+  for (size_t i = 0; i < toMerge->outputs().size(); ++i) {
+    auto it = std::find(inputs.begin(), inputs.end(), toMerge->outputs()[i]);
+    if (it != inputs.end()) {
+      size_t p = it - inputs.begin();
+      subgraphNode->removeInput(p);
+      subgraph->inputs()[p]->replaceAllUsesWith(mergedNode->outputs()[i]);
+      subgraph->eraseInput(p);
+    }
+  }
+
+  // Add n's outputs to the group node and inner subgraph outputs.
+  for (size_t i = 0; i < toMerge->outputs().size(); i++) {
+    auto oldOutput = toMerge->outputs()[i];
+
+    // Only register the output in the group node if it's actually used
+    // outside the subgraph.
+    const auto hasUsesOutsideSubgraph = std::any_of(
+        oldOutput->uses().cbegin(),
+        oldOutput->uses().cend(),
+        [&](const Use& use) { return use.user->isAfter(subgraphNode); });
+
+    if (hasUsesOutsideSubgraph) {
+      auto newOutput = mergedNode->outputs()[i];
+      subgraph->registerOutput(newOutput);
+      auto groupOutput = subgraphNode->addOutput();
+      groupOutput->copyMetadata(oldOutput);
+      oldOutput->replaceAllUsesWith(groupOutput);
+    }
+  }
+
+  // Remove the original node now that the merge is complete
+  toMerge->destroy();
+}
+
+Node* createSingletonSubgraph(Node* n, Symbol subgraphKind) {
+  JIT_ASSERT(isSubgraphNodeKind(subgraphKind));
+  auto graph = n->owningGraph();
+  auto subgraph = graph->create(subgraphKind, 0);
+  subgraph->g_(attr::Subgraph, std::make_shared<Graph>(graph->current_scope()));
+  subgraph->insertBefore(n);
+  mergeNodeIntoSubgraph(n, subgraph);
+  return subgraph;
+}
+} // namespace SubgraphUtils
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/passes/utils/subgraph_utils.h b/torch/csrc/jit/passes/utils/subgraph_utils.h
new file mode 100644 (file)
index 0000000..fdc76b3
--- /dev/null
@@ -0,0 +1,36 @@
+#pragma once
+
+#include "torch/csrc/jit/ir.h"
+
+namespace torch {
+namespace jit {
+
+// Utilities for dealing with nodes that contain subgraphs.
+//
+// They handle the complexity of editing inputs/outputs as you merge nodes in
+// and out of subgraphs.
+namespace SubgraphUtils {
+
+// Create a new subgraph node that contains only `n`. The new subgraph will have
+// `subgraphKind` as its type.
+//
+// `n` is destroyed.
+//
+// Returns the new subgraph node.
+Node* createSingletonSubgraph(Node* n, Symbol subgraphKind);
+
+// Merge a node into a subgraph node. If `toMerge` is also a subgraph, the
+// subgraphs are merged.
+// `toMerge` is destroyed.
+void mergeNodeIntoSubgraph(Node* toMerge, Node* subgraphNode);
+
+// Move nodes from a subgraph node to the outer graph.
+// `subgraphNode` is destroyed.
+std::vector<Node*> unmergeSubgraph(Node* subgraphNode);
+
+// Convenience function
+std::shared_ptr<Graph> getSubgraph(Node* n);
+
+} // namespace SubgraphUtils
+} // namespace jit
+} // namespace torch