Extracts the 'simplify strided slice' optimization into its own method.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 25 May 2018 18:34:30 +0000 (11:34 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 25 May 2018 18:37:35 +0000 (11:37 -0700)
PiperOrigin-RevId: 198078724

tensorflow/core/grappler/optimizers/constant_folding.cc
tensorflow/core/grappler/optimizers/constant_folding.h

index a64e9a3..90c52b3 100644 (file)
@@ -1793,96 +1793,14 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
     }
   }
 
-  if (use_shape_info && IsStridedSlice(*node) &&
-      properties->GetInputProperties(node->name()).size() == 4) {
-    if (node->attr().at("new_axis_mask").i() != 0 ||
-        node->attr().at("shrink_axis_mask").i() != 0) {
-      // Skip nodes with new/shrink axis mask, since they involve dimension
-      // changes.
-      return Status::OK();
-    }
-    const auto& input = properties->GetInputProperties(node->name())[0];
-    for (int j = 0; j < input.shape().dim_size(); ++j) {
-      // Skip if input shape is not fully determined.
-      if (input.shape().dim(j).size() < 0) {
-        return Status::OK();
-      }
-    }
-    const auto& b = properties->GetInputProperties(node->name())[1];
-    const auto& e = properties->GetInputProperties(node->name())[2];
-    const auto& s = properties->GetInputProperties(node->name())[3];
-    if (TensorShape::IsValid(b.shape()) && b.has_value() &&
-        TensorShape::IsValid(e.shape()) && e.has_value() &&
-        TensorShape::IsValid(s.shape()) && s.has_value()) {
-      Tensor begin(b.dtype(), b.shape());
-      if (!begin.FromProto(b.value())) {
-        return errors::InvalidArgument("Cannot parse tensor from proto: ",
-                                       b.value().DebugString());
-      }
-      Tensor end(e.dtype(), e.shape());
-      if (!end.FromProto(e.value())) {
-        return errors::InvalidArgument("Cannot parse tensor from proto: ",
-                                       e.value().DebugString());
-      }
-      Tensor strides(s.dtype(), s.shape());
-      if (!strides.FromProto(s.value())) {
-        return errors::InvalidArgument("Cannot parse tensor from proto: ",
-                                       s.value().DebugString());
-      }
-      int begin_mask = node->attr().at("begin_mask").i();
-      int end_mask = node->attr().at("end_mask").i();
-      std::set<int> expanded_ellipsis_indices;
-      int ellipsis_index = -1;
-      for (int j = 0; j < input.shape().dim_size(); ++j) {
-        // find the ellipsis_mask. If not found, insert one in the end if
-        // necessary.
-        if (node->attr().at("ellipsis_mask").i() & 1 << j ||
-            (ellipsis_index == -1 && j >= strides.NumElements())) {
-          ellipsis_index = j;
-        }
-        // insert the indices that are immediately after ellipsis_index if
-        // necessary.
-        if (ellipsis_index != -1 &&
-            input.shape().dim_size() >
-                strides.NumElements() + j - ellipsis_index) {
-          expanded_ellipsis_indices.insert(j);
-        }
-      }
-
-      // The node is replaceable iff unknown_rank == false &&
-      // ((begin_mask is set || begin == 0) && (end_mask is set || end == dim)
-      //  && strides == 1) for all dimensions.
-      bool replaceable = !input.shape().unknown_rank();
-      for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
-        if (expanded_ellipsis_indices.find(j) !=
-            expanded_ellipsis_indices.end()) {
-          // ellipsis_mask is effective on current dimension.
-          continue;
-        }
-        // when we have ellipsis_mask in between, input.shape().dim_size() will
-        // be greater than strides.NumElements(), since we will insert
-        // as many as expanded_ellipsis_indices.size() axes during computation.
-        // We need to subtract this number from j.
-        int i = j;
-        if (ellipsis_index != -1 &&
-            j >= ellipsis_index + expanded_ellipsis_indices.size()) {
-          i = j - expanded_ellipsis_indices.size();
-        }
-        int b = begin.dtype() == DT_INT32 ? begin.vec<int>()(i)
-                                          : begin.vec<int64>()(i);
-        int e =
-            end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64>()(i);
-        int s = strides.dtype() == DT_INT32 ? strides.vec<int>()(i)
-                                            : strides.vec<int64>()(i);
-        replaceable &=
-            (begin_mask & 1 << i || b == 0) &&
-            (end_mask & 1 << i || e == input.shape().dim(j).size()) && s == 1;
-      }
-      if (replaceable) {
-        ReplaceOperationWithIdentity(0, *properties, node, optimized_graph);
-        return Status::OK();
-      }
-    }
+  bool simplify_strided_slice_successful = false;
+  Status simplify_strided_slice_status =
+      SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node,
+                           &simplify_strided_slice_successful);
+  if (!simplify_strided_slice_status.ok()) {
+    return simplify_strided_slice_status;
+  } else if (simplify_strided_slice_successful) {
+    return Status::OK();
   }
 
   bool simplify_tile_successful = false;
@@ -1978,6 +1896,106 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
   return Status::OK();
 }
 
+Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
+                                             bool use_shape_info,
+                                             GraphDef* optimized_graph,
+                                             NodeDef* node, bool* success) {
+  if (use_shape_info && IsStridedSlice(*node) &&
+      properties.GetInputProperties(node->name()).size() == 4) {
+    if (node->attr().at("new_axis_mask").i() != 0 ||
+        node->attr().at("shrink_axis_mask").i() != 0) {
+      // Skip nodes with new/shrink axis mask, since they involve dimension
+      // changes.
+      return Status::OK();
+    }
+    const auto& input = properties.GetInputProperties(node->name())[0];
+    for (int j = 0; j < input.shape().dim_size(); ++j) {
+      // Skip if input shape is not fully determined.
+      if (input.shape().dim(j).size() < 0) {
+        return Status::OK();
+      }
+    }
+    const auto& b = properties.GetInputProperties(node->name())[1];
+    const auto& e = properties.GetInputProperties(node->name())[2];
+    const auto& s = properties.GetInputProperties(node->name())[3];
+    if (TensorShape::IsValid(b.shape()) && b.has_value() &&
+        TensorShape::IsValid(e.shape()) && e.has_value() &&
+        TensorShape::IsValid(s.shape()) && s.has_value()) {
+      Tensor begin(b.dtype(), b.shape());
+      if (!begin.FromProto(b.value())) {
+        return errors::InvalidArgument("Cannot parse tensor from proto: ",
+                                       b.value().DebugString());
+      }
+      Tensor end(e.dtype(), e.shape());
+      if (!end.FromProto(e.value())) {
+        return errors::InvalidArgument("Cannot parse tensor from proto: ",
+                                       e.value().DebugString());
+      }
+      Tensor strides(s.dtype(), s.shape());
+      if (!strides.FromProto(s.value())) {
+        return errors::InvalidArgument("Cannot parse tensor from proto: ",
+                                       s.value().DebugString());
+      }
+      int begin_mask = node->attr().at("begin_mask").i();
+      int end_mask = node->attr().at("end_mask").i();
+      std::set<int> expanded_ellipsis_indices;
+      int ellipsis_index = -1;
+      for (int j = 0; j < input.shape().dim_size(); ++j) {
+        // find the ellipsis_mask. If not found, insert one in the end if
+        // necessary.
+        if (node->attr().at("ellipsis_mask").i() & 1 << j ||
+            (ellipsis_index == -1 && j >= strides.NumElements())) {
+          ellipsis_index = j;
+        }
+        // insert the indices that are immediately after ellipsis_index if
+        // necessary.
+        if (ellipsis_index != -1 &&
+            input.shape().dim_size() >
+                strides.NumElements() + j - ellipsis_index) {
+          expanded_ellipsis_indices.insert(j);
+        }
+      }
+
+      // The node is replaceable iff unknown_rank == false &&
+      // ((begin_mask is set || begin == 0) && (end_mask is set || end == dim)
+      //  && strides == 1) for all dimensions.
+      bool replaceable = !input.shape().unknown_rank();
+      for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
+        if (expanded_ellipsis_indices.find(j) !=
+            expanded_ellipsis_indices.end()) {
+          // ellipsis_mask is effective on current dimension.
+          continue;
+        }
+        // when we have ellipsis_mask in between, input.shape().dim_size() will
+        // be greater than strides.NumElements(), since we will insert
+        // as many as expanded_ellipsis_indices.size() axes during computation.
+        // We need to subtract this number from j.
+        int i = j;
+        if (ellipsis_index != -1 &&
+            j >= ellipsis_index + expanded_ellipsis_indices.size()) {
+          i = j - expanded_ellipsis_indices.size();
+        }
+        int b = begin.dtype() == DT_INT32 ? begin.vec<int>()(i)
+                                          : begin.vec<int64>()(i);
+        int e =
+            end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64>()(i);
+        int s = strides.dtype() == DT_INT32 ? strides.vec<int>()(i)
+                                            : strides.vec<int64>()(i);
+        replaceable &=
+            (begin_mask & 1 << i || b == 0) &&
+            (end_mask & 1 << i || e == input.shape().dim(j).size()) && s == 1;
+      }
+      if (replaceable) {
+        ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+        *success = true;
+        return Status::OK();
+      }
+    }
+  }
+  *success = false;
+  return Status::OK();
+}
+
 Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
                                      bool use_shape_info,
                                      GraphDef* optimized_graph, NodeDef* node,
index 30e6354..6c42b8f 100644 (file)
@@ -181,6 +181,11 @@ class ConstantFolding : public GraphOptimizer {
   // Simplifies a Tile operation to an Identity operation if applicable.
   Status SimplifyTile(const GraphProperties& properties, bool use_shape_info,
                       GraphDef* optimized_graph, NodeDef* node, bool* success);
+
+  // Simplifies a StridedSlice operation to an Identity operation if applicable.
+  Status SimplifyStridedSlice(const GraphProperties& properties,
+                              bool use_shape_info, GraphDef* optimized_graph,
+                              NodeDef* node, bool* success);
   // Points to an externally provided device or to owned_device_;
   RewriterConfig::Toggle opt_level_;
   DeviceBase* cpu_device_;