Implementation of Less
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 13 Apr 2018 07:12:41 +0000 (00:12 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 13 Apr 2018 07:14:59 +0000 (00:14 -0700)
PiperOrigin-RevId: 192728635

16 files changed:
tensorflow/contrib/lite/builtin_ops.h
tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
tensorflow/contrib/lite/kernels/BUILD
tensorflow/contrib/lite/kernels/comparisons.cc [new file with mode: 0644]
tensorflow/contrib/lite/kernels/comparisons_test.cc [new file with mode: 0644]
tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
tensorflow/contrib/lite/kernels/register.cc
tensorflow/contrib/lite/model.cc
tensorflow/contrib/lite/nnapi_delegate.cc
tensorflow/contrib/lite/schema/schema.fbs
tensorflow/contrib/lite/schema/schema_generated.h
tensorflow/contrib/lite/testing/BUILD
tensorflow/contrib/lite/testing/generate_examples.py
tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
tensorflow/contrib/lite/toco/tflite/operator.cc
tensorflow/contrib/lite/toco/tflite/operator_test.cc

index 1ceefaf..859bc7a 100644 (file)
@@ -82,6 +82,7 @@ typedef enum {
   kTfLiteBuiltinMaximum = 55,
   kTfLiteBuiltinArgMax = 56,
   kTfLiteBuiltinMinimum = 57,
+  kTfLiteBuiltinLess = 58,
 } TfLiteBuiltinOperator;
 
 #ifdef __cplusplus
index 61ea523..203924f 100644 (file)
@@ -302,6 +302,19 @@ Options {
 }
 ```
 
+**LESS**
+
+```
+Inputs {
+  0: a tensor
+  1: a tensor
+}
+Outputs {
+  0: a tensor of type bool, true whenever an element of the first tensor is less
+  than the corresponding element of the second tensor.
+}
+```
+
 **LOCAL_RESPONSE_NORMALIZATION**
 
 ```
index 914893c..800e2a9 100644 (file)
@@ -136,6 +136,7 @@ cc_library(
         "bidirectional_sequence_lstm.cc",
         "bidirectional_sequence_rnn.cc",
         "cast.cc",
+        "comparisons.cc",
         "concatenation.cc",
         "conv.cc",
         "depthwise_conv.cc",
@@ -818,6 +819,24 @@ tf_cc_test(
     ],
 )
 
+tf_cc_test(
+    name = "comparisons_test",
+    size = "small",
+    srcs = [
+        "comparisons_test.cc",
+    ],
+    tags = [
+        "tflite_not_portable_ios_arm64",
+        "tflite_not_portable_ios_x86_64",
+    ],
+    deps = [
+        ":builtin_ops",
+        "//tensorflow/contrib/lite:framework",
+        "//tensorflow/contrib/lite/kernels:test_util",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
 filegroup(
     name = "all_files",
     srcs = glob(
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
new file mode 100644 (file)
index 0000000..87c413c
--- /dev/null
@@ -0,0 +1,119 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+#include "tensorflow/contrib/lite/kernels/op_macros.h"
+#include "tensorflow/contrib/lite/string_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace comparisons {
+
+constexpr int kInputTensor1 = 0;
+constexpr int kInputTensor2 = 1;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus LessPrepare(TfLiteContext* context, TfLiteNode* node) {
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 2);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+
+  TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+  TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+  // Don't support string and bool.
+  TF_LITE_ENSURE(context,
+                 input1->type != kTfLiteString || input1->type != kTfLiteBool);
+  // Currently only support tensors have the same type.
+  TF_LITE_ENSURE_EQ(context, input1->type, input2->type);
+  output->type = kTfLiteBool;
+
+  bool requires_broadcast = !HaveSameShapes(input1, input2);
+
+  TfLiteIntArray* output_size = nullptr;
+  if (requires_broadcast) {
+    TF_LITE_ENSURE_OK(context, CalculateShapeForBroadcast(
+                                   context, input1, input2, &output_size));
+  } else {
+    output_size = TfLiteIntArrayCopy(input1->dims);
+  }
+
+  return context->ResizeTensor(context, output, output_size);
+}
+
+TfLiteStatus LessEval(TfLiteContext* context, TfLiteNode* node) {
+  TfLiteTensor* input1 = GetInput(context, node, kInputTensor1);
+  TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+  bool requires_broadcast = !HaveSameShapes(input1, input2);
+
+#define TF_LITE_LESS(type, opname)                                          \
+  reference_ops::opname(GetTensorData<type>(input1), GetTensorDims(input1), \
+                        GetTensorData<type>(input2), GetTensorDims(input2), \
+                        GetTensorData<bool>(output), GetTensorDims(output));
+
+  // TODO(renjieliu): Support quantized data.
+  if (requires_broadcast) {
+    switch (input1->type) {
+      case kTfLiteFloat32:
+        TF_LITE_LESS(float, BroadcastLess);
+        break;
+      case kTfLiteInt32:
+        TF_LITE_LESS(int32_t, BroadcastLess);
+        break;
+      case kTfLiteInt64:
+        TF_LITE_LESS(int64_t, BroadcastLess);
+        break;
+      default:
+        context->ReportError(context,
+                             "Does not support type other than float|int");
+        return kTfLiteError;
+    }
+  } else {
+    switch (input1->type) {
+      case kTfLiteFloat32:
+        TF_LITE_LESS(float, Less);
+        break;
+      case kTfLiteInt32:
+        TF_LITE_LESS(int32_t, Less);
+        break;
+      case kTfLiteInt64:
+        TF_LITE_LESS(int64_t, Less);
+        break;
+      default:
+        context->ReportError(context,
+                             "Does not support type other than float|int");
+        return kTfLiteError;
+    }
+  }
+#undef TF_LITE_LESS
+  return kTfLiteOk;
+}
+
+}  // namespace comparisons
+
+TfLiteRegistration* Register_LESS() {
+  static TfLiteRegistration r = {nullptr, nullptr, comparisons::LessPrepare,
+                                 comparisons::LessEval};
+  return &r;
+}
+
+}  // namespace builtin
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc
new file mode 100644 (file)
index 0000000..da2d7f8
--- /dev/null
@@ -0,0 +1,98 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class LessOpModel : public SingleOpModel {
+ public:
+  LessOpModel(std::initializer_list<int> input1_shape,
+              std::initializer_list<int> input2_shape, TensorType input_type) {
+    input1_ = AddInput(input_type);
+    input2_ = AddInput(input_type);
+    output_ = AddOutput(TensorType_BOOL);
+    SetBuiltinOp(BuiltinOperator_LESS, BuiltinOptions_LessOptions,
+                 CreateLessOptions(builder_).Union());
+    BuildInterpreter({input1_shape, input2_shape});
+  }
+
+  int input1() { return input1_; }
+  int input2() { return input2_; }
+
+  std::vector<bool> GetOutput() { return ExtractVector<bool>(output_); }
+  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+  int input1_;
+  int input2_;
+  int output_;
+};
+
+TEST(ArgMaxOpTest, LessFloat) {
+  LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_FLOAT32);
+  model.PopulateTensor<float>(model.input1(), {0.1, 0.9, 0.7, 0.3});
+  model.PopulateTensor<float>(model.input2(), {0.1, 0.2, 0.6, 0.5});
+  model.Invoke();
+
+  EXPECT_THAT(model.GetOutput(), ElementsAreArray({false, false, false, true}));
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ArgMaxOpTest, LessInt) {
+  LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 4}, TensorType_INT32);
+  model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+  model.PopulateTensor<int>(model.input2(), {1, 2, 6, 5});
+  model.Invoke();
+
+  EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true}));
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ArgMaxOpTest, LessBroadcast) {
+  LessOpModel model({1, 1, 1, 4}, {1, 1, 1, 1}, TensorType_INT32);
+  model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3});
+  model.PopulateTensor<int>(model.input2(), {7});
+  model.Invoke();
+
+  EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true}));
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 1, 4}));
+}
+
+TEST(ArgMaxOpTest, LessBroadcastTwoD) {
+  LessOpModel model({1, 1, 2, 4}, {1, 1, 1, 4}, TensorType_INT32);
+  model.PopulateTensor<int>(model.input1(), {-1, 9, 7, 3, 2, 4, 6, 8});
+  model.PopulateTensor<int>(model.input2(), {7, 1, 2, 4});
+  model.Invoke();
+
+  EXPECT_THAT(model.GetOutput(), ElementsAreArray({true, false, false, true,
+                                                   true, false, false, false}));
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({1, 1, 2, 4}));
+}
+
+}  // namespace
+}  // namespace tflite
+
+int main(int argc, char** argv) {
+  ::tflite::LogToStderr();
+  ::testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
index c601939..6a89dbc 100644 (file)
@@ -3378,6 +3378,51 @@ inline void TransposeConv(const float* input_data, const Dims<4>& input_dims,
   }
 }
 
+template <typename T>
+inline void Less(int64_t num_elements, const T* input1, const T* input2,
+                 bool* output) {
+  for (int64_t i = 0; i < num_elements; ++i) {
+    output[i] = input1[i] < input2[i];
+  }
+}
+
+template <typename T>
+inline void Less(const T* input1_data, const Dims<4>& input1_dims,
+                 const T* input2_data, const Dims<4>& input2_dims,
+                 bool* output_data, const Dims<4>& output_dims) {
+  const int64_t batches =
+      MatchingArraySize(input1_dims, 3, input2_dims, 3, output_dims, 3);
+  const int64_t height =
+      MatchingArraySize(input1_dims, 2, input2_dims, 2, output_dims, 2);
+  const int64_t width =
+      MatchingArraySize(input1_dims, 1, input2_dims, 1, output_dims, 1);
+  const int64_t depth =
+      MatchingArraySize(input1_dims, 0, input2_dims, 0, output_dims, 0);
+  Less(batches * height * width * depth, input1_data, input2_data, output_data);
+}
+
+template <typename T1, typename T2>
+inline void BroadcastLess(T1* input1_data, const Dims<4>& input1_dims,
+                          T2* input2_data, const Dims<4>& input2_dims,
+                          bool* output_data, const Dims<4>& output_dims) {
+  gemmlowp::ScopedProfilingLabel label("BroadcastLess");
+  NdArrayDesc<4> desc1;
+  NdArrayDesc<4> desc2;
+  NdArrayDescsForElementwiseBroadcast(input1_dims, input2_dims, &desc1, &desc2);
+
+  for (int b = 0; b < ArraySize(output_dims, 3); ++b) {
+    for (int y = 0; y < ArraySize(output_dims, 2); ++y) {
+      for (int x = 0; x < ArraySize(output_dims, 1); ++x) {
+        for (int c = 0; c < ArraySize(output_dims, 0); ++c) {
+          output_data[Offset(output_dims, c, x, y, b)] =
+              input1_data[SubscriptToIndex(desc1, c, x, y, b)] <
+              input2_data[SubscriptToIndex(desc2, c, x, y, b)];
+        }
+      }
+    }
+  }
+}
+
 }  // namespace reference_ops
 }  // namespace tflite
 
index 67ba8d0..b07e7b6 100644 (file)
@@ -79,6 +79,7 @@ TfLiteRegistration* Register_PRELU();
 TfLiteRegistration* Register_MAXIMUM();
 TfLiteRegistration* Register_MINIMUM();
 TfLiteRegistration* Register_ARG_MAX();
+TfLiteRegistration* Register_LESS();
 
 BuiltinOpResolver::BuiltinOpResolver() {
   AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -139,6 +140,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
   AddBuiltin(BuiltinOperator_MAXIMUM, Register_MAXIMUM());
   AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM());
   AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX());
+  AddBuiltin(BuiltinOperator_LESS, Register_LESS());
 
   // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
   // custom ops aren't always included by default.
index 0b65884..54b1460 100644 (file)
@@ -665,6 +665,9 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
       *builtin_data = reinterpret_cast<void*>(params);
       break;
     }
+    case BuiltinOperator_LESS: {
+      break;
+    }
     case BuiltinOperator_DELEGATE: {
       // TODO(ycling): Revisit when supporting saving delegated models.
       error_reporter->Report("DELEGATE op shouldn't exist in model.");
index 08fb820..eab82ea 100644 (file)
@@ -353,6 +353,7 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
       case tflite::BuiltinOperator_MAXIMUM:
       case tflite::BuiltinOperator_MINIMUM:
       case tflite::BuiltinOperator_ARG_MAX:
+      case tflite::BuiltinOperator_LESS:
         FATAL("Op code %d is currently not delegated to NNAPI", builtin);
         nn_op_type = -1;  // set to invalid
         break;
index fa82550..93980b1 100644 (file)
@@ -135,6 +135,7 @@ enum BuiltinOperator : byte {
   MAXIMUM = 55,
   ARG_MAX = 56,
   MINIMUM = 57,
+  LESS = 58,
 }
 
 // Options for the builtin operators.
@@ -179,6 +180,7 @@ union BuiltinOptions {
   DequantizeOptions,
   MaximumMinimumOptions,
   ArgMaxOptions,
+  LessOptions,
 }
 
 enum Padding : byte { SAME, VALID }
@@ -399,6 +401,9 @@ table ArgMaxOptions {
   output_type : TensorType;
 }
 
+table LessOptions {
+}
+
 // An OperatorCode can be an enum value (BuiltinOperator) if the operator is a
 // builtin, or a string if the operator is custom.
 table OperatorCode {
index 909c4cc..b2a799d 100755 (executable)
@@ -151,6 +151,9 @@ struct MaximumMinimumOptionsT;
 struct ArgMaxOptions;
 struct ArgMaxOptionsT;
 
+struct LessOptions;
+struct LessOptionsT;
+
 struct OperatorCode;
 struct OperatorCodeT;
 
@@ -267,11 +270,12 @@ enum BuiltinOperator {
   BuiltinOperator_MAXIMUM = 55,
   BuiltinOperator_ARG_MAX = 56,
   BuiltinOperator_MINIMUM = 57,
+  BuiltinOperator_LESS = 58,
   BuiltinOperator_MIN = BuiltinOperator_ADD,
-  BuiltinOperator_MAX = BuiltinOperator_MINIMUM
+  BuiltinOperator_MAX = BuiltinOperator_LESS
 };
 
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[56] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[57] {
   static BuiltinOperator values[] = {
     BuiltinOperator_ADD,
     BuiltinOperator_AVERAGE_POOL_2D,
@@ -328,7 +332,8 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[56] {
     BuiltinOperator_PRELU,
     BuiltinOperator_MAXIMUM,
     BuiltinOperator_ARG_MAX,
-    BuiltinOperator_MINIMUM
+    BuiltinOperator_MINIMUM,
+    BuiltinOperator_LESS
   };
   return values;
 }
@@ -393,6 +398,7 @@ inline const char **EnumNamesBuiltinOperator() {
     "MAXIMUM",
     "ARG_MAX",
     "MINIMUM",
+    "LESS",
     nullptr
   };
   return names;
@@ -445,11 +451,12 @@ enum BuiltinOptions {
   BuiltinOptions_DequantizeOptions = 38,
   BuiltinOptions_MaximumMinimumOptions = 39,
   BuiltinOptions_ArgMaxOptions = 40,
+  BuiltinOptions_LessOptions = 41,
   BuiltinOptions_MIN = BuiltinOptions_NONE,
-  BuiltinOptions_MAX = BuiltinOptions_ArgMaxOptions
+  BuiltinOptions_MAX = BuiltinOptions_LessOptions
 };
 
-inline BuiltinOptions (&EnumValuesBuiltinOptions())[41] {
+inline BuiltinOptions (&EnumValuesBuiltinOptions())[42] {
   static BuiltinOptions values[] = {
     BuiltinOptions_NONE,
     BuiltinOptions_Conv2DOptions,
@@ -491,7 +498,8 @@ inline BuiltinOptions (&EnumValuesBuiltinOptions())[41] {
     BuiltinOptions_CastOptions,
     BuiltinOptions_DequantizeOptions,
     BuiltinOptions_MaximumMinimumOptions,
-    BuiltinOptions_ArgMaxOptions
+    BuiltinOptions_ArgMaxOptions,
+    BuiltinOptions_LessOptions
   };
   return values;
 }
@@ -539,6 +547,7 @@ inline const char **EnumNamesBuiltinOptions() {
     "DequantizeOptions",
     "MaximumMinimumOptions",
     "ArgMaxOptions",
+    "LessOptions",
     nullptr
   };
   return names;
@@ -713,6 +722,10 @@ template<> struct BuiltinOptionsTraits<ArgMaxOptions> {
   static const BuiltinOptions enum_value = BuiltinOptions_ArgMaxOptions;
 };
 
+template<> struct BuiltinOptionsTraits<LessOptions> {
+  static const BuiltinOptions enum_value = BuiltinOptions_LessOptions;
+};
+
 struct BuiltinOptionsUnion {
   BuiltinOptions type;
   void *value;
@@ -1064,6 +1077,14 @@ struct BuiltinOptionsUnion {
     return type == BuiltinOptions_ArgMaxOptions ?
       reinterpret_cast<const ArgMaxOptionsT *>(value) : nullptr;
   }
+  LessOptionsT *AsLessOptions() {
+    return type == BuiltinOptions_LessOptions ?
+      reinterpret_cast<LessOptionsT *>(value) : nullptr;
+  }
+  const LessOptionsT *AsLessOptions() const {
+    return type == BuiltinOptions_LessOptions ?
+      reinterpret_cast<const LessOptionsT *>(value) : nullptr;
+  }
 };
 
 bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -3927,6 +3948,46 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(
 
 flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatBufferBuilder &_fbb, const ArgMaxOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
 
+struct LessOptionsT : public flatbuffers::NativeTable {
+  typedef LessOptions TableType;
+  LessOptionsT() {
+  }
+};
+
+struct LessOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+  typedef LessOptionsT NativeTableType;
+  bool Verify(flatbuffers::Verifier &verifier) const {
+    return VerifyTableStart(verifier) &&
+           verifier.EndTable();
+  }
+  LessOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  void UnPackTo(LessOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+  static flatbuffers::Offset<LessOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const LessOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct LessOptionsBuilder {
+  flatbuffers::FlatBufferBuilder &fbb_;
+  flatbuffers::uoffset_t start_;
+  explicit LessOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+        : fbb_(_fbb) {
+    start_ = fbb_.StartTable();
+  }
+  LessOptionsBuilder &operator=(const LessOptionsBuilder &);
+  flatbuffers::Offset<LessOptions> Finish() {
+    const auto end = fbb_.EndTable(start_);
+    auto o = flatbuffers::Offset<LessOptions>(end);
+    return o;
+  }
+};
+
+inline flatbuffers::Offset<LessOptions> CreateLessOptions(
+    flatbuffers::FlatBufferBuilder &_fbb) {
+  LessOptionsBuilder builder_(_fbb);
+  return builder_.Finish();
+}
+
+flatbuffers::Offset<LessOptions> CreateLessOptions(flatbuffers::FlatBufferBuilder &_fbb, const LessOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
 struct OperatorCodeT : public flatbuffers::NativeTable {
   typedef OperatorCode TableType;
   BuiltinOperator builtin_code;
@@ -4164,6 +4225,9 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
   const ArgMaxOptions *builtin_options_as_ArgMaxOptions() const {
     return builtin_options_type() == BuiltinOptions_ArgMaxOptions ? static_cast<const ArgMaxOptions *>(builtin_options()) : nullptr;
   }
+  const LessOptions *builtin_options_as_LessOptions() const {
+    return builtin_options_type() == BuiltinOptions_LessOptions ? static_cast<const LessOptions *>(builtin_options()) : nullptr;
+  }
   const flatbuffers::Vector<uint8_t> *custom_options() const {
     return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
   }
@@ -4350,6 +4414,10 @@ template<> inline const ArgMaxOptions *Operator::builtin_options_as<ArgMaxOption
   return builtin_options_as_ArgMaxOptions();
 }
 
+template<> inline const LessOptions *Operator::builtin_options_as<LessOptions>() const {
+  return builtin_options_as_LessOptions();
+}
+
 struct OperatorBuilder {
   flatbuffers::FlatBufferBuilder &fbb_;
   flatbuffers::uoffset_t start_;
@@ -5933,6 +6001,29 @@ inline flatbuffers::Offset<ArgMaxOptions> CreateArgMaxOptions(flatbuffers::FlatB
       _output_type);
 }
 
+inline LessOptionsT *LessOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+  auto _o = new LessOptionsT();
+  UnPackTo(_o, _resolver);
+  return _o;
+}
+
+inline void LessOptions::UnPackTo(LessOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+  (void)_o;
+  (void)_resolver;
+}
+
+inline flatbuffers::Offset<LessOptions> LessOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const LessOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+  return CreateLessOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<LessOptions> CreateLessOptions(flatbuffers::FlatBufferBuilder &_fbb, const LessOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+  (void)_rehasher;
+  (void)_o;
+  struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const LessOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+  return tflite::CreateLessOptions(
+      _fbb);
+}
+
 inline OperatorCodeT *OperatorCode::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
   auto _o = new OperatorCodeT();
   UnPackTo(_o, _resolver);
@@ -6273,6 +6364,10 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
       auto ptr = reinterpret_cast<const ArgMaxOptions *>(obj);
       return verifier.VerifyTable(ptr);
     }
+    case BuiltinOptions_LessOptions: {
+      auto ptr = reinterpret_cast<const LessOptions *>(obj);
+      return verifier.VerifyTable(ptr);
+    }
     default: return false;
   }
 }
@@ -6451,6 +6546,10 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
       auto ptr = reinterpret_cast<const ArgMaxOptions *>(obj);
       return ptr->UnPack(resolver);
     }
+    case BuiltinOptions_LessOptions: {
+      auto ptr = reinterpret_cast<const LessOptions *>(obj);
+      return ptr->UnPack(resolver);
+    }
     default: return nullptr;
   }
 }
@@ -6617,6 +6716,10 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
       auto ptr = reinterpret_cast<const ArgMaxOptionsT *>(value);
       return CreateArgMaxOptions(_fbb, ptr, _rehasher).Union();
     }
+    case BuiltinOptions_LessOptions: {
+      auto ptr = reinterpret_cast<const LessOptionsT *>(value);
+      return CreateLessOptions(_fbb, ptr, _rehasher).Union();
+    }
     default: return 0;
   }
 }
@@ -6783,6 +6886,10 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
       value = new ArgMaxOptionsT(*reinterpret_cast<ArgMaxOptionsT *>(u.value));
       break;
     }
+    case BuiltinOptions_LessOptions: {
+      value = new LessOptionsT(*reinterpret_cast<LessOptionsT *>(u.value));
+      break;
+    }
     default:
       break;
   }
@@ -6990,6 +7097,11 @@ inline void BuiltinOptionsUnion::Reset() {
       delete ptr;
       break;
     }
+    case BuiltinOptions_LessOptions: {
+      auto ptr = reinterpret_cast<LessOptionsT *>(value);
+      delete ptr;
+      break;
+    }
     default: break;
   }
   value = nullptr;
index 2c226e7..bd888a4 100644 (file)
@@ -34,6 +34,7 @@ gen_zipped_test_files(
         "global_batch_norm.zip",
         "l2_pool.zip",
         "l2norm.zip",
+        "less.zip",
         "local_response_norm.zip",
         "log_softmax.zip",
         "max_pool.zip",
index 4b4ccc0..53b41d2 100644 (file)
@@ -1997,6 +1997,39 @@ def make_arg_max_tests(zip_path):
   make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
 
 
+def make_less_tests(zip_path):
+  """Make a set of tests to do less."""
+
+  test_parameters = [{
+      "input_dtype": [tf.float32, tf.int32, tf.int64],
+      "input_shape_pair": [([1, 1, 1, 3], [1, 1, 1, 3]),
+                           ([2, 3, 4, 5], [2, 3, 4, 5]), ([2, 3, 3], [2, 3]),
+                           ([5, 5], [1]), ([10], [2, 4, 10])],
+  }]
+
+  def build_graph(parameters):
+    """Build the less op testing graph."""
+    input_value1 = tf.placeholder(
+        dtype=parameters["input_dtype"],
+        name="input1",
+        shape=parameters["input_shape_pair"][0])
+    input_value2 = tf.placeholder(
+        dtype=parameters["input_dtype"],
+        name="input2",
+        shape=parameters["input_shape_pair"][1])
+    out = tf.less(input_value1, input_value2)
+    return [input_value1, input_value2], [out]
+
+  def build_inputs(parameters, sess, inputs, outputs):
+    input_value1 = create_tensor_data(parameters["input_dtype"],
+                                      parameters["input_shape_pair"][0])
+    input_value2 = create_tensor_data(parameters["input_dtype"],
+                                      parameters["input_shape_pair"][1])
+    return [input_value1, input_value2], sess.run(
+        outputs, feed_dict=dict(zip(inputs, [input_value1, input_value2])))
+
+  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
 # Toco binary path provided by the generate rule.
 bin_path = None
 
index 84ae1d5..9da8bd7 100644 (file)
@@ -280,6 +280,7 @@ INSTANTIATE_TESTS(squeeze)
 INSTANTIATE_TESTS(strided_slice)
 INSTANTIATE_TESTS(sub)
 INSTANTIATE_TESTS(transpose)
+INSTANTIATE_TESTS(less)
 
 }  // namespace testing
 }  // namespace tflite
index 0e057fd..f41a312 100644 (file)
@@ -895,6 +895,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
       "MAXIMUM", OperatorType::kTensorFlowMaximum));
   ops.emplace_back(new SimpleOperator<TensorFlowMinimumOperator>(
       "MINIMUM", OperatorType::kTensorFlowMinimum));
+  ops.emplace_back(new SimpleOperator<TensorFlowLessOperator>(
+      "LESS", OperatorType::kTensorFlowLess));
 
   return ops;
 }
index a947630..36ed741 100644 (file)
@@ -113,6 +113,8 @@ TEST_F(OperatorTest, SimpleOperators) {
       "MAXIMUM", OperatorType::kTensorFlowMaximum);
   CheckSimpleOperator<TensorFlowMinimumOperator>(
       "MINIMUM", OperatorType::kTensorFlowMinimum);
+  CheckSimpleOperator<TensorFlowLessOperator>("LESS",
+                                              OperatorType::kTensorFlowLess);
 }
 
 TEST_F(OperatorTest, BuiltinAdd) {