Adds support for hoisting out common denominator in arithmetic_optimizer
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 10 Apr 2018 18:09:37 +0000 (11:09 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 10 Apr 2018 18:12:23 +0000 (11:12 -0700)
PiperOrigin-RevId: 192314177

tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc

index fa0f7c1..463c332 100644 (file)
@@ -695,15 +695,20 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage {
   }
 };
 
-// Use the commutativity and (left- and right-) distributive property of
-// multiplication over addition to hoist common factors out of aggregate nodes
-// where all the inputs are Mul nodes. This pattern occurs frequently in
-// regularization terms for the gradients during training.
+// Use the distributive property of multiplication and division over addition,
+// along with commutativity of the former, to hoist common factors/denominators
+// out of aggregate nodes where ALL the inputs are Mul/Div nodes.
+// This pattern occurs frequently in regularization terms for the gradients
+// during training.
 //
 // For example, we can rewrite an expression of the form:
 //   AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
 // to the following:
 //   Mul(x, AddN(y1, y2, y3, ... yn))
+// For division, we can rewrite
+//   AddN(Div(y1, x), Div(y2, x), Div(y3, x), ... Div(yn, x))
+// to:
+//   Div(AddN(y1, y2, y3, ... yn), x)
 class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
  public:
   explicit HoistCommonFactorOutOfAggregation(
@@ -720,9 +725,11 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
   Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
     TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node));
 
+    bool common_factor_is_denominator = false;
     std::set<string> common_factors;
     std::vector<string> ctrl_deps;
-    TF_RETURN_IF_ERROR(GetCommonFactors(node, &common_factors, &ctrl_deps));
+    TF_RETURN_IF_ERROR(GetCommonFactors(
+        node, &common_factors, &common_factor_is_denominator, &ctrl_deps));
 
     if (common_factors.size() == 1) {
       const string& common_factor = *common_factors.begin();
@@ -730,24 +737,31 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
       // Gather up the non-shared factors
       bool shapes_match = true;
       std::vector<string> unique_factors;
-      TF_RETURN_IF_ERROR(GetUniqueFactors(node, common_factor, &shapes_match,
-                                          &unique_factors));
+      TF_RETURN_IF_ERROR(GetUniqueFactors(node, common_factor,
+                                          common_factor_is_denominator,
+                                          &shapes_match, &unique_factors));
 
       if (shapes_match) {
         NodeDef* input_0;
         TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input_0));
 
-        // Use a copy of the first Mul node for the outer multiplication.
-        NodeDef* new_mul_node = AddCopyNode(OuterMulNodeName(node), input_0);
+        // Use a copy of the first node for the outer multiplication/division.
+        NodeDef* new_outer_node = AddCopyNode(
+            OuterNodeName(node, common_factor_is_denominator), input_0);
         // And a copy of aggregation node as one of the inner operands
         NodeDef* new_add_node = AddCopyNode(InnerAddNodeName(node), node);
 
-        new_mul_node->set_device(node->device());
-        new_mul_node->set_input(0, common_factor);
-        new_mul_node->set_input(1, new_add_node->name());
+        new_outer_node->set_device(node->device());
+        if (common_factor_is_denominator) {
+          new_outer_node->set_input(0, new_add_node->name());
+          new_outer_node->set_input(1, common_factor);
+        } else {
+          new_outer_node->set_input(0, common_factor);
+          new_outer_node->set_input(1, new_add_node->name());
+        }
 
-        ctx_.node_map->AddOutput(common_factor, new_mul_node->name());
-        ctx_.node_map->AddOutput(new_add_node->name(), new_mul_node->name());
+        ctx_.node_map->AddOutput(common_factor, new_outer_node->name());
+        ctx_.node_map->AddOutput(new_add_node->name(), new_outer_node->name());
 
         // Hoist non-shared factors up into the new AddN node.
         for (int i = 0; i < unique_factors.size(); ++i) {
@@ -766,17 +780,18 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
         AddToOptimizationQueue(new_add_node);
         // do not optimize the same node twice
         rewritten_nodes_.insert(node->name());
-        *simplified_node_name = new_mul_node->name();
+        *simplified_node_name = new_outer_node->name();
       }
     }
     return Status::OK();
   }
 
  private:
-  // Get a name for new outer Mul node
-  string OuterMulNodeName(const NodeDef* node) const {
+  // Get a name for new outer node
+  string OuterNodeName(const NodeDef* node, bool is_div) const {
     auto scope_and_name = ParseNodeScopeAndName(node->name());
-    return OptimizedNodeName(scope_and_name, "Mul");
+    return is_div ? OptimizedNodeName(scope_and_name, "Div")
+                  : OptimizedNodeName(scope_and_name, "Mul");
   }
 
   // Get a name new inner Add node
@@ -785,11 +800,17 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
     return OptimizedNodeName(scope_and_name, "Add");
   }
 
-  // Determine the set of common factors if the input nodes are all Mul nodes.
+  // Determine the set of common factors if the input nodes are all Mul or
+  // Div nodes.
   Status GetCommonFactors(const NodeDef* node, std::set<string>* common_factors,
+                          bool* common_factor_is_denominator,
                           std::vector<string>* ctrl_deps) const {
     CHECK(common_factors->empty());
+    CHECK_NOTNULL(common_factor_is_denominator);
+    *common_factor_is_denominator = false;
 
+    bool has_mul = false;
+    bool has_div = false;
     for (int i = 0; i < node->input_size(); ++i) {
       if (i > 0 && common_factors->empty()) break;
       if (IsControlInput(node->input(i))) {
@@ -799,12 +820,36 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
       NodeDef* input;
       TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input));
 
-      if (!IsMul(*input)) {
+      if ((!IsMul(*input) && !IsAnyDiv(*input)) || (IsMul(*input) && has_div) ||
+          (IsAnyDiv(*input) && has_mul)) {
+        // Break if input is neither a Mul or Div, or if there are both Mul &
+        // Div Ops.
         common_factors->clear();
         break;
+      } else if (IsAnyDiv(*input)) {
+        has_div = true;
+        // In case of possible common dividers, we avoid hoisting out if any
+        // input is not float/double, since integer division is not distributive
+        // over addition.
+        OpInfo::TensorProperties properties0, properties1;
+        TF_RETURN_IF_ERROR(GetTensorProperties(input->input(0), &properties0));
+        TF_RETURN_IF_ERROR(GetTensorProperties(input->input(1), &properties1));
+        if (properties0.dtype() != DT_FLOAT &&
+            properties0.dtype() != DT_DOUBLE &&
+            properties1.dtype() != DT_FLOAT &&
+            properties1.dtype() != DT_DOUBLE) {
+          common_factors->clear();
+          break;
+        }
+      } else if (IsMul(*input)) {
+        has_mul = true;
       }
 
-      std::set<string> factors_i{input->input(0), input->input(1)};
+      // We only focus on common factors from denominators if any Op is a
+      // Div.
+      std::set<string> factors_i =
+          has_mul ? std::set<string>{input->input(0), input->input(1)}
+                  : std::set<string>{input->input(1)};
       if (i == 0) {
         std::swap(*common_factors, factors_i);
       } else {
@@ -819,6 +864,8 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
         ctrl_deps->push_back(input->input(i));
       }
     }
+
+    *common_factor_is_denominator = has_div;
     return Status::OK();
   }
 
@@ -827,6 +874,7 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
   // have the same shape since the other aggregation ops do not support
   // broadcasting.
   Status GetUniqueFactors(const NodeDef* node, const string& common_factor,
+                          const bool common_factor_is_denominator,
                           bool* shapes_match,
                           std::vector<string>* unique_factors) const {
     *shapes_match = true;
@@ -837,11 +885,13 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
       if (IsControlInput(input)) {
         break;
       }
-      NodeDef* mul_node;
-      TF_RETURN_IF_ERROR(GetInputNode(input, &mul_node));
+      NodeDef* inner_node;
+      TF_RETURN_IF_ERROR(GetInputNode(input, &inner_node));
       const int unique_factor_index =
-          mul_node->input(0) == common_factor ? 1 : 0;
-      unique_factors->push_back(mul_node->input(unique_factor_index));
+          common_factor_is_denominator
+              ? 0
+              : (inner_node->input(0) == common_factor ? 1 : 0);
+      unique_factors->push_back(inner_node->input(unique_factor_index));
       if (i > 0 && !IsAdd(*node)) {
         OpInfo::TensorProperties lhs;
         OpInfo::TensorProperties rhs;
@@ -857,7 +907,8 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
     // if graph rewrite happens in multiple passes without graph pruning between
     // them, it's possible that rewritten node already exists in a graph
     return rewritten_nodes_.find(node->name()) != rewritten_nodes_.end() ||
-           ctx_.node_map->NodeExists(OuterMulNodeName(node));
+           ctx_.node_map->NodeExists(OuterNodeName(node, false)) ||
+           ctx_.node_map->NodeExists(OuterNodeName(node, true));
   }
 
   // keep names of the nodes that were optimized by this stage
index 9677175..e639812 100644 (file)
@@ -31,6 +31,9 @@ namespace grappler {
 
 namespace {
 
+constexpr char kHoistFactorOptimizerDiv[] =
+    "ArithmeticOptimizer/HoistCommonFactor_Div_";
+
 constexpr char kHoistFactorOptimizerMul[] =
     "ArithmeticOptimizer/HoistCommonFactor_Mul_";
 
@@ -42,6 +45,11 @@ string HoistMulName(const string& name) {
   return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, "");
 }
 
+// Optimized name of outer Div node by HoistCommonFactorOutOfAggregation
+string HoistDivName(const string& name) {
+  return AddPrefixToNodeName(name, kHoistFactorOptimizerDiv, "");
+}
+
 // Optimized name of inner Add node by HoistCommonFactorOutOfAggregation
 string HoistAddName(const string& name) {
   return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, "");
@@ -558,7 +566,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
   EXPECT_EQ("^Placeholder", add_1_const_node->input(0));
 }
 
-TEST_F(ArithmeticOptimizerTest, HoistFactor) {
+TEST_F(ArithmeticOptimizerTest, HoistFactorMul) {
   for (bool matching_shapes : {true, false}) {
     for (bool use_addn : {true, false}) {
       tensorflow::Scope s = tensorflow::Scope::NewRootScope();
@@ -625,6 +633,81 @@ TEST_F(ArithmeticOptimizerTest, HoistFactor) {
   }
 }
 
+TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) {
+  for (bool matching_shapes : {true, false}) {
+    for (bool use_addn : {true, false}) {
+      for (bool use_ints : {true, false}) {
+        tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+        Output x = use_ints
+                       ? ops::Const(s.WithOpName("x"), {1, 2}, {1, 2})
+                       : ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+        Output y1 = use_ints
+                        ? ops::Const(s.WithOpName("y1"), {3, 4}, {1, 2})
+                        : ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
+        Output y2;
+        if (matching_shapes) {
+          y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5, 6}, {1, 2})
+                        : ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2});
+        } else {
+          y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5}, {1, 1})
+                        : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
+        }
+        Output div1 = ops::Div(s.WithOpName("div1"), y1, x);
+        Output div2 = ops::Div(s.WithOpName("div2"), y2, x);
+        Output id =
+            use_addn
+                ? ops::Identity(s.WithOpName("id"),
+                                ops::AddN(s.WithOpName("add"), {div1, div2}))
+                : ops::Identity(s.WithOpName("id"),
+                                ops::Add(s.WithOpName("add"), div1, div2));
+
+        GrapplerItem item;
+        item.fetch = {"id"};
+        TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+        ArithmeticOptimizer optimizer;
+        EnableOnlyHoistCommonFactor(&optimizer);
+
+        GraphDef output;
+        OptimizeTwice(&optimizer, &item, &output);
+
+        // We expect the following rewrite(s) to occur:
+        //
+        //        Add                 Div
+        //      /    \               /   \
+        //    Div    Div       ->  Add    x
+        //    / \    / \           / \
+        //   y1  x  y2  x         y1  y2
+        //
+        // If "root" op is AddN and shapes does not match, this rewrite is not
+        // possible and graph should stay intact.
+        NodeMap node_map(&output);
+
+        if ((use_addn && !matching_shapes) || use_ints) {
+          VerifyGraphsMatch(item.graph, output, __LINE__);
+        } else {
+          EXPECT_EQ(9, output.node_size());
+
+          const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
+          ASSERT_TRUE(new_add_node != nullptr) << "Hoisted Add node not found";
+          EXPECT_EQ("y1", new_add_node->input(0));
+          EXPECT_EQ("y2", new_add_node->input(1));
+
+          const NodeDef* new_div_node = node_map.GetNode(HoistDivName("add"));
+          ASSERT_TRUE(new_div_node != nullptr) << "Hoisted Div node not found";
+          EXPECT_EQ(new_add_node->name(), new_div_node->input(0));
+          EXPECT_EQ("x", new_div_node->input(1));
+
+          const NodeDef* id_node = node_map.GetNode("id");
+          ASSERT_TRUE(id_node != nullptr) << "Id node not found";
+          EXPECT_EQ("id", id_node->name());
+          EXPECT_EQ(HoistDivName("add"), id_node->input(0));
+        }
+      }
+    }
+  }
+}
+
 TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
   tensorflow::Scope s = tensorflow::Scope::NewRootScope();
   Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2});