Export align_corners to TF Lite
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 6 Feb 2018 02:45:13 +0000 (18:45 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Feb 2018 02:48:59 +0000 (18:48 -0800)
PiperOrigin-RevId: 184622482

tensorflow/contrib/lite/builtin_op_data.h
tensorflow/contrib/lite/kernels/resize_bilinear.cc
tensorflow/contrib/lite/model.cc
tensorflow/contrib/lite/schema/schema.fbs
tensorflow/contrib/lite/schema/schema_generated.h
tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
tensorflow/contrib/lite/toco/tflite/operator.cc
tensorflow/contrib/lite/toco/tflite/operator_test.cc

index a1037a5..5dbeadd 100644 (file)
@@ -151,6 +151,7 @@ typedef struct {
 } TfLiteLSTMParams;
 
 typedef struct {
+  bool align_corners;
 } TfLiteResizeBilinearParams;
 
 typedef struct {
index 4a2101f..c5d60ca 100644 (file)
@@ -75,6 +75,9 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
 
 template <KernelType kernel_type>
 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  auto* params =
+      reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
+
   TfLiteTensor* input = GetInput(context, node, kInputTensor);
   TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
   TfLiteTensor* size = GetInput(context, node, kSizeTensor);
@@ -86,10 +89,11 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
   }
 
   if (output->type == kTfLiteFloat32) {
-#define TF_LITE_RESIZE_BILINEAR(type)                                     \
-  type::ResizeBilinear(GetTensorData<float>(input), GetTensorDims(input), \
-                       GetTensorData<int32>(size), GetTensorDims(size),   \
-                       GetTensorData<float>(output), GetTensorDims(output))
+#define TF_LITE_RESIZE_BILINEAR(type)                                       \
+  type::ResizeBilinear(GetTensorData<float>(input), GetTensorDims(input),   \
+                       GetTensorData<int32>(size), GetTensorDims(size),     \
+                       GetTensorData<float>(output), GetTensorDims(output), \
+                       params->align_corners)
 
     if (kernel_type == kReference) {
       TF_LITE_RESIZE_BILINEAR(reference_ops);
index b36bfce..14b6709 100644 (file)
@@ -469,6 +469,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
       auto* params = MallocPOD<TfLiteResizeBilinearParams>();
       if (auto* schema_params =
               op->builtin_options_as_ResizeBilinearOptions()) {
+        params->align_corners = schema_params->align_corners();
       }
       builtin_data = reinterpret_cast<void*>(params);
       break;
index c0b220e..36cc272 100644 (file)
@@ -273,6 +273,9 @@ table LSTMOptions {
 }
 
 table ResizeBilinearOptions {
+  new_height: int (deprecated);
+  new_width: int (deprecated);
+  align_corners: bool;
 }
 
 // A call operation options
index 29f3a17..e2ac0b9 100755 (executable)
@@ -2626,28 +2626,36 @@ flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(
 
 struct ResizeBilinearOptionsT : public flatbuffers::NativeTable {
   typedef ResizeBilinearOptions TableType;
-  ResizeBilinearOptionsT() {}
+  bool align_corners;
+  ResizeBilinearOptionsT()
+      : align_corners(false) {
+  }
 };
 
-struct ResizeBilinearOptions FLATBUFFERS_FINAL_CLASS
-    : private flatbuffers::Table {
+struct ResizeBilinearOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
   typedef ResizeBilinearOptionsT NativeTableType;
+  enum {
+    VT_ALIGN_CORNERS = 8
+  };
+  bool align_corners() const {
+    return GetField<uint8_t>(VT_ALIGN_CORNERS, 0) != 0;
+  }
   bool Verify(flatbuffers::Verifier &verifier) const {
-    return VerifyTableStart(verifier) && verifier.EndTable();
+    return VerifyTableStart(verifier) &&
+           VerifyField<uint8_t>(verifier, VT_ALIGN_CORNERS) &&
+           verifier.EndTable();
   }
-  ResizeBilinearOptionsT *UnPack(
-      const flatbuffers::resolver_function_t *_resolver = nullptr) const;
-  void UnPackTo(
-      ResizeBilinearOptionsT *_o,
-      const flatbuffers::resolver_function_t *_resolver = nullptr) const;
-  static flatbuffers::Offset<ResizeBilinearOptions> Pack(
-      flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o,
-      const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+  ResizeBilinearOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(ResizeBilinearOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<ResizeBilinearOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
 };
 
 struct ResizeBilinearOptionsBuilder {
   flatbuffers::FlatBufferBuilder &fbb_;
   flatbuffers::uoffset_t start_;
+  void add_align_corners(bool align_corners) {
+    fbb_.AddElement<uint8_t>(ResizeBilinearOptions::VT_ALIGN_CORNERS, static_cast<uint8_t>(align_corners), 0);
+  }
   explicit ResizeBilinearOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
       : fbb_(_fbb) {
     start_ = fbb_.StartTable();
@@ -2661,14 +2669,14 @@ struct ResizeBilinearOptionsBuilder {
 };
 
 inline flatbuffers::Offset<ResizeBilinearOptions> CreateResizeBilinearOptions(
-    flatbuffers::FlatBufferBuilder &_fbb) {
+    flatbuffers::FlatBufferBuilder &_fbb,
+    bool align_corners = false) {
   ResizeBilinearOptionsBuilder builder_(_fbb);
+  builder_.add_align_corners(align_corners);
   return builder_.Finish();
 }
 
-flatbuffers::Offset<ResizeBilinearOptions> CreateResizeBilinearOptions(
-    flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o,
-    const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+flatbuffers::Offset<ResizeBilinearOptions> CreateResizeBilinearOptions(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
 
 struct CallOptionsT : public flatbuffers::NativeTable {
   typedef CallOptions TableType;
index e8b425a..5ea3e21 100644 (file)
@@ -89,9 +89,6 @@ std::map<string, string> kBrokenTests = {
 
     // ResizeBilinear looks completely incompatible with Tensorflow
     {R"(^\/resize_bilinear.*dtype=tf.int32)", "72401107"},
-    {R"(^\/resize_bilinearalign_corners=True,.*,size=\[2,2\])", "72401483"},
-    {R"(^\/resize_bilinearalign_corners=True,.*,size=\[4,3\])", "72401483"},
-    {R"(^\/resize_bilinearalign_corners=True,.*,size=\[5,6\])", "72401483"},
 
     // Transpose only supports 1D-4D input tensors.
     {R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"},
index 461494f..04aaedd 100644 (file)
@@ -540,6 +540,24 @@ class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions,
   }
 };
 
+class ResizeBilinear
+    : public BuiltinOperator<ResizeBilinearOperator,
+                             ::tflite::ResizeBilinearOptions,
+                             ::tflite::BuiltinOptions_ResizeBilinearOptions> {
+ public:
+  using BuiltinOperator::BuiltinOperator;
+  flatbuffers::Offset<TfLiteOptions> WriteOptions(
+      const TocoOperator& op,
+      flatbuffers::FlatBufferBuilder* builder) const override {
+    return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners);
+  }
+
+  void ReadOptions(const TfLiteOptions& options,
+                   TocoOperator* op) const override {
+    op->align_corners = options.align_corners();
+  }
+};
+
 class Squeeze
     : public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions,
                              ::tflite::BuiltinOptions_SqueezeOptions> {
@@ -755,6 +773,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
                                  OperatorType::kTranspose));
   ops.emplace_back(
       new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
+  ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR,
+                                      OperatorType::kResizeBilinear));
   ops.emplace_back(
       new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze));
   ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE,
@@ -787,8 +807,6 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
       new SimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1));
   ops.emplace_back(
       new SimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6));
-  ops.emplace_back(new SimpleOperator<ResizeBilinearOperator>(
-      "RESIZE_BILINEAR", OperatorType::kResizeBilinear));
   ops.emplace_back(new SimpleOperator<LogisticOperator>(
       "LOGISTIC", OperatorType::kLogistic));
   ops.emplace_back(
index 6daa296..796534b 100644 (file)
@@ -104,8 +104,6 @@ TEST_F(OperatorTest, SimpleOperators) {
   CheckSimpleOperator<ReluOperator>("RELU", OperatorType::kRelu);
   CheckSimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1);
   CheckSimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6);
-  CheckSimpleOperator<ResizeBilinearOperator>("RESIZE_BILINEAR",
-                                              OperatorType::kResizeBilinear);
   CheckSimpleOperator<LogisticOperator>("LOGISTIC", OperatorType::kLogistic);
   CheckSimpleOperator<TanhOperator>("TANH", OperatorType::kTanh);
 }
@@ -331,6 +329,14 @@ TEST_F(OperatorTest, BuiltinMul) {
             output_toco_op->fused_activation_function);
 }
 
+TEST_F(OperatorTest, ResizeBilinear) {
+  ResizeBilinearOperator op;
+  op.align_corners = true;
+  auto output_toco_op = SerializeAndDeserialize(
+      GetOperator("RESIZE_BILINEAR", OperatorType::kResizeBilinear), op);
+  EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
+}
+
 TEST_F(OperatorTest, Svdf) {
   SvdfOperator op;
   op.fused_activation_function = FusedActivationFunctionType::kRelu;