Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Logistic.cpp
index 4e8cba8..4dbc153 100644 (file)
@@ -1,6 +1,5 @@
 /*
  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
- * Copyright 2019 The TensorFlow Authors. All Rights Reserved.
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
 
 #include "Builders.h"
 #include "kernels/Utils.h"
-
-#include <tensorflow/lite/kernels/internal/reference/logistic.h>
+#include "PALLogistic.h"
 
 namespace luci_interpreter
 {
-namespace
-{
-
-#ifndef DIS_FLOAT
-void evalFloat(const circle::Tensor *input, const circle::Tensor *output, bool is_inplace,
-               BaseRuntimeGraph *runtime_graph)
-{
-  const float *input_data = reinterpret_cast<const float *>(runtime_graph->getDataByTensor(input));
-  float *output_data = reinterpret_cast<float *>(runtime_graph->getDataByTensor(output));
-
-  if (is_inplace)
-  {
-    output_data = const_cast<float *>(input_data);
-  }
-
-  assert(input_data != nullptr);
-  assert(output_data != nullptr);
-
-  tflite::reference_ops::Logistic(kernels::getTensorShape(input), input_data,
-                                  kernels::getTensorShape(output), output_data);
-  if (is_inplace)
-  {
-    runtime_graph->makeInplaceOperation(input, output);
-  }
-}
-#endif // DIS_FLOAT
-
-#ifndef DIS_QUANT
-void evalQuantized(const circle::Tensor *input, const circle::Tensor *output, bool is_inplace,
-                   BaseRuntimeGraph *runtime_graph)
-{
-  const int8_t *input_data =
-    reinterpret_cast<const int8_t *>(runtime_graph->getDataByTensor(input));
-  int8_t *output_data = reinterpret_cast<int8_t *>(runtime_graph->getDataByTensor(output));
-  if (is_inplace)
-    output_data = const_cast<int8_t *>(input_data);
-
-  tflite::reference_ops::Logistic(kernels::getTensorShape(input), input_data, Tensor::scale(input),
-                                  Tensor::zero_point(input), kernels::getTensorShape(output),
-                                  output_data, Tensor::scale(output), Tensor::zero_point(output));
-  if (is_inplace)
-  {
-    runtime_graph->makeInplaceOperation(input, output);
-  }
-}
-#endif // DIS_QUANT
-
-} // namespace
 
 void configure_kernel_CircleLogistic(const circle::Operator *cur_op,
                                      BaseRuntimeGraph *runtime_graph)
@@ -96,8 +46,7 @@ void configure_kernel_CircleLogistic(const circle::Operator *cur_op,
 #endif // DIS_QUANT
 }
 
-void execute_kernel_CircleLogistic(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph,
-                                   bool is_inplace)
+void execute_kernel_CircleLogistic(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
 {
   const auto input_index = cur_op->inputs()->operator[](0);
   const auto output_index = cur_op->outputs()->operator[](0);
@@ -111,21 +60,45 @@ void execute_kernel_CircleLogistic(const circle::Operator *cur_op, BaseRuntimeGr
   assert(input != nullptr);
   assert(output != nullptr);
 
+  bool is_inplace = runtime_graph->is_inplace_op(cur_op);
+
+  const uint8_t *input_data = runtime_graph->getDataByTensor(input);
+  uint8_t *output_data = runtime_graph->getDataByTensor(output);
+
+  if (is_inplace)
+  {
+    output_data = const_cast<uint8_t *>(input_data);
+  }
+
+  assert(input_data != nullptr);
+  assert(output_data != nullptr);
+
+  const int flat_size = kernels::getTensorRuntimeShape(input, runtime_graph).flatSize();
+
   switch (Tensor::element_type(input))
   {
 #ifndef DIS_FLOAT
     case DataType::FLOAT32:
-      evalFloat(input, output, is_inplace, runtime_graph);
+      luci_interpreter_pal::Logistic(flat_size, kernels::getTensorData<float>(input_data),
+                                     kernels::getTensorData<float>(output_data));
       break;
 #endif // DIS_FLOAT
 #ifndef DIS_QUANT
     case DataType::S8:
-      evalQuantized(input, output, is_inplace, runtime_graph);
+      luci_interpreter_pal::Logistic(flat_size, kernels::getTensorData<int8_t>(input_data),
+                                     Tensor::scale(input), Tensor::zero_point(input),
+                                     kernels::getTensorData<int8_t>(output_data),
+                                     Tensor::scale(output), Tensor::zero_point(output));
       break;
 #endif // DIS_QUANT
     default:
       assert(false && "Unsupported type.");
   }
+
+  if (is_inplace)
+  {
+    runtime_graph->makeInplaceOperation(input, output);
+  }
 }
 
 } // namespace luci_interpreter