From 874cf8e1d332175c8a90d7512f8385e98e2a7377 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 10 May 2018 16:09:00 -0700 Subject: [PATCH] Enable support for crops in BatchToSpaceNd PiperOrigin-RevId: 196186750 --- .../contrib/lite/kernels/batch_to_space_nd.cc | 22 ++++++++++++++-------- .../contrib/lite/kernels/batch_to_space_nd_test.cc | 8 ++++---- .../lite/testing/generated_examples_zip_test.cc | 4 ---- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc index 90edf4f..bd40575 100644 --- a/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd.cc @@ -66,12 +66,10 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops), kSpatialDimensionNum); - // TODO(ycling): Add crops as part of calculation. Remove check for a crops - // containing all zeroes. - TF_LITE_ENSURE_EQ(context, crops[0], 0); - TF_LITE_ENSURE_EQ(context, crops[1], 0); - TF_LITE_ENSURE_EQ(context, crops[2], 0); - TF_LITE_ENSURE_EQ(context, crops[3], 0); + TF_LITE_ENSURE(context, crops[0] >= 0); + TF_LITE_ENSURE(context, crops[1] >= 0); + TF_LITE_ENSURE(context, crops[2] >= 0); + TF_LITE_ENSURE(context, crops[3] >= 0); // Number of batch must be multiple of (block_shape[0] * block_shape[1]). TF_LITE_ENSURE_EQ(context, @@ -79,8 +77,16 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context, const int output_batch_size = input_size->data[0] / (block_shape[0] * block_shape[1]); - const int output_height = input_size->data[1] * block_shape[0]; - const int output_width = input_size->data[2] * block_shape[1]; + + const int crops_top = crops[0]; + const int crops_bottom = crops[1]; + const int crops_left = crops[2]; + const int crops_right = crops[3]; + const int output_height = + input_size->data[1] * block_shape[0] - crops_top - crops_bottom; + const int output_width = + input_size->data[2] * block_shape[1] - crops_left - crops_right; + const int output_channel_size = input_size->data[3]; TfLiteIntArray* output_size = TfLiteIntArrayCopy(input_size); diff --git a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc index 8485cde..95b025c 100644 --- a/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc +++ b/tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc @@ -120,16 +120,16 @@ TEST(BatchToSpaceNDOpTest, InvalidShapeTest) { } TEST(BatchToSpaceNDOpTest, InvalidCropsConstTest) { - EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, 1}), - "1 != 0"); + EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, -1}), + "crops.3. >= 0 was not true."); } TEST(BatchToSpaceNDOpTest, InvalidCropsDynamicTest) { BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1}); m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}); m.SetBlockShape({2, 2}); - m.SetCrops({0, 0, 1, 0}); - EXPECT_DEATH(m.Invoke(), "1 != 0"); + m.SetCrops({0, 0, -1, 0}); + EXPECT_DEATH(m.Invoke(), "crops.2. >= 0 was not true."); } } // namespace diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index a8714af..6ecaf2a 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -63,10 +63,6 @@ std::map kBrokenTests = { // L2Norm only supports tensors with 4D or fewer. {R"(^\/l2normdim=.*,epsilon=.*,input_shape=\[.,.,.,.,.*\])", "67963684"}, - // BatchToSpaceND doesn't support cropping. This catches test cases with - // non-const tensors as crops. - {R"(^\/batch_to_space_nd.*crops=\[\[1,1\],\[1,1\]\])", "70594634"}, - // SpaceToBatchND only supports 4D tensors. {R"(^\/space_to_batch_nd.*input_shape=\[1,4,4,4,1,1\])", "70848787"}, -- 2.7.4