Supports op exp (tf.exp) in Toco and Tensorflow Lite.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 14 Feb 2018 22:40:29 +0000 (14:40 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 14 Feb 2018 22:44:24 +0000 (14:44 -0800)
PiperOrigin-RevId: 185747281

18 files changed:
tensorflow/contrib/lite/kernels/BUILD
tensorflow/contrib/lite/kernels/exp.cc [new file with mode: 0644]
tensorflow/contrib/lite/kernels/exp_test.cc [new file with mode: 0644]
tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
tensorflow/contrib/lite/kernels/register.cc
tensorflow/contrib/lite/model.cc
tensorflow/contrib/lite/nnapi_delegate.cc
tensorflow/contrib/lite/schema/schema.fbs
tensorflow/contrib/lite/schema/schema_generated.h
tensorflow/contrib/lite/testing/BUILD
tensorflow/contrib/lite/testing/generate_examples.py
tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
tensorflow/contrib/lite/toco/import_tensorflow.cc
tensorflow/contrib/lite/toco/model.h
tensorflow/contrib/lite/toco/tflite/operator.cc
tensorflow/contrib/lite/toco/tflite/operator_test.cc
tensorflow/contrib/lite/toco/tooling_util.cc

index a8ef0da..5d553de 100644 (file)
@@ -111,6 +111,7 @@ cc_library(
         "div.cc",
         "embedding_lookup.cc",
         "embedding_lookup_sparse.cc",
+        "exp.cc",
         "fully_connected.cc",
         "gather.cc",
         "hashtable_lookup.cc",
@@ -328,6 +329,18 @@ tf_cc_test(
 )
 
 tf_cc_test(
+    name = "exp_test",
+    size = "small",
+    srcs = ["exp_test.cc"],
+    deps = [
+        ":builtin_ops",
+        "//tensorflow/contrib/lite:framework",
+        "//tensorflow/contrib/lite/kernels:test_util",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
+tf_cc_test(
     name = "mean_test",
     size = "small",
     srcs = ["mean_test.cc"],
diff --git a/tensorflow/contrib/lite/kernels/exp.cc b/tensorflow/contrib/lite/kernels/exp.cc
new file mode 100644 (file)
index 0000000..a9e79b7
--- /dev/null
@@ -0,0 +1,92 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <string.h>
+#include <vector>
+#include "tensorflow/contrib/lite/builtin_op_data.h"
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace exp {
+
+// This file has reference implementation of Exp.
+enum KernelType {
+  kReference,
+};
+
+struct ExpContext {
+  ExpContext(TfLiteContext* context, TfLiteNode* node) {
+    input = GetInput(context, node, 0);
+    output = GetOutput(context, node, 0);
+  }
+  TfLiteTensor* input;
+  TfLiteTensor* output;
+};
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+  ExpContext op_context(context, node);
+  TfLiteIntArray* output_dims = TfLiteIntArrayCopy(op_context.input->dims);
+  op_context.output->type = op_context.input->type;
+  return context->ResizeTensor(context, op_context.output, output_dims);
+}
+
+template <KernelType kernel_type>
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  ExpContext op_context(context, node);
+
+#define TF_LITE_EXP(kernel_type, data_type)                               \
+  kernel_type::Exp<data_type>(GetTensorData<data_type>(op_context.input), \
+                              NumElements(op_context.input),              \
+                              GetTensorData<data_type>(op_context.output))
+
+  // TODO(kanlig): supports half, bfloat16, float64, complex64, and complex128.
+  if (kernel_type == kReference) {
+    switch (op_context.input->type) {
+      case kTfLiteFloat32:
+        TF_LITE_EXP(reference_ops, float);
+        break;
+      default:
+        context->ReportError(context,
+                             "Type %d is currently not supported by Exp.",
+                             op_context.input->type);
+        return kTfLiteError;
+    }
+  }
+#undef TF_LITE_EXP
+  return kTfLiteOk;
+}
+
+}  // namespace exp
+
+TfLiteRegistration* Register_EXP_REF() {
+  static TfLiteRegistration r = {nullptr, nullptr, exp::Prepare,
+                                 exp::Eval<exp::kReference>};
+  return &r;
+}
+
+// TODO(kanlig): add optimized implementation of Exp.
+TfLiteRegistration* Register_EXP() { return Register_EXP_REF(); }
+
+}  // namespace builtin
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/exp_test.cc b/tensorflow/contrib/lite/kernels/exp_test.cc
new file mode 100644 (file)
index 0000000..eed6736
--- /dev/null
@@ -0,0 +1,70 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class ExpOpModel : public SingleOpModel {
+ public:
+  ExpOpModel(const TensorData& input, const TensorType& output) {
+    input_ = AddInput(input);
+    output_ = AddOutput(output);
+    SetBuiltinOp(BuiltinOperator_EXP, BuiltinOptions_ExpOptions,
+                 CreateExpOptions(builder_).Union());
+    BuildInterpreter({GetShape(input_)});
+  }
+
+  template <class T>
+  void SetInput(std::initializer_list<T> data) {
+    PopulateTensor(input_, data);
+  }
+
+  template <class T>
+  std::vector<T> GetOutput() {
+    return ExtractVector<T>(output_);
+  }
+  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ protected:
+  int input_;
+  int output_;
+};
+
+TEST(ExpOpTest, FloatTest) {
+  std::initializer_list<float> data = {1.0, 0.0, -1.0, 1.0, 1.0, -1.0};
+  ExpOpModel m({TensorType_FLOAT32, {3, 1, 2}}, TensorType_FLOAT32);
+  m.SetInput<float>(data);
+  m.Invoke();
+  EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({3, 1, 2}));
+  EXPECT_THAT(m.GetOutput<float>(),
+              ElementsAreArray(ArrayFloatNear(
+                  {2.71828, 1, 0.367879, 2.71828, 2.71828, 0.367879})));
+}
+
+}  // namespace
+}  // namespace tflite
+
+int main(int argc, char** argv) {
+  ::tflite::LogToStderr();
+  ::testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
index f18543f..2e03766 100644 (file)
@@ -2763,6 +2763,14 @@ inline void Slice(const T* input_data, const Dims<4>& input_dims,
 }
 
 template <typename T>
+inline void Exp(const T* input_data, const size_t num_elements,
+                T* output_data) {
+  for (size_t idx = 0; idx < num_elements; ++idx) {
+    output_data[idx] = exp(input_data[idx]);
+  }
+}
+
+template <typename T>
 inline void Mean(T* input_data, const int* input_dims, const int input_num_dims,
                  T* output_data, const int* output_dims,
                  const int output_num_dims, const int* axis,
index 1fb779f..0f36507 100644 (file)
@@ -60,6 +60,7 @@ TfLiteRegistration* Register_TRANSPOSE();
 TfLiteRegistration* Register_MEAN();
 TfLiteRegistration* Register_SQUEEZE();
 TfLiteRegistration* Register_STRIDED_SLICE();
+TfLiteRegistration* Register_EXP();
 
 BuiltinOpResolver::BuiltinOpResolver() {
   AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -108,6 +109,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
   AddBuiltin(BuiltinOperator_SUB, Register_SUB());
   AddBuiltin(BuiltinOperator_SQUEEZE, Register_SQUEEZE());
   AddBuiltin(BuiltinOperator_STRIDED_SLICE, Register_STRIDED_SLICE());
+  AddBuiltin(BuiltinOperator_EXP, Register_EXP());
 }
 
 TfLiteRegistration* BuiltinOpResolver::FindOp(
index 14b6709..2ee0cac 100644 (file)
@@ -278,6 +278,7 @@ void* ParseOpData(const Operator* op, BuiltinOperator op_type,
     case BuiltinOperator_RELU_N1_TO_1:
     case BuiltinOperator_RELU6:
     case BuiltinOperator_CONCAT_EMBEDDINGS:
+    case BuiltinOperator_EXP:
       break;
     case BuiltinOperator_LSH_PROJECTION: {
       TfLiteLSHProjectionParams* params =
index da9ceec..77084b4 100644 (file)
@@ -341,6 +341,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
       case tflite::BuiltinOperator_SUB:
       case tflite::BuiltinOperator_SQUEEZE:
       case tflite::BuiltinOperator_STRIDED_SLICE:
+      case tflite::BuiltinOperator_EXP:
         FATAL("Op code %d is currently not delegated to NNAPI", builtin);
         nn_op_type = -1;  // set to invalid
         break;
index 36cc272..ef8f39c 100644 (file)
@@ -120,6 +120,7 @@ enum BuiltinOperator : byte {
   UNIDIRECTIONAL_SEQUENCE_LSTM = 44,
   STRIDED_SLICE = 45,
   BIDIRECTIONAL_SEQUENCE_RNN = 46,
+  EXP = 47,
 }
 
 // Options for the builtin operators.
@@ -156,6 +157,7 @@ union BuiltinOptions {
   SqueezeOptions,
   SequenceRNNOptions,
   StridedSliceOptions,
+  ExpOptions,
 }
 
 enum Padding : byte { SAME, VALID }
@@ -332,6 +334,9 @@ table GatherOptions {
 table TransposeOptions {
 }
 
+table ExpOptions {
+}
+
 table MeanOptions {
   keep_dims: bool;
 }
index e2ac0b9..40b50ba 100755 (executable)
@@ -117,6 +117,9 @@ struct GatherOptionsT;
 struct TransposeOptions;
 struct TransposeOptionsT;
 
+struct ExpOptions;
+struct ExpOptionsT;
+
 struct MeanOptions;
 struct MeanOptionsT;
 
@@ -215,11 +218,12 @@ enum BuiltinOperator {
   BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM = 44,
   BuiltinOperator_STRIDED_SLICE = 45,
   BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN = 46,
+  BuiltinOperator_EXP = 47,
   BuiltinOperator_MIN = BuiltinOperator_ADD,
-  BuiltinOperator_MAX = BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN
+  BuiltinOperator_MAX = BuiltinOperator_EXP
 };
 
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[44] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[45] {
   static BuiltinOperator values[] = {
       BuiltinOperator_ADD,
       BuiltinOperator_AVERAGE_POOL_2D,
@@ -264,7 +268,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[44] {
       BuiltinOperator_SQUEEZE,
       BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM,
       BuiltinOperator_STRIDED_SLICE,
-      BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN};
+      BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
+      BuiltinOperator_EXP};
   return values;
 }
 
@@ -316,6 +321,7 @@ inline const char **EnumNamesBuiltinOperator() {
                                 "UNIDIRECTIONAL_SEQUENCE_LSTM",
                                 "STRIDED_SLICE",
                                 "BIDIRECTIONAL_SEQUENCE_RNN",
+                                "EXP",
                                 nullptr};
   return names;
 }
@@ -359,11 +365,12 @@ enum BuiltinOptions {
   BuiltinOptions_SqueezeOptions = 30,
   BuiltinOptions_SequenceRNNOptions = 31,
   BuiltinOptions_StridedSliceOptions = 32,
+  BuiltinOptions_ExpOptions = 33,
   BuiltinOptions_MIN = BuiltinOptions_NONE,
-  BuiltinOptions_MAX = BuiltinOptions_StridedSliceOptions
+  BuiltinOptions_MAX = BuiltinOptions_ExpOptions
 };
 
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[33] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[34] {
   static BuiltinOptions values[] = {
       BuiltinOptions_NONE,
       BuiltinOptions_Conv2DOptions,
@@ -397,7 +404,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[33] {
       BuiltinOptions_DivOptions,
       BuiltinOptions_SqueezeOptions,
       BuiltinOptions_SequenceRNNOptions,
-      BuiltinOptions_StridedSliceOptions};
+      BuiltinOptions_StridedSliceOptions,
+      BuiltinOptions_ExpOptions};
   return values;
 }
 
@@ -435,6 +443,7 @@ inline const char **EnumNamesBuiltinOptions() {
                                 "SqueezeOptions",
                                 "SequenceRNNOptions",
                                 "StridedSliceOptions",
+                                "ExpOptions",
                                 nullptr};
   return names;
 }
@@ -613,6 +622,11 @@ struct BuiltinOptionsTraits<StridedSliceOptions> {
   static const BuiltinOptions enum_value = BuiltinOptions_StridedSliceOptions;
 };
 
+template <>
+struct BuiltinOptionsTraits<ExpOptions> {
+  static const BuiltinOptions enum_value = BuiltinOptions_ExpOptions;
+};
+
 struct BuiltinOptionsUnion {
   BuiltinOptions type;
   void *value;
@@ -980,6 +994,16 @@ struct BuiltinOptionsUnion {
                ? reinterpret_cast<const StridedSliceOptionsT *>(value)
                : nullptr;
   }
+  ExpOptionsT *AsExpOptions() {
+    return type == BuiltinOptions_ExpOptions
+               ? reinterpret_cast<ExpOptionsT *>(value)
+               : nullptr;
+  }
+  const ExpOptionsT *AsExpOptions() const {
+    return type == BuiltinOptions_ExpOptions
+               ? reinterpret_cast<const ExpOptionsT *>(value)
+               : nullptr;
+  }
 };
 
 bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj,
@@ -2627,16 +2651,13 @@ flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(
 struct ResizeBilinearOptionsT : public flatbuffers::NativeTable {
   typedef ResizeBilinearOptions TableType;
   bool align_corners;
-  ResizeBilinearOptionsT()
-      : align_corners(false) {
-  }
+  ResizeBilinearOptionsT() : align_corners(false) {}
 };
 
-struct ResizeBilinearOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+struct ResizeBilinearOptions FLATBUFFERS_FINAL_CLASS
+    : private flatbuffers::Table {
   typedef ResizeBilinearOptionsT NativeTableType;
-  enum {
-    VT_ALIGN_CORNERS = 8
-  };
+  enum { VT_ALIGN_CORNERS = 8 };
   bool align_corners() const {
     return GetField<uint8_t>(VT_ALIGN_CORNERS, 0) != 0;
   }
@@ -2645,16 +2666,22 @@ struct ResizeBilinearOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Tabl
            VerifyField<uint8_t>(verifier, VT_ALIGN_CORNERS) &&
            verifier.EndTable();
   }
-  ResizeBilinearOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
-  void UnPackTo(ResizeBilinearOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
-  static flatbuffers::Offset<ResizeBilinearOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+  ResizeBilinearOptionsT *UnPack(
+      const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(
+      ResizeBilinearOptionsT *_o,
+      const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<ResizeBilinearOptions> Pack(
+      flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o,
+      const flatbuffers::rehasher_function_t *_rehasher = nullptr);
 };
 
 struct ResizeBilinearOptionsBuilder {
   flatbuffers::FlatBufferBuilder &fbb_;
   flatbuffers::uoffset_t start_;
   void add_align_corners(bool align_corners) {
-    fbb_.AddElement<uint8_t>(ResizeBilinearOptions::VT_ALIGN_CORNERS, static_cast<uint8_t>(align_corners), 0);
+    fbb_.AddElement<uint8_t>(ResizeBilinearOptions::VT_ALIGN_CORNERS,
+                             static_cast<uint8_t>(align_corners), 0);
   }
   explicit ResizeBilinearOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
       : fbb_(_fbb) {
@@ -2669,14 +2696,15 @@ struct ResizeBilinearOptionsBuilder {
 };
 
 inline flatbuffers::Offset<ResizeBilinearOptions> CreateResizeBilinearOptions(
-    flatbuffers::FlatBufferBuilder &_fbb,
-    bool align_corners = false) {
+    flatbuffers::FlatBufferBuilder &_fbb, bool align_corners = false) {
   ResizeBilinearOptionsBuilder builder_(_fbb);
   builder_.add_align_corners(align_corners);
   return builder_.Finish();
 }
 
-flatbuffers::Offset<ResizeBilinearOptions> CreateResizeBilinearOptions(flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+flatbuffers::Offset<ResizeBilinearOptions> CreateResizeBilinearOptions(
+    flatbuffers::FlatBufferBuilder &_fbb, const ResizeBilinearOptionsT *_o,
+    const flatbuffers::rehasher_function_t *_rehasher = nullptr);
 
 struct CallOptionsT : public flatbuffers::NativeTable {
   typedef CallOptions TableType;
@@ -3345,6 +3373,51 @@ flatbuffers::Offset<TransposeOptions> CreateTransposeOptions(
     flatbuffers::FlatBufferBuilder &_fbb, const TransposeOptionsT *_o,
     const flatbuffers::rehasher_function_t *_rehasher = nullptr);
 
+struct ExpOptionsT : public flatbuffers::NativeTable {
+  typedef ExpOptions TableType;
+  ExpOptionsT() {}
+};
+
+struct ExpOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef ExpOptionsT NativeTableType;
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) && verifier.EndTable();
+  }
+  ExpOptionsT *UnPack(
+      const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(
+      ExpOptionsT *_o,
+      const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<ExpOptions> Pack(
+      flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT *_o,
+      const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct ExpOptionsBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  explicit ExpOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+      : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  ExpOptionsBuilder &operator=(const ExpOptionsBuilder &);
+  flatbuffers::Offset<ExpOptions> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<ExpOptions>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<ExpOptions> CreateExpOptions(
+    flatbuffers::FlatBufferBuilder &_fbb) {
+  ExpOptionsBuilder builder_(_fbb);
+  return builder_.Finish();
+}
+
+flatbuffers::Offset<ExpOptions> CreateExpOptions(
+    flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT *_o,
+    const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
 struct MeanOptionsT : public flatbuffers::NativeTable {
   typedef MeanOptions TableType;
   bool keep_dims;
@@ -3858,6 +3931,11 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
                ? static_cast<const StridedSliceOptions *>(builtin_options())
                : nullptr;
   }
+  const ExpOptions *builtin_options_as_ExpOptions() const {
+    return builtin_options_type() == BuiltinOptions_ExpOptions
+               ? static_cast<const ExpOptions *>(builtin_options())
+               : nullptr;
+  }
   const flatbuffers::Vector<uint8_t> *custom_options() const {
     return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
   }
@@ -4071,6 +4149,11 @@ Operator::builtin_options_as<StridedSliceOptions>() const {
   return builtin_options_as_StridedSliceOptions();
 }
 
+template <>
+inline const ExpOptions *Operator::builtin_options_as<ExpOptions>() const {
+  return builtin_options_as_ExpOptions();
+}
+
 struct OperatorBuilder {
   flatbuffers::FlatBufferBuilder &fbb_;
   flatbuffers::uoffset_t start_;
@@ -5449,6 +5532,10 @@ inline void ResizeBilinearOptions::UnPackTo(
     const flatbuffers::resolver_function_t *_resolver) const {
   (void)_o;
   (void)_resolver;
+  {
+    auto _e = align_corners();
+    _o->align_corners = _e;
+  };
 }
 
 inline flatbuffers::Offset<ResizeBilinearOptions> ResizeBilinearOptions::Pack(
@@ -5468,7 +5555,8 @@ inline flatbuffers::Offset<ResizeBilinearOptions> CreateResizeBilinearOptions(
     const flatbuffers::rehasher_function_t *__rehasher;
   } _va = {&_fbb, _o, _rehasher};
   (void)_va;
-  return tflite::CreateResizeBilinearOptions(_fbb);
+  auto _align_corners = _o->align_corners;
+  return tflite::CreateResizeBilinearOptions(_fbb, _align_corners);
 }
 
 inline CallOptionsT *CallOptions::UnPack(
@@ -5935,6 +6023,39 @@ inline flatbuffers::Offset<TransposeOptions> CreateTransposeOptions(
   return tflite::CreateTransposeOptions(_fbb);
 }
 
+inline ExpOptionsT *ExpOptions::UnPack(
+    const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new ExpOptionsT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void ExpOptions::UnPackTo(
+    ExpOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+}
+
+inline flatbuffers::Offset<ExpOptions> ExpOptions::Pack(
+    flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT *_o,
+    const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateExpOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<ExpOptions> CreateExpOptions(
+    flatbuffers::FlatBufferBuilder &_fbb, const ExpOptionsT *_o,
+    const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs {
+    flatbuffers::FlatBufferBuilder *__fbb;
+    const ExpOptionsT *__o;
+    const flatbuffers::rehasher_function_t *__rehasher;
+  } _va = {&_fbb, _o, _rehasher};
+  (void)_va;
+  return tflite::CreateExpOptions(_fbb);
+}
+
 inline MeanOptionsT *MeanOptions::UnPack(
     const flatbuffers::resolver_function_t *_resolver) const {
   auto _o = new MeanOptionsT();
@@ -6595,6 +6716,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier,
       auto ptr = reinterpret_cast<const StridedSliceOptions *>(obj);
       return verifier.VerifyTable(ptr);
     }
+    case BuiltinOptions_ExpOptions: {
+      auto ptr = reinterpret_cast<const ExpOptions *>(obj);
+      return verifier.VerifyTable(ptr);
+    }
     default:
       return false;
   }
@@ -6604,6 +6729,7 @@ inline bool VerifyBuiltinOptionsVector(
     flatbuffers::Verifier &verifier,
     const flatbuffers::Vector<flatbuffers::Offset<void>> *values,
     const flatbuffers::Vector<uint8_t> *types) {
+  if (!values || !types) return !values && !types;
   if (values->size() != types->size()) return false;
   for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) {
     if (!VerifyBuiltinOptions(verifier, values->Get(i),
@@ -6747,6 +6873,10 @@ inline void *BuiltinOptionsUnion::UnPack(
       auto ptr = reinterpret_cast<const StridedSliceOptions *>(obj);
       return ptr->UnPack(resolver);
     }
+    case BuiltinOptions_ExpOptions: {
+      auto ptr = reinterpret_cast<const ExpOptions *>(obj);
+      return ptr->UnPack(resolver);
+    }
     default:
       return nullptr;
   }
@@ -6886,6 +7016,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(
       auto ptr = reinterpret_cast<const StridedSliceOptionsT *>(value);
       return CreateStridedSliceOptions(_fbb, ptr, _rehasher).Union();
     }
+    case BuiltinOptions_ExpOptions: {
+      auto ptr = reinterpret_cast<const ExpOptionsT *>(value);
+      return CreateExpOptions(_fbb, ptr, _rehasher).Union();
+    }
     default:
       return 0;
   }
@@ -7041,6 +7175,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u)
           *reinterpret_cast<StridedSliceOptionsT *>(u.value));
       break;
     }
+    case BuiltinOptions_ExpOptions: {
+      value = new ExpOptionsT(*reinterpret_cast<ExpOptionsT *>(u.value));
+      break;
+    }
     default:
       break;
   }
@@ -7208,6 +7346,11 @@ inline void BuiltinOptionsUnion::Reset() {
       delete ptr;
       break;
     }
+    case BuiltinOptions_ExpOptions: {
+      auto ptr = reinterpret_cast<ExpOptionsT *>(value);
+      delete ptr;
+      break;
+    }
     default:
       break;
   }
index 9351cf9..8739ffb 100644 (file)
@@ -25,6 +25,7 @@ gen_zipped_test_files(
         "conv.zip",
         "depthwiseconv.zip",
         "div.zip",
+        "exp.zip",
         "fully_connected.zip",
         "fused_batch_norm.zip",
         "gather.zip",
index a86c648..7fe4616 100644 (file)
@@ -745,6 +745,33 @@ def make_mean_tests(zip_path):
   make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
 
 
+def make_exp_tests(zip_path):
+  """Make a set of tests to do exp."""
+
+  test_parameters = [{
+      "input_dtype": [tf.float32],
+      "input_shape": [[3], [1, 100], [4, 2, 3], [5, 224, 224, 3]],
+  }]
+
+  def build_graph(parameters):
+    """Build the exp op testing graph."""
+    input_tensor = tf.placeholder(
+        dtype=parameters["input_dtype"],
+        name="input",
+        shape=parameters["input_shape"])
+
+    out = tf.exp(input_tensor)
+    return [input_tensor], [out]
+
+  def build_inputs(parameters, sess, inputs, outputs):
+    values = [
+        create_tensor_data(parameters["input_dtype"], parameters["input_shape"])
+    ]
+    return values, sess.run(outputs, feed_dict=dict(zip(inputs, values)))
+
+  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
 def make_binary_op_tests_func(binary_operator):
   """Return a function that does a test on a binary operator."""
   return lambda zip_path: make_binary_op_tests(zip_path, binary_operator)
@@ -1715,6 +1742,7 @@ def main(unused_args):
         "mean.zip": make_mean_tests,
         "squeeze.zip": make_squeeze_tests,
         "strided_slice.zip": make_strided_slice_tests,
+        "exp.zip": make_exp_tests,
     }
     out = FLAGS.zip_to_output
     bin_path = FLAGS.toco
index 5ea3e21..80e806a 100644 (file)
@@ -242,6 +242,7 @@ INSTANTIATE_TESTS(constant)
 INSTANTIATE_TESTS(control_dep)
 INSTANTIATE_TESTS(conv)
 INSTANTIATE_TESTS(depthwiseconv)
+INSTANTIATE_TESTS(exp)
 INSTANTIATE_TESTS(fully_connected)
 INSTANTIATE_TESTS(fused_batch_norm)
 INSTANTIATE_TESTS(gather)
index 7cddbad..ddcc038 100644 (file)
@@ -1327,6 +1327,7 @@ bool PropagateFixedSizes::Run(Model* model, std::size_t op_index) {
     case OperatorType::kTensorFlowAssert:
     case OperatorType::kCast:
     case OperatorType::kFloor:
+    case OperatorType::kExp:
       ProcessSimpleOperator(model, op);
       break;
     case OperatorType::kGather:
index 41d6c83..02c3b2e 100644 (file)
@@ -1460,6 +1460,17 @@ void ConvertBatchToSpaceNDOperator(const NodeDef& node,
   model->operators.emplace_back(op);
 }
 
+void ConvertExpOperator(const NodeDef& node,
+                        const TensorFlowImportFlags& tf_import_flags,
+                        Model* model) {
+  CHECK_EQ(node.op(), "Exp");
+  CheckInputsCount(node, tf_import_flags, 1);
+  auto* op = new ExpOperator;
+  op->inputs.push_back(node.input(0));
+  op->outputs.push_back(node.name());
+  model->operators.emplace_back(op);
+}
+
 void ConvertMeanOperator(const NodeDef& node,
                          const TensorFlowImportFlags& tf_import_flags,
                          Model* model) {
@@ -1986,6 +1997,8 @@ std::unique_ptr<Model> ImportTensorFlowGraphDef(
       ConvertTransposeOperator(node, tf_import_flags, model);
     } else if (node.op() == "ArgMax") {
       ConvertArgMaxOperator(node, tf_import_flags, model);
+    } else if (node.op() == "Exp") {
+      ConvertExpOperator(node, tf_import_flags, model);
     } else {
       ConvertUnsupportedOperator(node, tf_import_flags, model);
     }
index 0bee694..4c44f3f 100644 (file)
@@ -44,6 +44,7 @@ enum class OperatorType {
   kSpaceToDepth,
   kDequantize,
   kDiv,
+  kExp,
   kExpandDims,
   kFill,
   kFloorDiv,
@@ -852,6 +853,17 @@ struct TransposeConvOperator : Operator {
   int stride_height = 0;
 };
 
+// Given a tensor input, this operation calculates element-wise exponential
+// (y = e^x).
+//
+// Inputs:
+//   inputs[0]: required: input tensor
+//
+// TensorFlow equivalent: Exp
+struct ExpOperator : Operator {
+  ExpOperator() : Operator(OperatorType::kExp) {}
+};
+
 // Given a tensor input, this operation inserts a dimension of 1 at the
 // dimension index axis of input's shape. The dimension index axis starts at
 // zero; if you specify a negative number for axis it is counted backward from
index ff54b35..2583ec0 100644 (file)
@@ -835,6 +835,7 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
       "LOGISTIC", OperatorType::kLogistic));
   ops.emplace_back(
       new SimpleOperator<TanhOperator>("TANH", OperatorType::kTanh));
+  ops.emplace_back(new SimpleOperator<ExpOperator>("EXP", OperatorType::kExp));
 
   return ops;
 }
index 796534b..05c325e 100644 (file)
@@ -106,6 +106,7 @@ TEST_F(OperatorTest, SimpleOperators) {
   CheckSimpleOperator<Relu6Operator>("RELU6", OperatorType::kRelu6);
   CheckSimpleOperator<LogisticOperator>("LOGISTIC", OperatorType::kLogistic);
   CheckSimpleOperator<TanhOperator>("TANH", OperatorType::kTanh);
+  CheckSimpleOperator<ExpOperator>("EXP", OperatorType::kExp);
 }
 
 TEST_F(OperatorTest, BuiltinAdd) {
index a5fed23..6275415 100644 (file)
@@ -313,6 +313,7 @@ const char* OperatorTypeName(OperatorType type) {
     HANDLE_OPERATORTYPENAME_CASE(Svdf)
     HANDLE_OPERATORTYPENAME_CASE(ArgMax)
     HANDLE_OPERATORTYPENAME_CASE(TensorFlowUnsupported)
+    HANDLE_OPERATORTYPENAME_CASE(Exp)
     default:
       LOG(FATAL) << "Unhandled op type";
 #undef HANDLE_OPERATORTYPENAME_CASE