Implement floor operator
authorA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 27 Apr 2018 02:35:10 +0000 (19:35 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Fri, 27 Apr 2018 02:38:26 +0000 (19:38 -0700)
PiperOrigin-RevId: 194490433

14 files changed:
tensorflow/contrib/lite/builtin_ops.h
tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md
tensorflow/contrib/lite/kernels/BUILD
tensorflow/contrib/lite/kernels/floor.cc [new file with mode: 0644]
tensorflow/contrib/lite/kernels/floor_test.cc [new file with mode: 0644]
tensorflow/contrib/lite/kernels/register.cc
tensorflow/contrib/lite/model.cc
tensorflow/contrib/lite/nnapi_delegate.cc
tensorflow/contrib/lite/schema/schema.fbs
tensorflow/contrib/lite/schema/schema_generated.h
tensorflow/contrib/lite/testing/BUILD
tensorflow/contrib/lite/testing/generate_examples.py
tensorflow/contrib/lite/testing/generated_examples_zip_test.cc
tensorflow/contrib/lite/toco/tflite/operator.cc

index 859bc7a..21e0e04 100644 (file)
@@ -33,6 +33,7 @@ typedef enum {
   kTfLiteBuiltinDepthwiseConv2d = 4,
   kTfLiteBuiltinDequantize = 6,
   kTfLiteBuiltinEmbeddingLookup = 7,
+  kTfLiteBuiltinFloor = 8,
   kTfLiteBuiltinFullyConnected = 9,
   kTfLiteBuiltinHashtableLookup = 10,
   kTfLiteBuiltinL2Normalization = 11,
index 203924f..aa28f8d 100644 (file)
@@ -132,7 +132,6 @@ TensorFlow operation not listed above are likely unsupported. Notably, the
 following common ops are not supported at the moment:
 
 *   [tf.depth_to_space](https://www.tensorflow.org/api_docs/python/tf/depth_to_space)
-*   [tf.floor](https://www.tensorflow.org/api_docs/python/tf/floor)
 *   [tf.gather](https://www.tensorflow.org/api_docs/python/tf/gather)
 *   [tf.image.resize_bilinear](https://www.tensorflow.org/api_docs/python/tf/image/resize_bilinear)
 *   [tf.slice](https://www.tensorflow.org/api_docs/python/tf/slice)
@@ -254,6 +253,17 @@ Outputs {
 }
 ```
 
+**FLOOR**
+
+```
+inputs {
+  0: tensor
+}
+outputs: {
+  0: result of computing element-wise floor of the input tensor
+}
+```
+
 **FULLY_CONNECTED**
 
 ```
index 80cefe8..689f9bf 100644 (file)
@@ -145,6 +145,7 @@ cc_library(
         "embedding_lookup.cc",
         "embedding_lookup_sparse.cc",
         "exp.cc",
+        "floor.cc",
         "fully_connected.cc",
         "gather.cc",
         "hashtable_lookup.cc",
@@ -438,6 +439,19 @@ tf_cc_test(
 )
 
 tf_cc_test(
+    name = "floor_test",
+    size = "small",
+    srcs = ["floor_test.cc"],
+    tags = ["tflite_not_portable_ios"],
+    deps = [
+        ":builtin_ops",
+        "//tensorflow/contrib/lite:framework",
+        "//tensorflow/contrib/lite/kernels:test_util",
+        "@com_google_googletest//:gtest",
+    ],
+)
+
+tf_cc_test(
     name = "unidirectional_sequence_lstm_test",
     size = "small",
     srcs = ["unidirectional_sequence_lstm_test.cc"],
diff --git a/tensorflow/contrib/lite/kernels/floor.cc b/tensorflow/contrib/lite/kernels/floor.cc
new file mode 100644 (file)
index 0000000..4b4395f
--- /dev/null
@@ -0,0 +1,58 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/lite/context.h"
+#include "tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h"
+#include "tensorflow/contrib/lite/kernels/internal/tensor.h"
+#include "tensorflow/contrib/lite/kernels/kernel_util.h"
+
+namespace tflite {
+namespace ops {
+namespace builtin {
+namespace floor {
+
+constexpr int kInputTensor = 0;
+constexpr int kOutputTensor = 0;
+
+TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+  TfLiteTensor* input = GetInput(context, node, kInputTensor);
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+  TF_LITE_ENSURE_EQ(context, NumInputs(node), 1);
+  TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1);
+  TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
+  output->type = input->type;
+  TfLiteIntArray* output_size = TfLiteIntArrayCopy(input->dims);
+  return context->ResizeTensor(context, output, output_size);
+}
+
+TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
+  TfLiteTensor* input = GetInput(context, node, kInputTensor);
+  TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
+
+  optimized_ops::Floor(GetTensorData<float>(input), GetTensorDims(input),
+                       GetTensorData<float>(output), GetTensorDims(output));
+  return kTfLiteOk;
+}
+}  // namespace floor
+
+TfLiteRegistration* Register_FLOOR() {
+  static TfLiteRegistration r = {/*init=*/nullptr,
+                                 /*free=*/nullptr, floor::Prepare, floor::Eval};
+  return &r;
+}
+
+}  // namespace builtin
+}  // namespace ops
+}  // namespace tflite
diff --git a/tensorflow/contrib/lite/kernels/floor_test.cc b/tensorflow/contrib/lite/kernels/floor_test.cc
new file mode 100644 (file)
index 0000000..b71e040
--- /dev/null
@@ -0,0 +1,83 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <gtest/gtest.h>
+#include "tensorflow/contrib/lite/interpreter.h"
+#include "tensorflow/contrib/lite/kernels/register.h"
+#include "tensorflow/contrib/lite/kernels/test_util.h"
+#include "tensorflow/contrib/lite/model.h"
+
+namespace tflite {
+namespace {
+
+using ::testing::ElementsAreArray;
+
+class FloorOpModel : public SingleOpModel {
+ public:
+  FloorOpModel(std::initializer_list<int> input_shape, TensorType input_type) {
+    input_ = AddInput(TensorType_FLOAT32);
+    output_ = AddOutput(TensorType_FLOAT32);
+    SetBuiltinOp(BuiltinOperator_FLOOR, BuiltinOptions_NONE, 0);
+    BuildInterpreter({
+        input_shape,
+    });
+  }
+
+  int input() { return input_; }
+
+  std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+  std::vector<int> GetOutputShape() { return GetTensorShape(output_); }
+
+ private:
+  int input_;
+  int output_;
+};
+
+TEST(FloorOpTest, SingleDim) {
+  FloorOpModel model({2}, TensorType_FLOAT32);
+  model.PopulateTensor<float>(model.input(), {8.5, 0.0});
+  model.Invoke();
+  EXPECT_THAT(model.GetOutput(), ElementsAreArray({8, 0}));
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2}));
+}
+
+TEST(FloorOpTest, MultiDims) {
+  FloorOpModel model({2, 1, 1, 5}, TensorType_FLOAT32);
+  model.PopulateTensor<float>(model.input(), {
+                                                 0.0001,
+                                                 8.0001,
+                                                 0.9999,
+                                                 9.9999,
+                                                 0.5,
+                                                 -0.0001,
+                                                 -8.0001,
+                                                 -0.9999,
+                                                 -9.9999,
+                                                 -0.5,
+                                             });
+  model.Invoke();
+  EXPECT_THAT(model.GetOutput(),
+              ElementsAreArray({0, 8, 0, 9, 0, -1, -9, -1, -10, -1}));
+  EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 1, 1, 5}));
+}
+
+}  // namespace
+}  // namespace tflite
+
+int main(int argc, char** argv) {
+  ::tflite::LogToStderr();
+  ::testing::InitGoogleTest(&argc, argv);
+  return RUN_ALL_TESTS();
+}
index b07e7b6..f91d188 100644 (file)
@@ -80,6 +80,7 @@ TfLiteRegistration* Register_MAXIMUM();
 TfLiteRegistration* Register_MINIMUM();
 TfLiteRegistration* Register_ARG_MAX();
 TfLiteRegistration* Register_LESS();
+TfLiteRegistration* Register_FLOOR();
 
 BuiltinOpResolver::BuiltinOpResolver() {
   AddBuiltin(BuiltinOperator_RELU, Register_RELU());
@@ -141,6 +142,7 @@ BuiltinOpResolver::BuiltinOpResolver() {
   AddBuiltin(BuiltinOperator_MINIMUM, Register_MINIMUM());
   AddBuiltin(BuiltinOperator_ARG_MAX, Register_ARG_MAX());
   AddBuiltin(BuiltinOperator_LESS, Register_LESS());
+  AddBuiltin(BuiltinOperator_FLOOR, Register_FLOOR());
 
   // TODO(andrewharp, ahentz): Move these somewhere more appropriate so that
   // custom ops aren't always included by default.
index f45f39d..6fd3d9f 100644 (file)
@@ -347,6 +347,7 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
     case BuiltinOperator_LOG_SOFTMAX:
     case BuiltinOperator_DEQUANTIZE:
     case BuiltinOperator_PRELU:
+    case BuiltinOperator_FLOOR:
       break;
     case BuiltinOperator_CAST: {
       TfLiteCastParams* params = MallocPOD<TfLiteCastParams>();
index eab82ea..6a78f30 100644 (file)
@@ -278,6 +278,9 @@ void AddOpsAndParams(tflite::Interpreter* interpreter,
       case tflite::BuiltinOperator_TANH:
         nn_op_type = ANEURALNETWORKS_TANH;
         break;
+      case tflite::BuiltinOperator_FLOOR:
+        nn_op_type = ANEURALNETWORKS_FLOOR;
+        break;
       case tflite::BuiltinOperator_LOGISTIC:
         nn_op_type = ANEURALNETWORKS_LOGISTIC;
         break;
index 20d68ce..b16baf0 100644 (file)
@@ -78,7 +78,7 @@ enum BuiltinOperator : byte {
   // DEPTH_TO_SPACE = 5,
   DEQUANTIZE = 6,
   EMBEDDING_LOOKUP = 7,
-  // FLOOR = 8,
+  FLOOR = 8,
   FULLY_CONNECTED = 9,
   HASHTABLE_LOOKUP = 10,
   L2_NORMALIZATION = 11,
index 0b9961d..25ed9ab 100755 (executable)
@@ -221,6 +221,7 @@ enum BuiltinOperator {
   BuiltinOperator_DEPTHWISE_CONV_2D = 4,
   BuiltinOperator_DEQUANTIZE = 6,
   BuiltinOperator_EMBEDDING_LOOKUP = 7,
+  BuiltinOperator_FLOOR = 8,
   BuiltinOperator_FULLY_CONNECTED = 9,
   BuiltinOperator_HASHTABLE_LOOKUP = 10,
   BuiltinOperator_L2_NORMALIZATION = 11,
@@ -275,7 +276,7 @@ enum BuiltinOperator {
   BuiltinOperator_MAX = BuiltinOperator_LESS
 };
 
-inline BuiltinOperator (&EnumValuesBuiltinOperator())[57] {
+inline BuiltinOperator (&EnumValuesBuiltinOperator())[58] {
   static BuiltinOperator values[] = {
     BuiltinOperator_ADD,
     BuiltinOperator_AVERAGE_POOL_2D,
@@ -284,6 +285,7 @@ inline BuiltinOperator (&EnumValuesBuiltinOperator())[57] {
     BuiltinOperator_DEPTHWISE_CONV_2D,
     BuiltinOperator_DEQUANTIZE,
     BuiltinOperator_EMBEDDING_LOOKUP,
+    BuiltinOperator_FLOOR,
     BuiltinOperator_FULLY_CONNECTED,
     BuiltinOperator_HASHTABLE_LOOKUP,
     BuiltinOperator_L2_NORMALIZATION,
@@ -348,7 +350,7 @@ inline const char **EnumNamesBuiltinOperator() {
     "",
     "DEQUANTIZE",
     "EMBEDDING_LOOKUP",
-    "",
+    "FLOOR",
     "FULLY_CONNECTED",
     "HASHTABLE_LOOKUP",
     "L2_NORMALIZATION",
@@ -1485,8 +1487,8 @@ struct Conv2DOptionsT : public flatbuffers::NativeTable {
         stride_w(0),
         stride_h(0),
         fused_activation_function(ActivationFunctionType_NONE),
-        dilation_w_factor(0),
-        dilation_h_factor(0) {
+        dilation_w_factor(1),
+        dilation_h_factor(1) {
   }
 };
 
@@ -1513,10 +1515,10 @@ struct Conv2DOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
     return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
   }
   int32_t dilation_w_factor() const {
-    return GetField<int32_t>(VT_DILATION_W_FACTOR, 0);
+    return GetField<int32_t>(VT_DILATION_W_FACTOR, 1);
   }
   int32_t dilation_h_factor() const {
-    return GetField<int32_t>(VT_DILATION_H_FACTOR, 0);
+    return GetField<int32_t>(VT_DILATION_H_FACTOR, 1);
   }
   bool Verify(flatbuffers::Verifier &verifier) const {
     return VerifyTableStart(verifier) &&
@@ -1549,10 +1551,10 @@ struct Conv2DOptionsBuilder {
     fbb_.AddElement<int8_t>(Conv2DOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
   }
   void add_dilation_w_factor(int32_t dilation_w_factor) {
-    fbb_.AddElement<int32_t>(Conv2DOptions::VT_DILATION_W_FACTOR, dilation_w_factor, 0);
+    fbb_.AddElement<int32_t>(Conv2DOptions::VT_DILATION_W_FACTOR, dilation_w_factor, 1);
   }
   void add_dilation_h_factor(int32_t dilation_h_factor) {
-    fbb_.AddElement<int32_t>(Conv2DOptions::VT_DILATION_H_FACTOR, dilation_h_factor, 0);
+    fbb_.AddElement<int32_t>(Conv2DOptions::VT_DILATION_H_FACTOR, dilation_h_factor, 1);
   }
   explicit Conv2DOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
         : fbb_(_fbb) {
@@ -1572,8 +1574,8 @@ inline flatbuffers::Offset<Conv2DOptions> CreateConv2DOptions(
     int32_t stride_w = 0,
     int32_t stride_h = 0,
     ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
-    int32_t dilation_w_factor = 0,
-    int32_t dilation_h_factor = 0) {
+    int32_t dilation_w_factor = 1,
+    int32_t dilation_h_factor = 1) {
   Conv2DOptionsBuilder builder_(_fbb);
   builder_.add_dilation_h_factor(dilation_h_factor);
   builder_.add_dilation_w_factor(dilation_w_factor);
index bd888a4..a1162ce 100644 (file)
@@ -28,6 +28,7 @@ gen_zipped_test_files(
         "depthwiseconv.zip",
         "div.zip",
         "exp.zip",
+        "floor.zip",
         "fully_connected.zip",
         "fused_batch_norm.zip",
         "gather.zip",
index 9c9acf6..2f8f7a1 100644 (file)
@@ -2034,6 +2034,33 @@ def make_less_tests(zip_path):
 
   make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
 
+
+def make_floor_tests(zip_path):
+  """Make a set of tests to do floor."""
+
+  test_parameters = [{
+      "input_dtype": [tf.float32],
+      "input_shape": [[1], [1, 2], [5, 6, 7, 8], [3, 4, 5, 6]],
+  }]
+
+  def build_graph(parameters):
+    """Build the floor op testing graph."""
+    input_value = tf.placeholder(
+        dtype=parameters["input_dtype"],
+        name="input1",
+        shape=parameters["input_shape"])
+    out = tf.floor(input_value)
+    return [input_value], [out]
+
+  def build_inputs(parameters, sess, inputs, outputs):
+    input_value = create_tensor_data(parameters["input_dtype"],
+                                     parameters["input_shape"])
+    return [input_value], sess.run(
+        outputs, feed_dict={inputs[0]: input_value})
+
+  make_zip_of_tests(zip_path, test_parameters, build_graph, build_inputs)
+
+
 # Toco binary path provided by the generate rule.
 bin_path = None
 
index 9da8bd7..34abb21 100644 (file)
@@ -251,23 +251,25 @@ INSTANTIATE_TESTS(conv)
 INSTANTIATE_TESTS(depthwiseconv)
 INSTANTIATE_TESTS(div)
 INSTANTIATE_TESTS(exp)
+INSTANTIATE_TESTS(floor)
 INSTANTIATE_TESTS(fully_connected)
 INSTANTIATE_TESTS(fused_batch_norm)
 INSTANTIATE_TESTS(gather)
 INSTANTIATE_TESTS(global_batch_norm)
 INSTANTIATE_TESTS(l2_pool)
 INSTANTIATE_TESTS(l2norm)
+INSTANTIATE_TESTS(less)
 INSTANTIATE_TESTS(local_response_norm)
 INSTANTIATE_TESTS(log_softmax)
-INSTANTIATE_TESTS(maximum)
 INSTANTIATE_TESTS(max_pool)
+INSTANTIATE_TESTS(maximum)
 INSTANTIATE_TESTS(mean)
 INSTANTIATE_TESTS(minimum)
 INSTANTIATE_TESTS(mul)
 INSTANTIATE_TESTS(pad)
+// INSTANTIATE_TESTS(prelu)
 INSTANTIATE_TESTS(relu)
 INSTANTIATE_TESTS(relu1)
-// INSTANTIATE_TESTS(prelu)
 INSTANTIATE_TESTS(relu6)
 INSTANTIATE_TESTS(reshape)
 INSTANTIATE_TESTS(resize_bilinear)
@@ -280,7 +282,6 @@ INSTANTIATE_TESTS(squeeze)
 INSTANTIATE_TESTS(strided_slice)
 INSTANTIATE_TESTS(sub)
 INSTANTIATE_TESTS(transpose)
-INSTANTIATE_TESTS(less)
 
 }  // namespace testing
 }  // namespace tflite
index d2e14ac..fce3bad 100644 (file)
@@ -901,6 +901,8 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList() {
       "MINIMUM", OperatorType::kTensorFlowMinimum));
   ops.emplace_back(new SimpleOperator<TensorFlowLessOperator>(
       "LESS", OperatorType::kTensorFlowLess));
+  ops.emplace_back(
+      new SimpleOperator<FloorOperator>("FLOOR", OperatorType::kFloor));
 
   return ops;
 }