kMaxTensorSize, disable_constant_propagation_);
}
+Status ShapeRefiner::EvaluateConstantIntScalarEdge(const Node* node,
+ int dst_idx, bool* evaluated,
+ int64* result) {
+ Tensor scalar;
+ TF_RETURN_IF_ERROR(
+ EvaluateConstantTensorForEdge(node, dst_idx, evaluated, &scalar));
+ if (*evaluated) {
+ DCHECK_EQ(scalar.NumElements(), 1)
+ << "EvaluateConstantIntScalarEdge called on non-scalar edge: "
+ << scalar.NumElements();
+ if (scalar.dtype() == DT_INT32) {
+ *result = scalar.scalar<int32>()();
+ } else {
+ DCHECK_EQ(scalar.dtype(), DT_INT64)
+ << "EvaluateConstantIntScalarEdge called on non-integer edge: "
+ << scalar.dtype();
+ *result = scalar.scalar<int64>()();
+ }
+ }
+ return Status::OK();
+}
+
Status ShapeRefiner::ConstantPartialShape(InferenceContext* target_context,
const Node* node, int dst_idx,
ShapeHandle* result) {
std::vector<DimensionHandle> dims;
// Pack is concatenating its input scalars to form the shape tensor vector.
for (int i = 0; i < src_context->num_inputs(); ++i) {
- Tensor scalar;
- bool evaluated = false;
- TF_RETURN_IF_ERROR(EvaluateConstantTensorForEdge(input_edge->src(), i,
- &evaluated, &scalar));
+ int64 size;
+ bool evaluated;
+ TF_RETURN_IF_ERROR(EvaluateConstantIntScalarEdge(input_edge->src(), i,
+ &evaluated, &size));
if (evaluated) {
- int64 size;
- if (scalar.dtype() == DT_INT32) {
- size = scalar.scalar<int32>()();
- } else if (scalar.dtype() == DT_INT64) {
- size = scalar.scalar<int64>()();
- } else {
- return errors::InvalidArgument("Pack input must be int32 or int64");
- }
dims.push_back(size < 0 ? target_context->UnknownDim()
: target_context->MakeDim(size));
} else {
TF_RETURN_IF_ERROR(
target_context->Concatenate(*result, sub_result, result));
}
+ } else if (src_op == "StridedSlice") {
+ TF_RETURN_IF_ERROR(
+ PartialStridedSliceShape(input_edge->src(), src_context, result));
} else {
Tensor t;
bool evaluated = false;
return Status::OK();
}
+Status ShapeRefiner::PartialStridedSliceShape(Node* slice_node,
+ InferenceContext* ctx,
+ ShapeHandle* result) {
+ // Only attempt to evaluate if begin/end/strides all are scalars.
+ for (int i = 1; i <= 3; ++i) {
+ ShapeHandle input_shape = ctx->input(i);
+ if (ctx->Value(ctx->Dim(input_shape, 0)) != 1) {
+ *result = ctx->UnknownShape();
+ return Status::OK();
+ }
+ }
+
+ int begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask;
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(slice_node->attrs(), "begin_mask", &begin_mask));
+ TF_RETURN_IF_ERROR(GetNodeAttr(slice_node->attrs(), "end_mask", &end_mask));
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(slice_node->attrs(), "ellipsis_mask", &ellipsis_mask));
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(slice_node->attrs(), "new_axis_mask", &new_axis_mask));
+ TF_RETURN_IF_ERROR(
+ GetNodeAttr(slice_node->attrs(), "shrink_axis_mask", &shrink_axis_mask));
+
+ // Only attempt to evaluate if there are no special masks set (note that we
+ // can handle begin/end_mask == 1).
+ if (!(begin_mask == 0 || begin_mask == 1) ||
+ !(end_mask == 0 || end_mask == 1) || ellipsis_mask != 0 ||
+ new_axis_mask != 0 || shrink_axis_mask != 0) {
+ *result = ctx->UnknownShape();
+ return Status::OK();
+ }
+
+ bool evaluated;
+ int64 begin;
+ if (begin_mask == 1) {
+ begin = 0;
+ } else {
+ TF_RETURN_IF_ERROR(
+ EvaluateConstantIntScalarEdge(slice_node, 1, &evaluated, &begin));
+ if (!evaluated) {
+ *result = ctx->UnknownShape();
+ return Status::OK();
+ }
+ }
+
+ int64 end;
+ if (end_mask == 1) {
+ end = std::numeric_limits<int64>::max();
+ } else {
+ TF_RETURN_IF_ERROR(
+ EvaluateConstantIntScalarEdge(slice_node, 2, &evaluated, &end));
+ if (!evaluated) {
+ *result = ctx->UnknownShape();
+ return Status::OK();
+ }
+ }
+
+ int64 stride;
+ TF_RETURN_IF_ERROR(
+ EvaluateConstantIntScalarEdge(slice_node, 3, &evaluated, &stride));
+ if (!evaluated) {
+ *result = ctx->UnknownShape();
+ return Status::OK();
+ }
+
+ // Apply stride to input interpreted as a partial shape.
+ ShapeHandle input;
+ TF_RETURN_IF_ERROR(ConstantPartialShape(ctx, slice_node, 0, &input));
+ TF_RETURN_IF_ERROR(ctx->Subshape(input, begin, end, stride, result));
+ return Status::OK();
+}
+
Status ShapeRefiner::RunShapeFn(const Node* node,
const OpRegistrationData* op_reg_data,
ExtendedInferenceContext* ec) {
bool keep_nested_shapes,
ExtendedInferenceContext* outer_context);
+ // Attempts to evaluate the 'dst_idx'-th input to 'node'. If the input edge
+ // value can be evaluated, 'evaluated' is set to true and the value returned
+ // in 'result'. Otherwise 'evaluated' is set to false.
Status EvaluateConstantTensorForEdge(const Node* node, int dst_idx,
bool* evaluated, Tensor* result);
+ // Wrapper around EvaluateConstantTensorForEdge for scalar int32/int64 input
+ // tensors. The caller is responsible for checking that the specified edge is
+ // scalar and int32 or int64.
+ Status EvaluateConstantIntScalarEdge(const Node* node, int dst_idx,
+ bool* evaluated, int64* result);
+
// This function tries to materialize as much information about the 'node''s
// dst_idx input as a statically computable shape, and the result may be
// partially known, depending on what is statically inferable.
const Node* node, int dst_idx,
shape_inference::ShapeHandle* result);
+ // Implementation of ConstantPartialShape for StridedSlice nodes.
+ Status PartialStridedSliceShape(Node* slice_node,
+ shape_inference::InferenceContext* ctx,
+ shape_inference::ShapeHandle* result);
+
Status RunShapeFn(const Node* node, const OpRegistrationData* op_reg_data,
ExtendedInferenceContext* ec);
}
static constexpr int64 kMaxTensorSize = ShapeRefiner::kMaxTensorSize;
+
+ void TestStridedSlice(const PartialTensorShape& input_shape, int begin,
+ int end, int stride, const char* expected,
+ int begin_mask = 0, int end_mask = 0,
+ int ellipsis_mask = 0) {
+ Scope root = Scope::DisabledShapeInferenceScope();
+ auto placeholder =
+ ops::Placeholder(root, DT_INT32, ops::Placeholder::Shape(input_shape));
+ auto input = ops::Shape(root, placeholder);
+ auto begin_op = ops::Const(root, {begin});
+ auto end_op = ops::Const(root, {end});
+ auto stride_op = ops::Const(root, {stride});
+ auto slice = ops::StridedSlice(root, input, begin_op, end_op, stride_op,
+ ops::StridedSlice::BeginMask(begin_mask)
+ .EndMask(end_mask)
+ .EllipsisMask(ellipsis_mask));
+ Node* result;
+ TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
+ .Input(slice.node())
+ .Finalize(root.graph(), &result));
+
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
+ TF_ASSERT_OK(m.AddNode(placeholder.node()));
+ TF_ASSERT_OK(m.AddNode(input.node()));
+ TF_ASSERT_OK(m.AddNode(begin_op.node()));
+ TF_ASSERT_OK(m.AddNode(end_op.node()));
+ TF_ASSERT_OK(m.AddNode(stride_op.node()));
+ TF_ASSERT_OK(m.AddNode(slice.node()));
+ TF_ASSERT_OK(m.AddNode(result));
+
+ shape_inference::InferenceContext* ctx = m.GetContext(result);
+ EXPECT_EQ(ctx->DebugString(ctx->output(0)), expected);
+ }
};
namespace {
m.AddNode(result).error_message());
}
+TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSlice) {
+ TestStridedSlice(
+ /*input_shape=*/{1, -1, 3, -1, 5},
+ /*begin=*/2,
+ /*end=*/5,
+ /*stride=*/1,
+ /*expected=*/"[3,?,5]");
+}
+
+TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceNegativeStride) {
+ // clang-format off
+ TestStridedSlice(
+ /*input_shape=*/{1, -1, 3, -1, 5},
+ /*begin=*/10,
+ /*end=*/0,
+ /*stride=*/-1,
+ /*expected=*/"[5,?,3,?]");
+ // clang-format on
+}
+
+TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceMasks) {
+ TestStridedSlice(
+ /*input_shape=*/{1, -1, 3, -1, 5},
+ /*begin=*/3,
+ /*end=*/4,
+ /*stride=*/1,
+ /*expected=*/"[1,?,3,?,5]",
+ /*begin_mask=*/1,
+ /*end_mask=*/1);
+}
+
+TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceInvalidMask) {
+ TestStridedSlice(
+ /*input_shape=*/{1, -1, 3},
+ /*begin=*/2,
+ /*end=*/3,
+ /*stride=*/1,
+ /*expected=*/"[?,?,?]",
+ /*begin_mask=*/0,
+ /*end_mask=*/0,
+ /*ellipsis_mask=*/1);
+}
+
+TEST_F(ShapeRefinerTest, ConstantValueAsShape_StridedSliceMulti) {
+ Scope root = Scope::DisabledShapeInferenceScope();
+ auto input = ops::Placeholder(root, DT_INT32);
+ auto begin = ops::Const(root, {0, 0});
+ auto end = ops::Const(root, {2, 2});
+ auto stride = ops::Const(root, {1, 1});
+ auto slice = ops::StridedSlice(root, input, begin, end, stride);
+ Node* result;
+ TF_ASSERT_OK(NodeBuilder("test", "TensorAsShapeInt32")
+ .Input(slice.node())
+ .Finalize(root.graph(), &result));
+
+ ShapeRefiner m(TF_GRAPH_DEF_VERSION, OpRegistry::Global());
+ TF_ASSERT_OK(m.AddNode(input.node()));
+ TF_ASSERT_OK(m.AddNode(begin.node()));
+ TF_ASSERT_OK(m.AddNode(end.node()));
+ TF_ASSERT_OK(m.AddNode(stride.node()));
+ TF_ASSERT_OK(m.AddNode(slice.node()));
+ TF_ASSERT_OK(m.AddNode(result));
+
+ shape_inference::InferenceContext* ctx = m.GetContext(result);
+ EXPECT_EQ(ctx->DebugString(ctx->output(0)), "?");
+}
+
namespace {
// Dummy op to test ShapeRefiner util functions
return Subshape(s, start, std::numeric_limits<int64>::max() /* end */, out);
}
-Status InferenceContext::Subshape(ShapeHandle s, int64 start_in, int64 end_in,
+Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end,
ShapeHandle* out) {
- int64 start = start_in;
- int64 end = end_in;
+ return Subshape(s, start, end, 1 /* stride */, out);
+}
+
+Status InferenceContext::Subshape(ShapeHandle s, int64 start, int64 end,
+ int64 stride, ShapeHandle* out) {
+ int64 start_in = start;
+ int64 end_in = end;
+
const int32 rank = Rank(s);
if (start == 0 && ((RankKnown(s) && end >= rank) ||
end == std::numeric_limits<int64>::max())) {
if (start > rank) start = rank;
if (end > rank) end = rank;
+
+ if (stride < 0 && start == rank) --start;
+
if (start < 0) {
start = rank + start;
if (start < 0) {
", for shape with rank ", rank);
}
}
- if (start > end) {
+ if (stride > 0 && start > end) {
*out = nullptr;
return errors::InvalidArgument(
"Subshape must have computed start <= end, but is ", start, " and ",
end, " (computed from start ", start_in, " and end ", end_in,
" over shape with rank ", rank, ")");
+ } else if (stride < 0 && start < end) {
+ *out = nullptr;
+ return errors::InvalidArgument(
+ "Subshape must have computed start >= end since stride is negative, "
+ "but is ",
+ start, " and ", end, " (computed from start ", start_in, " and end ",
+ end_in, " over shape with rank ", rank, " and stride", stride, ")");
}
+
std::vector<DimensionHandle> dims;
- dims.reserve(end - start);
- for (int i = start; i < end; ++i) {
+ dims.reserve((end - start) / stride);
+ for (int i = start; stride > 0 ? i < end : i > end; i += stride) {
dims.push_back(Dim(s, i));
}
return ReturnCreatedShape(dims, out);
Status Subshape(ShapeHandle s, int64 start, int64 end,
ShapeHandle* out) TF_MUST_USE_RESULT;
+ // Returns in <*out> a sub-shape of <s>, with dimensions [start:end:stride].
+ // <start> and <end> can be negative, to index from the end of the shape.
+ // <start> and <end> are set to the rank of <s> if > rank of <s>.
+ // <stride> can be negative, to reverse the <s>.
+ Status Subshape(ShapeHandle s, int64 start, int64 end, int64 stride,
+ ShapeHandle* out) TF_MUST_USE_RESULT;
+
// Returns in <*out> the result of appending the dimensions of <s2> to those
// of <s1>.
Status Concatenate(ShapeHandle s1, ShapeHandle s2,