Enable unary chain hoisting optimization for concat/split/splitv by default.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 3 May 2018 20:00:56 +0000 (13:00 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 3 May 2018 20:40:29 +0000 (13:40 -0700)
PiperOrigin-RevId: 195297330

tensorflow/core/grappler/op_types.cc
tensorflow/core/grappler/op_types.h
tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc

index 7c936df..c48dc00 100644 (file)
@@ -476,28 +476,40 @@ bool IsInvolution(const NodeDef& node) {
   return involution_ops->count(node.op()) > 0;
 }
 
+bool IsValueAndOrderAndShapePreserving(const NodeDef& node) {
+  if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
+    return true;
+  }
+  static const std::unordered_set<string>*
+      value_and_order_and_shape_preserving_ops =
+          CHECK_NOTNULL((new const std::unordered_set<string>{
+              "CheckNumerics",
+              "DebugGradientIdentity",
+              "DeepCopy"
+              "Enter",
+              "Exit",
+              "Identity",
+              "IdentityN",
+              "PreventGradient",
+              "Print",
+              "Snapshot",
+              "StopGradient",
+          }));
+  return value_and_order_and_shape_preserving_ops->count(node.op()) > 0;
+}
+
 bool IsValueAndOrderPreserving(const NodeDef& node) {
   if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
     return true;
   }
   static const std::unordered_set<string>* value_and_order_preserving_ops =
       CHECK_NOTNULL((new const std::unordered_set<string>{
-          "CheckNumerics",
-          "DebugGradientIdentity",
-          "DeepCopy"
-          "Enter",
-          "Exit",
           "ExpandDims",
-          "Identity",
-          "IdentityN",
-          "PreventGradient",
-          "Print",
-          "Reshape",
           "Snapshot",
           "Squeeze",
-          "StopGradient",
       }));
-  return value_and_order_preserving_ops->count(node.op()) > 0;
+  return value_and_order_preserving_ops->count(node.op()) > 0 ||
+         IsValueAndOrderAndShapePreserving(node);
 }
 
 bool IsValuePreserving(const NodeDef& node) {
@@ -564,7 +576,7 @@ bool IsUnaryElementWise(const NodeDef& node) {
           "Tanh",
       }));
   return element_wise_ops->count(node.op()) > 0 ||
-         (!IsIdentityN(node) && IsValueAndOrderPreserving(node));
+         (!IsIdentityN(node) && IsValueAndOrderAndShapePreserving(node));
 }
 
 bool HasOpDef(const NodeDef& node) {
index 7a1b438..e33dd21 100644 (file)
@@ -174,6 +174,10 @@ bool ModifiesInputsInPlace(const NodeDef& node);
 // own inverse such that f(f(x)) == x.
 bool IsInvolution(const NodeDef& node);
 
+// Returns true if the op preserves the order and value of elements
+// and shape of its first input tensor.
+bool IsValueAndOrderAndShapePreserving(const NodeDef& node);
+
 // Returns true if the op preserves the order and value of elements in its
 // first input tensor and possible changes its shape.
 bool IsValueAndOrderPreserving(const NodeDef& node);
index d6510ba..2a5654f 100644 (file)
@@ -1400,6 +1400,11 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
       return n > 1;
     } else if (IsSplit(*node) || IsSplitV(*node)) {
       const int num_split = node->attr().at("num_split").i();
+      if (NumNonControlOutputs(*node, *ctx().node_map) > num_split) {
+        // TODO(rmlarsen): Remove this constraint when we have optimizations
+        // in place for merging slices into splits.
+        return false;
+      }
       return num_split > 1 && !IsAlreadyOptimized(*node);
     }
     return false;
@@ -1458,13 +1463,13 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
     if (tails.empty()) {
       return Status::OK();
     }
-    AddControlInputs(ctrl_inputs, root_node);
     AddToOptimizationQueue(root_node);
     optimized_nodes_.insert(root_node->name());
     if (node_is_concat_) {
+      AddControlInputs(ctrl_inputs, root_node);
       return HoistChainForConcat(prefix_length, tails, root_node);
     } else {
-      return HoistChainForSplit(prefix_length, tails, root_node);
+      return HoistChainForSplit(prefix_length, tails, ctrl_inputs, root_node);
     }
   }
 
@@ -1542,9 +1547,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
           IsInPreserveSet(*op)) {
         return false;
       }
-      if (node_is_concat_ &&
-          ctx().node_map->GetOutputs(op->name()).size() > 1) {
-        // TODO(rmlarsen): Allow and hoist outgoing control edges.
+      if (ctx().node_map->GetOutputs(op->name()).size() > 1) {
+        // TODO(rmlarsen): Allow outgoing control edges.
         return false;
       }
     }
@@ -1612,6 +1616,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
   }
 
   Status HoistChainForSplit(const int prefix_length, const ChainLinkSet& tails,
+                            std::set<string>* ctrl_inputs,
                             NodeDef* split_node) {
     // Create a new chain before the split node to process the input tensor.
     const string& split_name = split_node->name();
@@ -1646,6 +1651,9 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage {
     cur_copy->add_input(orig_input);
     ctx().node_map->UpdateOutput(NodeName(orig_input), split_name,
                                  cur_copy->name());
+    // Make sure all the control inputs are satisfied before running the first
+    // node in the new chain.
+    AddControlInputs(ctrl_inputs, cur_copy);
 
     // Connect all consumers of the tail nodes directly to the
     // output port of Split from which the chain started.
index 3b297ec..6309dc1 100644 (file)
@@ -65,7 +65,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
     bool remove_redundant_bitcast = true;
     bool remove_redundant_cast = true;
     bool remove_negation = true;
-    bool hoist_cwise_unary_chains = false;
+    bool hoist_cwise_unary_chains = true;
     bool convert_sqrt_div_to_rsqrt_mul = false;
 
     // Choose which arithmetic optimizer stages will be enabled for a given
index f903f53..d32743f 100644 (file)
@@ -2320,16 +2320,16 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
     EXPECT_NE(node.name(), "cos_exp_b2");
 
     if (node.name() == "split1") {
-      EXPECT_EQ(3, node.input_size());
+      EXPECT_EQ(2, node.input_size());
       EXPECT_EQ("axis", node.input(0));
       EXPECT_EQ("ArithmeticOptimizer/_sin_a_split1", node.input(1));
-      EXPECT_EQ("^ctrl1", node.input(2));
       found++;
     }
     if (node.name() == "ArithmeticOptimizer/_sin_a_split1") {
       EXPECT_EQ("Sin", node.op());
-      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ(2, node.input_size());
       EXPECT_EQ("x", node.input(0));
+      EXPECT_EQ("^ctrl1", node.input(1));
       found++;
     }
     if (node.name() == "id_a") {
@@ -2349,8 +2349,11 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
     }
     if (node.name() == "ArithmeticOptimizer/_exp_a2_split2") {
       EXPECT_EQ("Exp", node.op());
-      EXPECT_EQ(1, node.input_size());
+      EXPECT_EQ(4, node.input_size());
       EXPECT_EQ("x", node.input(0));
+      EXPECT_EQ("^ctrl1", node.input(1));
+      EXPECT_EQ("^ctrl2", node.input(2));
+      EXPECT_EQ("^ctrl3", node.input(3));
       found++;
     }
     if (node.name() == "ArithmeticOptimizer/_cos_exp_a2_split2") {
@@ -2360,13 +2363,10 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) {
       found++;
     }
     if (node.name() == "split2") {
-      EXPECT_EQ(6, node.input_size());
+      EXPECT_EQ(3, node.input_size());
       EXPECT_EQ("ArithmeticOptimizer/_cos_exp_a2_split2", node.input(0));
       EXPECT_EQ("size_splits2", node.input(1));
       EXPECT_EQ("axis", node.input(2));
-      EXPECT_EQ("^ctrl1", node.input(3));
-      EXPECT_EQ("^ctrl2", node.input(4));
-      EXPECT_EQ("^ctrl3", node.input(5));
       found++;
     }
     if (node.name() == "id_a2") {