Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / service / src / ShapeInfer_StridedSlice.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2018 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "ShapeInfer_StridedSlice.h"
19 #include "Check.h"
20 #include "CircleShapeInferenceHelper.h"
21
22 #include <luci/IR/CircleNode.h>
23 #include <loco/IR/DataType.h>
24 #include <loco/IR/NodeShape.h>
25 #include <oops/InternalExn.h>
26
27 #include <algorithm>
28 #include <cmath>
29 #include <cstdint>
30 #include <limits>
31
32 // code referenced from
33 // https://github.com/tensorflow/tensorflow/blob/3f878cff5b698b82eea85db2b60d65a2e320850e/
34 //    tensorflow/lite/kernels/strided_slice.cc
35 //    tensorflow/lite/kernels/internal/strided_slice_logic.h
36
37 namespace
38 {
39
40 // This Op only supports 1-5D cases and since we use the reference 4D
41 // implementation, the 1-3D tensors are mapped to 4D.
42 const int kMaxDim = 5;
43
44 const loco::DataType S32 = loco::DataType::S32;
45
46 struct StridedSliceParams
47 {
48   int8_t start_indices_count = 0;
49   int32_t start_indices[kMaxDim];
50   int8_t stop_indices_count = 0;
51   int32_t stop_indices[kMaxDim];
52   int8_t strides_count = 0;
53   int32_t strides[kMaxDim];
54
55   int16_t begin_mask = 0;
56   int16_t ellipsis_mask = 0;
57   int16_t end_mask = 0;
58   int16_t new_axis_mask = 0;
59   int16_t shrink_axis_mask = 0;
60 };
61
62 struct StridedSliceContext
63 {
64   StridedSliceContext(const luci::CircleStridedSlice *node)
65   {
66     // check overflow issues
67     assert(static_cast<int16_t>(node->begin_mask()) == node->begin_mask());
68     assert(static_cast<int16_t>(node->ellipsis_mask()) == node->ellipsis_mask());
69     assert(static_cast<int16_t>(node->end_mask()) == node->end_mask());
70     assert(static_cast<int16_t>(node->new_axis_mask()) == node->new_axis_mask());
71     assert(static_cast<int16_t>(node->shrink_axis_mask()) == node->shrink_axis_mask());
72
73     params.begin_mask = node->begin_mask();
74     params.ellipsis_mask = node->ellipsis_mask();
75     params.end_mask = node->end_mask();
76     params.new_axis_mask = node->new_axis_mask();
77     params.shrink_axis_mask = node->shrink_axis_mask();
78
79     input = loco::must_cast<luci::CircleNode *>(node->input());
80     begin = loco::must_cast<luci::CircleConst *>(node->begin());
81     end = loco::must_cast<luci::CircleConst *>(node->end());
82     strides = loco::must_cast<luci::CircleConst *>(node->strides());
83
84     loco::TensorShape input_shape = luci::shape_get(input).as<loco::TensorShape>();
85     input_dims = input_shape.rank();
86   }
87   StridedSliceParams params;
88   luci::CircleNode *input = nullptr;
89   luci::CircleConst *begin = nullptr;
90   luci::CircleConst *end = nullptr;
91   luci::CircleConst *strides = nullptr;
92
93   // Equivalent input shape after adding axis according to new_axis_mask.
94   loco::TensorShape effective_input_shape;
95   int64_t input_dims = 0;
96 };
97
98 // Use until std::clamp() is available from C++17.
99 inline int Clamp(const int32_t v, const int32_t lo, const int32_t hi)
100 {
101   LUCI_ASSERT(!(hi < lo), "Clamp hi < lo");
102   if (hi < v)
103     return hi;
104   if (v < lo)
105     return lo;
106   return v;
107 }
108
109 // Return the index for the first element along that axis. This index will be a
110 // positive integer between [0, axis_size - 1] that can be used to index
111 // directly into the data.
112 inline int64_t StartForAxis(const StridedSliceParams &params, const loco::TensorShape &input_shape,
113                             int64_t axis)
114 {
115   const auto begin_mask = params.begin_mask;
116   const auto *start_indices = params.start_indices;
117   const auto *strides = params.strides;
118   const int64_t axis_size = static_cast<int64_t>(input_shape.dim(axis).value());
119   if (axis_size == 0)
120   {
121     return 0;
122   }
123   // Begin with the specified index.
124   int64_t start = start_indices[axis];
125
126   // begin_mask override
127   if (begin_mask & (1 << axis))
128   {
129     if (strides[axis] > 0)
130     {
131       // Forward iteration - use the first element. These values will get
132       // clamped below (Note: We could have set them to 0 and axis_size-1, but
133       // use lowest() and max() to maintain symmetry with StopForAxis())
134       start = std::numeric_limits<int32_t>::lowest();
135     }
136     else
137     {
138       // Backward iteration - use the last element.
139       start = std::numeric_limits<int32_t>::max();
140     }
141   }
142
143   // Handle negative indices
144   if (start < 0)
145   {
146     start += axis_size;
147   }
148
149   // Clamping
150   if (strides[axis] > 0)
151   {
152     // Forward iteration
153     start = Clamp(start, 0, axis_size);
154   }
155   else
156   {
157     // Backward iteration
158     start = Clamp(start, -1, axis_size - 1);
159   }
160
161   return start;
162 }
163
164 // Return the "real" index for the end of iteration along that axis. This is an
165 // "end" in the traditional C sense, in that it points to one past the last
166 // element. ie. So if you were iterating through all elements of a 1D array of
167 // size 4, this function would return 4 as the stop, because it is one past the
168 // "real" indices of 0, 1, 2 & 3.
169 inline int64_t StopForAxis(const StridedSliceParams &params, const loco::TensorShape &input_shape,
170                            int64_t axis, int64_t start_for_axis)
171 {
172   const auto end_mask = params.end_mask;
173   const auto shrink_axis_mask = params.shrink_axis_mask;
174   const auto *stop_indices = params.stop_indices;
175   const auto *strides = params.strides;
176   const int64_t axis_size = static_cast<int64_t>(input_shape.dim(axis).value());
177   if (axis_size == 0)
178   {
179     return 0;
180   }
181
182   // Begin with the specified index
183   const bool shrink_axis = shrink_axis_mask & (1 << axis);
184   int64_t stop = stop_indices[axis];
185
186   // When shrinking an axis, the end position does not matter (and can be
187   // incorrect when negative indexing is used, see Issue #19260). Always use
188   // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
189   // already been adjusted for negative indices.
190   if (shrink_axis)
191   {
192     return start_for_axis + 1;
193   }
194
195   // end_mask override
196   if (end_mask & (1 << axis))
197   {
198     if (strides[axis] > 0)
199     {
200       // Forward iteration - use the last element. These values will get
201       // clamped below
202       stop = std::numeric_limits<int32_t>::max();
203     }
204     else
205     {
206       // Backward iteration - use the first element.
207       stop = std::numeric_limits<int32_t>::lowest();
208     }
209   }
210
211   // Handle negative indices
212   if (stop < 0)
213   {
214     stop += axis_size;
215   }
216
217   // Clamping
218   // Because the end index points one past the last element, we need slightly
219   // different clamping ranges depending on the direction.
220   if (strides[axis] > 0)
221   {
222     // Forward iteration
223     stop = Clamp(stop, 0, axis_size);
224   }
225   else
226   {
227     // Backward iteration
228     stop = Clamp(stop, -1, axis_size - 1);
229   }
230
231   return stop;
232 }
233
234 StridedSliceParams BuildStridedSliceParams(StridedSliceContext *op_context)
235 {
236   StridedSliceParams op_params;
237
238   // The ellipsis_mask and new_axis_mask in op_params are not used. Those masks
239   // are processed here to update begin_mask, end_mask and the index range.
240   op_params.begin_mask = 0;
241   op_params.ellipsis_mask = 0;
242   op_params.end_mask = 0;
243   op_params.new_axis_mask = 0;
244   op_params.shrink_axis_mask = 0;
245
246   // Count indexes where the new_axis_mask is set but the ellipsis_mask is not.
247   loco::TensorShape begin_shape = luci::shape_get(op_context->begin).as<loco::TensorShape>();
248   const int64_t begin_count = static_cast<int64_t>(begin_shape.dim(0).value());
249   int64_t num_add_axis = 0;
250   for (int64_t i = 0; i < begin_count; ++i)
251   {
252     if (!((1 << i) & op_context->params.ellipsis_mask) &&
253         ((1 << i) & op_context->params.new_axis_mask))
254     {
255       num_add_axis++;
256     }
257   }
258
259   // Calculate the dims of input after adding new axises.
260   const int64_t effective_dims = op_context->input_dims + num_add_axis;
261
262   // If begin, end and strides are not fully provided, it means Ellipsis should
263   // be expanded to multiple dimensions (Ex: for spec [Ellipsis, 2] on a 3D
264   // input, the Ellipsis should be applied for the first 2 dimensions). Besides,
265   // If the new_axis_mask and the ellipsis_mask are set at the same index, the
266   // new_axis_mask will have no effect.
267   int64_t effective_ellipsis_mask = 0, effective_new_axis_mask = 0;
268   int64_t ellipsis_start_idx = effective_dims, expanded_ellipsis = 0;
269   for (int64_t i = 0; i < effective_dims;)
270   {
271     if ((1 << i) & op_context->params.ellipsis_mask)
272     {
273       ellipsis_start_idx = i;
274       int64_t ellipsis_end_idx =
275         std::max(i + 1, std::min(i + 1 + num_add_axis + op_context->input_dims - begin_count,
276                                  effective_dims));
277       expanded_ellipsis = ellipsis_end_idx - ellipsis_start_idx - 1;
278
279       // Set bit for effective_ellipsis_mask.
280       for (; i < ellipsis_end_idx; ++i)
281       {
282         effective_ellipsis_mask |= (1 << i);
283       }
284       continue;
285     }
286
287     if ((1 << (i - expanded_ellipsis)) & op_context->params.new_axis_mask)
288     {
289       effective_new_axis_mask |= (1 << i);
290     }
291     ++i;
292   }
293
294   // Calculate effective_input_shape and its corresponding begin, end, strides.
295   loco::TensorShape input_shape = luci::shape_get(op_context->input).as<loco::TensorShape>();
296   int64_t added_ellipsis = 0, added_axises = 0;
297   op_context->effective_input_shape.rank(effective_dims);
298
299   for (int64_t i = 0; i < effective_dims; ++i)
300   {
301     if ((1 << i) & effective_ellipsis_mask)
302     {
303       // If ellipsis_mask, set the begin_mask and end_mask at that index.
304       added_ellipsis = std::max(int64_t(0), i - ellipsis_start_idx);
305       assert(i < 16);
306       op_params.begin_mask |= (1 << i);
307       op_params.end_mask |= (1 << i);
308       op_params.strides[i] = 1;
309       op_context->effective_input_shape.dim(i) = input_shape.dim(i - added_axises);
310     }
311     else if ((1 << i) & effective_new_axis_mask)
312     {
313       // If new_axis_mask is set, it is equivalent to adding a new dim of 1 to
314       // input tensor. Store added shape to effective_input_shape.
315       op_params.start_indices[i] = 0;
316       op_params.stop_indices[i] = 1;
317       op_params.strides[i] = 1;
318       op_context->effective_input_shape.dim(i) = loco::Dimension(1);
319       added_axises++;
320     }
321     else if (i >= begin_count + expanded_ellipsis)
322     {
323       op_params.start_indices[i] = 0;
324       op_params.stop_indices[i] = 0;
325       op_params.strides[i] = 1;
326       assert(i < 16);
327       op_params.begin_mask |= (1 << i);
328       op_params.end_mask |= (1 << i);
329       op_context->effective_input_shape.dim(i) = input_shape.dim(i - added_axises);
330     }
331     else
332     {
333       const int64_t orig_idx = i - added_ellipsis;
334       op_params.start_indices[i] = op_context->begin->at<S32>(orig_idx);
335       op_params.stop_indices[i] = op_context->end->at<S32>(orig_idx);
336       op_params.strides[i] = op_context->strides->at<S32>(orig_idx);
337       if (op_context->params.begin_mask & (1 << orig_idx))
338       {
339         assert(i < 16);
340         op_params.begin_mask |= (1 << i);
341       }
342       if (op_context->params.end_mask & (1 << orig_idx))
343       {
344         assert(i < 16);
345         op_params.end_mask |= (1 << i);
346       }
347       if (op_context->params.shrink_axis_mask & (1 << orig_idx))
348       {
349         assert(i < 16);
350         op_params.shrink_axis_mask |= (1 << i);
351       }
352       op_context->effective_input_shape.dim(i) = input_shape.dim(i - added_axises);
353     }
354   }
355
356   // make sure no overflow
357   assert(static_cast<int8_t>(effective_dims) == static_cast<int32_t>(effective_dims));
358
359   op_params.start_indices_count = effective_dims;
360   op_params.stop_indices_count = effective_dims;
361   op_params.strides_count = effective_dims;
362
363   return op_params;
364 }
365
366 } // namespace
367
368 namespace luci
369 {
370
371 loco::TensorShape infer_output_shape(const CircleStridedSlice *node)
372 {
373   loco::TensorShape output_shape;
374
375   auto input_node = loco::must_cast<luci::CircleNode *>(node->input());
376
377   auto begin_node = dynamic_cast<luci::CircleConst *>(node->begin());
378   auto end_node = dynamic_cast<luci::CircleConst *>(node->end());
379   auto strides_node = dynamic_cast<luci::CircleConst *>(node->strides());
380   if (begin_node == nullptr || end_node == nullptr || strides_node == nullptr)
381   {
382     INTERNAL_EXN("StridedSlice begin/end/strides nodes are not Constant");
383   }
384
385   LUCI_ASSERT(begin_node->dtype() == S32, "Only support S32 for begin_node");
386   LUCI_ASSERT(end_node->dtype() == S32, "Only support S32 for end_node");
387   LUCI_ASSERT(strides_node->dtype() == S32, "Only support S32 for strides_node");
388
389   LUCI_ASSERT(begin_node->rank() == 1, "Only support rank 1 for begin_node");
390   LUCI_ASSERT(end_node->rank() == 1, "Only support rank 1 for end_node");
391   LUCI_ASSERT(strides_node->rank() == 1, "Only support rank 1 for strides_node");
392
393   loco::TensorShape input_shape = luci::shape_get(input_node).as<loco::TensorShape>();
394
395   assert(begin_node->size<S32>() <= input_shape.rank());
396   assert(end_node->size<S32>() <= input_shape.rank());
397   assert(strides_node->size<S32>() <= input_shape.rank());
398
399   StridedSliceContext op_context(node);
400   auto op_params = BuildStridedSliceParams(&op_context);
401   auto effective_input_shape = op_context.effective_input_shape;
402   std::vector<int64_t> output_shape_vector;
403
404   for (int32_t idx = effective_input_shape.rank() - 1; idx >= 0; --idx)
405   {
406     int32_t stride = op_params.strides[idx];
407     LUCI_ASSERT(stride != 0, "stride value has to be non-zero");
408
409     int64_t begin = StartForAxis(op_params, effective_input_shape, idx);
410     int64_t end = StopForAxis(op_params, effective_input_shape, idx, begin);
411
412     // When shrinking an axis, the end position does not matter (and can be
413     // incorrect when negative indexing is used, see Issue #19260). Always use
414     // begin + 1 to generate a length 1 slice, since begin has
415     // already been adjusted for negative indices by GetBeginValueAtIndex.
416     const bool shrink_axis = op_params.shrink_axis_mask & (1 << idx);
417     if (shrink_axis)
418     {
419       end = begin + 1;
420     }
421
422     // This is valid for both positive and negative strides
423     int64_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
424     dim_shape = dim_shape < 0 ? 0 : dim_shape;
425     if (!shrink_axis)
426     {
427       output_shape_vector.push_back(dim_shape);
428     }
429   }
430
431   auto shape_size = output_shape_vector.size();
432   output_shape.rank(shape_size);
433   for (uint32_t idx = 0; idx < shape_size; ++idx)
434   {
435     int64_t dim = output_shape_vector.at(shape_size - 1u - idx);
436     LUCI_ASSERT(0 <= dim && dim < 0xfffffffL, "Dimension size exceeds limit");
437     // reverse copy
438     output_shape.dim(idx) = static_cast<uint32_t>(dim);
439   }
440
441   return output_shape;
442 }
443
444 } // namespace luci