computeChains with nomnigraph (#15366)
authorBram Wasti <bwasti@fb.com>
Wed, 19 Dec 2018 22:31:06 +0000 (14:31 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 19 Dec 2018 23:04:23 +0000 (15:04 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15366

swap the old implementation with a slightly easier one to understand

I ran the tests and compared the number of chains compared to the old algorithm.  This one outperforms on every test, but we have yet to see if that impacts performance at all.

old chain 34 nomnigraph chain 25
old chain 46 nomnigraph chain 34
old chain 228 nomnigraph chain 188
old chain 397 nomnigraph chain 338

Reviewed By: ilia-cher

Differential Revision: D13057451

fbshipit-source-id: ccd050bfead6eb94ab9c7b0a70b09a22c2b9e499

caffe2/core/net_async_base.cc
caffe2/core/net_dag_utils.cc
caffe2/core/net_dag_utils.h
caffe2/python/test/executor_test.py

index 6d9f741..680893f 100644 (file)
@@ -74,7 +74,7 @@ AsyncNetBase::AsyncNetBase(
   if (FLAGS_caffe2_net_async_inference_mode) {
     execution_chains_ = dag_utils::computeGroups(operator_nodes_);
   } else {
-    execution_chains_ = dag_utils::computeChains(operator_nodes_);
+    execution_chains_ = dag_utils::computeChains(*net_def, operator_nodes_);
   }
   chains_.reserve(execution_chains_.size());
   for (const auto& kv : execution_chains_) {
index 8786ac6..c46b5c5 100644 (file)
@@ -8,9 +8,12 @@
 #include "caffe2/core/operator.h"
 #include "caffe2/core/static_tracepoint.h"
 #include "caffe2/core/timer.h"
+#include "caffe2/opt/converter.h"
 #include "caffe2/proto/caffe2_pb.h"
 #include "caffe2/utils/proto_utils.h"
 
+#include "nomnigraph/Graph/Algorithms.h"
+
 namespace caffe2 {
 namespace dag_utils {
 
@@ -120,6 +123,146 @@ void updateOperatorNodes(
 }
 } // namespace
 
+using namespace nom::repr;
+using DepGraph = nom::Graph<NNGraph::NodeRef>;
+
+// \brief This function prunes edges in the dependency
+// graph to increase the chaining opportunity.
+// It does not eliminate parallelism opportunity.
+void optimizeDependencyGraph(DepGraph* deps) {
+  auto edges = deps->getMutableEdges();
+  for (const auto& edge : edges) {
+    auto tail = edge->tail();
+    auto head = edge->head();
+    deps->deleteEdge(edge);
+    std::unordered_set<DepGraph::NodeRef> seen;
+    nom::algorithm::reachable<DepGraph>(tail, nullptr, &seen);
+    // Removing that edge removes a dominator, which is invalid
+    if (!seen.count(head)) {
+      deps->createEdge(tail, head);
+    }
+  }
+}
+
+ExecutionChains computeChains(
+    const caffe2::NetDef& predict_net,
+    std::vector<OperatorNode>& orig_nodes) {
+  // These serve as the map into predict_net.op()
+  std::vector<NNGraph::NodeRef> nom_ops;
+  auto nn = convertToNNModule(predict_net, false, &nom_ops);
+  CAFFE_ENFORCE_EQ(nom_ops.size(), predict_net.op().size());
+
+  // Create a map from NodeRef to index into predict_net.op()
+  // Now we can use pure nomnigraph functions and map back later
+  std::unordered_map<NNGraph::NodeRef, int> nom_op_to_pos;
+  for (auto idx = 0; idx < nom_ops.size(); ++idx) {
+    nom_op_to_pos[nom_ops[idx]] = idx;
+  }
+
+  // The algorithm:
+  // 1) create dependency graph of ops
+  // 2) for all nodes thats have multiple in edges, remove all in edges
+  // 3) for all nodes thats have multiple out edges, remove all out edges
+  // 4) return the components as chains
+
+  // Caveats that can easily be handled
+  // 1) Cannot have a chain that crosses device options
+  //    insert extra edge at each boundary
+  // 2) All CPU async ops have to be the last op in a chain
+  //    insert extra out edge
+  DepGraph deps;
+
+  // Map NodeRef to the node in the dependency graph
+  std::unordered_map<NNGraph::NodeRef, DepGraph::NodeRef> dep_map;
+  for (const auto& node : nn::filter<NeuralNetOperator>(nn)) {
+    dep_map[node] = deps.createNode(node);
+  }
+
+  // 1) Create dependency graph
+  for (const auto& node : nn::filter<NeuralNetOperator>(nn)) {
+    for (const auto& output : nn::getOutputs(node)) {
+      for (const auto& consumer : nn::getConsumers(output)) {
+        // Record single dependencies first
+        if (!deps.hasEdge(dep_map[node], dep_map[consumer])) {
+          deps.createEdge(dep_map[node], dep_map[consumer]);
+        }
+      }
+    }
+  }
+
+  optimizeDependencyGraph(&deps);
+
+  // Fixup device boundary and async op issues
+  for (const auto& dep : deps.getMutableNodes()) {
+    int op_idx = nom_op_to_pos[dep->data()];
+    auto d1 = orig_nodes.at(op_idx).operator_->device_option();
+    auto outEdges = dep->getOutEdges();
+    for (const auto& outEdge : outEdges) {
+      int op2_idx = nom_op_to_pos[outEdge->head()->data()];
+      auto d2 = orig_nodes.at(op2_idx).operator_->device_option();
+      if (!IsSameDevice(d1, d2)) {
+        deps.createEdge(dep, outEdge->head());
+      }
+    }
+    if (d1.device_type() == PROTO_CUDA) {
+      continue;
+    }
+    if (orig_nodes.at(op_idx).operator_->HasAsyncPart()) {
+      outEdges = dep->getOutEdges();
+      for (const auto& outEdge : outEdges) {
+        // Clone out edges
+        deps.createEdge(outEdge->tail(), outEdge->head());
+      }
+    }
+  }
+
+  // 2) Prune in edges if multiplicity > 1
+  // 3) Prune out edges if multiplicity > 1
+  for (const auto& dep : deps.getMutableNodes()) {
+    auto inEdges = dep->getInEdges();
+    if (inEdges.size() > 1) {
+      for (const auto& inEdge : inEdges) {
+        NOM_REQUIRE_OR_CONT(inEdge);
+        deps.deleteEdge(inEdge);
+      }
+    }
+    auto outEdges = dep->getOutEdges();
+    if (outEdges.size() > 1) {
+      for (const auto& outEdge : outEdges) {
+        NOM_REQUIRE_OR_CONT(outEdge);
+        deps.deleteEdge(outEdge);
+      }
+    }
+  }
+
+  // 4) Return components as chains
+  std::vector<DepGraph::NodeRef> chain_starts;
+  for (const auto& dep : deps.getMutableNodes()) {
+    if (dep->getInEdges().size() == 0) {
+      chain_starts.emplace_back(dep);
+    }
+  }
+
+  ExecutionChains chains;
+  for (const auto& dep : chain_starts) {
+    DepGraph::NodeRef front = dep;
+    std::vector<int> ops;
+    do {
+      ops.emplace_back(nom_op_to_pos[front->data()]);
+      auto outEdges = front->getOutEdges();
+      if (outEdges.size()) {
+        front = outEdges.at(0)->head();
+      } else {
+        front = nullptr;
+      }
+    } while (front);
+    chains[nom_op_to_pos[dep->data()]] = ops;
+  }
+
+  updateOperatorNodes(orig_nodes, chains);
+  return chains;
+}
+
 ExecutionChains computeChains(std::vector<OperatorNode>& orig_nodes) {
   const std::vector<OpGraphNode> nodes = pruneOpNodeGraph(orig_nodes);
   vector<int> initial_frontier;
index 9b605a9..930f8e4 100644 (file)
@@ -43,6 +43,9 @@ struct OpGraphNode {
 
 using ExecutionChains = std::unordered_map<int, std::vector<int>>;
 
+C10_EXPORT ExecutionChains computeChains(
+    const caffe2::NetDef& predict_net,
+    std::vector<OperatorNode>& orig_nodes);
 C10_EXPORT ExecutionChains computeChains(std::vector<OperatorNode>& orig_nodes);
 
 // Instead of breaking down the DAG into chains, we partition it into clusters
index ee52717..8c9a42b 100644 (file)
@@ -2,14 +2,16 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from caffe2.python import core, workspace
+from caffe2.python import core, workspace, model_helper
+import random
 from caffe2.python.test.executor_test_util import (
     build_conv_model,
     build_resnet50_dataparallel_model,
     run_resnet50_epoch,
     ExecutorTestBase,
     executor_test_settings,
-    executor_test_model_names)
+    executor_test_model_names,
+)
 
 from caffe2.python.test_util import TestCase
 
@@ -24,10 +26,12 @@ ITERATIONS = 1
 
 
 class ExecutorCPUConvNetTest(ExecutorTestBase):
-    @given(executor=st.sampled_from(EXECUTORS),
-           model_name=st.sampled_from(executor_test_model_names()),
-           batch_size=st.sampled_from([1]),
-           num_workers=st.sampled_from([8]))
+    @given(
+        executor=st.sampled_from(EXECUTORS),
+        model_name=st.sampled_from(executor_test_model_names()),
+        batch_size=st.sampled_from([1]),
+        num_workers=st.sampled_from([8]),
+    )
     @executor_test_settings
     def test_executor(self, executor, model_name, batch_size, num_workers):
         model = build_conv_model(model_name, batch_size)
@@ -50,8 +54,7 @@ class ExecutorCPUConvNetTest(ExecutorTestBase):
 @unittest.skipIf(not workspace.has_gpu_support
                 and not workspace.has_hip_support, "no gpu")
 class ExecutorGPUResNetTest(ExecutorTestBase):
-    @given(executor=st.sampled_from(EXECUTORS),
-           num_workers=st.sampled_from([8]))
+    @given(executor=st.sampled_from(EXECUTORS), num_workers=st.sampled_from([8]))
     @executor_test_settings
     def test_executor(self, executor, num_workers):
         model = build_resnet50_dataparallel_model(
@@ -100,5 +103,33 @@ class ExecutorFailingOpTest(TestCase):
         self.assertFalse(res)
 
 
-if __name__ == '__main__':
+class ExecutorFuzzTest(ExecutorTestBase):
+    def test_fuzzy_model(self):
+        model = model_helper.ModelHelper(name="test")
+        inits = []
+        for i in range(100):
+            init = model.param_init_net.ConstantFill(
+                [], "ONE" + str(i), shape=[1], value=1.0
+            )
+            inits.append(init)
+        adds = []
+        for i in range(1000):
+            add = model.net.Add(
+                [random.choice(inits + adds), random.choice(inits + adds)],
+                "ADD" + str(i),
+            )
+            adds.append(add)
+
+        def run_model():
+            workspace.RunNet(model.net, 100)
+
+        self.compare_executors(
+            model,
+            ref_executor="simple",
+            test_executor="async_scheduling",
+            model_run_func=run_model,
+        )
+
+
+if __name__ == "__main__":
     unittest.main()