From 224874002f93fec471e401488e23d97d4f36c4fc Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Fri, 2 Feb 2018 11:24:17 -0800 Subject: [PATCH] Allow ResizeBilinear to resize the output tensor in Prepare(), if the size tensor is const. PiperOrigin-RevId: 184309687 --- tensorflow/contrib/lite/kernels/resize_bilinear.cc | 33 +++++---- .../contrib/lite/kernels/resize_bilinear_test.cc | 81 +++++++++++++++++++--- 2 files changed, 92 insertions(+), 22 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear.cc b/tensorflow/contrib/lite/kernels/resize_bilinear.cc index 9a419af..4a2101f 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear.cc @@ -36,6 +36,17 @@ constexpr int kInputTensor = 0; constexpr int kSizeTensor = 1; constexpr int kOutputTensor = 0; +TfLiteStatus ResizeOutputTensor(TfLiteContext* context, TfLiteTensor* input, + TfLiteTensor* size, TfLiteTensor* output) { + TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); + output_size->data[0] = input->dims->data[0]; + const int32* size_data = GetTensorData(size); + output_size->data[1] = size_data[0]; + output_size->data[2] = size_data[1]; + output_size->data[3] = input->dims->data[3]; + return context->ResizeTensor(context, output, output_size); +} + TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); @@ -55,9 +66,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // integers. output->type = kTfLiteFloat32; - // TODO(ahentz): if the input is constant, we can allocate here. - output->allocation_type = kTfLiteDynamic; - return kTfLiteOk; + if (!IsConstantTensor(size)) { + SetTensorToDynamic(output); + return kTfLiteOk; + } + return ResizeOutputTensor(context, input, size, output); } template @@ -66,15 +79,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* size = GetInput(context, node, kSizeTensor); - // TODO(ahentz): we only need to do this here if it wasn't done in Eval(). - TfLiteIntArray* output_size = TfLiteIntArrayCreate(4); - output_size->data[0] = input->dims->data[0]; - const int32* size_data = GetTensorData(size); - output_size->data[1] = size_data[0]; - output_size->data[2] = size_data[1]; - output_size->data[3] = input->dims->data[3]; - context->ResizeTensor(context, output, output_size); - TfLiteTensorRealloc(output->bytes, output); + if (IsDynamicTensor(output)) { + TF_LITE_ENSURE_OK(context, + ResizeOutputTensor(context, input, size, output)); + TfLiteTensorRealloc(output->bytes, output); + } if (output->type == kTfLiteFloat32) { #define TF_LITE_RESIZE_BILINEAR(type) \ diff --git a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc index 2b1aaf6..4e03f38 100644 --- a/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc +++ b/tensorflow/contrib/lite/kernels/resize_bilinear_test.cc @@ -25,14 +25,24 @@ using ::testing::ElementsAreArray; class ResizeBilinearOpModel : public SingleOpModel { public: - ResizeBilinearOpModel(std::initializer_list input_shape) { - input_ = AddInput(TensorType_FLOAT32); - size_ = AddInput(TensorType_INT32); - output_ = AddOutput(TensorType_FLOAT32); + ResizeBilinearOpModel(const TensorData& input, + std::initializer_list size_data = {}) { + bool const_size = size_data.size() != 0; + input_ = AddInput(input); + if (const_size) { + size_ = AddConstInput(TensorType_INT32, size_data, {2}); + } else { + size_ = AddInput({TensorType_INT32, {2}}); + } + output_ = AddOutput(TensorType_FLOAT32); // Always float. SetBuiltinOp(BuiltinOperator_RESIZE_BILINEAR, BuiltinOptions_ResizeBilinearOptions, CreateResizeBilinearOptions(builder_).Union()); - BuildInterpreter({input_shape, {2}}); + if (const_size) { + BuildInterpreter({GetShape(input_)}); + } else { + BuildInterpreter({GetShape(input_), GetShape(size_)}); + } } void SetInput(std::initializer_list data) { @@ -49,23 +59,33 @@ class ResizeBilinearOpModel : public SingleOpModel { }; TEST(ResizeBilinearOpTest, HorizontalResize) { - ResizeBilinearOpModel m({1, 1, 2, 1}); + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 1, 2, 1}}); m.SetInput({3, 6}); m.SetSize({1, 3}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 1, 2, 1}}, {1, 3}); + const_m.SetInput({3, 6}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 5, 6}))); } TEST(ResizeBilinearOpTest, VerticalResize) { - ResizeBilinearOpModel m({1, 2, 1, 1}); + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 1, 1}}); m.SetInput({3, 9}); m.SetSize({3, 1}); m.Invoke(); EXPECT_THAT(m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 1, 1}}, {3, 1}); + const_m.SetInput({3, 9}); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({3, 7, 9}))); } TEST(ResizeBilinearOpTest, TwoDimensionalResize) { - ResizeBilinearOpModel m({1, 2, 2, 1}); + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 1}}); m.SetInput({ 3, 6, // 9, 12 // @@ -77,10 +97,22 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResize) { 7, 9, 10, // 9, 11, 12, // }))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + }))); } TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { - ResizeBilinearOpModel m({2, 2, 2, 1}); + ResizeBilinearOpModel m({TensorType_FLOAT32, {2, 2, 2, 1}}); m.SetInput({ 3, 6, // 9, 12, // @@ -97,10 +129,27 @@ TEST(ResizeBilinearOpTest, TwoDimensionalResizeWithTwoBatches) { 8, 12, 14, // 10, 14, 16, // }))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {2, 2, 2, 1}}, {3, 3}); + const_m.SetInput({ + 3, 6, // + 9, 12, // + 4, 10, // + 10, 16 // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 5, 6, // + 7, 9, 10, // + 9, 11, 12, // + 4, 8, 10, // + 8, 12, 14, // + 10, 14, 16, // + }))); } TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { - ResizeBilinearOpModel m({1, 2, 2, 2}); + ResizeBilinearOpModel m({TensorType_FLOAT32, {1, 2, 2, 2}}); m.SetInput({ 3, 4, 6, 10, // 9, 10, 12, 16, // @@ -112,6 +161,18 @@ TEST(ResizeBilinearOpTest, ThreeDimensionalResize) { 7, 8, 9, 12, 10, 14, // 9, 10, 11, 14, 12, 16, // }))); + + ResizeBilinearOpModel const_m({TensorType_FLOAT32, {1, 2, 2, 2}}, {3, 3}); + const_m.SetInput({ + 3, 4, 6, 10, // + 9, 10, 12, 16, // + }); + const_m.Invoke(); + EXPECT_THAT(const_m.GetOutput(), ElementsAreArray(ArrayFloatNear({ + 3, 4, 5, 8, 6, 10, // + 7, 8, 9, 12, 10, 14, // + 9, 10, 11, 14, 12, 16, // + }))); } } // namespace -- 2.7.4