Internal change.
authorNupur Garg <nupurgarg@google.com>
Tue, 30 Jan 2018 23:11:33 +0000 (15:11 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 30 Jan 2018 23:19:19 +0000 (15:19 -0800)
PiperOrigin-RevId: 183900332

tensorflow/contrib/lite/kernels/batch_to_space_nd.cc
tensorflow/contrib/lite/kernels/batch_to_space_nd_test.cc

index d84a77039bfc065051d6439ae98656303dbe48ac..889239f93215a309d5434b209ebfc1f584c47849 100644 (file)
@@ -57,6 +57,7 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
                                 BatchToSpaceNDContext* op_context) {
   TfLiteIntArray* input_size = op_context->input->dims;
   const int* block_shape = GetTensorData<int32>(op_context->block_shape);
+  const int* crops = GetTensorData<int32>(op_context->crops);
 
   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->block_shape),
                     kBlockSizeDimensionNum);
@@ -65,7 +66,13 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
   TF_LITE_ENSURE_EQ(context, NumDimensions(op_context->crops),
                     kSpatialDimensionNum);
 
-  // TODO(ycling): Add crops as part of calculation.
+  // TODO(ycling): Add crops as part of calculation. Remove check for a crops
+  // containing all zeroes.
+  TF_LITE_ENSURE_EQ(context, crops[0], 0);
+  TF_LITE_ENSURE_EQ(context, crops[1], 0);
+  TF_LITE_ENSURE_EQ(context, crops[2], 0);
+  TF_LITE_ENSURE_EQ(context, crops[3], 0);
+
   // Number of batch must be multiple of (block_shape[0] * block_shape[1]).
   TF_LITE_ENSURE_EQ(context,
                     input_size->data[0] % (block_shape[0] * block_shape[1]), 0);
index c9152bf96729c07e4e06a606f193aac269aad260..8485cde1b40066f2070855bca91ea78a9f80e83c 100644 (file)
@@ -119,6 +119,19 @@ TEST(BatchToSpaceNDOpTest, InvalidShapeTest) {
                "Cannot allocate tensors");
 }
 
+TEST(BatchToSpaceNDOpTest, InvalidCropsConstTest) {
+  EXPECT_DEATH(BatchToSpaceNDOpConstModel({3, 2, 2, 1}, {2, 2}, {0, 0, 0, 1}),
+               "1 != 0");
+}
+
+TEST(BatchToSpaceNDOpTest, InvalidCropsDynamicTest) {
+  BatchToSpaceNDOpDynamicModel m({4, 2, 2, 1});
+  m.SetInput({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16});
+  m.SetBlockShape({2, 2});
+  m.SetCrops({0, 0, 1, 0});
+  EXPECT_DEATH(m.Invoke(), "1 != 0");
+}
+
 }  // namespace
 }  // namespace tflite