From 1af09b57ef663d4ab0c02a00e2af1f1e2819d32f Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Mon, 7 May 2018 17:28:41 -0700 Subject: [PATCH] Add logic for StridedSlice ops in ShapeRefiner::ConstantPartialShape(). This mimics the logic in tensor_util.constant_value_as_shape, allowing the C++ shape inference code to infer more shapes than it could before. This change also adds an optional stride argument to InferenceContext::Subshape(). PiperOrigin-RevId: 195749522 --- tensorflow/core/common_runtime/shape_refiner.cc | 113 ++++++++++++++++++--- tensorflow/core/common_runtime/shape_refiner.h | 14 +++ .../core/common_runtime/shape_refiner_test.cc | 100 ++++++++++++++++++ tensorflow/core/framework/shape_inference.cc | 29 ++++-- tensorflow/core/framework/shape_inference.h | 7 ++ 5 files changed, 245 insertions(+), 18 deletions(-) diff --git a/tensorflow/core/common_runtime/shape_refiner.cc b/tensorflow/core/common_runtime/shape_refiner.cc index a077271..fa4d1ed 100644 --- a/tensorflow/core/common_runtime/shape_refiner.cc +++ b/tensorflow/core/common_runtime/shape_refiner.cc @@ -421,6 +421,28 @@ Status ShapeRefiner::EvaluateConstantTensorForEdge(const Node* node, kMaxTensorSize, disable_constant_propagation_); } +Status ShapeRefiner::EvaluateConstantIntScalarEdge(const Node* node, + int dst_idx, bool* evaluated, + int64* result) { + Tensor scalar; + TF_RETURN_IF_ERROR( + EvaluateConstantTensorForEdge(node, dst_idx, evaluated, &scalar)); + if (*evaluated) { + DCHECK_EQ(scalar.NumElements(), 1) + << "EvaluateConstantIntScalarEdge called on non-scalar edge: " + << scalar.NumElements(); + if (scalar.dtype() == DT_INT32) { + *result = scalar.scalar()(); + } else { + DCHECK_EQ(scalar.dtype(), DT_INT64) + << "EvaluateConstantIntScalarEdge called on non-integer edge: " + << scalar.dtype(); + *result = scalar.scalar()(); + } + } + return Status::OK(); +} + Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context, const Node* node, int dst_idx, ShapeHandle* result) { @@ -471,19 +493,11 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context, std::vector dims; // Pack is concatenating its input scalars to form the shape tensor vector. for (int i = 0; i < src_context->num_inputs(); ++i) { - Tensor scalar; - bool evaluated = false; - TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(input_edge->src(), i, - &evaluated, &scalar)); + int64 size; + bool evaluated; + TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(input_edge->src(), i, + &evaluated, &size)); if (evaluated) { - int64 size; - if (scalar.dtype() == DT_INT32) { - size = scalar.scalar()(); - } else if (scalar.dtype() == DT_INT64) { - size = scalar.scalar()(); - } else { - return errors::InvalidArgument("Pack input must be int32 or int64"); - } dims.push_back(size < 0 ? target_context->UnknownDim() : target_context->MakeDim(size)); } else { @@ -513,6 +527,9 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context, TF_RETURN_IF_ERROR( target_context->Concatenate(*result, sub_result, result)); } + } else if (src_op == "StridedSlice") { + TF_RETURN_IF_ERROR( + PartialStridedSliceShape(input_edge->src(), src_context, result)); } else { Tensor t; bool evaluated = false; @@ -524,6 +541,78 @@ Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context, return Status::OK(); } +Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node, + InferenceContext* ctx, + ShapeHandle* result) { + // Only attempt to evaluate if begin/end/strides all are scalars. + for (int i = 1; i <= 3; ++i) { + ShapeHandle input_shape = ctx->input(i); + if (ctx->Value(ctx->Dim(input_shape, 0)) != 1) { + *result = ctx->UnknownShape(); + return Status::OK(); + } + } + + int begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask; + TF_RETURN_IF_ERROR( + GetNodeAttr(slice_node->attrs(), "begin_mask", &begin_mask)); + TF_RETURN_IF_ERROR(GetNodeAttr(slice_node->attrs(), "end_mask", &end_mask)); + TF_RETURN_IF_ERROR( + GetNodeAttr(slice_node->attrs(), "ellipsis_mask", &ellipsis_mask)); + TF_RETURN_IF_ERROR( + GetNodeAttr(slice_node->attrs(), "new_axis_mask", &new_axis_mask)); + TF_RETURN_IF_ERROR( + GetNodeAttr(slice_node->attrs(), "shrink_axis_mask", &shrink_axis_mask)); + + // Only attempt to evaluate if there are no special masks set (note that we + // can handle begin/end_mask == 1). + if (!(begin_mask == 0 || begin_mask == 1) || + !(end_mask == 0 || end_mask == 1) || ellipsis_mask != 0 || + new_axis_mask != 0 || shrink_axis_mask != 0) { + *result = ctx->UnknownShape(); + return Status::OK(); + } + + bool evaluated; + int64 begin; + if (begin_mask == 1) { + begin = 0; + } else { + TF_RETURN_IF_ERROR( + EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated, &begin)); + if (!evaluated) { + *result = ctx->UnknownShape(); + return Status::OK(); + } + } + + int64 end; + if (end_mask == 1) { + end = std::numeric_limits::max(); + } else { + TF_RETURN_IF_ERROR( + EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated, &end)); + if (!evaluated) { + *result = ctx->UnknownShape(); + return Status::OK(); + } + } + + int64 stride; + TF_RETURN_IF_ERROR( + EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated, &stride)); + if (!evaluated) { + *result = ctx->UnknownShape(); + return Status::OK(); + } + + // Apply stride to input interpreted as a partial shape. + ShapeHandle input; + TF_RETURN_IF_ERROR(ConstantPartialShape(ctx, slice_node, 0, &input)); + TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result)); + return Status::OK(); +} + Status ShapeRefiner::RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data, ExtendedInferenceContext* ec) { diff --git a/tensorflow/core/common_runtime/shape_refiner.h b/tensorflow/core/common_runtime/shape_refiner.h index d49c437..9c96dcb 100644 --- a/tensorflow/core/common_runtime/shape_refiner.h +++ b/tensorflow/core/common_runtime/shape_refiner.h @@ -215,9 +215,18 @@ class ShapeRefiner { bool keep_nested_shapes, ExtendedInferenceContext* outer_context); + // Attempts to evaluate the 'dst_idx'-th input to 'node'. If the input edge + // value can be evaluated, 'evaluated' is set to true and the value returned + // in 'result'. Otherwise 'evaluated' is set to false. Status EvaluateConstantTensorForEdge(const Node* node, int dst_idx, bool* evaluated, Tensor* result); + // Wrapper around EvaluateConstantTensorForEdge for scalar int32/int64 input + // tensors. The caller is responsible for checking that the specified edge is + // scalar and int32 or int64. + Status EvaluateConstantIntScalarEdge(const Node* node, int dst_idx, + bool* evaluated, int64* result); + // This function tries to materialize as much information about the 'node''s // dst_idx input as a statically computable shape, and the result may be // partially known, depending on what is statically inferable. @@ -243,6 +252,11 @@ class ShapeRefiner { const Node* node, int dst_idx, shape_inference::ShapeHandle* result); + // Implementation of ConstantPartialShape for StridedSlice nodes. + Status PartialStridedSliceShape(Node* slice_node, + shape_inference::InferenceContext* ctx, + shape_inference::ShapeHandle* result); + Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data, ExtendedInferenceContext* ec); diff --git a/tensorflow/core/common_runtime/shape_refiner_test.cc b/tensorflow/core/common_runtime/shape_refiner_test.cc index f48638a..8b9657e 100644 --- a/tensorflow/core/common_runtime/shape_refiner_test.cc +++ b/tensorflow/core/common_runtime/shape_refiner_test.cc @@ -60,6 +60,39 @@ class ShapeRefinerTest : public ::testing::Test { } static constexpr int64 kMaxTensorSize = ShapeRefiner::kMaxTensorSize; + + void TestStridedSlice(const PartialTensorShape& input_shape, int begin, + int end, int stride, const char* expected, + int begin_mask = 0, int end_mask = 0, + int ellipsis_mask = 0) { + Scope root = Scope::DisabledShapeInferenceScope(); + auto placeholder = + ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape)); + auto input = ops::Shape(root, placeholder); + auto begin_op = ops::Const(root, {begin}); + auto end_op = ops::Const(root, {end}); + auto stride_op = ops::Const(root, {stride}); + auto slice = ops::StridedSlice(root, input, begin_op, end_op, stride_op, + ops::StridedSlice::BeginMask(begin_mask) + .EndMask(end_mask) + .EllipsisMask(ellipsis_mask)); + Node* result; + TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32") + .Input(slice.node()) + .Finalize(root.graph(), &result)); + + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); + TF_ASSERT_OK(m.AddNode(placeholder.node())); + TF_ASSERT_OK(m.AddNode(input.node())); + TF_ASSERT_OK(m.AddNode(begin_op.node())); + TF_ASSERT_OK(m.AddNode(end_op.node())); + TF_ASSERT_OK(m.AddNode(stride_op.node())); + TF_ASSERT_OK(m.AddNode(slice.node())); + TF_ASSERT_OK(m.AddNode(result)); + + shape_inference::InferenceContext* ctx = m.GetContext(result); + EXPECT_EQ(ctx->DebugString(ctx->output(0)), expected); + } }; namespace { @@ -1156,6 +1189,73 @@ TEST_F(ShapeRefinerTest, ConstantValueAsShape_ConcatInvalidDimValue) { m.AddNode(result).error_message()); } +TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSlice) { + TestStridedSlice( + /*input_shape=*/{1, -1, 3, -1, 5}, + /*begin=*/2, + /*end=*/5, + /*stride=*/1, + /*expected=*/"[3,?,5]"); +} + +TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceNegativeStride) { + // clang-format off + TestStridedSlice( + /*input_shape=*/{1, -1, 3, -1, 5}, + /*begin=*/10, + /*end=*/0, + /*stride=*/-1, + /*expected=*/"[5,?,3,?]"); + // clang-format on +} + +TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceMasks) { + TestStridedSlice( + /*input_shape=*/{1, -1, 3, -1, 5}, + /*begin=*/3, + /*end=*/4, + /*stride=*/1, + /*expected=*/"[1,?,3,?,5]", + /*begin_mask=*/1, + /*end_mask=*/1); +} + +TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceInvalidMask) { + TestStridedSlice( + /*input_shape=*/{1, -1, 3}, + /*begin=*/2, + /*end=*/3, + /*stride=*/1, + /*expected=*/"[?,?,?]", + /*begin_mask=*/0, + /*end_mask=*/0, + /*ellipsis_mask=*/1); +} + +TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceMulti) { + Scope root = Scope::DisabledShapeInferenceScope(); + auto input = ops::Placeholder(root, DT_INT32); + auto begin = ops::Const(root, {0, 0}); + auto end = ops::Const(root, {2, 2}); + auto stride = ops::Const(root, {1, 1}); + auto slice = ops::StridedSlice(root, input, begin, end, stride); + Node* result; + TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32") + .Input(slice.node()) + .Finalize(root.graph(), &result)); + + ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global()); + TF_ASSERT_OK(m.AddNode(input.node())); + TF_ASSERT_OK(m.AddNode(begin.node())); + TF_ASSERT_OK(m.AddNode(end.node())); + TF_ASSERT_OK(m.AddNode(stride.node())); + TF_ASSERT_OK(m.AddNode(slice.node())); + TF_ASSERT_OK(m.AddNode(result)); + + shape_inference::InferenceContext* ctx = m.GetContext(result); + EXPECT_EQ(ctx->DebugString(ctx->output(0)), "?"); +} + namespace { // Dummy op to test ShapeRefiner util functions diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index 2b995e8..3185875 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -605,10 +605,16 @@ Status InferenceContext::Subshape(ShapeHandle s, int64 start, return Subshape(s, start, std::numeric_limits::max() /* end */, out); } -Status InferenceContext::Subshape(ShapeHandle s, int64 start_in, int64 end_in, +Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end, ShapeHandle* out) { - int64 start = start_in; - int64 end = end_in; + return Subshape(s, start, end, 1 /* stride */, out); +} + +Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end, + int64 stride, ShapeHandle* out) { + int64 start_in = start; + int64 end_in = end; + const int32 rank = Rank(s); if (start == 0 && ((RankKnown(s) && end >= rank) || end == std::numeric_limits::max())) { @@ -621,6 +627,9 @@ Status InferenceContext::Subshape(ShapeHandle s, int64 start_in, int64 end_in, if (start > rank) start = rank; if (end > rank) end = rank; + + if (stride < 0 && start == rank) --start; + if (start < 0) { start = rank + start; if (start < 0) { @@ -638,16 +647,24 @@ Status InferenceContext::Subshape(ShapeHandle s, int64 start_in, int64 end_in, ", for shape with rank ", rank); } } - if (start > end) { + if (stride > 0 && start > end) { *out = nullptr; return errors::InvalidArgument( "Subshape must have computed start <= end, but is ", start, " and ", end, " (computed from start ", start_in, " and end ", end_in, " over shape with rank ", rank, ")"); + } else if (stride < 0 && start < end) { + *out = nullptr; + return errors::InvalidArgument( + "Subshape must have computed start >= end since stride is negative, " + "but is ", + start, " and ", end, " (computed from start ", start_in, " and end ", + end_in, " over shape with rank ", rank, " and stride", stride, ")"); } + std::vector dims; - dims.reserve(end - start); - for (int i = start; i < end; ++i) { + dims.reserve((end - start) / stride); + for (int i = start; stride > 0 ? i < end : i > end; i += stride) { dims.push_back(Dim(s, i)); } return ReturnCreatedShape(dims, out); diff --git a/tensorflow/core/framework/shape_inference.h b/tensorflow/core/framework/shape_inference.h index 9431a62..3f3729d 100644 --- a/tensorflow/core/framework/shape_inference.h +++ b/tensorflow/core/framework/shape_inference.h @@ -434,6 +434,13 @@ class InferenceContext { Status Subshape(ShapeHandle s, int64 start, int64 end, ShapeHandle* out) TF_MUST_USE_RESULT; + // Returns in <*out> a sub-shape of , with dimensions [start:end:stride]. + // and can be negative, to index from the end of the shape. + // and are set to the rank of if > rank of . + // can be negative, to reverse the . + Status Subshape(ShapeHandle s, int64 start, int64 end, int64 stride, + ShapeHandle* out) TF_MUST_USE_RESULT; + // Returns in <*out> the result of appending the dimensions of to those // of . Status Concatenate(ShapeHandle s1, ShapeHandle s2, -- 2.7.4