From 41658b0f288535ecca0512457610db4cb43bea27 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 29 May 2018 11:12:22 -0700 Subject: [PATCH] Extracts the 'remove shuffle or transpose node' optimization into its own method. PiperOrigin-RevId: 198425354 --- .../core/grappler/optimizers/constant_folding.cc | 93 +++++++++++++--------- .../core/grappler/optimizers/constant_folding.h | 6 ++ 2 files changed, 60 insertions(+), 39 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index fed5d87..90862b0 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1655,45 +1655,14 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, return Status::OK(); } - // Remove Shuffle or Transpose op over dimensions of size 1. - if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) && - properties->GetInputProperties(node->name()).size() >= 2) { - const auto& shape = properties->GetInputProperties(node->name())[0].shape(); - if (shape.unknown_rank()) { - // Not optimizable. - return Status::OK(); - } - const auto& p = properties->GetInputProperties(node->name())[1]; - if (TensorShape::IsValid(p.shape()) && p.has_value()) { - Tensor perm(p.dtype(), p.shape()); - if (!perm.FromProto(p.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - p.value().DebugString()); - } - std::vector permutation; - for (int j = 0; j < perm.NumElements(); ++j) { - if (perm.dtype() == DT_INT64) { - permutation.push_back(perm.vec()(j)); - } else { - permutation.push_back(perm.vec()(j)); - } - } - if (permutation.size() != shape.dim_size()) { - // Number of elements in perm should be same as dim_size. Skip if not. - return Status::OK(); - } - // The node is replaceable iff - // dim_size == 0 || all dims have size 1 || - // all dims with > 1 size are not permuted. - bool replaceable = true; - for (int j = 0; replaceable && j < shape.dim_size(); ++j) { - replaceable &= shape.dim(j).size() == 1 || j == permutation[j]; - } - if (replaceable) { - ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); - return Status::OK(); - } - } + bool remove_shuffle_transpose_successful = false; + Status remove_shuffle_transpose_status = + RemoveShuffleOrTranspose(*properties, use_shape_info, optimized_graph, + node, &remove_shuffle_transpose_successful); + if (!remove_shuffle_transpose_status.ok()) { + return remove_shuffle_transpose_status; + } else if (remove_shuffle_transpose_successful) { + return Status::OK(); } if (RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node)) { @@ -1823,6 +1792,52 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, return Status::OK(); } +Status ConstantFolding::RemoveShuffleOrTranspose( + const GraphProperties& properties, bool use_shape_info, + GraphDef* optimized_graph, NodeDef* node, bool* success) { + if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) && + properties.GetInputProperties(node->name()).size() >= 2) { + const auto& shape = properties.GetInputProperties(node->name())[0].shape(); + if (shape.unknown_rank()) { + // Not optimizable. + return Status::OK(); + } + const auto& p = properties.GetInputProperties(node->name())[1]; + if (TensorShape::IsValid(p.shape()) && p.has_value()) { + Tensor perm(p.dtype(), p.shape()); + if (!perm.FromProto(p.value())) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + p.value().DebugString()); + } + std::vector permutation; + for (int j = 0; j < perm.NumElements(); ++j) { + if (perm.dtype() == DT_INT64) { + permutation.push_back(perm.vec()(j)); + } else { + permutation.push_back(perm.vec()(j)); + } + } + if (permutation.size() != shape.dim_size()) { + // Number of elements in perm should be same as dim_size. Skip if not. + return Status::OK(); + } + // The node is replaceable iff + // dim_size == 0 || all dims have size 1 || + // all dims with > 1 size are not permuted. + bool replaceable = true; + for (int j = 0; replaceable && j < shape.dim_size(); ++j) { + replaceable &= shape.dim(j).size() == 1 || j == permutation[j]; + } + if (replaceable) { + ReplaceOperationWithIdentity(0, properties, node, optimized_graph); + *success = true; + return Status::OK(); + } + } + } + *success = false; + return Status::OK(); +} bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index c760b05..43cabb4 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -199,6 +199,12 @@ class ConstantFolding : public GraphOptimizer { bool RemoveRandomShuffle(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node); + + // Removes Shuffle or Transpose op over dimensions of size 1. + Status RemoveShuffleOrTranspose(const GraphProperties& properties, + bool use_shape_info, + GraphDef* optimized_graph, NodeDef* node, + bool* success); // Points to an externally provided device or to owned_device_; RewriterConfig::Toggle opt_level_; DeviceBase* cpu_device_; -- 2.7.4