Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / util / ShapeInference.cc
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "util/Utils.h"
19 #include "ir/InternalType.h"
20 #include "ir/Shape.h"
21 #include "util/ShapeInference.h"
22 #include "util/logging.h"
23
24 #include <cassert>
25 #include <sstream>
26 #include <cmath>
27
28 namespace onert
29 {
30 namespace shape_inference
31 {
32
33 //
34 // Helper functions
35 //
36
37 namespace
38 {
39
40 template <typename T, typename U>
41 typename std::enable_if<std::is_integral<T>::value && std::is_integral<U>::value,
42                         typename std::common_type<T, U>::type>::type
43 ceil_div(T dividend, U divisor)
44 {
45   assert(dividend > 0 && divisor > 0 && "this implementations is for positive numbers only");
46   return (dividend + divisor - 1) / divisor;
47 }
48
49 // Calculate the result of broadcast of two shapes
50 ir::Shape broadcastShapes(const ir::Shape &lhs_shape, const ir::Shape &rhs_shape)
51 {
52   ir::Shape out_shape;
53   auto max_rank = std::max(lhs_shape.rank(), rhs_shape.rank());
54
55   for (int idx = 0; idx < max_rank; ++idx)
56   {
57     // Go over operands dimensions from right to left
58     int lhs_idx = lhs_shape.rank() - idx - 1;
59     int rhs_idx = rhs_shape.rank() - idx - 1;
60
61     int32_t lhs_dim = lhs_idx >= 0 ? lhs_shape.dim(lhs_idx) : 1;
62     int32_t rhs_dim = rhs_idx >= 0 ? rhs_shape.dim(rhs_idx) : 1;
63
64     if (lhs_dim != 1 && rhs_dim != 1 && lhs_dim != rhs_dim)
65       throw std::runtime_error("Incompatible shapes for broadcast");
66
67     out_shape.prepend(std::max(lhs_dim, rhs_dim));
68   }
69
70   return out_shape;
71 }
72
73 } // namespace
74
75 //
76 // Shape inference
77 //
78
79 // Calculate output height and width of convolution-like operation
80 std::pair<int, int> calcConvLikeHeightAndWidth(const int in_h, const int in_w, const int ker_h,
81                                                const int ker_w, const ir::Padding pad,
82                                                const ir::Stride stride,
83                                                const ir::Dilation dilation = {1, 1})
84 {
85   int32_t out_h = 0, out_w = 0;
86   int32_t effective_filter_w_size = (ker_w - 1) * dilation.width_factor + 1;
87   int32_t effective_filter_h_size = (ker_h - 1) * dilation.height_factor + 1;
88   switch (pad.type)
89   {
90     case ir::PaddingType::SAME:
91       out_h = ceil_div(in_h, stride.vertical);
92       out_w = ceil_div(in_w, stride.horizontal);
93       break;
94     case ir::PaddingType::VALID:
95       out_h = ceil_div(in_h - effective_filter_h_size + 1, stride.vertical);
96       out_w = ceil_div(in_w - effective_filter_w_size + 1, stride.horizontal);
97       break;
98     case ir::PaddingType::EXPLICIT:
99       out_h =
100           (in_h + pad.param.top + pad.param.bottom - effective_filter_h_size) / stride.vertical + 1;
101       out_w =
102           (in_w + pad.param.left + pad.param.right - effective_filter_w_size) / stride.horizontal +
103           1;
104       break;
105     default:
106       assert(false);
107   }
108
109   return {out_h, out_w};
110 }
111
112 ir::Shape inferEltwiseShape(const ir::Shape &lhs_shape, const ir::Shape &rhs_shape)
113 {
114   return broadcastShapes(lhs_shape, rhs_shape);
115 }
116
117 ir::Shape inferArgMaxShape(const ir::Shape &input_shape, int axis, int rank)
118 {
119   ir::Shape out_shape;
120   for (int idx = 0; idx < rank; ++idx)
121   {
122     if (idx != axis)
123     {
124       int32_t input_dim = input_shape.dim(idx);
125       out_shape.append(input_dim);
126     }
127   }
128
129   return out_shape;
130 }
131
132 ir::Shape inferReduceShape(const ir::Shape &input_shape, const std::vector<int> &axes,
133                            bool keep_dims)
134 {
135   int num_axis = axes.size();
136   int input_num_dims = input_shape.rank();
137   if (input_num_dims == 0)
138   {
139     ir::Shape out_shape(0);
140     return out_shape;
141   }
142   if (keep_dims)
143   {
144     ir::Shape out_shape;
145     for (int idx = 0; idx < input_num_dims; ++idx)
146     {
147       bool is_axis = false;
148       for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx)
149       {
150         if (axes[axis_idx] == idx || axes[axis_idx] + input_num_dims == idx)
151         {
152           is_axis = true;
153           break;
154         }
155       }
156       if (is_axis)
157       {
158         out_shape.append(1);
159       }
160       else
161       {
162         out_shape.append(input_shape.dim(idx));
163       }
164     }
165     return out_shape;
166   }
167   else
168   {
169     // Calculates size of reducing axis.
170     int num_reduce_axis = num_axis;
171     for (int i = 0; i < num_axis; ++i)
172     {
173       int current = axes[i];
174       if (current < 0)
175       {
176         current += input_num_dims;
177       }
178       assert(0 <= current && current < input_num_dims);
179       for (int j = 0; j < i; ++j)
180       {
181         int previous = axes[j];
182         if (previous < 0)
183         {
184           previous += input_num_dims;
185         }
186         if (current == previous)
187         {
188           --num_reduce_axis;
189           break;
190         }
191       }
192     }
193     // Determines output dimensions.
194     ir::Shape out_shape;
195     int num_skip_axis = 0;
196     for (int idx = 0; idx < input_num_dims; ++idx)
197     {
198       bool is_axis = false;
199       for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx)
200       {
201         if (axes[axis_idx] == idx || axes[axis_idx] + input_num_dims == idx)
202         {
203           ++num_skip_axis;
204           is_axis = true;
205           break;
206         }
207       }
208       if (!is_axis)
209       {
210         out_shape.append(input_shape.dim(idx));
211       }
212     }
213     return out_shape;
214   }
215 }
216
217 ir::Shape inferBatchMatMulShape(const ir::Shape &lhs_shape, const ir::Shape &rhs_shape,
218                                 const ir::operation::BatchMatMul::Param &param)
219 {
220   bool adj_x = param.adj_x;
221   bool adj_y = param.adj_y;
222   ir::Shape output_shape;
223
224   int output_rank = std::max(lhs_shape.rank(), rhs_shape.rank());
225
226   // Extend lhs and rhs shape
227   ir::Shape extended_lhs_shape(lhs_shape);
228   ir::Shape extended_rhs_shape(rhs_shape);
229   extended_lhs_shape.extendRank(output_rank);
230   extended_rhs_shape.extendRank(output_rank);
231
232   for (int i = 0; i < output_rank - 2; i++)
233   {
234     const int lhs_dim = extended_lhs_shape.dim(i);
235     const int rhs_dim = extended_rhs_shape.dim(i);
236     int broadcast_dim = lhs_dim;
237     if (lhs_dim != rhs_dim)
238     {
239       if (lhs_dim == 1)
240       {
241         broadcast_dim = rhs_dim;
242       }
243       else if (rhs_dim != 1)
244       {
245         throw std::runtime_error{"BatchMatMul shape inference: invalid brodcasting input shape"};
246       }
247     }
248
249     output_shape.append(broadcast_dim);
250   }
251
252   // Fill in the matmul dimensions.
253   int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2;
254   int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1;
255
256   output_shape.append(extended_lhs_shape.dim(lhs_rows_index));
257   output_shape.append(extended_rhs_shape.dim(rhs_cols_index));
258
259   return output_shape;
260 }
261
262 ir::Shape inferBroadcastToShape(const ir::Shape wshape, const int32_t *shape_buffer)
263 {
264   const int num_elements = wshape.num_elements();
265
266   assert(num_elements != 0);
267   assert(shape_buffer);
268
269   ir::Shape new_shape(num_elements);
270
271   for (int i = 0; i < num_elements; ++i)
272   {
273     assert(shape_buffer[i] != 0); // It shouldn't be 0.
274     new_shape.dim(i) = shape_buffer[i];
275   }
276
277   return new_shape;
278 }
279
280 ir::Shape inferConcatShape(const Shapes &in_shapes, const ir::operation::Concat::Param &param)
281 {
282   const int32_t concat_axis = param.axis >= 0 ? param.axis : in_shapes[0].rank() + param.axis;
283   const auto &first_in_shape = in_shapes[0];
284
285   // Check that all shapes are equal except for concat axis dimension
286   for (const auto &in_shape : in_shapes)
287   {
288     if (in_shape.rank() != first_in_shape.rank())
289       throw std::runtime_error("Rank in all input tensors should be same");
290
291     for (int64_t dim_idx = 0; dim_idx < in_shape.rank(); ++dim_idx)
292       if (!(dim_idx == concat_axis || in_shape.dim(dim_idx) == first_in_shape.dim(dim_idx)))
293         throw std::runtime_error("All tensor should have same dimension "
294                                  "except dimension on passed axis");
295   }
296
297   // Calculate output shape
298   ir::Shape out_shape(first_in_shape);
299   out_shape.dim(concat_axis) = 0;
300   for (const auto &in_shape : in_shapes)
301     out_shape.dim(concat_axis) += in_shape.dim(concat_axis);
302   return out_shape;
303 }
304
305 ir::Shape inferConv2DShape(const ir::Shape &in_shape, const ir::Shape &ker_shape,
306                            const ir::operation::Conv2D::Param &param, ir::Layout layout)
307 {
308   auto ifm_shape = in_shape.asFeature(layout);
309
310   // Kernel format is [depth_out, kernel_height, kernel_width, depth_in]
311   auto kf_shape = ker_shape.asFeature(layout);
312   assert(ifm_shape.C == kf_shape.C);
313
314   const auto out_h_w = calcConvLikeHeightAndWidth(ifm_shape.H, ifm_shape.W, kf_shape.H, kf_shape.W,
315                                                   param.padding, param.stride, param.dilation);
316
317   return ir::Shape{ifm_shape.N, out_h_w.first, out_h_w.second, kf_shape.N};
318 }
319
320 ir::Shape inferDepthwiseConv2DShape(const ir::Shape &in_shape, const ir::Shape &ker_shape,
321                                     const ir::operation::DepthwiseConv2D::Param &param,
322                                     ir::Layout layout)
323 {
324   assert(layout == ir::Layout::NHWC);
325   auto ifm_shape = in_shape.asFeature(layout);
326
327   // Kernel format is [1, kernel_height, kernel_width, depth_out]
328   auto kf_shape = ker_shape.asFeature(layout);
329   assert(kf_shape.C == static_cast<int32_t>(ifm_shape.C * param.multiplier));
330   assert(kf_shape.N == 1);
331
332   const auto out_h_w = calcConvLikeHeightAndWidth(ifm_shape.H, ifm_shape.W, kf_shape.H, kf_shape.W,
333                                                   param.padding, param.stride);
334
335   return ir::Shape{ifm_shape.N, out_h_w.first, out_h_w.second, kf_shape.C};
336 }
337
338 ir::Shape inferExpandDimsShape(const ir::Shape &in_shape, int32_t axis)
339 {
340   ir::Shape out_shape(in_shape.rank() + 1);
341
342   axis = ((axis >= 0) ? axis : /* when axis < 0 */ (out_shape.rank() + axis));
343   if (!(0 <= axis && axis <= in_shape.rank()))
344     throw std::runtime_error("axis of dim is out of range");
345
346   for (int x = 0, out_x = 0; out_x < out_shape.rank(); ++out_x)
347   {
348     if (out_x == axis)
349       out_shape.dim(out_x) = 1;
350     else
351       out_shape.dim(out_x) = in_shape.dim(x++);
352   }
353
354   return out_shape;
355 }
356
357 ir::Shape inferFillShape(const ir::Shape &in_shape, const int32_t *buffer)
358 {
359   ir::Shape out_shape(in_shape.dim(0));
360
361   for (int out_x = 0; out_x < out_shape.rank(); ++out_x)
362   {
363     out_shape.dim(out_x) = buffer[out_x];
364   }
365
366   return out_shape;
367 }
368
369 ir::Shape inferFullyConnectedShape(const ir::Shape &in_shape, const ir::Shape &ker_shape)
370 {
371   assert(in_shape.rank() >= 2);
372   assert(ker_shape.rank() == 2);
373
374   const auto input_size_with_batch = in_shape.num_elements();
375   const auto num_units = ker_shape.dim(0);
376   const auto input_size = ker_shape.dim(1);
377   const auto batch_size = input_size_with_batch / input_size;
378   assert(input_size_with_batch % input_size == 0);
379
380   return {ir::Shape({static_cast<int32_t>(batch_size), num_units})};
381 }
382
383 ir::Shape inferGatherShape(const ir::Shape &input_shape, const ir::Shape &indices_shape, int axis,
384                            int rank)
385 {
386   ir::Shape out_shape;
387   const int indices_rank = indices_shape.rank();
388   for (int idx = 0; idx < rank; ++idx)
389   {
390     if (idx == axis)
391     {
392       for (int indices_idx = 0; indices_idx < indices_rank; indices_idx++)
393       {
394         out_shape.append(indices_shape.dim(indices_idx));
395       }
396     }
397     else
398     {
399       out_shape.append(input_shape.dim(idx));
400     }
401   }
402
403   return out_shape;
404 }
405
406 ir::Shape inferOnehotShape(const ir::Shape &input_shape, const int depth, int axis)
407 {
408   assert(depth >= 0);
409   const auto rank = input_shape.rank() + 1;
410   ir::Shape newShape(rank);
411
412   axis = (axis == -1) ? (rank - 1) : axis;
413
414   for (int i = 0; i < rank; ++i)
415   {
416     if (i < axis)
417     {
418       newShape.dim(i) = input_shape.dim(i);
419     }
420     else if (i == axis)
421     {
422       newShape.dim(i) = depth;
423     }
424     else
425     {
426       newShape.dim(i) = input_shape.dim(i - 1);
427     }
428   }
429
430   return newShape;
431 }
432
433 ir::Shape inferPackShape(const ir::Shape &input_shape, int axis, int rank, int num)
434 {
435   ir::Shape out_shape;
436   int in_idx = 0;
437
438   for (int out_idx = 0; out_idx < rank; ++out_idx)
439   {
440     if (out_idx == axis)
441     {
442       out_shape.append(num);
443     }
444     else
445     {
446       out_shape.append(input_shape.dim(in_idx++));
447     }
448   }
449
450   return out_shape;
451 }
452
453 ir::Shape inferPadShape(const ir::Shape &in_shape, const int32_t *pad_buf, const size_t num_pads)
454 {
455   assert(num_pads % 2 == 0);
456   const int32_t rank = num_pads / 2;
457
458   ir::Shape ret(rank);
459   for (int32_t i = 0; i < rank; ++i)
460   {
461     const auto before_padding = pad_buf[i * 2];
462     const auto after_padding = pad_buf[i * 2 + 1];
463
464     ret.dim(i) = in_shape.dim(i) + before_padding + after_padding;
465   }
466
467   return ret;
468 }
469
470 ir::Shape inferPoolShape(const ir::Shape &in_shape, const ir::operation::Pool2D::Param &param,
471                          const ir::Layout layout)
472 {
473   assert(layout == ir::Layout::NHWC);
474   auto ifm_shape = in_shape.asFeature(layout);
475   const auto out_h_w = calcConvLikeHeightAndWidth(ifm_shape.H, ifm_shape.W, param.kh, param.kw,
476                                                   param.padding, param.stride);
477   // Pooling don't change number of channels and batch size
478   return ir::Shape{ifm_shape.N, out_h_w.first, out_h_w.second, ifm_shape.C};
479 }
480
481 ir::Shape inferResizeBilinearShape(const ir::Shape &in_shape, const int32_t output_height,
482                                    const int32_t output_width)
483 {
484   assert(in_shape.rank() == 4);
485   ir::Shape ret(in_shape.rank());
486
487   ret.dim(0) = in_shape.dim(0);
488   ret.dim(1) = output_height;
489   ret.dim(2) = output_width;
490   ret.dim(3) = in_shape.dim(3);
491
492   return ret;
493 }
494
495 template <typename T> ir::Shape inferRangeShape(T start_val, T limit_val, T delta_val)
496 {
497   ir::Shape out_shape(static_cast<int>(1));
498
499   out_shape.dim(0) =
500       (std::is_integral<T>::value
501            ? ((std::abs(start_val - limit_val) + std::abs(delta_val) - 1) / std::abs(delta_val))
502            : std::ceil(std::abs((start_val - limit_val) / delta_val)));
503   return out_shape;
504 }
505
506 // template instantiation
507 template ir::Shape inferRangeShape(int start_val, int limit_val, int delta_val);
508 template ir::Shape inferRangeShape(float start_val, float limit_val, float delta_val);
509
510 ir::Shape inferReshapeShape(const int32_t *shape_buf, const int32_t shape_num_elements,
511                             const size_t total_num_elements)
512 {
513   ir::Shape ret(shape_num_elements);
514   int32_t flatten_dim = ir::Shape::UNSPECIFIED_DIM;
515   for (int32_t i = 0; i < shape_num_elements; ++i)
516   {
517     if (shape_buf[i] < 0)
518     {
519       if (flatten_dim != ir::Shape::UNSPECIFIED_DIM)
520         throw std::runtime_error("Reshape: 2nd param has special dim(for flatten) more than twice");
521       flatten_dim = i;
522       ret.dim(i) = 1;
523     }
524     else
525     {
526       ret.dim(i) = shape_buf[i];
527     }
528   }
529   if (flatten_dim != ir::Shape::UNSPECIFIED_DIM)
530     ret.dim(flatten_dim) = total_num_elements / ret.num_elements();
531
532   // Check reshapable
533   if (total_num_elements != static_cast<size_t>(ret.num_elements()))
534     throw std::runtime_error("Reshape: 2nd param is not compatible with the shape of input");
535
536   return ret;
537 }
538
539 ir::Shape inferSelectShape(const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape,
540                            const ir::Shape &input_false_shape)
541 {
542   auto haveSameShapes = [](const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape,
543                            const ir::Shape &input_false_shape) {
544     if ((input_cond_shape.rank() != input_true_shape.rank()) ||
545         input_cond_shape.rank() != input_false_shape.rank())
546     {
547       return false;
548     }
549
550     int rank = input_cond_shape.rank();
551     for (int i = 0; i < rank; ++i)
552     {
553       if (input_cond_shape.dim(i) != input_true_shape.dim(i) ||
554           input_cond_shape.dim(i) != input_false_shape.dim(i))
555       {
556         return false;
557       }
558     }
559
560     return true;
561   };
562
563   auto calculateShape = [](const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape,
564                            const ir::Shape &input_false_shape, ir::Shape &new_shape) {
565     ir::Shape cond_shape = input_cond_shape;
566     ir::Shape true_shape = input_true_shape;
567     ir::Shape false_shape = input_false_shape;
568     int most_rank =
569         (cond_shape.rank() >= true_shape.rank()) && (cond_shape.rank() >= false_shape.rank())
570             ? cond_shape.rank()
571             : (false_shape.rank() >= true_shape.rank() ? false_shape.rank() : true_shape.rank());
572
573     ir::Shape calculate_shape(most_rank);
574
575     cond_shape.extendRank(most_rank);
576     true_shape.extendRank(most_rank);
577     false_shape.extendRank(most_rank);
578
579     for (int i = 0; i < most_rank; ++i)
580     {
581       calculate_shape.dim(i) =
582           (cond_shape.dim(i) >= true_shape.dim(i)) && (cond_shape.dim(i) >= false_shape.dim(i))
583               ? cond_shape.dim(i)
584               : (false_shape.dim(i) >= true_shape.dim(i) ? false_shape.dim(i) : true_shape.dim(i));
585
586       if ((cond_shape.dim(i) != calculate_shape.dim(i) && cond_shape.dim(i) != 1) ||
587           (true_shape.dim(i) != calculate_shape.dim(i) && true_shape.dim(i) != 1) ||
588           (false_shape.dim(i) != calculate_shape.dim(i) && false_shape.dim(i) != 1))
589       {
590         return false;
591       }
592     }
593
594     new_shape = calculate_shape;
595
596     return true;
597   };
598
599   bool havesame = haveSameShapes(input_cond_shape, input_true_shape, input_false_shape);
600   if (havesame)
601   {
602     return input_cond_shape;
603   }
604
605   ir::Shape new_shape;
606   bool possible = calculateShape(input_cond_shape, input_true_shape, input_false_shape, new_shape);
607
608   if (!possible)
609   {
610     throw std::runtime_error("Broadcasting is not possible.");
611   }
612
613   return new_shape;
614 }
615
616 ir::Shape inferSliceShape(const ir::Shape &input_shape, const int32_t *begins, const int32_t *sizes)
617 {
618   const uint32_t rank = input_shape.rank();
619   ir::Shape out_shape(rank);
620
621   for (uint32_t idx = 0; idx < rank; ++idx)
622   {
623     const auto input_dim = input_shape.dim(idx);
624
625     // begin is zero-based
626     auto begin = begins[idx];
627     if (begin < 0)
628       throw std::runtime_error("shape inference Slice: Invalid begin.");
629
630     // size is one-based
631     auto size = sizes[idx];
632     if (size < -1)
633       throw std::runtime_error("shape inference Slice: Invalid size.");
634
635     if (size == -1)
636     {
637       size = input_dim - begin;
638     }
639     else
640     {
641       if (input_dim < begin + size)
642         throw std::runtime_error("shape inference Slice: Invalid begin and size.");
643     }
644     out_shape.dim(idx) = size;
645   }
646
647   return out_shape;
648 }
649
650 ir::Shape inferSpaceToBatchNDShape(const ir::Shape &input_shape, const ir::Shape &block_shape_shape,
651                                    const ir::Shape &padding_shape, const int32_t *block_shape_data,
652                                    const int32_t *padding_data)
653 {
654   const uint32_t rank = input_shape.rank();
655   ir::Shape out_shape(rank);
656
657   // Currently, only 4D NHWC input/output op_context are supported.
658   // The 4D array need to have exactly 2 spatial dimensions.
659   // TODO(nupurgarg): Support arbitrary dimension in SpaceToBatchND.
660   const int32_t kInputDimensionNum = 4;
661   const int32_t kBlockSizeDimensionNum = 1;
662   const int32_t kSpatialDimensionNum = 2;
663
664   UNUSED_RELEASE(kInputDimensionNum);
665   UNUSED_RELEASE(kBlockSizeDimensionNum);
666   UNUSED_RELEASE(block_shape_shape);
667   UNUSED_RELEASE(padding_shape);
668
669   assert(block_shape_shape.rank() == kBlockSizeDimensionNum);
670   assert(block_shape_shape.dim(0) == kSpatialDimensionNum);
671   assert(padding_shape.dim(0) == kSpatialDimensionNum);
672   assert(padding_shape.dim(1) == 2); // fixed, meaning left/right padding for each element
673   assert(padding_shape.rank() == 2); // fixed, meaning dimension(dim 0) and padding length(dim 1)
674
675   // Ensures the input height and width (with padding) is a multiple of block
676   // shape height and width.
677   for (int dim = 0; dim < kSpatialDimensionNum; ++dim)
678   {
679     int final_dim_size =
680         (input_shape.dim(dim + 1) + padding_data[dim * 2] + padding_data[dim * 2 + 1]);
681
682     assert(final_dim_size % block_shape_data[dim] == 0);
683
684     out_shape.dim(dim + 1) = final_dim_size / block_shape_data[dim];
685   }
686
687   const int output_batch_size = input_shape.dim(0) * block_shape_data[0] * block_shape_data[1];
688   const int output_channel_size = input_shape.dim(3);
689
690   out_shape.dim(0) = output_batch_size;
691   out_shape.dim(3) = output_channel_size;
692
693   return out_shape;
694 }
695
696 ir::Shape inferSplitShape(const ir::Shape input_shape, int axis_value, int num_splits)
697 {
698   ir::Shape newShape(input_shape);
699
700   assert(axis_value >= 0);
701   assert(axis_value < input_shape.rank());
702
703   const int input_size = input_shape.dim(axis_value);
704   assert(input_size % num_splits == 0);
705   const int slice_size = input_size / num_splits;
706
707   newShape.dim(axis_value) = slice_size;
708
709   return newShape;
710 }
711
712 ir::Shape inferSqueezeShape(const ir::Shape &in_shape, const ir::operation::Squeeze::Param &param)
713 {
714   const int ndims = param.ndim;
715   const int *squeeze_dims = param.dims;
716   bool should_squeeze[8] = {false};
717   int num_squeezed_dims = 0;
718   int shape_rank = in_shape.rank();
719   if (ndims == 0)
720   {
721     for (int idx = 0; idx < shape_rank; ++idx)
722     {
723       if (in_shape.dim(idx) == 1)
724       {
725         should_squeeze[idx] = true;
726         ++num_squeezed_dims;
727       }
728     }
729   }
730   else
731   {
732     for (int idx = 0; idx < ndims; ++idx)
733     {
734       int current = squeeze_dims[idx];
735       if (current < 0)
736       {
737         current += shape_rank;
738       }
739
740       if (!(current >= 0 && current < shape_rank && in_shape.dim(current) == 1))
741       {
742         throw std::runtime_error(
743             "The following conditions must be met: 0 <= dim < Shape rank, dim == 1");
744       }
745
746       if (!should_squeeze[current])
747       {
748         ++num_squeezed_dims;
749       }
750       should_squeeze[current] = true;
751     }
752   }
753
754   // Set output shape.
755   ir::Shape out_shape(shape_rank - num_squeezed_dims);
756   for (int in_idx = 0, out_idx = 0; in_idx < shape_rank; ++in_idx)
757   {
758     if (!should_squeeze[in_idx])
759     {
760       out_shape.dim(out_idx++) = in_shape.dim(in_idx);
761     }
762   }
763
764   return out_shape;
765 }
766
767 // helper for for StridedSlice
768 template <typename T>
769 StridedSliceParams buildStridedSliceParams(const T *begin, const T *end, const T *strides,
770                                            const uint32_t begin_mask, const uint32_t end_mask,
771                                            const uint32_t shrink_axis_mask, const uint8_t rank)
772 {
773   StridedSliceParams op_params;
774   op_params.start_indices_count = rank;
775   op_params.stop_indices_count = rank;
776   op_params.strides_count = rank;
777
778   for (int i = 0; i < op_params.strides_count; ++i)
779   {
780     op_params.start_indices[i] = begin[i];
781     op_params.stop_indices[i] = end[i];
782     op_params.strides[i] = strides[i];
783
784     assert(op_params.strides[i] != 0);
785   }
786
787   op_params.begin_mask = begin_mask;
788   op_params.ellipsis_mask = 0; // NYI
789   op_params.end_mask = end_mask;
790   op_params.new_axis_mask = 0; // NYI
791   op_params.shrink_axis_mask = shrink_axis_mask;
792
793   assert(sizeof(op_params.begin_mask) * 4 >= rank);
794
795   return op_params;
796 }
797
798 // template instantiation
799 template StridedSliceParams
800 buildStridedSliceParams(const uint32_t *begin, const uint32_t *end, const uint32_t *strides,
801                         const uint32_t begin_mask, const uint32_t end_mask,
802                         const uint32_t shrink_axis_mask, const uint8_t rank);
803
804 int Clamp(const int v, const int lo, const int hi)
805 {
806   assert(!(hi < lo));
807   if (hi < v)
808     return hi;
809   if (v < lo)
810     return lo;
811   return v;
812 }
813
814 int StartForAxis(const StridedSliceParams &params, const ir::Shape &input_shape, int axis)
815 {
816   const auto begin_mask = params.begin_mask;
817   const auto *start_indices = params.start_indices;
818   const auto *strides = params.strides;
819   // Begin with the specified index.
820   int start = start_indices[axis];
821
822   // begin_mask override
823   if (begin_mask & 1 << axis)
824   {
825     if (strides[axis] > 0)
826     {
827       // Forward iteration - use the first element. These values will get
828       // clamped below (Note: We could have set them to 0 and axis_size-1, but
829       // use lowest() and max() to maintain symmetry with StopForAxis())
830       start = std::numeric_limits<int>::lowest();
831     }
832     else
833     {
834       // Backward iteration - use the last element.
835       start = std::numeric_limits<int>::max();
836     }
837   }
838
839   // Handle negative indices
840   int axis_size = input_shape.dim(axis);
841   if (start < 0)
842   {
843     start += axis_size;
844   }
845
846   // Clamping
847   start = Clamp(start, 0, axis_size - 1);
848
849   return start;
850 }
851
852 // Return the "real" index for the end of iteration along that axis. This is an
853 // "end" in the traditional C sense, in that it points to one past the last
854 // element. ie. So if you were iterating through all elements of a 1D array of
855 // size 4, this function would return 4 as the stop, because it is one past the
856 // "real" indices of 0, 1, 2 & 3.
857 int StopForAxis(const StridedSliceParams &params, const ir::Shape &input_shape, int axis,
858                 int start_for_axis)
859 {
860   const auto end_mask = params.end_mask;
861   const auto shrink_axis_mask = params.shrink_axis_mask;
862   const auto *stop_indices = params.stop_indices;
863   const auto *strides = params.strides;
864
865   // Begin with the specified index
866   const bool shrink_axis = shrink_axis_mask & (1 << axis);
867   int stop = stop_indices[axis];
868
869   // When shrinking an axis, the end position does not matter (and can be
870   // incorrect when negative indexing is used, see Issue #19260). Always use
871   // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
872   // already been adjusted for negative indices.
873   if (shrink_axis)
874   {
875     stop = start_for_axis + 1;
876   }
877
878   // end_mask override
879   if (end_mask & (1 << axis))
880   {
881     if (strides[axis] > 0)
882     {
883       // Forward iteration - use the last element. These values will get
884       // clamped below
885       stop = std::numeric_limits<int>::max();
886     }
887     else
888     {
889       // Backward iteration - use the first element.
890       stop = std::numeric_limits<int>::lowest();
891     }
892   }
893
894   // Handle negative indices
895
896   const int axis_size = input_shape.dim(axis);
897   if (stop < 0)
898   {
899     stop += axis_size;
900   }
901
902   // Clamping
903   // Because the end index points one past the last element, we need slightly
904   // different clamping ranges depending on the direction.
905   if (strides[axis] > 0)
906   {
907     // Forward iteration
908     stop = Clamp(stop, 0, axis_size);
909   }
910   else
911   {
912     // Backward iteration
913     stop = Clamp(stop, -1, axis_size - 1);
914   }
915
916   return stop;
917 }
918
919 ir::Shape inferStridedSliceShape(const ir::Shape &input_shape, const StridedSliceParams &op_params,
920                                  uint32_t rank)
921 {
922   ir::Shape out_shape;
923
924   for (uint32_t idx = 0; idx < rank; ++idx)
925   {
926     int32_t stride = op_params.strides[idx];
927     int32_t begin = StartForAxis(op_params, input_shape, idx);
928     int32_t end = StopForAxis(op_params, input_shape, idx, begin);
929
930     // When shrinking an axis, the end position does not matter (and can be
931     // incorrect when negative indexing is used, see Issue #19260). Always use
932     // begin + 1 to generate a length 1 slice, since begin has
933     // already been adjusted for negative indices by StartForAxis.
934     const bool shrink_axis = op_params.shrink_axis_mask & (1 << idx);
935     if (shrink_axis)
936     {
937       end = begin + 1;
938     }
939
940     int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
941     dim_shape = dim_shape < 0 ? 0 : dim_shape;
942     if (!shrink_axis)
943     {
944       out_shape.append(dim_shape);
945     }
946   }
947
948   return out_shape;
949 }
950
951 ir::Shape inferTileShape(const ir::Shape &in_shape, const int32_t *multiplier)
952 {
953   // assert(in_shape.rank() == multiplier.rank());
954   ir::Shape new_Shape(in_shape.rank());
955
956   for (int i = 0; i < in_shape.rank(); ++i)
957   {
958     assert(multiplier[i]); // multiplier[i] shuld not be 0.
959     new_Shape.dim(i) = in_shape.dim(i) * multiplier[i];
960   }
961   return new_Shape;
962 }
963
964 ir::Shape inferTransposeShape(const ir::Shape &in_shape, const std::vector<int> &perm)
965 {
966   if (static_cast<int>(perm.size()) > in_shape.rank())
967   {
968     throw std::runtime_error("inferTransposeShape failed, bad rank size: " +
969                              std::to_string(static_cast<int>(perm.size())));
970   }
971   ir::Shape out_shape(static_cast<int>(perm.size()));
972   for (int idx = 0; idx < static_cast<int>(perm.size()); idx++)
973   {
974     if (perm[idx] < 0 || perm[idx] >= static_cast<int>(perm.size()))
975     {
976       throw std::runtime_error("inferTransposeShape failed, bad perm value: " +
977                                std::to_string(perm[idx]));
978     }
979     out_shape.dim(idx) = in_shape.dim(perm[idx]);
980   }
981   return out_shape;
982 }
983
984 ir::Shape inferUnpackShape(const ir::Shape &input_shape, int axis, int rank)
985 {
986   ir::Shape out_shape;
987
988   for (int out_idx = 0; out_idx < rank; out_idx++)
989   {
990     if (out_idx != axis)
991     {
992       out_shape.append(input_shape.dim(out_idx));
993     }
994   }
995
996   return out_shape;
997 }
998
999 } // namespace shape_inference
1000 } // namespace onert