Add/AddN optimizer/rewriter
authorA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 8 Mar 2018 22:42:12 +0000 (14:42 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Thu, 8 Mar 2018 22:46:40 +0000 (14:46 -0800)
Collapse a sub-graph of Add/AddN operations of fully specified
and identical shapes to a single AddN operation.

PiperOrigin-RevId: 188392302

tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc

index 709a434..3cf42fd 100644 (file)
@@ -214,7 +214,12 @@ PartialTensorShape GetInputShape(const string& input, const NodeMap& node_map) {
   int output_pos;
   string node_name = ParseNodeName(input, &output_pos);
   const NodeDef* input_node = node_map.GetNode(node_name);
-  return input_node->attr().at(kOutputShapesAttr).list().shape(output_pos);
+  auto attr = input_node->attr();
+  if (attr.find(kOutputShapesAttr) == attr.end()) {
+    return PartialTensorShape();  // unknown shape
+  } else {
+    return attr.at(kOutputShapesAttr).list().shape(output_pos);
+  }
 }
 
 bool ShapesEqual(const string& input_x, const string& input_y,
@@ -292,6 +297,359 @@ NodeDef* GetTailOfValuePreservingChain(
                         is_value_preserving_non_branching);
 }
 
+// Context passed to each arithmetic optimizer stage. Optimizer stage is
+// responsible for updating the node map for all added or deleted nodes, to keep
+// it consistent with optimized graph.
+struct ArithmeticOptimizerContext {
+  ArithmeticOptimizerContext(
+      const std::unordered_set<string>* nodes_to_preserve,
+      GraphDef* optimized_graph, NodeMap* node_map,
+      SetVector<NodeDef*>* nodes_to_simplify)
+      : nodes_to_preserve(nodes_to_preserve),
+        optimized_graph(optimized_graph),
+        node_map(node_map),
+        nodes_to_simplify(nodes_to_simplify) {}
+
+  const std::unordered_set<string>* nodes_to_preserve;
+  GraphDef* optimized_graph;
+  NodeMap* node_map;
+  SetVector<NodeDef*>* nodes_to_simplify;
+};
+
+// Base class for single arithmetic optimization: e.g. Bitcast optimization,
+// AddOps optimization, etc...
+class ArithmeticOptimizerStage {
+ public:
+  explicit ArithmeticOptimizerStage(ArithmeticOptimizerContext ctx)
+      : ctx_(ctx) {}
+  virtual ~ArithmeticOptimizerStage() = default;
+
+  // Check if we should try to simplify node. Returning true doesn't
+  // guarantee that node will be simplified.
+  //
+  // Should implement just a basic sanity check, without any expensive graph
+  // traversals.
+  virtual bool IsSupported(const NodeDef* node) const = 0;
+
+  // Try to simplify the given node. If successfully simplified a given node,
+  // return a name of a new simplified version using output parameter.
+  //
+  // Consumers of an old node's outputs will be automatically re-wired to
+  // consume outputs of a new simplified node.
+  //
+  // Return error status only if some precondition is failed, or got an
+  // incorrect graph. In every other case return Status:OK(), even if didn't
+  // simplify anything.
+  //
+  // A simplified node will be always considered for further optimization and
+  // will be automatically added to the optimization queue. If a simplified node
+  // has the same name as original node it has to be explicitly added to the
+  // optimization queue for second pass.
+  virtual Status TrySimplify(const NodeDef* node,
+                             string* simplified_node_name) = 0;
+
+ protected:
+  // Simplification graph rewrite can create additional nodes that are inputs
+  // to final simplified node, they can be also added to the arithmetic
+  // optimizer queue for further optimization.
+  void AddToOptimizationQueue(NodeDef* node) {
+    ctx_.nodes_to_simplify->PushBack(node);
+  }
+
+  // Get a node by input name from a node map. Return a error if node was not
+  // found.
+  Status GetInputNode(const string& input, NodeDef** node) const {
+    string node_name = NodeName(input);
+    NodeDef* node_by_name = ctx_.node_map->GetNode(node_name);
+    if (node_by_name == nullptr) {
+      return errors::FailedPrecondition("Node ", node_name,
+                                        " doesn't exists in a node map");
+    }
+    *node = node_by_name;
+    return Status::OK();
+  }
+
+  // Get input shape from a node map. If node doesn't exists return unknown
+  // shape.
+  PartialTensorShape GetInputShape(const string& input) const {
+    int position;
+    string node_name = ParseNodeName(input, &position);
+    NodeDef* node;
+    Status node_status = GetInputNode(node_name, &node);
+    if (!node_status.ok()) {
+      return PartialTensorShape();  // unknown shape
+    }
+    auto attr = node->attr();
+    if (attr.find(kOutputShapesAttr) == attr.end()) {
+      return PartialTensorShape();  // unknown shape
+    } else {
+      return attr.at(kOutputShapesAttr).list().shape(position);
+    }
+  }
+
+  ArithmeticOptimizerContext ctx_;
+};
+
+// Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
+// original inputs of absorbed nodes.
+//
+// All nodes in a Add/AddN subgraph must have fully specified and identical
+// shape. All nodes must have the same device placement.
+//
+// Example:
+//                AddN_1
+//             /    |    \
+//          Add_1   z   Add_2       -> AddN(z, y, z, w, q, e)
+//          /  \        /  \
+//         x    y      w    Add_3
+//                          / \
+//                         q   e
+class AddOpsRewriteStage : public ArithmeticOptimizerStage {
+ public:
+  explicit AddOpsRewriteStage(ArithmeticOptimizerContext ctx)
+      : ArithmeticOptimizerStage(ctx), rewritten_nodes_() {}
+
+  ~AddOpsRewriteStage() override = default;
+
+  // Check if a node can become a root of AddOpsGroup
+  bool IsSupported(const NodeDef* node) const override {
+    // check basic preconditions
+    if (!IsRewritable(node)) {
+      return false;
+    }
+    // and must have fully defined shape
+    // TODO(ezhulenev): support partially defined shapes, when we can prove that
+    // unknown dimensions in the rewritten subgraph are the same.
+    PartialTensorShape shape = GetInputShape(node->name());
+    if (!shape.IsFullyDefined()) {
+      return false;
+    }
+    // and must have inputs of fully defined shape identical to the output
+    // TODO(ezhulenev): relax this condition to support equal unknown dimensions
+    return HasAllInputsOfIdenticalShape(*node, shape);
+  }
+
+  Status TrySimplify(const NodeDef* node,
+                     string* simplified_node_name) override {
+    CHECK(IsSupported(node))
+        << "Node " << node->name()
+        << " is not supported by add ops group optimizer step";
+    AddOpsGroup group;
+    TF_RETURN_IF_ERROR(CreateAddOpsGroup(node, &group));
+
+    if (!group.absorbed_nodes.empty()) {
+      *simplified_node_name = RewriteAddOpsGroup(group);
+    }
+
+    return Status::OK();
+  }
+
+ private:
+  // Holds together an add ops subgraph that we want to rewrite together.
+  //
+  // For the graph above the AddOpsGroup will be:
+  //   root_node: AddN_1
+  //   absorbed_nodes: [Add_1, Add_2]
+  //   input_nodes: [x, y, z, w, q, e]
+  struct AddOpsGroup {
+    const NodeDef* root_node;
+    PartialTensorShape root_shape;
+    // Add/AddN operations below the root level that were absorbed by this group
+    std::vector<NodeDef*> absorbed_nodes;
+    // Inputs of absorbed nodes that will be forwarded to rewritten AddN node
+    std::vector<string> inputs;
+  };
+
+  // Check if all inputs are fully defined and identical to expected shape
+  bool HasAllInputsOfIdenticalShape(const NodeDef& node,
+                                    const PartialTensorShape& shape) const {
+    const AddOpsRewriteStage* self = this;
+    return std::all_of(node.input().begin(), node.input().end(),
+                       [self, &shape](const string& input) {
+                         auto input_shape = self->GetInputShape(input);
+                         return input_shape.IsFullyDefined() &&
+                                input_shape.IsIdenticalTo(shape);
+                       });
+  }
+
+  // TODO(ezhulenev): use GraphRewriter?
+  bool IsDrivenByControlDependency(const NodeDef& node) const {
+    return std::any_of(node.input().begin(), node.input().end(),
+                       IsControlInput);
+  }
+
+  // TODO(ezhulenev): use GraphRewriter?
+  bool DrivesControlDependency(const NodeDef& node) const {
+    int position;
+    for (const NodeDef* output : ctx_.node_map->GetOutputs(node.name())) {
+      for (int i = 0; i < output->input_size(); ++i) {
+        auto input = output->input(i);
+        string name = ParseNodeName(input, &position);
+        if (name == node.name() && /*control input*/ position < 0) {
+          return true;
+        }
+      }
+    }
+    return false;
+  }
+
+  // Check if a node can be absorbed by current AddOpsGroup
+  bool IsAbsorbableByAddOpsGroup(const string& name, const AddOpsGroup& group) {
+    NodeDef* node;
+    Status node_status = GetInputNode(name, &node);
+    if (!node_status.ok()) {
+      return false;
+    }
+
+    PartialTensorShape shape = GetInputShape(name);
+    CHECK(shape.IsIdenticalTo(group.root_shape))
+        << "Cannot absorb a node of incompatible shape";
+
+    // check basic preconditions
+    if (!IsRewritable(node)) {
+      return false;
+    }
+    // with a single output consumer (presumably if we reach this node from
+    // previously absorbed or a root node, it means that this node is not used
+    // as an input to any other op, outside of the group)
+    if (ctx_.node_map->GetOutputs(node->name()).size() != 1) {
+      return false;
+    }
+    // must be on the same device as a root node
+    if (node->device() != group.root_node->device()) {
+      return false;
+    }
+    // All input shapes must be fully defined and equal to the node shape
+    return HasAllInputsOfIdenticalShape(*node, shape);
+  }
+
+  // Node requirements both for a root node and an absorbed node
+  bool IsRewritable(const NodeDef* node) const {
+    // only Add or AddN can be a root node
+    // TODO(ezhulenev): check if AccumulateNV2 can be supported too
+    if (!IsAdd(*node) && !IsAddN(*node)) {
+      return false;
+    }
+    // it must not be in a preserve set
+    if (ctx_.nodes_to_preserve->find(node->name()) !=
+        ctx_.nodes_to_preserve->end()) {
+      return false;
+    }
+    // it must not be a node created or absorbed by previous iteration
+    if (rewritten_nodes_.find(node->name()) != rewritten_nodes_.end()) {
+      return false;
+    }
+    // should not drive or be driven by control dependency
+    // TODO(ezhulenev): relax this condition for root node
+    return !(IsDrivenByControlDependency(*node) ||
+             DrivesControlDependency(*node));
+  }
+
+  // Create an AddOpsGroup with a root in a given node
+  Status CreateAddOpsGroup(const NodeDef* root_node, AddOpsGroup* group) {
+    group->root_node = root_node;
+    group->root_shape = GetInputShape(root_node->name());
+
+    group->absorbed_nodes.reserve(root_node->input_size());
+    for (int i = 0; i < root_node->input_size(); ++i) {
+      TF_RETURN_IF_ERROR(AbsorbInputByAddOpsGroup(root_node->input(i), group));
+    }
+
+    return Status::OK();
+  }
+
+  Status AbsorbInputByAddOpsGroup(const string& input, AddOpsGroup* group) {
+    NodeDef* node;
+    TF_RETURN_IF_ERROR(GetInputNode(input, &node));
+
+    if (IsAbsorbableByAddOpsGroup(input, *group)) {
+      group->absorbed_nodes.push_back(node);
+      for (int i = 0; i < node->input_size(); ++i) {
+        TF_RETURN_IF_ERROR(AbsorbInputByAddOpsGroup(node->input(i), group));
+      }
+    } else {
+      // If node can't be absorbed, add it to AddOpsGroup input
+      group->inputs.push_back(input);
+    }
+    return Status::OK();
+  }
+
+  const std::pair<string, string> ParseNodeScopeAndName(const string& name) {
+    auto pos = name.find_last_of("/");
+    if (pos == string::npos) {
+      return {"", name};
+    } else {
+      return {name.substr(0, pos), name.substr(pos + 1)};
+    }
+  }
+
+  // New node for AddOpsGroup is added to the same scope as a root_node. All
+  // absorbed nodes are stripped of their scope, and only names are used in a
+  // new node name.
+  //
+  // Example: AddOpsGroup(root="a/b/c/Add_2", absorbed=["d/Add_1", "e/Add"])
+  //          node_name="a/b/c/AddOpsGroup_Add_2_Add_1_Add
+  string AddOpsGroupName(const AddOpsGroup& group) {
+    CHECK_NOTNULL(group.root_node);
+    string node_name;
+
+    auto root_node = ParseNodeScopeAndName(group.root_node->name());
+    auto root_scope = root_node.first;
+    auto root_name = root_node.second;
+    if (!root_scope.empty()) {
+      strings::StrAppend(&node_name, root_scope, "/");
+    }
+
+    strings::StrAppend(&node_name, kArithmeticOptimizer, "/", "AddOpsGroup_",
+                       root_name);
+    for (const NodeDef* absorbed : group.absorbed_nodes) {
+      auto absorbed_node = ParseNodeScopeAndName(absorbed->name());
+      strings::StrAppend(&node_name, "_", absorbed_node.second);
+    }
+    return node_name;
+  }
+
+  // Create a new node for a AddOpsGroup and return it's name.
+  string RewriteAddOpsGroup(const AddOpsGroup& group) {
+    CHECK_GT(group.absorbed_nodes.size(), 0)
+        << "AddOpsGroup must have non empty absorbed nodes";
+
+    // name for a new node constructed from AddOpsGroup
+    string node_name = AddOpsGroupName(group);
+
+    // copy attributes from a root node
+    DataType dtype = group.root_node->attr().at("T").type();
+
+    // add new node
+    NodeDef* added_node = ctx_.optimized_graph->add_node();
+    added_node->set_name(node_name);
+    added_node->set_op("AddN");
+    added_node->set_device(group.root_node->device());
+    (*added_node->mutable_attr())["T"].set_type(dtype);
+    (*added_node->mutable_attr())["N"].set_i(group.inputs.size());
+
+    ctx_.node_map->AddNode(node_name, added_node);
+    for (string input : group.inputs) {
+      ctx_.node_map->AddOutput(input, node_name);
+      added_node->add_input(std::move(input));
+    }
+
+    VLOG(1) << "Absorbed " << group.absorbed_nodes.size()
+            << " Add/AddN nodes from the graph";
+
+    // keep track of nodes that were created or absorbed as a part of rewrite
+    rewritten_nodes_.insert(node_name);
+    for (const NodeDef* absorbed : group.absorbed_nodes) {
+      rewritten_nodes_.insert(absorbed->name());
+    }
+
+    return node_name;
+  }
+
+  // keep nodes that were added or absorbed as a part of AddOpsGroup rewrite
+  std::unordered_set<string> rewritten_nodes_;
+};
+
 }  // namespace
 
 class UniqueNodes {
@@ -516,6 +874,8 @@ void ArithmeticOptimizer::AddFrameControlDeps(
   }
 }
 
+// TODO(ezhulenev): extract each individual simplify rewrite into separate
+// ArithmeticOptimizerStage
 string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
     const NodeDef* node, SetVector<NodeDef*>* nodes_to_simplify) {
   // Remove involutions applied twice.
@@ -1025,14 +1385,46 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
   for (int i = 0; i < optimized_graph_->node_size(); ++i) {
     nodes_to_simplify.PushBack(optimized_graph_->mutable_node(i));
   }
+
+  ArithmeticOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
+                                 node_map_.get(), &nodes_to_simplify);
+
+  std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages;
+
+  // Add/AddN tree rewrites
+  if (options_.enable_add_to_addn_combining) {
+    stages.push_back(
+        std::unique_ptr<ArithmeticOptimizerStage>(new AddOpsRewriteStage(ctx)));
+  }
+
+  VLOG(1) << "Simplify arithmetic ops using " << stages.size()
+          << " arithmetic optimization stages";
+
   while (!nodes_to_simplify.Empty()) {
     const NodeDef* node = nodes_to_simplify.PopBack();
-    const string simplified_tensor =
+
+    // TODO(ezhulenev): move all rewrites into separate stages
+    string simplified_tensor =
         TrySimplifyAndReplaceUses(node, &nodes_to_simplify);
+
+    // if it was not simplified try to run it through all configured stages
+    if (simplified_tensor.empty()) {
+      for (auto& stage : stages) {
+        if (stage->IsSupported(node)) {
+          TF_RETURN_IF_ERROR(stage->TrySimplify(node, &simplified_tensor));
+          if (!simplified_tensor.empty()) {
+            break;
+          }
+        }
+      }
+    }
+
+    // if it's still empty go to the next Node
     if (simplified_tensor.empty()) {
       continue;
     }
 
+    // re-wire consumers of an old node to the new one
     if (NodeName(simplified_tensor) != node->name()) {
       // Always consider simplified_tensor for further optimizations.
       NodeDef* simplified_node = node_map_->GetNode(simplified_tensor);
@@ -1087,6 +1479,7 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
   // Shapes are only needed in aggressive mode.
   graph_properties_.reset(new GraphProperties(item));
   TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false));
+  // TODO(ezhulenev): Use GraphProperties to lookup tensor shapes directly
   TF_RETURN_IF_ERROR(graph_properties_->AnnotateOutputShapes(optimized_graph_));
 
   // Perform the optimizations.
index afd538d..9cff8ca 100644 (file)
@@ -32,9 +32,14 @@ constexpr char kArithmeticOptimizer[] = "ArithmeticOptimizer";
 // run a model.
 class ArithmeticOptimizer : public GraphOptimizer {
  public:
-  ArithmeticOptimizer() : opt_level_(RewriterConfig::ON) {}
+  ArithmeticOptimizer()
+      : opt_level_(RewriterConfig::ON),
+        options_(ArithmeticOptimizerOptions::Default(RewriterConfig::ON)) {}
+
   explicit ArithmeticOptimizer(RewriterConfig::Toggle opt_level)
-      : opt_level_(opt_level) {}
+      : opt_level_(opt_level),
+        options_(ArithmeticOptimizerOptions::Default(opt_level)) {}
+
   ~ArithmeticOptimizer() override {}
 
   string name() const override { return "arithmetic_optimizer"; };
@@ -46,6 +51,21 @@ class ArithmeticOptimizer : public GraphOptimizer {
                 const GraphDef& optimized_graph, double result) override;
 
  private:
+  friend class ArithmeticOptimizerTest;
+
+  // Granular control for arithmetic optimizer stages
+  struct ArithmeticOptimizerOptions {
+    // rewrite a tree of Add/AddN ops with a single AddN
+    bool enable_add_to_addn_combining;
+
+    // Choose which arithmetic optimizer stages will be enabled for a given
+    // optimization level by default.
+    static ArithmeticOptimizerOptions Default(
+        RewriterConfig::Toggle opt_level) {
+      return {/*enable_add_to_addn_combining*/ true};
+    }
+  };
+
   // Returns true is a node with given name and the optimizer prefix already
   // exists.
   string OptimizedNodeName(const NodeDef& node, StringPiece suffix) const;
@@ -97,13 +117,14 @@ class ArithmeticOptimizer : public GraphOptimizer {
                                    SetVector<NodeDef*>* nodes_to_simplify);
 
   RewriterConfig::Toggle opt_level_;
+  ArithmeticOptimizerOptions options_;
 
-  bool fetch_nodes_known_;
+  bool fetch_nodes_known_ = false;
   std::unordered_set<string> nodes_to_preserve_;
   std::unique_ptr<NodeMap> node_map_;
   FrameMap frame_map_;
   std::unique_ptr<GraphProperties> graph_properties_;
-  GraphDef* optimized_graph_;  // Not owned.
+  GraphDef* optimized_graph_ = nullptr;  // Not owned.
 };
 
 }  // end namespace grappler
index 2a82b25..a56351c 100644 (file)
@@ -26,6 +26,7 @@ limitations under the License.
 
 namespace tensorflow {
 namespace grappler {
+
 namespace {
 
 string OptimizedName(const string& name) {
@@ -46,8 +47,32 @@ void VerifyGraphsMatch(const GraphDef& original_graph,
     }
   }
 }
+}  // namespace
 
-class ArithmeticOptimizerTest : public ::testing::Test {};
+class ArithmeticOptimizerTest : public ::testing::Test {
+ protected:
+  // Optimize a graph using ArithmeticOptimizer and prune all the nodes that no
+  // longer have any output consumers.
+  void OptimizeAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item,
+                        GraphDef* output) {
+    TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
+    item->graph.Swap(output);
+    TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
+  }
+
+  // TODO(ezhulenev): Make private. After migration to stages each test
+  // should explicitly enable required optimization for tests isolation
+  void DisableAllStages(ArithmeticOptimizer* optimizer) {
+    ArithmeticOptimizer::ArithmeticOptimizerOptions options{
+        /*enable_add_to_addn_combining*/ false};
+    optimizer->options_ = options;
+  }
+
+  void EnableAddToAddNCombining(ArithmeticOptimizer* optimizer) {
+    DisableAllStages(optimizer);
+    optimizer->options_.enable_add_to_addn_combining = true;
+  }
+};
 
 TEST_F(ArithmeticOptimizerTest, NoOp) {
   // This trivial graph is so basic there's nothing to optimize.
@@ -350,7 +375,10 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
   for (int i = 0; i < item.graph.node_size(); ++i) {
     item.graph.mutable_node(i)->set_device(devices[i]);
   }
+
   ArithmeticOptimizer optimizer;
+  DisableAllStages(&optimizer);
+
   GraphDef output;
   Status status = optimizer.Optimize(nullptr, item, &output);
   TF_EXPECT_OK(status);
@@ -1164,6 +1192,169 @@ TEST_F(ArithmeticOptimizerTest, RemoveRedundantCast) {
                    [](const NodeDef& node) { return node.op() == "Cast"; }));
 }
 
-}  // namespace
+TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+  tensorflow::Scope sx = s.NewSubScope("x");
+  tensorflow::Scope sy = s.NewSubScope("y");
+
+  auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
+  auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
+  auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
+  auto add_ab = ops::Add(sx.WithOpName("Add_ab"), a, b);
+  auto add_abc = ops::Add(sy.WithOpName("Add_abc"), add_ab, c);
+
+  auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
+
+  GrapplerItem item;
+  item.fetch = {"outputs"};
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+  GraphDef output;
+  ArithmeticOptimizer optimizer;
+  EnableAddToAddNCombining(&optimizer);
+
+  OptimizeAndPrune(&optimizer, &item, &output);
+
+  // We expect the following rewrite(s) to occur:
+  //
+  //     +
+  //    / \
+  //   +   c      -->    AddN(a, b, c)
+  //  / \
+  // a   b
+  EXPECT_EQ(5, output.node_size());
+
+  NodeMap node_map(&output);
+
+  // check add tree was replaced with AddN
+  const NodeDef* collapsed_add = CHECK_NOTNULL(
+      node_map.GetNode("y/ArithmeticOptimizer/AddOpsGroup_Add_abc_Add_ab"));
+
+  EXPECT_EQ("AddN", collapsed_add->op());
+  EXPECT_EQ(3, collapsed_add->input_size());
+  EXPECT_EQ("a", collapsed_add->input(0));
+  EXPECT_EQ("b", collapsed_add->input(1));
+  EXPECT_EQ("c", collapsed_add->input(2));
+
+  // check output was re-wired to new node
+  const NodeDef* updated_outputs = CHECK_NOTNULL(node_map.GetNode("outputs"));
+
+  EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
+}
+
+TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+  auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
+  auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
+  auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
+  auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
+  auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
+
+  auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT);
+  auto y = ops::Variable(s.WithOpName("y"), {2, 2}, DT_FLOAT);
+  auto z = ops::Variable(s.WithOpName("z"), {2, 2}, DT_FLOAT);
+  auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y);
+  auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z);
+
+  auto mul = ops::Multiply(s.WithOpName("Mul"), add_abc, add_xyz);
+  auto outputs = ops::Identity(s.WithOpName("outputs"), mul);
+
+  GrapplerItem item;
+  item.fetch = {"outputs"};
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+  GraphDef output;
+  ArithmeticOptimizer optimizer;
+  EnableAddToAddNCombining(&optimizer);
+
+  OptimizeAndPrune(&optimizer, &item, &output);
+
+  // We expect the following rewrite(s) to occur:
+  //
+  //         *
+  //      /     \
+  //     +       +                        *
+  //    / \     / \                    /     \
+  //   +   c   x   + -->    AddN(a, b, c)  AddN(x, y, z))
+  //  / \         / \
+  // a   b       y   z
+  EXPECT_EQ(10, output.node_size());
+
+  NodeMap node_map(&output);
+
+  // check left Add subtree replaced with AddN
+  const NodeDef* collapsed_left = CHECK_NOTNULL(
+      node_map.GetNode("ArithmeticOptimizer/AddOpsGroup_Add_abc_Add_ab"));
+
+  EXPECT_EQ("AddN", collapsed_left->op());
+  EXPECT_EQ(3, collapsed_left->input_size());
+  EXPECT_EQ("a", collapsed_left->input(0));
+  EXPECT_EQ("b", collapsed_left->input(1));
+  EXPECT_EQ("c", collapsed_left->input(2));
+
+  // check right Add subtree replaced with AddN
+  const NodeDef* collapsed_right = CHECK_NOTNULL(
+      node_map.GetNode("ArithmeticOptimizer/AddOpsGroup_Add_xyz_Add_xy"));
+
+  EXPECT_EQ("AddN", collapsed_right->op());
+  EXPECT_EQ(3, collapsed_right->input_size());
+  EXPECT_EQ("x", collapsed_right->input(0));
+  EXPECT_EQ("y", collapsed_right->input(1));
+  EXPECT_EQ("z", collapsed_right->input(2));
+
+  // check that Mul inputs re-wired to new Nodes
+  const NodeDef* updated_mul = CHECK_NOTNULL(node_map.GetNode("Mul"));
+
+  EXPECT_EQ("Mul", updated_mul->op());
+  EXPECT_EQ(2, updated_mul->input_size());
+  EXPECT_EQ(collapsed_left->name(), updated_mul->input(0));
+  EXPECT_EQ(collapsed_right->name(), updated_mul->input(1));
+}
+
+TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputThroughMultiplePaths) {
+  tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+  auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
+  auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
+  auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
+  auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
+  auto add_bc = ops::Add(s.WithOpName("Add_bc"), b, c);
+  auto add_all = ops::Add(s.WithOpName("Add_all"), add_ab, add_bc);
+  auto outputs = ops::Identity(s.WithOpName("outputs"), add_all);
+
+  GrapplerItem item;
+  item.fetch = {"outputs"};
+  TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+  GraphDef output;
+  ArithmeticOptimizer optimizer;
+  EnableAddToAddNCombining(&optimizer);
+
+  OptimizeAndPrune(&optimizer, &item, &output);
+
+  // We expect the following rewrite(s) to occur:
+  //
+  //     +
+  //    / \
+  //   +   +     -->    AddN(a, b, b, c)
+  //  / \ / \                   ^
+  // a   b   c                  b added twice!
+  EXPECT_EQ(5, output.node_size());
+
+  NodeMap node_map(&output);
+
+  // check Add tree replaced with AddN
+  const NodeDef* collapsed_add = CHECK_NOTNULL(node_map.GetNode(
+      "ArithmeticOptimizer/AddOpsGroup_Add_all_Add_ab_Add_bc"));
+
+  EXPECT_EQ("AddN", collapsed_add->op());
+  EXPECT_EQ(4, collapsed_add->input_size());
+  EXPECT_EQ("a", collapsed_add->input(0));
+  EXPECT_EQ("b", collapsed_add->input(1));
+  EXPECT_EQ("b", collapsed_add->input(2));
+  EXPECT_EQ("c", collapsed_add->input(3));
+}
+
 }  // namespace grappler
 }  // namespace tensorflow