From ea703f4e0e72d1e016f8157e206dcc9e80602862 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 14 Dec 2017 14:48:02 -0800 Subject: [PATCH] Implementation of Gather in TfLite PiperOrigin-RevId: 179101363 --- tensorflow/contrib/lite/builtin_op_data.h | 4 + tensorflow/contrib/lite/kernels/BUILD | 14 ++ tensorflow/contrib/lite/kernels/gather.cc | 130 ++++++++++++++ .../contrib/lite/kernels/gather_test.cc | 121 +++++++++++++ tensorflow/contrib/lite/kernels/register.cc | 2 + tensorflow/contrib/lite/kernels/test_util.cc | 13 ++ tensorflow/contrib/lite/kernels/test_util.h | 3 + tensorflow/contrib/lite/model.cc | 10 ++ tensorflow/contrib/lite/nnapi_delegate.cc | 1 + tensorflow/contrib/lite/schema/schema.fbs | 7 +- .../contrib/lite/schema/schema_generated.h | 162 +++++++++++++++++- tensorflow/contrib/lite/testing/BUILD | 1 + .../contrib/lite/testing/generate_examples.py | 45 ++++- .../testing/generated_examples_zip_test.cc | 1 + .../propagate_fixed_sizes.cc | 1 + tensorflow/contrib/lite/toco/model.h | 3 +- .../contrib/lite/toco/tflite/operator.cc | 20 ++- .../contrib/lite/toco/tflite/operator_test.cc | 8 +- 18 files changed, 533 insertions(+), 13 deletions(-) create mode 100644 tensorflow/contrib/lite/kernels/gather.cc create mode 100644 tensorflow/contrib/lite/kernels/gather_test.cc mode change 100755 => 100644 tensorflow/contrib/lite/schema/schema_generated.h diff --git a/tensorflow/contrib/lite/builtin_op_data.h b/tensorflow/contrib/lite/builtin_op_data.h index 7249d124e9..548864a1e9 100644 --- a/tensorflow/contrib/lite/builtin_op_data.h +++ b/tensorflow/contrib/lite/builtin_op_data.h @@ -165,6 +165,10 @@ typedef struct { TfLiteCombinerType combiner; } TfLiteEmbeddingLookupSparseParams; +typedef struct { + int axis; +} TfLiteGatherParams; + #ifdef __cplusplus } // extern "C" #endif // __cplusplus diff --git a/tensorflow/contrib/lite/kernels/BUILD b/tensorflow/contrib/lite/kernels/BUILD index 32bbe2670e..3908960c33 100644 --- a/tensorflow/contrib/lite/kernels/BUILD +++ b/tensorflow/contrib/lite/kernels/BUILD @@ -83,6 +83,7 @@ cc_library( "embedding_lookup.cc", "embedding_lookup_sparse.cc", "fully_connected.cc", + "gather.cc", "hashtable_lookup.cc", "kernel_util.cc", "l2norm.cc", @@ -263,6 +264,19 @@ tf_cc_test( ], ) +tf_cc_test( + name = "gather_test", + size = "small", + srcs = ["gather_test.cc"], + deps = [ + ":builtin_ops", + "//tensorflow/contrib/lite:builtin_op_data", + "//tensorflow/contrib/lite:framework", + "//tensorflow/contrib/lite/kernels:test_util", + "@com_google_googletest//:gtest", + ], +) + tf_cc_test( name = "resize_bilinear_test", size = "small", diff --git a/tensorflow/contrib/lite/kernels/gather.cc b/tensorflow/contrib/lite/kernels/gather.cc new file mode 100644 index 0000000000..f8df797daf --- /dev/null +++ b/tensorflow/contrib/lite/kernels/gather.cc @@ -0,0 +1,130 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/context.h" +#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h" +#include "tensorflow/contrib/lite/kernels/internal/tensor.h" +#include "tensorflow/contrib/lite/kernels/kernel_util.h" +#include "tensorflow/contrib/lite/kernels/op_macros.h" +#include "tensorflow/contrib/lite/string_util.h" + +namespace tflite { +namespace ops { +namespace builtin { +namespace gather { +constexpr int kInputTensor = 0; +constexpr int kInputPositions = 1; +constexpr int kOutputTensor = 0; + +TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + TF_LITE_ENSURE_EQ(context, NumInputs(node), 2); + TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); + + const auto* params = + reinterpret_cast(node->builtin_data); + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* positions = GetInput(context, node, kInputPositions); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + // Only INT32 positions are supported. + TF_LITE_ENSURE_EQ(context, positions->type, kTfLiteInt32); + // Check that input and output types match. + TF_LITE_ENSURE_EQ(context, input->type, output->type); + // TODO(mgubin): only 1D positions are currently supported. + TF_LITE_ENSURE_EQ(context, NumDimensions(positions), 1); + // TODO(mgubin): Only default axis == 0 is supported. + // Check conditions for different types. + switch (input->type) { + case kTfLiteFloat32: + case kTfLiteUInt8: + case kTfLiteInt32: { + // Fully supported by reference_ops::Gather. + } break; + + case kTfLiteString: { + // Only 1D input is supported. + TF_LITE_ENSURE_EQ(context, NumDimensions(input), 1); + } break; + default: + context->ReportError(context, + "Only float32 and string types are supported"); + return kTfLiteError; + } + const int num_dimensions = + NumDimensions(input) + NumDimensions(positions) - 1; + TF_LITE_ENSURE(context, params->axis < num_dimensions); + TfLiteIntArray* output_shape = TfLiteIntArrayCreate(num_dimensions); + int output_index = 0; + for (int i = 0; i < params->axis; ++i) { + output_shape->data[output_index++] = input->dims->data[i]; + } + for (int i = 0; i < positions->dims->size; ++i) { + output_shape->data[output_index++] = positions->dims->data[i]; + } + for (int i = params->axis + 1; i < input->dims->size; ++i) { + output_shape->data[output_index++] = input->dims->data[i]; + } + return context->ResizeTensor(context, output, output_shape); +} + +TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { + TfLiteTensor* input = GetInput(context, node, kInputTensor); + TfLiteTensor* positions = GetInput(context, node, kInputPositions); + TfLiteTensor* output = GetOutput(context, node, kOutputTensor); + const int input_rank = NumDimensions(input); +#define TF_LITE_GATHER(data_type, index_type) \ + optimized_ops::Gather( \ + GetTensorData(input), GetTensorDims(input), input_rank, \ + GetTensorData(positions), GetTensorDims(positions), \ + GetTensorData(output), GetTensorDims(output)); + switch (input->type) { + case kTfLiteFloat32: + TF_LITE_GATHER(float, int32_t); + break; + case kTfLiteUInt8: + TF_LITE_GATHER(uint8_t, int32_t); + break; + case kTfLiteInt32: + TF_LITE_GATHER(int32_t, int32_t); + break; + case kTfLiteString: { + DynamicBuffer buffer; + const int32* indexes = positions->data.i32; + const int num_strings = GetStringCount(input); + for (int i = 0; i < positions->dims->data[0]; ++i) { + const int pos = indexes[i]; + TF_LITE_ENSURE(context, pos < num_strings); + const auto string_ref = GetString(input, pos); + buffer.AddString(string_ref.str, string_ref.len); + } + buffer.WriteToTensor(output); + } break; + default: + return kTfLiteError; + } +#undef TF_LITE_GATHER + return kTfLiteOk; +} +} // namespace gather + +TfLiteRegistration* Register_GATHER() { + static TfLiteRegistration r = {nullptr, nullptr, gather::Prepare, + gather::Eval}; + return &r; +} + +} // namespace builtin +} // namespace ops +} // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/gather_test.cc b/tensorflow/contrib/lite/kernels/gather_test.cc new file mode 100644 index 0000000000..6343d3b4ef --- /dev/null +++ b/tensorflow/contrib/lite/kernels/gather_test.cc @@ -0,0 +1,121 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include "tensorflow/contrib/lite/builtin_op_data.h" +#include "tensorflow/contrib/lite/interpreter.h" +#include "tensorflow/contrib/lite/kernels/register.h" +#include "tensorflow/contrib/lite/kernels/test_util.h" +#include "tensorflow/contrib/lite/model.h" + +namespace tflite { +namespace { + +using ::testing::ElementsAreArray; + +class GatherOpModel : public SingleOpModel { + public: + GatherOpModel(std::initializer_list input_shape, TensorType input_type, + std::initializer_list positions_shape) { + input_ = AddInput(input_type); + positions_ = AddInput(TensorType_INT32); + output_ = AddOutput(input_type); + SetBuiltinOp(BuiltinOperator_GATHER, BuiltinOptions_GatherOptions, + CreateGatherOptions(builder_, 0).Union()); + BuildInterpreter({input_shape, positions_shape}); + } + + void SetInputFloat(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInputUint8(std::initializer_list data) { + PopulateTensor(input_, data); + } + + void SetInput(std::initializer_list data) { + PopulateStringTensor(input_, data); + } + + void SetPositions(std::initializer_list data) { + PopulateTensor(positions_, data); + } + + std::vector GetOutputFloat() { return ExtractVector(output_); } + std::vector GetOutputUint8() { + return ExtractVector(output_); + } + std::vector GetOutputString() { + return ExtractVector(output_); + } + std::vector GetOutputShape() { return GetTensorShape(output_); } + + protected: + int input_; + int positions_; + int output_; +}; + +TEST(GatherOpTest, Shuffle) { + GatherOpModel m({2, 2}, TensorType_FLOAT32, {2}); + m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); + m.SetPositions({1, 0}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), + ElementsAreArray(ArrayFloatNear({0.7, 0.8, -2, 0.2}))); +} + +TEST(FloatGatherOpTest, Duplicate) { + GatherOpModel m({1, 2, 2}, TensorType_FLOAT32, {2}); + m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); + m.SetPositions({0, 0}); + m.Invoke(); + EXPECT_THAT( + m.GetOutputFloat(), + ElementsAreArray(ArrayFloatNear({-2, 0.2, 0.7, 0.8, -2, 0.2, 0.7, 0.8}))); +} + +TEST(FloatGatherOpTest, Slice) { + GatherOpModel m({4, 1}, TensorType_FLOAT32, {2}); + m.SetInputFloat({-2.0, 0.2, 0.7, 0.8}); + m.SetPositions({1, 3}); + m.Invoke(); + EXPECT_THAT(m.GetOutputFloat(), ElementsAreArray(ArrayFloatNear({0.2, 0.8}))); +} + +TEST(Uint8tGatherOpTest, Shuffle) { + GatherOpModel m({2, 2}, TensorType_UINT8, {2}); + m.SetInputUint8({133, 134, 14, 15}); + m.SetPositions({1, 0}); + m.Invoke(); + + EXPECT_THAT(m.GetOutputUint8(), ElementsAreArray({14, 15, 133, 134})); +} + +TEST(GatherOpTest, SimpleString) { + GatherOpModel m({3}, TensorType_STRING, {2}); + m.SetInput({"A", "B", "C"}); + m.SetPositions({0, 2}); + m.Invoke(); + ASSERT_THAT(m.GetOutputShape(), ElementsAreArray({2})); + EXPECT_THAT(m.GetOutputString(), ElementsAreArray({"A", "C"})); +} +} // namespace +} // namespace tflite + +int main(int argc, char** argv) { + ::tflite::LogToStderr(); + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/contrib/lite/kernels/register.cc b/tensorflow/contrib/lite/kernels/register.cc index 12d360f15c..3d1edeef01 100644 --- a/tensorflow/contrib/lite/kernels/register.cc +++ b/tensorflow/contrib/lite/kernels/register.cc @@ -49,6 +49,7 @@ TfLiteRegistration* Register_RESHAPE(); TfLiteRegistration* Register_RESIZE_BILINEAR(); TfLiteRegistration* Register_SKIP_GRAM(); TfLiteRegistration* Register_SPACE_TO_DEPTH(); +TfLiteRegistration* Register_GATHER(); BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RELU, Register_RELU()); @@ -84,6 +85,7 @@ BuiltinOpResolver::BuiltinOpResolver() { AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR()); AddBuiltin(BuiltinOperator_SKIP_GRAM, Register_SKIP_GRAM()); AddBuiltin(BuiltinOperator_SPACE_TO_DEPTH, Register_SPACE_TO_DEPTH()); + AddBuiltin(BuiltinOperator_GATHER, Register_GATHER()); } TfLiteRegistration* BuiltinOpResolver::FindOp( diff --git a/tensorflow/contrib/lite/kernels/test_util.cc b/tensorflow/contrib/lite/kernels/test_util.cc index f716ba8741..b69f2b3e4b 100644 --- a/tensorflow/contrib/lite/kernels/test_util.cc +++ b/tensorflow/contrib/lite/kernels/test_util.cc @@ -180,4 +180,17 @@ int32_t SingleOpModel::GetTensorSize(int index) const { return total_size; } +template <> +std::vector SingleOpModel::ExtractVector(int index) { + TfLiteTensor* tensor_ptr = interpreter_->tensor(index); + CHECK(tensor_ptr != nullptr); + const int num_strings = GetStringCount(tensor_ptr); + std::vector result; + result.reserve(num_strings); + for (int i = 0; i < num_strings; ++i) { + const auto str = GetString(tensor_ptr, i); + result.emplace_back(str.str, str.len); + } + return result; +} } // namespace tflite diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h index adcdeddbfc..531c1366a8 100644 --- a/tensorflow/contrib/lite/kernels/test_util.h +++ b/tensorflow/contrib/lite/kernels/test_util.h @@ -192,6 +192,9 @@ class SingleOpModel { std::map> custom_registrations_; }; +// Strings have a special implementation that is in test_util.cc +template <> +std::vector SingleOpModel::ExtractVector(int index); } // namespace tflite #endif // THIRD_PARTY_TENSORFLOW_CONTRIB_LITE_KERNELS_TEST_UTIL_H_ diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc index 2f508d9d60..4ef2c942c1 100644 --- a/tensorflow/contrib/lite/model.cc +++ b/tensorflow/contrib/lite/model.cc @@ -508,6 +508,16 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type, builtin_data = reinterpret_cast(params); break; } + case BuiltinOperator_GATHER: { + TfLiteGatherParams* params = MallocPOD(); + params->axis = 0; + if (auto* gather_params = op->builtin_options_as_GatherOptions()) { + params->axis = gather_params->axis(); + } + + builtin_data = reinterpret_cast(params); + break; + } } return builtin_data; } diff --git a/tensorflow/contrib/lite/nnapi_delegate.cc b/tensorflow/contrib/lite/nnapi_delegate.cc index 86f2afbaf2..6b93a70bff 100644 --- a/tensorflow/contrib/lite/nnapi_delegate.cc +++ b/tensorflow/contrib/lite/nnapi_delegate.cc @@ -306,6 +306,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter, case tflite::BuiltinOperator_CALL: case tflite::BuiltinOperator_SKIP_GRAM: case tflite::BuiltinOperator_RELU1: + case tflite::BuiltinOperator_GATHER: FATAL("Op code %d is currently not delegated to NNAPI", builtin); nn_op_type = -1; // set to invalid break; diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index d1302bdc63..8b48543fc8 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -106,6 +106,7 @@ enum BuiltinOperator : byte { EMBEDDING_LOOKUP_SPARSE = 33, PAD = 34, UNIDIRECTIONAL_SEQUENCE_RNN = 35, + GATHER = 36, } // Options for the builtin operators. @@ -132,6 +133,7 @@ union BuiltinOptions { EmbeddingLookupSparseOptions, MulOptions, PadOptions, + GatherOptions, } enum Padding : byte { SAME, VALID } @@ -276,6 +278,10 @@ table EmbeddingLookupSparseOptions { combiner:CombinerType; } +table GatherOptions { + axis: int; +} + // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a // builtin, or a string if the operator is custom. table OperatorCode { @@ -351,4 +357,3 @@ table Model { } root_type Model; - diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h old mode 100755 new mode 100644 index ba645c2764..7de205e1e4 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -1,4 +1,4 @@ -/* Copyright 2015 The TensorFlow Authors. All Rights Reserved. +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -94,6 +94,9 @@ struct SpaceToDepthOptionsT; struct EmbeddingLookupSparseOptions; struct EmbeddingLookupSparseOptionsT; +struct GatherOptions; +struct GatherOptionsT; + struct OperatorCode; struct OperatorCodeT; @@ -172,11 +175,12 @@ enum BuiltinOperator { BuiltinOperator_EMBEDDING_LOOKUP_SPARSE = 33, BuiltinOperator_PAD = 34, BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN = 35, + BuiltinOperator_GATHER = 36, BuiltinOperator_MIN = BuiltinOperator_ADD, - BuiltinOperator_MAX = BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN + BuiltinOperator_MAX = BuiltinOperator_GATHER }; -inline BuiltinOperator (&EnumValuesBuiltinOperator())[33] { +inline BuiltinOperator (&EnumValuesBuiltinOperator())[34] { static BuiltinOperator values[] = { BuiltinOperator_ADD, BuiltinOperator_AVERAGE_POOL_2D, @@ -210,7 +214,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[33] { BuiltinOperator_CUSTOM, BuiltinOperator_EMBEDDING_LOOKUP_SPARSE, BuiltinOperator_PAD, - BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN}; + BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, + BuiltinOperator_GATHER}; return values; } @@ -251,6 +256,7 @@ inline const char **EnumNamesBuiltinOperator() { "EMBEDDING_LOOKUP_SPARSE", "PAD", "UNIDIRECTIONAL_SEQUENCE_RNN", + "GATHER", nullptr}; return names; } @@ -284,11 +290,12 @@ enum BuiltinOptions { BuiltinOptions_EmbeddingLookupSparseOptions = 20, BuiltinOptions_MulOptions = 21, BuiltinOptions_PadOptions = 22, + BuiltinOptions_GatherOptions = 23, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_PadOptions + BuiltinOptions_MAX = BuiltinOptions_GatherOptions }; -inline BuiltinOptions (&EnumValuesBuiltinOptions())[23] { +inline BuiltinOptions (&EnumValuesBuiltinOptions())[24] { static BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -312,7 +319,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[23] { BuiltinOptions_SpaceToDepthOptions, BuiltinOptions_EmbeddingLookupSparseOptions, BuiltinOptions_MulOptions, - BuiltinOptions_PadOptions}; + BuiltinOptions_PadOptions, + BuiltinOptions_GatherOptions}; return values; } @@ -340,6 +348,7 @@ inline const char **EnumNamesBuiltinOptions() { "EmbeddingLookupSparseOptions", "MulOptions", "PadOptions", + "GatherOptions", nullptr}; return names; } @@ -468,6 +477,11 @@ struct BuiltinOptionsTraits { static const BuiltinOptions enum_value = BuiltinOptions_PadOptions; }; +template <> +struct BuiltinOptionsTraits { + static const BuiltinOptions enum_value = BuiltinOptions_GatherOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -735,6 +749,16 @@ struct BuiltinOptionsUnion { ? reinterpret_cast(value) : nullptr; } + GatherOptionsT *AsGatherOptions() { + return type == BuiltinOptions_GatherOptions + ? reinterpret_cast(value) + : nullptr; + } + const GatherOptionsT *AsGatherOptions() const { + return type == BuiltinOptions_GatherOptions + ? reinterpret_cast(value) + : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, @@ -2681,6 +2705,59 @@ CreateEmbeddingLookupSparseOptions( const EmbeddingLookupSparseOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct GatherOptionsT : public flatbuffers::NativeTable { + typedef GatherOptions TableType; + int32_t axis; + GatherOptionsT() : axis(0) {} +}; + +struct GatherOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef GatherOptionsT NativeTableType; + enum { VT_AXIS = 4 }; + int32_t axis() const { return GetField(VT_AXIS, 0); } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_AXIS) && verifier.EndTable(); + } + GatherOptionsT *UnPack( + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo( + GatherOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset Pack( + flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct GatherOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_axis(int32_t axis) { + fbb_.AddElement(GatherOptions::VT_AXIS, axis, 0); + } + explicit GatherOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + GatherOptionsBuilder &operator=(const GatherOptionsBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateGatherOptions( + flatbuffers::FlatBufferBuilder &_fbb, int32_t axis = 0) { + GatherOptionsBuilder builder_(_fbb); + builder_.add_axis(axis); + return builder_.Finish(); +} + +flatbuffers::Offset CreateGatherOptions( + flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct OperatorCodeT : public flatbuffers::NativeTable { typedef OperatorCode TableType; BuiltinOperator builtin_code; @@ -2918,6 +2995,11 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { ? static_cast(builtin_options()) : nullptr; } + const GatherOptions *builtin_options_as_GatherOptions() const { + return builtin_options_type() == BuiltinOptions_GatherOptions + ? static_cast(builtin_options()) + : nullptr; + } const flatbuffers::Vector *custom_options() const { return GetPointer *>(VT_CUSTOM_OPTIONS); } @@ -3074,6 +3156,12 @@ inline const PadOptions *Operator::builtin_options_as() const { return builtin_options_as_PadOptions(); } +template <> +inline const GatherOptions *Operator::builtin_options_as() + const { + return builtin_options_as_GatherOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -4658,6 +4746,45 @@ CreateEmbeddingLookupSparseOptions( return tflite::CreateEmbeddingLookupSparseOptions(_fbb, _combiner); } +inline GatherOptionsT *GatherOptions::UnPack( + const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new GatherOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void GatherOptions::UnPackTo( + GatherOptionsT *_o, + const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { + auto _e = axis(); + _o->axis = _e; + }; +} + +inline flatbuffers::Offset GatherOptions::Pack( + flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + return CreateGatherOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset CreateGatherOptions( + flatbuffers::FlatBufferBuilder &_fbb, const GatherOptionsT *_o, + const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { + flatbuffers::FlatBufferBuilder *__fbb; + const GatherOptionsT *__o; + const flatbuffers::rehasher_function_t *__rehasher; + } _va = {&_fbb, _o, _rehasher}; + (void)_va; + auto _axis = _o->axis; + return tflite::CreateGatherOptions(_fbb, _axis); +} + inline OperatorCodeT *OperatorCode::UnPack( const flatbuffers::resolver_function_t *_resolver) const { auto _o = new OperatorCodeT(); @@ -5134,6 +5261,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, auto ptr = reinterpret_cast(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_GatherOptions: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } default: return false; } @@ -5246,6 +5377,10 @@ inline void *BuiltinOptionsUnion::UnPack( auto ptr = reinterpret_cast(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_GatherOptions: { + auto ptr = reinterpret_cast(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } @@ -5345,6 +5480,10 @@ inline flatbuffers::Offset BuiltinOptionsUnion::Pack( auto ptr = reinterpret_cast(value); return CreatePadOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_GatherOptions: { + auto ptr = reinterpret_cast(value); + return CreateGatherOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } @@ -5454,6 +5593,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) value = new PadOptionsT(*reinterpret_cast(u.value)); break; } + case BuiltinOptions_GatherOptions: { + value = new GatherOptionsT(*reinterpret_cast(u.value)); + break; + } default: break; } @@ -5571,6 +5714,11 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_GatherOptions: { + auto ptr = reinterpret_cast(value); + delete ptr; + break; + } default: break; } diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD index b9c5cbe715..b63c0c058c 100644 --- a/tensorflow/contrib/lite/testing/BUILD +++ b/tensorflow/contrib/lite/testing/BUILD @@ -25,6 +25,7 @@ gen_zipped_test_files( "depthwiseconv.zip", "fully_connected.zip", "fused_batch_norm.zip", + "gather.zip", "global_batch_norm.zip", "l2_pool.zip", "l2norm.zip", diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 4848ca8062..4c01fedb1e 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -94,6 +94,8 @@ KNOWN_BUGS = { r"softmax.*input_shape=\[1,3,4,3\]": "67749831", # SpaceToDepth only supports float32. r"space_to_depth.*(float16|int32|uint8|int64)": "68018134", + # Gather doesn't support int64 indices. + r"gather.*indices_dtype=int64": "XXXX", } @@ -120,7 +122,7 @@ def toco_options(data_types, # to change if data_types[0] == "QUANTIZED_UINT8": inference_type = "QUANTIZED_UINT8" - s = (" --input_types=%s" % ",".join(data_types) + + s = (" --input_data_types=%s" % ",".join(data_types) + " --inference_type=%s" % inference_type + " --input_format=TENSORFLOW_GRAPHDEF" + " --output_format=TFLITE" + " --input_arrays=%s" % ",".join(input_arrays) + @@ -704,6 +706,46 @@ def make_mul_tests(zip_path): make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) +def make_gather_tests(zip_path): + """Make a set of tests to do gather.""" + + test_parameters = [{ + # TODO(mgubin): add string tests when they are supported by Toco. + # TODO(mgubin): add tests for Nd indices when they are supported by + # TfLite. + # TODO(mgubin): add tests for axis != 0 when it is supported by TfLite. + "params_dtype": [tf.float32, tf.int32], + "params_shape": [[10], [1, 2, 20]], + "indices_dtype": [tf.int32], + "indices_shape": [[3], [5]], + "axis": [0], # axis!=0 is GatherV2 + }] + + def build_graph(parameters): + """Build the gather op testing graph.""" + params = tf.placeholder( + dtype=parameters["params_dtype"], + name="params", + shape=parameters["params_shape"]) + indices = tf.placeholder( + dtype=parameters["indices_dtype"], + name="indices", + shape=parameters["indices_shape"]) + out = tf.gather(params, indices, axis=parameters["axis"]) + return [params, indices], [out] + + def build_inputs(parameters, sess, inputs, outputs): + params = create_tensor_data(parameters["params_dtype"], + parameters["params_shape"]) + indices = create_tensor_data(parameters["indices_dtype"], + parameters["indices_shape"], 0, + parameters["params_shape"][0] - 1) + return [params, indices], sess.run( + outputs, feed_dict=dict(zip(inputs, [params, indices]))) + + make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs) + + def make_global_batch_norm_tests(zip_path): """Make a set of tests to do batch_norm_with_global_normalization.""" @@ -1190,6 +1232,7 @@ def main(unused_args): "concat.zip": make_concatenation_tests, "fully_connected.zip": make_fully_connected_tests, "global_batch_norm.zip": make_global_batch_norm_tests, + "gather.zip": make_gather_tests, "fused_batch_norm.zip": make_fused_batch_norm_tests, "l2norm.zip": make_l2norm_tests, "local_response_norm.zip": make_local_response_norm_tests, diff --git a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc index 76e8767617..29f0c68ba4 100644 --- a/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc +++ b/tensorflow/contrib/lite/testing/generated_examples_zip_test.cc @@ -248,6 +248,7 @@ INSTANTIATE_TESTS(conv) INSTANTIATE_TESTS(depthwiseconv) INSTANTIATE_TESTS(fully_connected) INSTANTIATE_TESTS(fused_batch_norm) +INSTANTIATE_TESTS(gather) INSTANTIATE_TESTS(global_batch_norm) INSTANTIATE_TESTS(l2norm) INSTANTIATE_TESTS(l2_pool) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc index 308dadfdeb..786d3da7cf 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc +++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc @@ -821,6 +821,7 @@ void ProcessGatherOperator(Model* model, GatherOperator* op) { // Copy the input dimensions to the output except for dimension 0, // where the dimension of indices_shape is used. + // TODO(mgubin): if axis != 0 this is not true, change when it's supported. auto output_dims = output_array.mutable_shape()->mutable_dims(); output_dims->push_back(indices_shape.dims(0)); for (int dim = 1; dim < input_shape.dimensions_count(); dim++) { diff --git a/tensorflow/contrib/lite/toco/model.h b/tensorflow/contrib/lite/toco/model.h index a481d8801c..a53c751d3c 100644 --- a/tensorflow/contrib/lite/toco/model.h +++ b/tensorflow/contrib/lite/toco/model.h @@ -1208,7 +1208,8 @@ struct FloorOperator : Operator { // TensorFlow equivalent: Gather struct GatherOperator : Operator { GatherOperator() : Operator(OperatorType::kGather) {} - int input_rank; + int axis = 0; + int input_rank = 0; }; // ResizeBilinear operator. It resizes input images with bilinear interpolation. diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 8d25336bb7..7fee47a90b 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -211,6 +211,22 @@ class FullyConnected } }; +class Gather : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + return ::tflite::CreateGatherOptions(*builder, op.axis); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + op->axis = options.axis(); + } +}; + class Svdf : public BuiltinOperator { public: @@ -564,6 +580,8 @@ std::vector> BuildOperatorList() { OperatorType::kDepthwiseConv)); ops.emplace_back(new FullyConnected(::tflite::BuiltinOperator_FULLY_CONNECTED, OperatorType::kFullyConnected)); + ops.emplace_back( + new Gather(::tflite::BuiltinOperator_GATHER, OperatorType::kGather)); ops.emplace_back( new L2Normalization(::tflite::BuiltinOperator_L2_NORMALIZATION, OperatorType::kL2Normalization)); @@ -606,8 +624,6 @@ std::vector> BuildOperatorList() { "DEQUANTIZE", OperatorType::kDequantize)); ops.emplace_back( new SimpleOperator("FLOOR", OperatorType::kFloor)); - ops.emplace_back( - new SimpleOperator("GATHER", OperatorType::kGather)); ops.emplace_back( new SimpleOperator("RELU", OperatorType::kRelu)); ops.emplace_back( diff --git a/tensorflow/contrib/lite/toco/tflite/operator_test.cc b/tensorflow/contrib/lite/toco/tflite/operator_test.cc index fe079e833d..caecbd0325 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator_test.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator_test.cc @@ -101,7 +101,6 @@ TEST_F(OperatorTest, SimpleOperators) { CheckSimpleOperator("DEQUANTIZE", OperatorType::kDequantize); CheckSimpleOperator("FLOOR", OperatorType::kFloor); - CheckSimpleOperator("GATHER", OperatorType::kGather); CheckSimpleOperator("RELU", OperatorType::kRelu); CheckSimpleOperator("RELU1", OperatorType::kRelu1); CheckSimpleOperator("RELU6", OperatorType::kRelu6); @@ -167,6 +166,13 @@ TEST_F(OperatorTest, CustomFullyConnected) { output_toco_op->fused_activation_function); } +TEST_F(OperatorTest, BuiltinGather) { + GatherOperator op; + auto output_toco_op = + SerializeAndDeserialize(GetOperator("GATHER", OperatorType::kGather), op); + ASSERT_NE(nullptr, output_toco_op.get()); +} + TEST_F(OperatorTest, BuiltinL2Pool) { L2PoolOperator op; op.stride_width = 123; -- 2.34.1