BatchToSpaceNDContext* op_context) {
TfLiteIntArray* input_size = op_context->input->dims;
const int* block_shape = GetTensorData<int32>(op_context->block_shape);
+ const int* crops = GetTensorData<int32>(op_context->crops);
TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape),
kBlockSizeDimensionNum);
TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops),
kSpatialDimensionNum);
- // TODO(ycling): Add crops as part of calculation.
+ // TODO(ycling): Add crops as part of calculation. Remove check for a crops
+ // containing all zeroes.
+ TF_LITE_ENSURE_EQ(context, crops[0], 0);
+ TF_LITE_ENSURE_EQ(context, crops[1], 0);
+ TF_LITE_ENSURE_EQ(context, crops[2], 0);
+ TF_LITE_ENSURE_EQ(context, crops[3], 0);
+
// Number of batch must be multiple of (block_shape[0] * block_shape[1]).
TF_LITE_ENSURE_EQ(context,
input_size->data[0] % (block_shape[0] * block_shape[1]), 0);
"Cannot allocate tensors");
}
+TEST(BatchToSpaceNDOpTest, InvalidCropsConstTest) {
+ EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, 1}),
+ "1 != 0");
+}
+
+TEST(BatchToSpaceNDOpTest, InvalidCropsDynamicTest) {
+ BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1});
+ m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
+ m.SetBlockShape({2, 2});
+ m.SetCrops({0, 0, 1, 0});
+ EXPECT_DEATH(m.Invoke(), "1 != 0");
+}
+
} // namespace
} // namespace tflite