From 5acba9b600d5463dd4b542c7f606c02da6bc6f6c Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Sat, 26 May 2018 08:25:12 -0700 Subject: [PATCH] Extracts the 'remove random shuffle node' optimization into its own method. PiperOrigin-RevId: 198169790 --- .../core/grappler/optimizers/constant_folding.cc | 32 ++++++++++++++-------- .../core/grappler/optimizers/constant_folding.h | 5 ++++ 2 files changed, 26 insertions(+), 11 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index df32d4a..fed5d87 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1695,17 +1695,9 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, } } } - // Remove RandomShuffle op if it is scalar or first dimension is of size 1. - if (use_shape_info && IsRandomShuffle(*node) && - !properties->GetInputProperties(node->name()).empty()) { - const auto& shape = properties->GetInputProperties(node->name())[0].shape(); - // The node is replaceable iff - // unknown_rank == false && (dim_size == 0 || first dim is of size 1) - if (!shape.unknown_rank() && - (shape.dim_size() == 0 || shape.dim(0).size() == 1)) { - ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); - return Status::OK(); - } + + if (RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node)) { + return Status::OK(); } bool remove_reverse_successful = false; @@ -1831,6 +1823,24 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, return Status::OK(); } +bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties, + bool use_shape_info, + GraphDef* optimized_graph, + NodeDef* node) { + if (use_shape_info && IsRandomShuffle(*node) && + !properties.GetInputProperties(node->name()).empty()) { + const auto& shape = properties.GetInputProperties(node->name())[0].shape(); + // The node is replaceable iff + // unknown_rank == false && (dim_size == 0 || first dim is of size 1) + if (!shape.unknown_rank() && + (shape.dim_size() == 0 || shape.dim(0).size() == 1)) { + ReplaceOperationWithIdentity(0, properties, node, optimized_graph); + return true; + } + } + return false; +} + Status ConstantFolding::RemoveReverse(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node, diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 9a3ea03..c760b05 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -194,6 +194,11 @@ class ConstantFolding : public GraphOptimizer { // Removes Reverse op over dimensions with size 1. Status RemoveReverse(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node, bool* success); + + // Removes RandomShuffle op if it is scalar or first dimension is of size 1. + bool RemoveRandomShuffle(const GraphProperties& properties, + bool use_shape_info, GraphDef* optimized_graph, + NodeDef* node); // Points to an externally provided device or to owned_device_; RewriterConfig::Toggle opt_level_; DeviceBase* cpu_device_; -- 2.7.4