a55f50b193961594733251e132a20dcbb964ea19
[platform/core/ml/nnfw.git] / compiler / luci / service / src / CircleShapeInferenceRule.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Service/CircleShapeInferenceRule.h"
18 #include "Check.h"
19
20 #include "ShapeInfer_StridedSlice.h"
21
22 #include <luci/IR/CircleNodes.h>
23 #include <luci/IR/CircleDialect.h>
24 #include <luci/IR/CircleNodeVisitor.h>
25 #include <luci/Log.h>
26
27 #include <oops/InternalExn.h>
28
29 #include <algorithm>
30 #include <cassert>
31 #include <cmath>
32 #include <stdexcept>
33
34 namespace
35 {
36
37 std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape)
38 {
39   os << "[";
40   for (uint32_t r = 0; r < tensor_shape.rank(); ++r)
41   {
42     if (r)
43       os << ",";
44     os << tensor_shape.dim(r).value();
45   }
46   os << "]";
47   return os;
48 }
49
50 loco::TensorShape own_shape(const luci::CircleNode *node)
51 {
52   loco::TensorShape shape;
53   shape.rank(node->rank());
54   for (uint32_t r = 0; r < node->rank(); ++r)
55     shape.dim(r) = loco::Dimension(node->dim(r).value());
56   return shape;
57 }
58
59 loco::NodeShape use_own(const luci::CircleNode *node)
60 {
61   loco::TensorShape shape = own_shape(node);
62   return loco::NodeShape{shape};
63 }
64
65 /**
66  * @brief Create a higher-rank TensorShape following NumPy broadcasting semantics
67  *
68  * HOW TO USE:
69  *
70  *   auto expanded_tensor_shape = expand(tensor_shape).to(N);
71  */
72 class TensorShapeExpander
73 {
74 public:
75   TensorShapeExpander(const loco::TensorShape &shape) : _shape{shape}
76   {
77     // DO NOTHING
78   }
79
80 public:
81   loco::TensorShape to(uint32_t output_rank)
82   {
83     auto const &input_shape = _shape;
84     uint32_t const input_rank = input_shape.rank();
85
86     assert(input_rank <= output_rank && "Cannot shrink rank");
87     uint32_t const axis_shift = output_rank - input_rank;
88
89     loco::TensorShape output_shape;
90
91     output_shape.rank(output_rank);
92     for (uint32_t axis = 0; axis < output_rank; ++axis)
93     {
94       output_shape.dim(axis) = (axis < axis_shift) ? 1 : input_shape.dim(axis - axis_shift);
95     }
96
97     return output_shape;
98   }
99
100 private:
101   const loco::TensorShape _shape;
102 };
103
104 /**
105  * @breif  Expand shape x and y to same rank by align right and filling with 1
106  */
107 void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
108 {
109   auto x_rank = x.rank();
110   auto y_rank = y.rank();
111
112   if (x_rank == y_rank)
113     return;
114
115   TensorShapeExpander x_exp(x);
116   TensorShapeExpander y_exp(y);
117
118   auto xy_rank = std::max(x_rank, y_rank);
119
120   x = x_rank > y_rank ? x : x_exp.to(xy_rank);
121   y = y_rank > x_rank ? y : y_exp.to(xy_rank);
122 }
123
124 /**
125  * @breif  Returns shape of expanded dimension of input x and y having same rank
126  */
127 loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y)
128 {
129   assert(x.rank() == y.rank());
130
131   auto rank = x.rank();
132
133   loco::TensorShape output_shape;
134
135   output_shape.rank(rank);
136   for (uint32_t axis = 0; axis < rank; ++axis)
137   {
138     assert(x.dim(axis).known() && y.dim(axis).known());
139
140     auto x_dim = x.dim(axis).value();
141     auto y_dim = y.dim(axis).value();
142
143     // each dimension of x and y should be same or one must be 1 if different
144     if (!((x_dim == y_dim) || (x_dim == 1 || y_dim == 1)))
145       INTERNAL_EXN("Cannot produce expand_dimension of two shapes");
146
147     output_shape.dim(axis) = std::max(x_dim, y_dim);
148   }
149
150   return output_shape;
151 }
152
153 loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y)
154 {
155   auto x_match = x;
156   auto y_match = y;
157
158   expand_rank(x_match, y_match);
159
160   auto output_shape = expand_dimension(x_match, y_match);
161
162   return output_shape;
163 }
164
165 /**
166  * @brief vector_from_constant will return int64_t vector from CircleConst node
167  */
168 template <loco::DataType T> std::vector<int64_t> vector_from_constant(luci::CircleConst *const_node)
169 {
170   std::vector<int64_t> result;
171
172   for (uint32_t idx = 0; idx < const_node->size<T>(); ++idx)
173     result.push_back(const_node->at<T>(idx));
174
175   return result;
176 }
177
178 template <class CIRCLENODE> loco::NodeShape broadcast_xy(const CIRCLENODE *node)
179 {
180   auto x_shape = loco::shape_get(node->x()).template as<loco::TensorShape>();
181   auto y_shape = loco::shape_get(node->y()).template as<loco::TensorShape>();
182
183   auto output_shape = broadcast_shape(x_shape, y_shape);
184
185   return loco::NodeShape{output_shape};
186 }
187
188 template <class CIRCLENODE> loco::NodeShape use_x(const CIRCLENODE *node)
189 {
190   auto x_shape = loco::shape_get(node->x()).template as<loco::TensorShape>();
191   return loco::NodeShape{x_shape};
192 }
193
194 template <class CIRCLENODE> loco::NodeShape use_logits(const CIRCLENODE *node)
195 {
196   auto shape = loco::shape_get(node->logits()).template as<loco::TensorShape>();
197   return loco::NodeShape{shape};
198 }
199
200 template <class CIRCLENODE>
201 loco::NodeShape use_paddings(const CIRCLENODE *node, const luci::CircleConst *paddings)
202 {
203   const loco::DataType S32 = loco::DataType::S32;
204
205   auto input_shape = loco::shape_get(node->input()).template as<loco::TensorShape>();
206
207   // TODO support other data type
208   LUCI_ASSERT(paddings->dtype() == S32, "Only support int 32 for now");
209   LUCI_ASSERT(paddings->rank() == 2, "paddings should be rank 2")
210
211   int32_t n = paddings->dim(0).value();
212   int32_t v = paddings->dim(1).value();
213
214   LUCI_ASSERT(v == 2, "paddings should be [n, 2]");
215   LUCI_ASSERT(n == int32_t(input_shape.rank()),
216               "paddings [n, 2] should have same value of input rank");
217
218   loco::TensorShape output_shape;
219
220   output_shape.rank(input_shape.rank());
221   for (int32_t ni = 0; ni < n; ++ni)
222   {
223     int32_t idx = ni * 2;
224     int value = input_shape.dim(ni).value();
225     value += paddings->at<S32>(idx + 0); // left
226     value += paddings->at<S32>(idx + 1); // right
227     output_shape.dim(ni) = value;
228   }
229
230   return loco::NodeShape{output_shape};
231 }
232
233 loco::NodeShape infer_add_n(const luci::CircleAddN *node)
234 {
235   auto shape = loco::shape_get(node->inputs(0)).as<loco::TensorShape>();
236
237   for (uint32_t idx = 1; idx < node->arity(); ++idx)
238   {
239     auto shape_idx = loco::shape_get(node->inputs(idx)).as<loco::TensorShape>();
240     if (!(shape == shape_idx))
241     {
242       INTERNAL_EXN_V("ADD_N shape not same as the first input: ", idx);
243     }
244   }
245   return loco::NodeShape{shape};
246 }
247
248 loco::NodeShape infer_arg_max(const luci::CircleArgMax *node)
249 {
250   auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
251   auto dimension_shape = loco::shape_get(node->dimension()).as<loco::TensorShape>();
252
253   int64_t select_axis = 0;
254   {
255     LUCI_ASSERT(node->dimension(), "2nd input dimension() should not be nullptr");
256
257     // Only support node's shape() is CircleConst with S32/S64
258     // Support S32 for now.
259     auto const_shape_node = loco::must_cast<luci::CircleConst *>(node->dimension());
260     LUCI_ASSERT(const_shape_node->dtype() == loco::DataType::S32,
261                 "Only support int32 CircleConst for CircleArgMax");
262
263     if (const_shape_node->rank() > 1)
264       INTERNAL_EXN_V("Only support rank 0/1 CircleConst",
265                      oops::to_uint32(const_shape_node->rank()));
266
267     select_axis = const_shape_node->scalar<loco::DataType::S32>();
268   }
269   assert(select_axis < input_shape.rank());
270   assert(select_axis >= 0); // TODO support minus of this breaks
271
272   // NOTE select_axis is removed
273   loco::TensorShape shape_output;
274   uint32_t rank = input_shape.rank();
275   uint32_t shrink = static_cast<uint32_t>(select_axis);
276   assert(rank > 0);
277   shape_output.rank(rank - 1);
278   for (uint32_t r = 0, d = 0; r < rank; ++r)
279   {
280     if (r == shrink)
281       continue;
282     shape_output.dim(d++) = input_shape.dim(r);
283   }
284   return loco::NodeShape{shape_output};
285 }
286
287 loco::NodeShape infer_arg_min(const luci::CircleArgMin *node)
288 {
289   auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
290   auto dimension_shape = loco::shape_get(node->dimension()).as<loco::TensorShape>();
291
292   int64_t select_axis = 0;
293   {
294     LUCI_ASSERT(node->dimension(), "2nd input dimension() should not be nullptr");
295
296     // Only support node's shape() is CircleConst with S32/S64
297     // Support S32 for now.
298     auto const_shape_node = loco::must_cast<luci::CircleConst *>(node->dimension());
299     LUCI_ASSERT(const_shape_node->dtype() == loco::DataType::S32,
300                 "Only support int32 CircleConst for CircleArgMin");
301
302     if (const_shape_node->rank() > 1)
303       INTERNAL_EXN_V("Only support rank 0/1 CircleConst",
304                      oops::to_uint32(const_shape_node->rank()));
305
306     select_axis = const_shape_node->scalar<loco::DataType::S32>();
307   }
308   assert(select_axis < input_shape.rank());
309   assert(select_axis >= 0); // TODO support minus of this breaks
310
311   // NOTE select_axis is removed
312   loco::TensorShape shape_output;
313   uint32_t rank = input_shape.rank();
314   uint32_t shrink = static_cast<uint32_t>(select_axis);
315   assert(rank > 0);
316   shape_output.rank(rank - 1);
317   for (uint32_t r = 0, d = 0; r < rank; ++r)
318   {
319     if (r == shrink)
320       continue;
321     shape_output.dim(d++) = input_shape.dim(r);
322   }
323   return loco::NodeShape{shape_output};
324 }
325
326 // Call this for CircleAvgPool2D and CircleMaxPool2D only
327 template <class Pool2DType> loco::NodeShape infer_pool_2d_shape(const Pool2DType *node)
328 {
329   LUCI_ASSERT(loco::shape_known(node->value()), "Shape must be known");
330
331   auto ifm_shape = loco::shape_get(node->value()).template as<loco::TensorShape>();
332   assert(ifm_shape.rank() == 4);
333
334   uint32_t input_height = ifm_shape.dim(1).value();
335   uint32_t input_width = ifm_shape.dim(2).value();
336   uint32_t stride_height = node->stride()->h();
337   uint32_t stride_width = node->stride()->w();
338   uint32_t window_height = node->filter()->h();
339   uint32_t window_width = node->filter()->w();
340   uint32_t dilation_height = 1; // dilation for CircleAvgPool2D and CircleMaxPool2D is 1
341   uint32_t dilation_width = 1;
342   uint32_t effective_window_height = dilation_height * (window_height - 1) + 1;
343   uint32_t effective_window_width = dilation_width * (window_width - 1) + 1;
344
345   uint32_t output_height = 0;
346   uint32_t output_width = 0;
347
348   if (node->padding() == luci::Padding::VALID)
349   {
350     output_height = (input_height + stride_height - effective_window_height) / stride_height;
351     output_width = (input_width + stride_width - effective_window_width) / stride_width;
352   }
353   else if (node->padding() == luci::Padding::SAME)
354   {
355     output_height = (input_height + stride_height - 1) / stride_height;
356     output_width = (input_width + stride_width - 1) / stride_width;
357   }
358   else
359     LUCI_ASSERT(false, "Wrong padding type");
360
361   loco::TensorShape ofm_shape;
362   ofm_shape.rank(4);
363   ofm_shape.dim(0) = ifm_shape.dim(0);
364   ofm_shape.dim(1) = output_height;
365   ofm_shape.dim(2) = output_width;
366   ofm_shape.dim(3) = ifm_shape.dim(3);
367
368   return loco::NodeShape{ofm_shape};
369 }
370
371 loco::NodeShape infer_batch_to_space_nd(const luci::CircleBatchToSpaceND *node)
372 {
373   const loco::DataType S32 = loco::DataType::S32;
374
375   auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
376   // Support only input rank is 3 and 4
377   assert(input_shape.rank() == 3 || input_shape.rank() == 4);
378
379   // Only support block_shape() with S32 type CircleConst for now
380   auto const_block_shape = loco::must_cast<luci::CircleConst *>(node->block_shape());
381   LUCI_ASSERT(const_block_shape->dtype() == loco::DataType::S32, "Only support int32 block_shape");
382
383   // Only support crops() with S32 type CircleConst for now
384   auto const_crops = loco::must_cast<luci::CircleConst *>(node->crops());
385   LUCI_ASSERT(const_crops->dtype() == loco::DataType::S32, "Only support int32 crops");
386
387   auto const_block_shape_shape = loco::shape_get(const_block_shape).as<loco::TensorShape>();
388   auto const_crops_shape = loco::shape_get(const_crops).as<loco::TensorShape>();
389   assert(const_block_shape_shape.rank() == 1);
390   assert(const_crops_shape.rank() == 2);
391
392   int32_t input_spatial_dim = input_shape.rank() - 2;
393   assert(const_block_shape_shape.dim(0) == input_spatial_dim);
394   assert(const_crops_shape.dim(0) == input_spatial_dim);
395   assert(const_crops_shape.dim(1) == 2);
396
397   loco::TensorShape shape_output;
398
399   shape_output.rank(input_shape.rank());
400
401   int32_t output_batch_size = input_shape.dim(0).value();
402   for (int32_t dim = 0; dim < input_spatial_dim; ++dim)
403   {
404     int dim_size = input_shape.dim(dim + 1).value() * const_block_shape->at<S32>(dim);
405     dim_size -= const_crops->at<S32>(dim * 2);
406     dim_size -= const_crops->at<S32>(dim * 2 + 1);
407     shape_output.dim(dim + 1) = dim_size;
408
409     assert(output_batch_size % const_block_shape->at<S32>(dim) == 0);
410     output_batch_size = output_batch_size / const_block_shape->at<S32>(dim);
411   }
412   shape_output.dim(0) = output_batch_size;
413   shape_output.dim(input_shape.rank() - 1) = input_shape.dim(input_shape.rank() - 1);
414
415   return loco::NodeShape{shape_output};
416 }
417
418 struct OutputSize
419 {
420   uint32_t height = 0;
421   uint32_t width = 0;
422 };
423
424 template <class Conv2DType> OutputSize infer_conv2d_type(const Conv2DType *node)
425 {
426   auto ifm_shape = loco::shape_get(node->input()).template as<loco::TensorShape>();
427   auto ker_shape = loco::shape_get(node->filter()).template as<loco::TensorShape>();
428   assert(ifm_shape.rank() == 4);
429   assert(ker_shape.rank() == 4);
430
431   uint32_t input_height = ifm_shape.dim(1).value();
432   uint32_t input_width = ifm_shape.dim(2).value();
433   uint32_t stride_height = node->stride()->h();
434   uint32_t stride_width = node->stride()->w();
435   uint32_t ker_height = ker_shape.dim(1).value();
436   uint32_t ker_width = ker_shape.dim(2).value();
437   uint32_t dilation_height = node->dilation()->h();
438   uint32_t dilation_width = node->dilation()->w();
439   uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1;
440   uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1;
441
442   uint32_t output_height = 0;
443   uint32_t output_width = 0;
444
445   if (node->padding() == luci::Padding::VALID)
446   {
447     output_height = (input_height + stride_height - effective_ker_height) / stride_height;
448     output_width = (input_width + stride_width - effective_ker_width) / stride_width;
449   }
450   else if (node->padding() == luci::Padding::SAME)
451   {
452     output_height = (input_height + stride_height - 1) / stride_height;
453     output_width = (input_width + stride_width - 1) / stride_width;
454   }
455   else
456     LUCI_ASSERT(false, "Wrong padding type");
457
458   OutputSize os{output_height, output_width};
459
460   return os;
461 }
462
463 // BatchMatMulV2 supports broadcasting in the batch dimensions(BatchMatMul doesn't)
464 // TODO Distinguish BatchMatMul and BatchMatMulV2
465 loco::NodeShape infer_batchmatmul_shape(const loco::TensorShape &x_shape,
466                                         const loco::TensorShape &y_shape, bool adj_x, bool adj_y)
467 {
468   uint32_t x_rank = x_shape.rank();
469   uint32_t y_rank = y_shape.rank();
470   assert(x_rank >= 2 && y_rank >= 2);
471
472   loco::TensorShape output_shape;
473   output_shape.rank(x_shape.rank());
474   // Braodcast in the batch dimension
475   if (x_rank > 2 || y_rank > 2)
476   {
477     loco::TensorShape dummy_x = x_shape;
478     loco::TensorShape dummy_y = y_shape;
479     expand_rank(dummy_x, dummy_y);
480     if (x_rank < y_rank)
481       expand_rank(output_shape, dummy_y);
482
483     for (uint32_t d = 0; d < output_shape.rank() - 2; d++)
484     {
485       uint32_t max_dim = std::max(dummy_x.dim(d).value(), dummy_y.dim(d).value());
486       if (dummy_x.dim(d) == dummy_y.dim(d) ||
487           dummy_x.dim(d).value() * dummy_y.dim(d).value() == max_dim)
488         output_shape.dim(d).set(max_dim);
489       else
490         INTERNAL_EXN("BatchMatMul has wrong shape");
491     }
492   }
493
494   loco::Dimension x_lhs = adj_x ? x_shape.dim(x_rank - 1) : x_shape.dim(x_rank - 2);
495   loco::Dimension x_rhs = adj_x ? x_shape.dim(x_rank - 2) : x_shape.dim(x_rank - 1);
496   loco::Dimension y_lhs = adj_y ? y_shape.dim(y_rank - 1) : y_shape.dim(y_rank - 2);
497   loco::Dimension y_rhs = adj_y ? y_shape.dim(y_rank - 2) : y_shape.dim(y_rank - 1);
498
499   if (not(x_rhs == y_lhs))
500     INTERNAL_EXN("x_rhs and y_lhs should be same");
501
502   uint32_t out_rank = output_shape.rank();
503   output_shape.dim(out_rank - 2) = x_lhs;
504   output_shape.dim(out_rank - 1) = y_rhs;
505
506   return loco::NodeShape{output_shape};
507 }
508
509 loco::NodeShape infer_concatenation(const luci::CircleConcatenation *node)
510 {
511   // TODO Support when CircleConcatenation has 0 input
512   assert(node->numValues() > 0);
513
514   auto first_shape = loco::shape_get(node->values(0)).as<loco::TensorShape>();
515   auto axis = node->axis();
516   if (axis < 0)
517     axis += first_shape.rank();
518
519   assert(0 <= axis);
520   assert(first_shape.rank() > static_cast<uint32_t>(axis));
521
522   loco::TensorShape output_shape;
523
524   output_shape.rank(first_shape.rank());
525   for (uint32_t i = 0; i < output_shape.rank(); ++i)
526     output_shape.dim(i) = first_shape.dim(i);
527
528   for (uint32_t i = 1; i < node->numValues(); ++i)
529   {
530     auto input_shape = loco::shape_get(node->values(i)).as<loco::TensorShape>();
531
532     for (uint32_t j = 0; j < output_shape.rank(); ++j)
533     {
534       if (j == static_cast<uint32_t>(axis))
535         output_shape.dim(j) = output_shape.dim(j).value() + input_shape.dim(j).value();
536       else
537         assert(output_shape.dim(j) == input_shape.dim(j));
538     }
539   }
540
541   return loco::NodeShape{output_shape};
542 }
543
544 loco::NodeShape infer_conv2d(const luci::CircleConv2D *node)
545 {
546   LOGGER(l);
547
548   auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>();  // in NHWC
549   auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in OHWI
550
551   INFO(l) << "[luci] CircleConv2D ShapeInf ifm(" << ifm_shape.rank() << ") ker(" << ker_shape.rank()
552           << ")" << std::endl;
553
554   assert(ifm_shape.rank() == 4);
555   assert(ker_shape.rank() == 4);
556   assert(ifm_shape.dim(3) == ker_shape.dim(3));
557
558   auto os = infer_conv2d_type(node);
559
560   loco::TensorShape ofm_shape;
561   ofm_shape.rank(4);
562   ofm_shape.dim(0) = ifm_shape.dim(0);
563   ofm_shape.dim(1) = os.height;
564   ofm_shape.dim(2) = os.width;
565   ofm_shape.dim(3) = ker_shape.dim(0);
566
567   return loco::NodeShape{ofm_shape};
568 }
569
570 loco::NodeShape infer_depth_to_space(const luci::CircleDepthToSpace *node)
571 {
572   auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
573   LUCI_ASSERT(input_shape.rank() == 4, "Only input rank 4 is supported");
574
575   // Only data format NHWC is supported
576   // TODO need to clarify what to do with layout in this operator
577   int32_t height = input_shape.dim(1).value();
578   int32_t width = input_shape.dim(2).value();
579   int32_t depth = input_shape.dim(3).value();
580
581   int block_size = node->block_size();
582
583   if (block_size < 2)
584     INTERNAL_EXN("Block size must be >= 2");
585
586   if (depth % (block_size * block_size))
587   {
588     INTERNAL_EXN("The input tensor's depth must be divisible by block_size^2");
589   }
590
591   loco::TensorShape output_shape;
592   output_shape.rank(4);
593
594   output_shape.dim(0) = input_shape.dim(0).value();
595   output_shape.dim(1) = height * block_size;
596   output_shape.dim(2) = width * block_size;
597   output_shape.dim(3) = depth / (block_size * block_size);
598
599   return loco::NodeShape{output_shape};
600 }
601
602 loco::NodeShape infer_depthwise_conv2d(const luci::CircleDepthwiseConv2D *node)
603 {
604   auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>();  // in NHWC
605   auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in 1 H W CM
606
607   assert(ifm_shape.rank() == 4);
608   assert(ker_shape.rank() == 4);
609   assert(ker_shape.dim(0).value() == 1);
610
611   auto os = infer_conv2d_type(node);
612
613   loco::TensorShape ofm_shape;
614   ofm_shape.rank(4);
615   ofm_shape.dim(0) = ifm_shape.dim(0);
616   ofm_shape.dim(1) = os.height;
617   ofm_shape.dim(2) = os.width;
618   ofm_shape.dim(3) = ker_shape.dim(3);
619
620   return loco::NodeShape{ofm_shape};
621 }
622
623 loco::NodeShape infer_expand_dims(const luci::CircleExpandDims *node)
624 {
625   const loco::DataType S32 = loco::DataType::S32;
626   auto x_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
627   if (x_shape.rank() == 0)
628   {
629     // This maybe for unknown shape. We use shape from the node itself.
630     return use_own(node);
631   }
632   auto const_axis = loco::must_cast<luci::CircleConst *>(node->axis());
633   LUCI_ASSERT(const_axis->dtype() == S32, "Only support int32 CircleConst for axis");
634   if (const_axis->rank() != 0 && const_axis->rank() != 1)
635   {
636     INTERNAL_EXN_V("Non-scalar axis in OP", node->opnum());
637   }
638   int32_t axis = const_axis->at<S32>(0);
639   LUCI_ASSERT((axis <= static_cast<int32_t>(x_shape.rank())) &&
640                   (axis >= -1 - static_cast<int32_t>(x_shape.rank())),
641               "Axis has to be between [-(D+1), D], where D is rank of input.");
642   size_t positive_axis = axis < 0 ? x_shape.rank() + axis + 1 : axis;
643   loco::TensorShape output_shape;
644   output_shape.rank(x_shape.rank() + 1);
645   size_t i = 0;
646   for (; i < positive_axis; i++)
647     output_shape.dim(i) = x_shape.dim(i);
648   output_shape.dim(i) = loco::Dimension(1);
649   for (; i < x_shape.rank(); i++)
650     output_shape.dim(i + 1) = x_shape.dim(i);
651   return loco::NodeShape{output_shape};
652 }
653
654 loco::NodeShape infer_fill(const luci::CircleFill *node)
655 {
656   loco::TensorShape shape;
657   {
658     LUCI_ASSERT(node->dims(), "dims input should not be nullptr");
659
660     auto dims_node = dynamic_cast<luci::CircleConst *>(node->dims());
661     if (dims_node != nullptr)
662     {
663       // Only support node with S32
664       LUCI_ASSERT(dims_node->dtype() == loco::DataType::S32, "Only support int32 CircleConst");
665
666       if (dims_node->rank() != 1)
667         INTERNAL_EXN_V("Only support rank 1 CircleConst", oops::to_uint32(dims_node->rank()));
668
669       shape.rank(dims_node->dim(0).value());
670
671       for (uint32_t axis = 0; axis < shape.rank(); ++axis)
672       {
673         shape.dim(axis) = dims_node->at<loco::DataType::S32>(axis);
674       }
675     }
676     else
677     {
678       shape = own_shape(node);
679     }
680   }
681
682   return loco::NodeShape{shape};
683 }
684
685 loco::NodeShape infer_fully_connected(const luci::CircleFullyConnected *node)
686 {
687   auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
688   auto weights_shape = loco::shape_get(node->weights()).as<loco::TensorShape>();
689
690   // Checking shape capability for fully connected layer
691   // Input: a tensor of at least rank 2 [D1, D2, ... Dn]
692   // Weight: [# of units, K]
693   // Output: [D1 * D2 * ... * Dn / K, # of units]
694   if (input_shape.rank() < 2 || weights_shape.rank() != 2)
695   {
696     // Return node own shape if shape inference is not possible
697     return use_own(node);
698   }
699
700   uint32_t input_size = 1;
701   for (uint32_t i = 0; i < input_shape.rank(); i++)
702   {
703     input_size = input_size * input_shape.dim(i).value();
704   }
705   const uint32_t batch_size = input_size / weights_shape.dim(1).value();
706   loco::TensorShape out_shape;
707   out_shape.rank(2);
708   out_shape.dim(0) = batch_size;
709   out_shape.dim(1) = weights_shape.dim(0);
710
711   return loco::NodeShape{out_shape};
712 }
713
714 loco::NodeShape infer_gather(const luci::CircleGather *node)
715 {
716   loco::TensorShape output_shape;
717
718   const auto input_shape = loco::shape_get(node->params()).as<loco::TensorShape>();
719   const auto positions_shape = loco::shape_get(node->indices()).as<loco::TensorShape>();
720   int32_t axis = node->axis();
721
722   // If CircleGather input has a dynamic shape, it can't inference this shape. So, it returns the
723   // shape that node already has.
724   if (input_shape.rank() == 0 || positions_shape.rank() == 0)
725     return use_own(node);
726
727   if (axis < 0)
728     axis += input_shape.rank();
729
730   output_shape.rank(input_shape.rank() - 1 + positions_shape.rank());
731   int32_t outdim_index = 0;
732   for (int32_t i = 0; i < axis; ++i)
733     output_shape.dim(outdim_index++) = input_shape.dim(i);
734   for (uint32_t i = 0; i < positions_shape.rank(); ++i)
735     output_shape.dim(outdim_index++) = positions_shape.dim(i);
736   for (uint32_t i = axis + 1; i < input_shape.rank(); ++i)
737     output_shape.dim(outdim_index++) = input_shape.dim(i);
738
739   return loco::NodeShape{output_shape};
740 }
741
742 loco::NodeShape infer_gather_nd(const luci::CircleGatherNd *node)
743 {
744   loco::TensorShape output_shape;
745
746   const auto params_shape = loco::shape_get(node->params()).as<loco::TensorShape>();
747   const auto indices_shape = loco::shape_get(node->indices()).as<loco::TensorShape>();
748
749   const auto params_rank = params_shape.rank();
750   const auto indices_rank = indices_shape.rank();
751
752   // see https://www.tensorflow.org/api_docs/python/tf/gather_nd
753   // output.shape = indices.shape[:-1] + params.shape[indices.shape[-1]:]
754   // batch_dims isn't supported in tflite
755
756   // TODO: replace exceptions with setting shape to unknown?
757
758   if (!indices_shape.dim(indices_rank - 1).known())
759     INTERNAL_EXN("Last indices dimension is unknown");
760
761   auto indices_last_dim = indices_shape.dim(indices_rank - 1).value();
762
763   if (indices_last_dim > params_rank)
764     INTERNAL_EXN("Last indices dimension should be <= params rank");
765
766   const uint32_t output_rank = indices_rank + params_rank - indices_last_dim - 1;
767
768   output_shape.rank(output_rank);
769
770   uint32_t output_index = 0;
771   for (uint32_t i = 0; i < indices_rank - 1; ++i)
772   {
773     auto &dim = indices_shape.dim(i);
774     if (!dim.known())
775       INTERNAL_EXN("Unknown indices dimension is unsupported");
776     output_shape.dim(output_index++).set(dim.value());
777   }
778
779   for (uint32_t i = indices_last_dim; i < params_rank; ++i)
780   {
781     auto &dim = params_shape.dim(i);
782     if (!dim.known())
783       INTERNAL_EXN("Unknown params dimension is unsupported");
784     output_shape.dim(output_index++).set(dim.value());
785   }
786
787   return loco::NodeShape{output_shape};
788 }
789
790 loco::NodeShape infer_matrix_diag(const luci::CircleMatrixDiag *node)
791 {
792   loco::TensorShape output_shape;
793
794   auto diagonal_shape = loco::shape_get(node->diagonal()).as<loco::TensorShape>();
795   auto rank = diagonal_shape.rank();
796
797   output_shape.rank(rank + 1);
798
799   for (uint32_t i = 0; i < rank; i++)
800   {
801     output_shape.dim(i) = diagonal_shape.dim(i);
802   }
803
804   output_shape.dim(rank) = diagonal_shape.dim(rank - 1);
805
806   return loco::NodeShape{output_shape};
807 }
808
809 loco::NodeShape infer_matrix_set_diag(const luci::CircleMatrixSetDiag *node)
810 {
811   auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
812   auto diagonal_shape = loco::shape_get(node->diagonal()).as<loco::TensorShape>();
813
814   auto rank = diagonal_shape.rank();
815
816   LUCI_ASSERT(rank == input_shape.rank() - 1, "diagonal rank = input rank - 1");
817
818   for (uint32_t i = 0; i < rank - 1; i++)
819   {
820     LUCI_ASSERT(diagonal_shape.dim(i) == input_shape.dim(i), "diagonal dims = input dims");
821   }
822
823   auto dim = std::min(input_shape.dim(rank - 1).value(), input_shape.dim(rank).value());
824
825   LUCI_ASSERT(dim == diagonal_shape.dim(rank - 1), "Max diag len error");
826
827   return loco::NodeShape{input_shape};
828 }
829
830 loco::TensorShape infer_reducer(const loco::Node *input, const loco::Node *indices, bool keep_dims)
831 {
832   const loco::DataType S32 = loco::DataType::S32;
833
834   auto input_shape = loco::shape_get(input).as<loco::TensorShape>();
835   auto reduction_indices = loco::must_cast<const luci::CircleConst *>(indices);
836
837   { // Exceptions
838     // TODO support non-const case
839     // TODO support other data type
840     LUCI_ASSERT(reduction_indices->dtype() == S32, "Only support int 32");
841   }
842
843   std::vector<int32_t> reduction_values;
844
845   for (uint32_t i = 0; i < reduction_indices->size<S32>(); ++i)
846   {
847     int32_t axis = reduction_indices->at<S32>(i);
848     if (axis < 0)
849       axis += input_shape.rank();
850     if (not(0 <= axis and axis < static_cast<int32_t>(input_shape.rank())))
851       INTERNAL_EXN_V("Invalid reduction axis for REDUCER", oops::to_uint32(axis));
852     reduction_values.push_back(axis);
853   }
854
855   loco::TensorShape output_shape;
856
857   if (keep_dims)
858   {
859     output_shape.rank(input_shape.rank());
860     for (uint32_t i = 0; i < input_shape.rank(); ++i)
861       output_shape.dim(i) = input_shape.dim(i);
862     for (uint32_t i = 0; i < reduction_values.size(); ++i)
863       output_shape.dim(reduction_values.at(i)) = 1;
864   }
865   else
866   {
867     std::vector<bool> check_reduce(input_shape.rank(), false);
868     for (uint32_t i = 0; i < reduction_values.size(); ++i)
869       check_reduce.at(reduction_values.at(i)) = true;
870
871     uint32_t reduce_cnt = 0;
872     for (uint32_t i = 0; i < check_reduce.size(); ++i)
873       if (check_reduce.at(i))
874         ++reduce_cnt;
875
876     output_shape.rank(input_shape.rank() - reduce_cnt);
877     for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i)
878       if (check_reduce.at(i) == false)
879         output_shape.dim(j++) = input_shape.dim(i);
880   }
881
882   return output_shape;
883 }
884
885 loco::NodeShape infer_mirror_pad(const luci::CircleMirrorPad *node)
886 {
887   // TODO support non-const case
888   auto paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
889   return use_paddings(node, paddings);
890 }
891
892 loco::NodeShape infer_one_hot(const luci::CircleOneHot *node)
893 {
894   const loco::DataType S32 = loco::DataType::S32;
895   auto indices_shape = loco::shape_get(node->indices()).as<loco::TensorShape>();
896   // Only support OneHot node's depth() is CircleConst with type S32
897   // TODO support depth with other types
898   auto depth = loco::must_cast<luci::CircleConst *>(node->depth());
899   LUCI_ASSERT(depth->dtype() == S32, "Only support int32 CircleConst");
900   if (depth->rank() != 0)
901     INTERNAL_EXN_V("Only support rank 0 CircleOneHot in Depth", oops::to_uint32(depth->rank()));
902   loco::TensorShape output_shape;
903   output_shape.rank(indices_shape.rank() + 1);
904   auto axis = node->axis();
905   if (axis < 0)
906     axis += indices_shape.rank() + 1;
907   LUCI_ASSERT(0 <= axis, "Axis is out of range");
908   LUCI_ASSERT(static_cast<uint32_t>(axis) <= indices_shape.rank(), "Axis is out of range");
909   uint32_t j = 0;
910   for (uint32_t i = 0; i < output_shape.rank(); i++)
911   {
912     if (i == static_cast<uint32_t>(axis))
913     {
914       output_shape.dim(i) = depth->at<S32>(0);
915     }
916     else
917     {
918       output_shape.dim(i) = indices_shape.dim(j++);
919     }
920   }
921   return loco::NodeShape{output_shape};
922 }
923
924 loco::NodeShape infer_pack(const luci::CirclePack *node)
925 {
926   LUCI_ASSERT(node->values_count() > 0, "Only support one or more inputs");
927
928   auto first_shape = loco::shape_get(node->values(0)).as<loco::TensorShape>();
929   // Make sure all inputs have the same shape.
930   for (uint32_t i = 1; i < node->values_count(); ++i)
931   {
932     auto in_shape = loco::shape_get(node->values(i)).as<loco::TensorShape>();
933     LUCI_ASSERT(loco::NodeShape{first_shape} == loco::NodeShape{in_shape},
934                 "All inputs must have the same shape");
935   }
936
937   // Checking shape capability for pack layer
938   // Input: tensors [D1, D2, ... Dn]
939   // Axis: K
940   // Output: [D1, D2, ... , D_K-1, n, D_K+1, ... Dn]
941   auto axis = node->axis();
942   if (axis < 0)
943     axis += first_shape.rank() + 1;
944
945   LUCI_ASSERT(0 <= axis, "Axis is out of range");
946   LUCI_ASSERT(static_cast<uint32_t>(axis) <= first_shape.rank(), "Axis is out of range");
947
948   loco::TensorShape output_shape;
949   output_shape.rank(first_shape.rank() + 1);
950
951   uint32_t j = 0;
952   for (uint32_t i = 0; i < output_shape.rank(); ++i)
953   {
954     if (i == static_cast<uint32_t>(axis))
955     {
956       output_shape.dim(i) = node->values_count();
957     }
958     else
959     {
960       output_shape.dim(i) = first_shape.dim(j++);
961     }
962   }
963
964   return loco::NodeShape{output_shape};
965 }
966
967 loco::NodeShape infer_pad(const luci::CirclePad *node)
968 {
969   // TODO support non-const case
970   auto paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
971   return use_paddings(node, paddings);
972 }
973
974 loco::NodeShape infer_pad_v2(const luci::CirclePadV2 *node)
975 {
976   // TODO support non-const case
977   auto paddings = dynamic_cast<luci::CircleConst *>(node->paddings());
978   if (!paddings)
979   {
980     auto node_shape = own_shape(node);
981     return loco::NodeShape{node_shape};
982   }
983   return use_paddings(node, paddings);
984 }
985
986 loco::NodeShape infer_p_relu(const luci::CirclePRelu *node)
987 {
988   auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
989   auto alpha_shape = loco::shape_get(node->alpha()).as<loco::TensorShape>();
990
991   auto output_shape = broadcast_shape(input_shape, alpha_shape);
992
993   return loco::NodeShape{output_shape};
994 }
995
996 loco::NodeShape infer_range(const luci::CircleRange *node)
997 {
998   loco::TensorShape output_shape;
999   output_shape.rank(1);
1000
1001   auto start_node = dynamic_cast<luci::CircleConst *>(node->start());
1002   auto limit_node = dynamic_cast<luci::CircleConst *>(node->limit());
1003   auto delta_node = dynamic_cast<luci::CircleConst *>(node->delta());
1004
1005   if (start_node == nullptr || limit_node == nullptr || delta_node == nullptr)
1006   {
1007     return use_own(node);
1008   }
1009
1010   double start = 0, limit = 0, delta = 0;
1011
1012 #define GET_RANGE_PARAM(DT)         \
1013   start = start_node->scalar<DT>(); \
1014   limit = limit_node->scalar<DT>(); \
1015   delta = delta_node->scalar<DT>();
1016
1017   switch (start_node->dtype())
1018   {
1019     case loco::DataType::FLOAT32:
1020       GET_RANGE_PARAM(loco::DataType::FLOAT32)
1021       break;
1022     case loco::DataType::S32:
1023       GET_RANGE_PARAM(loco::DataType::S32)
1024       break;
1025     default:
1026       INTERNAL_EXN("Range data type not supported");
1027   }
1028
1029 #undef GET_RANGE_PARAM
1030
1031   if (delta == 0)
1032     INTERNAL_EXN("Delta can not be zero");
1033
1034   output_shape.dim(0) = ceil((limit - start) / delta);
1035
1036   return loco::NodeShape{output_shape};
1037 }
1038
1039 loco::NodeShape infer_reshape(const luci::CircleReshape *node)
1040 {
1041   LOGGER(l);
1042
1043   const loco::DataType S32 = loco::DataType::S32;
1044
1045   loco::TensorShape shape_by_input;
1046   {
1047     LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr");
1048
1049     // Only support node's shape() is CircleConst with S32
1050     // TODO support other node with other types
1051     auto const_shape_node = dynamic_cast<luci::CircleConst *>(node->shape());
1052     if (const_shape_node != nullptr)
1053     {
1054       LUCI_ASSERT(const_shape_node->dtype() == S32, "Only support int32 CircleConst");
1055
1056       shape_by_input.rank(const_shape_node->size<S32>());
1057
1058       for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis)
1059       {
1060         shape_by_input.dim(axis) = const_shape_node->at<S32>(axis);
1061       }
1062     }
1063     else
1064     {
1065       // We use shape from the node itself
1066       shape_by_input = own_shape(node);
1067     }
1068   }
1069
1070   loco::TensorShape shape_by_attr;
1071   {
1072     shape_by_attr.rank(node->newShape()->rank());
1073
1074     for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis)
1075     {
1076       shape_by_attr.dim(axis) = node->newShape()->dim(axis);
1077     }
1078   }
1079
1080   if (!(shape_by_input == shape_by_attr))
1081   {
1082     INFO(l) << "CircleReshape: Two new shape information mismatched : " << std::endl;
1083     INFO(l) << "   shape_by_input : " << shape_by_input << std::endl;
1084     INFO(l) << "   shape_by_attr : " << shape_by_attr << std::endl;
1085   }
1086
1087   loco::TensorShape output_shape = shape_by_input;
1088
1089   // One of the dimensions can have special value -1, meaning its actual value should be inferred.
1090   const auto input_shape = loco::shape_get(node->tensor()).as<loco::TensorShape>();
1091   const uint32_t input_element_count = loco::element_count(&input_shape);
1092   uint32_t output_element_count = 1;
1093   uint32_t unknown_dim_index = UINT32_MAX;
1094   for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)
1095   {
1096     const uint32_t dim_value = output_shape.dim(dim_index).value();
1097     if (static_cast<int>(dim_value) == -1)
1098     {
1099       LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension");
1100       unknown_dim_index = dim_index;
1101     }
1102     else
1103     {
1104       output_element_count *= dim_value;
1105     }
1106   }
1107   if (unknown_dim_index != UINT32_MAX)
1108   {
1109     output_shape.dim(unknown_dim_index) = input_element_count / output_element_count;
1110   }
1111
1112   return loco::NodeShape{output_shape};
1113 }
1114
1115 loco::NodeShape infer_resize_bilinear(const luci::CircleResizeBilinear *node)
1116 {
1117   auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1118
1119   if (input_shape.rank() != 4)
1120     INTERNAL_EXN("Expected ResizeBilinear input to have rank 4");
1121
1122   auto *const_node = loco::must_cast<luci::CircleConst *>(node->size());
1123
1124   if (const_node->dtype() != loco::DataType::S32)
1125     INTERNAL_EXN("Only S32 datatype is supported for ResizeBilinear size");
1126
1127   if (const_node->rank() != 1)
1128     INTERNAL_EXN("Expected size tensor of rank 1");
1129
1130   if (const_node->dim(0).value() != 2)
1131     INTERNAL_EXN("Expected size tensor with shape [2]");
1132
1133   loco::TensorShape output_shape;
1134   output_shape.rank(4);
1135   output_shape.dim(0) = input_shape.dim(0);
1136   output_shape.dim(1) = const_node->at<loco::DataType::S32>(0);
1137   output_shape.dim(2) = const_node->at<loco::DataType::S32>(1);
1138   output_shape.dim(3) = input_shape.dim(3);
1139
1140   return loco::NodeShape{output_shape};
1141 }
1142
1143 loco::NodeShape infer_resize_nearest_neighbor(const luci::CircleResizeNearestNeighbor *node)
1144 {
1145   auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1146
1147   if (input_shape.rank() != 4)
1148     INTERNAL_EXN("Expected ResizeNearesNeighbor input to have rank 4");
1149
1150   auto *const_node = loco::must_cast<luci::CircleConst *>(node->size());
1151
1152   if (const_node->dtype() != loco::DataType::S32)
1153     INTERNAL_EXN("Only S32 datatype is supported for ResizeNearesNeighbor size");
1154
1155   if (const_node->rank() != 1)
1156     INTERNAL_EXN("Expected size tensor of rank 1");
1157
1158   if (const_node->dim(0).value() != 2)
1159     INTERNAL_EXN("Expected size tensor with shape [2]");
1160
1161   loco::TensorShape output_shape;
1162   output_shape.rank(4);
1163   output_shape.dim(0) = input_shape.dim(0);
1164   output_shape.dim(1) = const_node->at<loco::DataType::S32>(0);
1165   output_shape.dim(2) = const_node->at<loco::DataType::S32>(1);
1166   output_shape.dim(3) = input_shape.dim(3);
1167
1168   return loco::NodeShape{output_shape};
1169 }
1170
1171 loco::NodeShape infer_scatter_nd(const luci::CircleScatterNd *node)
1172 {
1173   loco::TensorShape output_shape;
1174
1175   auto shape_node = loco::must_cast<luci::CircleConst *>(node->shape());
1176
1177   const loco::DataType S32 = loco::DataType::S32;
1178   const loco::DataType S64 = loco::DataType::S64;
1179
1180   std::vector<int64_t> vect_shape;
1181
1182   if (shape_node->dtype() == S32)
1183     vect_shape = vector_from_constant<S32>(shape_node);
1184   else if (shape_node->dtype() == S64)
1185     vect_shape = vector_from_constant<S64>(shape_node);
1186   else
1187     LUCI_ASSERT(false, "Only support int32/int64 for shape()");
1188
1189   output_shape.rank(vect_shape.size());
1190   for (uint32_t i = 0; i < vect_shape.size(); ++i)
1191     output_shape.dim(i) = vect_shape[i];
1192
1193   return loco::NodeShape{output_shape};
1194 }
1195
1196 loco::NodeShape infer_segment_sum(const luci::CircleSegmentSum *node)
1197 {
1198   auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1199   auto segment_shape = loco::shape_get(node->segment_ids()).as<loco::TensorShape>();
1200
1201   LUCI_ASSERT(segment_shape.rank() == 1, "segment_ids must be 1-D tensor");
1202   LUCI_ASSERT(segment_shape.dim(0).value() == input_shape.dim(0).value(),
1203               "segment_ids size must be equal to the size of data's first dimension");
1204
1205   auto ids_shape_value = loco::must_cast<luci::CircleConst *>(node->segment_ids());
1206
1207   std::vector<int64_t> vect_ids;
1208
1209   if (ids_shape_value->dtype() == loco::DataType::S32)
1210     vect_ids = vector_from_constant<loco::DataType::S32>(ids_shape_value);
1211
1212   LUCI_ASSERT(std::is_sorted(vect_ids.begin(), vect_ids.end()),
1213               "segment_ids values should be sorted")
1214
1215   loco::TensorShape output_shape;
1216
1217   output_shape.rank(input_shape.rank());
1218
1219   for (uint32_t i = 1; i < input_shape.rank(); ++i)
1220     output_shape.dim(i) = input_shape.dim(i);
1221
1222   output_shape.dim(0) = vect_ids.back() + 1;
1223
1224   return loco::NodeShape{output_shape};
1225 }
1226
1227 loco::NodeShape infer_select(const luci::CircleSelect *node)
1228 {
1229   auto t_shape = loco::shape_get(node->t()).as<loco::TensorShape>();
1230   assert(t_shape == loco::shape_get(node->e()).as<loco::TensorShape>());
1231
1232   // condition shape validation
1233   auto c_shape = loco::shape_get(node->condition()).as<loco::TensorShape>();
1234   if (c_shape.rank() != t_shape.rank())
1235   {
1236     if (c_shape.rank() != 0 && c_shape.rank() != 1)
1237       INTERNAL_EXN_V("CircleSelect condition rank is not 0 nor 1: ", c_shape.rank());
1238
1239     if (c_shape.rank() == 1)
1240     {
1241       if (c_shape.dim(0).value() != t_shape.dim(0).value())
1242         INTERNAL_EXN("CircleSelect condition dim(0) should match with t.dim(0)");
1243     }
1244   }
1245
1246   return loco::NodeShape{t_shape};
1247 }
1248
1249 loco::NodeShape infer_select_v2(const luci::CircleSelectV2 *node)
1250 {
1251   auto c_shape = loco::shape_get(node->condition()).as<loco::TensorShape>();
1252   auto t_shape = loco::shape_get(node->t()).as<loco::TensorShape>();
1253   auto e_shape = loco::shape_get(node->e()).as<loco::TensorShape>();
1254
1255   // validate ability to broadcast shapes to each other
1256   auto b_shape = broadcast_shape(broadcast_shape(c_shape, t_shape), e_shape);
1257   return loco::NodeShape{b_shape};
1258 }
1259
1260 loco::NodeShape infer_shape(const luci::CircleShape *node)
1261 {
1262   auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1263
1264   loco::TensorShape output_shape;
1265
1266   output_shape.rank(1);
1267   output_shape.dim(0) = input_shape.rank();
1268
1269   return loco::NodeShape{output_shape};
1270 }
1271
1272 loco::NodeShape infer_slice(const luci::CircleSlice *node)
1273 {
1274   const loco::DataType S32 = loco::DataType::S32;
1275   const loco::DataType S64 = loco::DataType::S64;
1276
1277   auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1278
1279   auto const_begin = loco::must_cast<luci::CircleConst *>(node->begin());
1280   auto const_size = loco::must_cast<luci::CircleConst *>(node->size());
1281
1282   loco::TensorShape output_shape;
1283   std::vector<int64_t> vect_begin; // to hold both S32/S64, we use int64_t
1284   std::vector<int64_t> vect_size;
1285
1286   if (const_begin->dtype() == S32)
1287     vect_begin = vector_from_constant<S32>(const_begin);
1288   else if (const_begin->dtype() == S64)
1289     vect_begin = vector_from_constant<S64>(const_begin);
1290   else
1291     LUCI_ASSERT(false, "Only support int32/int64 for begin()");
1292
1293   if (const_size->dtype() == S32)
1294     vect_size = vector_from_constant<S32>(const_size);
1295   else if (const_size->dtype() == S64)
1296     vect_size = vector_from_constant<S64>(const_size);
1297   else
1298     LUCI_ASSERT(false, "Only support int32/int64 for size()");
1299
1300   assert(input_shape.rank() == vect_begin.size());
1301   assert(input_shape.rank() == vect_size.size());
1302
1303   output_shape.rank(vect_begin.size());
1304   for (uint32_t idx = 0; idx < vect_begin.size(); ++idx)
1305   {
1306     auto size = vect_size.at(idx);
1307     if (size == -1)
1308     {
1309       size = input_shape.dim(idx).value() - vect_begin.at(idx);
1310     }
1311     output_shape.dim(idx) = size;
1312   }
1313
1314   return loco::NodeShape{output_shape};
1315 }
1316
1317 loco::NodeShape infer_space_to_batch_nd(const luci::CircleSpaceToBatchND *node)
1318 {
1319   const loco::DataType S32 = loco::DataType::S32;
1320
1321   auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1322   // Support only input rank is 3 and 4
1323   assert(input_shape.rank() == 3 || input_shape.rank() == 4);
1324
1325   // Only support block_shape() with S32 type CircleConst for now
1326   auto const_block_shape = loco::must_cast<luci::CircleConst *>(node->block_shape());
1327   LUCI_ASSERT(const_block_shape->dtype() == S32, "Only support int32 block_shape");
1328
1329   // Only support paddings() with S32 type CircleConst for now
1330   auto const_paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
1331   LUCI_ASSERT(const_paddings->dtype() == S32, "Only support int32 paddings");
1332
1333   auto const_block_shape_shape = loco::shape_get(const_block_shape).as<loco::TensorShape>();
1334   auto const_paddings_shape = loco::shape_get(const_paddings).as<loco::TensorShape>();
1335   assert(const_block_shape_shape.rank() == 1);
1336   assert(const_paddings_shape.rank() == 2);
1337
1338   int32_t input_spatial_dim = input_shape.rank() - 2;
1339   assert(const_block_shape_shape.dim(0) == input_spatial_dim);
1340   assert(const_paddings_shape.dim(0) == input_spatial_dim);
1341   assert(const_paddings_shape.dim(1) == 2);
1342
1343   // Check all values of block_shape >= 1
1344   uint32_t ele_count = const_block_shape->size<S32>();
1345   for (uint32_t e = 0; e < ele_count; ++e)
1346   {
1347     auto val = const_block_shape->at<S32>(e);
1348     if (val < 1)
1349     {
1350       INTERNAL_EXN_V("All values of block_shape >= 1: ", e);
1351     }
1352   }
1353
1354   loco::TensorShape shape_output;
1355
1356   shape_output.rank(input_shape.rank());
1357
1358   int32_t output_batch_size = input_shape.dim(0).value();
1359   for (int32_t dim = 0; dim < input_spatial_dim; ++dim)
1360   {
1361     int dim_size = input_shape.dim(dim + 1).value();
1362     dim_size += const_paddings->at<S32>(dim * 2);
1363     dim_size += const_paddings->at<S32>(dim * 2 + 1);
1364     shape_output.dim(dim + 1) = dim_size / const_block_shape->at<S32>(dim);
1365
1366     assert(dim_size % const_block_shape->at<S32>(dim) == 0);
1367     output_batch_size = output_batch_size * const_block_shape->at<S32>(dim);
1368   }
1369   shape_output.dim(0) = output_batch_size;
1370   shape_output.dim(input_shape.rank() - 1) = input_shape.dim(input_shape.rank() - 1);
1371
1372   return loco::NodeShape{shape_output};
1373 }
1374
1375 loco::NodeShape infer_space_to_depth(const luci::CircleSpaceToDepth *node)
1376 {
1377   auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1378   LUCI_ASSERT(input_shape.rank() == 4, "Only input rank 4 is supported");
1379
1380   // Only data format NHWC is supported
1381   int32_t height = input_shape.dim(1).value();
1382   int32_t width = input_shape.dim(2).value();
1383   int32_t depth = input_shape.dim(3).value();
1384
1385   int block_size = node->block_size();
1386
1387   if (block_size < 2)
1388     INTERNAL_EXN("Block size must be >= 2");
1389
1390   if ((height % block_size) || (width % block_size))
1391   {
1392     INTERNAL_EXN("The input tensor's height and width must be divisible by block_size");
1393   }
1394
1395   loco::TensorShape output_shape;
1396   output_shape.rank(4);
1397
1398   output_shape.dim(0) = input_shape.dim(0).value();
1399   output_shape.dim(1) = height / block_size;
1400   output_shape.dim(2) = width / block_size;
1401   output_shape.dim(3) = block_size * block_size * depth;
1402
1403   return loco::NodeShape{output_shape};
1404 }
1405
1406 loco::NodeShape infer_sparse_to_dense(const luci::CircleSparseToDense *node)
1407 {
1408   loco::TensorShape shape;
1409   {
1410     LUCI_ASSERT(node->output_shape(), "dims input should not be nullptr");
1411
1412     auto output_shape_node = dynamic_cast<luci::CircleConst *>(node->output_shape());
1413     if (output_shape_node != nullptr)
1414     {
1415       // Only support node with S32
1416       LUCI_ASSERT(output_shape_node->dtype() == loco::DataType::S32,
1417                   "Only support int32 CircleConst");
1418
1419       if (output_shape_node->rank() != 1)
1420         INTERNAL_EXN_V("Only support rank 1 CircleConst",
1421                        oops::to_uint32(output_shape_node->rank()));
1422
1423       shape.rank(output_shape_node->size<loco::DataType::S32>());
1424
1425       for (uint32_t axis = 0; axis < shape.rank(); ++axis)
1426       {
1427         shape.dim(axis) = output_shape_node->at<loco::DataType::S32>(axis);
1428       }
1429     }
1430     else
1431     {
1432       shape = own_shape(node);
1433     }
1434   }
1435
1436   return loco::NodeShape{shape};
1437 }
1438
1439 loco::NodeShape infer_strided_slice(const luci::CircleStridedSlice *node)
1440 {
1441   auto begin_node = dynamic_cast<luci::CircleConst *>(node->begin());
1442   auto end_node = dynamic_cast<luci::CircleConst *>(node->end());
1443   auto strides_node = dynamic_cast<luci::CircleConst *>(node->strides());
1444
1445   if (begin_node == nullptr || end_node == nullptr || strides_node == nullptr)
1446   {
1447     return use_own(node);
1448   }
1449
1450   loco::TensorShape shape = infer_output_shape(node);
1451   return loco::NodeShape{shape};
1452 }
1453
1454 loco::NodeShape infer_squeeze(const luci::CircleSqueeze *node)
1455 {
1456   auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1457
1458   // TODO input shape may be unknown before runtime
1459   std::vector<bool> do_squeeze(input_shape.rank(), false);
1460   uint32_t num_squeezed = 0;
1461
1462   if (!node->squeeze_dims().empty())
1463   {
1464     // SqueezeDims not empty, squeeze only dims specified
1465     for (int32_t raw_dim : node->squeeze_dims())
1466     {
1467       int32_t dim = raw_dim < 0 ? raw_dim + input_shape.rank() : raw_dim;
1468
1469       if (dim < 0 || static_cast<uint32_t>(dim) >= input_shape.rank() ||
1470           input_shape.dim(dim).value() != 1)
1471       {
1472         INTERNAL_EXN("invalid dimention specified to Squeeze");
1473       }
1474
1475       if (!do_squeeze[dim])
1476         ++num_squeezed;
1477       do_squeeze[dim] = true;
1478     }
1479   }
1480   else
1481   {
1482     // SqueezeDims empty, squeeze any dims with size == 1
1483     for (uint32_t dim = 0; dim < input_shape.rank(); ++dim)
1484     {
1485       if (input_shape.dim(dim) == 1)
1486       {
1487         do_squeeze[dim] = true;
1488         ++num_squeezed;
1489       }
1490     }
1491   }
1492
1493   loco::TensorShape output_shape;
1494   output_shape.rank(input_shape.rank() - num_squeezed);
1495
1496   for (uint32_t in_dim = 0, out_dim = 0; in_dim < input_shape.rank(); ++in_dim)
1497   {
1498     if (!do_squeeze[in_dim])
1499     {
1500       output_shape.dim(out_dim++) = input_shape.dim(in_dim);
1501     }
1502   }
1503
1504   return loco::NodeShape{output_shape};
1505 }
1506
1507 loco::NodeShape infer_tile(const luci::CircleTile *node)
1508 {
1509   const loco::DataType S32 = loco::DataType::S32;
1510
1511   auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1512   auto multiples = loco::must_cast<luci::CircleConst *>(node->multiples());
1513
1514   // TODO support non-const case
1515   // TODO support S64 type
1516   LUCI_ASSERT(multiples->dtype() == S32, "Only support int32 multiples");
1517   LUCI_ASSERT(multiples->rank() == 1, "multiples should be rank 1")
1518
1519   uint32_t n = multiples->dim(0).value();
1520
1521   LUCI_ASSERT(n == input_shape.rank(), "length of multiples should be the same with input rank");
1522
1523   loco::TensorShape output_shape;
1524
1525   output_shape.rank(input_shape.rank());
1526   for (uint32_t ni = 0; ni < n; ++ni)
1527   {
1528     int32_t multiple = multiples->at<S32>(ni);
1529     output_shape.dim(ni) = input_shape.dim(ni).value() * static_cast<uint32_t>(multiple);
1530   }
1531
1532   return loco::NodeShape{output_shape};
1533 }
1534
1535 loco::NodeShape infer_transpose(const luci::CircleTranspose *node)
1536 {
1537   auto input_shape = loco::shape_get(node->a()).as<loco::TensorShape>();
1538
1539   auto perm_node = loco::must_cast<luci::CircleConst *>(node->perm());
1540
1541   loco::TensorShape output_shape;
1542   output_shape.rank(input_shape.rank());
1543
1544   assert(perm_node->dtype() == loco::DataType::S32);
1545   assert(input_shape.rank() == perm_node->template size<loco::DataType::S32>());
1546
1547   for (uint32_t out_axis = 0; out_axis < output_shape.rank(); out_axis++)
1548   {
1549     auto in_axis = perm_node->template at<loco::DataType::S32>(out_axis);
1550     output_shape.dim(out_axis) = input_shape.dim(in_axis);
1551   }
1552
1553   return output_shape;
1554 }
1555
1556 loco::NodeShape infer_transpose_conv(const luci::CircleTransposeConv *node)
1557 {
1558   // TransposeConv's output shape is written in its 'inputSizes' argument
1559   auto input_sizes_const = loco::must_cast<luci::CircleConst *>(node->inputSizes());
1560   // TODO support non-const type
1561   LUCI_ASSERT(input_sizes_const->dtype() == loco::DataType::S32, "Only support S32 dtype")
1562   LUCI_ASSERT(input_sizes_const->rank() == 1 && input_sizes_const->dim(0).value() == 4,
1563               "Only support rank 1 with 4 entries")
1564
1565   loco::TensorShape shape;
1566
1567   shape.rank(4);
1568   for (uint32_t axis = 0; axis < 4; ++axis)
1569     shape.dim(axis) = input_sizes_const->at<loco::DataType::S32>(axis);
1570
1571   return loco::NodeShape{shape};
1572 }
1573
1574 loco::NodeShape infer_unpack(const luci::CircleUnpack *node)
1575 {
1576   // CircleUnpack provides list(array) of Tensors which has one less dimension of the input
1577   // We'll set shape of CircleUnpack to shape of actual outputs
1578   // TODO fix this if any problem rises
1579   auto value_shape = loco::shape_get(node->value()).as<loco::TensorShape>();
1580
1581   auto axis = node->axis();
1582   auto num = node->num();
1583   auto rank = static_cast<int32_t>(value_shape.rank());
1584
1585   if (rank == 0)
1586   {
1587     // Unknown shape
1588     return use_own(node);
1589   }
1590
1591   LUCI_ASSERT(-rank <= axis && axis < rank, "Axis is out of range");
1592
1593   if (axis < 0)
1594     axis += rank;
1595
1596   LUCI_ASSERT(num == static_cast<int32_t>(value_shape.dim(axis).value()),
1597               "num, axis maybe incorrect");
1598
1599   loco::TensorShape output_shape;
1600   output_shape.rank(rank - 1);
1601
1602   for (int32_t i = 0, o = 0; i < rank; ++i)
1603   {
1604     if (i != axis)
1605       output_shape.dim(o++) = value_shape.dim(i);
1606   }
1607
1608   return loco::NodeShape{output_shape};
1609 }
1610
1611 loco::NodeShape infer_unidirectionalsequencelstm(const luci::CircleUnidirectionalSequenceLSTM *node)
1612 {
1613   auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1614   auto recurrent_to_output_weights =
1615       loco::shape_get(node->recurrent_to_output_weights()).as<loco::TensorShape>();
1616   auto rank = input_shape.rank();
1617   loco::TensorShape output_shape;
1618   output_shape.rank(rank);
1619   for (uint32_t i = 0; i < rank - 1; i++)
1620   {
1621     output_shape.dim(i) = input_shape.dim(i);
1622   }
1623   output_shape.dim(rank - 1) = recurrent_to_output_weights.dim(1);
1624   return loco::NodeShape{output_shape};
1625 }
1626
1627 loco::NodeShape infer_unique(const luci::CircleUnique *node)
1628 {
1629   auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1630
1631   assert(input_shape.rank() == 1);
1632
1633   loco::TensorShape shape_output;
1634   shape_output = own_shape(node);
1635
1636   return loco::NodeShape{shape_output};
1637 }
1638
1639 // Circle Only
1640 loco::NodeShape infer_bcq_fully_connected(const luci::CircleBCQFullyConnected *node)
1641 {
1642   loco::TensorShape out_shape;
1643
1644   auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1645   auto weights_clusters = loco::must_cast<luci::CircleConst *>(node->weights_clusters());
1646
1647   LUCI_ASSERT(input_shape.rank() == 2, "Input rank of BCQFullyConnected should be 2");
1648
1649   int32_t qbits_sum = 0;
1650   for (uint32_t i = 0; i < weights_clusters->dim(0).value(); ++i)
1651   {
1652     qbits_sum += weights_clusters->at<loco::DataType::S32>(i * 2 + 1);
1653   }
1654
1655   out_shape.rank(2);
1656   out_shape.dim(0) = qbits_sum;
1657   out_shape.dim(1) = input_shape.dim(1);
1658
1659   return loco::NodeShape{out_shape};
1660 }
1661
1662 loco::NodeShape infer_bcq_gather(const luci::CircleBCQGather *node)
1663 {
1664   loco::TensorShape input_shape;
1665   loco::TensorShape output_shape;
1666
1667   const auto input_binary_shape = loco::shape_get(node->input_binary()).as<loco::TensorShape>();
1668   const auto indices_shape = loco::shape_get(node->indices()).as<loco::TensorShape>();
1669   auto axis = node->axis();
1670
1671   auto input_clusters = loco::must_cast<luci::CircleConst *>(node->input_clusters());
1672   auto qbits_sum = 0;
1673   for (uint32_t i = 0; i < input_clusters->dim(0).value(); ++i)
1674   {
1675     qbits_sum += input_clusters->at<loco::DataType::S32>(i * 2 + 1);
1676   }
1677
1678   input_shape.rank(2);
1679   input_shape.dim(0) = qbits_sum;
1680   input_shape.dim(1) = input_binary_shape.dim(1).value() * 32;
1681
1682   output_shape.rank(input_shape.rank() - 1 + indices_shape.rank());
1683   int32_t outdim_index = 0;
1684   for (int32_t i = 0; i < axis; ++i)
1685     output_shape.dim(outdim_index++) = input_shape.dim(i);
1686   for (uint32_t i = 0; i < indices_shape.rank(); ++i)
1687     output_shape.dim(outdim_index++) = indices_shape.dim(i);
1688   for (uint32_t i = axis + 1; i < input_shape.rank(); ++i)
1689     output_shape.dim(outdim_index++) = input_shape.dim(i);
1690
1691   return loco::NodeShape{output_shape};
1692 }
1693
1694 // Virtual
1695 loco::NodeShape infer_input(const luci::CircleInput *node)
1696 {
1697   loco::TensorShape shape;
1698
1699   shape.rank(node->rank());
1700   for (uint32_t axis = 0; axis < node->rank(); axis++)
1701     shape.dim(axis) = node->dim(axis);
1702
1703   return loco::NodeShape{shape};
1704 }
1705
1706 loco::NodeShape infer_output(const luci::CircleOutput *node)
1707 {
1708   auto graph_outputs = node->graph()->outputs();
1709   auto graph_output = graph_outputs->at(node->index());
1710   auto output_shape = graph_output->shape();
1711
1712   return loco::NodeShape{*output_shape};
1713 }
1714
1715 loco::NodeShape infer_if_out(const luci::CircleIfOut *node)
1716 {
1717   /**
1718    * @note  IF operator type and shape are that of the "then" and "else"
1719    *        Graph Outputs.
1720    */
1721   auto circle_if = dynamic_cast<const luci::CircleIf *>(node->input());
1722   if (circle_if == nullptr)
1723   {
1724     INTERNAL_EXN("CircleIf IR is not configured correctly");
1725   }
1726
1727   auto index = node->index();
1728   auto then_graph = circle_if->then_graph();
1729   auto else_graph = circle_if->else_graph();
1730   assert(then_graph != nullptr);
1731   assert(else_graph != nullptr);
1732
1733   // shape and type are assumed to be same
1734   // these are checked at post_import_graph() in Import
1735   auto then_outputs = loco::output_nodes(then_graph);
1736   auto else_outputs = loco::output_nodes(else_graph);
1737   assert(then_outputs.size() == else_outputs.size());
1738   assert(index < static_cast<int32_t>(then_outputs.size()));
1739
1740   auto then_out = loco::must_cast<luci::CircleOutput *>(then_outputs.at(index));
1741   auto else_out = loco::must_cast<luci::CircleOutput *>(else_outputs.at(index));
1742
1743   auto then_graph_outputs = then_graph->outputs(); // loco::GraphOutput items
1744   auto else_graph_outputs = else_graph->outputs();
1745   assert(then_graph_outputs->size() == else_graph_outputs->size());
1746
1747   auto then_graph_output = then_graph_outputs->at(then_out->index());
1748   auto else_graph_output = else_graph_outputs->at(else_out->index());
1749   (void)else_graph_output; // make compiler happy for unused variable warnings
1750   assert(*then_graph_output->shape() == *else_graph_output->shape());
1751
1752   return loco::NodeShape{*then_graph_output->shape()};
1753 }
1754
1755 loco::NodeShape infer_non_max_suppression_v4_out(const luci::CircleNonMaxSuppressionV4Out *node)
1756 {
1757   const loco::DataType S32 = loco::DataType::S32;
1758
1759   auto nmsv4 = dynamic_cast<const luci::CircleNonMaxSuppressionV4 *>(node->input());
1760   if (nmsv4 == nullptr)
1761     INTERNAL_EXN("CircleNonMaxSuppressionV4 IR is not configured correctly");
1762
1763   auto index = node->index();
1764   if (index == 1)
1765     return loco::TensorShape({0});
1766
1767   assert(index == 0);
1768
1769   auto unknown = loco::TensorShape{loco::Dimension()};
1770   auto max_output_size = dynamic_cast<const luci::CircleConst *>(nmsv4->max_output_size());
1771   if (max_output_size == nullptr)
1772     return unknown; // we need CircleConst for max output size
1773
1774   LUCI_ASSERT(max_output_size->dtype() == S32, "Only support int32 for max_output_size");
1775
1776   if (max_output_size->size<S32>() < 1)
1777     return unknown;
1778
1779   auto max_output_size_value = uint32_t(max_output_size->at<S32>(0));
1780   return loco::TensorShape{max_output_size_value};
1781 }
1782
1783 loco::NodeShape infer_non_max_suppression_v5_out(const luci::CircleNonMaxSuppressionV5Out *node)
1784 {
1785   const loco::DataType S32 = loco::DataType::S32;
1786
1787   auto nmsv5 = dynamic_cast<const luci::CircleNonMaxSuppressionV5 *>(node->input());
1788   if (nmsv5 == nullptr)
1789     INTERNAL_EXN("CircleNonMaxSuppressionV5 IR is not configured correctly");
1790
1791   auto index = node->index();
1792   if (index == 2)
1793     return loco::TensorShape({0});
1794
1795   assert(index == 0 || index == 1);
1796
1797   auto unknown = loco::TensorShape{loco::Dimension()};
1798   auto max_output_size = dynamic_cast<const luci::CircleConst *>(nmsv5->max_output_size());
1799   if (max_output_size == nullptr)
1800     return unknown; // we need CircleConst for max output size
1801
1802   LUCI_ASSERT(max_output_size->dtype() == S32, "Only support int32 for max_output_size");
1803
1804   if (max_output_size->size<S32>() < 1)
1805     return unknown;
1806
1807   auto max_output_size_value = uint32_t(max_output_size->at<S32>(0));
1808   return loco::TensorShape{max_output_size_value};
1809 }
1810
1811 loco::NodeShape infer_split_out(const luci::CircleSplitOut *node)
1812 {
1813   const loco::DataType S32 = loco::DataType::S32;
1814
1815   auto split = dynamic_cast<const luci::CircleSplit *>(node->input());
1816   if (split == nullptr)
1817     INTERNAL_EXN("CircleSplit IR is not configured correctly");
1818
1819   loco::NodeShape unknown;
1820
1821   auto split_shape = loco::shape_get(split).as<loco::TensorShape>();
1822
1823   auto split_dim = dynamic_cast<const luci::CircleConst *>(split->split_dim());
1824   if (split_dim == nullptr)
1825     return unknown; // we need CircleConst for split_dim
1826   LUCI_ASSERT(split_dim->dtype() == S32, "Only support int32 for split_dim");
1827
1828   assert(split_dim->size<S32>() == 1);
1829   auto split_dim_axis = split_dim->at<S32>(0);
1830   if (split_dim_axis < 0)
1831     split_dim_axis += split_shape.rank();
1832
1833   auto split_dim_value = split_shape.dim(split_dim_axis).value();
1834   assert(split_dim_value % split->num_split() == 0);
1835   const int split_depth = split_dim_value / split->num_split();
1836
1837   loco::TensorShape output_shape = split_shape;
1838
1839   // All shapes are equally same
1840   output_shape.dim(split_dim_axis) = loco::Dimension(split_depth);
1841
1842   return loco::NodeShape{output_shape};
1843 }
1844
1845 loco::NodeShape infer_split_v_out(const luci::CircleSplitVOut *node)
1846 {
1847   const loco::DataType S32 = loco::DataType::S32;
1848
1849   auto split = dynamic_cast<const luci::CircleSplitV *>(node->input());
1850   if (split == nullptr)
1851     INTERNAL_EXN("CircleSplit IR is not configured correctly");
1852
1853   loco::NodeShape unknown;
1854
1855   auto split_shape = loco::shape_get(split).as<loco::TensorShape>();
1856
1857   auto size_splits = dynamic_cast<const luci::CircleConst *>(split->size_splits());
1858   if (size_splits == nullptr)
1859     return unknown; // we need CircleConst for size_splits
1860   LUCI_ASSERT(size_splits->dtype() == S32, "Only support int32 for size_splits");
1861
1862   auto split_dim = dynamic_cast<const luci::CircleConst *>(split->split_dim());
1863   if (split_dim == nullptr)
1864     return unknown; // we need CircleConst for split_dim
1865   LUCI_ASSERT(split_dim->dtype() == S32, "Only support int32 for split_dim");
1866
1867   // fetch axis
1868   assert(split_dim->size<S32>() == 1);
1869   auto split_dim_axis = split_dim->at<S32>(0);
1870   if (split_dim_axis < 0)
1871     split_dim_axis += split_shape.rank();
1872
1873   // interpret size_splits values
1874   int32_t size_splits_count = static_cast<int32_t>(size_splits->size<S32>());
1875   assert(size_splits_count == split->num_split());
1876
1877   int64_t minus_one_count = 0, size_splits_sum = 0;
1878   for (int32_t idx = 0; idx < size_splits_count; ++idx)
1879   {
1880     auto size = size_splits->at<S32>(idx);
1881     assert(size >= -1);
1882     if (size == -1)
1883       ++minus_one_count;
1884     else
1885       size_splits_sum += size;
1886   }
1887   if (minus_one_count > 1)
1888     INTERNAL_EXN("CircleSplitV size_splits has more than two -1 values");
1889
1890   // calcuate this SplitVOut shape
1891   auto input_size = split_shape.dim(split_dim_axis).value();
1892   assert(size_splits_sum <= input_size);
1893
1894   auto index_this = node->index();
1895   assert(0 <= index_this && index_this < split->num_split());
1896   auto split_depth = size_splits->at<S32>(index_this);
1897   if (split_depth == -1)
1898     split_depth = input_size - size_splits_sum;
1899
1900   loco::TensorShape output_shape = split_shape;
1901
1902   output_shape.dim(split_dim_axis) = loco::Dimension(split_depth);
1903
1904   return loco::NodeShape{output_shape};
1905 }
1906
1907 loco::NodeShape infer_top_k_v2_out(const luci::CircleTopKV2Out *node)
1908 {
1909   const loco::DataType S32 = loco::DataType::S32;
1910
1911   auto topkv2 = dynamic_cast<const luci::CircleTopKV2 *>(node->input());
1912   if (topkv2 == nullptr)
1913     INTERNAL_EXN("CircleSplit IR is not configured correctly");
1914
1915   // shape of topkv2 is same as topkv2->input()
1916   auto input_shape = loco::shape_get(topkv2).as<loco::TensorShape>();
1917
1918   auto node_k = loco::must_cast<const luci::CircleConst *>(topkv2->k());
1919   LUCI_ASSERT(node_k->dtype() == S32, "Only support Int32");
1920   assert(node_k->size<S32>() == 1);
1921
1922   loco::TensorShape output_shape;
1923
1924   output_shape.rank(input_shape.rank());
1925   for (uint32_t idx = 0; idx < input_shape.rank() - 1; ++idx)
1926   {
1927     output_shape.dim(idx) = input_shape.dim(idx);
1928   }
1929   output_shape.dim(input_shape.rank() - 1) = node_k->at<S32>(0);
1930
1931   return loco::NodeShape{output_shape};
1932 }
1933
1934 loco::NodeShape infer_unique_out(const luci::CircleUniqueOut *node)
1935 {
1936   if (node->index() == 0)
1937   {
1938     auto unique_shape = own_shape(node);
1939     return loco::NodeShape{unique_shape};
1940   }
1941   assert(node->index() == 1);
1942   auto unique = loco::must_cast<luci::CircleUnique *>(node->input());
1943   auto unique_shape = loco::shape_get(unique->input()).as<loco::TensorShape>();
1944
1945   assert(unique_shape.rank() == 1);
1946
1947   loco::TensorShape shape_output;
1948   shape_output.rank(1);
1949   shape_output.dim(0) = unique_shape.dim(0);
1950   return loco::NodeShape{shape_output};
1951 }
1952
1953 loco::NodeShape infer_unpack_out(const luci::CircleUnpackOut *node)
1954 {
1955   auto unpack = dynamic_cast<const luci::CircleUnpack *>(node->input());
1956   if (unpack == nullptr)
1957   {
1958     INTERNAL_EXN("CircleUnpack IR is not configured correctly");
1959   }
1960
1961   auto unpack_shape = loco::shape_get(unpack).as<loco::TensorShape>();
1962
1963   return loco::NodeShape{unpack_shape};
1964 }
1965
1966 loco::NodeShape infer_while_out(const luci::CircleWhileOut *node)
1967 {
1968   /**
1969    * @note  WHILE operator's shape is the same with the "cond"
1970    *        Graph input.
1971    */
1972   auto circle_while = dynamic_cast<const luci::CircleWhile *>(node->input());
1973   if (circle_while == nullptr)
1974   {
1975     INTERNAL_EXN("CircleWhile IR is not configured correctly");
1976   }
1977
1978   auto index = node->index();
1979   auto cond_graph = circle_while->cond_graph();
1980   assert(cond_graph != nullptr);
1981
1982   // Assumption: the index of CircleWhileOut matches with the index of input nodes returned by
1983   // loco::input_nodes
1984   auto cond_inputs = loco::input_nodes(cond_graph);
1985   auto cond_in = loco::must_cast<luci::CircleInput *>(cond_inputs.at(index));
1986
1987   auto cond_graph_inputs = cond_graph->inputs();
1988   auto cond_graph_input = cond_graph_inputs->at(cond_in->index());
1989
1990   auto cond_graph_input_shape = *cond_graph_input->shape();
1991   auto this_shape = own_shape(node);
1992
1993   if (!(this_shape == cond_graph_input_shape))
1994   {
1995     LOGGER(l);
1996     WARN(l) << "Warning: CircleWhileOut '" << node->name() << "' shape mispatch " << this_shape
1997             << " vs " << cond_graph_input_shape;
1998   }
1999
2000   return loco::NodeShape{this_shape};
2001 }
2002
2003 /**
2004  * @brief Class to infer the shape of CircleNode
2005  *
2006  * @note All CircleNode's inputs and outputs are always loco::Domain::Tensor
2007  */
2008 class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::NodeShape>
2009 {
2010 public:
2011   loco::NodeShape visit(const luci::CircleAbs *node) final { return use_x(node); }
2012
2013   loco::NodeShape visit(const luci::CircleAdd *node) final { return broadcast_xy(node); }
2014
2015   loco::NodeShape visit(const luci::CircleAddN *node) final { return infer_add_n(node); }
2016
2017   loco::NodeShape visit(const luci::CircleArgMax *node) final { return infer_arg_max(node); }
2018
2019   loco::NodeShape visit(const luci::CircleArgMin *node) final { return infer_arg_min(node); }
2020
2021   loco::NodeShape visit(const luci::CircleAveragePool2D *node) final
2022   {
2023     return infer_pool_2d_shape(node);
2024   }
2025
2026   loco::NodeShape visit(const luci::CircleBatchMatMul *node) final
2027   {
2028     auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
2029     auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
2030
2031     return infer_batchmatmul_shape(x_shape, y_shape, node->adj_x(), node->adj_y());
2032   }
2033
2034   loco::NodeShape visit(const luci::CircleBatchToSpaceND *node) final
2035   {
2036     return infer_batch_to_space_nd(node);
2037   }
2038
2039   loco::NodeShape visit(const luci::CircleCast *node) final { return use_x(node); }
2040
2041   loco::NodeShape visit(const luci::CircleCeil *node) final { return use_x(node); }
2042
2043   loco::NodeShape visit(const luci::CircleConcatenation *node) final
2044   {
2045     return infer_concatenation(node);
2046   }
2047
2048   loco::NodeShape visit(const luci::CircleConst *node) final { return use_own(node); }
2049
2050   loco::NodeShape visit(const luci::CircleConv2D *node) final { return infer_conv2d(node); }
2051
2052   loco::NodeShape visit(const luci::CircleCos *node) final { return use_x(node); }
2053
2054   loco::NodeShape visit(const luci::CircleCustom *node) final { return use_own(node); }
2055
2056   loco::NodeShape visit(const luci::CircleDepthToSpace *node) final
2057   {
2058     return infer_depth_to_space(node);
2059   }
2060
2061   loco::NodeShape visit(const luci::CircleDepthwiseConv2D *node) final
2062   {
2063     return infer_depthwise_conv2d(node);
2064   }
2065
2066   loco::NodeShape visit(const luci::CircleDequantize *node) final
2067   {
2068     const auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
2069     return loco::NodeShape{input_shape};
2070   }
2071
2072   loco::NodeShape visit(const luci::CircleDiv *node) final { return broadcast_xy(node); }
2073
2074   loco::NodeShape visit(const luci::CircleElu *node) final
2075   {
2076     auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
2077
2078     return loco::NodeShape{input_shape};
2079   }
2080
2081   loco::NodeShape visit(const luci::CircleEqual *node) final { return broadcast_xy(node); }
2082
2083   loco::NodeShape visit(const luci::CircleExp *node) final { return use_x(node); }
2084
2085   loco::NodeShape visit(const luci::CircleExpandDims *node) final
2086   {
2087     return infer_expand_dims(node);
2088   }
2089
2090   loco::NodeShape visit(const luci::CircleFill *node) final { return infer_fill(node); }
2091
2092   loco::NodeShape visit(const luci::CircleFloor *node) final { return use_x(node); }
2093
2094   loco::NodeShape visit(const luci::CircleFloorDiv *node) final { return broadcast_xy(node); }
2095
2096   loco::NodeShape visit(const luci::CircleFloorMod *node) final { return broadcast_xy(node); }
2097
2098   loco::NodeShape visit(const luci::CircleFullyConnected *node) final
2099   {
2100     return infer_fully_connected(node);
2101   }
2102
2103   loco::NodeShape visit(const luci::CircleGather *node) final { return infer_gather(node); }
2104
2105   loco::NodeShape visit(const luci::CircleGatherNd *node) final { return infer_gather_nd(node); }
2106
2107   loco::NodeShape visit(const luci::CircleGreater *node) final { return broadcast_xy(node); }
2108
2109   loco::NodeShape visit(const luci::CircleGreaterEqual *node) final { return broadcast_xy(node); }
2110
2111   loco::NodeShape visit(const luci::CircleIf *node) final
2112   {
2113     // Shape of CircleIf is not used. Just use input 0
2114     assert(node->input_count() > 0);
2115     const auto input_shape = loco::shape_get(node->input(0)).as<loco::TensorShape>();
2116     return loco::NodeShape{input_shape};
2117   }
2118
2119   loco::NodeShape visit(const luci::CircleL2Normalize *node) final { return use_x(node); }
2120
2121   loco::NodeShape visit(const luci::CircleL2Pool2D *node) final
2122   {
2123     return infer_pool_2d_shape(node);
2124   }
2125
2126   loco::NodeShape visit(const luci::CircleLeakyRelu *node) final
2127   {
2128     const auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
2129     return loco::NodeShape{input_shape};
2130   }
2131
2132   loco::NodeShape visit(const luci::CircleLess *node) final { return broadcast_xy(node); }
2133
2134   loco::NodeShape visit(const luci::CircleLessEqual *node) final { return broadcast_xy(node); }
2135
2136   loco::NodeShape visit(const luci::CircleLocalResponseNormalization *node) final
2137   {
2138     const auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
2139     return loco::NodeShape{input_shape};
2140   }
2141
2142   loco::NodeShape visit(const luci::CircleLog *node) final { return use_x(node); }
2143
2144   loco::NodeShape visit(const luci::CircleLogicalAnd *node) final { return use_x(node); }
2145
2146   loco::NodeShape visit(const luci::CircleLogicalNot *node) final { return use_x(node); }
2147
2148   loco::NodeShape visit(const luci::CircleLogicalOr *node) final { return use_x(node); }
2149
2150   loco::NodeShape visit(const luci::CircleLogistic *node) final { return use_x(node); }
2151
2152   loco::NodeShape visit(const luci::CircleLogSoftmax *node) final { return use_logits(node); }
2153
2154   loco::NodeShape visit(const luci::CircleMatrixDiag *node) final
2155   {
2156     return infer_matrix_diag(node);
2157   }
2158
2159   loco::NodeShape visit(const luci::CircleMatrixSetDiag *node) final
2160   {
2161     return infer_matrix_set_diag(node);
2162   }
2163
2164   loco::NodeShape visit(const luci::CircleMaximum *node) final { return broadcast_xy(node); }
2165
2166   loco::NodeShape visit(const luci::CircleMaxPool2D *node) final
2167   {
2168     return infer_pool_2d_shape(node);
2169   }
2170
2171   loco::NodeShape visit(const luci::CircleMean *node) final
2172   {
2173     auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2174     return loco::NodeShape{output_shape};
2175   }
2176
2177   loco::NodeShape visit(const luci::CircleMinimum *node) final { return broadcast_xy(node); }
2178
2179   loco::NodeShape visit(const luci::CircleMirrorPad *node) final { return infer_mirror_pad(node); }
2180
2181   loco::NodeShape visit(const luci::CircleMul *node) final { return broadcast_xy(node); }
2182
2183   loco::NodeShape visit(const luci::CircleNeg *node) final { return use_x(node); }
2184
2185   loco::NodeShape visit(const luci::CircleNonMaxSuppressionV4 *node) final
2186   {
2187     const auto boxes_shape = loco::shape_get(node->boxes()).as<loco::TensorShape>();
2188     return loco::NodeShape{boxes_shape};
2189   }
2190
2191   loco::NodeShape visit(const luci::CircleNonMaxSuppressionV5 *node) final
2192   {
2193     const auto boxes_shape = loco::shape_get(node->boxes()).as<loco::TensorShape>();
2194     return loco::NodeShape{boxes_shape};
2195   }
2196
2197   loco::NodeShape visit(const luci::CircleNotEqual *node) final { return broadcast_xy(node); }
2198
2199   loco::NodeShape visit(const luci::CircleOneHot *node) final { return infer_one_hot(node); }
2200
2201   loco::NodeShape visit(const luci::CirclePack *node) final { return infer_pack(node); }
2202
2203   loco::NodeShape visit(const luci::CirclePad *node) final { return infer_pad(node); }
2204
2205   loco::NodeShape visit(const luci::CirclePadV2 *node) final { return infer_pad_v2(node); }
2206
2207   loco::NodeShape visit(const luci::CirclePow *node) final { return broadcast_xy(node); }
2208
2209   loco::NodeShape visit(const luci::CirclePRelu *node) final { return infer_p_relu(node); }
2210
2211   loco::NodeShape visit(const luci::CircleRange *node) final { return infer_range(node); }
2212
2213   loco::NodeShape visit(const luci::CircleRank *) final
2214   {
2215     loco::TensorShape shape_output;
2216     shape_output.rank(0);
2217
2218     return loco::NodeShape{shape_output};
2219   }
2220
2221   loco::NodeShape visit(const luci::CircleReduceAny *node) final
2222   {
2223     auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2224     return loco::NodeShape{output_shape};
2225   }
2226
2227   loco::NodeShape visit(const luci::CircleReduceMax *node) final
2228   {
2229     auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2230     return loco::NodeShape{output_shape};
2231   }
2232
2233   loco::NodeShape visit(const luci::CircleReduceMin *node) final
2234   {
2235     auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2236     return loco::NodeShape{output_shape};
2237   }
2238
2239   loco::NodeShape visit(const luci::CircleReduceProd *node) final
2240   {
2241     auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2242     return loco::NodeShape{output_shape};
2243   }
2244
2245   loco::NodeShape visit(const luci::CircleRelu *node) final
2246   {
2247     auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
2248
2249     return loco::NodeShape{input_shape};
2250   }
2251
2252   loco::NodeShape visit(const luci::CircleRelu6 *node) final
2253   {
2254     auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
2255
2256     return loco::NodeShape{input_shape};
2257   }
2258
2259   loco::NodeShape visit(const luci::CircleReluN1To1 *node) final
2260   {
2261     auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
2262
2263     return loco::NodeShape{input_shape};
2264   }
2265
2266   /**
2267    * @note  CircleReshape has new shape info in two places: 2nd input and attribute.
2268    *        This shape inference uses shape from input 'shape' node when it's constant.
2269    *        If not, shape will be from node itself. shape from attribute is not used.
2270    *
2271    * TODO Change this policy when not appropriate
2272    */
2273   loco::NodeShape visit(const luci::CircleReshape *node) final { return infer_reshape(node); }
2274
2275   loco::NodeShape visit(const luci::CircleResizeBilinear *node) final
2276   {
2277     return infer_resize_bilinear(node);
2278   }
2279
2280   loco::NodeShape visit(const luci::CircleResizeNearestNeighbor *node) final
2281   {
2282     return infer_resize_nearest_neighbor(node);
2283   }
2284
2285   loco::NodeShape visit(const luci::CircleReverseSequence *node) final
2286   {
2287     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
2288
2289     return loco::NodeShape{input_shape};
2290   }
2291
2292   loco::NodeShape visit(const luci::CircleRound *node) final { return use_x(node); }
2293
2294   loco::NodeShape visit(const luci::CircleReverseV2 *node) final
2295   {
2296     auto input_shape = loco::shape_get(node->tensor()).as<loco::TensorShape>();
2297
2298     LUCI_ASSERT(loco::shape_get(node->axis()).as<loco::TensorShape>().rank() == 1,
2299                 "Tensor must be 1-D");
2300
2301     return loco::NodeShape{input_shape};
2302   }
2303
2304   loco::NodeShape visit(const luci::CircleRsqrt *node) final { return use_x(node); }
2305
2306   loco::NodeShape visit(const luci::CircleScatterNd *node) final { return infer_scatter_nd(node); }
2307
2308   loco::NodeShape visit(const luci::CircleSegmentSum *node) final
2309   {
2310     return infer_segment_sum(node);
2311   }
2312
2313   loco::NodeShape visit(const luci::CircleSelect *node) final { return infer_select(node); }
2314
2315   loco::NodeShape visit(const luci::CircleSelectV2 *node) final { return infer_select_v2(node); }
2316
2317   loco::NodeShape visit(const luci::CircleShape *node) final { return infer_shape(node); }
2318
2319   loco::NodeShape visit(const luci::CircleSin *node) final { return use_x(node); }
2320
2321   loco::NodeShape visit(const luci::CircleSlice *node) final { return infer_slice(node); }
2322
2323   loco::NodeShape visit(const luci::CircleSoftmax *node) final { return use_logits(node); }
2324
2325   loco::NodeShape visit(const luci::CircleSpaceToBatchND *node) final
2326   {
2327     return infer_space_to_batch_nd(node);
2328   }
2329
2330   loco::NodeShape visit(const luci::CircleSpaceToDepth *node) final
2331   {
2332     return infer_space_to_depth(node);
2333   }
2334
2335   loco::NodeShape visit(const luci::CircleSparseToDense *node) final
2336   {
2337     return infer_sparse_to_dense(node);
2338   }
2339
2340   loco::NodeShape visit(const luci::CircleSplit *node) final
2341   {
2342     // We'll set Split output as same as input so that SplitOut can handle it's own shape
2343     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
2344     return loco::NodeShape{input_shape};
2345   }
2346
2347   loco::NodeShape visit(const luci::CircleSplitV *node) final
2348   {
2349     // We'll set SplitV output as same as input so that SplitOut can handle it's own shape
2350     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
2351     return loco::NodeShape{input_shape};
2352   }
2353
2354   loco::NodeShape visit(const luci::CircleSqrt *node) final { return use_x(node); }
2355
2356   loco::NodeShape visit(const luci::CircleSquare *node) final { return use_x(node); }
2357
2358   loco::NodeShape visit(const luci::CircleSquaredDifference *node) final
2359   {
2360     return broadcast_xy(node);
2361   }
2362
2363   loco::NodeShape visit(const luci::CircleStridedSlice *node) final
2364   {
2365     return infer_strided_slice(node);
2366   }
2367
2368   loco::NodeShape visit(const luci::CircleSqueeze *node) final { return infer_squeeze(node); }
2369
2370   loco::NodeShape visit(const luci::CircleSub *node) final { return broadcast_xy(node); }
2371
2372   loco::NodeShape visit(const luci::CircleSum *node) final
2373   {
2374     auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2375     return loco::NodeShape{output_shape};
2376   }
2377
2378   loco::NodeShape visit(const luci::CircleTanh *node) final { return use_x(node); }
2379
2380   loco::NodeShape visit(const luci::CircleTile *node) final { return infer_tile(node); }
2381
2382   loco::NodeShape visit(const luci::CircleTopKV2 *node) final
2383   {
2384     // set shape of this node as same as input
2385     const auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
2386     return loco::NodeShape{input_shape};
2387   }
2388
2389   loco::NodeShape visit(const luci::CircleTranspose *node) final { return infer_transpose(node); }
2390
2391   loco::NodeShape visit(const luci::CircleTransposeConv *node) final
2392   {
2393     return infer_transpose_conv(node);
2394   }
2395
2396   loco::NodeShape visit(const luci::CircleUnpack *node) final { return infer_unpack(node); }
2397
2398   loco::NodeShape visit(const luci::CircleUnidirectionalSequenceLSTM *node) final
2399   {
2400     return infer_unidirectionalsequencelstm(node);
2401   }
2402
2403   loco::NodeShape visit(const luci::CircleUnique *node) final { return infer_unique(node); }
2404
2405   loco::NodeShape visit(const luci::CircleWhere *node) final { return use_own(node); }
2406
2407   loco::NodeShape visit(const luci::CircleWhile *node) final
2408   {
2409     // Shape of CircleWhile is not used. Just use input 0
2410     assert(node->arity() > 0);
2411     const auto input_shape = loco::shape_get(node->input(0)).as<loco::TensorShape>();
2412     return loco::NodeShape{input_shape};
2413   }
2414
2415   loco::NodeShape visit(const luci::CircleZerosLike *node) final
2416   {
2417     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
2418
2419     return loco::NodeShape{input_shape};
2420   }
2421
2422   // Circle Only
2423   loco::NodeShape visit(const luci::CircleBCQFullyConnected *node) final
2424   {
2425     return infer_bcq_fully_connected(node);
2426   }
2427
2428   loco::NodeShape visit(const luci::CircleBCQGather *node) final { return infer_bcq_gather(node); }
2429
2430   loco::NodeShape visit(const luci::CircleInstanceNorm *node) final
2431   {
2432     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
2433
2434     return loco::NodeShape{input_shape};
2435   }
2436
2437   // Virtual
2438   loco::NodeShape visit(const luci::CircleInput *node) final { return infer_input(node); }
2439
2440   loco::NodeShape visit(const luci::CircleOutput *node) final { return infer_output(node); }
2441
2442   loco::NodeShape visit(const luci::CircleOutputDummy *node) final { return use_own(node); }
2443
2444   loco::NodeShape visit(const luci::CircleOutputExclude *node) final { return use_own(node); }
2445
2446   loco::NodeShape visit(const luci::CircleCustomOut *node) final { return use_own(node); }
2447
2448   loco::NodeShape visit(const luci::CircleIfOut *node) final { return infer_if_out(node); }
2449
2450   loco::NodeShape visit(const luci::CircleNonMaxSuppressionV4Out *node) final
2451   {
2452     return infer_non_max_suppression_v4_out(node);
2453   }
2454
2455   loco::NodeShape visit(const luci::CircleNonMaxSuppressionV5Out *node) final
2456   {
2457     return infer_non_max_suppression_v5_out(node);
2458   }
2459
2460   loco::NodeShape visit(const luci::CircleSplitOut *node) final { return infer_split_out(node); }
2461
2462   loco::NodeShape visit(const luci::CircleSplitVOut *node) final { return infer_split_v_out(node); }
2463
2464   loco::NodeShape visit(const luci::CircleTopKV2Out *node) final
2465   {
2466     return infer_top_k_v2_out(node);
2467   }
2468
2469   loco::NodeShape visit(const luci::CircleUniqueOut *node) final { return infer_unique_out(node); }
2470
2471   loco::NodeShape visit(const luci::CircleUnpackOut *node) final { return infer_unpack_out(node); }
2472
2473   loco::NodeShape visit(const luci::CircleWhileOut *node) final { return infer_while_out(node); }
2474 };
2475
2476 } // namespace
2477
2478 namespace luci
2479 {
2480
2481 bool CircleShapeInferenceRule::recognize(const loco::Dialect *d) const
2482 {
2483   return CircleDialect::get() == d;
2484 }
2485
2486 bool CircleShapeInferenceRule::infer(const loco::Node *node, loco::NodeShape &shape) const
2487 {
2488   LOGGER(l);
2489
2490   assert(node->dialect() == CircleDialect::get());
2491
2492   ShapeInferenceAlgorithm alg;
2493   auto circle_node = loco::must_cast<const CircleNode *>(node);
2494
2495   bool is_shape_undefined = (circle_node->shape_status() == ShapeStatus::UNDEFINED);
2496   bool is_shape_none = (circle_node->shape_status() == ShapeStatus::NOSHAPE);
2497   bool is_scalar = (circle_node->rank() == 0);
2498
2499   if (is_shape_undefined)
2500     shape = circle_node->accept(&alg);
2501   else
2502   {
2503     if (is_shape_none || is_scalar)
2504       shape = own_shape(circle_node);
2505     else
2506       shape = circle_node->accept(&alg);
2507   }
2508
2509   VERBOSE(l, 1) << "[luci] shape: " << circle_node->name();
2510   VERBOSE(l, 1) << "              own_shape: " << own_shape(circle_node)
2511                 << " -> infer: " << shape.as<loco::TensorShape>();
2512
2513   return true;
2514 }
2515
2516 } // namespace luci