Allow ResizeBilinear to resize the output tensor in Prepare(), if the size tensor...
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 2 Feb 2018 19:24:17 +0000 (11:24 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 2 Feb 2018 19:32:17 +0000 (11:32 -0800)
PiperOrigin-RevId: 184309687

tensorflow/contrib/lite/kernels/resize_bilinear.cc
tensorflow/contrib/lite/kernels/resize_bilinear_test.cc

index 9a419af..4a2101f 100644 (file)
@@ -36,6 +36,17 @@ constexpr int kInputTensor = 0;
 constexpr int kSizeTensor = 1;
 constexpr int kOutputTensor = 0;
 
+TfLiteStatus ResizeOutputTensor(TfLiteContext* context, TfLiteTensor* input,
+                                TfLiteTensor* size, TfLiteTensor* output) {
+  TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
+  output_size->data[0] = input->dims->data[0];
+  const int32* size_data = GetTensorData<int32>(size);
+  output_size->data[1] = size_data[0];
+  output_size->data[2] = size_data[1];
+  output_size->data[3] = input->dims->data[3];
+  return context->ResizeTensor(context, output, output_size);
+}
+
 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
   TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
   TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
@@ -55,9 +66,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
   // integers.
   output->type = kTfLiteFloat32;
 
-  // TODO(ahentz): if the input is constant, we can allocate here.
-  output->allocation_type = kTfLiteDynamic;
-  return kTfLiteOk;
+  if (!IsConstantTensor(size)) {
+    SetTensorToDynamic(output);
+    return kTfLiteOk;
+  }
+  return ResizeOutputTensor(context, input, size, output);
 }
 
 template <KernelType kernel_type>
@@ -66,15 +79,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
   TfLiteTensor* size = GetInput(context, node, kSizeTensor);
 
-  // TODO(ahentz): we only need to do this here if it wasn't done in Eval().
-  TfLiteIntArray* output_size = TfLiteIntArrayCreate(4);
-  output_size->data[0] = input->dims->data[0];
-  const int32* size_data = GetTensorData<int32>(size);
-  output_size->data[1] = size_data[0];
-  output_size->data[2] = size_data[1];
-  output_size->data[3] = input->dims->data[3];
-  context->ResizeTensor(context, output, output_size);
-  TfLiteTensorRealloc(output->bytes, output);
+  if (IsDynamicTensor(output)) {
+    TF_LITE_ENSURE_OK(context,
+                      ResizeOutputTensor(context, input, size, output));
+    TfLiteTensorRealloc(output->bytes, output);
+  }
 
   if (output->type == kTfLiteFloat32) {
 #define TF_LITE_RESIZE_BILINEAR(type)                                     \
index 2b1aaf6..4e03f38 100644 (file)
@@ -25,14 +25,24 @@ using ::testing::ElementsAreArray;
 
 class ResizeBilinearOpModel : public SingleOpModel {
  public:
-  ResizeBilinearOpModel(std::initializer_list<int> input_shape) {
-    input_ = AddInput(TensorType_FLOAT32);
-    size_ = AddInput(TensorType_INT32);
-    output_ = AddOutput(TensorType_FLOAT32);
+  ResizeBilinearOpModel(const TensorData& input,
+                        std::initializer_list<int> size_data = {}) {
+    bool const_size = size_data.size() != 0;
+    input_ = AddInput(input);
+    if (const_size) {
+      size_ = AddConstInput(TensorType_INT32, size_data, {2});
+    } else {
+      size_ = AddInput({TensorType_INT32, {2}});
+    }
+    output_ = AddOutput(TensorType_FLOAT32);  // Always float.
     SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR,
                  BuiltinOptions_ResizeBilinearOptions,
                  CreateResizeBilinearOptions(builder_).Union());
-    BuildInterpreter({input_shape, {2}});
+    if (const_size) {
+      BuildInterpreter({GetShape(input_)});
+    } else {
+      BuildInterpreter({GetShape(input_), GetShape(size_)});
+    }
   }
 
   void SetInput(std::initializer_list<float> data) {
@@ -49,23 +59,33 @@ class ResizeBilinearOpModel : public SingleOpModel {
 };
 
 TEST(ResizeBilinearOpTest, HorizontalResize) {
-  ResizeBilinearOpModel m({1, 1, 2, 1});
+  ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}});
   m.SetInput({3, 6});
   m.SetSize({1, 3});
   m.Invoke();
   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6})));
+
+  ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3});
+  const_m.SetInput({3, 6});
+  const_m.Invoke();
+  EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6})));
 }
 
 TEST(ResizeBilinearOpTest, VerticalResize) {
-  ResizeBilinearOpModel m({1, 2, 1, 1});
+  ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}});
   m.SetInput({3, 9});
   m.SetSize({3, 1});
   m.Invoke();
   EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9})));
+
+  ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1});
+  const_m.SetInput({3, 9});
+  const_m.Invoke();
+  EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9})));
 }
 
 TEST(ResizeBilinearOpTest, TwoDimensionalResize) {
-  ResizeBilinearOpModel m({1, 2, 2, 1});
+  ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}});
   m.SetInput({
       3, 6,  //
       9, 12  //
@@ -77,10 +97,22 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResize) {
                                  7, 9, 10,   //
                                  9, 11, 12,  //
                              })));
+
+  ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3});
+  const_m.SetInput({
+      3, 6,  //
+      9, 12  //
+  });
+  const_m.Invoke();
+  EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+                                       3, 5, 6,    //
+                                       7, 9, 10,   //
+                                       9, 11, 12,  //
+                                   })));
 }
 
 TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
-  ResizeBilinearOpModel m({2, 2, 2, 1});
+  ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}});
   m.SetInput({
       3, 6,   //
       9, 12,  //
@@ -97,10 +129,27 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) {
                                  8, 12, 14,   //
                                  10, 14, 16,  //
                              })));
+
+  ResizeBilinearOpModel const_m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3});
+  const_m.SetInput({
+      3, 6,   //
+      9, 12,  //
+      4, 10,  //
+      10, 16  //
+  });
+  const_m.Invoke();
+  EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+                                       3, 5, 6,     //
+                                       7, 9, 10,    //
+                                       9, 11, 12,   //
+                                       4, 8, 10,    //
+                                       8, 12, 14,   //
+                                       10, 14, 16,  //
+                                   })));
 }
 
 TEST(ResizeBilinearOpTest, ThreeDimensionalResize) {
-  ResizeBilinearOpModel m({1, 2, 2, 2});
+  ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}});
   m.SetInput({
       3, 4, 6, 10,    //
       9, 10, 12, 16,  //
@@ -112,6 +161,18 @@ TEST(ResizeBilinearOpTest, ThreeDimensionalResize) {
                                  7, 8, 9, 12, 10, 14,    //
                                  9, 10, 11, 14, 12, 16,  //
                              })));
+
+  ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3});
+  const_m.SetInput({
+      3, 4, 6, 10,    //
+      9, 10, 12, 16,  //
+  });
+  const_m.Invoke();
+  EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({
+                                       3, 4, 5, 8, 6, 10,      //
+                                       7, 8, 9, 12, 10, 14,    //
+                                       9, 10, 11, 14, 12, 16,  //
+                                   })));
 }
 
 }  // namespace