From afa17984849881f39fb56c6e3500d866539924d5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 10 Apr 2018 11:09:37 -0700 Subject: [PATCH] Adds support for hoisting out common denominator in arithmetic_optimizer PiperOrigin-RevId: 192314177 --- .../grappler/optimizers/arithmetic_optimizer.cc | 103 +++++++++++++++------ .../optimizers/arithmetic_optimizer_test.cc | 85 ++++++++++++++++- 2 files changed, 161 insertions(+), 27 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index fa0f7c1..463c332 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -695,15 +695,20 @@ class AddOpsRewriteStage : public ArithmeticNodesGroupOptimizerStage { } }; -// Use the commutativity and (left- and right-) distributive property of -// multiplication over addition to hoist common factors out of aggregate nodes -// where all the inputs are Mul nodes. This pattern occurs frequently in -// regularization terms for the gradients during training. +// Use the distributive property of multiplication and division over addition, +// along with commutativity of the former, to hoist common factors/denominators +// out of aggregate nodes where ALL the inputs are Mul/Div nodes. +// This pattern occurs frequently in regularization terms for the gradients +// during training. // // For example, we can rewrite an expression of the form: // AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn)) // to the following: // Mul(x, AddN(y1, y2, y3, ... yn)) +// For division, we can rewrite +// AddN(Div(y1, x), Div(y2, x), Div(y3, x), ... Div(yn, x)) +// to: +// Div(AddN(y1, y2, y3, ... yn), x) class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { public: explicit HoistCommonFactorOutOfAggregation( @@ -720,9 +725,11 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { Status TrySimplify(NodeDef* node, string* simplified_node_name) override { TF_RETURN_IF_ERROR(EnsureNodeIsSupported(node)); + bool common_factor_is_denominator = false; std::set common_factors; std::vector ctrl_deps; - TF_RETURN_IF_ERROR(GetCommonFactors(node, &common_factors, &ctrl_deps)); + TF_RETURN_IF_ERROR(GetCommonFactors( + node, &common_factors, &common_factor_is_denominator, &ctrl_deps)); if (common_factors.size() == 1) { const string& common_factor = *common_factors.begin(); @@ -730,24 +737,31 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { // Gather up the non-shared factors bool shapes_match = true; std::vector unique_factors; - TF_RETURN_IF_ERROR(GetUniqueFactors(node, common_factor, &shapes_match, - &unique_factors)); + TF_RETURN_IF_ERROR(GetUniqueFactors(node, common_factor, + common_factor_is_denominator, + &shapes_match, &unique_factors)); if (shapes_match) { NodeDef* input_0; TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &input_0)); - // Use a copy of the first Mul node for the outer multiplication. - NodeDef* new_mul_node = AddCopyNode(OuterMulNodeName(node), input_0); + // Use a copy of the first node for the outer multiplication/division. + NodeDef* new_outer_node = AddCopyNode( + OuterNodeName(node, common_factor_is_denominator), input_0); // And a copy of aggregation node as one of the inner operands NodeDef* new_add_node = AddCopyNode(InnerAddNodeName(node), node); - new_mul_node->set_device(node->device()); - new_mul_node->set_input(0, common_factor); - new_mul_node->set_input(1, new_add_node->name()); + new_outer_node->set_device(node->device()); + if (common_factor_is_denominator) { + new_outer_node->set_input(0, new_add_node->name()); + new_outer_node->set_input(1, common_factor); + } else { + new_outer_node->set_input(0, common_factor); + new_outer_node->set_input(1, new_add_node->name()); + } - ctx_.node_map->AddOutput(common_factor, new_mul_node->name()); - ctx_.node_map->AddOutput(new_add_node->name(), new_mul_node->name()); + ctx_.node_map->AddOutput(common_factor, new_outer_node->name()); + ctx_.node_map->AddOutput(new_add_node->name(), new_outer_node->name()); // Hoist non-shared factors up into the new AddN node. for (int i = 0; i < unique_factors.size(); ++i) { @@ -766,17 +780,18 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { AddToOptimizationQueue(new_add_node); // do not optimize the same node twice rewritten_nodes_.insert(node->name()); - *simplified_node_name = new_mul_node->name(); + *simplified_node_name = new_outer_node->name(); } } return Status::OK(); } private: - // Get a name for new outer Mul node - string OuterMulNodeName(const NodeDef* node) const { + // Get a name for new outer node + string OuterNodeName(const NodeDef* node, bool is_div) const { auto scope_and_name = ParseNodeScopeAndName(node->name()); - return OptimizedNodeName(scope_and_name, "Mul"); + return is_div ? OptimizedNodeName(scope_and_name, "Div") + : OptimizedNodeName(scope_and_name, "Mul"); } // Get a name new inner Add node @@ -785,11 +800,17 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { return OptimizedNodeName(scope_and_name, "Add"); } - // Determine the set of common factors if the input nodes are all Mul nodes. + // Determine the set of common factors if the input nodes are all Mul or + // Div nodes. Status GetCommonFactors(const NodeDef* node, std::set* common_factors, + bool* common_factor_is_denominator, std::vector* ctrl_deps) const { CHECK(common_factors->empty()); + CHECK_NOTNULL(common_factor_is_denominator); + *common_factor_is_denominator = false; + bool has_mul = false; + bool has_div = false; for (int i = 0; i < node->input_size(); ++i) { if (i > 0 && common_factors->empty()) break; if (IsControlInput(node->input(i))) { @@ -799,12 +820,36 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { NodeDef* input; TF_RETURN_IF_ERROR(GetInputNode(node->input(i), &input)); - if (!IsMul(*input)) { + if ((!IsMul(*input) && !IsAnyDiv(*input)) || (IsMul(*input) && has_div) || + (IsAnyDiv(*input) && has_mul)) { + // Break if input is neither a Mul or Div, or if there are both Mul & + // Div Ops. common_factors->clear(); break; + } else if (IsAnyDiv(*input)) { + has_div = true; + // In case of possible common dividers, we avoid hoisting out if any + // input is not float/double, since integer division is not distributive + // over addition. + OpInfo::TensorProperties properties0, properties1; + TF_RETURN_IF_ERROR(GetTensorProperties(input->input(0), &properties0)); + TF_RETURN_IF_ERROR(GetTensorProperties(input->input(1), &properties1)); + if (properties0.dtype() != DT_FLOAT && + properties0.dtype() != DT_DOUBLE && + properties1.dtype() != DT_FLOAT && + properties1.dtype() != DT_DOUBLE) { + common_factors->clear(); + break; + } + } else if (IsMul(*input)) { + has_mul = true; } - std::set factors_i{input->input(0), input->input(1)}; + // We only focus on common factors from denominators if any Op is a + // Div. + std::set factors_i = + has_mul ? std::set{input->input(0), input->input(1)} + : std::set{input->input(1)}; if (i == 0) { std::swap(*common_factors, factors_i); } else { @@ -819,6 +864,8 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { ctrl_deps->push_back(input->input(i)); } } + + *common_factor_is_denominator = has_div; return Status::OK(); } @@ -827,6 +874,7 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { // have the same shape since the other aggregation ops do not support // broadcasting. Status GetUniqueFactors(const NodeDef* node, const string& common_factor, + const bool common_factor_is_denominator, bool* shapes_match, std::vector* unique_factors) const { *shapes_match = true; @@ -837,11 +885,13 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { if (IsControlInput(input)) { break; } - NodeDef* mul_node; - TF_RETURN_IF_ERROR(GetInputNode(input, &mul_node)); + NodeDef* inner_node; + TF_RETURN_IF_ERROR(GetInputNode(input, &inner_node)); const int unique_factor_index = - mul_node->input(0) == common_factor ? 1 : 0; - unique_factors->push_back(mul_node->input(unique_factor_index)); + common_factor_is_denominator + ? 0 + : (inner_node->input(0) == common_factor ? 1 : 0); + unique_factors->push_back(inner_node->input(unique_factor_index)); if (i > 0 && !IsAdd(*node)) { OpInfo::TensorProperties lhs; OpInfo::TensorProperties rhs; @@ -857,7 +907,8 @@ class HoistCommonFactorOutOfAggregation : public ArithmeticOptimizerStage { // if graph rewrite happens in multiple passes without graph pruning between // them, it's possible that rewritten node already exists in a graph return rewritten_nodes_.find(node->name()) != rewritten_nodes_.end() || - ctx_.node_map->NodeExists(OuterMulNodeName(node)); + ctx_.node_map->NodeExists(OuterNodeName(node, false)) || + ctx_.node_map->NodeExists(OuterNodeName(node, true)); } // keep names of the nodes that were optimized by this stage diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index 9677175..e639812 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -31,6 +31,9 @@ namespace grappler { namespace { +constexpr char kHoistFactorOptimizerDiv[] = + "ArithmeticOptimizer/HoistCommonFactor_Div_"; + constexpr char kHoistFactorOptimizerMul[] = "ArithmeticOptimizer/HoistCommonFactor_Mul_"; @@ -42,6 +45,11 @@ string HoistMulName(const string& name) { return AddPrefixToNodeName(name, kHoistFactorOptimizerMul, ""); } +// Optimized name of outer Div node by HoistCommonFactorOutOfAggregation +string HoistDivName(const string& name) { + return AddPrefixToNodeName(name, kHoistFactorOptimizerDiv, ""); +} + // Optimized name of inner Add node by HoistCommonFactorOutOfAggregation string HoistAddName(const string& name) { return AddPrefixToNodeName(name, kHoistFactorOptimizerAdd, ""); @@ -558,7 +566,7 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) { EXPECT_EQ("^Placeholder", add_1_const_node->input(0)); } -TEST_F(ArithmeticOptimizerTest, HoistFactor) { +TEST_F(ArithmeticOptimizerTest, HoistFactorMul) { for (bool matching_shapes : {true, false}) { for (bool use_addn : {true, false}) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); @@ -625,6 +633,81 @@ TEST_F(ArithmeticOptimizerTest, HoistFactor) { } } +TEST_F(ArithmeticOptimizerTest, HoistFactorDiv) { + for (bool matching_shapes : {true, false}) { + for (bool use_addn : {true, false}) { + for (bool use_ints : {true, false}) { + tensorflow::Scope s = tensorflow::Scope::NewRootScope(); + Output x = use_ints + ? ops::Const(s.WithOpName("x"), {1, 2}, {1, 2}) + : ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2}); + Output y1 = use_ints + ? ops::Const(s.WithOpName("y1"), {3, 4}, {1, 2}) + : ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2}); + Output y2; + if (matching_shapes) { + y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5, 6}, {1, 2}) + : ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2}); + } else { + y2 = use_ints ? ops::Const(s.WithOpName("y2"), {5}, {1, 1}) + : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1}); + } + Output div1 = ops::Div(s.WithOpName("div1"), y1, x); + Output div2 = ops::Div(s.WithOpName("div2"), y2, x); + Output id = + use_addn + ? ops::Identity(s.WithOpName("id"), + ops::AddN(s.WithOpName("add"), {div1, div2})) + : ops::Identity(s.WithOpName("id"), + ops::Add(s.WithOpName("add"), div1, div2)); + + GrapplerItem item; + item.fetch = {"id"}; + TF_CHECK_OK(s.ToGraphDef(&item.graph)); + + ArithmeticOptimizer optimizer; + EnableOnlyHoistCommonFactor(&optimizer); + + GraphDef output; + OptimizeTwice(&optimizer, &item, &output); + + // We expect the following rewrite(s) to occur: + // + // Add Div + // / \ / \ + // Div Div -> Add x + // / \ / \ / \ + // y1 x y2 x y1 y2 + // + // If "root" op is AddN and shapes does not match, this rewrite is not + // possible and graph should stay intact. + NodeMap node_map(&output); + + if ((use_addn && !matching_shapes) || use_ints) { + VerifyGraphsMatch(item.graph, output, __LINE__); + } else { + EXPECT_EQ(9, output.node_size()); + + const NodeDef* new_add_node = node_map.GetNode(HoistAddName("add")); + ASSERT_TRUE(new_add_node != nullptr) << "Hoisted Add node not found"; + EXPECT_EQ("y1", new_add_node->input(0)); + EXPECT_EQ("y2", new_add_node->input(1)); + + const NodeDef* new_div_node = node_map.GetNode(HoistDivName("add")); + ASSERT_TRUE(new_div_node != nullptr) << "Hoisted Div node not found"; + EXPECT_EQ(new_add_node->name(), new_div_node->input(0)); + EXPECT_EQ("x", new_div_node->input(1)); + + const NodeDef* id_node = node_map.GetNode("id"); + ASSERT_TRUE(id_node != nullptr) << "Id node not found"; + EXPECT_EQ("id", id_node->name()); + EXPECT_EQ(HoistDivName("add"), id_node->input(0)); + } + } + } + } +} + TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output re = ops::Const(s.WithOpName("re"), {1.0f, 2.0f, 3.0f, 4.0f}, {2, 2}); -- 2.7.4