From 68546a6cfd18ac1a16f6d6a1843882aea4243f55 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 18 May 2018 06:27:13 -0700 Subject: [PATCH] Extracts the following optimizations into methods: SimplifyArithmeticOperations ReduceDivToReciprocalMul PiperOrigin-RevId: 197137281 --- .../core/grappler/optimizers/constant_folding.cc | 127 +++++++++++++++------ .../core/grappler/optimizers/constant_folding.h | 15 ++- 2 files changed, 105 insertions(+), 37 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 782ccff..9137b9d 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -16,6 +16,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "tensorflow/core/grappler/optimizers/constant_folding.h" + #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.pb.h" @@ -1566,9 +1567,13 @@ void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node, Status ConstantFolding::ReplaceOperationWithConstant( double value, const GraphProperties& properties, - const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) { + const TensorShapeProto& shape, NodeDef* node, GraphDef* graph, + bool* success) { const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties); - if (dtype == DT_INVALID) return Status::OK(); + if (dtype == DT_INVALID) { + *success = false; + return Status::OK(); + } AttrValue tensor_attr; TF_RETURN_IF_ERROR( @@ -1587,7 +1592,7 @@ Status ConstantFolding::ReplaceOperationWithConstant( node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep); node->set_input(i, ctrl_dep); } - graph_modified_ = true; + *success = true; return Status::OK(); } @@ -1605,7 +1610,6 @@ Status ConstantFolding::SimplifyGraph(GraphDef* optimized_graph, Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph, GraphProperties* properties, bool use_shape_info) { - const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE; if (IsSplit(*node) && node->attr().at("num_split").i() == 1) { ReplaceOperationWithIdentity(1, *properties, node, optimized_graph); return Status::OK(); @@ -2029,6 +2033,48 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph, return Status::OK(); } + bool arithmetic_simplification_succeed = false; + Status simplify_arithmetic_status = SimplifyArithmeticOperations( + optimized_graph, properties, node, use_shape_info, + &arithmetic_simplification_succeed); + if (!simplify_arithmetic_status.ok()) { + return simplify_arithmetic_status; + } else if (arithmetic_simplification_succeed) { + graph_modified_ = true; + return Status::OK(); + } + + if (ReduceDivToReciprocalMul(optimized_graph, node)) { + graph_modified_ = true; + return Status::OK(); + } + + if (ConstantPushDown(node)) { + graph_modified_ = true; + return Status::OK(); + } + + if (PartialConstPropThroughIdentityN(node)) { + graph_modified_ = true; + return Status::OK(); + } + + if (PartialAssocOpConstFolding(optimized_graph, properties, node)) { + graph_modified_ = true; + return Status::OK(); + } + + if (PartialConcatConstFolding(optimized_graph, properties, node)) { + graph_modified_ = true; + return Status::OK(); + } + + return Status::OK(); +} + +Status ConstantFolding::SimplifyArithmeticOperations( + GraphDef* optimized_graph, GraphProperties* properties, NodeDef* node, + bool use_shape_info, bool* success) { const bool is_mul = IsMul(*node) || IsLogicalAnd(*node); const bool is_matmul = IsMatMul(*node); const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node); @@ -2059,12 +2105,14 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph, ((is_mul && x_is_one) || (is_add && x_is_zero))) { // 1 * y = y or 0 + y = y. ReplaceOperationWithSnapshot(1, *properties, node, optimized_graph); + *success = true; return Status::OK(); } if (y_matches_output_shape && (is_sub && x_is_zero)) { // Replace 0 - y with Neg(y). ReplaceSubtractionFromZeroByNegation(node, optimized_graph); + *success = true; return Status::OK(); } @@ -2073,6 +2121,7 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph, DataType type = node->attr().at("T").type(); if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) { ReplaceDivisionOfOnesByReciprocal(node, optimized_graph); + *success = true; return Status::OK(); } } @@ -2086,40 +2135,68 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph, ((is_add || is_sub) && y_is_zero))) { // x * 1 = x or x / 1 = x or x +/- 0 = x ReplaceOperationWithSnapshot(0, *properties, node, optimized_graph); + *success = true; return Status::OK(); } // x OR true = true OR y = true. + bool updated_graph = false; const PartialTensorShape shp(output_shape); if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) { - TF_RETURN_IF_ERROR(ReplaceOperationWithConstant( - 1, *properties, output_shape, node, optimized_graph)); + bool replace_succeed = false; + Status replace_op_status = + ReplaceOperationWithConstant(1, *properties, output_shape, node, + optimized_graph, &replace_succeed); + if (!replace_op_status.ok()) { + return replace_op_status; + } else if (replace_succeed) { + updated_graph = true; + } } // Simplify multiplication and matmul by zeros. // Also optimize zeros divided by a tensor, but only if we are in // aggressive mode, since we might get rid of divisions by zero. + const bool is_aggressive = opt_level_ == RewriterConfig::AGGRESSIVE; bool optimize_zeros_divided_by_y = is_any_div && x_is_zero && is_aggressive; if ((x_is_zero || y_is_zero) && (is_mul || is_matmul || optimize_zeros_divided_by_y)) { if (shp.IsFullyDefined()) { - TF_RETURN_IF_ERROR(ReplaceOperationWithConstant( - 0, *properties, output_shape, node, optimized_graph)); - return Status::OK(); + bool replace_succeed = false; + Status replace_op_status = + ReplaceOperationWithConstant(0, *properties, output_shape, node, + optimized_graph, &replace_succeed); + if (!replace_op_status.ok()) { + return replace_op_status; + } else if (replace_succeed) { + *success = true; + return Status::OK(); + } } // Even if an input shape is only partially known, we may known that it // matches the output shape and thus forward the corresponding zero // input. if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) { ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); + *success = true; return Status::OK(); } else if (is_mul && y_is_zero && y_matches_output_shape) { ReplaceOperationWithIdentity(1, *properties, node, optimized_graph); + *success = true; return Status::OK(); } } + if (updated_graph) { + *success = true; + return Status::OK(); + } } + *success = false; + return Status::OK(); +} +bool ConstantFolding::ReduceDivToReciprocalMul(GraphDef* optimized_graph, + NodeDef* node) { // Strength reduce floating point division by a constant Div(x, const) to // multiplication by the reciprocal Mul(x, Reciprocal(const)). This in turn // will be constant folded to Mul(x, 1.0/const). @@ -2128,15 +2205,15 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph, const NodeDef* denom = node_map_->GetNode(const_input); CHECK(denom != nullptr); if (!IsReallyConstant(*denom)) { - return Status::OK(); + return false; } if (node->attr().count("T") == 0) { - return Status::OK(); + return false; } DataType type = node->attr().at("T").type(); if (IsDiv(*node) && !(DataTypeIsFloating(type) || DataTypeIsComplex(type))) { - return Status::OK(); + return false; } // Insert new reciprocal op and change node from Div to Mul. NodeDef* reciprocal_node = optimized_graph->add_node(); @@ -2150,31 +2227,9 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph, node->set_input(1, reciprocal_node->name()); node_map_->AddNode(reciprocal_node->name(), reciprocal_node); node_map_->UpdateOutput(node->name(), const_input, reciprocal_node->name()); - graph_modified_ = true; - return Status::OK(); - } - - if (ConstantPushDown(node)) { - graph_modified_ = true; - return Status::OK(); - } - - if (PartialConstPropThroughIdentityN(node)) { - graph_modified_ = true; - return Status::OK(); - } - - if (PartialAssocOpConstFolding(optimized_graph, properties, node)) { - graph_modified_ = true; - return Status::OK(); - } - - if (PartialConcatConstFolding(optimized_graph, properties, node)) { - graph_modified_ = true; - return Status::OK(); + return true; } - - return Status::OK(); + return false; } bool ConstantFolding::ConstantPushDown(NodeDef* node) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 227caba..6c99120 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -88,7 +88,8 @@ class ConstantFolding : public GraphOptimizer { Status ReplaceOperationWithConstant(double value, const GraphProperties& properties, const TensorShapeProto& shape, - NodeDef* node, GraphDef* graph); + NodeDef* node, GraphDef* graph, + bool* success); void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph); Status FoldGraph(GraphDef* output); @@ -121,6 +122,18 @@ class ConstantFolding : public GraphOptimizer { // the transformation applied successfully. bool ConstantPushDown(NodeDef* node); + // Strength reduces floating point division by a constant Div(x, const) to + // multiplication by the reciprocal Mul(x, Reciprocal(const)). + bool ReduceDivToReciprocalMul(GraphDef* optimized_graph, NodeDef* node); + + // Simplifies arithmetic operations with ones or zeros. Returns the status, + // and updates the success input argument that denotes if any simplification + // was applied. + Status SimplifyArithmeticOperations(GraphDef* optimized_graph, + GraphProperties* properties, + NodeDef* node, bool use_shape_info, + bool* success); + // Points to an externally provided device or to owned_device_; RewriterConfig::Toggle opt_level_; DeviceBase* cpu_device_; -- 2.7.4