From a797ded69e4fb2d8e7cd23b5f73a09abaabb31c6 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 24 May 2018 20:36:45 -0700 Subject: [PATCH] Extracts the 'simplify tile node' optimization into its own method. PiperOrigin-RevId: 197996636 --- .../core/grappler/optimizers/constant_folding.cc | 70 ++++++++++++++-------- .../core/grappler/optimizers/constant_folding.h | 3 + 2 files changed, 47 insertions(+), 26 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 8cd1968..a64e9a3 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1885,32 +1885,14 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, } } - if (use_shape_info && IsTile(*node) && - properties->GetInputProperties(node->name()).size() == 2) { - const auto& m = properties->GetInputProperties(node->name())[1]; - if (TensorShape::IsValid(m.shape()) && m.has_value()) { - Tensor multiplies(m.dtype(), m.shape()); - if (!multiplies.FromProto(m.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - m.value().DebugString()); - } - // The node is replaceable iff all values in multiplies are 1. - bool replaceable = true; - if (multiplies.dtype() == DT_INT32) { - for (int j = 0; replaceable && j < multiplies.vec().size(); ++j) { - replaceable &= multiplies.vec()(j) == 1; - } - } else { - for (int j = 0; replaceable && j < multiplies.vec().size(); - ++j) { - replaceable &= multiplies.vec()(j) == 1; - } - } - if (replaceable) { - ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); - return Status::OK(); - } - } + bool simplify_tile_successful = false; + Status simplify_tile_status = + SimplifyTile(*properties, use_shape_info, optimized_graph, node, + &simplify_tile_successful); + if (!simplify_tile_status.ok()) { + return simplify_tile_status; + } else if (simplify_tile_successful) { + return Status::OK(); } bool simplify_pad_successful = false; @@ -1996,6 +1978,42 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, return Status::OK(); } +Status ConstantFolding::SimplifyTile(const GraphProperties& properties, + bool use_shape_info, + GraphDef* optimized_graph, NodeDef* node, + bool* success) { + if (use_shape_info && IsTile(*node) && + properties.GetInputProperties(node->name()).size() == 2) { + const auto& m = properties.GetInputProperties(node->name())[1]; + if (TensorShape::IsValid(m.shape()) && m.has_value()) { + Tensor multiplies(m.dtype(), m.shape()); + if (!multiplies.FromProto(m.value())) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + m.value().DebugString()); + } + // The node is replaceable iff all values in multiplies are 1. + bool replaceable = true; + if (multiplies.dtype() == DT_INT32) { + for (int j = 0; replaceable && j < multiplies.vec().size(); ++j) { + replaceable &= multiplies.vec()(j) == 1; + } + } else { + for (int j = 0; replaceable && j < multiplies.vec().size(); + ++j) { + replaceable &= multiplies.vec()(j) == 1; + } + } + if (replaceable) { + ReplaceOperationWithIdentity(0, properties, node, optimized_graph); + *success = true; + return Status::OK(); + } + } + } + *success = false; + return Status::OK(); +} + Status ConstantFolding::SimplifyPad(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node, diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index fa9249f..30e6354 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -178,6 +178,9 @@ class ConstantFolding : public GraphOptimizer { Status SimplifyPad(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node, bool* success); + // Simplifies a Tile operation to an Identity operation if applicable. + Status SimplifyTile(const GraphProperties& properties, bool use_shape_info, + GraphDef* optimized_graph, NodeDef* node, bool* success); // Points to an externally provided device or to owned_device_; RewriterConfig::Toggle opt_level_; DeviceBase* cpu_device_; -- 2.7.4