From da07aa28e0eef4aebe4851e9bdfc40e7b098cf04 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 23 May 2018 18:13:23 -0700 Subject: [PATCH] Extracts the SimplifyReduction optimization into its own method. PiperOrigin-RevId: 197823183 --- .../core/grappler/optimizers/constant_folding.cc | 37 ++++++++++++++-------- .../core/grappler/optimizers/constant_folding.h | 6 ++-- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 4ebe1ca..bf606fb 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -2133,20 +2133,8 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, return Status::OK(); } } - if (IsSimplifiableReduction(*node, *properties)) { - // Replace the reduction node with an identity node, that can be further - // optimized by the model pruner. - DataType output_type; - if (node->attr().count("T") > 0) { - output_type = node->attr().at("T").type(); - } else { - // This is an 'any' or 'all' reduction. The output is always boolean. - output_type = DT_BOOL; - } - node->set_op("Identity"); - node->clear_attr(); - (*node->mutable_attr())["T"].set_type(output_type); - *node->mutable_input(1) = AsControlDependency(node->input(1)); + + if (SimplifyReduction(*properties, node)) { graph_modified_ = true; return Status::OK(); } @@ -2200,6 +2188,27 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, return Status::OK(); } +bool ConstantFolding::SimplifyReduction(const GraphProperties& properties, + NodeDef* node) { + if (IsSimplifiableReduction(*node, properties)) { + // Replace the reduction node with an identity node, that can be further + // optimized by the model pruner. + DataType output_type; + if (node->attr().count("T") > 0) { + output_type = node->attr().at("T").type(); + } else { + // This is an 'any' or 'all' reduction. The output is always boolean. + output_type = DT_BOOL; + } + node->set_op("Identity"); + node->clear_attr(); + (*node->mutable_attr())["T"].set_type(output_type); + *node->mutable_input(1) = AsControlDependency(node->input(1)); + return true; + } + return false; +} + bool ConstantFolding::SimplifyReshape(const GraphProperties& properties, bool use_shape_info, NodeDef* node) { if (!use_shape_info) return false; diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 3cf379f..07a2e01 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -139,11 +139,13 @@ class ConstantFolding : public GraphOptimizer { GraphDef* optimized_graph, NodeDef* node, bool* success); - // Simplifies a Reshape operation to an Identity operation if the input node - // to the operation is a constant. + // Simplifies a Reshape operation to an Identity operation if applicable. bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info, NodeDef* node); + // Simplifies a Reduction operation to an Identity operation if applicable. + bool SimplifyReduction(const GraphProperties& properties, NodeDef* node); + // Points to an externally provided device or to owned_device_; RewriterConfig::Toggle opt_level_; DeviceBase* cpu_device_; -- 2.7.4