int output_pos;
string node_name = ParseNodeName(input, &output_pos);
const NodeDef* input_node = node_map.GetNode(node_name);
- return input_node->attr().at(kOutputShapesAttr).list().shape(output_pos);
+ auto attr = input_node->attr();
+ if (attr.find(kOutputShapesAttr) == attr.end()) {
+ return PartialTensorShape(); // unknown shape
+ } else {
+ return attr.at(kOutputShapesAttr).list().shape(output_pos);
+ }
}
bool ShapesEqual(const string& input_x, const string& input_y,
is_value_preserving_non_branching);
}
+// Context passed to each arithmetic optimizer stage. Optimizer stage is
+// responsible for updating the node map for all added or deleted nodes, to keep
+// it consistent with optimized graph.
+struct ArithmeticOptimizerContext {
+ ArithmeticOptimizerContext(
+ const std::unordered_set<string>* nodes_to_preserve,
+ GraphDef* optimized_graph, NodeMap* node_map,
+ SetVector<NodeDef*>* nodes_to_simplify)
+ : nodes_to_preserve(nodes_to_preserve),
+ optimized_graph(optimized_graph),
+ node_map(node_map),
+ nodes_to_simplify(nodes_to_simplify) {}
+
+ const std::unordered_set<string>* nodes_to_preserve;
+ GraphDef* optimized_graph;
+ NodeMap* node_map;
+ SetVector<NodeDef*>* nodes_to_simplify;
+};
+
+// Base class for single arithmetic optimization: e.g. Bitcast optimization,
+// AddOps optimization, etc...
+class ArithmeticOptimizerStage {
+ public:
+ explicit ArithmeticOptimizerStage(ArithmeticOptimizerContext ctx)
+ : ctx_(ctx) {}
+ virtual ~ArithmeticOptimizerStage() = default;
+
+ // Check if we should try to simplify node. Returning true doesn't
+ // guarantee that node will be simplified.
+ //
+ // Should implement just a basic sanity check, without any expensive graph
+ // traversals.
+ virtual bool IsSupported(const NodeDef* node) const = 0;
+
+ // Try to simplify the given node. If successfully simplified a given node,
+ // return a name of a new simplified version using output parameter.
+ //
+ // Consumers of an old node's outputs will be automatically re-wired to
+ // consume outputs of a new simplified node.
+ //
+ // Return error status only if some precondition is failed, or got an
+ // incorrect graph. In every other case return Status:OK(), even if didn't
+ // simplify anything.
+ //
+ // A simplified node will be always considered for further optimization and
+ // will be automatically added to the optimization queue. If a simplified node
+ // has the same name as original node it has to be explicitly added to the
+ // optimization queue for second pass.
+ virtual Status TrySimplify(const NodeDef* node,
+ string* simplified_node_name) = 0;
+
+ protected:
+ // Simplification graph rewrite can create additional nodes that are inputs
+ // to final simplified node, they can be also added to the arithmetic
+ // optimizer queue for further optimization.
+ void AddToOptimizationQueue(NodeDef* node) {
+ ctx_.nodes_to_simplify->PushBack(node);
+ }
+
+ // Get a node by input name from a node map. Return a error if node was not
+ // found.
+ Status GetInputNode(const string& input, NodeDef** node) const {
+ string node_name = NodeName(input);
+ NodeDef* node_by_name = ctx_.node_map->GetNode(node_name);
+ if (node_by_name == nullptr) {
+ return errors::FailedPrecondition("Node ", node_name,
+ " doesn't exists in a node map");
+ }
+ *node = node_by_name;
+ return Status::OK();
+ }
+
+ // Get input shape from a node map. If node doesn't exists return unknown
+ // shape.
+ PartialTensorShape GetInputShape(const string& input) const {
+ int position;
+ string node_name = ParseNodeName(input, &position);
+ NodeDef* node;
+ Status node_status = GetInputNode(node_name, &node);
+ if (!node_status.ok()) {
+ return PartialTensorShape(); // unknown shape
+ }
+ auto attr = node->attr();
+ if (attr.find(kOutputShapesAttr) == attr.end()) {
+ return PartialTensorShape(); // unknown shape
+ } else {
+ return attr.at(kOutputShapesAttr).list().shape(position);
+ }
+ }
+
+ ArithmeticOptimizerContext ctx_;
+};
+
+// Rewrite a tree of Add/AddN with a single AddN operation, consuming all the
+// original inputs of absorbed nodes.
+//
+// All nodes in a Add/AddN subgraph must have fully specified and identical
+// shape. All nodes must have the same device placement.
+//
+// Example:
+// AddN_1
+// / | \
+// Add_1 z Add_2 -> AddN(z, y, z, w, q, e)
+// / \ / \
+// x y w Add_3
+// / \
+// q e
+class AddOpsRewriteStage : public ArithmeticOptimizerStage {
+ public:
+ explicit AddOpsRewriteStage(ArithmeticOptimizerContext ctx)
+ : ArithmeticOptimizerStage(ctx), rewritten_nodes_() {}
+
+ ~AddOpsRewriteStage() override = default;
+
+ // Check if a node can become a root of AddOpsGroup
+ bool IsSupported(const NodeDef* node) const override {
+ // check basic preconditions
+ if (!IsRewritable(node)) {
+ return false;
+ }
+ // and must have fully defined shape
+ // TODO(ezhulenev): support partially defined shapes, when we can prove that
+ // unknown dimensions in the rewritten subgraph are the same.
+ PartialTensorShape shape = GetInputShape(node->name());
+ if (!shape.IsFullyDefined()) {
+ return false;
+ }
+ // and must have inputs of fully defined shape identical to the output
+ // TODO(ezhulenev): relax this condition to support equal unknown dimensions
+ return HasAllInputsOfIdenticalShape(*node, shape);
+ }
+
+ Status TrySimplify(const NodeDef* node,
+ string* simplified_node_name) override {
+ CHECK(IsSupported(node))
+ << "Node " << node->name()
+ << " is not supported by add ops group optimizer step";
+ AddOpsGroup group;
+ TF_RETURN_IF_ERROR(CreateAddOpsGroup(node, &group));
+
+ if (!group.absorbed_nodes.empty()) {
+ *simplified_node_name = RewriteAddOpsGroup(group);
+ }
+
+ return Status::OK();
+ }
+
+ private:
+ // Holds together an add ops subgraph that we want to rewrite together.
+ //
+ // For the graph above the AddOpsGroup will be:
+ // root_node: AddN_1
+ // absorbed_nodes: [Add_1, Add_2]
+ // input_nodes: [x, y, z, w, q, e]
+ struct AddOpsGroup {
+ const NodeDef* root_node;
+ PartialTensorShape root_shape;
+ // Add/AddN operations below the root level that were absorbed by this group
+ std::vector<NodeDef*> absorbed_nodes;
+ // Inputs of absorbed nodes that will be forwarded to rewritten AddN node
+ std::vector<string> inputs;
+ };
+
+ // Check if all inputs are fully defined and identical to expected shape
+ bool HasAllInputsOfIdenticalShape(const NodeDef& node,
+ const PartialTensorShape& shape) const {
+ const AddOpsRewriteStage* self = this;
+ return std::all_of(node.input().begin(), node.input().end(),
+ [self, &shape](const string& input) {
+ auto input_shape = self->GetInputShape(input);
+ return input_shape.IsFullyDefined() &&
+ input_shape.IsIdenticalTo(shape);
+ });
+ }
+
+ // TODO(ezhulenev): use GraphRewriter?
+ bool IsDrivenByControlDependency(const NodeDef& node) const {
+ return std::any_of(node.input().begin(), node.input().end(),
+ IsControlInput);
+ }
+
+ // TODO(ezhulenev): use GraphRewriter?
+ bool DrivesControlDependency(const NodeDef& node) const {
+ int position;
+ for (const NodeDef* output : ctx_.node_map->GetOutputs(node.name())) {
+ for (int i = 0; i < output->input_size(); ++i) {
+ auto input = output->input(i);
+ string name = ParseNodeName(input, &position);
+ if (name == node.name() && /*control input*/ position < 0) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ // Check if a node can be absorbed by current AddOpsGroup
+ bool IsAbsorbableByAddOpsGroup(const string& name, const AddOpsGroup& group) {
+ NodeDef* node;
+ Status node_status = GetInputNode(name, &node);
+ if (!node_status.ok()) {
+ return false;
+ }
+
+ PartialTensorShape shape = GetInputShape(name);
+ CHECK(shape.IsIdenticalTo(group.root_shape))
+ << "Cannot absorb a node of incompatible shape";
+
+ // check basic preconditions
+ if (!IsRewritable(node)) {
+ return false;
+ }
+ // with a single output consumer (presumably if we reach this node from
+ // previously absorbed or a root node, it means that this node is not used
+ // as an input to any other op, outside of the group)
+ if (ctx_.node_map->GetOutputs(node->name()).size() != 1) {
+ return false;
+ }
+ // must be on the same device as a root node
+ if (node->device() != group.root_node->device()) {
+ return false;
+ }
+ // All input shapes must be fully defined and equal to the node shape
+ return HasAllInputsOfIdenticalShape(*node, shape);
+ }
+
+ // Node requirements both for a root node and an absorbed node
+ bool IsRewritable(const NodeDef* node) const {
+ // only Add or AddN can be a root node
+ // TODO(ezhulenev): check if AccumulateNV2 can be supported too
+ if (!IsAdd(*node) && !IsAddN(*node)) {
+ return false;
+ }
+ // it must not be in a preserve set
+ if (ctx_.nodes_to_preserve->find(node->name()) !=
+ ctx_.nodes_to_preserve->end()) {
+ return false;
+ }
+ // it must not be a node created or absorbed by previous iteration
+ if (rewritten_nodes_.find(node->name()) != rewritten_nodes_.end()) {
+ return false;
+ }
+ // should not drive or be driven by control dependency
+ // TODO(ezhulenev): relax this condition for root node
+ return !(IsDrivenByControlDependency(*node) ||
+ DrivesControlDependency(*node));
+ }
+
+ // Create an AddOpsGroup with a root in a given node
+ Status CreateAddOpsGroup(const NodeDef* root_node, AddOpsGroup* group) {
+ group->root_node = root_node;
+ group->root_shape = GetInputShape(root_node->name());
+
+ group->absorbed_nodes.reserve(root_node->input_size());
+ for (int i = 0; i < root_node->input_size(); ++i) {
+ TF_RETURN_IF_ERROR(AbsorbInputByAddOpsGroup(root_node->input(i), group));
+ }
+
+ return Status::OK();
+ }
+
+ Status AbsorbInputByAddOpsGroup(const string& input, AddOpsGroup* group) {
+ NodeDef* node;
+ TF_RETURN_IF_ERROR(GetInputNode(input, &node));
+
+ if (IsAbsorbableByAddOpsGroup(input, *group)) {
+ group->absorbed_nodes.push_back(node);
+ for (int i = 0; i < node->input_size(); ++i) {
+ TF_RETURN_IF_ERROR(AbsorbInputByAddOpsGroup(node->input(i), group));
+ }
+ } else {
+ // If node can't be absorbed, add it to AddOpsGroup input
+ group->inputs.push_back(input);
+ }
+ return Status::OK();
+ }
+
+ const std::pair<string, string> ParseNodeScopeAndName(const string& name) {
+ auto pos = name.find_last_of("/");
+ if (pos == string::npos) {
+ return {"", name};
+ } else {
+ return {name.substr(0, pos), name.substr(pos + 1)};
+ }
+ }
+
+ // New node for AddOpsGroup is added to the same scope as a root_node. All
+ // absorbed nodes are stripped of their scope, and only names are used in a
+ // new node name.
+ //
+ // Example: AddOpsGroup(root="a/b/c/Add_2", absorbed=["d/Add_1", "e/Add"])
+ // node_name="a/b/c/AddOpsGroup_Add_2_Add_1_Add
+ string AddOpsGroupName(const AddOpsGroup& group) {
+ CHECK_NOTNULL(group.root_node);
+ string node_name;
+
+ auto root_node = ParseNodeScopeAndName(group.root_node->name());
+ auto root_scope = root_node.first;
+ auto root_name = root_node.second;
+ if (!root_scope.empty()) {
+ strings::StrAppend(&node_name, root_scope, "/");
+ }
+
+ strings::StrAppend(&node_name, kArithmeticOptimizer, "/", "AddOpsGroup_",
+ root_name);
+ for (const NodeDef* absorbed : group.absorbed_nodes) {
+ auto absorbed_node = ParseNodeScopeAndName(absorbed->name());
+ strings::StrAppend(&node_name, "_", absorbed_node.second);
+ }
+ return node_name;
+ }
+
+ // Create a new node for a AddOpsGroup and return it's name.
+ string RewriteAddOpsGroup(const AddOpsGroup& group) {
+ CHECK_GT(group.absorbed_nodes.size(), 0)
+ << "AddOpsGroup must have non empty absorbed nodes";
+
+ // name for a new node constructed from AddOpsGroup
+ string node_name = AddOpsGroupName(group);
+
+ // copy attributes from a root node
+ DataType dtype = group.root_node->attr().at("T").type();
+
+ // add new node
+ NodeDef* added_node = ctx_.optimized_graph->add_node();
+ added_node->set_name(node_name);
+ added_node->set_op("AddN");
+ added_node->set_device(group.root_node->device());
+ (*added_node->mutable_attr())["T"].set_type(dtype);
+ (*added_node->mutable_attr())["N"].set_i(group.inputs.size());
+
+ ctx_.node_map->AddNode(node_name, added_node);
+ for (string input : group.inputs) {
+ ctx_.node_map->AddOutput(input, node_name);
+ added_node->add_input(std::move(input));
+ }
+
+ VLOG(1) << "Absorbed " << group.absorbed_nodes.size()
+ << " Add/AddN nodes from the graph";
+
+ // keep track of nodes that were created or absorbed as a part of rewrite
+ rewritten_nodes_.insert(node_name);
+ for (const NodeDef* absorbed : group.absorbed_nodes) {
+ rewritten_nodes_.insert(absorbed->name());
+ }
+
+ return node_name;
+ }
+
+ // keep nodes that were added or absorbed as a part of AddOpsGroup rewrite
+ std::unordered_set<string> rewritten_nodes_;
+};
+
} // namespace
class UniqueNodes {
}
}
+// TODO(ezhulenev): extract each individual simplify rewrite into separate
+// ArithmeticOptimizerStage
string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
const NodeDef* node, SetVector<NodeDef*>* nodes_to_simplify) {
// Remove involutions applied twice.
for (int i = 0; i < optimized_graph_->node_size(); ++i) {
nodes_to_simplify.PushBack(optimized_graph_->mutable_node(i));
}
+
+ ArithmeticOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_,
+ node_map_.get(), &nodes_to_simplify);
+
+ std::vector<std::unique_ptr<ArithmeticOptimizerStage>> stages;
+
+ // Add/AddN tree rewrites
+ if (options_.enable_add_to_addn_combining) {
+ stages.push_back(
+ std::unique_ptr<ArithmeticOptimizerStage>(new AddOpsRewriteStage(ctx)));
+ }
+
+ VLOG(1) << "Simplify arithmetic ops using " << stages.size()
+ << " arithmetic optimization stages";
+
while (!nodes_to_simplify.Empty()) {
const NodeDef* node = nodes_to_simplify.PopBack();
- const string simplified_tensor =
+
+ // TODO(ezhulenev): move all rewrites into separate stages
+ string simplified_tensor =
TrySimplifyAndReplaceUses(node, &nodes_to_simplify);
+
+ // if it was not simplified try to run it through all configured stages
+ if (simplified_tensor.empty()) {
+ for (auto& stage : stages) {
+ if (stage->IsSupported(node)) {
+ TF_RETURN_IF_ERROR(stage->TrySimplify(node, &simplified_tensor));
+ if (!simplified_tensor.empty()) {
+ break;
+ }
+ }
+ }
+ }
+
+ // if it's still empty go to the next Node
if (simplified_tensor.empty()) {
continue;
}
+ // re-wire consumers of an old node to the new one
if (NodeName(simplified_tensor) != node->name()) {
// Always consider simplified_tensor for further optimizations.
NodeDef* simplified_node = node_map_->GetNode(simplified_tensor);
// Shapes are only needed in aggressive mode.
graph_properties_.reset(new GraphProperties(item));
TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false));
+ // TODO(ezhulenev): Use GraphProperties to lookup tensor shapes directly
TF_RETURN_IF_ERROR(graph_properties_->AnnotateOutputShapes(optimized_graph_));
// Perform the optimizations.
namespace tensorflow {
namespace grappler {
+
namespace {
string OptimizedName(const string& name) {
}
}
}
+} // namespace
-class ArithmeticOptimizerTest : public ::testing::Test {};
+class ArithmeticOptimizerTest : public ::testing::Test {
+ protected:
+ // Optimize a graph using ArithmeticOptimizer and prune all the nodes that no
+ // longer have any output consumers.
+ void OptimizeAndPrune(ArithmeticOptimizer* optimizer, GrapplerItem* item,
+ GraphDef* output) {
+ TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output));
+ item->graph.Swap(output);
+ TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output));
+ }
+
+ // TODO(ezhulenev): Make private. After migration to stages each test
+ // should explicitly enable required optimization for tests isolation
+ void DisableAllStages(ArithmeticOptimizer* optimizer) {
+ ArithmeticOptimizer::ArithmeticOptimizerOptions options{
+ /*enable_add_to_addn_combining*/ false};
+ optimizer->options_ = options;
+ }
+
+ void EnableAddToAddNCombining(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.enable_add_to_addn_combining = true;
+ }
+};
TEST_F(ArithmeticOptimizerTest, NoOp) {
// This trivial graph is so basic there's nothing to optimize.
for (int i = 0; i < item.graph.node_size(); ++i) {
item.graph.mutable_node(i)->set_device(devices[i]);
}
+
ArithmeticOptimizer optimizer;
+ DisableAllStages(&optimizer);
+
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
[](const NodeDef& node) { return node.op() == "Cast"; }));
}
-} // namespace
+TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ tensorflow::Scope sx = s.NewSubScope("x");
+ tensorflow::Scope sy = s.NewSubScope("y");
+
+ auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
+ auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
+ auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
+ auto add_ab = ops::Add(sx.WithOpName("Add_ab"), a, b);
+ auto add_abc = ops::Add(sy.WithOpName("Add_abc"), add_ab, c);
+
+ auto outputs = ops::Identity(s.WithOpName("outputs"), add_abc);
+
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableAddToAddNCombining(&optimizer);
+
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ // We expect the following rewrite(s) to occur:
+ //
+ // +
+ // / \
+ // + c --> AddN(a, b, c)
+ // / \
+ // a b
+ EXPECT_EQ(5, output.node_size());
+
+ NodeMap node_map(&output);
+
+ // check add tree was replaced with AddN
+ const NodeDef* collapsed_add = CHECK_NOTNULL(
+ node_map.GetNode("y/ArithmeticOptimizer/AddOpsGroup_Add_abc_Add_ab"));
+
+ EXPECT_EQ("AddN", collapsed_add->op());
+ EXPECT_EQ(3, collapsed_add->input_size());
+ EXPECT_EQ("a", collapsed_add->input(0));
+ EXPECT_EQ("b", collapsed_add->input(1));
+ EXPECT_EQ("c", collapsed_add->input(2));
+
+ // check output was re-wired to new node
+ const NodeDef* updated_outputs = CHECK_NOTNULL(node_map.GetNode("outputs"));
+
+ EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0));
+}
+
+TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
+ auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
+ auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
+ auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
+ auto add_abc = ops::Add(s.WithOpName("Add_abc"), add_ab, c);
+
+ auto x = ops::Variable(s.WithOpName("x"), {2, 2}, DT_FLOAT);
+ auto y = ops::Variable(s.WithOpName("y"), {2, 2}, DT_FLOAT);
+ auto z = ops::Variable(s.WithOpName("z"), {2, 2}, DT_FLOAT);
+ auto add_xy = ops::Add(s.WithOpName("Add_xy"), x, y);
+ auto add_xyz = ops::Add(s.WithOpName("Add_xyz"), add_xy, z);
+
+ auto mul = ops::Multiply(s.WithOpName("Mul"), add_abc, add_xyz);
+ auto outputs = ops::Identity(s.WithOpName("outputs"), mul);
+
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableAddToAddNCombining(&optimizer);
+
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ // We expect the following rewrite(s) to occur:
+ //
+ // *
+ // / \
+ // + + *
+ // / \ / \ / \
+ // + c x + --> AddN(a, b, c) AddN(x, y, z))
+ // / \ / \
+ // a b y z
+ EXPECT_EQ(10, output.node_size());
+
+ NodeMap node_map(&output);
+
+ // check left Add subtree replaced with AddN
+ const NodeDef* collapsed_left = CHECK_NOTNULL(
+ node_map.GetNode("ArithmeticOptimizer/AddOpsGroup_Add_abc_Add_ab"));
+
+ EXPECT_EQ("AddN", collapsed_left->op());
+ EXPECT_EQ(3, collapsed_left->input_size());
+ EXPECT_EQ("a", collapsed_left->input(0));
+ EXPECT_EQ("b", collapsed_left->input(1));
+ EXPECT_EQ("c", collapsed_left->input(2));
+
+ // check right Add subtree replaced with AddN
+ const NodeDef* collapsed_right = CHECK_NOTNULL(
+ node_map.GetNode("ArithmeticOptimizer/AddOpsGroup_Add_xyz_Add_xy"));
+
+ EXPECT_EQ("AddN", collapsed_right->op());
+ EXPECT_EQ(3, collapsed_right->input_size());
+ EXPECT_EQ("x", collapsed_right->input(0));
+ EXPECT_EQ("y", collapsed_right->input(1));
+ EXPECT_EQ("z", collapsed_right->input(2));
+
+ // check that Mul inputs re-wired to new Nodes
+ const NodeDef* updated_mul = CHECK_NOTNULL(node_map.GetNode("Mul"));
+
+ EXPECT_EQ("Mul", updated_mul->op());
+ EXPECT_EQ(2, updated_mul->input_size());
+ EXPECT_EQ(collapsed_left->name(), updated_mul->input(0));
+ EXPECT_EQ(collapsed_right->name(), updated_mul->input(1));
+}
+
+TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputThroughMultiplePaths) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+ auto a = ops::Variable(s.WithOpName("a"), {2, 2}, DT_FLOAT);
+ auto b = ops::Variable(s.WithOpName("b"), {2, 2}, DT_FLOAT);
+ auto c = ops::Variable(s.WithOpName("c"), {2, 2}, DT_FLOAT);
+ auto add_ab = ops::Add(s.WithOpName("Add_ab"), a, b);
+ auto add_bc = ops::Add(s.WithOpName("Add_bc"), b, c);
+ auto add_all = ops::Add(s.WithOpName("Add_all"), add_ab, add_bc);
+ auto outputs = ops::Identity(s.WithOpName("outputs"), add_all);
+
+ GrapplerItem item;
+ item.fetch = {"outputs"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableAddToAddNCombining(&optimizer);
+
+ OptimizeAndPrune(&optimizer, &item, &output);
+
+ // We expect the following rewrite(s) to occur:
+ //
+ // +
+ // / \
+ // + + --> AddN(a, b, b, c)
+ // / \ / \ ^
+ // a b c b added twice!
+ EXPECT_EQ(5, output.node_size());
+
+ NodeMap node_map(&output);
+
+ // check Add tree replaced with AddN
+ const NodeDef* collapsed_add = CHECK_NOTNULL(node_map.GetNode(
+ "ArithmeticOptimizer/AddOpsGroup_Add_all_Add_ab_Add_bc"));
+
+ EXPECT_EQ("AddN", collapsed_add->op());
+ EXPECT_EQ(4, collapsed_add->input_size());
+ EXPECT_EQ("a", collapsed_add->input(0));
+ EXPECT_EQ("b", collapsed_add->input(1));
+ EXPECT_EQ("b", collapsed_add->input(2));
+ EXPECT_EQ("c", collapsed_add->input(3));
+}
+
} // namespace grappler
} // namespace tensorflow