From a0106575e1f445dde23c96a85b650f38251a2ca3 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 23 May 2018 12:35:05 -0700 Subject: [PATCH] Extracts the SimplifyReshape optimization into its own method. PiperOrigin-RevId: 197770994 --- .../core/grappler/optimizers/constant_folding.cc | 75 ++++++++++++---------- .../core/grappler/optimizers/constant_folding.h | 19 ++++-- 2 files changed, 53 insertions(+), 41 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 8bdb164..4ebe1ca 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1631,20 +1631,20 @@ Status ConstantFolding::ReplaceOperationWithConstant( return Status::OK(); } -Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, - GraphProperties* properties, - bool use_shape_info) { +Status ConstantFolding::SimplifyGraph(bool use_shape_info, + GraphDef* optimized_graph, + GraphProperties* properties) { for (int i = 0; i < optimized_graph->node_size(); ++i) { - TF_RETURN_IF_ERROR(SimplifyNode(optimized_graph->mutable_node(i), - optimized_graph, properties, - use_shape_info)); + TF_RETURN_IF_ERROR(SimplifyNode(use_shape_info, + optimized_graph->mutable_node(i), + optimized_graph, properties)); } return Status::OK(); } -Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph, - GraphProperties* properties, - bool use_shape_info) { +Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, + GraphDef* optimized_graph, + GraphProperties* properties) { if (IsSplit(*node) && node->attr().at("num_split").i() == 1) { ReplaceOperationWithIdentity(1, *properties, node, optimized_graph); return Status::OK(); @@ -2150,20 +2150,16 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph, graph_modified_ = true; return Status::OK(); } - if (use_shape_info && IsSimplifiableReshape(*node, *properties)) { - DataType output_type = node->attr().at("T").type(); - node->set_op("Identity"); - node->clear_attr(); - (*node->mutable_attr())["T"].set_type(output_type); - *node->mutable_input(1) = AsControlDependency(node->input(1)); + + if (SimplifyReshape(*properties, use_shape_info, node)) { graph_modified_ = true; return Status::OK(); } bool arithmetic_simplification_succeed = false; - Status simplify_arithmetic_status = SimplifyArithmeticOperations( - optimized_graph, properties, node, use_shape_info, - &arithmetic_simplification_succeed); + Status simplify_arithmetic_status = + SimplifyArithmeticOperations(*properties, use_shape_info, optimized_graph, + node, &arithmetic_simplification_succeed); if (!simplify_arithmetic_status.ok()) { return simplify_arithmetic_status; } else if (arithmetic_simplification_succeed) { @@ -2204,9 +2200,21 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph, return Status::OK(); } +bool ConstantFolding::SimplifyReshape(const GraphProperties& properties, + bool use_shape_info, NodeDef* node) { + if (!use_shape_info) return false; + if (!IsSimplifiableReshape(*node, properties)) return false; + DataType output_type = node->attr().at("T").type(); + node->set_op("Identity"); + node->clear_attr(); + (*node->mutable_attr())["T"].set_type(output_type); + *node->mutable_input(1) = AsControlDependency(node->input(1)); + return true; +} + Status ConstantFolding::SimplifyArithmeticOperations( - GraphDef* optimized_graph, GraphProperties* properties, NodeDef* node, - bool use_shape_info, bool* success) { + const GraphProperties& properties, bool use_shape_info, + GraphDef* optimized_graph, NodeDef* node, bool* success) { const bool is_mul = IsMul(*node) || IsLogicalAnd(*node); const bool is_matmul = IsMatMul(*node); const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node); @@ -2215,8 +2223,8 @@ Status ConstantFolding::SimplifyArithmeticOperations( // Simplify arithmetic operations with ones or zeros. if (use_shape_info && (is_mul || is_matmul || is_add || is_sub || is_any_div) && - properties->HasInputProperties(node->name()) && - properties->HasOutputProperties(node->name())) { + properties.HasInputProperties(node->name()) && + properties.HasOutputProperties(node->name())) { const NodeDef* x = node_map_->GetNode(node->input(0)); const NodeDef* y = node_map_->GetNode(node->input(1)); if (x == nullptr || y == nullptr) { @@ -2224,19 +2232,19 @@ Status ConstantFolding::SimplifyArithmeticOperations( node->DebugString()); } const TensorShapeProto& output_shape = - properties->GetOutputProperties(node->name())[0].shape(); + properties.GetOutputProperties(node->name())[0].shape(); // Simplify element-wise multiplication by ones or addition/subtraction // of zeros. const TensorShapeProto& y_shape = - properties->GetInputProperties(node->name())[1].shape(); + properties.GetInputProperties(node->name())[1].shape(); const bool x_is_zero = IsZeros(*x); const bool x_is_one = x_is_zero ? false : IsOnes(*x); const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape); if (y_matches_output_shape && ((is_mul && x_is_one) || (is_add && x_is_zero))) { // 1 * y = y or 0 + y = y. - ReplaceOperationWithSnapshot(1, *properties, node, optimized_graph); + ReplaceOperationWithSnapshot(1, properties, node, optimized_graph); *success = true; return Status::OK(); } @@ -2259,14 +2267,14 @@ Status ConstantFolding::SimplifyArithmeticOperations( } const TensorShapeProto& x_shape = - properties->GetInputProperties(node->name())[0].shape(); + properties.GetInputProperties(node->name())[0].shape(); const bool y_is_zero = IsZeros(*y); const bool y_is_one = y_is_zero ? false : IsOnes(*y); const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape); if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) || ((is_add || is_sub) && y_is_zero))) { // x * 1 = x or x / 1 = x or x +/- 0 = x - ReplaceOperationWithSnapshot(0, *properties, node, optimized_graph); + ReplaceOperationWithSnapshot(0, properties, node, optimized_graph); *success = true; return Status::OK(); } @@ -2276,9 +2284,8 @@ Status ConstantFolding::SimplifyArithmeticOperations( const PartialTensorShape shp(output_shape); if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) { bool replace_succeed = false; - Status replace_op_status = - ReplaceOperationWithConstant(1, *properties, output_shape, node, - optimized_graph, &replace_succeed); + Status replace_op_status = ReplaceOperationWithConstant( + 1, properties, output_shape, node, optimized_graph, &replace_succeed); if (!replace_op_status.ok()) { return replace_op_status; } else if (replace_succeed) { @@ -2296,7 +2303,7 @@ Status ConstantFolding::SimplifyArithmeticOperations( if (shp.IsFullyDefined()) { bool replace_succeed = false; Status replace_op_status = - ReplaceOperationWithConstant(0, *properties, output_shape, node, + ReplaceOperationWithConstant(0, properties, output_shape, node, optimized_graph, &replace_succeed); if (!replace_op_status.ok()) { return replace_op_status; @@ -2309,11 +2316,11 @@ Status ConstantFolding::SimplifyArithmeticOperations( // matches the output shape and thus forward the corresponding zero // input. if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) { - ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); + ReplaceOperationWithIdentity(0, properties, node, optimized_graph); *success = true; return Status::OK(); } else if (is_mul && y_is_zero && y_matches_output_shape) { - ReplaceOperationWithIdentity(1, *properties, node, optimized_graph); + ReplaceOperationWithIdentity(1, properties, node, optimized_graph); *success = true; return Status::OK(); } @@ -2855,7 +2862,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, TF_RETURN_IF_ERROR(FoldGraph(optimized_graph)); node_map_.reset(new NodeMap(optimized_graph)); TF_RETURN_IF_ERROR( - SimplifyGraph(optimized_graph, &properties, can_use_shape_info)); + SimplifyGraph(can_use_shape_info, optimized_graph, &properties)); return Status::OK(); } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index e477934..3cf379f 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -97,10 +97,10 @@ class ConstantFolding : public GraphOptimizer { const GraphProperties& properties) const; bool IsSimplifiableReshape(const NodeDef& node, const GraphProperties& properties) const; - Status SimplifyGraph(GraphDef* output, GraphProperties* properties, - bool use_shape_info); - Status SimplifyNode(NodeDef* node, GraphDef* optimized_graph, - GraphProperties* properties, bool use_shape_info); + Status SimplifyGraph(bool use_shape_info, GraphDef* optimized_graph, + GraphProperties* properties); + Status SimplifyNode(bool use_shape_info, NodeDef* node, + GraphDef* optimized_graph, GraphProperties* properties); Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item, GraphDef* output); @@ -134,11 +134,16 @@ class ConstantFolding : public GraphOptimizer { // Simplifies arithmetic operations with ones or zeros. Returns the status, // and updates the success input argument that denotes if any simplification // was applied. - Status SimplifyArithmeticOperations(GraphDef* optimized_graph, - GraphProperties* properties, - NodeDef* node, bool use_shape_info, + Status SimplifyArithmeticOperations(const GraphProperties& properties, + bool use_shape_info, + GraphDef* optimized_graph, NodeDef* node, bool* success); + // Simplifies a Reshape operation to an Identity operation if the input node + // to the operation is a constant. + bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info, + NodeDef* node); + // Points to an externally provided device or to owned_device_; RewriterConfig::Toggle opt_level_; DeviceBase* cpu_device_; -- 2.7.4