From 6d41d9fb0ca1b3f25d24242ca9e45364828baca8 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 14 May 2018 09:45:42 -0700 Subject: [PATCH] Extracts the following optimizations into methods: PartialConstPropThroughIdentityN ConstantPushDown PiperOrigin-RevId: 196520167 --- .../core/grappler/optimizers/constant_folding.cc | 58 ++++++++++++++-------- .../core/grappler/optimizers/constant_folding.h | 8 +++ 2 files changed, 44 insertions(+), 22 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 171d492..b2dcbf9 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -2157,6 +2157,30 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph, return Status::OK(); } + if (ConstantPushDown(node)) { + graph_modified_ = true; + return Status::OK(); + } + + if (PartialConstPropThroughIdentityN(node)) { + graph_modified_ = true; + return Status::OK(); + } + + if (PartialAssocOpConstFolding(optimized_graph, properties, node)) { + graph_modified_ = true; + return Status::OK(); + } + + if (PartialConcatConstFolding(optimized_graph, properties, node)) { + graph_modified_ = true; + return Status::OK(); + } + + return Status::OK(); +} + +bool ConstantFolding::ConstantPushDown(NodeDef* node) { // Consider the transformation // // + + = parent @@ -2178,22 +2202,22 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph, // division/multiplication. // Don't touch BiasAdd since they can't handle vectors as their first // inputs. - if (has_fetch_ && (IsAdd(*node) || is_mul) && + if (has_fetch_ && (IsAdd(*node) || IsMul(*node)) && NumNonControlInputs(*node) == 2) { NodeDef* left_child = node_map_->GetNode(node->input(0)); NodeDef* right_child = node_map_->GetNode(node->input(1)); // One child must be constant, and the other the same op as the parent. if (node->op() != left_child->op() && node->op() != right_child->op()) { - return Status::OK(); + return false; } const bool left_child_is_constant = IsReallyConstant(*left_child); const bool right_child_is_constant = IsReallyConstant(*right_child); if (!left_child_is_constant && !right_child_is_constant) { - return Status::OK(); + return false; } if (node->device() != left_child->device() || node->device() != right_child->device()) { - return Status::OK(); + return false; } NodeDef* op_child_node = left_child_is_constant ? right_child : left_child; NodeDef* const_child_node = @@ -2203,7 +2227,7 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph, nodes_to_preserve_.find(op_child_node->name()) != nodes_to_preserve_.end() || NumNonControlOutputs(*op_child_node, *node_map_) > 1) { - return Status::OK(); + return false; } // Identify the nodes to swap. @@ -2213,7 +2237,7 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph, const bool right_leaf_is_constant = IsReallyConstant(*right_leaf); if (left_leaf_is_constant && right_leaf_is_constant) { // Child is already foldable, leave it alone. - return Status::OK(); + return false; } const int non_const_leaf_input = left_leaf_is_constant ? 1 : 0; const int parent_const_input = left_child_is_constant ? 0 : 1; @@ -2238,10 +2262,12 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph, node->input(parent_const_input)); std::swap(*node->mutable_input(parent_const_input), *op_child_node->mutable_input(non_const_leaf_input)); - graph_modified_ = true; - return Status::OK(); + return true; } + return false; +} +bool ConstantFolding::PartialConstPropThroughIdentityN(NodeDef* node) { // Partial constant propagation through IdentityN. if (IsIdentityN(*node) && NumNonControlInputs(*node) > 0) { const std::set& tmp = node_map_->GetOutputs(node->name()); @@ -2294,22 +2320,10 @@ Status ConstantFolding::SimplifyNode(NodeDef* node, GraphDef* optimized_graph, for (NodeDef* consumer : consumers) { DedupControlInputs(consumer); } - graph_modified_ = true; - return Status::OK(); + return true; } } - - if (PartialAssocOpConstFolding(optimized_graph, properties, node)) { - graph_modified_ = true; - return Status::OK(); - } - - if (PartialConcatConstFolding(optimized_graph, properties, node)) { - graph_modified_ = true; - return Status::OK(); - } - - return Status::OK(); + return false; } bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph, diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index f92f755..227caba 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -113,6 +113,14 @@ class ConstantFolding : public GraphOptimizer { bool PartialAssocOpConstFolding(GraphDef* optimized_graph, GraphProperties* properties, NodeDef* node); + // Applies partial constant propagation through IdentityN operator. + // Returns true if the transformation applied successfully. + bool PartialConstPropThroughIdentityN(NodeDef* node); + + // Pushes down constants on '+' and '*' operators if applicable. Returns true + // the transformation applied successfully. + bool ConstantPushDown(NodeDef* node); + // Points to an externally provided device or to owned_device_; RewriterConfig::Toggle opt_level_; DeviceBase* cpu_device_; -- 2.7.4