} TfLiteLSTMParams;
typedef struct {
+ bool align_corners;
} TfLiteResizeBilinearParams;
typedef struct {
template <KernelType kernel_type>
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+ auto* params =
+ reinterpret_cast<TfLiteResizeBilinearParams*>(node->builtin_data);
+
TfLiteTensor* input = GetInput(context, node, kInputTensor);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
TfLiteTensor* size = GetInput(context, node, kSizeTensor);
}
if (output->type == kTfLiteFloat32) {
-#define TF_LITE_RESIZE_BILINEAR(type) \
- type::ResizeBilinear(GetTensorData<float>(input), GetTensorDims(input), \
- GetTensorData<int32>(size), GetTensorDims(size), \
- GetTensorData<float>(output), GetTensorDims(output))
+#define TF_LITE_RESIZE_BILINEAR(type) \
+ type::ResizeBilinear(GetTensorData<float>(input), GetTensorDims(input), \
+ GetTensorData<int32>(size), GetTensorDims(size), \
+ GetTensorData<float>(output), GetTensorDims(output), \
+ params->align_corners)
if (kernel_type == kReference) {
TF_LITE_RESIZE_BILINEAR(reference_ops);
auto* params = MallocPOD<TfLiteResizeBilinearParams>();
if (auto* schema_params =
op->builtin_options_as_ResizeBilinearOptions()) {
+ params->align_corners = schema_params->align_corners();
}
builtin_data = reinterpret_cast<void*>(params);
break;
}
table ResizeBilinearOptions {
+ new_height: int (deprecated);
+ new_width: int (deprecated);
+ align_corners: bool;
}
// A call operation options
struct ResizeBilinearOptionsT : public flatbuffers::NativeTable {
typedef ResizeBilinearOptions TableType;
- ResizeBilinearOptionsT() {}
+ bool align_corners;
+ ResizeBilinearOptionsT()
+ : align_corners(false) {
+ }
};
-struct ResizeBilinearOptions FLATBUFFERS_FINAL_CLASS
- : private flatbuffers::Table {
+struct ResizeBilinearOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
typedef ResizeBilinearOptionsT NativeTableType;
+ enum {
+ VT_ALIGN_CORNERS = 8
+ };
+ bool align_corners() const {
+ return GetField<uint8_t>(VT_ALIGN_CORNERS, 0) != 0;
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
- return VerifyTableStart(verifier) && verifier.EndTable();
+ return VerifyTableStart(verifier) &&
+ VerifyField<uint8_t>(verifier, VT_ALIGN_CORNERS) &&
+ verifier.EndTable();
}
- ResizeBilinearOptionsT *UnPack(
- const flatbuffers::resolver_function_t *_resolver = nullptr) const;
- void UnPackTo(
- ResizeBilinearOptionsT *_o,
- const flatbuffers::resolver_function_t *_resolver = nullptr) const;
- static flatbuffers::Offset<ResizeBilinearOptions> Pack(
- flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o,
- const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+ ResizeBilinearOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(ResizeBilinearOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<ResizeBilinearOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
};
struct ResizeBilinearOptionsBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
+ void add_align_corners(bool align_corners) {
+ fbb_.AddElement<uint8_t>(ResizeBilinearOptions::VT_ALIGN_CORNERS, static_cast<uint8_t>(align_corners), 0);
+ }
explicit ResizeBilinearOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
};
inline flatbuffers::Offset<ResizeBilinearOptions> CreateResizeBilinearOptions(
- flatbuffers::FlatBufferBuilder &_fbb) {
+ flatbuffers::FlatBufferBuilder &_fbb,
+ bool align_corners = false) {
ResizeBilinearOptionsBuilder builder_(_fbb);
+ builder_.add_align_corners(align_corners);
return builder_.Finish();
}
-flatbuffers::Offset<ResizeBilinearOptions> CreateResizeBilinearOptions(
- flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o,
- const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+flatbuffers::Offset<ResizeBilinearOptions> CreateResizeBilinearOptions(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
struct CallOptionsT : public flatbuffers::NativeTable {
typedef CallOptions TableType;
// ResizeBilinear looks completely incompatible with Tensorflow
{R"(^\/resize_bilinear.*dtype=tf.int32)", "72401107"},
- {R"(^\/resize_bilinearalign_corners=True,.*,size=\[2,2\])", "72401483"},
- {R"(^\/resize_bilinearalign_corners=True,.*,size=\[4,3\])", "72401483"},
- {R"(^\/resize_bilinearalign_corners=True,.*,size=\[5,6\])", "72401483"},
// Transpose only supports 1D-4D input tensors.
{R"(^\/transpose.*input_shape=\[.,.,.,.,.\])", "71545879"},
}
};
+class ResizeBilinear
+ : public BuiltinOperator<ResizeBilinearOperator,
+ ::tflite::ResizeBilinearOptions,
+ ::tflite::BuiltinOptions_ResizeBilinearOptions> {
+ public:
+ using BuiltinOperator::BuiltinOperator;
+ flatbuffers::Offset<TfLiteOptions> WriteOptions(
+ const TocoOperator& op,
+ flatbuffers::FlatBufferBuilder* builder) const override {
+ return ::tflite::CreateResizeBilinearOptions(*builder, op.align_corners);
+ }
+
+ void ReadOptions(const TfLiteOptions& options,
+ TocoOperator* op) const override {
+ op->align_corners = options.align_corners();
+ }
+};
+
class Squeeze
: public BuiltinOperator<SqueezeOperator, ::tflite::SqueezeOptions,
::tflite::BuiltinOptions_SqueezeOptions> {
OperatorType::kTranspose));
ops.emplace_back(
new Mean(::tflite::BuiltinOperator_MEAN, OperatorType::kMean));
+ ops.emplace_back(new ResizeBilinear(::tflite::BuiltinOperator_RESIZE_BILINEAR,
+ OperatorType::kResizeBilinear));
ops.emplace_back(
new Squeeze(::tflite::BuiltinOperator_SQUEEZE, OperatorType::kSqueeze));
ops.emplace_back(new StridedSlice(::tflite::BuiltinOperator_STRIDED_SLICE,
new SimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1));
ops.emplace_back(
new SimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6));
- ops.emplace_back(new SimpleOperator<ResizeBilinearOperator>(
- "RESIZE_BILINEAR", OperatorType::kResizeBilinear));
ops.emplace_back(new SimpleOperator<LogisticOperator>(
"LOGISTIC", OperatorType::kLogistic));
ops.emplace_back(
CheckSimpleOperator<ReluOperator>("RELU", OperatorType::kRelu);
CheckSimpleOperator<Relu1Operator>("RELU_N1_TO_1", OperatorType::kRelu1);
CheckSimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6);
- CheckSimpleOperator<ResizeBilinearOperator>("RESIZE_BILINEAR",
- OperatorType::kResizeBilinear);
CheckSimpleOperator<LogisticOperator>("LOGISTIC", OperatorType::kLogistic);
CheckSimpleOperator<TanhOperator>("TANH", OperatorType::kTanh);
}
output_toco_op->fused_activation_function);
}
+TEST_F(OperatorTest, ResizeBilinear) {
+ ResizeBilinearOperator op;
+ op.align_corners = true;
+ auto output_toco_op = SerializeAndDeserialize(
+ GetOperator("RESIZE_BILINEAR", OperatorType::kResizeBilinear), op);
+ EXPECT_EQ(op.align_corners, output_toco_op->align_corners);
+}
+
TEST_F(OperatorTest, Svdf) {
SvdfOperator op;
op.fused_activation_function = FusedActivationFunctionType::kRelu;