From 2b211b681ac6264c61372d10c496e234bf2eda9b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 7 Mar 2018 08:52:39 -0800 Subject: [PATCH] Add support for the "DEQUANTIZE" op. This cover only ops that are generated by TOCO in order to handle UINT8 input to floating-point models. PiperOrigin-RevId: 188182372 --- tensorflow/contrib/lite/builtin_ops.h | 1 + tensorflow/contrib/lite/kernels/BUILD | 14 +++ tensorflow/contrib/lite/kernels/dequantize.cc | 77 +++++++++++++ tensorflow/contrib/lite/kernels/dequantize_test.cc | 65 +++++++++++ tensorflow/contrib/lite/kernels/register.cc | 2 + tensorflow/contrib/lite/model.cc | 1 + tensorflow/contrib/lite/nnapi_delegate.cc | 1 + tensorflow/contrib/lite/schema/schema.fbs | 6 +- tensorflow/contrib/lite/schema/schema_generated.h | 121 ++++++++++++++++++++- 9 files changed, 282 insertions(+), 6 deletions(-) create mode 100644 tensorflow/contrib/lite/kernels/dequantize.cc create mode 100644 tensorflow/contrib/lite/kernels/dequantize_test.cc diff --git a/tensorflow/contrib/lite/builtin_ops.h b/tensorflow/contrib/lite/builtin_ops.h index 7e08500..2218ea8 100644 --- a/tensorflow/contrib/lite/builtin_ops.h +++ b/tensorflow/contrib/lite/builtin_ops.h @@ -32,6 +32,7 @@ typedef enum { kTfLiteBuiltinConcatenation = 2, kTfLiteBuiltinConv2d = 3, kTfLiteBuiltinDepthwiseConv2d = 4, + kTfLiteBuiltinDequantize = 6, kTfLiteBuiltinEmbeddingLookup = 7, kTfLiteBuiltinFullyConnected = 9, kTfLiteBuiltinHashtableLookup = 10, diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index a6be410..8e9d427 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -121,6 +121,7 @@ cc_library( "concatenation.cc", "conv.cc", "depthwise_conv.cc", + "dequantize.cc", "div.cc", "embedding_lookup.cc", "embedding_lookup_sparse.cc", @@ -296,6 +297,19 @@ tf_cc_test( ) tf_cc_test( + name = "dequantize_test", + size = "small", + srcs = ["dequantize_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_absl//absl/memory", + "@com_google_googletest//:gtest", + ], +) + +tf_cc_test( name = "basic_rnn_test", size = "small", srcs = ["basic_rnn_test.cc"], diff --git a/tensorflow/contrib/lite/kernels/dequantize.cc b/tensorflow/contrib/lite/kernels/dequantize.cc new file mode 100644 index 0000000..e685f24 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/dequantize.cc @@ -0,0 +1,77 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace dequantize { + +struct OpContext { + OpContext(TfLiteContext* context, TfLiteNode* node) { + input = GetInput(context, node, 0); + output = GetOutput(context, node, 0); + } + TfLiteTensor* input; + TfLiteTensor* output; +}; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + OpContext op_context(context, node); + + TF_LITE_ENSURE(context, op_context.input->type == kTfLiteUInt8); + + op_context.output->type = kTfLiteFloat32; + return context->ResizeTensor(context, op_context.output, + TfLiteIntArrayCopy(op_context.input->dims)); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + OpContext op_context(context, node); + + auto zero_point = op_context.input->params.zero_point; + auto scale = op_context.input->params.scale; + + optimized_ops::Dequantize(GetTensorData(op_context.input), + GetTensorDims(op_context.input), zero_point, scale, + GetTensorData(op_context.output), + GetTensorDims(op_context.output)); + return kTfLiteOk; +} + +} // namespace dequantize + +TfLiteRegistration* Register_DEQUANTIZE_OPT() { + static TfLiteRegistration r = {nullptr, nullptr, dequantize::Prepare, + dequantize::Eval}; + return &r; +} + +TfLiteRegistration* Register_DEQUANTIZE() { return Register_DEQUANTIZE_OPT(); } + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/dequantize_test.cc b/tensorflow/contrib/lite/kernels/dequantize_test.cc new file mode 100644 index 0000000..fcd7420 --- /dev/null +++ b/tensorflow/contrib/lite/kernels/dequantize_test.cc @@ -0,0 +1,65 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class DequantizeOpModel : public SingleOpModel { + public: + DequantizeOpModel(std::initializer_list shape, float min, float max) { + input_ = AddInput({TensorType_UINT8, shape, min, max}); + output_ = AddOutput({TensorType_FLOAT32, shape}); + SetBuiltinOp(BuiltinOperator_DEQUANTIZE, BuiltinOptions_DequantizeOptions, + CreateDequantizeOptions(builder_).Union()); + + BuildInterpreter({GetShape(input_)}); + } + + void SetInput(std::initializer_list data) { + PopulateTensor(input_, data); + } + + std::vector GetOutput() { return ExtractVector(output_); } + + private: + int input_; + int output_; +}; + +TEST(SplitOpTest, FourDimensional) { + DequantizeOpModel m({2, 5}, -63.5, 64); + + m.SetInput({0, 1, 2, 3, 4, 251, 252, 253, 254, 255}); + m.Invoke(); + EXPECT_THAT(m.GetOutput(), + ElementsAreArray(ArrayFloatNear( + {-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64}))); +} + +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 06b7ce4..9537b79 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -66,6 +66,7 @@ TfLiteRegistration* Register_EXP(); TfLiteRegistration* Register_TOPK_V2(); TfLiteRegistration* Register_LOG_SOFTMAX(); TfLiteRegistration* Register_CAST(); +TfLiteRegistration* Register_DEQUANTIZE(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -121,6 +122,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_TOPK_V2, Register_TOPK_V2()); AddBuiltin(BuiltinOperator_LOG_SOFTMAX, Register_LOG_SOFTMAX()); AddBuiltin(BuiltinOperator_CAST, Register_CAST()); + AddBuiltin(BuiltinOperator_DEQUANTIZE, Register_DEQUANTIZE()); } TfLiteRegistration* BuiltinOpResolver::FindOp( diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 141d04a..8c456e7 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -288,6 +288,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, case BuiltinOperator_TOPK_V2: case BuiltinOperator_LOG_SOFTMAX: case BuiltinOperator_CAST: + case BuiltinOperator_DEQUANTIZE: break; case BuiltinOperator_LSH_PROJECTION: { TfLiteLSHProjectionParams* params = diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 80036d8..9d00d96 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -346,6 +346,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_STRIDED_SLICE: case tflite::BuiltinOperator_EXP: case tflite::BuiltinOperator_LOG_SOFTMAX: + case tflite::BuiltinOperator_DEQUANTIZE: case tflite::BuiltinOperator_DELEGATE: case tflite::BuiltinOperator_CAST: FATAL("Op code %d is currently not delegated to NNAPI", builtin); diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 5f617a7..04387fe 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -75,7 +75,7 @@ enum BuiltinOperator : byte { CONV_2D = 3, DEPTHWISE_CONV_2D = 4, // DEPTH_TO_SPACE = 5, - // DEQUANTIZE = 6, + DEQUANTIZE = 6, EMBEDDING_LOOKUP = 7, // FLOOR = 8, FULLY_CONNECTED = 9, @@ -171,6 +171,7 @@ union BuiltinOptions { SplitOptions, LogSoftmaxOptions, CastOptions, + DequantizeOptions, } enum Padding : byte { SAME, VALID } @@ -379,6 +380,9 @@ table LogSoftmaxOptions { table CastOptions { } +table DequantizeOptions { +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index fcacc98..b922de2 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -142,6 +142,9 @@ struct LogSoftmaxOptionsT; struct CastOptions; struct CastOptionsT; +struct DequantizeOptions; +struct DequantizeOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -204,6 +207,7 @@ enum BuiltinOperator { BuiltinOperator_CONCATENATION = 2, BuiltinOperator_CONV_2D = 3, BuiltinOperator_DEPTHWISE_CONV_2D = 4, + BuiltinOperator_DEQUANTIZE = 6, BuiltinOperator_EMBEDDING_LOOKUP = 7, BuiltinOperator_FULLY_CONNECTED = 9, BuiltinOperator_HASHTABLE_LOOKUP = 10, @@ -254,13 +258,14 @@ enum BuiltinOperator { BuiltinOperator_MAX = BuiltinOperator_CAST }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[51] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[52] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, BuiltinOperator_CONCATENATION, BuiltinOperator_CONV_2D, BuiltinOperator_DEPTHWISE_CONV_2D, + BuiltinOperator_DEQUANTIZE, BuiltinOperator_EMBEDDING_LOOKUP, BuiltinOperator_FULLY_CONNECTED, BuiltinOperator_HASHTABLE_LOOKUP, @@ -319,7 +324,7 @@ inline const char **EnumNamesBuiltinOperator() { "CONV_2D", "DEPTHWISE_CONV_2D", "", - "", + "DEQUANTIZE", "EMBEDDING_LOOKUP", "", "FULLY_CONNECTED", @@ -416,11 +421,12 @@ enum BuiltinOptions { BuiltinOptions_SplitOptions = 35, BuiltinOptions_LogSoftmaxOptions = 36, BuiltinOptions_CastOptions = 37, + BuiltinOptions_DequantizeOptions = 38, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_CastOptions + BuiltinOptions_MAX = BuiltinOptions_DequantizeOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[38] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[39] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -459,7 +465,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[38] { BuiltinOptions_TopKV2Options, BuiltinOptions_SplitOptions, BuiltinOptions_LogSoftmaxOptions, - BuiltinOptions_CastOptions + BuiltinOptions_CastOptions, + BuiltinOptions_DequantizeOptions }; return values; } @@ -504,6 +511,7 @@ inline const char **EnumNamesBuiltinOptions() { "SplitOptions", "LogSoftmaxOptions", "CastOptions", + "DequantizeOptions", nullptr }; return names; @@ -666,6 +674,10 @@ template<> struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_CastOptions; }; +template<> struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_DequantizeOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -993,6 +1005,14 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_CastOptions ? reinterpret_cast(value) : nullptr; } + DequantizeOptionsT *AsDequantizeOptions() { + return type == BuiltinOptions_DequantizeOptions ? + reinterpret_cast(value) : nullptr; + } + const DequantizeOptionsT *AsDequantizeOptions() const { + return type == BuiltinOptions_DequantizeOptions ? + reinterpret_cast(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -3696,6 +3716,46 @@ inline flatbuffers::Offset CreateCastOptions( flatbuffers::Offset CreateCastOptions(flatbuffers::FlatBufferBuilder &_fbb, const CastOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct DequantizeOptionsT : public flatbuffers::NativeTable { + typedef DequantizeOptions TableType; + DequantizeOptionsT() { + } +}; + +struct DequantizeOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DequantizeOptionsT NativeTableType; + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + verifier.EndTable(); + } + DequantizeOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(DequantizeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack(flatbuffers::FlatBufferBuilder &_fbb, const DequantizeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct DequantizeOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + explicit DequantizeOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + DequantizeOptionsBuilder &operator=(const DequantizeOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateDequantizeOptions( + flatbuffers::FlatBufferBuilder &_fbb) { + DequantizeOptionsBuilder builder_(_fbb); + return builder_.Finish(); +} + +flatbuffers::Offset CreateDequantizeOptions(flatbuffers::FlatBufferBuilder &_fbb, const DequantizeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -3924,6 +3984,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const CastOptions *builtin_options_as_CastOptions() const { return builtin_options_type() == BuiltinOptions_CastOptions ? static_cast(builtin_options()) : nullptr; } + const DequantizeOptions *builtin_options_as_DequantizeOptions() const { + return builtin_options_type() == BuiltinOptions_DequantizeOptions ? static_cast(builtin_options()) : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -4098,6 +4161,10 @@ template<> inline const CastOptions *Operator::builtin_options_as() return builtin_options_as_CastOptions(); } +template<> inline const DequantizeOptions *Operator::builtin_options_as() const { + return builtin_options_as_DequantizeOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -5603,6 +5670,29 @@ inline flatbuffers::Offset CreateCastOptions(flatbuffers::FlatBuffe _fbb); } +inline DequantizeOptionsT *DequantizeOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new DequantizeOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void DequantizeOptions::UnPackTo(DequantizeOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; +} + +inline flatbuffers::Offset DequantizeOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const DequantizeOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateDequantizeOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateDequantizeOptions(flatbuffers::FlatBufferBuilder &_fbb, const DequantizeOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const DequantizeOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + return tflite::CreateDequantizeOptions( + _fbb); +} + inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); UnPackTo(_o, _resolver); @@ -5931,6 +6021,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_DequantizeOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -6097,6 +6191,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_DequantizeOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -6251,6 +6349,10 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast(value); return CreateCastOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_DequantizeOptions: { + auto ptr = reinterpret_cast(value); + return CreateDequantizeOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -6405,6 +6507,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new CastOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_DequantizeOptions: { + value = new DequantizeOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -6597,6 +6703,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_DequantizeOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } value = nullptr; -- 2.7.4