From 2cb86382ebc8432b25469f813c9156507984043f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 24 May 2018 13:13:42 -0700 Subject: [PATCH] Extracts the Simplify Pack optimization into its own method. PiperOrigin-RevId: 197941474 --- .../core/grappler/optimizers/constant_folding.cc | 65 ++++++++++++---------- .../core/grappler/optimizers/constant_folding.h | 5 +- 2 files changed, 40 insertions(+), 30 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 62e1ab0..b8b8088 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1955,35 +1955,7 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, } } - if (IsPack(*node) && NumNonControlInputs(*node) == 1 && - !OptimizedNodeExists(*node, "_const_axis")) { - // Create constant axis node. - Tensor axis_t(DT_INT32, TensorShape({})); - NodeDef* axis_node = optimized_graph->add_node(); - axis_node->set_name(OptimizedNodeName(*node, "_const_axis")); - const int axis = node->attr().at("axis").i(); - if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() || - !CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node) - .ok()) { - return Status::OK(); - } - // Add a control dependency to make sure axis_node is in the right frame. - const string ctrl_dep = ConstantFolding::AddControlDependency( - node->input(0), graph_, node_map_.get()); - axis_node->add_input(ctrl_dep); - axis_node->set_device(node->device()); - node->set_op("ExpandDims"); - if (node->attr().count("axis") != 0) { - node->mutable_attr()->erase("axis"); - } - if (node->attr().count("N") != 0) { - node->mutable_attr()->erase("N"); - } - (*node->mutable_attr())["Tdim"].set_type(DT_INT32); - node->add_input(axis_node->name()); - if (node->input_size() > 2) { - node->mutable_input()->SwapElements(1, node->input_size() - 1); - } + if (SimplifyPack(optimized_graph, node)) { graph_modified_ = true; return Status::OK(); } @@ -2052,6 +2024,41 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, return Status::OK(); } +bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) { + if (IsPack(*node) && NumNonControlInputs(*node) == 1 && + !OptimizedNodeExists(*node, "_const_axis")) { + // Create constant axis node. + Tensor axis_t(DT_INT32, TensorShape({})); + NodeDef* axis_node = optimized_graph->add_node(); + axis_node->set_name(OptimizedNodeName(*node, "_const_axis")); + const int axis = node->attr().at("axis").i(); + if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() || + !CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node) + .ok()) { + return false; + } + // Add a control dependency to make sure axis_node is in the right frame. + const string ctrl_dep = ConstantFolding::AddControlDependency( + node->input(0), graph_, node_map_.get()); + axis_node->add_input(ctrl_dep); + axis_node->set_device(node->device()); + node->set_op("ExpandDims"); + if (node->attr().count("axis") != 0) { + node->mutable_attr()->erase("axis"); + } + if (node->attr().count("N") != 0) { + node->mutable_attr()->erase("N"); + } + (*node->mutable_attr())["Tdim"].set_type(DT_INT32); + node->add_input(axis_node->name()); + if (node->input_size() > 2) { + node->mutable_input()->SwapElements(1, node->input_size() - 1); + return true; + } + } + return false; +} + bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph, NodeDef* node) { if (IsEnter(*node) && node->input_size() > 0) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 9fd4c9c..be78004 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -164,9 +164,12 @@ class ConstantFolding : public GraphOptimizer { // +------+ bool SimplifySwitch(GraphDef* optimized_graph, NodeDef* node); - // Move constants past Enter node if applicable. + // Moves constants past Enter node if applicable. bool MoveConstantsPastEnter(GraphDef* optimized_graph, NodeDef* node); + // Simplifies Pack operation if applicable. + bool SimplifyPack(GraphDef* optimized_graph, NodeDef* node); + // Points to an externally provided device or to owned_device_; RewriterConfig::Toggle opt_level_; DeviceBase* cpu_device_; -- 2.7.4