Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Equal.cpp
index 003d643..76968a3 100644 (file)
@@ -1,6 +1,5 @@
 /*
  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
  * limitations under the License.
  */
 
-#include "kernels/Equal.h"
+#include "Builders.h"
 #include "kernels/Utils.h"
+#include "TISOKernel.h"
 
-#include <tensorflow/lite/kernels/internal/reference/comparisons.h>
+#include "PALComparisons.h"
 
 namespace luci_interpreter
 {
 
-namespace kernels
+namespace
 {
-
-Equal::Equal(const Tensor *x, const Tensor *y, Tensor *output) : Kernel({x, y}, {output}) {}
-
-void Equal::configure()
+// TODO: reduce code duplication with less
+template <typename T>
+void evalGeneric(const circle::Tensor *x, const circle::Tensor *y, const circle::Tensor *output,
+                 BaseRuntimeGraph *runtime_graph)
 {
-  LUCI_INTERPRETER_CHECK(x()->element_type() == y()->element_type());
-  LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::BOOL);
+  auto x_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(x));
+  if (x_data == nullptr)
+    x_data = kernels::getTensorData<T>(runtime_graph->getConstDataByTensor(x));
 
-  if (x()->element_type() == DataType::U8)
-  {
-    quantizeMultiplierSmallerThanOneExp(x()->scale(), &_x_multiplier, &_x_shift);
-    quantizeMultiplierSmallerThanOneExp(y()->scale(), &_y_multiplier, &_y_shift);
-  }
-  // TODO: enable it only if kernel with dynamic shapes
-  output()->resize(calculateShapeForBroadcast(x()->shape(), y()->shape()));
-}
+  assert(x_data != nullptr);
 
-void Equal::execute() const
-{
-  switch (x()->element_type())
-  {
-    case DataType::FLOAT32:
-      evalFloat();
-      break;
-    case DataType::S64:
-      evalInteger<int64_t>();
-      break;
-    case DataType::S32:
-      evalInteger<int32_t>();
-      break;
-    case DataType::U8:
-      evalQuantized();
-      break;
-    default:
-      assert(false && "Unsupported type.");
-  }
-}
+  auto y_data = kernels::getTensorData<T>(runtime_graph->getDataByTensor(y));
+  if (y_data == nullptr)
+    y_data = kernels::getTensorData<T>(runtime_graph->getConstDataByTensor(y));
 
-void Equal::evalFloat() const
-{
-  const auto x_data = getTensorData<float>(x());
-  const auto y_data = getTensorData<float>(y());
-  auto output_data = getTensorData<bool>(output());
+  assert(y_data != nullptr);
+
+  auto output_data = kernels::getTensorData<bool>(runtime_graph->getDataByTensor(output));
 
-  tflite::ComparisonParams op_params;
-  op_params.is_broadcast = x()->shape() != y()->shape();
+  luci_interpreter_pal::ComparisonParams op_params;
+  op_params.is_broadcast = Tensor::num_elements(x) != Tensor::num_elements(y);
 
   if (op_params.is_broadcast)
   {
-    tflite::reference_ops::Broadcast4DSlowEqual(op_params, getTensorShape(x()), x_data,
-                                                getTensorShape(y()), y_data,
-                                                getTensorShape(output()), output_data);
+    luci_interpreter_pal::BroadcastComparison4DSlowNoScaling<T>(
+      op_params, kernels::getTensorShape(x), x_data, kernels::getTensorShape(y), y_data,
+      kernels::getTensorShape(output), output_data, luci_interpreter_pal::EqualFn);
   }
   else
   {
-    tflite::reference_ops::Equal(op_params, getTensorShape(x()), x_data, getTensorShape(y()),
-                                 y_data, getTensorShape(output()), output_data);
+    const int64_t flat_size = kernels::getTensorShape(x).flatSize();
+    luci_interpreter_pal::ComparisonNoScaling<T>(flat_size, x_data, y_data, output_data,
+                                                 luci_interpreter_pal::EqualFn);
   }
 }
 
-template <typename T> void Equal::evalInteger() const
-{
-  const auto x_data = getTensorData<T>(x());
-  const auto y_data = getTensorData<T>(y());
-  auto output_data = getTensorData<bool>(output());
+} // namespace
 
-  tflite::ComparisonParams op_params;
-  op_params.is_broadcast = x()->shape() != y()->shape();
+void configure_kernel_CircleEqual(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
+{
+  kernels::TISOKernel kernel(cur_op, runtime_graph);
 
-  if (op_params.is_broadcast)
-  {
-    tflite::reference_ops::Broadcast4DSlowEqualNoScaling(op_params, getTensorShape(x()), x_data,
-                                                         getTensorShape(y()), y_data,
-                                                         getTensorShape(output()), output_data);
-  }
-  else
-  {
-    tflite::reference_ops::EqualNoScaling(op_params, getTensorShape(x()), x_data,
-                                          getTensorShape(y()), y_data, getTensorShape(output()),
-                                          output_data);
-  }
+  LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) ==
+                         Tensor::element_type(kernel.input2()));
+  LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.output()) == DataType::BOOL);
 }
 
-void Equal::evalQuantized() const
+void execute_kernel_CircleEqual(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
 {
-  const auto x_data = getTensorData<uint8_t>(x());
-  const auto y_data = getTensorData<uint8_t>(y());
-  auto output_data = getTensorData<bool>(output());
-
-  tflite::ComparisonParams op_params;
-  op_params.left_shift = 8;
-  op_params.input1_offset = -x()->zero_point(); // Note the '-'
-  op_params.input1_shift = _x_shift;
-  op_params.input1_multiplier = _x_multiplier;
-  op_params.input2_offset = -y()->zero_point(); // Note the '-'
-  op_params.input2_shift = _y_shift;
-  op_params.input2_multiplier = _y_multiplier;
-  op_params.is_broadcast = x()->shape() != y()->shape();
+  kernels::TISOKernel kernel(cur_op, runtime_graph);
 
-  if (op_params.is_broadcast)
-  {
-    tflite::reference_ops::Broadcast4DSlowEqualWithScaling(op_params, getTensorShape(x()), x_data,
-                                                           getTensorShape(y()), y_data,
-                                                           getTensorShape(output()), output_data);
-  }
-  else
+  switch (Tensor::element_type(kernel.input1()))
   {
-    tflite::reference_ops::EqualWithScaling(op_params, getTensorShape(x()), x_data,
-                                            getTensorShape(y()), y_data, getTensorShape(output()),
-                                            output_data);
+    case DataType::S64:
+      evalGeneric<int64_t>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph);
+      break;
+    case DataType::S32:
+      evalGeneric<int32_t>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph);
+      break;
+#ifndef DIS_FLOAT
+    case DataType::FLOAT32:
+      evalGeneric<float>(kernel.input1(), kernel.input2(), kernel.output(), runtime_graph);
+      break;
+#endif // DIS_FLOAT
+    default:
+      assert(false && "Unsupported type.");
   }
 }
 
-} // namespace kernels
 } // namespace luci_interpreter