Fixed handling of control dependencies in the arithmethic optimizer
authorBenoit Steiner <bsteiner@google.com>
Wed, 4 Apr 2018 23:17:46 +0000 (16:17 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 4 Apr 2018 23:22:19 +0000 (16:22 -0700)
PiperOrigin-RevId: 191665098

tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
tensorflow/core/grappler/optimizers/graph_optimizer_stage.h
tensorflow/core/grappler/optimizers/graph_optimizer_stage_test.cc

index 919f23f..59a5695 100644 (file)
@@ -34,7 +34,6 @@ limitations under the License.
 #include "tensorflow/core/grappler/optimizers/graph_optimizer_stage.h"
 #include "tensorflow/core/grappler/optimizers/symbolic_shapes.h"
 #include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/grappler/utils/frame.h"
 #include "tensorflow/core/grappler/utils/topological_sort.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
@@ -290,21 +289,16 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
 
   // TODO(ezhulenev): remove this method from ArithmeticOptimizer when all
   // optimizations will be migrated to stages
-  void AddFrameControlDeps(const NodeDef* old_node,
-                           const std::vector<NodeDef*>& new_nodes,
-                           const string& source_for_ctrl_dep,
-                           const std::vector<NodeDef*>& sinks_for_control_dep) {
-    const auto frame_it = ctx_.frame_map->find(old_node);
-    if (frame_it != ctx_.frame_map->end()) {
-      for (auto node : new_nodes) {
-        ctx_.frame_map->emplace(node, frame_it->second);
-      }
-      if (!source_for_ctrl_dep.empty() && !sinks_for_control_dep.empty()) {
-        const string ctrl_dep = ConstantFolding::AddControlDependency(
-            source_for_ctrl_dep, ctx_.optimized_graph, ctx_.node_map);
-        for (auto node : sinks_for_control_dep) {
-          MaybeAddControlInput(ctrl_dep, node, ctx_.optimized_graph,
-                               ctx_.node_map);
+  void ForwardControlDependencies(
+      NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) {
+    for (const auto& src : src_nodes) {
+      for (int i = src->input_size() - 1; i >= 0; --i) {
+        if (IsControlInput(src->input(i))) {
+          *target_node->add_input() = src->input(i);
+          ctx_.node_map->AddOutput(NodeName(src->input(i)),
+                                   target_node->name());
+        } else {
+          break;
         }
       }
     }
@@ -703,7 +697,8 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
     CHECK(IsSupported(node));
 
     std::set<string> common_factors;
-    TF_RETURN_IF_ERROR(GetCommonFactors(node, &common_factors));
+    std::vector<string> ctrl_deps;
+    TF_RETURN_IF_ERROR(GetCommonFactors(node, &common_factors, &ctrl_deps));
 
     if (common_factors.size() == 1) {
       const string& common_factor = *common_factors.begin();
@@ -735,9 +730,11 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
           new_add_node->set_input(i, unique_factors[i]);
         }
 
-        // Add frame dependencies that the original node might have had.
-        AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor,
-                            {new_add_node});
+        // Add control deps on add node
+        for (const string& ctrl_dep : ctrl_deps) {
+          *new_add_node->add_input() = ctrl_dep;
+          ctx_.node_map->AddOutput(NodeName(ctrl_dep), new_add_node->name());
+        }
 
         // optimize new inner aggregation node
         AddToOptimizationQueue(new_add_node);
@@ -763,14 +760,16 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
   }
 
   // Determine the set of common factors if the input nodes are all Mul nodes.
-  Status GetCommonFactors(const NodeDef* node,
-                          std::set<string>* common_factors) const {
+  Status GetCommonFactors(const NodeDef* node, std::set<string>* common_factors,
+                          std::vector<string>* ctrl_deps) const {
     CHECK(common_factors->empty());
 
     for (int i = 0; i < node->input_size(); ++i) {
       if (i > 0 && common_factors->empty()) break;
-      if (IsControlInput(node->input(i))) break;
-
+      if (IsControlInput(node->input(i))) {
+        ctrl_deps->push_back(node->input(i));
+        continue;
+      }
       NodeDef* input;
       TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input));
 
@@ -790,6 +789,9 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
             std::inserter(intersection, intersection.begin()));
         std::swap(*common_factors, intersection);
       }
+      for (int i = 2; i < input->input_size(); ++i) {
+        ctrl_deps->push_back(input->input(i));
+      }
     }
     return Status::OK();
   }
@@ -1275,20 +1277,15 @@ void ArithmeticOptimizer::DedupComputations() {
   }
 }
 
-void ArithmeticOptimizer::AddFrameControlDeps(
-    const NodeDef* old_node, const std::vector<NodeDef*>& new_nodes,
-    const string& source_for_ctrl_dep,
-    const std::vector<NodeDef*>& sinks_for_control_dep) {
-  const auto frame_it = frame_map_.find(old_node);
-  if (frame_it != frame_map_.end()) {
-    for (auto node : new_nodes) {
-      frame_map_.emplace(node, frame_it->second);
-    }
-    if (!source_for_ctrl_dep.empty() && !sinks_for_control_dep.empty()) {
-      const string ctrl_dep = ConstantFolding::AddControlDependency(
-          source_for_ctrl_dep, optimized_graph_, node_map_.get());
-      for (auto node : sinks_for_control_dep) {
-        MaybeAddControlInput(ctrl_dep, node, optimized_graph_, node_map_.get());
+void ArithmeticOptimizer::ForwardControlDependencies(
+    NodeDef* target_node, const std::vector<const NodeDef*>& src_nodes) {
+  for (const auto& src : src_nodes) {
+    for (int i = src->input_size() - 1; i >= 0; --i) {
+      if (IsControlInput(src->input(i))) {
+        *target_node->add_input() = src->input(i);
+        node_map_->AddOutput(NodeName(src->input(i)), target_node->name());
+      } else {
+        break;
       }
     }
   }
@@ -1408,10 +1405,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
           node_map_->AddOutput(new_transpose->name(), new_cast->name());
 
           nodes_to_simplify->PushBack(new_transpose);
-          //  Add frame dependencies that the original node might have had.
-          AddFrameControlDeps(node, {new_transpose, new_cast},
-                              new_transpose->input(0), {new_transpose});
-
+          ForwardControlDependencies(new_transpose, {cast, node});
           return new_cast->name();
         }
       }
@@ -1485,7 +1479,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
             node_map_->AddOutput(weights->name(), scaled_weights->name());
             scaled_weights->add_input(mul->input(1));
             node_map_->AddOutput(scale->name(), scaled_weights->name());
-            AddFrameControlDeps(node, {scaled_weights}, "", {});
+            ForwardControlDependencies(scaled_weights, {source});
 
             // Update `conv`'s weights to `scaled_weights`.
             conv->set_input(1, scaled_weights->name());
@@ -1521,7 +1515,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
   }
 
   if (IsAggregate(*node) && NumNonControlInputs(*node) > 0) {
-    // Discard aggregate nodes with a single input.
+    // Discard aggregate nodes with a single input and no control dependencies.
     if (node->input_size() == 1) {
       return node->input(0);
     }
@@ -1567,6 +1561,8 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
         return "";
       }
       new_const_node->set_device(node->device());
+      MaybeAddControlInput(NodeName(node->input(0)), new_const_node,
+                           optimized_graph_, node_map_.get());
       nodes_to_simplify->PushBack(new_const_node);
 
       // 2. Replace the aggregate node with Mul(Const(N), x).
@@ -1579,9 +1575,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
       new_mul_node->add_input(node->input(0));
       node_map_->AddOutput(node->input(0), new_mul_node->name());
 
-      CopyControlInputs(*node, new_mul_node, optimized_graph_, node_map_.get());
-      AddFrameControlDeps(node, {new_const_node, new_mul_node}, node->input(0),
-                          {new_const_node});
+      ForwardControlDependencies(new_mul_node, {node});
       return new_mul_node->name();
     }
   }
@@ -1614,7 +1608,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
         FlipBooleanAttr(attr_a, new_op);
         new_op->set_input(0, a->input(0));
         node_map_->UpdateInput(new_op->name(), a->name(), a->input(0));
-        AddFrameControlDeps(node, {new_op}, a->input(0), {new_op});
       }
       if (b_is_foldable) {
         const string attr_b =
@@ -1622,10 +1615,15 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
         FlipBooleanAttr(attr_b, new_op);
         new_op->set_input(1, b->input(0));
         node_map_->UpdateInput(new_op->name(), b->name(), b->input(0));
-        if (!a_is_foldable) {
-          AddFrameControlDeps(node, {new_op}, b->input(0), {new_op});
-        }
       }
+      std::vector<const NodeDef*> deps_to_forward({node});
+      if (a_is_foldable) {
+        deps_to_forward.push_back(a);
+      }
+      if (b_is_foldable) {
+        deps_to_forward.push_back(b);
+      }
+      ForwardControlDependencies(new_op, deps_to_forward);
     }
   }
 
@@ -1647,7 +1645,7 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
                                                        : "Transpose");
       new_op->set_input(0, input->input(0));
       node_map_->UpdateInput(new_op->name(), node->name(), input->input(0));
-      AddFrameControlDeps(node, {new_op}, "", {});
+      ForwardControlDependencies(new_op, {node, input});
       return new_op->name();
     }
   }
@@ -1663,8 +1661,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
   }
 
   const GraphOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
-                                  graph_properties_.get(), node_map_.get(),
-                                  &frame_map_);
+                                  graph_properties_.get(), node_map_.get());
   const ArithmeticOptimizerContext ctx_ext(&nodes_to_simplify);
 
   // Stop pipeline after first stage returning non-empty simplified tensor name.
@@ -1764,11 +1761,6 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
   graph_properties_.reset(new GraphProperties(item));
   TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false));
 
-  // Identify loop frames
-  int num_frames;
-  TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_,
-                                               &frame_map_, &num_frames));
-
   // Perform the optimizations.
   TF_RETURN_IF_ERROR(SimplifyArithmeticOps());
 
index 63a7b55..7e81ed0 100644 (file)
@@ -20,7 +20,6 @@ limitations under the License.
 #include "tensorflow/core/grappler/costs/graph_properties.h"
 #include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
 #include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/grappler/utils/frame.h"
 #include "tensorflow/core/protobuf/rewriter_config.pb.h"
 
 namespace tensorflow {
@@ -100,13 +99,9 @@ class ArithmeticOptimizer : public GraphOptimizer {
   // Dedup redundant nodes in the graph.
   void DedupComputations();
 
-  // Fix frame dependencies by adding control dependencies from old_input to
-  // nodes in new_nodes_for_control_dep, and update frame_map for all nodes in
-  // new_nodes.
-  void AddFrameControlDeps(const NodeDef* old_node,
-                           const std::vector<NodeDef*>& new_nodes,
-                           const string& source_for_ctrl_dep,
-                           const std::vector<NodeDef*>& sinks_for_control_dep);
+  // Forward the control dependencies anchored on src_nodes to the target_nodes.
+  void ForwardControlDependencies(NodeDef* target_node,
+                                  const std::vector<const NodeDef*>& src_nodes);
 
   // Runs peep-hole optimizations on `optimized_graph`, e.g., removing inverse
   // transposes.
@@ -135,7 +130,6 @@ class ArithmeticOptimizer : public GraphOptimizer {
   bool fetch_nodes_known_ = false;
   std::unordered_set<string> nodes_to_preserve_;
   std::unique_ptr<NodeMap> node_map_;
-  FrameMap frame_map_;
   std::unique_ptr<GraphProperties> graph_properties_;
   GraphDef* optimized_graph_ = nullptr;  // Not owned.
 };
index 48f1dd5..e117341 100644 (file)
@@ -520,26 +520,23 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
 
   const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6"));
   ASSERT_NE(add_6_node, nullptr);
-  EXPECT_EQ(3, add_6_node->input_size());
+  EXPECT_EQ(2, add_6_node->input_size());
   EXPECT_EQ(HoistAddName("Add_4"), add_6_node->input(0));
   EXPECT_EQ(HoistAddName("Add_5"), add_6_node->input(1));
-  EXPECT_EQ("^Placeholder", add_6_node->input(2));
 
   const NodeDef* add_4_node = node_map.GetNode(HoistAddName("Add_4"));
   ASSERT_NE(add_4_node, nullptr);
   EXPECT_EQ("Add", add_4_node->op());
-  EXPECT_EQ(3, add_4_node->input_size());
+  EXPECT_EQ(2, add_4_node->input_size());
   EXPECT_EQ(OptimizedName("Add_const"), add_4_node->input(0));
   EXPECT_EQ(OptimizedName("Add_1_const"), add_4_node->input(1));
-  EXPECT_EQ("^Placeholder", add_4_node->input(2));
 
   const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5"));
   ASSERT_NE(add_5_node, nullptr);
   EXPECT_EQ("Add", add_5_node->op());
-  EXPECT_EQ(3, add_5_node->input_size());
+  EXPECT_EQ(2, add_5_node->input_size());
   EXPECT_EQ(OptimizedName("Add_const"), add_5_node->input(0));
   EXPECT_EQ(OptimizedName("Add_1_const"), add_5_node->input(1));
-  EXPECT_EQ("^Placeholder", add_5_node->input(2));
 
   const NodeDef* add_const_node = node_map.GetNode(OptimizedName("Add_const"));
   ASSERT_NE(add_const_node, nullptr);
index 8d3e965..7ed0474 100644 (file)
@@ -21,7 +21,6 @@ limitations under the License.
 #include "tensorflow/core/grappler/costs/graph_properties.h"
 #include "tensorflow/core/grappler/grappler_item.h"
 #include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/grappler/utils/frame.h"
 
 namespace tensorflow {
 namespace grappler {
@@ -45,21 +44,16 @@ const NodeScopeAndName ParseNodeScopeAndName(const string& node_name);
 struct GraphOptimizerContext {
   GraphOptimizerContext(const std::unordered_set<string>* nodes_to_preserve,
                         GraphDef* optimized_graph,
-                        GraphProperties* graph_properties, NodeMap* node_map,
-                        FrameMap* frame_map)
+                        GraphProperties* graph_properties, NodeMap* node_map)
       : nodes_to_preserve(nodes_to_preserve),
         optimized_graph(optimized_graph),
         graph_properties(graph_properties),
-        node_map(node_map),
-        frame_map(frame_map) {}
+        node_map(node_map) {}
 
   const std::unordered_set<string>* nodes_to_preserve;
   GraphDef* optimized_graph;
   GraphProperties* graph_properties;
   NodeMap* node_map;
-  // TODO(ezhulenev): it seems that frame_map is only relevant for loop
-  // optimizer? Move it to loop-optimizer specific context extension.
-  FrameMap* frame_map;
 };
 
 Status GetInputNode(const GraphOptimizerContext& ctx, const string& input,
index 416327e..3f5ab87 100644 (file)
@@ -58,8 +58,8 @@ TEST_F(GraphOptimizerStageTest, ParseNodeNameAndScope_InScope) {
 TEST_F(GraphOptimizerStageTest, OptimizedNodeName) {
   GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
                             /*optimized_graph*/ nullptr,
-                            /*graph_properties*/ nullptr, /*node_name*/ nullptr,
-                            /*frame_map*/ nullptr);
+                            /*graph_properties*/ nullptr,
+                            /*node_name*/ nullptr);
   FakeOptimizerStage stage("my_opt", "my_stg", ctx);
 
   const auto node = ParseNodeScopeAndName("a/b/c/Add");
@@ -94,8 +94,7 @@ TEST_F(GraphOptimizerStageTest, GetInputNodeAndProperties) {
   GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
                             /*optimized_graph*/ &item.graph,
                             /*graph_properties*/ &properties,
-                            /*node_name*/ &node_map,
-                            /*frame_map*/ nullptr);
+                            /*node_name*/ &node_map);
   FakeOptimizerStage stage("my_opt", "my_stg", ctx);
 
   NodeDef* add_node;
@@ -134,8 +133,7 @@ TEST_F(GraphOptimizerStageTest, AddNodes) {
   GraphOptimizerContext ctx(/*nodes_to_preserve*/ nullptr,
                             /*optimized_graph*/ &item.graph,
                             /*graph_properties*/ &properties,
-                            /*node_name*/ &node_map,
-                            /*frame_map*/ nullptr);
+                            /*node_name*/ &node_map);
   FakeOptimizerStage stage("my_opt", "my_stg", ctx);
 
   NodeDef* add_node;
@@ -165,4 +163,4 @@ TEST_F(GraphOptimizerStageTest, AddNodes) {
 
 }  // namespace
 }  // end namespace grappler
-}  // end namespace tensorflow
\ No newline at end of file
+}  // end namespace tensorflow