Identify and prune nodes that can never be executed
authorBenoit Steiner <bsteiner@google.com>
Fri, 4 May 2018 22:14:00 +0000 (15:14 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 4 May 2018 22:41:54 +0000 (15:41 -0700)
PiperOrigin-RevId: 195478951

tensorflow/core/grappler/optimizers/BUILD
tensorflow/core/grappler/optimizers/loop_optimizer.cc
tensorflow/core/grappler/optimizers/loop_optimizer.h
tensorflow/core/grappler/optimizers/loop_optimizer_test.cc
tensorflow/core/grappler/utils.h

index 5b5e1e0..900dfa9 100644 (file)
@@ -604,6 +604,7 @@ cc_library(
         "//tensorflow/core:lib",
         "//tensorflow/core:lib_internal",
         "//tensorflow/core:protos_all_cc",
+        "//tensorflow/core/grappler:graph_view",
         "//tensorflow/core/grappler:grappler_item",
         "//tensorflow/core/grappler:op_types",
         "//tensorflow/core/grappler:utils",
index 5adc5b9..7d3520f 100644 (file)
@@ -27,6 +27,7 @@ limitations under the License.
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/tensor_shape.pb.h"
 #include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/grappler/graph_view.h"
 #include "tensorflow/core/grappler/grappler_item.h"
 #include "tensorflow/core/grappler/op_types.h"
 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
@@ -504,6 +505,140 @@ Status RemoveStackOps(const std::unordered_set<string>& nodes_to_preserve,
   return Status::OK();
 }
 
+Status RemoveDeadBranches(const std::unordered_set<string>& nodes_to_preserve,
+                          GraphDef* optimized_graph) {
+  std::unordered_set<const NodeDef*> dead_nodes;
+  std::unordered_map<NodeDef*, std::set<int>> dead_merge_inputs;
+  // TODO(bsteiner): also rewrite switches as identity. For now we just record
+  // them
+  std::unordered_set<GraphView::OutputPort, GraphView::HashPort>
+      identity_switches;
+
+  GraphView view(optimized_graph);
+  for (const NodeDef& node : optimized_graph->node()) {
+    if (!IsSwitch(node)) {
+      continue;
+    }
+    if (nodes_to_preserve.find(node.name()) != nodes_to_preserve.end()) {
+      continue;
+    }
+    GraphView::InputPort ctrl_port(&node, 1);
+    GraphView::OutputPort ctrl_node = view.GetRegularFanin(ctrl_port);
+    if (!IsConstant(*ctrl_node.node)) {
+      continue;
+    }
+    Tensor selector;
+    CHECK(selector.FromProto(ctrl_node.node->attr().at("value").tensor()));
+    const int dead_fanout = selector.scalar<bool>()() ? 0 : 1;
+    GraphView::OutputPort dead(const_cast<NodeDef*>(&node), dead_fanout);
+    identity_switches.insert(dead);
+
+    SetVector<GraphView::InputPort, GraphView::HashPort> zombie_inputs;
+    for (const GraphView::InputPort& port : view.GetFanout(dead)) {
+      if (dead_nodes.find(port.node) == dead_nodes.end()) {
+        zombie_inputs.PushBack(port);
+      }
+    }
+    // If we encounter a single node that must be preserved in the fanout of the
+    // switch node we need to preserve the entire switch fanout: we therefore
+    // work on a local copy that only gets committed to the master copy once the
+    // whole fanout has been explored.
+    std::unordered_set<const NodeDef*> local_dead_nodes = dead_nodes;
+    std::unordered_map<NodeDef*, std::set<int>> local_dead_merge_inputs =
+        dead_merge_inputs;
+    bool found_node_to_preserve = false;
+    while (!found_node_to_preserve && !zombie_inputs.Empty()) {
+      GraphView::InputPort dead = zombie_inputs.PopBack();
+      if (nodes_to_preserve.find(dead.node->name()) !=
+          nodes_to_preserve.end()) {
+        found_node_to_preserve = true;
+        break;
+      }
+
+      if (local_dead_nodes.find(dead.node) != local_dead_nodes.end()) {
+        continue;
+      }
+
+      if (IsMerge(*dead.node)) {
+        const int fanout = dead.node->attr().at("N").i();
+        if (fanout > 2) {
+          // This never happens in practice, so we'll just skip these to
+          // simplify the code for now.
+          found_node_to_preserve = true;
+          break;
+        }
+        GraphView::OutputPort value_index(dead.node, 1);
+        const std::unordered_set<GraphView::InputPort, GraphView::HashPort>&
+            index_fanout = view.GetFanout(value_index);
+        if (!index_fanout.empty()) {
+          // The 2nd output (that indicates which input is propagated) is
+          // connected. This never happens in practice, so we'll just skip this
+          // case to simplify the code for now.
+          found_node_to_preserve = true;
+          break;
+        }
+
+        bool fully_dead = false;
+        if (dead.port_id < 0) {
+          // If the control dependency never gets triggered the merge will also
+          // never get triggered.
+          local_dead_nodes.insert(dead.node);
+          fully_dead = true;
+        } else {
+          local_dead_merge_inputs[dead.node].insert(dead.port_id);
+          if (local_dead_merge_inputs[dead.node].size() ==
+              dead.node->attr().at("N").i()) {
+            fully_dead = true;
+          }
+          if (fully_dead) {
+            local_dead_nodes.insert(dead.node);
+            for (const GraphView::InputPort& port :
+                 view.GetFanouts(*dead.node, true)) {
+              zombie_inputs.PushBack(port);
+            }
+          }
+        }
+      } else {
+        if (local_dead_nodes.insert(dead.node).second) {
+          for (const GraphView::InputPort& dead_fanout :
+               view.GetFanouts(*dead.node, true)) {
+            zombie_inputs.PushBack(dead_fanout);
+          }
+        }
+      }
+    }
+    if (!found_node_to_preserve) {
+      std::swap(dead_nodes, local_dead_nodes);
+      std::swap(dead_merge_inputs, local_dead_merge_inputs);
+    }
+  }
+
+  int last = optimized_graph->node_size() - 1;
+  for (int i = optimized_graph->node_size() - 1; i >= 0; --i) {
+    NodeDef* node = optimized_graph->mutable_node(i);
+    if (dead_nodes.find(node) != dead_nodes.end()) {
+      optimized_graph->mutable_node()->SwapElements(i, last);
+      last--;
+    }
+  }
+  optimized_graph->mutable_node()->DeleteSubrange(last + 1, dead_nodes.size());
+
+  for (const auto& itr : dead_merge_inputs) {
+    NodeDef* dead_node = itr.first;
+    if (dead_nodes.find(dead_node) != dead_nodes.end()) {
+      // The node has been pruned since all its inputs are dead.
+      continue;
+    }
+    const std::set<int>& dead_inputs = itr.second;
+    for (int index : dead_inputs) {
+      dead_node->mutable_input()->DeleteSubrange(index, 1);
+    }
+    dead_node->set_op("Identity");
+    dead_node->mutable_attr()->erase("N");
+  }
+  return Status::OK();
+}
+
 }  // namespace
 
 Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
@@ -517,6 +652,11 @@ Status LoopOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
   if (options_.enable_stack_push_removal) {
     TF_RETURN_IF_ERROR(RemoveStackOps(item.NodesToPreserve(), optimized_graph));
   }
+  if (opt_level_ == RewriterConfig::AGGRESSIVE &&
+      options_.enable_dead_branch_removal) {
+    TF_RETURN_IF_ERROR(
+        RemoveDeadBranches(item.NodesToPreserve(), optimized_graph));
+  }
 
   return Status::OK();
 }
index 764506f..85b8e65 100644 (file)
@@ -54,6 +54,7 @@ class LoopOptimizer : public GraphOptimizer {
   struct LoopOptimizerOptions {
     bool enable_loop_invariant_node_motion = false;
     bool enable_stack_push_removal = true;
+    bool enable_dead_branch_removal = true;
 
     static LoopOptimizerOptions Default(RewriterConfig::Toggle opt_level) {
       LoopOptimizerOptions options;
index 10ec544..6fd177b 100644 (file)
@@ -589,5 +589,112 @@ TEST_F(LoopOptimizerTest, RemovePushWithoutMatchingPop) {
   }
 }
 
+TEST_F(LoopOptimizerTest, RemoveDeadBranches) {
+  Scope scope = Scope::NewRootScope();
+  Output v_in = ops::Variable(scope.WithOpName("v_in"), {3}, DT_FLOAT);
+
+  Output ctrl1 = ops::Const(scope.WithOpName("ctrl1"), false, TensorShape({}));
+  ops::Switch s1(scope.WithOpName("switch1"), v_in, ctrl1);
+  Output square1 = ops::Square(scope.WithOpName("square1"), s1.output_false);
+  Output sqrt1 = ops::Sqrt(scope.WithOpName("sqrt1"), s1.output_true);
+
+  Output ctrl2 = ops::Const(scope.WithOpName("ctrl2"), true, TensorShape({}));
+  ops::Switch s2(scope.WithOpName("switch2"), v_in, ctrl2);
+  Output square2 = ops::Square(scope.WithOpName("square2"), s2.output_false);
+  Output sqrt2 = ops::Sqrt(scope.WithOpName("sqrt2"), s2.output_true);
+
+  Output ctrl3 = ops::Const(scope.WithOpName("ctrl3"), false, TensorShape({}));
+  ops::Switch s3(scope.WithOpName("switch3"), v_in, ctrl3);
+  Output square3 = ops::Square(scope.WithOpName("square3"), s3.output_false);
+  Output sqrt3 = ops::Sqrt(scope.WithOpName("sqrt3"), s3.output_true);
+
+  Output ctrl4 = ops::Const(scope.WithOpName("ctrl4"), false, TensorShape({}));
+  ops::Switch s4(scope.WithOpName("switch4"), v_in, ctrl4);
+  Output square4 = ops::Square(scope.WithOpName("square4"), s4.output_false);
+  Output sqrt4 = ops::Sqrt(scope.WithOpName("sqrt4"), s4.output_true);
+
+  ops::Merge m1(scope.WithOpName("m1"), {square1, sqrt1});
+  ops::Merge m2(scope.WithOpName("m2"), {v_in, square1});
+  ops::Merge m3(scope.WithOpName("m3"), {v_in, sqrt1});
+  ops::Merge m4(scope.WithOpName("m4"), {square1, sqrt2});
+  ops::Merge m5(scope.WithOpName("m5"), {square2, sqrt1});
+  ops::Merge m6(scope.WithOpName("m6").WithControlDependencies(sqrt2),
+                {v_in, square1});
+  ops::Merge m7(scope.WithOpName("m7").WithControlDependencies(sqrt1),
+                {v_in, square1});
+
+  ops::Switch s5(scope.WithOpName("switch5"), v_in, ctrl1);
+  Output id1 = ops::Identity(scope.WithOpName("id1"), s5.output_false);
+  Output id2 = ops::Identity(scope.WithOpName("id2"), s5.output_true);
+  ops::Merge m8(scope.WithOpName("m8"), {id1, id2});
+
+  ops::Switch s6(scope.WithOpName("switch6"), v_in, ctrl1);
+  Output id3 = ops::Identity(scope.WithOpName("id3"), s6.output_false);
+  Output id4 = ops::Identity(scope.WithOpName("id4"), s6.output_true);
+  ops::Merge m9(scope.WithOpName("m9"), {id3, id4});
+
+  GrapplerItem item;
+  item.fetch.push_back("m8");
+  item.fetch.push_back("id4");
+
+  TF_CHECK_OK(scope.ToGraphDef(&item.graph));
+
+  LoopOptimizer optimizer(RewriterConfig::AGGRESSIVE);
+  GraphDef output;
+  Status status = optimizer.Optimize(nullptr, item, &output);
+  TF_CHECK_OK(status);
+
+  for (const NodeDef& node : output.node()) {
+    // These nodes should have been pruned
+    EXPECT_NE("Square1", node.name());
+    EXPECT_NE("Sqrt2", node.name());
+    EXPECT_NE("m5", node.name());
+    EXPECT_NE("m7", node.name());
+
+    if (node.name() == "m1") {
+      // sqrt1 is dead
+      EXPECT_EQ("Identity", node.op());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("square1", node.input(0));
+    } else if (node.name() == "m2") {
+      // both inputs are alive
+      EXPECT_EQ("Merge", node.op());
+      EXPECT_EQ(2, node.input_size());
+      EXPECT_EQ("v_in", node.input(0));
+      EXPECT_EQ("square1", node.input(1));
+    } else if (node.name() == "m3") {
+      // sqrt1 is dead
+      EXPECT_EQ("Identity", node.op());
+      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ("v_in", node.input(0));
+    } else if (node.name() == "m4") {
+      // both inputs are alive
+      EXPECT_EQ("Merge", node.op());
+      EXPECT_EQ(2, node.input_size());
+      EXPECT_EQ("square1", node.input(0));
+      EXPECT_EQ("sqrt2", node.input(1));
+    } else if (node.name() == "m6") {
+      // both inputs are alive and the control dependency can get triggered
+      EXPECT_EQ("Merge", node.op());
+      EXPECT_EQ(3, node.input_size());
+      EXPECT_EQ("v_in", node.input(0));
+      EXPECT_EQ("square1", node.input(1));
+      EXPECT_EQ("^sqrt2", node.input(2));
+    } else if (node.name() == "m8") {
+      // The node is to be preserved because of a fetch
+      EXPECT_EQ("Merge", node.op());
+      EXPECT_EQ(2, node.input_size());
+      EXPECT_EQ("id1", node.input(0));
+      EXPECT_EQ("id2", node.input(1));
+    } else if (node.name() == "m9") {
+      // The node is to be preserved because of a fetch
+      EXPECT_EQ("Merge", node.op());
+      EXPECT_EQ(2, node.input_size());
+      EXPECT_EQ("id3", node.input(0));
+      EXPECT_EQ("id4", node.input(1));
+    }
+  }
+}
+
 }  // namespace grappler
 }  // namespace tensorflow
index b87ae05..1c6fef5 100644 (file)
@@ -65,7 +65,7 @@ class NodeMap {
 // A vector with a set. The set stores the same elements as the vector, and
 // quickly answers whether a value is in the vector. Duplicated elements are not
 // allowed for now.
-template <class T>
+template <class T, class Hash = std::hash<T>>
 class SetVector {
  public:
   // Returns false if value already existed in the set, true otherwise.
@@ -91,7 +91,7 @@ class SetVector {
   void Reserve(int64 size) { vector_.reserve(size); }
 
  private:
-  std::unordered_set<T> set_;
+  std::unordered_set<T, Hash> set_;
   std::vector<T> vector_;
 };