Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / Gather.cpp
index d26b718..b146b89 100644 (file)
  * limitations under the License.
  */
 
-#include "kernels/Gather.h"
+#include "Builders.h"
 #include "kernels/Utils.h"
-#include "PALGather.h"
+#include "TISOKernel.h"
 
 #include <cassert>
 
 namespace luci_interpreter
 {
-
-namespace kernels
+namespace
 {
 
-Gather::Gather(const Tensor *params, const Tensor *indices, Tensor *output,
-               const GatherParams &gparams)
-  : KernelWithParams<GatherParams>({params, indices}, {output}, gparams)
+template <typename InputT, typename CoordsT = int32_t>
+void gather(const circle::GatherOptions *options, kernels::TISOKernel *kernel)
 {
-}
-
-void Gather::configure()
-{
-  if (params()->element_type() == DataType::FLOAT32)
-  {
-    LUCI_INTERPRETER_CHECK(output()->element_type() == DataType::FLOAT32);
-  }
-  else
-  {
-    assert(false && "Unsupported type.");
-  }
-
-  LUCI_INTERPRETER_CHECK(indices()->element_type() == DataType::S32 ||
-                         indices()->element_type() == DataType::S64);
+  kernels::TISOData tiso_data = kernel->readData();
 
-  // refer tensorflow/lite/kernels/gather.cc
+  const InputT *input_data = kernels::getTensorData<InputT>(tiso_data.input1_data);
+  const CoordsT *coords_data = kernels::getTensorData<CoordsT>(tiso_data.input2_data);
+  InputT *output_data = kernels::getTensorData<InputT>(tiso_data.output_data);
 
-  const Shape &params_shape = params()->shape();
-  const Shape &indices_shape = indices()->shape();
+  const circle::Tensor *input = kernel->input1();
+  const circle::Tensor *coords = kernel->input2();
 
-  int axis = _params.axis;
+  const int input_dims_size = Tensor::num_dims(input);
+  int axis = options->axis();
   if (axis < 0)
   {
-    axis += params_shape.num_dims();
+    axis += input_dims_size;
   }
-  LUCI_INTERPRETER_CHECK(0 <= axis && axis < params_shape.num_dims());
 
-  int batch_dims = _params.batch_dims;
-  // batch_dims should be in range: [-rank(indices), rank(indices)].
-  // Negative batch_dims is added with rank of positions.
+  int batch_dims = options->batch_dims();
+  // batch_dims should be in range: [-rank(coords), rank(coords)].
+  // Negative batch_dims is added with rank of coords.
+  const int coords_dims_size = Tensor::num_dims(coords);
   if (batch_dims < 0)
   {
-    batch_dims += indices_shape.num_dims();
+    batch_dims += coords_dims_size;
   }
-  LUCI_INTERPRETER_CHECK(batch_dims <= axis);
-  LUCI_INTERPRETER_CHECK(0 <= batch_dims && batch_dims < params_shape.num_dims());
-  LUCI_INTERPRETER_CHECK(batch_dims <= indices_shape.num_dims());
+
+  const int axis_size = Tensor::dim(input, axis);
+
+  int batch_size = 1;
   for (int i = 0; i < batch_dims; ++i)
   {
-    LUCI_INTERPRETER_CHECK(params_shape.dim(i) == indices_shape.dim(i));
+    batch_size *= Tensor::dim(input, i);
   }
-
-  const int num_dimensions = params_shape.num_dims() + indices_shape.num_dims() - 1 - batch_dims;
-
-  Shape output_shape(num_dimensions);
-  int output_index = 0;
-  for (int i = 0; i < axis; ++i)
+  int outer_size = 1;
+  for (int i = batch_dims; i < axis; ++i)
   {
-    output_shape.dim(output_index++) = params_shape.dim(i);
+    outer_size *= Tensor::dim(input, i);
   }
-  for (int i = batch_dims; i < indices_shape.num_dims(); ++i)
+  int inner_size = 1;
+  for (int i = axis + 1; i < input_dims_size; ++i)
   {
-    output_shape.dim(output_index++) = indices_shape.dim(i);
+    inner_size *= Tensor::dim(input, i);
   }
-  for (int i = axis + 1; i < params_shape.num_dims(); ++i)
+  int coord_size = 1;
+  for (int i = batch_dims; i < coords_dims_size; ++i)
   {
-    output_shape.dim(output_index++) = params_shape.dim(i);
+    coord_size *= Tensor::dim(coords, i);
   }
-  // TODO: enable it only if kernel with dynamic shapes
-  output()->resize(output_shape);
-}
 
-void Gather::execute() const
-{
-  switch (params()->element_type())
+  for (int batch = 0; batch < batch_size; ++batch)
   {
-    case DataType::FLOAT32:
-      evalFloat();
-      break;
-    default:
-      assert(false && "Unsupported type.");
+    for (int outer = 0; outer < outer_size; ++outer)
+    {
+      for (int coord = 0; coord < coord_size; ++coord)
+      {
+        auto x = coords_data[coord];
+        std::memcpy(
+          output_data + (((batch * outer_size) + outer) * coord_size + coord) * inner_size,
+          input_data +
+            (((batch * outer_size) + outer) * axis_size + coords_data[batch * coord_size + coord]) *
+              inner_size,
+          sizeof(InputT) * inner_size);
+      }
+    }
   }
 }
 
-void Gather::evalFloat() const
+} // namespace
+
+void configure_kernel_CircleGather(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
 {
-  assert(indices()->element_type() == DataType::S32 || indices()->element_type() == DataType::S64);
+  kernels::TISOKernel kernel(cur_op, runtime_graph);
 
-  const auto params_data = getTensorData<float>(params());
-  auto output_data = getTensorData<float>(output());
+  const auto *options = cur_op->builtin_options_as_GatherOptions();
 
-  tflite::GatherParams tparams;
-  tparams.axis = _params.axis;
-  tparams.batch_dims = _params.batch_dims;
+  LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input2()) == DataType::S32);
+  LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input1()) == DataType::FLOAT32 or
+                         Tensor::element_type(kernel.input1()) == DataType::S8 or
+                         Tensor::element_type(kernel.input1()) == DataType::S32);
 
-  if (indices()->element_type() == DataType::S32)
+  int32_t axis = options->axis();
+  int32_t num_dims = Tensor::num_dims(kernel.input1());
+  if (axis < 0)
   {
-    const auto indices_data = getTensorData<int32_t>(indices());
+    axis += num_dims;
+  }
+
+  LUCI_INTERPRETER_CHECK(axis >= 0 and axis < num_dims);
 
-    luci_interpreter_pal::Gather<float, int32_t>(tparams, getTensorShape(params()), params_data,
-                                                 getTensorShape(indices()), indices_data,
-                                                 getTensorShape(output()), output_data);
+  int32_t batch_dims = options->batch_dims();
+  int32_t coords_num_dims = Tensor::num_dims(kernel.input2());
+  // batch_dims should be in range: [-rank(coords), rank(coords)].
+  // Negative batch_dims is added with rank of coords.
+  if (batch_dims < 0)
+  {
+    batch_dims += coords_num_dims;
   }
-  else
+  LUCI_INTERPRETER_CHECK(batch_dims <= axis);
+  LUCI_INTERPRETER_CHECK(batch_dims >= 0 and batch_dims < num_dims);
+  LUCI_INTERPRETER_CHECK(batch_dims <= coords_num_dims);
+  for (int i = 0; i < batch_dims; ++i)
   {
-    const auto indices_data = getTensorData<int64_t>(indices());
+    LUCI_INTERPRETER_CHECK(Tensor::dim(kernel.input1(), i) == Tensor::dim(kernel.input2(), i));
+  }
+}
+
+void execute_kernel_CircleGather(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
+{
+  kernels::TISOKernel kernel(cur_op, runtime_graph);
 
-    luci_interpreter_pal::Gather<float, int64_t>(tparams, getTensorShape(params()), params_data,
-                                                 getTensorShape(indices()), indices_data,
-                                                 getTensorShape(output()), output_data);
+  const auto *options = cur_op->builtin_options_as_GatherOptions();
+
+  switch (Tensor::element_type(kernel.input1()))
+  {
+#ifndef DIS_FLOAT
+    case DataType::FLOAT32:
+      return gather<float, int32_t>(options, &kernel);
+#endif // DIS_FLOAT
+#ifndef DIS_QUANT
+    case DataType::S8:
+      return gather<int8_t, int32_t>(options, &kernel);
+#endif // DIS_QUANT
+    case DataType::S32:
+      return gather<int32_t, int32_t>(options, &kernel);
+    default:
+      assert(false && "Unsupported type");
   }
 }
 
-} // namespace kernels
 } // namespace luci_interpreter