* limitations under the License.
*/
-#include "kernels/StridedSlice.h"
-
+#include "Builders.h"
#include "kernels/Utils.h"
+#include "MISOKernel.h"
-#include <tensorflow/lite/kernels/internal/reference/strided_slice.h>
+#include "PALStridedSlice.h"
namespace luci_interpreter
{
-
-namespace kernels
+namespace
{
-StridedSlice::StridedSlice(const Tensor *input, const Tensor *begin, const Tensor *end,
- const Tensor *strides, Tensor *output, const StridedSliceParams ¶ms)
- : KernelWithParams<StridedSliceParams>({input, begin, end, strides}, {output}, params)
+luci_interpreter_pal::StridedSliceParams
+buildStridedSliceParams(int32_t dims, const int32_t *begin, const int32_t *end,
+ const int32_t *strides, const circle::StridedSliceOptions *options)
{
-}
+ luci_interpreter_pal::StridedSliceParams op_params;
+ op_params.start_indices_count = dims;
+ op_params.stop_indices_count = dims;
+ op_params.strides_count = dims;
-void StridedSlice::configure()
-{
- assert(begin()->shape().num_dims() == 1);
- assert(end()->shape().num_dims() == 1);
- assert(strides()->shape().num_dims() == 1);
- assert(input()->element_type() == output()->element_type());
- assert(begin()->element_type() == DataType::S32);
- assert(end()->element_type() == DataType::S32);
- assert(strides()->element_type() == DataType::S32);
- assert(input()->shape().num_dims() <= 4);
- if (params().ellipsis_mask != 0)
- {
- assert(false && "ellipsis_mask is not implemented yet.");
- }
- if (params().new_axis_mask != 0)
- {
- assert(false && "new_axis_mask is not implemented yet.");
- }
- if (input()->element_type() == DataType::U8)
+ for (int i = 0; i < dims; ++i)
{
- assert(input()->scale() == output()->scale());
- assert(input()->zero_point() == output()->zero_point());
+ op_params.start_indices[i] = begin[i];
+ op_params.stop_indices[i] = end[i];
+ op_params.strides[i] = strides[i];
}
- tflite::StridedSliceParams op_params{};
- op_params.start_indices_count = input()->shape().num_dims();
- op_params.stop_indices_count = input()->shape().num_dims();
- op_params.strides_count = input()->shape().num_dims();
- for (int i = 0; i < input()->shape().num_dims(); i++)
- {
- op_params.start_indices[i] = getTensorData<int32_t>(begin())[i];
- op_params.stop_indices[i] = getTensorData<int32_t>(end())[i];
- op_params.strides[i] = getTensorData<int32_t>(strides())[i];
- }
- op_params.begin_mask = params().begin_mask;
+ op_params.begin_mask = options->begin_mask();
op_params.ellipsis_mask = 0;
- op_params.end_mask = params().end_mask;
+ op_params.end_mask = options->end_mask();
op_params.new_axis_mask = 0;
- op_params.shrink_axis_mask = params().shrink_axis_mask;
- std::vector<int32_t> output_shape_vector;
- for (int i = 0; i < input()->shape().num_dims(); i++)
- {
- int idx = input()->shape().num_dims() - i - 1;
- int32_t stride = getTensorData<int32_t>(strides())[idx];
- assert(stride != 0);
- int32_t begin = ::tflite::strided_slice::StartForAxis(op_params, getTensorShape(input()), idx);
- int32_t end =
- ::tflite::strided_slice::StopForAxis(op_params, getTensorShape(input()), idx, begin);
-
- const bool shrink_axis = params().shrink_axis_mask & (1 << idx);
- if (shrink_axis)
- {
- end = begin + 1;
- }
-
- int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
- dim_shape = dim_shape < 0 ? 0 : dim_shape;
- if (!shrink_axis)
- {
- output_shape_vector.push_back(dim_shape);
- }
- }
- // TODO: enable it only if kernel with dynamic shapes
- Shape output_shape = Shape(output_shape_vector.size());
- for (size_t i = 0; i < output_shape_vector.size(); i++)
- {
- output_shape.dim(i) = output_shape_vector[output_shape_vector.size() - i - 1];
- }
- output()->resize(output_shape);
+ op_params.shrink_axis_mask = options->shrink_axis_mask();
+ return op_params;
}
-void StridedSlice::execute() const
+} // namespace
+
+void configure_kernel_CircleStridedSlice(const circle::Operator *cur_op,
+ BaseRuntimeGraph *runtime_graph)
{
- tflite::StridedSliceParams op_params{};
- op_params.start_indices_count = input()->shape().num_dims();
- op_params.stop_indices_count = input()->shape().num_dims();
- op_params.strides_count = input()->shape().num_dims();
+ kernels::MISOKernel miso_kernel(cur_op, runtime_graph);
- for (int i = 0; i < input()->shape().num_dims(); i++)
- {
- op_params.start_indices[i] = getTensorData<int32_t>(begin())[i];
- op_params.stop_indices[i] = getTensorData<int32_t>(end())[i];
- op_params.strides[i] = getTensorData<int32_t>(strides())[i];
- }
- op_params.begin_mask = params().begin_mask;
- op_params.ellipsis_mask = 0;
- op_params.end_mask = params().end_mask;
- op_params.new_axis_mask = 0;
- op_params.shrink_axis_mask = params().shrink_axis_mask;
+ const circle::Tensor *input = miso_kernel.input1();
+ const circle::Tensor *begin = miso_kernel.input2();
+ const circle::Tensor *end = miso_kernel.input3();
+ const circle::Tensor *strides = miso_kernel.input4();
+
+ LUCI_INTERPRETER_CHECK(strides != nullptr);
+
+ const circle::Tensor *output = miso_kernel.output();
+
+ LUCI_INTERPRETER_CHECK(Tensor::element_type(begin) == DataType::S32);
+ LUCI_INTERPRETER_CHECK(Tensor::element_type(end) == DataType::S32);
+ LUCI_INTERPRETER_CHECK(Tensor::element_type(strides) == DataType::S32);
+ LUCI_INTERPRETER_CHECK(Tensor::element_type(input) == Tensor::element_type(output));
+}
+
+void execute_kernel_CircleStridedSlice(const circle::Operator *cur_op,
+ BaseRuntimeGraph *runtime_graph)
+{
+ kernels::MISOKernel miso_kernel(cur_op, runtime_graph);
+
+ const circle::Tensor *input = miso_kernel.input1();
+ const circle::Tensor *begin = miso_kernel.input2();
+ const circle::Tensor *end = miso_kernel.input3();
+ const circle::Tensor *strides = miso_kernel.input4();
+ const circle::Tensor *output = miso_kernel.output();
+
+ const int32_t dims = Tensor::num_dims(input);
+
+ const uint8_t *input_data = runtime_graph->getDataByTensor(input);
+ const int32_t *begin_data =
+ kernels::getTensorData<int32_t>(runtime_graph->getConstDataByTensor(begin));
+ const int32_t *end_data =
+ kernels::getTensorData<int32_t>(runtime_graph->getConstDataByTensor(end));
+ const int32_t *strides_data =
+ kernels::getTensorData<int32_t>(runtime_graph->getConstDataByTensor(strides));
+ uint8_t *output_data = runtime_graph->getDataByTensor(output);
- switch (input()->element_type())
+ LUCI_INTERPRETER_CHECK(input_data != nullptr);
+ LUCI_INTERPRETER_CHECK(begin_data != nullptr);
+ LUCI_INTERPRETER_CHECK(end_data != nullptr);
+ LUCI_INTERPRETER_CHECK(strides_data != nullptr);
+ LUCI_INTERPRETER_CHECK(output_data != nullptr);
+
+ const auto *options = cur_op->builtin_options_as_StridedSliceOptions();
+
+ auto op_params = buildStridedSliceParams(dims, begin_data, end_data, strides_data, options);
+
+ switch (Tensor::element_type(input))
{
+#ifndef DIS_FLOAT
case DataType::FLOAT32:
- tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()),
- getTensorData<float>(input()), getTensorShape(output()),
- getTensorData<float>(output()));
+ luci_interpreter_pal::StridedSlice(op_params, kernels::getTensorShape(input),
+ kernels::getTensorData<float>(input_data),
+ kernels::getTensorData<float>(output_data));
break;
+#endif // DIS_FLOAT
+#ifndef DIS_QUANT
case DataType::U8:
- tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()),
- getTensorData<uint8_t>(input()), getTensorShape(output()),
- getTensorData<uint8_t>(output()));
+ luci_interpreter_pal::StridedSlice(op_params, kernels::getTensorShape(input), input_data,
+ output_data);
+ break;
+ case DataType::S8:
+ luci_interpreter_pal::StridedSlice(op_params, kernels::getTensorShape(input), input_data,
+ output_data);
break;
+#endif
case DataType::S32:
- tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()),
- getTensorData<int32_t>(input()), getTensorShape(output()),
- getTensorData<int32_t>(output()));
+ luci_interpreter_pal::StridedSlice(op_params, kernels::getTensorShape(input),
+ kernels::getTensorData<int32_t>(input_data),
+ kernels::getTensorData<int32_t>(output_data));
break;
default:
- assert(false && "Unsupported type.");
+ assert(false && "Unsupported type");
}
}
-} // namespace kernels
} // namespace luci_interpreter