Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / onert-micro / luci-interpreter / pal / common / PALStridedSlice.h
1 /*
2  * Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2020 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #ifndef LUCI_INTERPRETER_PAL_STRIDED_SLICE_H
19 #define LUCI_INTERPRETER_PAL_STRIDED_SLICE_H
20
21 #include "Params.h"
22
23 namespace luci_interpreter_pal
24 {
25
26 namespace
27 {
28 // Use until std::clamp() is available from C++17.
29 inline int clamp(const int v, const int lo, const int hi)
30 {
31   if (hi < v)
32     return hi;
33   if (v < lo)
34     return lo;
35   return v;
36 }
37
38 inline bool loopCondition(int index, int stop, int stride)
39 {
40   // True when we have reached the end of an axis and should loop.
41   return stride > 0 ? index >= stop : index <= stop;
42 }
43
44 // Return the "real" index for the end of iteration along that axis. This is an
45 // "end" in the traditional C sense, in that it points to one past the last
46 // element. ie. So if you were iterating through all elements of a 1D array of
47 // size 4, this function would return 4 as the stop, because it is one past the
48 // "real" indices of 0, 1, 2 & 3.
49 inline int stopForAxis(const StridedSliceParams &params,
50                        const luci_interpreter::RuntimeShape &input_shape, int axis,
51                        int start_for_axis)
52 {
53   const auto end_mask = params.end_mask;
54   const auto shrink_axis_mask = params.shrink_axis_mask;
55   const auto *stop_indices = params.stop_indices;
56   const auto *strides = params.strides;
57   const int axis_size = input_shape.dims(axis);
58   if (axis_size == 0)
59   {
60     return 0;
61   }
62
63   // Begin with the specified index
64   const bool shrink_axis = shrink_axis_mask & (1 << axis);
65   int stop = stop_indices[axis];
66
67   // When shrinking an axis, the end position does not matter (and can be
68   // incorrect when negative indexing is used, see Issue #19260). Always use
69   // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
70   // already been adjusted for negative indices.
71   if (shrink_axis)
72   {
73     return start_for_axis + 1;
74   }
75
76   // end_mask override
77   if (end_mask & (1 << axis))
78   {
79     if (strides[axis] > 0)
80     {
81       // Forward iteration - use the last element. These values will get
82       // clamped below
83       stop = std::numeric_limits<int>::max();
84     }
85     else
86     {
87       // Backward iteration - use the first element.
88       stop = std::numeric_limits<int>::lowest();
89     }
90   }
91
92   // Handle negative indices
93   if (stop < 0)
94   {
95     stop += axis_size;
96   }
97
98   // Clamping
99   // Because the end index points one past the last element, we need slightly
100   // different clamping ranges depending on the direction.
101   if (strides[axis] > 0)
102   {
103     // Forward iteration
104     stop = clamp(stop, 0, axis_size);
105   }
106   else
107   {
108     // Backward iteration
109     stop = clamp(stop, -1, axis_size - 1);
110   }
111
112   return stop;
113 }
114
115 // Return the index for the first element along that axis. This index will be a
116 // positive integer between [0, axis_size] (or [-1, axis_size -1] if stride < 0)
117 // that can be used to index directly into the data.
118 inline int startForAxis(const StridedSliceParams &params,
119                         const luci_interpreter::RuntimeShape &input_shape, int axis)
120 {
121   const auto begin_mask = params.begin_mask;
122   const auto *start_indices = params.start_indices;
123   const auto *strides = params.strides;
124   const int axis_size = input_shape.dims(axis);
125   if (axis_size == 0)
126   {
127     return 0;
128   }
129   // Begin with the specified index.
130   int start = start_indices[axis];
131
132   // begin_mask override
133   if (begin_mask & 1 << axis)
134   {
135     if (strides[axis] > 0)
136     {
137       // Forward iteration - use the first element. These values will get
138       // clamped below (Note: We could have set them to 0 and axis_size-1, but
139       // use lowest() and max() to maintain symmetry with StopForAxis())
140       start = std::numeric_limits<int>::lowest();
141     }
142     else
143     {
144       // Backward iteration - use the last element.
145       start = std::numeric_limits<int>::max();
146     }
147   }
148
149   // Handle negative indices
150   if (start < 0)
151   {
152     start += axis_size;
153   }
154
155   // Clamping
156   if (strides[axis] > 0)
157   {
158     // Forward iteration
159     start = clamp(start, 0, axis_size);
160   }
161   else
162   {
163     // Backward iteration
164     start = clamp(start, -1, axis_size - 1);
165   }
166
167   return start;
168 }
169
170 inline void stridedSlicePadIndices(StridedSliceParams *p, int dim_count)
171 {
172   const int pad_count = dim_count - p->start_indices_count;
173
174   // Pad indices at start, so move arrays by pad_count.
175   for (int i = p->start_indices_count - 1; i >= 0; --i)
176   {
177     p->strides[i + pad_count] = p->strides[i];
178     p->start_indices[i + pad_count] = p->start_indices[i];
179     p->stop_indices[i + pad_count] = p->stop_indices[i];
180   }
181   for (int i = 0; i < pad_count; ++i)
182   {
183     p->start_indices[i] = 0;
184     p->stop_indices[i] = 1;
185     p->strides[i] = 1;
186   }
187
188   // Pad masks with 0s or 1s as required.
189   p->shrink_axis_mask <<= pad_count;
190   p->ellipsis_mask <<= pad_count;
191   p->new_axis_mask <<= pad_count;
192   p->begin_mask <<= pad_count;
193   p->end_mask <<= pad_count;
194   p->begin_mask |= (1 << pad_count) - 1;
195   p->end_mask |= (1 << pad_count) - 1;
196
197   p->start_indices_count = dim_count;
198   p->stop_indices_count = dim_count;
199   p->strides_count = dim_count;
200 }
201
202 } // namespace
203
204 template <typename T>
205 inline void StridedSlice(StridedSliceParams &op_params,
206                          const luci_interpreter::RuntimeShape &unextended_input_shape,
207                          const T *input_data, T *output_data)
208 {
209   const luci_interpreter::RuntimeShape input_shape =
210     luci_interpreter::RuntimeShape::extendedShape(5, unextended_input_shape);
211
212   // Reverse and pad to 5 dimensions because that is what the runtime code
213   // requires (ie. all shapes must be 5D and are given backwards).
214   stridedSlicePadIndices(&op_params, 5);
215
216   const int start_0 = startForAxis(op_params, input_shape, 0);
217   const int stop_0 = stopForAxis(op_params, input_shape, 0, start_0);
218   const int start_1 = startForAxis(op_params, input_shape, 1);
219   const int stop_1 = stopForAxis(op_params, input_shape, 1, start_1);
220   const int start_2 = startForAxis(op_params, input_shape, 2);
221   const int stop_2 = stopForAxis(op_params, input_shape, 2, start_2);
222   const int start_3 = startForAxis(op_params, input_shape, 3);
223   const int stop_3 = stopForAxis(op_params, input_shape, 3, start_3);
224   const int start_4 = startForAxis(op_params, input_shape, 4);
225   const int stop_4 = stopForAxis(op_params, input_shape, 4, start_4);
226
227   for (int offset_0 = start_0 * input_shape.dims(1), end_0 = stop_0 * input_shape.dims(1),
228            step_0 = op_params.strides[0] * input_shape.dims(1);
229        !loopCondition(offset_0, end_0, op_params.strides[0]); offset_0 += step_0)
230   {
231     for (int offset_1 = (offset_0 + start_1) * input_shape.dims(2),
232              end_1 = (offset_0 + stop_1) * input_shape.dims(2),
233              step_1 = op_params.strides[1] * input_shape.dims(2);
234          !loopCondition(offset_1, end_1, op_params.strides[1]); offset_1 += step_1)
235     {
236       for (int offset_2 = (offset_1 + start_2) * input_shape.dims(3),
237                end_2 = (offset_1 + stop_2) * input_shape.dims(3),
238                step_2 = op_params.strides[2] * input_shape.dims(3);
239            !loopCondition(offset_2, end_2, op_params.strides[2]); offset_2 += step_2)
240       {
241         for (int offset_3 = (offset_2 + start_3) * input_shape.dims(4),
242                  end_3 = (offset_2 + stop_3) * input_shape.dims(4),
243                  step_3 = op_params.strides[3] * input_shape.dims(4);
244              !loopCondition(offset_3, end_3, op_params.strides[3]); offset_3 += step_3)
245         {
246           for (int offset_4 = offset_3 + start_4, end_4 = offset_3 + stop_4;
247                !loopCondition(offset_4, end_4, op_params.strides[4]);
248                offset_4 += op_params.strides[4])
249           {
250             *output_data++ = input_data[offset_4];
251           }
252         }
253       }
254     }
255   }
256 }
257
258 } // namespace luci_interpreter_pal
259
260 #endif // LUCI_INTERPRETER_PAL_STRIDED_SLICE_H