Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Concatenation.cpp
index a9eff09..c8c096e 100644 (file)
@@ -18,7 +18,7 @@
 #include "Builders.h"
 #include "kernels/Utils.h"
 
-#include <tensorflow/lite/kernels/internal/reference/concatenation.h>
+#include "PALConcatenation.h"
 
 namespace luci_interpreter
 {
@@ -27,7 +27,7 @@ namespace
 {
 
 template <typename T>
-void evalGeneric(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph, bool)
+void evalGeneric(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
 {
   const auto output_index = cur_op->outputs()->operator[](0);
 
@@ -44,37 +44,38 @@ void evalGeneric(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph
   const auto input_sizes = cur_op->inputs()->size();
 
   std::vector<const T *> all_input_data;
-  std::vector<tflite::RuntimeShape> all_shape;
-  std::vector<tflite::RuntimeShape *> all_shape_ptr;
-
-  all_input_data.reserve(input_sizes);
-  all_shape.reserve(input_sizes);
-  all_shape_ptr.reserve(input_sizes);
+  std::vector<luci_interpreter::RuntimeShape> all_shape;
+  std::vector<luci_interpreter::RuntimeShape *> all_shape_ptr;
 
   for (int32_t i = 0; i < input_sizes; ++i)
   {
     auto input_index = cur_op->inputs()->operator[](i);
     const auto *tensor = runtime_graph->getCircleTensorByIndex(input_index);
 
-    auto *data = reinterpret_cast<const T *>(runtime_graph->getDataByTensor(tensor));
+    const auto *tensor_data = runtime_graph->getDataByTensor(tensor);
+    if (tensor_data == nullptr)
+      tensor_data = runtime_graph->getConstDataByTensor(tensor);
+
+    auto *data = reinterpret_cast<const T *>(tensor_data);
+
+    auto runtime_shape = kernels::getTensorRuntimeShape(tensor, runtime_graph);
 
     all_input_data.push_back(data);
-    all_shape.push_back(kernels::getTensorShape(tensor));
+    all_shape.push_back(runtime_shape);
   }
 
-  for (tflite::RuntimeShape &shape : all_shape)
+  for (luci_interpreter::RuntimeShape &shape : all_shape)
   {
     all_shape_ptr.push_back(&shape);
   }
 
   auto *output_data = reinterpret_cast<T *>(runtime_graph->getDataByTensor(output));
 
-  // kernels::VectorOfTensors<T, true> inputs(_inputs);
-  tflite::ConcatenationParams params{};
+  luci_interpreter_pal::ConcatenationParams params{};
   params.axis = axis;
-  params.inputs_count = input_sizes;
-  tflite::reference_ops::Concatenation(params, all_shape_ptr.data(), all_input_data.data(),
-                                       kernels::getTensorShape(output), output_data);
+  params.inputs_count = all_shape.size();
+  luci_interpreter_pal::Concatenation(params, all_shape_ptr.data(), all_input_data.data(),
+                                      kernels::getTensorShape(output), output_data);
 }
 
 } // namespace
@@ -104,24 +105,12 @@ void configure_kernel_CircleConcatenation(const circle::Operator *cur_op,
     axis += Tensor::num_dims(t0);
   LUCI_INTERPRETER_CHECK(axis >= 0 && axis < Tensor::num_dims(t0));
 
-  int32_t sum_axis = Tensor::dim(t0, axis);
   for (int i = 1; i < num_inputs; ++i)
   {
     input_index = cur_op->inputs()->operator[](i);
     const auto *tensor = runtime_graph->getCircleTensorByIndex(input_index);
     LUCI_INTERPRETER_CHECK(Tensor::element_type(tensor) == Tensor::element_type(t0));
     LUCI_INTERPRETER_CHECK(Tensor::num_dims(tensor) == Tensor::num_dims(t0));
-    for (int d = 0; d < Tensor::num_dims(t0); ++d)
-    {
-      if (d == axis)
-      {
-        sum_axis += Tensor::dim(tensor, axis);
-      }
-      else
-      {
-        LUCI_INTERPRETER_CHECK(Tensor::dim(tensor, d) == Tensor::dim(t0, d));
-      }
-    }
   }
 
 #ifndef DIS_QUANT
@@ -145,7 +134,7 @@ void configure_kernel_CircleConcatenation(const circle::Operator *cur_op,
 }
 
 void execute_kernel_CircleConcatenation(const circle::Operator *cur_op,
-                                        BaseRuntimeGraph *runtime_graph, bool is_inplace)
+                                        BaseRuntimeGraph *runtime_graph)
 {
   int num_inputs = cur_op->inputs()->size();
   LUCI_INTERPRETER_CHECK(num_inputs > 0);
@@ -158,20 +147,20 @@ void execute_kernel_CircleConcatenation(const circle::Operator *cur_op,
   {
 #ifndef DIS_FLOAT
     case DataType::FLOAT32:
-      evalGeneric<float>(cur_op, runtime_graph, is_inplace);
+      evalGeneric<float>(cur_op, runtime_graph);
       break;
 #endif // DIS_FLOAT
 #ifndef DIS_QUANT
     case DataType::S8:
-      evalGeneric<int8_t>(cur_op, runtime_graph, is_inplace);
+      evalGeneric<int8_t>(cur_op, runtime_graph);
       break;
+#endif // DIS_QUANT
     case DataType::S32:
-      evalGeneric<int32_t>(cur_op, runtime_graph, is_inplace);
+      evalGeneric<int32_t>(cur_op, runtime_graph);
       break;
     case DataType::S64:
-      evalGeneric<int64_t>(cur_op, runtime_graph, is_inplace);
+      evalGeneric<int64_t>(cur_op, runtime_graph);
       break;
-#endif
     default:
       assert(false && "Unsupported type.");
   }