Imported Upstream version 1.18.0
[platform/core/ml/nnfw.git] / compiler / luci-interpreter / src / kernels / StridedSlice.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "kernels/StridedSlice.h"
19
20 #include "kernels/Utils.h"
21
22 #include <tensorflow/lite/kernels/internal/reference/strided_slice.h>
23
24 #include <stdexcept>
25
26 namespace luci_interpreter
27 {
28
29 namespace kernels
30 {
31
32 StridedSlice::StridedSlice(const Tensor *input, const Tensor *begin, const Tensor *end,
33                            const Tensor *strides, Tensor *output, const StridedSliceParams &params)
34   : KernelWithParams<StridedSliceParams>({input, begin, end, strides}, {output}, params)
35 {
36 }
37
38 void StridedSlice::configure()
39 {
40   assert(begin()->shape().num_dims() == 1);
41   assert(end()->shape().num_dims() == 1);
42   assert(strides()->shape().num_dims() == 1);
43   assert(input()->element_type() == output()->element_type());
44   assert(begin()->element_type() == DataType::S32);
45   assert(end()->element_type() == DataType::S32);
46   assert(strides()->element_type() == DataType::S32);
47   assert(input()->shape().num_dims() <= 4);
48   if (params().ellipsis_mask != 0)
49   {
50     throw std::runtime_error("ellipsis_mask is not implemented yet.");
51   }
52   if (params().new_axis_mask != 0)
53   {
54     throw std::runtime_error("new_axis_mask is not implemented yet.");
55   }
56   if (input()->element_type() == DataType::U8)
57   {
58     assert(input()->scale() == output()->scale());
59     assert(input()->zero_point() == output()->zero_point());
60   }
61   tflite::StridedSliceParams op_params{};
62   op_params.start_indices_count = input()->shape().num_dims();
63   op_params.stop_indices_count = input()->shape().num_dims();
64   op_params.strides_count = input()->shape().num_dims();
65
66   for (int i = 0; i < input()->shape().num_dims(); i++)
67   {
68     op_params.start_indices[i] = getTensorData<int32_t>(begin())[i];
69     op_params.stop_indices[i] = getTensorData<int32_t>(end())[i];
70     op_params.strides[i] = getTensorData<int32_t>(strides())[i];
71   }
72   op_params.begin_mask = params().begin_mask;
73   op_params.ellipsis_mask = 0;
74   op_params.end_mask = params().end_mask;
75   op_params.new_axis_mask = 0;
76   op_params.shrink_axis_mask = params().shrink_axis_mask;
77   std::vector<int32_t> output_shape_vector;
78   for (int i = 0; i < input()->shape().num_dims(); i++)
79   {
80     int idx = input()->shape().num_dims() - i - 1;
81     int32_t stride = getTensorData<int32_t>(strides())[idx];
82     assert(stride != 0);
83     int32_t begin = ::tflite::strided_slice::StartForAxis(op_params, getTensorShape(input()), idx);
84     int32_t end =
85       ::tflite::strided_slice::StopForAxis(op_params, getTensorShape(input()), idx, begin);
86
87     const bool shrink_axis = params().shrink_axis_mask & (1 << idx);
88     if (shrink_axis)
89     {
90       end = begin + 1;
91     }
92
93     int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
94     dim_shape = dim_shape < 0 ? 0 : dim_shape;
95     if (!shrink_axis)
96     {
97       output_shape_vector.push_back(dim_shape);
98     }
99   }
100   Shape output_shape = Shape(output_shape_vector.size());
101   for (size_t i = 0; i < output_shape_vector.size(); i++)
102   {
103     output_shape.dim(i) = output_shape_vector[output_shape_vector.size() - i - 1];
104   }
105   output()->resize(output_shape);
106 }
107
108 void StridedSlice::execute() const
109 {
110   tflite::StridedSliceParams op_params{};
111   op_params.start_indices_count = input()->shape().num_dims();
112   op_params.stop_indices_count = input()->shape().num_dims();
113   op_params.strides_count = input()->shape().num_dims();
114
115   for (int i = 0; i < input()->shape().num_dims(); i++)
116   {
117     op_params.start_indices[i] = getTensorData<int32_t>(begin())[i];
118     op_params.stop_indices[i] = getTensorData<int32_t>(end())[i];
119     op_params.strides[i] = getTensorData<int32_t>(strides())[i];
120   }
121   op_params.begin_mask = params().begin_mask;
122   op_params.ellipsis_mask = 0;
123   op_params.end_mask = params().end_mask;
124   op_params.new_axis_mask = 0;
125   op_params.shrink_axis_mask = params().shrink_axis_mask;
126
127   switch (input()->element_type())
128   {
129     case DataType::FLOAT32:
130       tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()),
131                                           getTensorData<float>(input()), getTensorShape(output()),
132                                           getTensorData<float>(output()));
133       break;
134     case DataType::U8:
135       tflite::reference_ops::StridedSlice(op_params, getTensorShape(input()),
136                                           getTensorData<uint8_t>(input()), getTensorShape(output()),
137                                           getTensorData<uint8_t>(output()));
138       break;
139     default:
140       throw std::runtime_error("Unsupported type.");
141   }
142 }
143
144 } // namespace kernels
145 } // namespace luci_interpreter