Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / ArgMax.cpp
index 5ac4bcb..be6291c 100644 (file)
  * limitations under the License.
  */
 
-#include "kernels/ArgMax.h"
+#include "Builders.h"
 #include "kernels/Utils.h"
-#include "PALArgMax.h"
+#include "TISOKernel.h"
+
+#include "PALArgMinMax.h"
 
 namespace luci_interpreter
 {
-namespace kernels
-{
 
-ArgMax::ArgMax(const Tensor *input, const Tensor *axis, Tensor *output, const ArgMaxParams &params)
-  : KernelWithParams<ArgMaxParams>({input, axis}, {output}, params)
+void configure_kernel_CircleArgMax(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
 {
+  kernels::TISOKernel kernel(cur_op, runtime_graph);
+  // dim tensor must be a scalar or has one element
+  LUCI_INTERPRETER_CHECK(Tensor::num_dims(kernel.input2()) == 0 or
+                         Tensor::num_elements(kernel.input2()) == 1);
+  // value and output type must match
+  LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.output()) == DataType::S32);
+
+  LUCI_INTERPRETER_CHECK(Tensor::element_type(kernel.input2()) == DataType::S32);
 }
 
-void ArgMax::configure()
+void execute_kernel_CircleArgMax(const circle::Operator *cur_op, BaseRuntimeGraph *runtime_graph)
 {
-  assert(axis()->element_type() == DataType::S32 || axis()->element_type() == DataType::S64);
-  assert(input()->shape().num_dims() >= 1);
-  const Shape &input_shape = input()->shape();
-  const int num_dims = input_shape.num_dims();
-  Shape output_shape(num_dims - 1);
-
-  // If axis value is negative, then update by adding input_shape's num_dims.
-  // If updated value also negative, then assert.
-  assert(axis()->shape().num_elements() == 1);
-  int axis_value = getTensorData<int32_t>(axis())[0];
-  if (axis_value < 0)
-    axis_value = axis_value + num_dims;
-  assert(axis_value >= 0);
-
-  int j = 0;
-  for (int i = 0; i < num_dims; i++)
-  {
-    if (i == axis_value)
-      continue;
-    output_shape.dim(j++) = input_shape.dim(i);
-  }
+  kernels::TISOKernel kernel(cur_op, runtime_graph);
 
-  assert(output()->element_type() == _params.output_type);
+  const circle::Tensor *input = kernel.input1();
+  const circle::Tensor *output = kernel.output();
 
-  // TODO: enable it only if kernel with dynamic shapes
-  output()->resize(output_shape);
-}
-
-void ArgMax::execute() const
-{
+  kernels::TISOData tiso_data = kernel.readData();
+  const auto input_data = tiso_data.input1_data;
+  const auto axis_data = tiso_data.input2_data;
+  auto output_data = tiso_data.output_data;
 
-#define TF_LITE_ARG_MAX(data_type, axis_type, output_type)                                    \
-  luci_interpreter_pal::ArgMinMax(getTensorShape(input()), getTensorData<data_type>(input()), \
-                                  getTensorData<axis_type>(axis()), getTensorShape(output()), \
-                                  getTensorData<output_type>(output()), std::greater<data_type>())
-  if (axis()->element_type() == DataType::S32)
-  {
-    switch (_params.output_type)
-    {
-      case DataType::S32:
-        switch (input()->element_type())
-        {
-          case DataType::FLOAT32:
-            TF_LITE_ARG_MAX(float, int32_t, int32_t);
-            break;
-          case DataType::U8:
-            TF_LITE_ARG_MAX(uint8_t, int32_t, int32_t);
-            break;
-          default:
-            assert(false && "Unsupported input type.");
-        }
-        break;
-      case DataType::S64:
-        switch (input()->element_type())
-        {
-          case DataType::FLOAT32:
-            TF_LITE_ARG_MAX(float, int32_t, int64_t);
-            break;
-          case DataType::U8:
-            TF_LITE_ARG_MAX(uint8_t, int32_t, int64_t);
-            break;
-          default:
-            assert(false && "Unsupported input type.");
-        }
-        break;
-      default:
-        assert(false && "Unsupported output type.");
-    }
-  }
-  else
+  switch (Tensor::element_type(input))
   {
-    switch (_params.output_type)
+#ifndef DIS_FLOAT
+    case DataType::FLOAT32:
     {
-      case DataType::S32:
-        switch (input()->element_type())
-        {
-          case DataType::FLOAT32:
-            TF_LITE_ARG_MAX(float, int64_t, int32_t);
-            break;
-          case DataType::U8:
-            TF_LITE_ARG_MAX(uint8_t, int64_t, int32_t);
-            break;
-          default:
-            assert(false && "Unsupported input type.");
-        }
-        break;
-      case DataType::S64:
-        switch (input()->element_type())
-        {
-          case DataType::FLOAT32:
-            TF_LITE_ARG_MAX(float, int64_t, int64_t);
-            break;
-          case DataType::U8:
-            TF_LITE_ARG_MAX(uint8_t, int64_t, int64_t);
-            break;
-          default:
-            assert(false && "Unsupported input type.");
-        }
-        break;
-      default:
-        assert(false && "Unsupported output type.");
+      luci_interpreter_pal::ArgMinMax(
+        kernels::getTensorRuntimeShape(input, runtime_graph),
+        kernels::getTensorData<float>(input_data), kernels::getTensorData<int32_t>(axis_data),
+        kernels::getTensorRuntimeShape(output, runtime_graph),
+        kernels::getTensorData<int32_t>(output_data), std::greater<float>());
     }
+    break;
+#endif // DIS_FLOAT
+    default:
+      assert(false && "Unsupported ArgMax input type");
   }
-#undef TF_LITE_ARG_MAX
 }
 
-} // namespace kernels
 } // namespace luci_interpreter