9a24f8c1a4f684e3413ba282badc3211fc59762b
[platform/core/ml/nnfw.git] / runtime / onert / core / src / util / ShapeInference.cc
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *      http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "util/Utils.h"
19 #include "ir/InternalType.h"
20 #include "ir/Shape.h"
21 #include "ir/operation/AvgPool2D.h"
22 #include "ir/operation/MaxPool2D.h"
23 #include "util/ShapeInference.h"
24 #include "util/logging.h"
25
26 #include <cassert>
27 #include <sstream>
28 #include <cmath>
29
30 namespace onert
31 {
32 namespace shape_inference
33 {
34
35 //
36 // Helper functions
37 //
38
39 namespace
40 {
41
42 template <typename T, typename U>
43 typename std::enable_if<std::is_integral<T>::value && std::is_integral<U>::value,
44                         typename std::common_type<T, U>::type>::type
45 ceil_div(T dividend, U divisor)
46 {
47   assert(dividend > 0 && divisor > 0 && "this implementations is for positive numbers only");
48   return (dividend + divisor - 1) / divisor;
49 }
50
51 // Calculate the result of broadcast of two shapes
52 ir::Shape broadcastShapes(const ir::Shape &lhs_shape, const ir::Shape &rhs_shape)
53 {
54   ir::Shape out_shape;
55   auto max_rank = std::max(lhs_shape.rank(), rhs_shape.rank());
56
57   for (int idx = 0; idx < max_rank; ++idx)
58   {
59     // Go over operands dimensions from right to left
60     int lhs_idx = lhs_shape.rank() - idx - 1;
61     int rhs_idx = rhs_shape.rank() - idx - 1;
62
63     int32_t lhs_dim = lhs_idx >= 0 ? lhs_shape.dim(lhs_idx) : 1;
64     int32_t rhs_dim = rhs_idx >= 0 ? rhs_shape.dim(rhs_idx) : 1;
65
66     if (lhs_dim != 1 && rhs_dim != 1 && lhs_dim != rhs_dim)
67       throw std::runtime_error("Incompatible shapes for broadcast");
68
69     out_shape.prepend(std::max(lhs_dim, rhs_dim));
70   }
71
72   return out_shape;
73 }
74
75 } // namespace
76
77 //
78 // Shape inference
79 //
80
81 // Calculate output height and width of convolution-like operation
82 std::pair<int, int> calcConvLikeHeightAndWidth(const int in_h, const int in_w, const int ker_h,
83                                                const int ker_w, const ir::Padding pad,
84                                                const ir::Stride stride)
85 {
86   int32_t out_h = 0, out_w = 0;
87
88   switch (pad.type)
89   {
90     case ir::PaddingType::SAME:
91       out_h = ceil_div(in_h, stride.vertical);
92       out_w = ceil_div(in_w, stride.horizontal);
93       break;
94     case ir::PaddingType::VALID:
95       out_h = ceil_div(in_h - ker_h + 1, stride.vertical);
96       out_w = ceil_div(in_w - ker_w + 1, stride.horizontal);
97       break;
98     case ir::PaddingType::EXPLICIT:
99       out_h = (in_h + pad.param.top + pad.param.bottom - ker_h) / stride.vertical + 1;
100       out_w = (in_w + pad.param.left + pad.param.right - ker_w) / stride.horizontal + 1;
101       break;
102     default:
103       assert(false);
104   }
105
106   return {out_h, out_w};
107 }
108
109 ir::Shape inferEltwiseShape(const ir::Shape &lhs_shape, const ir::Shape &rhs_shape)
110 {
111   return broadcastShapes(lhs_shape, rhs_shape);
112 }
113
114 ir::Shape inferArgMaxShape(const ir::Shape &input_shape, int axis, int rank)
115 {
116   ir::Shape out_shape;
117   for (int idx = 0; idx < rank; ++idx)
118   {
119     if (idx != axis)
120     {
121       int32_t input_dim = input_shape.dim(idx);
122       out_shape.append(input_dim);
123     }
124   }
125
126   return out_shape;
127 }
128
129 ir::Shape inferAvgPoolShape(const ir::Shape &in_shape, const ir::operation::AvgPool2D::Param &param,
130                             const ir::Layout layout)
131 {
132   assert(layout == ir::Layout::NHWC);
133   auto ifm_shape = in_shape.asFeature(layout);
134   const auto out_h_w = calcConvLikeHeightAndWidth(ifm_shape.H, ifm_shape.W, param.kh, param.kw,
135                                                   param.padding, param.stride);
136   // Pooling don't change number of channels and batch size
137   return ir::Shape{ifm_shape.N, out_h_w.first, out_h_w.second, ifm_shape.C};
138 }
139
140 ir::Shape inferReduceShape(const ir::Shape &input_shape, const std::vector<int> &axes,
141                            bool keep_dims)
142 {
143   int num_axis = axes.size();
144   int input_num_dims = input_shape.rank();
145   if (input_num_dims == 0)
146   {
147     ir::Shape out_shape(0);
148     return out_shape;
149   }
150   if (keep_dims)
151   {
152     ir::Shape out_shape;
153     for (int idx = 0; idx < input_num_dims; ++idx)
154     {
155       bool is_axis = false;
156       for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx)
157       {
158         if (axes[axis_idx] == idx || axes[axis_idx] + input_num_dims == idx)
159         {
160           is_axis = true;
161           break;
162         }
163       }
164       if (is_axis)
165       {
166         out_shape.append(1);
167       }
168       else
169       {
170         out_shape.append(input_shape.dim(idx));
171       }
172     }
173     return out_shape;
174   }
175   else
176   {
177     // Calculates size of reducing axis.
178     int num_reduce_axis = num_axis;
179     for (int i = 0; i < num_axis; ++i)
180     {
181       int current = axes[i];
182       if (current < 0)
183       {
184         current += input_num_dims;
185       }
186       assert(0 <= current && current < input_num_dims);
187       for (int j = 0; j < i; ++j)
188       {
189         int previous = axes[j];
190         if (previous < 0)
191         {
192           previous += input_num_dims;
193         }
194         if (current == previous)
195         {
196           --num_reduce_axis;
197           break;
198         }
199       }
200     }
201     // Determines output dimensions.
202     ir::Shape out_shape;
203     int num_skip_axis = 0;
204     for (int idx = 0; idx < input_num_dims; ++idx)
205     {
206       bool is_axis = false;
207       for (int axis_idx = 0; axis_idx < num_axis; ++axis_idx)
208       {
209         if (axes[axis_idx] == idx || axes[axis_idx] + input_num_dims == idx)
210         {
211           ++num_skip_axis;
212           is_axis = true;
213           break;
214         }
215       }
216       if (!is_axis)
217       {
218         out_shape.append(input_shape.dim(idx));
219       }
220     }
221     return out_shape;
222   }
223 }
224
225 ir::Shape inferBatchMatMulShape(const ir::Shape &lhs_shape, const ir::Shape &rhs_shape,
226                                 const ir::operation::BatchMatMul::Param &param)
227 {
228   bool adj_x = param.adj_x;
229   bool adj_y = param.adj_y;
230   ir::Shape output_shape;
231
232   int output_rank = std::max(lhs_shape.rank(), rhs_shape.rank());
233
234   // Extend lhs and rhs shape
235   ir::Shape extended_lhs_shape(lhs_shape);
236   ir::Shape extended_rhs_shape(rhs_shape);
237   extended_lhs_shape.extendRank(output_rank);
238   extended_rhs_shape.extendRank(output_rank);
239
240   for (int i = 0; i < output_rank - 2; i++)
241   {
242     const int lhs_dim = extended_lhs_shape.dim(i);
243     const int rhs_dim = extended_rhs_shape.dim(i);
244     int broadcast_dim = lhs_dim;
245     if (lhs_dim != rhs_dim)
246     {
247       if (lhs_dim == 1)
248       {
249         broadcast_dim = rhs_dim;
250       }
251       else if (rhs_dim != 1)
252       {
253         throw std::runtime_error{"BatchMatMul shape inference: invalid brodcasting input shape"};
254       }
255     }
256
257     output_shape.append(broadcast_dim);
258   }
259
260   // Fill in the matmul dimensions.
261   int lhs_rows_index = adj_x ? output_rank - 1 : output_rank - 2;
262   int rhs_cols_index = adj_y ? output_rank - 2 : output_rank - 1;
263
264   output_shape.append(extended_lhs_shape.dim(lhs_rows_index));
265   output_shape.append(extended_rhs_shape.dim(rhs_cols_index));
266
267   return output_shape;
268 }
269
270 ir::Shape inferBroadcastToShape(const ir::Shape wshape, const int32_t *shape_buffer)
271 {
272   const int num_elements = wshape.num_elements();
273
274   assert(num_elements != 0);
275   assert(shape_buffer);
276
277   ir::Shape new_shape(num_elements);
278
279   for (int i = 0; i < num_elements; ++i)
280   {
281     assert(shape_buffer[i] != 0); // It shouldn't be 0.
282     new_shape.dim(i) = shape_buffer[i];
283   }
284
285   return new_shape;
286 }
287
288 ir::Shape inferConcatShape(const Shapes &in_shapes, const ir::operation::Concat::Param &param)
289 {
290   const int32_t concat_axis = param.axis >= 0 ? param.axis : in_shapes[0].rank() + param.axis;
291   const auto &first_in_shape = in_shapes[0];
292
293   // Check that all shapes are equal except for concat axis dimension
294   for (const auto &in_shape : in_shapes)
295   {
296     if (in_shape.rank() != first_in_shape.rank())
297       throw std::runtime_error("Rank in all input tensors should be same");
298
299     for (int64_t dim_idx = 0; dim_idx < in_shape.rank(); ++dim_idx)
300       if (!(dim_idx == concat_axis || in_shape.dim(dim_idx) == first_in_shape.dim(dim_idx)))
301         throw std::runtime_error("All tensor should have same dimension "
302                                  "except dimension on passed axis");
303   }
304
305   // Calculate output shape
306   ir::Shape out_shape(first_in_shape);
307   out_shape.dim(concat_axis) = 0;
308   for (const auto &in_shape : in_shapes)
309     out_shape.dim(concat_axis) += in_shape.dim(concat_axis);
310   return out_shape;
311 }
312
313 ir::Shape inferConv2DShape(const ir::Shape &in_shape, const ir::Shape &ker_shape,
314                            const ir::operation::Conv2D::Param &param, ir::Layout layout)
315 {
316   auto ifm_shape = in_shape.asFeature(layout);
317
318   // Kernel format is [depth_out, kernel_height, kernel_width, depth_in]
319   auto kf_shape = ker_shape.asFeature(layout);
320   assert(ifm_shape.C == kf_shape.C);
321
322   const auto out_h_w = calcConvLikeHeightAndWidth(ifm_shape.H, ifm_shape.W, kf_shape.H, kf_shape.W,
323                                                   param.padding, param.stride);
324
325   return ir::Shape{ifm_shape.N, out_h_w.first, out_h_w.second, kf_shape.N};
326 }
327
328 ir::Shape inferDepthwiseConv2DShape(const ir::Shape &in_shape, const ir::Shape &ker_shape,
329                                     const ir::operation::DepthwiseConv2D::Param &param,
330                                     ir::Layout layout)
331 {
332   assert(layout == ir::Layout::NHWC);
333   auto ifm_shape = in_shape.asFeature(layout);
334
335   // Kernel format is [1, kernel_height, kernel_width, depth_out]
336   auto kf_shape = ker_shape.asFeature(layout);
337   assert(kf_shape.C == static_cast<int32_t>(ifm_shape.C * param.multiplier));
338   assert(kf_shape.N == 1);
339
340   const auto out_h_w = calcConvLikeHeightAndWidth(ifm_shape.H, ifm_shape.W, kf_shape.H, kf_shape.W,
341                                                   param.padding, param.stride);
342
343   return ir::Shape{ifm_shape.N, out_h_w.first, out_h_w.second, kf_shape.C};
344 }
345
346 ir::Shape inferExpandDimsShape(const ir::Shape &in_shape, int32_t axis)
347 {
348   ir::Shape out_shape(in_shape.rank() + 1);
349
350   axis = ((axis >= 0) ? axis : /* when axis < 0 */ (out_shape.rank() + axis));
351   if (!(0 <= axis && axis <= in_shape.rank()))
352     throw std::runtime_error("axis of dim is out of range");
353
354   for (int x = 0, out_x = 0; out_x < out_shape.rank(); ++out_x)
355   {
356     if (out_x == axis)
357       out_shape.dim(out_x) = 1;
358     else
359       out_shape.dim(out_x) = in_shape.dim(x++);
360   }
361
362   return out_shape;
363 }
364
365 ir::Shape inferFillShape(const ir::Shape &in_shape, const int32_t *buffer)
366 {
367   ir::Shape out_shape(in_shape.dim(0));
368
369   for (int out_x = 0; out_x < out_shape.rank(); ++out_x)
370   {
371     out_shape.dim(out_x) = buffer[out_x];
372   }
373
374   return out_shape;
375 }
376
377 ir::Shape inferFullyConnectedShape(const ir::Shape &in_shape, const ir::Shape &ker_shape)
378 {
379   assert(in_shape.rank() >= 2);
380   assert(ker_shape.rank() == 2);
381
382   const auto input_size_with_batch = in_shape.num_elements();
383   const auto num_units = ker_shape.dim(0);
384   const auto input_size = ker_shape.dim(1);
385   const auto batch_size = input_size_with_batch / input_size;
386   assert(input_size_with_batch % input_size == 0);
387
388   return {ir::Shape({static_cast<int32_t>(batch_size), num_units})};
389 }
390
391 ir::Shape inferGatherShape(const ir::Shape &input_shape, const ir::Shape &indices_shape, int axis,
392                            int rank)
393 {
394   ir::Shape out_shape;
395   const int indices_rank = indices_shape.rank();
396   for (int idx = 0; idx < rank; ++idx)
397   {
398     if (idx == axis)
399     {
400       for (int indices_idx = 0; indices_idx < indices_rank; indices_idx++)
401       {
402         out_shape.append(indices_shape.dim(indices_idx));
403       }
404     }
405     else
406     {
407       out_shape.append(input_shape.dim(idx));
408     }
409   }
410
411   return out_shape;
412 }
413
414 ir::Shape inferMaxPoolShape(const ir::Shape &in_shape, const ir::operation::MaxPool2D::Param &param,
415                             const ir::Layout layout)
416 {
417   assert(layout == ir::Layout::NHWC);
418   auto ifm_shape = in_shape.asFeature(layout);
419   const auto out_h_w = calcConvLikeHeightAndWidth(ifm_shape.H, ifm_shape.W, param.kh, param.kw,
420                                                   param.padding, param.stride);
421   // Pooling don't change number of channels and batch size
422   return ir::Shape{ifm_shape.N, out_h_w.first, out_h_w.second, ifm_shape.C};
423 }
424
425 ir::Shape inferOnehotShape(const ir::Shape &input_shape, const int depth, int axis)
426 {
427   assert(depth >= 0);
428   const auto rank = input_shape.rank() + 1;
429   ir::Shape newShape(rank);
430
431   axis = (axis == -1) ? (rank - 1) : axis;
432
433   for (int i = 0; i < rank; ++i)
434   {
435     if (i < axis)
436     {
437       newShape.dim(i) = input_shape.dim(i);
438     }
439     else if (i == axis)
440     {
441       newShape.dim(i) = depth;
442     }
443     else
444     {
445       newShape.dim(i) = input_shape.dim(i - 1);
446     }
447   }
448
449   return newShape;
450 }
451
452 ir::Shape inferPackShape(const ir::Shape &input_shape, int axis, int rank, int num)
453 {
454   ir::Shape out_shape;
455   int in_idx = 0;
456
457   for (int out_idx = 0; out_idx < rank; ++out_idx)
458   {
459     if (out_idx == axis)
460     {
461       out_shape.append(num);
462     }
463     else
464     {
465       out_shape.append(input_shape.dim(in_idx++));
466     }
467   }
468
469   return out_shape;
470 }
471
472 ir::Shape inferPadShape(const ir::Shape &in_shape, const int32_t *pad_buf, const size_t num_pads)
473 {
474   assert(num_pads % 2 == 0);
475   const int32_t rank = num_pads / 2;
476
477   ir::Shape ret(rank);
478   for (int32_t i = 0; i < rank; ++i)
479   {
480     const auto before_padding = pad_buf[i * 2];
481     const auto after_padding = pad_buf[i * 2 + 1];
482
483     ret.dim(i) = in_shape.dim(i) + before_padding + after_padding;
484   }
485
486   return ret;
487 }
488
489 ir::Shape inferResizeBilinearShape(const ir::Shape &in_shape, const int32_t output_height,
490                                    const int32_t output_width)
491 {
492   assert(in_shape.rank() == 4);
493   ir::Shape ret(in_shape.rank());
494
495   ret.dim(0) = in_shape.dim(0);
496   ret.dim(1) = output_height;
497   ret.dim(2) = output_width;
498   ret.dim(3) = in_shape.dim(3);
499
500   return ret;
501 }
502
503 template <typename T> ir::Shape inferRangeShape(T start_val, T limit_val, T delta_val)
504 {
505   ir::Shape out_shape(static_cast<int>(1));
506
507   out_shape.dim(0) =
508       (std::is_integral<T>::value
509            ? ((std::abs(start_val - limit_val) + std::abs(delta_val) - 1) / std::abs(delta_val))
510            : std::ceil(std::abs((start_val - limit_val) / delta_val)));
511   return out_shape;
512 }
513
514 // template instantiation
515 template ir::Shape inferRangeShape(int start_val, int limit_val, int delta_val);
516 template ir::Shape inferRangeShape(float start_val, float limit_val, float delta_val);
517
518 ir::Shape inferReshapeShape(const int32_t *shape_buf, const int32_t shape_num_elements,
519                             const size_t total_num_elements)
520 {
521   ir::Shape ret(shape_num_elements);
522   int32_t flatten_dim = ir::Shape::UNSPECIFIED_DIM;
523   for (int32_t i = 0; i < shape_num_elements; ++i)
524   {
525     if (shape_buf[i] < 0)
526     {
527       if (flatten_dim != ir::Shape::UNSPECIFIED_DIM)
528         throw std::runtime_error("Reshape: 2nd param has special dim(for flatten) more than twice");
529       flatten_dim = i;
530       ret.dim(i) = 1;
531     }
532     else
533     {
534       ret.dim(i) = shape_buf[i];
535     }
536   }
537   if (flatten_dim != ir::Shape::UNSPECIFIED_DIM)
538     ret.dim(flatten_dim) = total_num_elements / ret.num_elements();
539
540   // Check reshapable
541   if (total_num_elements != static_cast<size_t>(ret.num_elements()))
542     throw std::runtime_error("Reshape: 2nd param is not compatible with the shape of input");
543
544   return ret;
545 }
546
547 ir::Shape inferSelectShape(const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape,
548                            const ir::Shape &input_false_shape)
549 {
550   auto haveSameShapes = [](const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape,
551                            const ir::Shape &input_false_shape) {
552     if ((input_cond_shape.rank() != input_true_shape.rank()) ||
553         input_cond_shape.rank() != input_false_shape.rank())
554     {
555       return false;
556     }
557
558     int rank = input_cond_shape.rank();
559     for (int i = 0; i < rank; ++i)
560     {
561       if (input_cond_shape.dim(i) != input_true_shape.dim(i) ||
562           input_cond_shape.dim(i) != input_false_shape.dim(i))
563       {
564         return false;
565       }
566     }
567
568     return true;
569   };
570
571   auto calculateShape = [](const ir::Shape &input_cond_shape, const ir::Shape &input_true_shape,
572                            const ir::Shape &input_false_shape, ir::Shape &new_shape) {
573     ir::Shape cond_shape = input_cond_shape;
574     ir::Shape true_shape = input_true_shape;
575     ir::Shape false_shape = input_false_shape;
576     int most_rank =
577         (cond_shape.rank() >= true_shape.rank()) && (cond_shape.rank() >= false_shape.rank())
578             ? cond_shape.rank()
579             : (false_shape.rank() >= true_shape.rank() ? false_shape.rank() : true_shape.rank());
580
581     ir::Shape calculate_shape(most_rank);
582
583     cond_shape.extendRank(most_rank);
584     true_shape.extendRank(most_rank);
585     false_shape.extendRank(most_rank);
586
587     for (int i = 0; i < most_rank; ++i)
588     {
589       calculate_shape.dim(i) =
590           (cond_shape.dim(i) >= true_shape.dim(i)) && (cond_shape.dim(i) >= false_shape.dim(i))
591               ? cond_shape.dim(i)
592               : (false_shape.dim(i) >= true_shape.dim(i) ? false_shape.dim(i) : true_shape.dim(i));
593
594       if ((cond_shape.dim(i) != calculate_shape.dim(i) && cond_shape.dim(i) != 1) ||
595           (true_shape.dim(i) != calculate_shape.dim(i) && true_shape.dim(i) != 1) ||
596           (false_shape.dim(i) != calculate_shape.dim(i) && false_shape.dim(i) != 1))
597       {
598         return false;
599       }
600     }
601
602     new_shape = calculate_shape;
603
604     return true;
605   };
606
607   bool havesame = haveSameShapes(input_cond_shape, input_true_shape, input_false_shape);
608   if (havesame)
609   {
610     return input_cond_shape;
611   }
612
613   ir::Shape new_shape;
614   bool possible = calculateShape(input_cond_shape, input_true_shape, input_false_shape, new_shape);
615
616   if (!possible)
617   {
618     throw std::runtime_error("Broadcasting is not possible.");
619   }
620
621   return new_shape;
622 }
623
624 ir::Shape inferSliceShape(const ir::Shape &input_shape, const int32_t *begins, const int32_t *sizes)
625 {
626   const uint32_t rank = input_shape.rank();
627   ir::Shape out_shape(rank);
628
629   for (uint32_t idx = 0; idx < rank; ++idx)
630   {
631     const auto input_dim = input_shape.dim(idx);
632
633     // begin is zero-based
634     auto begin = begins[idx];
635     if (begin < 0)
636       throw std::runtime_error("shape inference Slice: Invalid begin.");
637
638     // size is one-based
639     auto size = sizes[idx];
640     if (size < -1)
641       throw std::runtime_error("shape inference Slice: Invalid size.");
642
643     if (size == -1)
644     {
645       size = input_dim - begin;
646     }
647     else
648     {
649       if (input_dim < begin + size)
650         throw std::runtime_error("shape inference Slice: Invalid begin and size.");
651     }
652     out_shape.dim(idx) = size;
653   }
654
655   return out_shape;
656 }
657
658 ir::Shape inferSpaceToBatchNDShape(const ir::Shape &input_shape, const ir::Shape &block_shape_shape,
659                                    const ir::Shape &padding_shape, const int32_t *block_shape_data,
660                                    const int32_t *padding_data)
661 {
662   const uint32_t rank = input_shape.rank();
663   ir::Shape out_shape(rank);
664
665   // Currently, only 4D NHWC input/output op_context are supported.
666   // The 4D array need to have exactly 2 spatial dimensions.
667   // TODO(nupurgarg): Support arbitrary dimension in SpaceToBatchND.
668   const int32_t kInputDimensionNum = 4;
669   const int32_t kBlockSizeDimensionNum = 1;
670   const int32_t kSpatialDimensionNum = 2;
671
672   UNUSED_RELEASE(kInputDimensionNum);
673   UNUSED_RELEASE(kBlockSizeDimensionNum);
674   UNUSED_RELEASE(block_shape_shape);
675   UNUSED_RELEASE(padding_shape);
676
677   assert(block_shape_shape.rank() == kBlockSizeDimensionNum);
678   assert(block_shape_shape.dim(0) == kSpatialDimensionNum);
679   assert(padding_shape.dim(0) == kSpatialDimensionNum);
680   assert(padding_shape.dim(1) == 2); // fixed, meaning left/right padding for each element
681   assert(padding_shape.rank() == 2); // fixed, meaning dimension(dim 0) and padding length(dim 1)
682
683   // Ensures the input height and width (with padding) is a multiple of block
684   // shape height and width.
685   for (int dim = 0; dim < kSpatialDimensionNum; ++dim)
686   {
687     int final_dim_size =
688         (input_shape.dim(dim + 1) + padding_data[dim * 2] + padding_data[dim * 2 + 1]);
689
690     assert(final_dim_size % block_shape_data[dim] == 0);
691
692     out_shape.dim(dim + 1) = final_dim_size / block_shape_data[dim];
693   }
694
695   const int output_batch_size = input_shape.dim(0) * block_shape_data[0] * block_shape_data[1];
696   const int output_channel_size = input_shape.dim(3);
697
698   out_shape.dim(0) = output_batch_size;
699   out_shape.dim(3) = output_channel_size;
700
701   return out_shape;
702 }
703
704 ir::Shape inferSplitShape(const ir::Shape input_shape, int axis_value, int num_splits)
705 {
706   ir::Shape newShape(input_shape);
707
708   assert(axis_value >= 0);
709   assert(axis_value < input_shape.rank());
710
711   const int input_size = input_shape.dim(axis_value);
712   assert(input_size % num_splits == 0);
713   const int slice_size = input_size / num_splits;
714
715   newShape.dim(axis_value) = slice_size;
716
717   return newShape;
718 }
719
720 ir::Shape inferSqueezeShape(const ir::Shape &in_shape, const ir::operation::Squeeze::Param &param)
721 {
722   const int ndims = param.ndim;
723   const int *squeeze_dims = param.dims;
724   bool should_squeeze[8] = {false};
725   int num_squeezed_dims = 0;
726   int shape_rank = in_shape.rank();
727   if (ndims == 0)
728   {
729     for (int idx = 0; idx < shape_rank; ++idx)
730     {
731       if (in_shape.dim(idx) == 1)
732       {
733         should_squeeze[idx] = true;
734         ++num_squeezed_dims;
735       }
736     }
737   }
738   else
739   {
740     for (int idx = 0; idx < ndims; ++idx)
741     {
742       int current = squeeze_dims[idx];
743       if (current < 0)
744       {
745         current += shape_rank;
746       }
747
748       if (!(current >= 0 && current < shape_rank && in_shape.dim(current) == 1))
749       {
750         throw std::runtime_error(
751             "The following conditions must be met: 0 <= dim < Shape rank, dim == 1");
752       }
753
754       if (!should_squeeze[current])
755       {
756         ++num_squeezed_dims;
757       }
758       should_squeeze[current] = true;
759     }
760   }
761
762   // Set output shape.
763   ir::Shape out_shape(shape_rank - num_squeezed_dims);
764   for (int in_idx = 0, out_idx = 0; in_idx < shape_rank; ++in_idx)
765   {
766     if (!should_squeeze[in_idx])
767     {
768       out_shape.dim(out_idx++) = in_shape.dim(in_idx);
769     }
770   }
771
772   return out_shape;
773 }
774
775 // helper for for StridedSlice
776 template <typename T>
777 StridedSliceParams buildStridedSliceParams(const T *begin, const T *end, const T *strides,
778                                            const uint32_t begin_mask, const uint32_t end_mask,
779                                            const uint32_t shrink_axis_mask, const uint8_t rank)
780 {
781   StridedSliceParams op_params;
782   op_params.start_indices_count = rank;
783   op_params.stop_indices_count = rank;
784   op_params.strides_count = rank;
785
786   for (int i = 0; i < op_params.strides_count; ++i)
787   {
788     op_params.start_indices[i] = begin[i];
789     op_params.stop_indices[i] = end[i];
790     op_params.strides[i] = strides[i];
791
792     assert(op_params.strides[i] != 0);
793   }
794
795   op_params.begin_mask = begin_mask;
796   op_params.ellipsis_mask = 0; // NYI
797   op_params.end_mask = end_mask;
798   op_params.new_axis_mask = 0; // NYI
799   op_params.shrink_axis_mask = shrink_axis_mask;
800
801   assert(sizeof(op_params.begin_mask) * 4 >= rank);
802
803   return op_params;
804 }
805
806 // template instantiation
807 template StridedSliceParams
808 buildStridedSliceParams(const uint32_t *begin, const uint32_t *end, const uint32_t *strides,
809                         const uint32_t begin_mask, const uint32_t end_mask,
810                         const uint32_t shrink_axis_mask, const uint8_t rank);
811
812 int Clamp(const int v, const int lo, const int hi)
813 {
814   assert(!(hi < lo));
815   if (hi < v)
816     return hi;
817   if (v < lo)
818     return lo;
819   return v;
820 }
821
822 int StartForAxis(const StridedSliceParams &params, const ir::Shape &input_shape, int axis)
823 {
824   const auto begin_mask = params.begin_mask;
825   const auto *start_indices = params.start_indices;
826   const auto *strides = params.strides;
827   // Begin with the specified index.
828   int start = start_indices[axis];
829
830   // begin_mask override
831   if (begin_mask & 1 << axis)
832   {
833     if (strides[axis] > 0)
834     {
835       // Forward iteration - use the first element. These values will get
836       // clamped below (Note: We could have set them to 0 and axis_size-1, but
837       // use lowest() and max() to maintain symmetry with StopForAxis())
838       start = std::numeric_limits<int>::lowest();
839     }
840     else
841     {
842       // Backward iteration - use the last element.
843       start = std::numeric_limits<int>::max();
844     }
845   }
846
847   // Handle negative indices
848   int axis_size = input_shape.dim(axis);
849   if (start < 0)
850   {
851     start += axis_size;
852   }
853
854   // Clamping
855   start = Clamp(start, 0, axis_size - 1);
856
857   return start;
858 }
859
860 // Return the "real" index for the end of iteration along that axis. This is an
861 // "end" in the traditional C sense, in that it points to one past the last
862 // element. ie. So if you were iterating through all elements of a 1D array of
863 // size 4, this function would return 4 as the stop, because it is one past the
864 // "real" indices of 0, 1, 2 & 3.
865 int StopForAxis(const StridedSliceParams &params, const ir::Shape &input_shape, int axis,
866                 int start_for_axis)
867 {
868   const auto end_mask = params.end_mask;
869   const auto shrink_axis_mask = params.shrink_axis_mask;
870   const auto *stop_indices = params.stop_indices;
871   const auto *strides = params.strides;
872
873   // Begin with the specified index
874   const bool shrink_axis = shrink_axis_mask & (1 << axis);
875   int stop = stop_indices[axis];
876
877   // When shrinking an axis, the end position does not matter (and can be
878   // incorrect when negative indexing is used, see Issue #19260). Always use
879   // start_for_axis + 1 to generate a length 1 slice, since start_for_axis has
880   // already been adjusted for negative indices.
881   if (shrink_axis)
882   {
883     stop = start_for_axis + 1;
884   }
885
886   // end_mask override
887   if (end_mask & (1 << axis))
888   {
889     if (strides[axis] > 0)
890     {
891       // Forward iteration - use the last element. These values will get
892       // clamped below
893       stop = std::numeric_limits<int>::max();
894     }
895     else
896     {
897       // Backward iteration - use the first element.
898       stop = std::numeric_limits<int>::lowest();
899     }
900   }
901
902   // Handle negative indices
903
904   const int axis_size = input_shape.dim(axis);
905   if (stop < 0)
906   {
907     stop += axis_size;
908   }
909
910   // Clamping
911   // Because the end index points one past the last element, we need slightly
912   // different clamping ranges depending on the direction.
913   if (strides[axis] > 0)
914   {
915     // Forward iteration
916     stop = Clamp(stop, 0, axis_size);
917   }
918   else
919   {
920     // Backward iteration
921     stop = Clamp(stop, -1, axis_size - 1);
922   }
923
924   return stop;
925 }
926
927 ir::Shape inferStridedSliceShape(const ir::Shape &input_shape, const StridedSliceParams &op_params,
928                                  uint32_t rank)
929 {
930   ir::Shape out_shape;
931
932   for (uint32_t idx = 0; idx < rank; ++idx)
933   {
934     int32_t stride = op_params.strides[idx];
935     int32_t begin = StartForAxis(op_params, input_shape, idx);
936     int32_t end = StopForAxis(op_params, input_shape, idx, begin);
937
938     // When shrinking an axis, the end position does not matter (and can be
939     // incorrect when negative indexing is used, see Issue #19260). Always use
940     // begin + 1 to generate a length 1 slice, since begin has
941     // already been adjusted for negative indices by StartForAxis.
942     const bool shrink_axis = op_params.shrink_axis_mask & (1 << idx);
943     if (shrink_axis)
944     {
945       end = begin + 1;
946     }
947
948     int32_t dim_shape = std::ceil((end - begin) / static_cast<float>(stride));
949     dim_shape = dim_shape < 0 ? 0 : dim_shape;
950     if (!shrink_axis)
951     {
952       out_shape.append(dim_shape);
953     }
954   }
955
956   return out_shape;
957 }
958
959 ir::Shape inferTileShape(const ir::Shape &in_shape, const int32_t *multiplier)
960 {
961   // assert(in_shape.rank() == multiplier.rank());
962   ir::Shape new_Shape(in_shape.rank());
963
964   for (int i = 0; i < in_shape.rank(); ++i)
965   {
966     assert(multiplier[i]); // multiplier[i] shuld not be 0.
967     new_Shape.dim(i) = in_shape.dim(i) * multiplier[i];
968   }
969   return new_Shape;
970 }
971
972 ir::Shape inferTransposeShape(const ir::Shape &in_shape, const std::vector<int> &perm)
973 {
974   if (static_cast<int>(perm.size()) > in_shape.rank())
975   {
976     throw std::runtime_error("inferTransposeShape failed, bad rank size: " +
977                              std::to_string(static_cast<int>(perm.size())));
978   }
979   ir::Shape out_shape(static_cast<int>(perm.size()));
980   for (int idx = 0; idx < static_cast<int>(perm.size()); idx++)
981   {
982     if (perm[idx] < 0 || perm[idx] >= static_cast<int>(perm.size()))
983     {
984       throw std::runtime_error("inferTransposeShape failed, bad perm value: " +
985                                std::to_string(perm[idx]));
986     }
987     out_shape.dim(idx) = in_shape.dim(perm[idx]);
988   }
989   return out_shape;
990 }
991
992 ir::Shape inferUnpackShape(const ir::Shape &input_shape, int axis, int rank)
993 {
994   ir::Shape out_shape;
995
996   for (int out_idx = 0; out_idx < rank; out_idx++)
997   {
998     if (out_idx != axis)
999     {
1000       out_shape.append(input_shape.dim(out_idx));
1001     }
1002   }
1003
1004   return out_shape;
1005 }
1006
1007 } // namespace shape_inference
1008 } // namespace onert