Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / src / kernels / StridedSlice.cpp
index 654fd3c..3968fb9 100644 (file)
  * limitations under the License.
  */
 
-#include "kernels/StridedSlice.h"
-
+#include "Builders.h"
 #include "kernels/Utils.h"
+#include "MISOKernel.h"
 
-#include <tensorflow/lite/kernels/internal/reference/strided_slice.h>
+#include "PALStridedSlice.h"
 
 namespace luci_interpreter
 {
-
-namespace kernels
+namespace
 {
 
-StridedSlice::StridedSlice(const Tensor *input, const Tensor *begin, const Tensor *end,
-                           const Tensor *strides, Tensor *output, const StridedSliceParams &params)
-  : KernelWithParams<StridedSliceParams>({input, begin, end, strides}, {output}, params)
+luci_interpreter_pal::StridedSliceParams
+buildStridedSliceParams(int32_t dims, const int32_t *begin, const int32_t *end,
+                        const int32_t *strides, const circle::StridedSliceOptions *options)
 {
-}
+  luci_interpreter_pal::StridedSliceParams op_params;
+  op_params.start_indices_count = dims;
+  op_params.stop_indices_count = dims;
+  op_params.strides_count = dims;
 
-void StridedSlice::configure()
-{
-  assert(begin()->shape().num_dims() == 1);
-  assert(end()->shape().num_dims() == 1);
-  assert(strides()->shape().num_dims() == 1);
-  assert(input()->element_type() == output()->element_type());
-  assert(begin()->element_type() == DataType::S32);
-  assert(end()->element_type() == DataType::S32);
-  assert(strides()->element_type() == DataType::S32);
-  assert(input()->shape().num_dims() <= 4);
-  if (params().ellipsis_mask != 0)
-  {
-    assert(false && "ellipsis_mask is not implemented yet.");
-  }
-  if (params().new_axis_mask != 0)
-  {
-    assert(false && "new_axis_mask is not implemented yet.");
-  }
-  if (input()->element_type() == DataType::U8)
+  for (int i = 0; i < dims; ++i)
   {
-    assert(input()->scale() == output()->scale());
-    assert(input()->zero_point() == output()->zero_point());
+    op_params.start_indices[i] = begin[i];
+    op_params.stop_indices[i] = end[i];
+    op_params.strides[i] = strides[i];
   }
-  tflite::StridedSliceParams op_params{};
-  op_params.start_indices_count = input()->shape().num_dims();
-  op_params.stop_indices_count = input()->shape().num_dims();
-  op_params.strides_count = input()->shape().num_dims();
 
-  for (int i = 0; i < input()->shape().num_dims(); i++)
-  {
-    op_params.start_indices[i] = getTensorData<int32_t>(begin())[i];
-    op_params.stop_indices[i] = getTensorData<int32_t>(end())[i];
-    op_params.strides[i] = getTensorData<int32_t>(strides())[i];
-  }
-  op_params.begin_mask = params().begin_mask;
+  op_params.begin_mask = options->begin_mask();
   op_params.ellipsis_mask = 0;
-  op_params.end_mask = params().end_mask;
+  op_params.end_mask = options->end_mask();
   op_params.new_axis_mask = 0;
-  op_params.shrink_axis_mask = params().shrink_axis_mask;
-  std::vector<int32_t> output_shape_vector;
-  for (int i = 0; i < input()->shape().num_dims(); i++)
-  {
-    int idx = input()->shape().num_dims() - i - 1;
-    int32_t stride = getTensorData<int32_t>(strides())[idx];
-    assert(stride != 0);
-    int32_t begin = ::tflite::strided_slice::StartForAxis(op_params, getTensorShape(input()), idx);
-    int32_t end =
-      ::tflite::strided_slice::StopForAxis(op_params, getTensorShape(input()), idx, begin);
-
-    const bool shrink_axis = params().shrink_axis_mask & (1 << idx);
-    if (shrink_axis)
-    {
-      end = begin + 1;
-    }
-
-    int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
-    dim_shape = dim_shape < 0 ? 0 : dim_shape;
-    if (!shrink_axis)
-    {
-      output_shape_vector.push_back(dim_shape);
-    }
-  }
-  // TODO: enable it only if kernel with dynamic shapes
-  Shape output_shape = Shape(output_shape_vector.size());
-  for (size_t i = 0; i < output_shape_vector.size(); i++)
-  {
-    output_shape.dim(i) = output_shape_vector[output_shape_vector.size() - i - 1];
-  }
-  output()->resize(output_shape);
+  op_params.shrink_axis_mask = options->shrink_axis_mask();
+  return op_params;
 }
 
-void StridedSlice::execute() const
+} // namespace
+
+void configure_kernel_CircleStridedSlice(const circle::Operator *cur_op,
+                                         BaseRuntimeGraph *runtime_graph)
 {
-  tflite::StridedSliceParams op_params{};
-  op_params.start_indices_count = input()->shape().num_dims();
-  op_params.stop_indices_count = input()->shape().num_dims();
-  op_params.strides_count = input()->shape().num_dims();
+  kernels::MISOKernel miso_kernel(cur_op, runtime_graph);
 
-  for (int i = 0; i < input()->shape().num_dims(); i++)
-  {
-    op_params.start_indices[i] = getTensorData<int32_t>(begin())[i];
-    op_params.stop_indices[i] = getTensorData<int32_t>(end())[i];
-    op_params.strides[i] = getTensorData<int32_t>(strides())[i];
-  }
-  op_params.begin_mask = params().begin_mask;
-  op_params.ellipsis_mask = 0;
-  op_params.end_mask = params().end_mask;
-  op_params.new_axis_mask = 0;
-  op_params.shrink_axis_mask = params().shrink_axis_mask;
+  const circle::Tensor *input = miso_kernel.input1();
+  const circle::Tensor *begin = miso_kernel.input2();
+  const circle::Tensor *end = miso_kernel.input3();
+  const circle::Tensor *strides = miso_kernel.input4();
+
+  LUCI_INTERPRETER_CHECK(strides != nullptr);
+
+  const circle::Tensor *output = miso_kernel.output();
+
+  LUCI_INTERPRETER_CHECK(Tensor::element_type(begin) == DataType::S32);
+  LUCI_INTERPRETER_CHECK(Tensor::element_type(end) == DataType::S32);
+  LUCI_INTERPRETER_CHECK(Tensor::element_type(strides) == DataType::S32);
+  LUCI_INTERPRETER_CHECK(Tensor::element_type(input) == Tensor::element_type(output));
+}
+
+void execute_kernel_CircleStridedSlice(const circle::Operator *cur_op,
+                                       BaseRuntimeGraph *runtime_graph)
+{
+  kernels::MISOKernel miso_kernel(cur_op, runtime_graph);
+
+  const circle::Tensor *input = miso_kernel.input1();
+  const circle::Tensor *begin = miso_kernel.input2();
+  const circle::Tensor *end = miso_kernel.input3();
+  const circle::Tensor *strides = miso_kernel.input4();
+  const circle::Tensor *output = miso_kernel.output();
+
+  const int32_t dims = Tensor::num_dims(input);
+
+  const uint8_t *input_data = runtime_graph->getDataByTensor(input);
+  const int32_t *begin_data =
+    kernels::getTensorData<int32_t>(runtime_graph->getConstDataByTensor(begin));
+  const int32_t *end_data =
+    kernels::getTensorData<int32_t>(runtime_graph->getConstDataByTensor(end));
+  const int32_t *strides_data =
+    kernels::getTensorData<int32_t>(runtime_graph->getConstDataByTensor(strides));
+  uint8_t *output_data = runtime_graph->getDataByTensor(output);
 
-  switch (input()->element_type())
+  LUCI_INTERPRETER_CHECK(input_data != nullptr);
+  LUCI_INTERPRETER_CHECK(begin_data != nullptr);
+  LUCI_INTERPRETER_CHECK(end_data != nullptr);
+  LUCI_INTERPRETER_CHECK(strides_data != nullptr);
+  LUCI_INTERPRETER_CHECK(output_data != nullptr);
+
+  const auto *options = cur_op->builtin_options_as_StridedSliceOptions();
+
+  auto op_params = buildStridedSliceParams(dims, begin_data, end_data, strides_data, options);
+
+  switch (Tensor::element_type(input))
   {
+#ifndef DIS_FLOAT
     case DataType::FLOAT32:
-      tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()),
-                                          getTensorData<float>(input()), getTensorShape(output()),
-                                          getTensorData<float>(output()));
+      luci_interpreter_pal::StridedSlice(op_params, kernels::getTensorShape(input),
+                                         kernels::getTensorData<float>(input_data),
+                                         kernels::getTensorData<float>(output_data));
       break;
+#endif // DIS_FLOAT
+#ifndef DIS_QUANT
     case DataType::U8:
-      tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()),
-                                          getTensorData<uint8_t>(input()), getTensorShape(output()),
-                                          getTensorData<uint8_t>(output()));
+      luci_interpreter_pal::StridedSlice(op_params, kernels::getTensorShape(input), input_data,
+                                         output_data);
+      break;
+    case DataType::S8:
+      luci_interpreter_pal::StridedSlice(op_params, kernels::getTensorShape(input), input_data,
+                                         output_data);
       break;
+#endif
     case DataType::S32:
-      tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()),
-                                          getTensorData<int32_t>(input()), getTensorShape(output()),
-                                          getTensorData<int32_t>(output()));
+      luci_interpreter_pal::StridedSlice(op_params, kernels::getTensorShape(input),
+                                         kernels::getTensorData<int32_t>(input_data),
+                                         kernels::getTensorData<int32_t>(output_data));
       break;
     default:
-      assert(false && "Unsupported type.");
+      assert(false && "Unsupported type");
   }
 }
 
-} // namespace kernels
 } // namespace luci_interpreter