Move "hoist common factor out of aggregation" optimization
authorA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 12 Mar 2018 20:50:35 +0000 (13:50 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Mon, 12 Mar 2018 20:54:36 +0000 (13:54 -0700)
to a separate stage.

1) Use a new naming scheme for optimized ops,
   share it with AddOpsRewrite
2) Make sure that tests actually test that optimized
   nodes exists in a graph

PiperOrigin-RevId: 188772892

tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc

index 177b073..c0fcfaf 100644 (file)
@@ -290,25 +290,30 @@ NodeDef* GetTailOfValuePreservingChain(
 struct ArithmeticOptimizerContext {
   ArithmeticOptimizerContext(
       const std::unordered_set<string>* nodes_to_preserve,
-      GraphDef* optimized_graph, NodeMap* node_map,
+      GraphDef* optimized_graph, NodeMap* node_map, FrameMap* frame_map,
       SetVector<NodeDef*>* nodes_to_simplify)
       : nodes_to_preserve(nodes_to_preserve),
         optimized_graph(optimized_graph),
         node_map(node_map),
+        frame_map(frame_map),
         nodes_to_simplify(nodes_to_simplify) {}
 
   const std::unordered_set<string>* nodes_to_preserve;
   GraphDef* optimized_graph;
   NodeMap* node_map;
+  FrameMap* frame_map;
   SetVector<NodeDef*>* nodes_to_simplify;
 };
 
 // Base class for single arithmetic optimization: e.g. Bitcast optimization,
 // AddOps optimization, etc...
+// TODO(ezhulenev): extract this class to be reused by other multi-stage
+// graph optimizers (const_folding, dependency_optimizer, etc...)
 class ArithmeticOptimizerStage {
  public:
-  explicit ArithmeticOptimizerStage(ArithmeticOptimizerContext ctx)
-      : ctx_(ctx) {}
+  explicit ArithmeticOptimizerStage(const string& name,
+                                    const ArithmeticOptimizerContext& ctx)
+      : name_(name), ctx_(ctx) {}
   virtual ~ArithmeticOptimizerStage() = default;
 
   // Check if we should try to simplify node. Returning true doesn't
@@ -336,6 +341,46 @@ class ArithmeticOptimizerStage {
                              string* simplified_node_name) = 0;
 
  protected:
+  struct ScopedNodeName {
+    string scope;
+    string name;
+  };
+
+  const ScopedNodeName ParseScopedNodeName(const string& name) const {
+    auto pos = name.find_last_of("/");
+    if (pos == string::npos) {
+      return {"", name};
+    } else {
+      return {name.substr(0, pos), name.substr(pos + 1)};
+    }
+  }
+
+  // Prefix optimized node name with stage name and rewrite_rule
+  const string OptimizedNodeName(const string& rewrite_rule,
+                                 const ScopedNodeName& scoped_node_name) const {
+    return MakeOptimizedNodeName(strings::StrCat(name_, "_", rewrite_rule),
+                                 scoped_node_name);
+  }
+
+  // Prefix optimized node name with stage name and rewrite_rule
+  const string OptimizedNodeName(const string& rewrite_rule,
+                                 const ScopedNodeName& scoped_node_name,
+                                 const std::vector<string>& node_names) const {
+    return MakeOptimizedNodeName(strings::StrCat(name_, "_", rewrite_rule),
+                                 scoped_node_name, node_names);
+  }
+
+  // Prefix optimized node name with stage name
+  const string OptimizedNodeName(const ScopedNodeName& scoped_node_name) const {
+    return MakeOptimizedNodeName(name_, scoped_node_name);
+  }
+
+  // Prefix optimized node name with stage name
+  const string OptimizedNodeName(const ScopedNodeName& scoped_node_name,
+                                 const std::vector<string>& node_names) const {
+    return MakeOptimizedNodeName(name_, scoped_node_name, node_names);
+  }
+
   // Simplification graph rewrite can create additional nodes that are inputs
   // to final simplified node, they can be also added to the arithmetic
   // optimizer queue for further optimization.
@@ -374,7 +419,91 @@ class ArithmeticOptimizerStage {
     }
   }
 
-  ArithmeticOptimizerContext ctx_;
+  NodeDef* AddCopyNode(const string& name, const NodeDef* node_to_copy) {
+    CHECK(node_to_copy != nullptr);
+    CHECK(!ctx_.node_map->NodeExists(name))
+        << "Node " << name << " already exists in a graph";
+    NodeDef* new_node = ctx_.optimized_graph->add_node();
+    *new_node = *node_to_copy;
+    new_node->set_name(name);
+    ctx_.node_map->AddNode(name, new_node);
+    return new_node;
+  }
+
+  NodeDef* AddEmptyNode(const string& name) {
+    CHECK(!ctx_.node_map->NodeExists(name))
+        << "Node " << name << " already exists in a graph";
+    NodeDef* new_node = ctx_.optimized_graph->add_node();
+    new_node->set_name(name);
+    ctx_.node_map->AddNode(name, new_node);
+    return new_node;
+  }
+
+  // TODO(ezhulenev): remove this method from ArithmeticOptimizer when all
+  // optimizations will be migrated to stages
+  void AddFrameControlDeps(const NodeDef* old_node,
+                           const std::vector<NodeDef*>& new_nodes,
+                           const string& source_for_ctrl_dep,
+                           const std::vector<NodeDef*>& sinks_for_control_dep) {
+    const auto frame_it = ctx_.frame_map->find(old_node);
+    if (frame_it != ctx_.frame_map->end()) {
+      for (auto node : new_nodes) {
+        ctx_.frame_map->emplace(node, frame_it->second);
+      }
+      if (!source_for_ctrl_dep.empty() && !sinks_for_control_dep.empty()) {
+        const string ctrl_dep = ConstantFolding::AddControlDependency(
+            source_for_ctrl_dep, ctx_.optimized_graph, ctx_.node_map);
+        for (auto node : sinks_for_control_dep) {
+          MaybeAddControlInput(ctrl_dep, node, ctx_.optimized_graph,
+                               ctx_.node_map);
+        }
+      }
+    }
+  }
+
+  const string name_;
+  const ArithmeticOptimizerContext ctx_;
+
+ private:
+  // Get a name for a new node obtained by optimizing a single node of the
+  // original graph. The optimized node is placed under the original node scope.
+  //
+  // Node name uniqueness is guaranteed by unique name of an original node in
+  // a same scope.
+  //
+  // Example: MakeOptimizedNodeName("AwesomeRewrite", "a/b/c/Add_1")
+  // Optimized name: "a/b/c/ArithmeticOptimizer/AwesomeRewrite_Add_1"
+  const string MakeOptimizedNodeName(
+      const string& prefix, const ScopedNodeName& scoped_node_name) const {
+    string node_name;
+    strings::StrAppend(&node_name, scoped_node_name.scope);
+    if (!node_name.empty()) strings::StrAppend(&node_name, "/");
+    strings::StrAppend(&node_name, kArithmeticOptimizer, "/", prefix, "_",
+                       scoped_node_name.name);
+    return node_name;
+  }
+
+  // Get a name for a new node obtained by optimizing multiple nodes of the
+  // original graph, starting from "root". The optimized node is placed under
+  // the original scope of a "root" node.
+  //
+  // Node name uniqueness is guaranteed by unique name of a "root" node in
+  // a same scope.
+  //
+  // Example:
+  //   MakeOptimizedNodeName("AwesomeRewrite", "a/b/Add_AB", ["x/y/Add_XY"])
+  // Optimized name:
+  //   "a/b/ArithmeticOptimizer/AwesomeRewrite_Add_AB_Add_XY"
+  const string MakeOptimizedNodeName(
+      const string& prefix, const ScopedNodeName& scoped_node_name,
+      const std::vector<string>& node_names) const {
+    string node_name = MakeOptimizedNodeName(prefix, scoped_node_name);
+    for (const string& optimized : node_names) {
+      auto scoped_node = ParseScopedNodeName(optimized);
+      strings::StrAppend(&node_name, "_", scoped_node.name);
+    }
+    return node_name;
+  }
 };
 
 // Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
@@ -393,8 +522,8 @@ class ArithmeticOptimizerStage {
 //                         q   e
 class AddOpsRewriteStage : public ArithmeticOptimizerStage {
  public:
-  explicit AddOpsRewriteStage(ArithmeticOptimizerContext ctx)
-      : ArithmeticOptimizerStage(ctx), rewritten_nodes_() {}
+  explicit AddOpsRewriteStage(const ArithmeticOptimizerContext& ctx)
+      : ArithmeticOptimizerStage("AddOpsRewrite", ctx), rewritten_nodes_() {}
 
   ~AddOpsRewriteStage() override = default;
 
@@ -422,7 +551,7 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
     AddOpsGroup group;
     TF_RETURN_IF_ERROR(CreateAddOpsGroup(node, &group));
 
-    if (!group.absorbed_nodes.empty()) {
+    if (!group.absorbed_nodes.empty() && !IsRewritten(group)) {
       *simplified_node_name = RewriteAddOpsGroup(group);
     }
 
@@ -530,6 +659,12 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
              DrivesControlDependency(*node));
   }
 
+  // Check that optimized group node name doesn't exists. It might happen if
+  // graph optimized multiple times without pruning beween invocations.
+  bool IsRewritten(const AddOpsGroup& group) const {
+    return ctx_.node_map->NodeExists(AddOpsGroupName(group));
+  }
+
   // Create an AddOpsGroup with a root in a given node
   Status CreateAddOpsGroup(const NodeDef* root_node, AddOpsGroup* group) {
     group->root_node = root_node;
@@ -559,39 +694,23 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
     return Status::OK();
   }
 
-  const std::pair<string, string> ParseNodeScopeAndName(const string& name) {
-    auto pos = name.find_last_of("/");
-    if (pos == string::npos) {
-      return {"", name};
-    } else {
-      return {name.substr(0, pos), name.substr(pos + 1)};
-    }
-  }
-
   // New node for AddOpsGroup is added to the same scope as a root_node. All
   // absorbed nodes are stripped of their scope, and only names are used in a
   // new node name.
   //
   // Example: AddOpsGroup(root="a/b/c/Add_2", absorbed=["d/Add_1", "e/Add"])
   //          node_name="a/b/c/AddOpsGroup_Add_2_Add_1_Add
-  string AddOpsGroupName(const AddOpsGroup& group) {
+  string AddOpsGroupName(const AddOpsGroup& group) const {
     CHECK_NOTNULL(group.root_node);
-    string node_name;
 
-    auto root_node = ParseNodeScopeAndName(group.root_node->name());
-    auto root_scope = root_node.first;
-    auto root_name = root_node.second;
-    if (!root_scope.empty()) {
-      strings::StrAppend(&node_name, root_scope, "/");
-    }
+    auto root = ParseScopedNodeName(group.root_node->name());
 
-    strings::StrAppend(&node_name, kArithmeticOptimizer, "/", "AddOpsGroup_",
-                       root_name);
-    for (const NodeDef* absorbed : group.absorbed_nodes) {
-      auto absorbed_node = ParseNodeScopeAndName(absorbed->name());
-      strings::StrAppend(&node_name, "_", absorbed_node.second);
-    }
-    return node_name;
+    std::vector<string> absorbed_node_names(group.absorbed_nodes.size());
+    std::transform(group.absorbed_nodes.begin(), group.absorbed_nodes.end(),
+                   absorbed_node_names.begin(),
+                   [](const NodeDef* node) { return node->name(); });
+
+    return OptimizedNodeName(root, absorbed_node_names);
   }
 
   // Create a new node for a AddOpsGroup and return it's name.
@@ -605,18 +724,17 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
     // copy attributes from a root node
     DataType dtype = group.root_node->attr().at("T").type();
 
-    // add new node
-    NodeDef* added_node = ctx_.optimized_graph->add_node();
-    added_node->set_name(node_name);
+    // add new AddN node
+    NodeDef* added_node = AddEmptyNode(node_name);
     added_node->set_op("AddN");
     added_node->set_device(group.root_node->device());
     (*added_node->mutable_attr())["T"].set_type(dtype);
     (*added_node->mutable_attr())["N"].set_i(group.inputs.size());
 
-    ctx_.node_map->AddNode(node_name, added_node);
-    for (string input : group.inputs) {
+    // all inputs of absorbed nodes are added to the new node
+    for (const string& input : group.inputs) {
       ctx_.node_map->AddOutput(input, node_name);
-      added_node->add_input(std::move(input));
+      added_node->add_input(input);
     }
 
     VLOG(1) << "Absorbed " << group.absorbed_nodes.size()
@@ -635,11 +753,167 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage {
   std::unordered_set<string> rewritten_nodes_;
 };
 
+// Use the commutativity and (left- and right-) distributive property of
+// multiplication over addition to hoist common factors out of aggregate nodes
+// where all the inputs are Mul nodes. This pattern occurs frequently in
+// regularization terms for the gradients during training.
+//
+// For example, we can rewrite an expression of the form:
+//   AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
+// to the following:
+//   Mul(x, AddN(y1, y2, y3, ... yn))
+class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage {
+ public:
+  explicit HoistCommonFactorOutOfAggregation(
+      const ArithmeticOptimizerContext& ctx)
+      : ArithmeticOptimizerStage("HoistCommonFactor", ctx) {}
+  ~HoistCommonFactorOutOfAggregation() override = default;
+
+  bool IsSupported(const NodeDef* node) const override {
+    return IsAggregate(*node) && NumNonControlInputs(*node) > 1 &&
+           !IsRewritten(node);
+  }
+
+  Status TrySimplify(const NodeDef* node,
+                     string* simplified_node_name) override {
+    CHECK(IsSupported(node));
+
+    std::set<string> common_factors;
+    TF_RETURN_IF_ERROR(GetCommonFactors(node, &common_factors));
+
+    if (common_factors.size() == 1) {
+      const string& common_factor = *common_factors.begin();
+
+      // Gather up the non-shared factors
+      bool shapes_match = true;
+      std::vector<string> unique_factors;
+      TF_RETURN_IF_ERROR(GetUniqueFactors(node, common_factor, &shapes_match,
+                                          &unique_factors));
+
+      if (shapes_match) {
+        NodeDef* input_0;
+        TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input_0));
+
+        // Use a copy of the first Mul node for the outer multiplication.
+        NodeDef* new_mul_node = AddCopyNode(OuterMulNodeName(node), input_0);
+        // And a copy of aggregation node as one of the inner operands
+        NodeDef* new_add_node = AddCopyNode(InnerAddNodeName(node), node);
+
+        new_mul_node->set_device(node->device());
+        new_mul_node->set_input(0, common_factor);
+        new_mul_node->set_input(1, new_add_node->name());
+
+        ctx_.node_map->AddOutput(common_factor, new_mul_node->name());
+        ctx_.node_map->AddOutput(new_add_node->name(), new_mul_node->name());
+
+        // Hoist non-shared factors up into the new AddN node.
+        for (int i = 0; i < unique_factors.size(); ++i) {
+          new_add_node->set_input(i, unique_factors[i]);
+        }
+
+        // Add frame dependencies that the original node might have had.
+        AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor,
+                            {new_add_node});
+
+        // optimize new inner aggregation node
+        AddToOptimizationQueue(new_add_node);
+        // do not optimize the same node twice
+        rewritten_nodes_.insert(node->name());
+        *simplified_node_name = new_mul_node->name();
+      }
+    }
+    return Status::OK();
+  }
+
+ private:
+  // Get a name for new outer Mul node
+  string OuterMulNodeName(const NodeDef* node) const {
+    auto scoped_node = ParseScopedNodeName(node->name());
+    return OptimizedNodeName("Mul", scoped_node);
+  }
+
+  // Get a name new inner Add node
+  string InnerAddNodeName(const NodeDef* node) const {
+    auto scoped_node = ParseScopedNodeName(node->name());
+    return OptimizedNodeName("Add", scoped_node);
+  }
+
+  // Determine the set of common factors if the input nodes are all Mul nodes.
+  Status GetCommonFactors(const NodeDef* node,
+                          std::set<string>* common_factors) const {
+    CHECK(common_factors->empty());
+
+    for (int i = 0; i < node->input_size(); ++i) {
+      if (i > 0 && common_factors->empty()) break;
+      if (IsControlInput(node->input(i))) break;
+
+      NodeDef* input;
+      TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input));
+
+      if (!IsMul(*input)) {
+        common_factors->clear();
+        break;
+      }
+
+      std::set<string> factors_i{input->input(0), input->input(1)};
+      if (i == 0) {
+        std::swap(*common_factors, factors_i);
+      } else {
+        std::set<string> intersection;
+        std::set_intersection(
+            factors_i.begin(), factors_i.end(), common_factors->begin(),
+            common_factors->end(),
+            std::inserter(intersection, intersection.begin()));
+        std::swap(*common_factors, intersection);
+      }
+    }
+    return Status::OK();
+  }
+
+  // Gather up the non-shared factors (the y's in the example).
+  // Unless the aggregation is Add, we have to make sure that all the y's
+  // have the same shape since the other aggregation ops do not support
+  // broadcasting.
+  Status GetUniqueFactors(const NodeDef* node, const string& common_factor,
+                          bool* shapes_match,
+                          std::vector<string>* unique_factors) const {
+    *shapes_match = true;
+    unique_factors->reserve(node->input_size());
+
+    for (int i = 0; i < node->input_size() && shapes_match; ++i) {
+      const string& input = node->input(i);
+      if (IsControlInput(input)) {
+        break;
+      }
+      NodeDef* mul_node;
+      TF_RETURN_IF_ERROR(GetInputNode(input, &mul_node));
+      const int unique_factor_index =
+          mul_node->input(0) == common_factor ? 1 : 0;
+      unique_factors->push_back(mul_node->input(unique_factor_index));
+      if (i > 0 && !IsAdd(*node)) {
+        *shapes_match = ShapesEqual(unique_factors->front(),
+                                    unique_factors->back(), *ctx_.node_map);
+      }
+    }
+    return Status::OK();
+  }
+
+  bool IsRewritten(const NodeDef* node) const {
+    // if graph rewrite happens in multiple passes without graph pruning between
+    // them, it's possible that rewritten node already exists in a graph
+    return rewritten_nodes_.find(node->name()) != rewritten_nodes_.end() ||
+           ctx_.node_map->NodeExists(OuterMulNodeName(node));
+  }
+
+  // keep names of the nodes that were optimized by this stage
+  std::unordered_set<string> rewritten_nodes_;
+};
+
 // Removes inverse transpose nodes
 class RemoveInverseTranspose : public ArithmeticOptimizerStage {
  public:
-  explicit RemoveInverseTranspose(ArithmeticOptimizerContext ctx)
-      : ArithmeticOptimizerStage(ctx) {}
+  explicit RemoveInverseTranspose(const ArithmeticOptimizerContext& ctx)
+      : ArithmeticOptimizerStage("RemoveInverseTranspose", ctx) {}
   ~RemoveInverseTranspose() override = default;
 
   bool IsSupported(const NodeDef* node) const override {
@@ -702,8 +976,8 @@ class RemoveInverseTranspose : public ArithmeticOptimizerStage {
 // 2) Rewrite Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2)
 class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage {
  public:
-  explicit RemoveRedundantBitcastStage(ArithmeticOptimizerContext ctx)
-      : ArithmeticOptimizerStage(ctx) {}
+  explicit RemoveRedundantBitcastStage(const ArithmeticOptimizerContext& ctx)
+      : ArithmeticOptimizerStage("RemoveRedundantBitcast", ctx) {}
   ~RemoveRedundantBitcastStage() override = default;
 
   bool IsSupported(const NodeDef* node) const override {
@@ -742,8 +1016,8 @@ class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage {
 // Remove Casts whose source type and destination type are equal.
 class RemoveRedundantCastStage : public ArithmeticOptimizerStage {
  public:
-  explicit RemoveRedundantCastStage(ArithmeticOptimizerContext ctx)
-      : ArithmeticOptimizerStage(ctx) {}
+  explicit RemoveRedundantCastStage(const ArithmeticOptimizerContext& ctx)
+      : ArithmeticOptimizerStage("RemoveRedundantCast", ctx) {}
   ~RemoveRedundantCastStage() override = default;
 
   bool IsSupported(const NodeDef* node) const override { return IsCast(*node); }
@@ -1276,98 +1550,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
     }
   }
 
-  // Use the commutativity and (left- and right-) distributive property of
-  // multiplication over addition to hoist common factors out of aggregate nodes
-  // where all the inputs are Mul nodes. This pattern occurs frequently in
-  // regularization terms for the gradients during training.
-  // For example, we can rewrite an expression of the form:
-  //   AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
-  // to the following:
-  //   Mul(x, AddN(y1, y2, y3, ... yn))
-  if (IsAggregate(*node) && NumNonControlInputs(*node) > 1 &&
-      !OptimizedNodeExists(*node, "hoist_add") &&
-      !OptimizedNodeExists(*node, "hoist_mul")) {
-    // Determine the set of common factors if the input nodes are all Mul nodes.
-    std::set<string> common_factors;
-    for (int i = 0; i < node->input_size(); ++i) {
-      if (i > 0 && common_factors.empty()) {
-        break;
-      }
-      if (IsControlInput(node->input(i))) {
-        break;
-      }
-      const NodeDef* input = node_map_->GetNode(node->input(i));
-      if (input->op() == "Mul") {
-        std::set<string> factors_i{input->input(0), input->input(1)};
-        if (i == 0) {
-          std::swap(common_factors, factors_i);
-        } else {
-          std::set<string> intersection;
-          std::set_intersection(
-              factors_i.begin(), factors_i.end(), common_factors.begin(),
-              common_factors.end(),
-              std::inserter(intersection, intersection.begin()));
-          std::swap(common_factors, intersection);
-        }
-      } else {
-        common_factors.clear();
-      }
-    }
-    if (common_factors.size() == 1) {
-      const string& common_factor = *common_factors.begin();
-
-      // Gather up the non-shared factors (the y's in the example).
-      // Unless the aggregation is Add, we have to make sure that all the y's
-      // have the same shape since the other aggregation ops do not support
-      // broadcasting.
-      std::vector<string> unique_factors;
-      unique_factors.reserve(node->input_size());
-      bool shapes_match = true;
-      for (int i = 0; i < node->input_size() && shapes_match; ++i) {
-        const string& input = node->input(i);
-        if (IsControlInput(input)) {
-          break;
-        }
-        const NodeDef* mul_node = node_map_->GetNode(input);
-        const int unique_factor_index =
-            mul_node->input(0) == common_factor ? 1 : 0;
-        unique_factors.push_back(mul_node->input(unique_factor_index));
-        if (i > 0 && !IsAdd(*node)) {
-          shapes_match = ShapesEqual(unique_factors.front(),
-                                     unique_factors.back(), *node_map_);
-        }
-      }
-
-      if (shapes_match) {
-        // 1. Use a copy of the first Mul node for the outer multiplication.
-        NodeDef* new_mul_node = AddNode(OptimizedNodeName(*node, "hoist_mul"),
-                                        node_map_->GetNode(node->input(0)));
-        NodeDef* new_add_node = AddNode(*node, "hoist_add", /*copy_node=*/true);
-        new_mul_node->set_device(node->device());
-        new_mul_node->set_input(0, common_factor);
-        node_map_->AddOutput(common_factor, new_mul_node->name());
-        new_mul_node->set_input(1, new_add_node->name());
-        node_map_->AddOutput(new_add_node->name(), new_mul_node->name());
-
-        // 2. Hoist non-shared factors up into the new AddN node.
-        nodes_to_simplify->PushBack(new_add_node);
-        for (int i = 0; i < node->input_size(); ++i) {
-          const string& input = node->input(i);
-          if (IsControlInput(input)) {
-            break;
-          }
-          new_add_node->set_input(i, unique_factors[i]);
-        }
-
-        // 3. Add frame dependencies that the original node might have had.
-        AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor,
-                            {new_add_node});
-
-        return new_mul_node->name();
-      }
-    }
-  }
-
   // Fold Transpose into matrix multiplication.
   if ((node->op() == "MatMul" || node->op() == "SparseMatMul" ||
        node->op() == "BatchMatMul") &&
@@ -1444,8 +1626,9 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
     nodes_to_simplify.PushBack(optimized_graph_->mutable_node(i));
   }
 
-  ArithmeticOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
-                                 node_map_.get(), &nodes_to_simplify);
+  const ArithmeticOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
+                                       node_map_.get(), &frame_map_,
+                                       &nodes_to_simplify);
 
   std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages;
 
@@ -1453,6 +1636,10 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
     stages.push_back(
         std::unique_ptr<ArithmeticOptimizerStage>(new AddOpsRewriteStage(ctx)));
   }
+  if (options_.hoist_common_factor_out_of_aggregation) {
+    stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
+        new HoistCommonFactorOutOfAggregation(ctx)));
+  }
   if (options_.remove_inverse_transpose) {
     stages.push_back(std::unique_ptr<ArithmeticOptimizerStage>(
         new RemoveInverseTranspose(ctx)));
index 7870844..d5a7af5 100644 (file)
@@ -56,6 +56,7 @@ class ArithmeticOptimizer : public GraphOptimizer {
   // Granular control for arithmetic optimizer stages
   struct ArithmeticOptimizerOptions {
     bool combine_add_to_addn = true;
+    bool hoist_common_factor_out_of_aggregation = true;
     bool remove_inverse_transpose = true;
     bool remove_redundant_bitcast = true;
     bool remove_redundant_cast = true;
index 98842b2..e1f4762 100644 (file)
@@ -30,6 +30,22 @@ namespace grappler {
 
 namespace {
 
+constexpr char kHoistFactorOptimizerMul[] =
+    "ArithmeticOptimizer/HoistCommonFactor_Mul_";
+
+constexpr char kHoistFactorOptimizerAdd[] =
+    "ArithmeticOptimizer/HoistCommonFactor_Add_";
+
+// Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation
+string HoistMulName(const string& name) {
+  return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, "");
+}
+
+// Optimized name of inner Add node by HoistCommonFactorOutOfAggregation
+string HoistAddName(const string& name) {
+  return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, "");
+}
+
 string OptimizedName(const string& name) {
   return AddPrefixToNodeName(name, kArithmeticOptimizer);
 }
@@ -61,22 +77,40 @@ class ArithmeticOptimizerTest : public GrapplerTest {
     TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
   }
 
+  // Run ArithmeticOptimizer twice to make sure the rewrite is idempotent.
+  void OptimizeTwice(ArithmeticOptimizer* optimizer, GrapplerItem* item,
+                     GraphDef* output) {
+    TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
+    item->graph.Swap(output);
+    TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
+  }
+
   // TODO(ezhulenev): Make private. After migration to stages each test
   // should explicitly enable required optimization for tests isolation
   void DisableAllStages(ArithmeticOptimizer* optimizer) {
     ArithmeticOptimizer::ArithmeticOptimizerOptions options;
     options.combine_add_to_addn = false;
+    options.hoist_common_factor_out_of_aggregation = false;
     options.remove_inverse_transpose = false;
     options.remove_redundant_bitcast = false;
     options.remove_redundant_cast = false;
     optimizer->options_ = options;
   }
 
+  void DisableAddToAddNCombining(ArithmeticOptimizer* optimizer) {
+    optimizer->options_.combine_add_to_addn = false;
+  }
+
   void EnableOnlyAddToAddNCombining(ArithmeticOptimizer* optimizer) {
     DisableAllStages(optimizer);
     optimizer->options_.combine_add_to_addn = true;
   }
 
+  void EnableOnlyHoistCommonFactor(ArithmeticOptimizer* optimizer) {
+    DisableAllStages(optimizer);
+    optimizer->options_.hoist_common_factor_out_of_aggregation = true;
+  }
+
   void EnableOnlyRemoveInverseTranspose(ArithmeticOptimizer* optimizer) {
     DisableAllStages(optimizer);
     optimizer->options_.remove_inverse_transpose = true;
@@ -396,59 +430,66 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
   }
 
   ArithmeticOptimizer optimizer;
-  DisableAllStages(&optimizer);
+  DisableAddToAddNCombining(&optimizer);
 
   GraphDef output;
-  Status status = optimizer.Optimize(nullptr, item, &output);
-  TF_EXPECT_OK(status);
-  // Run the optimizer twice to make sure the rewrite is idempotent.
-  item.graph.Swap(&output);
-  status = optimizer.Optimize(nullptr, item, &output);
-  TF_EXPECT_OK(status);
+  OptimizeTwice(&optimizer, &item, &output);
 
-  EXPECT_EQ(17, output.node_size());
-  // The graph gets optimized to
+  // We expect the following rewrite(s) to occur:
+  //
   // Mul(p,
-  //     Add(Add(Const(2), Const(2)),
-  //         Add(Const(2), Const(2))))
+  //     Add_6(Add_4(Const(2), Const(2)),
+  //           Add_5(Const(2), Const(2))))
+  NodeMap node_map(&output);
+
   EXPECT_EQ(17, output.node_size());
-  for (const auto& node : output.node()) {
-    if ("id" == node.name()) {
-      EXPECT_EQ(1, node.input_size());
-      EXPECT_EQ(OptimizedName("Add_6_hoist_mul"), node.input(0));
-    } else if (OptimizedName("Add_6_hoist_mul") == node.name()) {
-      EXPECT_EQ("Mul", node.op());
-      EXPECT_EQ(2, node.input_size());
-      EXPECT_EQ("Placeholder", node.input(0));
-      EXPECT_EQ(OptimizedName("Add_6_hoist_add"), node.input(1));
-    } else if (OptimizedName("Add_6_hoist_add") == node.name()) {
-      EXPECT_EQ("Add", node.op());
-      EXPECT_EQ(3, node.input_size());
-      EXPECT_EQ(OptimizedName("Add_4_hoist_add"), node.input(0));
-      EXPECT_EQ(OptimizedName("Add_5_hoist_add"), node.input(1));
-      EXPECT_EQ("^Placeholder", node.input(2));
-    } else if (OptimizedName("Add_4_hoist_add") == node.name()) {
-      EXPECT_EQ("Add", node.op());
-      EXPECT_EQ(3, node.input_size());
-      EXPECT_EQ(OptimizedName("Add_const"), node.input(0));
-      EXPECT_EQ(OptimizedName("Add_1_const"), node.input(1));
-      EXPECT_EQ("^Placeholder", node.input(2));
-    } else if (OptimizedName("Add_5_hoist_add") == node.name()) {
-      EXPECT_EQ("Add", node.op());
-      EXPECT_EQ(3, node.input_size());
-      EXPECT_EQ(OptimizedName("Add_const"), node.input(0));
-      EXPECT_EQ(OptimizedName("Add_1_const"), node.input(1));
-      EXPECT_EQ("^Placeholder", node.input(2));
-    } else if (OptimizedName("Add_const") == node.name()) {
-      EXPECT_EQ("Const", node.op());
-      EXPECT_EQ(1, node.input_size());
-      EXPECT_EQ("^Placeholder", node.input(0));
-    } else if (OptimizedName("Add_1_const") == node.name()) {
-      EXPECT_EQ("Const", node.op());
-      EXPECT_EQ(1, node.input_size());
-      EXPECT_EQ("^Placeholder", node.input(0));
-    }
-  }
+
+  const NodeDef* id_node = node_map.GetNode("id");
+  ASSERT_TRUE(id_node != nullptr);
+  EXPECT_EQ(1, id_node->input_size());
+  EXPECT_EQ(HoistMulName("Add_6"), id_node->input(0));
+
+  const NodeDef* mul_node = node_map.GetNode(HoistMulName("Add_6"));
+  ASSERT_TRUE(mul_node != nullptr);
+  EXPECT_EQ(2, mul_node->input_size());
+  EXPECT_EQ("Placeholder", mul_node->input(0));
+  EXPECT_EQ(HoistAddName("Add_6"), mul_node->input(1));
+
+  const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6"));
+  ASSERT_TRUE(add_6_node != nullptr);
+  EXPECT_EQ(3, add_6_node->input_size());
+  EXPECT_EQ(HoistAddName("Add_4"), add_6_node->input(0));
+  EXPECT_EQ(HoistAddName("Add_5"), add_6_node->input(1));
+  EXPECT_EQ("^Placeholder", add_6_node->input(2));
+
+  const NodeDef* add_4_node = node_map.GetNode(HoistAddName("Add_4"));
+  ASSERT_TRUE(add_4_node != nullptr);
+  EXPECT_EQ("Add", add_4_node->op());
+  EXPECT_EQ(3, add_4_node->input_size());
+  EXPECT_EQ(OptimizedName("Add_const"), add_4_node->input(0));
+  EXPECT_EQ(OptimizedName("Add_1_const"), add_4_node->input(1));
+  EXPECT_EQ("^Placeholder", add_4_node->input(2));
+
+  const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5"));
+  ASSERT_TRUE(add_5_node != nullptr);
+  EXPECT_EQ("Add", add_5_node->op());
+  EXPECT_EQ(3, add_5_node->input_size());
+  EXPECT_EQ(OptimizedName("Add_const"), add_5_node->input(0));
+  EXPECT_EQ(OptimizedName("Add_1_const"), add_5_node->input(1));
+  EXPECT_EQ("^Placeholder", add_5_node->input(2));
+
+  const NodeDef* add_const_node = node_map.GetNode(OptimizedName("Add_const"));
+  ASSERT_TRUE(add_const_node != nullptr);
+  EXPECT_EQ("Const", add_const_node->op());
+  EXPECT_EQ(1, add_const_node->input_size());
+  EXPECT_EQ("^Placeholder", add_const_node->input(0));
+
+  const NodeDef* add_1_const_node =
+      node_map.GetNode(OptimizedName("Add_1_const"));
+  ASSERT_TRUE(add_1_const_node != nullptr);
+  EXPECT_EQ("Const", add_1_const_node->op());
+  EXPECT_EQ(1, add_1_const_node->input_size());
+  EXPECT_EQ("^Placeholder", add_1_const_node->input(0));
 }
 
 TEST_F(ArithmeticOptimizerTest, HoistFactor) {
@@ -469,31 +510,46 @@ TEST_F(ArithmeticOptimizerTest, HoistFactor) {
                                    ops::Add(s.WithOpName("add"), mul1, mul2));
 
       GrapplerItem item;
+      item.fetch = {"id"};
       TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
       ArithmeticOptimizer optimizer;
+      EnableOnlyHoistCommonFactor(&optimizer);
+
       GraphDef output;
-      Status status = optimizer.Optimize(nullptr, item, &output);
-      TF_EXPECT_OK(status);
-      // Run the optimizer twice to make sure the rewrite is idempotent.
-      item.graph.Swap(&output);
-      status = optimizer.Optimize(nullptr, item, &output);
-      TF_EXPECT_OK(status);
+      OptimizeTwice(&optimizer, &item, &output);
+
+      // We expect the following rewrite(s) to occur:
+      //
+      //        Add                 Mul
+      //      /    \               /   \
+      //    Mul    Mul       ->   x    Add
+      //    / \    / \                 / \
+      //   x  y1  y2  x              y1   y2
+      //
+      // If "root" op is AddN and shapes does not match, this rewrite is not
+      // possible and graph should stay intact.
+      NodeMap node_map(&output);
 
       if (use_addn && !matching_shapes) {
         VerifyGraphsMatch(item.graph, output, __LINE__);
       } else {
         EXPECT_EQ(9, output.node_size());
-        const NodeDef& new_add = output.node(8);
-        EXPECT_EQ(OptimizedName("add_hoist_add"), new_add.name());
-        EXPECT_EQ("y1", new_add.input(0));
-        EXPECT_EQ("y2", new_add.input(1));
-        const NodeDef& new_mul = output.node(7);
-        EXPECT_EQ(OptimizedName("add_hoist_mul"), new_mul.name());
-        EXPECT_EQ("x", new_mul.input(0));
-        EXPECT_EQ(OptimizedName("add_hoist_add"), new_mul.input(1));
-        const NodeDef& new_id = output.node(6);
-        EXPECT_EQ("id", new_id.name());
-        EXPECT_EQ(OptimizedName("add_hoist_mul"), new_id.input(0));
+
+        const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add"));
+        ASSERT_TRUE(new_add_node != nullptr) << "Hoisted Add node not found";
+        EXPECT_EQ("y1", new_add_node->input(0));
+        EXPECT_EQ("y2", new_add_node->input(1));
+
+        const NodeDef* new_mul_node = node_map.GetNode(HoistMulName("add"));
+        ASSERT_TRUE(new_mul_node != nullptr) << "Hoisted Mul node not found";
+        EXPECT_EQ("x", new_mul_node->input(0));
+        EXPECT_EQ(new_add_node->name(), new_mul_node->input(1));
+
+        const NodeDef* id_node = node_map.GetNode("id");
+        ASSERT_TRUE(id_node != nullptr) << "Id node not found";
+        EXPECT_EQ("id", id_node->name());
+        EXPECT_EQ(HoistMulName("add"), id_node->input(0));
       }
     }
   }
@@ -1249,8 +1305,9 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) {
   NodeMap node_map(&output);
 
   // check add tree was replaced with AddN
-  const NodeDef* collapsed_add = CHECK_NOTNULL(
-      node_map.GetNode("y/ArithmeticOptimizer/AddOpsGroup_Add_abc_Add_ab"));
+  const NodeDef* collapsed_add =
+      node_map.GetNode("y/ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab");
+  ASSERT_TRUE(collapsed_add != nullptr);
 
   EXPECT_EQ("AddN", collapsed_add->op());
   EXPECT_EQ(3, collapsed_add->input_size());
@@ -1259,7 +1316,8 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) {
   EXPECT_EQ("c", collapsed_add->input(2));
 
   // check output was re-wired to new node
-  const NodeDef* updated_outputs = CHECK_NOTNULL(node_map.GetNode("outputs"));
+  const NodeDef* updated_outputs = node_map.GetNode("outputs");
+  ASSERT_TRUE(updated_outputs != nullptr);
 
   EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
 }
@@ -1306,8 +1364,9 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
   NodeMap node_map(&output);
 
   // check left Add subtree replaced with AddN
-  const NodeDef* collapsed_left = CHECK_NOTNULL(
-      node_map.GetNode("ArithmeticOptimizer/AddOpsGroup_Add_abc_Add_ab"));
+  const NodeDef* collapsed_left =
+      node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab");
+  ASSERT_TRUE(collapsed_left != nullptr);
 
   EXPECT_EQ("AddN", collapsed_left->op());
   EXPECT_EQ(3, collapsed_left->input_size());
@@ -1316,8 +1375,9 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
   EXPECT_EQ("c", collapsed_left->input(2));
 
   // check right Add subtree replaced with AddN
-  const NodeDef* collapsed_right = CHECK_NOTNULL(
-      node_map.GetNode("ArithmeticOptimizer/AddOpsGroup_Add_xyz_Add_xy"));
+  const NodeDef* collapsed_right =
+      node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_xyz_Add_xy");
+  ASSERT_TRUE(collapsed_right != nullptr);
 
   EXPECT_EQ("AddN", collapsed_right->op());
   EXPECT_EQ(3, collapsed_right->input_size());
@@ -1326,7 +1386,8 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
   EXPECT_EQ("z", collapsed_right->input(2));
 
   // check that Mul inputs re-wired to new Nodes
-  const NodeDef* updated_mul = CHECK_NOTNULL(node_map.GetNode("Mul"));
+  const NodeDef* updated_mul = node_map.GetNode("Mul");
+  ASSERT_TRUE(updated_mul != nullptr);
 
   EXPECT_EQ("Mul", updated_mul->op());
   EXPECT_EQ(2, updated_mul->input_size());
@@ -1367,8 +1428,9 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputThroughMultiplePaths) {
   NodeMap node_map(&output);
 
   // check Add tree replaced with AddN
-  const NodeDef* collapsed_add = CHECK_NOTNULL(node_map.GetNode(
-      "ArithmeticOptimizer/AddOpsGroup_Add_all_Add_ab_Add_bc"));
+  const NodeDef* collapsed_add = node_map.GetNode(
+      "ArithmeticOptimizer/AddOpsRewrite_Add_all_Add_ab_Add_bc");
+  ASSERT_TRUE(collapsed_add != nullptr);
 
   EXPECT_EQ("AddN", collapsed_add->op());
   EXPECT_EQ(4, collapsed_add->input_size());