Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Greater.cpp
index 68a0d47..b073a4a 100644 (file)
@@ -1,6 +1,5 @@
 /*
  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * limitations under the License.
  */
 
-#include "kernels/Greater.h"
+#include "Builders.h"
 #include "kernels/Utils.h"
+#include "TISOKernel.h"
 
-#include <tensorflow/lite/kernels/internal/reference/comparisons.h>
+#include "PALComparisons.h"
 
 namespace luci_interpreter
 {
 
-namespace kernels
+namespace
 {
+// TODO: reduce code duplication with less
+template <typename T>
+void evalGeneric(const circle::Tensor *x, const circle::Tensor *y, const circle::Tensor *output,
+                 BaseRuntimeGraph *runtime_graph)
+{
+  auto x_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(x));
+  if (x_data == nullptr)
+    x_data = kernels::getTensorData<T>(runtime_graph->getConstDataByTensor(x));
+
+  assert(x_data != nullptr);
+
+  auto y_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(y));
+  if (y_data == nullptr)
+    y_data = kernels::getTensorData<T>(runtime_graph->getConstDataByTensor(y));
+
+  assert(y_data != nullptr);
+
+  auto output_data = kernels::getTensorData<bool>(runtime_graph->getDataByTensor(output));
 
-Greater::Greater(const Tensor *x, const Tensor *y, Tensor *output) : Kernel({x, y}, {output}) {}
+  luci_interpreter_pal::ComparisonParams op_params;
+  op_params.is_broadcast = Tensor::num_elements(x) != Tensor::num_elements(y);
 
-void Greater::configure()
+  const int64_t flat_size = kernels::getTensorShape(x).flatSize();
+  luci_interpreter_pal::ComparisonNoScaling<T>(flat_size, x_data, y_data, output_data,
+                                               luci_interpreter_pal::GreaterFn);
+}
+
+} // namespace
+
+void configure_kernel_CircleGreater(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
 {
-  LUCI_INTERPRETER_CHECK(x()->element_type() == y()->element_type());
-  LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::BOOL);
+  kernels::TISOKernel kernel(cur_op, runtime_graph);
 
-  if (x()->element_type() == DataType::U8)
-  {
-    quantizeMultiplierSmallerThanOneExp(x()->scale(), &_x_multiplier, &_x_shift);
-    quantizeMultiplierSmallerThanOneExp(y()->scale(), &_y_multiplier, &_y_shift);
-  }
-  // TODO: enable it only if kernel with dynamic shapes
-  output()->resize(calculateShapeForBroadcast(x()->shape(), y()->shape()));
+  LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
+                         Tensor::element_type(kernel.input2()));
+  LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.output()) == DataType::BOOL);
 }
 
-void Greater::execute() const
+void execute_kernel_CircleGreater(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
 {
-  switch (x()->element_type())
+  kernels::TISOKernel kernel(cur_op, runtime_graph);
+
+  switch (Tensor::element_type(kernel.input1()))
   {
-    case DataType::FLOAT32:
-      evalFloat();
-      break;
     case DataType::S64:
-      evalInteger<int64_t>();
+      evalGeneric<int64_t>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph);
       break;
     case DataType::S32:
-      evalInteger<int32_t>();
+      evalGeneric<int32_t>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph);
       break;
-    case DataType::U8:
-      evalQuantized();
+#ifndef DIS_FLOAT
+    case DataType::FLOAT32:
+      evalGeneric<float>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph);
       break;
+#endif // DIS_FLOAT
     default:
       assert(false && "Unsupported type.");
   }
 }
 
-void Greater::evalFloat() const
-{
-  const auto x_data = getTensorData<float>(x());
-  const auto y_data = getTensorData<float>(y());
-  auto output_data = getTensorData<bool>(output());
-
-  tflite::ComparisonParams op_params;
-  op_params.is_broadcast = x()->shape() != y()->shape();
-
-  if (op_params.is_broadcast)
-  {
-    tflite::reference_ops::Broadcast4DSlowGreater(op_params, getTensorShape(x()), x_data,
-                                                  getTensorShape(y()), y_data,
-                                                  getTensorShape(output()), output_data);
-  }
-  else
-  {
-    tflite::reference_ops::Greater(op_params, getTensorShape(x()), x_data, getTensorShape(y()),
-                                   y_data, getTensorShape(output()), output_data);
-  }
-}
-
-template <typename T> void Greater::evalInteger() const
-{
-  const auto x_data = getTensorData<T>(x());
-  const auto y_data = getTensorData<T>(y());
-  auto output_data = getTensorData<bool>(output());
-
-  tflite::ComparisonParams op_params;
-  op_params.is_broadcast = x()->shape() != y()->shape();
-
-  if (op_params.is_broadcast)
-  {
-    tflite::reference_ops::Broadcast4DSlowGreaterNoScaling(op_params, getTensorShape(x()), x_data,
-                                                           getTensorShape(y()), y_data,
-                                                           getTensorShape(output()), output_data);
-  }
-  else
-  {
-    tflite::reference_ops::GreaterNoScaling(op_params, getTensorShape(x()), x_data,
-                                            getTensorShape(y()), y_data, getTensorShape(output()),
-                                            output_data);
-  }
-}
-
-void Greater::evalQuantized() const
-{
-  const auto x_data = getTensorData<uint8_t>(x());
-  const auto y_data = getTensorData<uint8_t>(y());
-  auto output_data = getTensorData<bool>(output());
-
-  tflite::ComparisonParams op_params;
-  op_params.left_shift = 8;
-  op_params.input1_offset = -x()->zero_point(); // Note the '-'
-  op_params.input1_shift = _x_shift;
-  op_params.input1_multiplier = _x_multiplier;
-  op_params.input2_offset = -y()->zero_point(); // Note the '-'
-  op_params.input2_shift = _y_shift;
-  op_params.input2_multiplier = _y_multiplier;
-  op_params.is_broadcast = x()->shape() != y()->shape();
-
-  if (op_params.is_broadcast)
-  {
-    tflite::reference_ops::Broadcast4DSlowGreaterWithScaling(op_params, getTensorShape(x()), x_data,
-                                                             getTensorShape(y()), y_data,
-                                                             getTensorShape(output()), output_data);
-  }
-  else
-  {
-    tflite::reference_ops::GreaterWithScaling(op_params, getTensorShape(x()), x_data,
-                                              getTensorShape(y()), y_data, getTensorShape(output()),
-                                              output_data);
-  }
-}
-
-} // namespace kernels
 } // namespace luci_interpreter