From 27533f61ddfa674ceccb59777d24e2fe0157f70c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 12 Mar 2018 13:50:35 -0700 Subject: [PATCH] Move "hoist common factor out of aggregation" optimization to a separate stage. 1) Use a new naming scheme for optimized ops, share it with AddOpsRewrite 2) Make sure that tests actually test that optimized nodes exists in a graph PiperOrigin-RevId: 188772892 --- .../grappler/optimizers/arithmetic_optimizer.cc | 461 +++++++++++++++------ .../grappler/optimizers/arithmetic_optimizer.h | 1 + .../optimizers/arithmetic_optimizer_test.cc | 212 ++++++---- 3 files changed, 462 insertions(+), 212 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 177b073..c0fcfaf 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -290,25 +290,30 @@ NodeDef* GetTailOfValuePreservingChain( struct ArithmeticOptimizerContext { ArithmeticOptimizerContext( const std::unordered_set* nodes_to_preserve, - GraphDef* optimized_graph, NodeMap* node_map, + GraphDef* optimized_graph, NodeMap* node_map, FrameMap* frame_map, SetVector* nodes_to_simplify) : nodes_to_preserve(nodes_to_preserve), optimized_graph(optimized_graph), node_map(node_map), + frame_map(frame_map), nodes_to_simplify(nodes_to_simplify) {} const std::unordered_set* nodes_to_preserve; GraphDef* optimized_graph; NodeMap* node_map; + FrameMap* frame_map; SetVector* nodes_to_simplify; }; // Base class for single arithmetic optimization: e.g. Bitcast optimization, // AddOps optimization, etc... +// TODO(ezhulenev): extract this class to be reused by other multi-stage +// graph optimizers (const_folding, dependency_optimizer, etc...) class ArithmeticOptimizerStage { public: - explicit ArithmeticOptimizerStage(ArithmeticOptimizerContext ctx) - : ctx_(ctx) {} + explicit ArithmeticOptimizerStage(const string& name, + const ArithmeticOptimizerContext& ctx) + : name_(name), ctx_(ctx) {} virtual ~ArithmeticOptimizerStage() = default; // Check if we should try to simplify node. Returning true doesn't @@ -336,6 +341,46 @@ class ArithmeticOptimizerStage { string* simplified_node_name) = 0; protected: + struct ScopedNodeName { + string scope; + string name; + }; + + const ScopedNodeName ParseScopedNodeName(const string& name) const { + auto pos = name.find_last_of("/"); + if (pos == string::npos) { + return {"", name}; + } else { + return {name.substr(0, pos), name.substr(pos + 1)}; + } + } + + // Prefix optimized node name with stage name and rewrite_rule + const string OptimizedNodeName(const string& rewrite_rule, + const ScopedNodeName& scoped_node_name) const { + return MakeOptimizedNodeName(strings::StrCat(name_, "_", rewrite_rule), + scoped_node_name); + } + + // Prefix optimized node name with stage name and rewrite_rule + const string OptimizedNodeName(const string& rewrite_rule, + const ScopedNodeName& scoped_node_name, + const std::vector& node_names) const { + return MakeOptimizedNodeName(strings::StrCat(name_, "_", rewrite_rule), + scoped_node_name, node_names); + } + + // Prefix optimized node name with stage name + const string OptimizedNodeName(const ScopedNodeName& scoped_node_name) const { + return MakeOptimizedNodeName(name_, scoped_node_name); + } + + // Prefix optimized node name with stage name + const string OptimizedNodeName(const ScopedNodeName& scoped_node_name, + const std::vector& node_names) const { + return MakeOptimizedNodeName(name_, scoped_node_name, node_names); + } + // Simplification graph rewrite can create additional nodes that are inputs // to final simplified node, they can be also added to the arithmetic // optimizer queue for further optimization. @@ -374,7 +419,91 @@ class ArithmeticOptimizerStage { } } - ArithmeticOptimizerContext ctx_; + NodeDef* AddCopyNode(const string& name, const NodeDef* node_to_copy) { + CHECK(node_to_copy != nullptr); + CHECK(!ctx_.node_map->NodeExists(name)) + << "Node " << name << " already exists in a graph"; + NodeDef* new_node = ctx_.optimized_graph->add_node(); + *new_node = *node_to_copy; + new_node->set_name(name); + ctx_.node_map->AddNode(name, new_node); + return new_node; + } + + NodeDef* AddEmptyNode(const string& name) { + CHECK(!ctx_.node_map->NodeExists(name)) + << "Node " << name << " already exists in a graph"; + NodeDef* new_node = ctx_.optimized_graph->add_node(); + new_node->set_name(name); + ctx_.node_map->AddNode(name, new_node); + return new_node; + } + + // TODO(ezhulenev): remove this method from ArithmeticOptimizer when all + // optimizations will be migrated to stages + void AddFrameControlDeps(const NodeDef* old_node, + const std::vector& new_nodes, + const string& source_for_ctrl_dep, + const std::vector& sinks_for_control_dep) { + const auto frame_it = ctx_.frame_map->find(old_node); + if (frame_it != ctx_.frame_map->end()) { + for (auto node : new_nodes) { + ctx_.frame_map->emplace(node, frame_it->second); + } + if (!source_for_ctrl_dep.empty() && !sinks_for_control_dep.empty()) { + const string ctrl_dep = ConstantFolding::AddControlDependency( + source_for_ctrl_dep, ctx_.optimized_graph, ctx_.node_map); + for (auto node : sinks_for_control_dep) { + MaybeAddControlInput(ctrl_dep, node, ctx_.optimized_graph, + ctx_.node_map); + } + } + } + } + + const string name_; + const ArithmeticOptimizerContext ctx_; + + private: + // Get a name for a new node obtained by optimizing a single node of the + // original graph. The optimized node is placed under the original node scope. + // + // Node name uniqueness is guaranteed by unique name of an original node in + // a same scope. + // + // Example: MakeOptimizedNodeName("AwesomeRewrite", "a/b/c/Add_1") + // Optimized name: "a/b/c/ArithmeticOptimizer/AwesomeRewrite_Add_1" + const string MakeOptimizedNodeName( + const string& prefix, const ScopedNodeName& scoped_node_name) const { + string node_name; + strings::StrAppend(&node_name, scoped_node_name.scope); + if (!node_name.empty()) strings::StrAppend(&node_name, "/"); + strings::StrAppend(&node_name, kArithmeticOptimizer, "/", prefix, "_", + scoped_node_name.name); + return node_name; + } + + // Get a name for a new node obtained by optimizing multiple nodes of the + // original graph, starting from "root". The optimized node is placed under + // the original scope of a "root" node. + // + // Node name uniqueness is guaranteed by unique name of a "root" node in + // a same scope. + // + // Example: + // MakeOptimizedNodeName("AwesomeRewrite", "a/b/Add_AB", ["x/y/Add_XY"]) + // Optimized name: + // "a/b/ArithmeticOptimizer/AwesomeRewrite_Add_AB_Add_XY" + const string MakeOptimizedNodeName( + const string& prefix, const ScopedNodeName& scoped_node_name, + const std::vector& node_names) const { + string node_name = MakeOptimizedNodeName(prefix, scoped_node_name); + for (const string& optimized : node_names) { + auto scoped_node = ParseScopedNodeName(optimized); + strings::StrAppend(&node_name, "_", scoped_node.name); + } + return node_name; + } }; // Rewrite a tree of Add/AddN with a single AddN operation, consuming all the @@ -393,8 +522,8 @@ class ArithmeticOptimizerStage { // q e class AddOpsRewriteStage : public ArithmeticOptimizerStage { public: - explicit AddOpsRewriteStage(ArithmeticOptimizerContext ctx) - : ArithmeticOptimizerStage(ctx), rewritten_nodes_() {} + explicit AddOpsRewriteStage(const ArithmeticOptimizerContext& ctx) + : ArithmeticOptimizerStage("AddOpsRewrite", ctx), rewritten_nodes_() {} ~AddOpsRewriteStage() override = default; @@ -422,7 +551,7 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { AddOpsGroup group; TF_RETURN_IF_ERROR(CreateAddOpsGroup(node, &group)); - if (!group.absorbed_nodes.empty()) { + if (!group.absorbed_nodes.empty() && !IsRewritten(group)) { *simplified_node_name = RewriteAddOpsGroup(group); } @@ -530,6 +659,12 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { DrivesControlDependency(*node)); } + // Check that optimized group node name doesn't exists. It might happen if + // graph optimized multiple times without pruning beween invocations. + bool IsRewritten(const AddOpsGroup& group) const { + return ctx_.node_map->NodeExists(AddOpsGroupName(group)); + } + // Create an AddOpsGroup with a root in a given node Status CreateAddOpsGroup(const NodeDef* root_node, AddOpsGroup* group) { group->root_node = root_node; @@ -559,39 +694,23 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { return Status::OK(); } - const std::pair ParseNodeScopeAndName(const string& name) { - auto pos = name.find_last_of("/"); - if (pos == string::npos) { - return {"", name}; - } else { - return {name.substr(0, pos), name.substr(pos + 1)}; - } - } - // New node for AddOpsGroup is added to the same scope as a root_node. All // absorbed nodes are stripped of their scope, and only names are used in a // new node name. // // Example: AddOpsGroup(root="a/b/c/Add_2", absorbed=["d/Add_1", "e/Add"]) // node_name="a/b/c/AddOpsGroup_Add_2_Add_1_Add - string AddOpsGroupName(const AddOpsGroup& group) { + string AddOpsGroupName(const AddOpsGroup& group) const { CHECK_NOTNULL(group.root_node); - string node_name; - auto root_node = ParseNodeScopeAndName(group.root_node->name()); - auto root_scope = root_node.first; - auto root_name = root_node.second; - if (!root_scope.empty()) { - strings::StrAppend(&node_name, root_scope, "/"); - } + auto root = ParseScopedNodeName(group.root_node->name()); - strings::StrAppend(&node_name, kArithmeticOptimizer, "/", "AddOpsGroup_", - root_name); - for (const NodeDef* absorbed : group.absorbed_nodes) { - auto absorbed_node = ParseNodeScopeAndName(absorbed->name()); - strings::StrAppend(&node_name, "_", absorbed_node.second); - } - return node_name; + std::vector absorbed_node_names(group.absorbed_nodes.size()); + std::transform(group.absorbed_nodes.begin(), group.absorbed_nodes.end(), + absorbed_node_names.begin(), + [](const NodeDef* node) { return node->name(); }); + + return OptimizedNodeName(root, absorbed_node_names); } // Create a new node for a AddOpsGroup and return it's name. @@ -605,18 +724,17 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { // copy attributes from a root node DataType dtype = group.root_node->attr().at("T").type(); - // add new node - NodeDef* added_node = ctx_.optimized_graph->add_node(); - added_node->set_name(node_name); + // add new AddN node + NodeDef* added_node = AddEmptyNode(node_name); added_node->set_op("AddN"); added_node->set_device(group.root_node->device()); (*added_node->mutable_attr())["T"].set_type(dtype); (*added_node->mutable_attr())["N"].set_i(group.inputs.size()); - ctx_.node_map->AddNode(node_name, added_node); - for (string input : group.inputs) { + // all inputs of absorbed nodes are added to the new node + for (const string& input : group.inputs) { ctx_.node_map->AddOutput(input, node_name); - added_node->add_input(std::move(input)); + added_node->add_input(input); } VLOG(1) << "Absorbed " << group.absorbed_nodes.size() @@ -635,11 +753,167 @@ class AddOpsRewriteStage : public ArithmeticOptimizerStage { std::unordered_set rewritten_nodes_; }; +// Use the commutativity and (left- and right-) distributive property of +// multiplication over addition to hoist common factors out of aggregate nodes +// where all the inputs are Mul nodes. This pattern occurs frequently in +// regularization terms for the gradients during training. +// +// For example, we can rewrite an expression of the form: +// AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn)) +// to the following: +// Mul(x, AddN(y1, y2, y3, ... yn)) +class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { + public: + explicit HoistCommonFactorOutOfAggregation( + const ArithmeticOptimizerContext& ctx) + : ArithmeticOptimizerStage("HoistCommonFactor", ctx) {} + ~HoistCommonFactorOutOfAggregation() override = default; + + bool IsSupported(const NodeDef* node) const override { + return IsAggregate(*node) && NumNonControlInputs(*node) > 1 && + !IsRewritten(node); + } + + Status TrySimplify(const NodeDef* node, + string* simplified_node_name) override { + CHECK(IsSupported(node)); + + std::set common_factors; + TF_RETURN_IF_ERROR(GetCommonFactors(node, &common_factors)); + + if (common_factors.size() == 1) { + const string& common_factor = *common_factors.begin(); + + // Gather up the non-shared factors + bool shapes_match = true; + std::vector unique_factors; + TF_RETURN_IF_ERROR(GetUniqueFactors(node, common_factor, &shapes_match, + &unique_factors)); + + if (shapes_match) { + NodeDef* input_0; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input_0)); + + // Use a copy of the first Mul node for the outer multiplication. + NodeDef* new_mul_node = AddCopyNode(OuterMulNodeName(node), input_0); + // And a copy of aggregation node as one of the inner operands + NodeDef* new_add_node = AddCopyNode(InnerAddNodeName(node), node); + + new_mul_node->set_device(node->device()); + new_mul_node->set_input(0, common_factor); + new_mul_node->set_input(1, new_add_node->name()); + + ctx_.node_map->AddOutput(common_factor, new_mul_node->name()); + ctx_.node_map->AddOutput(new_add_node->name(), new_mul_node->name()); + + // Hoist non-shared factors up into the new AddN node. + for (int i = 0; i < unique_factors.size(); ++i) { + new_add_node->set_input(i, unique_factors[i]); + } + + // Add frame dependencies that the original node might have had. + AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor, + {new_add_node}); + + // optimize new inner aggregation node + AddToOptimizationQueue(new_add_node); + // do not optimize the same node twice + rewritten_nodes_.insert(node->name()); + *simplified_node_name = new_mul_node->name(); + } + } + return Status::OK(); + } + + private: + // Get a name for new outer Mul node + string OuterMulNodeName(const NodeDef* node) const { + auto scoped_node = ParseScopedNodeName(node->name()); + return OptimizedNodeName("Mul", scoped_node); + } + + // Get a name new inner Add node + string InnerAddNodeName(const NodeDef* node) const { + auto scoped_node = ParseScopedNodeName(node->name()); + return OptimizedNodeName("Add", scoped_node); + } + + // Determine the set of common factors if the input nodes are all Mul nodes. + Status GetCommonFactors(const NodeDef* node, + std::set* common_factors) const { + CHECK(common_factors->empty()); + + for (int i = 0; i < node->input_size(); ++i) { + if (i > 0 && common_factors->empty()) break; + if (IsControlInput(node->input(i))) break; + + NodeDef* input; + TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input)); + + if (!IsMul(*input)) { + common_factors->clear(); + break; + } + + std::set factors_i{input->input(0), input->input(1)}; + if (i == 0) { + std::swap(*common_factors, factors_i); + } else { + std::set intersection; + std::set_intersection( + factors_i.begin(), factors_i.end(), common_factors->begin(), + common_factors->end(), + std::inserter(intersection, intersection.begin())); + std::swap(*common_factors, intersection); + } + } + return Status::OK(); + } + + // Gather up the non-shared factors (the y's in the example). + // Unless the aggregation is Add, we have to make sure that all the y's + // have the same shape since the other aggregation ops do not support + // broadcasting. + Status GetUniqueFactors(const NodeDef* node, const string& common_factor, + bool* shapes_match, + std::vector* unique_factors) const { + *shapes_match = true; + unique_factors->reserve(node->input_size()); + + for (int i = 0; i < node->input_size() && shapes_match; ++i) { + const string& input = node->input(i); + if (IsControlInput(input)) { + break; + } + NodeDef* mul_node; + TF_RETURN_IF_ERROR(GetInputNode(input, &mul_node)); + const int unique_factor_index = + mul_node->input(0) == common_factor ? 1 : 0; + unique_factors->push_back(mul_node->input(unique_factor_index)); + if (i > 0 && !IsAdd(*node)) { + *shapes_match = ShapesEqual(unique_factors->front(), + unique_factors->back(), *ctx_.node_map); + } + } + return Status::OK(); + } + + bool IsRewritten(const NodeDef* node) const { + // if graph rewrite happens in multiple passes without graph pruning between + // them, it's possible that rewritten node already exists in a graph + return rewritten_nodes_.find(node->name()) != rewritten_nodes_.end() || + ctx_.node_map->NodeExists(OuterMulNodeName(node)); + } + + // keep names of the nodes that were optimized by this stage + std::unordered_set rewritten_nodes_; +}; + // Removes inverse transpose nodes class RemoveInverseTranspose : public ArithmeticOptimizerStage { public: - explicit RemoveInverseTranspose(ArithmeticOptimizerContext ctx) - : ArithmeticOptimizerStage(ctx) {} + explicit RemoveInverseTranspose(const ArithmeticOptimizerContext& ctx) + : ArithmeticOptimizerStage("RemoveInverseTranspose", ctx) {} ~RemoveInverseTranspose() override = default; bool IsSupported(const NodeDef* node) const override { @@ -702,8 +976,8 @@ class RemoveInverseTranspose : public ArithmeticOptimizerStage { // 2) Rewrite Bitcast(Bitcast(x, type1), type2) => Bitcast(x, type2) class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage { public: - explicit RemoveRedundantBitcastStage(ArithmeticOptimizerContext ctx) - : ArithmeticOptimizerStage(ctx) {} + explicit RemoveRedundantBitcastStage(const ArithmeticOptimizerContext& ctx) + : ArithmeticOptimizerStage("RemoveRedundantBitcast", ctx) {} ~RemoveRedundantBitcastStage() override = default; bool IsSupported(const NodeDef* node) const override { @@ -742,8 +1016,8 @@ class RemoveRedundantBitcastStage : public ArithmeticOptimizerStage { // Remove Casts whose source type and destination type are equal. class RemoveRedundantCastStage : public ArithmeticOptimizerStage { public: - explicit RemoveRedundantCastStage(ArithmeticOptimizerContext ctx) - : ArithmeticOptimizerStage(ctx) {} + explicit RemoveRedundantCastStage(const ArithmeticOptimizerContext& ctx) + : ArithmeticOptimizerStage("RemoveRedundantCast", ctx) {} ~RemoveRedundantCastStage() override = default; bool IsSupported(const NodeDef* node) const override { return IsCast(*node); } @@ -1276,98 +1550,6 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses( } } - // Use the commutativity and (left- and right-) distributive property of - // multiplication over addition to hoist common factors out of aggregate nodes - // where all the inputs are Mul nodes. This pattern occurs frequently in - // regularization terms for the gradients during training. - // For example, we can rewrite an expression of the form: - // AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn)) - // to the following: - // Mul(x, AddN(y1, y2, y3, ... yn)) - if (IsAggregate(*node) && NumNonControlInputs(*node) > 1 && - !OptimizedNodeExists(*node, "hoist_add") && - !OptimizedNodeExists(*node, "hoist_mul")) { - // Determine the set of common factors if the input nodes are all Mul nodes. - std::set common_factors; - for (int i = 0; i < node->input_size(); ++i) { - if (i > 0 && common_factors.empty()) { - break; - } - if (IsControlInput(node->input(i))) { - break; - } - const NodeDef* input = node_map_->GetNode(node->input(i)); - if (input->op() == "Mul") { - std::set factors_i{input->input(0), input->input(1)}; - if (i == 0) { - std::swap(common_factors, factors_i); - } else { - std::set intersection; - std::set_intersection( - factors_i.begin(), factors_i.end(), common_factors.begin(), - common_factors.end(), - std::inserter(intersection, intersection.begin())); - std::swap(common_factors, intersection); - } - } else { - common_factors.clear(); - } - } - if (common_factors.size() == 1) { - const string& common_factor = *common_factors.begin(); - - // Gather up the non-shared factors (the y's in the example). - // Unless the aggregation is Add, we have to make sure that all the y's - // have the same shape since the other aggregation ops do not support - // broadcasting. - std::vector unique_factors; - unique_factors.reserve(node->input_size()); - bool shapes_match = true; - for (int i = 0; i < node->input_size() && shapes_match; ++i) { - const string& input = node->input(i); - if (IsControlInput(input)) { - break; - } - const NodeDef* mul_node = node_map_->GetNode(input); - const int unique_factor_index = - mul_node->input(0) == common_factor ? 1 : 0; - unique_factors.push_back(mul_node->input(unique_factor_index)); - if (i > 0 && !IsAdd(*node)) { - shapes_match = ShapesEqual(unique_factors.front(), - unique_factors.back(), *node_map_); - } - } - - if (shapes_match) { - // 1. Use a copy of the first Mul node for the outer multiplication. - NodeDef* new_mul_node = AddNode(OptimizedNodeName(*node, "hoist_mul"), - node_map_->GetNode(node->input(0))); - NodeDef* new_add_node = AddNode(*node, "hoist_add", /*copy_node=*/true); - new_mul_node->set_device(node->device()); - new_mul_node->set_input(0, common_factor); - node_map_->AddOutput(common_factor, new_mul_node->name()); - new_mul_node->set_input(1, new_add_node->name()); - node_map_->AddOutput(new_add_node->name(), new_mul_node->name()); - - // 2. Hoist non-shared factors up into the new AddN node. - nodes_to_simplify->PushBack(new_add_node); - for (int i = 0; i < node->input_size(); ++i) { - const string& input = node->input(i); - if (IsControlInput(input)) { - break; - } - new_add_node->set_input(i, unique_factors[i]); - } - - // 3. Add frame dependencies that the original node might have had. - AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor, - {new_add_node}); - - return new_mul_node->name(); - } - } - } - // Fold Transpose into matrix multiplication. if ((node->op() == "MatMul" || node->op() == "SparseMatMul" || node->op() == "BatchMatMul") && @@ -1444,8 +1626,9 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() { nodes_to_simplify.PushBack(optimized_graph_->mutable_node(i)); } - ArithmeticOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_, - node_map_.get(), &nodes_to_simplify); + const ArithmeticOptimizerContext ctx(&nodes_to_preserve_, optimized_graph_, + node_map_.get(), &frame_map_, + &nodes_to_simplify); std::vector> stages; @@ -1453,6 +1636,10 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() { stages.push_back( std::unique_ptr(new AddOpsRewriteStage(ctx))); } + if (options_.hoist_common_factor_out_of_aggregation) { + stages.push_back(std::unique_ptr( + new HoistCommonFactorOutOfAggregation(ctx))); + } if (options_.remove_inverse_transpose) { stages.push_back(std::unique_ptr( new RemoveInverseTranspose(ctx))); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 7870844..d5a7af5 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -56,6 +56,7 @@ class ArithmeticOptimizer : public GraphOptimizer { // Granular control for arithmetic optimizer stages struct ArithmeticOptimizerOptions { bool combine_add_to_addn = true; + bool hoist_common_factor_out_of_aggregation = true; bool remove_inverse_transpose = true; bool remove_redundant_bitcast = true; bool remove_redundant_cast = true; diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 98842b2..e1f4762 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -30,6 +30,22 @@ namespace grappler { namespace { +constexpr char kHoistFactorOptimizerMul[] = + "ArithmeticOptimizer/HoistCommonFactor_Mul_"; + +constexpr char kHoistFactorOptimizerAdd[] = + "ArithmeticOptimizer/HoistCommonFactor_Add_"; + +// Optimized name of outer Mul node by HoistCommonFactorOutOfAggregation +string HoistMulName(const string& name) { + return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, ""); +} + +// Optimized name of inner Add node by HoistCommonFactorOutOfAggregation +string HoistAddName(const string& name) { + return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, ""); +} + string OptimizedName(const string& name) { return AddPrefixToNodeName(name, kArithmeticOptimizer); } @@ -61,22 +77,40 @@ class ArithmeticOptimizerTest : public GrapplerTest { TF_EXPECT_OK(ModelPruner().Optimize(nullptr, *item, output)); } + // Run ArithmeticOptimizer twice to make sure the rewrite is idempotent. + void OptimizeTwice(ArithmeticOptimizer* optimizer, GrapplerItem* item, + GraphDef* output) { + TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); + item->graph.Swap(output); + TF_EXPECT_OK(optimizer->Optimize(nullptr, *item, output)); + } + // TODO(ezhulenev): Make private. After migration to stages each test // should explicitly enable required optimization for tests isolation void DisableAllStages(ArithmeticOptimizer* optimizer) { ArithmeticOptimizer::ArithmeticOptimizerOptions options; options.combine_add_to_addn = false; + options.hoist_common_factor_out_of_aggregation = false; options.remove_inverse_transpose = false; options.remove_redundant_bitcast = false; options.remove_redundant_cast = false; optimizer->options_ = options; } + void DisableAddToAddNCombining(ArithmeticOptimizer* optimizer) { + optimizer->options_.combine_add_to_addn = false; + } + void EnableOnlyAddToAddNCombining(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.combine_add_to_addn = true; } + void EnableOnlyHoistCommonFactor(ArithmeticOptimizer* optimizer) { + DisableAllStages(optimizer); + optimizer->options_.hoist_common_factor_out_of_aggregation = true; + } + void EnableOnlyRemoveInverseTranspose(ArithmeticOptimizer* optimizer) { DisableAllStages(optimizer); optimizer->options_.remove_inverse_transpose = true; @@ -396,59 +430,66 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) { } ArithmeticOptimizer optimizer; - DisableAllStages(&optimizer); + DisableAddToAddNCombining(&optimizer); GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); - // Run the optimizer twice to make sure the rewrite is idempotent. - item.graph.Swap(&output); - status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + OptimizeTwice(&optimizer, &item, &output); - EXPECT_EQ(17, output.node_size()); - // The graph gets optimized to + // We expect the following rewrite(s) to occur: + // // Mul(p, - // Add(Add(Const(2), Const(2)), - // Add(Const(2), Const(2)))) + // Add_6(Add_4(Const(2), Const(2)), + // Add_5(Const(2), Const(2)))) + NodeMap node_map(&output); + EXPECT_EQ(17, output.node_size()); - for (const auto& node : output.node()) { - if ("id" == node.name()) { - EXPECT_EQ(1, node.input_size()); - EXPECT_EQ(OptimizedName("Add_6_hoist_mul"), node.input(0)); - } else if (OptimizedName("Add_6_hoist_mul") == node.name()) { - EXPECT_EQ("Mul", node.op()); - EXPECT_EQ(2, node.input_size()); - EXPECT_EQ("Placeholder", node.input(0)); - EXPECT_EQ(OptimizedName("Add_6_hoist_add"), node.input(1)); - } else if (OptimizedName("Add_6_hoist_add") == node.name()) { - EXPECT_EQ("Add", node.op()); - EXPECT_EQ(3, node.input_size()); - EXPECT_EQ(OptimizedName("Add_4_hoist_add"), node.input(0)); - EXPECT_EQ(OptimizedName("Add_5_hoist_add"), node.input(1)); - EXPECT_EQ("^Placeholder", node.input(2)); - } else if (OptimizedName("Add_4_hoist_add") == node.name()) { - EXPECT_EQ("Add", node.op()); - EXPECT_EQ(3, node.input_size()); - EXPECT_EQ(OptimizedName("Add_const"), node.input(0)); - EXPECT_EQ(OptimizedName("Add_1_const"), node.input(1)); - EXPECT_EQ("^Placeholder", node.input(2)); - } else if (OptimizedName("Add_5_hoist_add") == node.name()) { - EXPECT_EQ("Add", node.op()); - EXPECT_EQ(3, node.input_size()); - EXPECT_EQ(OptimizedName("Add_const"), node.input(0)); - EXPECT_EQ(OptimizedName("Add_1_const"), node.input(1)); - EXPECT_EQ("^Placeholder", node.input(2)); - } else if (OptimizedName("Add_const") == node.name()) { - EXPECT_EQ("Const", node.op()); - EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("^Placeholder", node.input(0)); - } else if (OptimizedName("Add_1_const") == node.name()) { - EXPECT_EQ("Const", node.op()); - EXPECT_EQ(1, node.input_size()); - EXPECT_EQ("^Placeholder", node.input(0)); - } - } + + const NodeDef* id_node = node_map.GetNode("id"); + ASSERT_TRUE(id_node != nullptr); + EXPECT_EQ(1, id_node->input_size()); + EXPECT_EQ(HoistMulName("Add_6"), id_node->input(0)); + + const NodeDef* mul_node = node_map.GetNode(HoistMulName("Add_6")); + ASSERT_TRUE(mul_node != nullptr); + EXPECT_EQ(2, mul_node->input_size()); + EXPECT_EQ("Placeholder", mul_node->input(0)); + EXPECT_EQ(HoistAddName("Add_6"), mul_node->input(1)); + + const NodeDef* add_6_node = node_map.GetNode(HoistAddName("Add_6")); + ASSERT_TRUE(add_6_node != nullptr); + EXPECT_EQ(3, add_6_node->input_size()); + EXPECT_EQ(HoistAddName("Add_4"), add_6_node->input(0)); + EXPECT_EQ(HoistAddName("Add_5"), add_6_node->input(1)); + EXPECT_EQ("^Placeholder", add_6_node->input(2)); + + const NodeDef* add_4_node = node_map.GetNode(HoistAddName("Add_4")); + ASSERT_TRUE(add_4_node != nullptr); + EXPECT_EQ("Add", add_4_node->op()); + EXPECT_EQ(3, add_4_node->input_size()); + EXPECT_EQ(OptimizedName("Add_const"), add_4_node->input(0)); + EXPECT_EQ(OptimizedName("Add_1_const"), add_4_node->input(1)); + EXPECT_EQ("^Placeholder", add_4_node->input(2)); + + const NodeDef* add_5_node = node_map.GetNode(HoistAddName("Add_5")); + ASSERT_TRUE(add_5_node != nullptr); + EXPECT_EQ("Add", add_5_node->op()); + EXPECT_EQ(3, add_5_node->input_size()); + EXPECT_EQ(OptimizedName("Add_const"), add_5_node->input(0)); + EXPECT_EQ(OptimizedName("Add_1_const"), add_5_node->input(1)); + EXPECT_EQ("^Placeholder", add_5_node->input(2)); + + const NodeDef* add_const_node = node_map.GetNode(OptimizedName("Add_const")); + ASSERT_TRUE(add_const_node != nullptr); + EXPECT_EQ("Const", add_const_node->op()); + EXPECT_EQ(1, add_const_node->input_size()); + EXPECT_EQ("^Placeholder", add_const_node->input(0)); + + const NodeDef* add_1_const_node = + node_map.GetNode(OptimizedName("Add_1_const")); + ASSERT_TRUE(add_1_const_node != nullptr); + EXPECT_EQ("Const", add_1_const_node->op()); + EXPECT_EQ(1, add_1_const_node->input_size()); + EXPECT_EQ("^Placeholder", add_1_const_node->input(0)); } TEST_F(ArithmeticOptimizerTest, HoistFactor) { @@ -469,31 +510,46 @@ TEST_F(ArithmeticOptimizerTest, HoistFactor) { ops::Add(s.WithOpName("add"), mul1, mul2)); GrapplerItem item; + item.fetch = {"id"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); + ArithmeticOptimizer optimizer; + EnableOnlyHoistCommonFactor(&optimizer); + GraphDef output; - Status status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); - // Run the optimizer twice to make sure the rewrite is idempotent. - item.graph.Swap(&output); - status = optimizer.Optimize(nullptr, item, &output); - TF_EXPECT_OK(status); + OptimizeTwice(&optimizer, &item, &output); + + // We expect the following rewrite(s) to occur: + // + // Add Mul + // / \ / \ + // Mul Mul -> x Add + // / \ / \ / \ + // x y1 y2 x y1 y2 + // + // If "root" op is AddN and shapes does not match, this rewrite is not + // possible and graph should stay intact. + NodeMap node_map(&output); if (use_addn && !matching_shapes) { VerifyGraphsMatch(item.graph, output, __LINE__); } else { EXPECT_EQ(9, output.node_size()); - const NodeDef& new_add = output.node(8); - EXPECT_EQ(OptimizedName("add_hoist_add"), new_add.name()); - EXPECT_EQ("y1", new_add.input(0)); - EXPECT_EQ("y2", new_add.input(1)); - const NodeDef& new_mul = output.node(7); - EXPECT_EQ(OptimizedName("add_hoist_mul"), new_mul.name()); - EXPECT_EQ("x", new_mul.input(0)); - EXPECT_EQ(OptimizedName("add_hoist_add"), new_mul.input(1)); - const NodeDef& new_id = output.node(6); - EXPECT_EQ("id", new_id.name()); - EXPECT_EQ(OptimizedName("add_hoist_mul"), new_id.input(0)); + + const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add")); + ASSERT_TRUE(new_add_node != nullptr) << "Hoisted Add node not found"; + EXPECT_EQ("y1", new_add_node->input(0)); + EXPECT_EQ("y2", new_add_node->input(1)); + + const NodeDef* new_mul_node = node_map.GetNode(HoistMulName("add")); + ASSERT_TRUE(new_mul_node != nullptr) << "Hoisted Mul node not found"; + EXPECT_EQ("x", new_mul_node->input(0)); + EXPECT_EQ(new_add_node->name(), new_mul_node->input(1)); + + const NodeDef* id_node = node_map.GetNode("id"); + ASSERT_TRUE(id_node != nullptr) << "Id node not found"; + EXPECT_EQ("id", id_node->name()); + EXPECT_EQ(HoistMulName("add"), id_node->input(0)); } } } @@ -1249,8 +1305,9 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) { NodeMap node_map(&output); // check add tree was replaced with AddN - const NodeDef* collapsed_add = CHECK_NOTNULL( - node_map.GetNode("y/ArithmeticOptimizer/AddOpsGroup_Add_abc_Add_ab")); + const NodeDef* collapsed_add = + node_map.GetNode("y/ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab"); + ASSERT_TRUE(collapsed_add != nullptr); EXPECT_EQ("AddN", collapsed_add->op()); EXPECT_EQ(3, collapsed_add->input_size()); @@ -1259,7 +1316,8 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteCollapseAddsOfIdenticalShape) { EXPECT_EQ("c", collapsed_add->input(2)); // check output was re-wired to new node - const NodeDef* updated_outputs = CHECK_NOTNULL(node_map.GetNode("outputs")); + const NodeDef* updated_outputs = node_map.GetNode("outputs"); + ASSERT_TRUE(updated_outputs != nullptr); EXPECT_EQ(collapsed_add->name(), updated_outputs->input(0)); } @@ -1306,8 +1364,9 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) { NodeMap node_map(&output); // check left Add subtree replaced with AddN - const NodeDef* collapsed_left = CHECK_NOTNULL( - node_map.GetNode("ArithmeticOptimizer/AddOpsGroup_Add_abc_Add_ab")); + const NodeDef* collapsed_left = + node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_abc_Add_ab"); + ASSERT_TRUE(collapsed_left != nullptr); EXPECT_EQ("AddN", collapsed_left->op()); EXPECT_EQ(3, collapsed_left->input_size()); @@ -1316,8 +1375,9 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) { EXPECT_EQ("c", collapsed_left->input(2)); // check right Add subtree replaced with AddN - const NodeDef* collapsed_right = CHECK_NOTNULL( - node_map.GetNode("ArithmeticOptimizer/AddOpsGroup_Add_xyz_Add_xy")); + const NodeDef* collapsed_right = + node_map.GetNode("ArithmeticOptimizer/AddOpsRewrite_Add_xyz_Add_xy"); + ASSERT_TRUE(collapsed_right != nullptr); EXPECT_EQ("AddN", collapsed_right->op()); EXPECT_EQ(3, collapsed_right->input_size()); @@ -1326,7 +1386,8 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteMultiplePasses) { EXPECT_EQ("z", collapsed_right->input(2)); // check that Mul inputs re-wired to new Nodes - const NodeDef* updated_mul = CHECK_NOTNULL(node_map.GetNode("Mul")); + const NodeDef* updated_mul = node_map.GetNode("Mul"); + ASSERT_TRUE(updated_mul != nullptr); EXPECT_EQ("Mul", updated_mul->op()); EXPECT_EQ(2, updated_mul->input_size()); @@ -1367,8 +1428,9 @@ TEST_F(ArithmeticOptimizerTest, AddOpsRewriteAddInputThroughMultiplePaths) { NodeMap node_map(&output); // check Add tree replaced with AddN - const NodeDef* collapsed_add = CHECK_NOTNULL(node_map.GetNode( - "ArithmeticOptimizer/AddOpsGroup_Add_all_Add_ab_Add_bc")); + const NodeDef* collapsed_add = node_map.GetNode( + "ArithmeticOptimizer/AddOpsRewrite_Add_all_Add_ab_Add_bc"); + ASSERT_TRUE(collapsed_add != nullptr); EXPECT_EQ("AddN", collapsed_add->op()); EXPECT_EQ(4, collapsed_add->input_size()); -- 2.7.4