From 3890dba889fce1d49a199c72892863f28de02179 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 25 May 2018 11:34:30 -0700 Subject: [PATCH] Extracts the 'simplify strided slice' optimization into its own method. PiperOrigin-RevId: 198078724 --- .../core/grappler/optimizers/constant_folding.cc | 198 +++++++++++---------- .../core/grappler/optimizers/constant_folding.h | 5 + 2 files changed, 113 insertions(+), 90 deletions(-) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index a64e9a3..90c52b3 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1793,96 +1793,14 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, } } - if (use_shape_info && IsStridedSlice(*node) && - properties->GetInputProperties(node->name()).size() == 4) { - if (node->attr().at("new_axis_mask").i() != 0 || - node->attr().at("shrink_axis_mask").i() != 0) { - // Skip nodes with new/shrink axis mask, since they involve dimension - // changes. - return Status::OK(); - } - const auto& input = properties->GetInputProperties(node->name())[0]; - for (int j = 0; j < input.shape().dim_size(); ++j) { - // Skip if input shape is not fully determined. - if (input.shape().dim(j).size() < 0) { - return Status::OK(); - } - } - const auto& b = properties->GetInputProperties(node->name())[1]; - const auto& e = properties->GetInputProperties(node->name())[2]; - const auto& s = properties->GetInputProperties(node->name())[3]; - if (TensorShape::IsValid(b.shape()) && b.has_value() && - TensorShape::IsValid(e.shape()) && e.has_value() && - TensorShape::IsValid(s.shape()) && s.has_value()) { - Tensor begin(b.dtype(), b.shape()); - if (!begin.FromProto(b.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - b.value().DebugString()); - } - Tensor end(e.dtype(), e.shape()); - if (!end.FromProto(e.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - e.value().DebugString()); - } - Tensor strides(s.dtype(), s.shape()); - if (!strides.FromProto(s.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - s.value().DebugString()); - } - int begin_mask = node->attr().at("begin_mask").i(); - int end_mask = node->attr().at("end_mask").i(); - std::set expanded_ellipsis_indices; - int ellipsis_index = -1; - for (int j = 0; j < input.shape().dim_size(); ++j) { - // find the ellipsis_mask. If not found, insert one in the end if - // necessary. - if (node->attr().at("ellipsis_mask").i() & 1 << j || - (ellipsis_index == -1 && j >= strides.NumElements())) { - ellipsis_index = j; - } - // insert the indices that are immediately after ellipsis_index if - // necessary. - if (ellipsis_index != -1 && - input.shape().dim_size() > - strides.NumElements() + j - ellipsis_index) { - expanded_ellipsis_indices.insert(j); - } - } - - // The node is replaceable iff unknown_rank == false && - // ((begin_mask is set || begin == 0) && (end_mask is set || end == dim) - // && strides == 1) for all dimensions. - bool replaceable = !input.shape().unknown_rank(); - for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) { - if (expanded_ellipsis_indices.find(j) != - expanded_ellipsis_indices.end()) { - // ellipsis_mask is effective on current dimension. - continue; - } - // when we have ellipsis_mask in between, input.shape().dim_size() will - // be greater than strides.NumElements(), since we will insert - // as many as expanded_ellipsis_indices.size() axes during computation. - // We need to subtract this number from j. - int i = j; - if (ellipsis_index != -1 && - j >= ellipsis_index + expanded_ellipsis_indices.size()) { - i = j - expanded_ellipsis_indices.size(); - } - int b = begin.dtype() == DT_INT32 ? begin.vec()(i) - : begin.vec()(i); - int e = - end.dtype() == DT_INT32 ? end.vec()(i) : end.vec()(i); - int s = strides.dtype() == DT_INT32 ? strides.vec()(i) - : strides.vec()(i); - replaceable &= - (begin_mask & 1 << i || b == 0) && - (end_mask & 1 << i || e == input.shape().dim(j).size()) && s == 1; - } - if (replaceable) { - ReplaceOperationWithIdentity(0, *properties, node, optimized_graph); - return Status::OK(); - } - } + bool simplify_strided_slice_successful = false; + Status simplify_strided_slice_status = + SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node, + &simplify_strided_slice_successful); + if (!simplify_strided_slice_status.ok()) { + return simplify_strided_slice_status; + } else if (simplify_strided_slice_successful) { + return Status::OK(); } bool simplify_tile_successful = false; @@ -1978,6 +1896,106 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, return Status::OK(); } +Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties, + bool use_shape_info, + GraphDef* optimized_graph, + NodeDef* node, bool* success) { + if (use_shape_info && IsStridedSlice(*node) && + properties.GetInputProperties(node->name()).size() == 4) { + if (node->attr().at("new_axis_mask").i() != 0 || + node->attr().at("shrink_axis_mask").i() != 0) { + // Skip nodes with new/shrink axis mask, since they involve dimension + // changes. + return Status::OK(); + } + const auto& input = properties.GetInputProperties(node->name())[0]; + for (int j = 0; j < input.shape().dim_size(); ++j) { + // Skip if input shape is not fully determined. + if (input.shape().dim(j).size() < 0) { + return Status::OK(); + } + } + const auto& b = properties.GetInputProperties(node->name())[1]; + const auto& e = properties.GetInputProperties(node->name())[2]; + const auto& s = properties.GetInputProperties(node->name())[3]; + if (TensorShape::IsValid(b.shape()) && b.has_value() && + TensorShape::IsValid(e.shape()) && e.has_value() && + TensorShape::IsValid(s.shape()) && s.has_value()) { + Tensor begin(b.dtype(), b.shape()); + if (!begin.FromProto(b.value())) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + b.value().DebugString()); + } + Tensor end(e.dtype(), e.shape()); + if (!end.FromProto(e.value())) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + e.value().DebugString()); + } + Tensor strides(s.dtype(), s.shape()); + if (!strides.FromProto(s.value())) { + return errors::InvalidArgument("Cannot parse tensor from proto: ", + s.value().DebugString()); + } + int begin_mask = node->attr().at("begin_mask").i(); + int end_mask = node->attr().at("end_mask").i(); + std::set expanded_ellipsis_indices; + int ellipsis_index = -1; + for (int j = 0; j < input.shape().dim_size(); ++j) { + // find the ellipsis_mask. If not found, insert one in the end if + // necessary. + if (node->attr().at("ellipsis_mask").i() & 1 << j || + (ellipsis_index == -1 && j >= strides.NumElements())) { + ellipsis_index = j; + } + // insert the indices that are immediately after ellipsis_index if + // necessary. + if (ellipsis_index != -1 && + input.shape().dim_size() > + strides.NumElements() + j - ellipsis_index) { + expanded_ellipsis_indices.insert(j); + } + } + + // The node is replaceable iff unknown_rank == false && + // ((begin_mask is set || begin == 0) && (end_mask is set || end == dim) + // && strides == 1) for all dimensions. + bool replaceable = !input.shape().unknown_rank(); + for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) { + if (expanded_ellipsis_indices.find(j) != + expanded_ellipsis_indices.end()) { + // ellipsis_mask is effective on current dimension. + continue; + } + // when we have ellipsis_mask in between, input.shape().dim_size() will + // be greater than strides.NumElements(), since we will insert + // as many as expanded_ellipsis_indices.size() axes during computation. + // We need to subtract this number from j. + int i = j; + if (ellipsis_index != -1 && + j >= ellipsis_index + expanded_ellipsis_indices.size()) { + i = j - expanded_ellipsis_indices.size(); + } + int b = begin.dtype() == DT_INT32 ? begin.vec()(i) + : begin.vec()(i); + int e = + end.dtype() == DT_INT32 ? end.vec()(i) : end.vec()(i); + int s = strides.dtype() == DT_INT32 ? strides.vec()(i) + : strides.vec()(i); + replaceable &= + (begin_mask & 1 << i || b == 0) && + (end_mask & 1 << i || e == input.shape().dim(j).size()) && s == 1; + } + if (replaceable) { + ReplaceOperationWithIdentity(0, properties, node, optimized_graph); + *success = true; + return Status::OK(); + } + } + } + *success = false; + return Status::OK(); +} + Status ConstantFolding::SimplifyTile(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node, diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 30e6354..6c42b8f 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -181,6 +181,11 @@ class ConstantFolding : public GraphOptimizer { // Simplifies a Tile operation to an Identity operation if applicable. Status SimplifyTile(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node, bool* success); + + // Simplifies a StridedSlice operation to an Identity operation if applicable. + Status SimplifyStridedSlice(const GraphProperties& properties, + bool use_shape_info, GraphDef* optimized_graph, + NodeDef* node, bool* success); // Points to an externally provided device or to owned_device_; RewriterConfig::Toggle opt_level_; DeviceBase* cpu_device_; -- 2.7.4