Don't remove identity nodes if they follow a device crossing and have consumers on...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 12 Mar 2018 16:12:41 +0000 (09:12 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Mar 2018 16:16:44 +0000 (09:16 -0700)
Simplify code in DependencyOptimizer a bit.

PiperOrigin-RevId: 188730185

tensorflow/core/grappler/optimizers/dependency_optimizer.cc
tensorflow/core/grappler/optimizers/dependency_optimizer_test.cc

index a5b2572..63bc196 100644 (file)
@@ -274,12 +274,17 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
   //           +----------+             y --^> b
 
   if (is_noop || is_identity) {
+    if (is_identity && !SafeToRemoveIdentity(*node)) {
+      return;
+    }
+
     const auto& output_node_set = node_map_->GetOutputs(node_name);
     const std::vector<NodeDef*> output_nodes(output_node_set.begin(),
                                              output_node_set.end());
     const int num_outputs = output_nodes.size();
     const int num_inputs = node->input_size();
 
+    // Don't increase the number of edges in the graph.
     if (num_inputs * num_outputs > num_inputs + num_outputs) {
       return;
     }
@@ -293,39 +298,34 @@ void DependencyOptimizer::OptimizeNode(int node_idx,
       input_nodes.push_back(input_node);
     }
 
-    // Make sure that we don't increase the number of edges that cross
-    // device boundaries.
-    if ((num_inputs == 1 && num_outputs > 1 &&
-         input_nodes[0]->device() != node->device()) ||
-        (num_inputs > 1 && num_outputs == 1 &&
-         output_nodes[0]->device() != node->device())) {
+    // TODO(rmlarsen): Not all device crossings are equally expensive.
+    // Assign a cost to each based on device affinity and compute a
+    // cost before and after.
+    const string& node_dev = node->device();
+    int num_cross_in = 0;
+    for (NodeDef* input_node : input_nodes) {
+      num_cross_in += static_cast<int>(input_node->device() != node_dev);
+    }
+    int num_cross_out = 0;
+    for (NodeDef* output_node : output_nodes) {
+      num_cross_out += static_cast<int>(output_node->device() != node_dev);
+    }
+    if (is_identity && num_cross_in > 0 && num_cross_out > 0) {
+      // This identity node follows a device crossing, so it might be
+      // following a _Recv node after partioning. Do not remove such nodes,
+      // unless they only have consumers on the same device as themselves.
       return;
     }
-    if (num_inputs == 2 && num_outputs == 2) {
-      const string& noop_dev = node->device();
-      const string& in0_dev = input_nodes[0]->device();
-      const string& in1_dev = input_nodes[1]->device();
-      const string& out0_dev = output_nodes[0]->device();
-      const string& out1_dev = output_nodes[1]->device();
-      const int num_cross_before = static_cast<int>(in0_dev != noop_dev) +
-                                   static_cast<int>(in1_dev != noop_dev) +
-                                   static_cast<int>(out0_dev != noop_dev) +
-                                   static_cast<int>(out1_dev != noop_dev);
-      const int num_cross_after = static_cast<int>(in0_dev != out0_dev) +
-                                  static_cast<int>(in0_dev != out1_dev) +
-                                  static_cast<int>(in1_dev != out0_dev) +
-                                  static_cast<int>(in1_dev != out1_dev);
-      if (num_cross_after > num_cross_before) {
-        return;
-      }
-      // To avoid potentially removing Identity nodes following _Recv nodes,
-      // we require that no device crossings occur in that case.
-      // TODO(rmlarsen): See if we can relax this condition.
-      if (is_identity && (num_cross_after > 0 || num_cross_before > 0)) {
-        return;
+    const int num_cross_before = num_cross_in + num_cross_out;
+    int num_cross_after = 0;
+    for (NodeDef* input_node : input_nodes) {
+      for (NodeDef* output_node : output_nodes) {
+        num_cross_after +=
+            static_cast<int>(input_node->device() != output_node->device());
       }
     }
-    if (is_identity && !SafeToRemoveIdentity(*node)) {
+    if (num_cross_after > num_cross_before) {
+      // Avoid increasing the number of device crossings.
       return;
     }
 
index b66cc17..cc1e142 100644 (file)
@@ -595,6 +595,57 @@ TEST_F(DependencyOptimizerTest, IdentityN) {
   EXPECT_EQ("id_b:1", output.node(8).input(0));
 }
 
+TEST_F(DependencyOptimizerTest,
+       Identity_DeviceCrossing_ConsumerOnDifferentDevice) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output x_on_1 =
+      ops::Const(s.WithOpName("x_on_1").WithDevice("/gpu:1"), {1.0f}, {});
+  Output one_on_3 =
+      ops::Const(s.WithOpName("one_on_3").WithDevice("/gpu:3"), {1.0f}, {});
+  Output x_on_2 =
+      ops::Identity(s.WithOpName("x_on_2").WithDevice("/gpu:2"), x_on_1);
+  Output result =
+      ops::Add(s.WithOpName("result").WithDevice("/gpu:3"), x_on_2, one_on_3);
+
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  item.fetch = {"result"};
+  DependencyOptimizer optimizer;
+  GraphDef output;
+  Status status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+
+  VerifyGraphsEqual(item.graph, output, __FUNCTION__);
+}
+
+TEST_F(DependencyOptimizerTest, Identity_DeviceCrossing_ConsumerOnSameDevice) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  Output x_on_1 =
+      ops::Const(s.WithOpName("x_on_1").WithDevice("/gpu:1"), {1.0f}, {});
+  Output one_on_2 =
+      ops::Const(s.WithOpName("one_on_2").WithDevice("/gpu:2"), {1.0f}, {});
+  Output x_on_2 =
+      ops::Identity(s.WithOpName("x_on_2").WithDevice("/gpu:2"), x_on_1);
+  Output result =
+      ops::Add(s.WithOpName("result").WithDevice("/gpu:2"), x_on_2, one_on_2);
+
+  GrapplerItem item;
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+  item.fetch = {"result"};
+  DependencyOptimizer optimizer;
+  GraphDef output;
+  Status status = optimizer.Optimize(nullptr, item, &output);
+  TF_EXPECT_OK(status);
+  LOG(INFO) << output.DebugString();
+  EXPECT_EQ(3, output.node_size());
+  for (const auto& node : output.node()) {
+    EXPECT_NE("x_on_2", node.name());
+    if (node.name() == "result") {
+      EXPECT_EQ("x_on_1", node.input(0));
+    }
+  }
+}
+
 }  // namespace
 }  // namespace grappler
 }  // namespace tensorflow