From ceda30408f66a7eea86dc359164deb662d5a32d0 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 3 May 2018 13:00:56 -0700 Subject: [PATCH] Enable unary chain hoisting optimization for concat/split/splitv by default. PiperOrigin-RevId: 195297330 --- tensorflow/core/grappler/op_types.cc | 38 ++++++++++++++-------- tensorflow/core/grappler/op_types.h | 4 +++ .../grappler/optimizers/arithmetic_optimizer.cc | 18 +++++++--- .../grappler/optimizers/arithmetic_optimizer.h | 2 +- .../optimizers/arithmetic_optimizer_test.cc | 16 ++++----- 5 files changed, 51 insertions(+), 27 deletions(-) diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc index 7c936df..c48dc00 100644 --- a/tensorflow/core/grappler/op_types.cc +++ b/tensorflow/core/grappler/op_types.cc @@ -476,28 +476,40 @@ bool IsInvolution(const NodeDef& node) { return involution_ops->count(node.op()) > 0; } +bool IsValueAndOrderAndShapePreserving(const NodeDef& node) { + if (NumNonControlInputs(node) == 1 && IsAggregate(node)) { + return true; + } + static const std::unordered_set* + value_and_order_and_shape_preserving_ops = + CHECK_NOTNULL((new const std::unordered_set{ + "CheckNumerics", + "DebugGradientIdentity", + "DeepCopy" + "Enter", + "Exit", + "Identity", + "IdentityN", + "PreventGradient", + "Print", + "Snapshot", + "StopGradient", + })); + return value_and_order_and_shape_preserving_ops->count(node.op()) > 0; +} + bool IsValueAndOrderPreserving(const NodeDef& node) { if (NumNonControlInputs(node) == 1 && IsAggregate(node)) { return true; } static const std::unordered_set* value_and_order_preserving_ops = CHECK_NOTNULL((new const std::unordered_set{ - "CheckNumerics", - "DebugGradientIdentity", - "DeepCopy" - "Enter", - "Exit", "ExpandDims", - "Identity", - "IdentityN", - "PreventGradient", - "Print", - "Reshape", "Snapshot", "Squeeze", - "StopGradient", })); - return value_and_order_preserving_ops->count(node.op()) > 0; + return value_and_order_preserving_ops->count(node.op()) > 0 || + IsValueAndOrderAndShapePreserving(node); } bool IsValuePreserving(const NodeDef& node) { @@ -564,7 +576,7 @@ bool IsUnaryElementWise(const NodeDef& node) { "Tanh", })); return element_wise_ops->count(node.op()) > 0 || - (!IsIdentityN(node) && IsValueAndOrderPreserving(node)); + (!IsIdentityN(node) && IsValueAndOrderAndShapePreserving(node)); } bool HasOpDef(const NodeDef& node) { diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h index 7a1b438..e33dd21 100644 --- a/tensorflow/core/grappler/op_types.h +++ b/tensorflow/core/grappler/op_types.h @@ -174,6 +174,10 @@ bool ModifiesInputsInPlace(const NodeDef& node); // own inverse such that f(f(x)) == x. bool IsInvolution(const NodeDef& node); +// Returns true if the op preserves the order and value of elements +// and shape of its first input tensor. +bool IsValueAndOrderAndShapePreserving(const NodeDef& node); + // Returns true if the op preserves the order and value of elements in its // first input tensor and possible changes its shape. bool IsValueAndOrderPreserving(const NodeDef& node); diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index d6510ba..2a5654f 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -1400,6 +1400,11 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { return n > 1; } else if (IsSplit(*node) || IsSplitV(*node)) { const int num_split = node->attr().at("num_split").i(); + if (NumNonControlOutputs(*node, *ctx().node_map) > num_split) { + // TODO(rmlarsen): Remove this constraint when we have optimizations + // in place for merging slices into splits. + return false; + } return num_split > 1 && !IsAlreadyOptimized(*node); } return false; @@ -1458,13 +1463,13 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { if (tails.empty()) { return Status::OK(); } - AddControlInputs(ctrl_inputs, root_node); AddToOptimizationQueue(root_node); optimized_nodes_.insert(root_node->name()); if (node_is_concat_) { + AddControlInputs(ctrl_inputs, root_node); return HoistChainForConcat(prefix_length, tails, root_node); } else { - return HoistChainForSplit(prefix_length, tails, root_node); + return HoistChainForSplit(prefix_length, tails, ctrl_inputs, root_node); } } @@ -1542,9 +1547,8 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { IsInPreserveSet(*op)) { return false; } - if (node_is_concat_ && - ctx().node_map->GetOutputs(op->name()).size() > 1) { - // TODO(rmlarsen): Allow and hoist outgoing control edges. + if (ctx().node_map->GetOutputs(op->name()).size() > 1) { + // TODO(rmlarsen): Allow outgoing control edges. return false; } } @@ -1612,6 +1616,7 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { } Status HoistChainForSplit(const int prefix_length, const ChainLinkSet& tails, + std::set* ctrl_inputs, NodeDef* split_node) { // Create a new chain before the split node to process the input tensor. const string& split_name = split_node->name(); @@ -1646,6 +1651,9 @@ class HoistCWiseUnaryChainsStage : public ArithmeticOptimizerStage { cur_copy->add_input(orig_input); ctx().node_map->UpdateOutput(NodeName(orig_input), split_name, cur_copy->name()); + // Make sure all the control inputs are satisfied before running the first + // node in the new chain. + AddControlInputs(ctrl_inputs, cur_copy); // Connect all consumers of the tail nodes directly to the // output port of Split from which the chain started. diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h index 3b297ec..6309dc1 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h @@ -65,7 +65,7 @@ class ArithmeticOptimizer : public GraphOptimizer { bool remove_redundant_bitcast = true; bool remove_redundant_cast = true; bool remove_negation = true; - bool hoist_cwise_unary_chains = false; + bool hoist_cwise_unary_chains = true; bool convert_sqrt_div_to_rsqrt_mul = false; // Choose which arithmetic optimizer stages will be enabled for a given diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc index f903f53..d32743f 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc @@ -2320,16 +2320,16 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) { EXPECT_NE(node.name(), "cos_exp_b2"); if (node.name() == "split1") { - EXPECT_EQ(3, node.input_size()); + EXPECT_EQ(2, node.input_size()); EXPECT_EQ("axis", node.input(0)); EXPECT_EQ("ArithmeticOptimizer/_sin_a_split1", node.input(1)); - EXPECT_EQ("^ctrl1", node.input(2)); found++; } if (node.name() == "ArithmeticOptimizer/_sin_a_split1") { EXPECT_EQ("Sin", node.op()); - EXPECT_EQ(1, node.input_size()); + EXPECT_EQ(2, node.input_size()); EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("^ctrl1", node.input(1)); found++; } if (node.name() == "id_a") { @@ -2349,8 +2349,11 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) { } if (node.name() == "ArithmeticOptimizer/_exp_a2_split2") { EXPECT_EQ("Exp", node.op()); - EXPECT_EQ(1, node.input_size()); + EXPECT_EQ(4, node.input_size()); EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("^ctrl1", node.input(1)); + EXPECT_EQ("^ctrl2", node.input(2)); + EXPECT_EQ("^ctrl3", node.input(3)); found++; } if (node.name() == "ArithmeticOptimizer/_cos_exp_a2_split2") { @@ -2360,13 +2363,10 @@ TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryIntoSplit) { found++; } if (node.name() == "split2") { - EXPECT_EQ(6, node.input_size()); + EXPECT_EQ(3, node.input_size()); EXPECT_EQ("ArithmeticOptimizer/_cos_exp_a2_split2", node.input(0)); EXPECT_EQ("size_splits2", node.input(1)); EXPECT_EQ("axis", node.input(2)); - EXPECT_EQ("^ctrl1", node.input(3)); - EXPECT_EQ("^ctrl2", node.input(4)); - EXPECT_EQ("^ctrl3", node.input(5)); found++; } if (node.name() == "id_a2") { -- 2.7.4