Add logic for StridedSlice ops in ShapeRefiner::ConstantPartialShape().
authorSkye Wanderman-Milne <skyewm@google.com>
Tue, 8 May 2018 00:28:41 +0000 (17:28 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 8 May 2018 01:32:07 +0000 (18:32 -0700)
This mimics the logic in tensor_util.constant_value_as_shape, allowing
the C++ shape inference code to infer more shapes than it could before.

This change also adds an optional stride argument to InferenceContext::Subshape().

PiperOrigin-RevId: 195749522

tensorflow/core/common_runtime/shape_refiner.cc
tensorflow/core/common_runtime/shape_refiner.h
tensorflow/core/common_runtime/shape_refiner_test.cc
tensorflow/core/framework/shape_inference.cc
tensorflow/core/framework/shape_inference.h

index a077271..fa4d1ed 100644 (file)
@@ -421,6 +421,28 @@ Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node,
                                 kMaxTensorSize, disable_constant_propagation_);
 }
 
+Status ShapeRefiner::EvaluateConstantIntScalarEdge(const Node* node,
+                                                   int dst_idx, bool* evaluated,
+                                                   int64* result) {
+  Tensor scalar;
+  TF_RETURN_IF_ERROR(
+      EvaluateConstantTensorForEdge(node, dst_idx, evaluated, &scalar));
+  if (*evaluated) {
+    DCHECK_EQ(scalar.NumElements(), 1)
+        << "EvaluateConstantIntScalarEdge called on non-scalar edge: "
+        << scalar.NumElements();
+    if (scalar.dtype() == DT_INT32) {
+      *result = scalar.scalar<int32>()();
+    } else {
+      DCHECK_EQ(scalar.dtype(), DT_INT64)
+          << "EvaluateConstantIntScalarEdge called on non-integer edge: "
+          << scalar.dtype();
+      *result = scalar.scalar<int64>()();
+    }
+  }
+  return Status::OK();
+}
+
 Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
                                           const Node* node, int dst_idx,
                                           ShapeHandle* result) {
@@ -471,19 +493,11 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
     std::vector<DimensionHandle> dims;
     // Pack is concatenating its input scalars to form the shape tensor vector.
     for (int i = 0; i < src_context->num_inputs(); ++i) {
-      Tensor scalar;
-      bool evaluated = false;
-      TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(input_edge->src(), i,
-                                                       &evaluated, &scalar));
+      int64 size;
+      bool evaluated;
+      TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(input_edge->src(), i,
+                                                       &evaluated, &size));
       if (evaluated) {
-        int64 size;
-        if (scalar.dtype() == DT_INT32) {
-          size = scalar.scalar<int32>()();
-        } else if (scalar.dtype() == DT_INT64) {
-          size = scalar.scalar<int64>()();
-        } else {
-          return errors::InvalidArgument("Pack input must be int32 or int64");
-        }
         dims.push_back(size < 0 ? target_context->UnknownDim()
                                 : target_context->MakeDim(size));
       } else {
@@ -513,6 +527,9 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
       TF_RETURN_IF_ERROR(
           target_context->Concatenate(*result, sub_result, result));
     }
+  } else if (src_op == "StridedSlice") {
+    TF_RETURN_IF_ERROR(
+        PartialStridedSliceShape(input_edge->src(), src_context, result));
   } else {
     Tensor t;
     bool evaluated = false;
@@ -524,6 +541,78 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
   return Status::OK();
 }
 
+Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
+                                              InferenceContext* ctx,
+                                              ShapeHandle* result) {
+  // Only attempt to evaluate if begin/end/strides all are scalars.
+  for (int i = 1; i <= 3; ++i) {
+    ShapeHandle input_shape = ctx->input(i);
+    if (ctx->Value(ctx->Dim(input_shape, 0)) != 1) {
+      *result = ctx->UnknownShape();
+      return Status::OK();
+    }
+  }
+
+  int begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
+  TF_RETURN_IF_ERROR(
+      GetNodeAttr(slice_node->attrs(), "begin_mask", &begin_mask));
+  TF_RETURN_IF_ERROR(GetNodeAttr(slice_node->attrs(), "end_mask", &end_mask));
+  TF_RETURN_IF_ERROR(
+      GetNodeAttr(slice_node->attrs(), "ellipsis_mask", &ellipsis_mask));
+  TF_RETURN_IF_ERROR(
+      GetNodeAttr(slice_node->attrs(), "new_axis_mask", &new_axis_mask));
+  TF_RETURN_IF_ERROR(
+      GetNodeAttr(slice_node->attrs(), "shrink_axis_mask", &shrink_axis_mask));
+
+  // Only attempt to evaluate if there are no special masks set (note that we
+  // can handle begin/end_mask == 1).
+  if (!(begin_mask == 0 || begin_mask == 1) ||
+      !(end_mask == 0 || end_mask == 1) || ellipsis_mask != 0 ||
+      new_axis_mask != 0 || shrink_axis_mask != 0) {
+    *result = ctx->UnknownShape();
+    return Status::OK();
+  }
+
+  bool evaluated;
+  int64 begin;
+  if (begin_mask == 1) {
+    begin = 0;
+  } else {
+    TF_RETURN_IF_ERROR(
+        EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated, &begin));
+    if (!evaluated) {
+      *result = ctx->UnknownShape();
+      return Status::OK();
+    }
+  }
+
+  int64 end;
+  if (end_mask == 1) {
+    end = std::numeric_limits<int64>::max();
+  } else {
+    TF_RETURN_IF_ERROR(
+        EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated, &end));
+    if (!evaluated) {
+      *result = ctx->UnknownShape();
+      return Status::OK();
+    }
+  }
+
+  int64 stride;
+  TF_RETURN_IF_ERROR(
+      EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated, &stride));
+  if (!evaluated) {
+    *result = ctx->UnknownShape();
+    return Status::OK();
+  }
+
+  // Apply stride to input interpreted as a partial shape.
+  ShapeHandle input;
+  TF_RETURN_IF_ERROR(ConstantPartialShape(ctx, slice_node, 0, &input));
+  TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result));
+  return Status::OK();
+}
+
 Status ShapeRefiner::RunShapeFn(const Node* node,
                                 const OpRegistrationData* op_reg_data,
                                 ExtendedInferenceContext* ec) {
index d49c437..9c96dcb 100644 (file)
@@ -215,9 +215,18 @@ class ShapeRefiner {
                                 bool keep_nested_shapes,
                                 ExtendedInferenceContext* outer_context);
 
+  // Attempts to evaluate the 'dst_idx'-th input to 'node'. If the input edge
+  // value can be evaluated, 'evaluated' is set to true and the value returned
+  // in 'result'. Otherwise 'evaluated' is set to false.
   Status EvaluateConstantTensorForEdge(const Node* node, int dst_idx,
                                        bool* evaluated, Tensor* result);
 
+  // Wrapper around EvaluateConstantTensorForEdge for scalar int32/int64 input
+  // tensors. The caller is responsible for checking that the specified edge is
+  // scalar and int32 or int64.
+  Status EvaluateConstantIntScalarEdge(const Node* node, int dst_idx,
+                                       bool* evaluated, int64* result);
+
   // This function tries to materialize as much information about the 'node''s
   // dst_idx input as a statically computable shape, and the result may be
   // partially known, depending on what is statically inferable.
@@ -243,6 +252,11 @@ class ShapeRefiner {
                               const Node* node, int dst_idx,
                               shape_inference::ShapeHandle* result);
 
+  // Implementation of ConstantPartialShape for StridedSlice nodes.
+  Status PartialStridedSliceShape(Node* slice_node,
+                                  shape_inference::InferenceContext* ctx,
+                                  shape_inference::ShapeHandle* result);
+
   Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data,
                     ExtendedInferenceContext* ec);
 
index f48638a..8b9657e 100644 (file)
@@ -60,6 +60,39 @@ class ShapeRefinerTest : public ::testing::Test {
   }
 
   static constexpr int64 kMaxTensorSize = ShapeRefiner::kMaxTensorSize;
+
+  void TestStridedSlice(const PartialTensorShape& input_shape, int begin,
+                        int end, int stride, const char* expected,
+                        int begin_mask = 0, int end_mask = 0,
+                        int ellipsis_mask = 0) {
+    Scope root = Scope::DisabledShapeInferenceScope();
+    auto placeholder =
+        ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
+    auto input = ops::Shape(root, placeholder);
+    auto begin_op = ops::Const(root, {begin});
+    auto end_op = ops::Const(root, {end});
+    auto stride_op = ops::Const(root, {stride});
+    auto slice = ops::StridedSlice(root, input, begin_op, end_op, stride_op,
+                                   ops::StridedSlice::BeginMask(begin_mask)
+                                       .EndMask(end_mask)
+                                       .EllipsisMask(ellipsis_mask));
+    Node* result;
+    TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
+                     .Input(slice.node())
+                     .Finalize(root.graph(), &result));
+
+    ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
+    TF_ASSERT_OK(m.AddNode(placeholder.node()));
+    TF_ASSERT_OK(m.AddNode(input.node()));
+    TF_ASSERT_OK(m.AddNode(begin_op.node()));
+    TF_ASSERT_OK(m.AddNode(end_op.node()));
+    TF_ASSERT_OK(m.AddNode(stride_op.node()));
+    TF_ASSERT_OK(m.AddNode(slice.node()));
+    TF_ASSERT_OK(m.AddNode(result));
+
+    shape_inference::InferenceContext* ctx = m.GetContext(result);
+    EXPECT_EQ(ctx->DebugString(ctx->output(0)), expected);
+  }
 };
 
 namespace {
@@ -1156,6 +1189,73 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) {
             m.AddNode(result).error_message());
 }
 
+TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSlice) {
+  TestStridedSlice(
+      /*input_shape=*/{1, -1, 3, -1, 5},
+      /*begin=*/2,
+      /*end=*/5,
+      /*stride=*/1,
+      /*expected=*/"[3,?,5]");
+}
+
+TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceNegativeStride) {
+  // clang-format off
+  TestStridedSlice(
+      /*input_shape=*/{1, -1, 3, -1, 5},
+      /*begin=*/10,
+      /*end=*/0,
+      /*stride=*/-1,
+      /*expected=*/"[5,?,3,?]");
+  // clang-format on
+}
+
+TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceMasks) {
+  TestStridedSlice(
+      /*input_shape=*/{1, -1, 3, -1, 5},
+      /*begin=*/3,
+      /*end=*/4,
+      /*stride=*/1,
+      /*expected=*/"[1,?,3,?,5]",
+      /*begin_mask=*/1,
+      /*end_mask=*/1);
+}
+
+TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceInvalidMask) {
+  TestStridedSlice(
+      /*input_shape=*/{1, -1, 3},
+      /*begin=*/2,
+      /*end=*/3,
+      /*stride=*/1,
+      /*expected=*/"[?,?,?]",
+      /*begin_mask=*/0,
+      /*end_mask=*/0,
+      /*ellipsis_mask=*/1);
+}
+
+TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceMulti) {
+  Scope root = Scope::DisabledShapeInferenceScope();
+  auto input = ops::Placeholder(root, DT_INT32);
+  auto begin = ops::Const(root, {0, 0});
+  auto end = ops::Const(root, {2, 2});
+  auto stride = ops::Const(root, {1, 1});
+  auto slice = ops::StridedSlice(root, input, begin, end, stride);
+  Node* result;
+  TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
+                   .Input(slice.node())
+                   .Finalize(root.graph(), &result));
+
+  ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
+  TF_ASSERT_OK(m.AddNode(input.node()));
+  TF_ASSERT_OK(m.AddNode(begin.node()));
+  TF_ASSERT_OK(m.AddNode(end.node()));
+  TF_ASSERT_OK(m.AddNode(stride.node()));
+  TF_ASSERT_OK(m.AddNode(slice.node()));
+  TF_ASSERT_OK(m.AddNode(result));
+
+  shape_inference::InferenceContext* ctx = m.GetContext(result);
+  EXPECT_EQ(ctx->DebugString(ctx->output(0)), "?");
+}
+
 namespace {
 
 // Dummy op to test ShapeRefiner util functions
index 2b995e8..3185875 100644 (file)
@@ -605,10 +605,16 @@ Status InferenceContext::Subshape(ShapeHandle s, int64 start,
   return Subshape(s, start, std::numeric_limits<int64>::max() /* end */, out);
 }
 
-Status InferenceContext::Subshape(ShapeHandle s, int64 start_in, int64 end_in,
+Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end,
                                   ShapeHandle* out) {
-  int64 start = start_in;
-  int64 end = end_in;
+  return Subshape(s, start, end, 1 /* stride */, out);
+}
+
+Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end,
+                                  int64 stride, ShapeHandle* out) {
+  int64 start_in = start;
+  int64 end_in = end;
+
   const int32 rank = Rank(s);
   if (start == 0 && ((RankKnown(s) && end >= rank) ||
                      end == std::numeric_limits<int64>::max())) {
@@ -621,6 +627,9 @@ Status InferenceContext::Subshape(ShapeHandle s, int64 start_in, int64 end_in,
 
   if (start > rank) start = rank;
   if (end > rank) end = rank;
+
+  if (stride < 0 && start == rank) --start;
+
   if (start < 0) {
     start = rank + start;
     if (start < 0) {
@@ -638,16 +647,24 @@ Status InferenceContext::Subshape(ShapeHandle s, int64 start_in, int64 end_in,
                                      ", for shape with rank ", rank);
     }
   }
-  if (start > end) {
+  if (stride > 0 && start > end) {
     *out = nullptr;
     return errors::InvalidArgument(
         "Subshape must have computed start <= end, but is ", start, " and ",
         end, " (computed from start ", start_in, " and end ", end_in,
         " over shape with rank ", rank, ")");
+  } else if (stride < 0 && start < end) {
+    *out = nullptr;
+    return errors::InvalidArgument(
+        "Subshape must have computed start >= end since stride is negative, "
+        "but is ",
+        start, " and ", end, " (computed from start ", start_in, " and end ",
+        end_in, " over shape with rank ", rank, " and stride", stride, ")");
   }
+
   std::vector<DimensionHandle> dims;
-  dims.reserve(end - start);
-  for (int i = start; i < end; ++i) {
+  dims.reserve((end - start) / stride);
+  for (int i = start; stride > 0 ? i < end : i > end; i += stride) {
     dims.push_back(Dim(s, i));
   }
   return ReturnCreatedShape(dims, out);
index 9431a62..3f3729d 100644 (file)
@@ -434,6 +434,13 @@ class InferenceContext {
   Status Subshape(ShapeHandle s, int64 start, int64 end,
                   ShapeHandle* out) TF_MUST_USE_RESULT;
 
+  // Returns in <*out> a sub-shape of <s>, with dimensions [start:end:stride].
+  // <start> and <end> can be negative, to index from the end of the shape.
+  // <start> and <end> are set to the rank of <s> if > rank of <s>.
+  // <stride> can be negative, to reverse the <s>.
+  Status Subshape(ShapeHandle s, int64 start, int64 end, int64 stride,
+                  ShapeHandle* out) TF_MUST_USE_RESULT;
+
   // Returns in <*out> the result of appending the dimensions of <s2> to those
   // of <s1>.
   Status Concatenate(ShapeHandle s1, ShapeHandle s2,