Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / service / src / CircleShapeInferenceRule.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  * Copyright 2017 The TensorFlow Authors. All Rights Reserved.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17
18 #include "luci/Service/CircleShapeInferenceRule.h"
19 #include "Check.h"
20
21 #include "CircleShapeInferenceHelper.h"
22 #include "ShapeInfer_StridedSlice.h"
23
24 #include <luci/IR/CircleNodes.h>
25 #include <luci/IR/CircleDialect.h>
26 #include <luci/IR/CircleNodeVisitor.h>
27 #include <luci/Log.h>
28
29 #include <oops/InternalExn.h>
30
31 #include <algorithm>
32 #include <cassert>
33 #include <cmath>
34 #include <stdexcept>
35
36 namespace
37 {
38
39 std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape)
40 {
41   os << "[";
42   for (uint32_t r = 0; r < tensor_shape.rank(); ++r)
43   {
44     if (r)
45       os << ",";
46
47     if (tensor_shape.dim(r).known())
48       os << tensor_shape.dim(r).value();
49     else
50       os << "?";
51   }
52   os << "]";
53   return os;
54 }
55
56 loco::TensorShape own_shape(const luci::CircleNode *node)
57 {
58   loco::TensorShape shape;
59   shape.rank(node->rank());
60   for (uint32_t r = 0; r < node->rank(); ++r)
61   {
62     // Shape inference rules in this file did not consider unknown dimension.
63     // If some node has unknown dimension, 0 is inserted and wrong shape
64     // inference was done as a result.
65     // To fix this, new shape inference algorithm is being implemented.
66     // Until new inference algorithm is fully implemented, unknown dimension
67     // would be represented as 1 along with TFLite expression.
68     shape.dim(r) = node->dim(r).known() ? node->dim(r).value() : 1;
69   }
70   return shape;
71 }
72
73 loco::NodeShape use_own(const luci::CircleNode *node)
74 {
75   loco::TensorShape shape = own_shape(node);
76   return loco::NodeShape{shape};
77 }
78
79 /**
80  * @brief Create a higher-rank TensorShape following NumPy broadcasting semantics
81  *
82  * HOW TO USE:
83  *
84  *   auto expanded_tensor_shape = expand(tensor_shape).to(N);
85  */
86 class TensorShapeExpander
87 {
88 public:
89   TensorShapeExpander(const loco::TensorShape &shape) : _shape{shape}
90   {
91     // DO NOTHING
92   }
93
94 public:
95   loco::TensorShape to(uint32_t output_rank)
96   {
97     auto const &input_shape = _shape;
98     uint32_t const input_rank = input_shape.rank();
99
100     assert(input_rank <= output_rank && "Cannot shrink rank");
101     uint32_t const axis_shift = output_rank - input_rank;
102
103     loco::TensorShape output_shape;
104
105     output_shape.rank(output_rank);
106     for (uint32_t axis = 0; axis < output_rank; ++axis)
107     {
108       output_shape.dim(axis) = (axis < axis_shift) ? 1 : input_shape.dim(axis - axis_shift);
109     }
110
111     return output_shape;
112   }
113
114 private:
115   const loco::TensorShape _shape;
116 };
117
118 /**
119  * @brief  Expand shape x and y to same rank by align right and filling with 1
120  */
121 void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
122 {
123   auto x_rank = x.rank();
124   auto y_rank = y.rank();
125
126   if (x_rank == y_rank)
127     return;
128
129   TensorShapeExpander x_exp(x);
130   TensorShapeExpander y_exp(y);
131
132   auto xy_rank = std::max(x_rank, y_rank);
133
134   x = x_rank > y_rank ? x : x_exp.to(xy_rank);
135   y = y_rank > x_rank ? y : y_exp.to(xy_rank);
136 }
137
138 /**
139  * @brief  Returns shape of expanded dimension of input x and y having same rank
140  */
141 loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y)
142 {
143   assert(x.rank() == y.rank());
144
145   auto rank = x.rank();
146
147   loco::TensorShape output_shape;
148
149   output_shape.rank(rank);
150   for (uint32_t axis = 0; axis < rank; ++axis)
151   {
152     auto x_dim = x.dim(axis).known() ? x.dim(axis).value() : 1;
153     auto y_dim = y.dim(axis).known() ? y.dim(axis).value() : 1;
154
155     // each dimension of x and y should be same or one must be 1 if different
156     if (!((x_dim == y_dim) || (x_dim == 1 || y_dim == 1)))
157       INTERNAL_EXN("Cannot produce expand_dimension of two shapes");
158
159     output_shape.dim(axis) = std::max(x_dim, y_dim);
160   }
161
162   return output_shape;
163 }
164
165 loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y)
166 {
167   auto x_match = x;
168   auto y_match = y;
169
170   expand_rank(x_match, y_match);
171
172   auto output_shape = expand_dimension(x_match, y_match);
173
174   return output_shape;
175 }
176
177 /**
178  * @brief vector_from_constant will return int64_t vector from CircleConst node
179  */
180 template <loco::DataType T> std::vector<int64_t> vector_from_constant(luci::CircleConst *const_node)
181 {
182   std::vector<int64_t> result;
183
184   for (uint32_t idx = 0; idx < const_node->size<T>(); ++idx)
185     result.push_back(const_node->at<T>(idx));
186
187   return result;
188 }
189
190 template <class CIRCLENODE> loco::NodeShape broadcast_xy(const CIRCLENODE *node)
191 {
192   auto x_shape = luci::shape_get(node->x()).template as<loco::TensorShape>();
193   auto y_shape = luci::shape_get(node->y()).template as<loco::TensorShape>();
194
195   auto output_shape = broadcast_shape(x_shape, y_shape);
196
197   return loco::NodeShape{output_shape};
198 }
199
200 #define DECLARE_USE_SINGLE(NAME)                                                        \
201   template <class CIRCLENODE> loco::NodeShape use_##NAME(const CIRCLENODE *node)        \
202   {                                                                                     \
203     auto inputs_shape = luci::shape_get(node->NAME()).template as<loco::TensorShape>(); \
204     return loco::NodeShape{inputs_shape};                                               \
205   }
206
207 DECLARE_USE_SINGLE(input);
208 DECLARE_USE_SINGLE(inputs);
209 DECLARE_USE_SINGLE(x);
210 DECLARE_USE_SINGLE(logits);
211
212 #undef DECLARE_USE_SINGLE
213
214 template <class CIRCLENODE>
215 loco::NodeShape use_paddings(const CIRCLENODE *node, const luci::CircleConst *paddings)
216 {
217   const loco::DataType S32 = loco::DataType::S32;
218
219   auto input_shape = luci::shape_get(node->input()).template as<loco::TensorShape>();
220
221   // TODO support other data type
222   LUCI_ASSERT(paddings->dtype() == S32, "Only support int 32 for now");
223   LUCI_ASSERT(paddings->rank() == 2, "paddings should be rank 2")
224
225   int32_t n = paddings->dim(0).value();
226   int32_t v = paddings->dim(1).value();
227
228   LUCI_ASSERT(v == 2, "paddings should be [n, 2]");
229   LUCI_ASSERT(n == int32_t(input_shape.rank()),
230               "paddings [n, 2] should have same value of input rank");
231
232   loco::TensorShape output_shape;
233
234   output_shape.rank(input_shape.rank());
235   for (int32_t ni = 0; ni < n; ++ni)
236   {
237     int32_t idx = ni * 2;
238     int value = input_shape.dim(ni).value();
239     value += paddings->at<S32>(idx + 0); // left
240     value += paddings->at<S32>(idx + 1); // right
241     output_shape.dim(ni) = value;
242   }
243
244   return loco::NodeShape{output_shape};
245 }
246
247 loco::NodeShape infer_add_n(const luci::CircleAddN *node)
248 {
249   auto shape = luci::shape_get(node->inputs(0)).as<loco::TensorShape>();
250
251   for (uint32_t idx = 1; idx < node->arity(); ++idx)
252   {
253     auto shape_idx = luci::shape_get(node->inputs(idx)).as<loco::TensorShape>();
254     if (!(shape == shape_idx))
255     {
256       INTERNAL_EXN_V("ADD_N shape not same as the first input: ", idx);
257     }
258   }
259   return loco::NodeShape{shape};
260 }
261
262 template <class CIRCLENODE> loco::NodeShape infer_arg_maxmin(const CIRCLENODE *node)
263 {
264   auto input_shape = luci::shape_get(node->input()).template as<loco::TensorShape>();
265   auto dimension_shape = luci::shape_get(node->dimension()).template as<loco::TensorShape>();
266
267   int64_t select_axis = 0;
268   {
269     LUCI_ASSERT(node->dimension(), "2nd input dimension() should not be nullptr");
270
271     // Only support node's shape() is CircleConst with S32/S64
272     // Support S32 for now.
273     auto const_shape_node = loco::must_cast<luci::CircleConst *>(node->dimension());
274     LUCI_ASSERT(const_shape_node->dtype() == loco::DataType::S32,
275                 "Only support int32 CircleConst for CircleArgMax/CircleArgMin");
276
277     if (const_shape_node->rank() > 1)
278       INTERNAL_EXN_V("Only support rank 0/1 CircleConst",
279                      oops::to_uint32(const_shape_node->rank()));
280
281     select_axis = const_shape_node->template scalar<loco::DataType::S32>();
282   }
283
284   assert(select_axis < input_shape.rank());
285
286   if (select_axis < 0)
287     select_axis += static_cast<int64_t>(input_shape.rank());
288
289   // NOTE select_axis is removed
290   loco::TensorShape shape_output;
291   uint32_t rank = input_shape.rank();
292   uint32_t shrink = static_cast<uint32_t>(select_axis);
293   assert(rank > 0);
294   shape_output.rank(rank - 1);
295   for (uint32_t r = 0, d = 0; r < rank; ++r)
296   {
297     if (r == shrink)
298       continue;
299     shape_output.dim(d++) = input_shape.dim(r);
300   }
301   return loco::NodeShape{shape_output};
302 }
303
304 // Call this for CircleAvgPool2D and CircleMaxPool2D only
305 template <class Pool2DType> loco::NodeShape infer_pool_2d_shape(const Pool2DType *node)
306 {
307   auto ifm_shape = luci::shape_get(node->value()).template as<loco::TensorShape>();
308   assert(ifm_shape.rank() == 4);
309   assert(ifm_shape.dim(1).known());
310   assert(ifm_shape.dim(2).known());
311
312   uint32_t input_height = ifm_shape.dim(1).value();
313   uint32_t input_width = ifm_shape.dim(2).value();
314   uint32_t stride_height = node->stride()->h();
315   uint32_t stride_width = node->stride()->w();
316   uint32_t window_height = node->filter()->h();
317   uint32_t window_width = node->filter()->w();
318   uint32_t dilation_height = 1; // dilation for CircleAvgPool2D and CircleMaxPool2D is 1
319   uint32_t dilation_width = 1;
320   uint32_t effective_window_height = dilation_height * (window_height - 1) + 1;
321   uint32_t effective_window_width = dilation_width * (window_width - 1) + 1;
322
323   uint32_t output_height = 0;
324   uint32_t output_width = 0;
325
326   if (node->padding() == luci::Padding::VALID)
327   {
328     LUCI_ASSERT(input_height + stride_height > effective_window_height, "Invalid shape");
329     LUCI_ASSERT(input_width + stride_width > effective_window_width, "Invalid shape");
330     output_height = (input_height + stride_height - effective_window_height) / stride_height;
331     output_width = (input_width + stride_width - effective_window_width) / stride_width;
332   }
333   else if (node->padding() == luci::Padding::SAME)
334   {
335     output_height = (input_height + stride_height - 1) / stride_height;
336     output_width = (input_width + stride_width - 1) / stride_width;
337   }
338   else
339     LUCI_ASSERT(false, "Wrong padding type");
340
341   loco::TensorShape ofm_shape;
342   ofm_shape.rank(4);
343   ofm_shape.dim(0) = ifm_shape.dim(0);
344   ofm_shape.dim(1) = output_height;
345   ofm_shape.dim(2) = output_width;
346   ofm_shape.dim(3) = ifm_shape.dim(3);
347
348   return loco::NodeShape{ofm_shape};
349 }
350
351 loco::NodeShape infer_batch_to_space_nd(const luci::CircleBatchToSpaceND *node)
352 {
353   const loco::DataType S32 = loco::DataType::S32;
354
355   auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
356   // Support only input rank is 3 and 4
357   assert(input_shape.rank() == 3 || input_shape.rank() == 4);
358
359   // Only support block_shape() with S32 type CircleConst for now
360   auto const_block_shape = loco::must_cast<luci::CircleConst *>(node->block_shape());
361   LUCI_ASSERT(const_block_shape->dtype() == loco::DataType::S32, "Only support int32 block_shape");
362
363   // Only support crops() with S32 type CircleConst for now
364   auto const_crops = loco::must_cast<luci::CircleConst *>(node->crops());
365   LUCI_ASSERT(const_crops->dtype() == loco::DataType::S32, "Only support int32 crops");
366
367   auto const_block_shape_shape = luci::shape_get(const_block_shape).as<loco::TensorShape>();
368   auto const_crops_shape = luci::shape_get(const_crops).as<loco::TensorShape>();
369   assert(const_block_shape_shape.rank() == 1);
370   assert(const_crops_shape.rank() == 2);
371
372   int32_t input_spatial_dim = input_shape.rank() - 2;
373   assert(const_block_shape_shape.dim(0) == input_spatial_dim);
374   assert(const_crops_shape.dim(0) == input_spatial_dim);
375   assert(const_crops_shape.dim(1) == 2);
376
377   loco::TensorShape shape_output;
378
379   shape_output.rank(input_shape.rank());
380
381   int32_t output_batch_size = input_shape.dim(0).value();
382   for (int32_t dim = 0; dim < input_spatial_dim; ++dim)
383   {
384     int dim_size = input_shape.dim(dim + 1).value() * const_block_shape->at<S32>(dim);
385     dim_size -= const_crops->at<S32>(dim * 2);
386     dim_size -= const_crops->at<S32>(dim * 2 + 1);
387     shape_output.dim(dim + 1) = dim_size;
388
389     assert(output_batch_size % const_block_shape->at<S32>(dim) == 0);
390     output_batch_size = output_batch_size / const_block_shape->at<S32>(dim);
391   }
392   shape_output.dim(0) = output_batch_size;
393   shape_output.dim(input_shape.rank() - 1) = input_shape.dim(input_shape.rank() - 1);
394
395   return loco::NodeShape{shape_output};
396 }
397
398 struct OutputSize
399 {
400   uint32_t height = 0;
401   uint32_t width = 0;
402 };
403
404 template <class Conv2DType> OutputSize infer_conv2d_type(const Conv2DType *node)
405 {
406   auto ifm_shape = luci::shape_get(node->input()).template as<loco::TensorShape>();
407   auto ker_shape = luci::shape_get(node->filter()).template as<loco::TensorShape>();
408   assert(ifm_shape.rank() == 4);
409   assert(ker_shape.rank() == 4);
410   assert(ifm_shape.dim(1).known());
411   assert(ifm_shape.dim(2).known());
412   assert(ker_shape.dim(1).known());
413   assert(ker_shape.dim(2).known());
414
415   uint32_t input_height = ifm_shape.dim(1).value();
416   uint32_t input_width = ifm_shape.dim(2).value();
417   uint32_t stride_height = node->stride()->h();
418   uint32_t stride_width = node->stride()->w();
419   uint32_t ker_height = ker_shape.dim(1).value();
420   uint32_t ker_width = ker_shape.dim(2).value();
421   uint32_t dilation_height = node->dilation()->h();
422   uint32_t dilation_width = node->dilation()->w();
423   uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1;
424   uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1;
425
426   uint32_t output_height = 0;
427   uint32_t output_width = 0;
428
429   if (node->padding() == luci::Padding::VALID)
430   {
431     LUCI_ASSERT(input_height + stride_height > effective_ker_height, "Invalid shape");
432     LUCI_ASSERT(input_width + stride_width > effective_ker_width, "Invalid shape");
433     output_height = (input_height + stride_height - effective_ker_height) / stride_height;
434     output_width = (input_width + stride_width - effective_ker_width) / stride_width;
435   }
436   else if (node->padding() == luci::Padding::SAME)
437   {
438     output_height = (input_height + stride_height - 1) / stride_height;
439     output_width = (input_width + stride_width - 1) / stride_width;
440   }
441   else
442     LUCI_ASSERT(false, "Wrong padding type");
443
444   OutputSize os{output_height, output_width};
445
446   return os;
447 }
448
449 // BatchMatMulV2 supports broadcasting in the batch dimensions(BatchMatMul doesn't)
450 // TODO Distinguish BatchMatMul and BatchMatMulV2
451 loco::NodeShape infer_batchmatmul_shape(const loco::TensorShape &x_shape,
452                                         const loco::TensorShape &y_shape, bool adj_x, bool adj_y)
453 {
454   uint32_t x_rank = x_shape.rank();
455   uint32_t y_rank = y_shape.rank();
456   assert(x_rank >= 2 && y_rank >= 2);
457
458   loco::TensorShape output_shape;
459   output_shape.rank(x_shape.rank());
460   // Braodcast in the batch dimension
461   if (x_rank > 2 || y_rank > 2)
462   {
463     loco::TensorShape dummy_x = x_shape;
464     loco::TensorShape dummy_y = y_shape;
465     expand_rank(dummy_x, dummy_y);
466     if (x_rank < y_rank)
467       expand_rank(output_shape, dummy_y);
468
469     for (uint32_t d = 0; d < output_shape.rank() - 2; d++)
470     {
471       uint32_t max_dim = std::max(dummy_x.dim(d).value(), dummy_y.dim(d).value());
472       if (dummy_x.dim(d) == dummy_y.dim(d) ||
473           dummy_x.dim(d).value() * dummy_y.dim(d).value() == max_dim)
474         output_shape.dim(d).set(max_dim);
475       else
476         INTERNAL_EXN("BatchMatMul has wrong shape");
477     }
478   }
479
480   loco::Dimension x_lhs = adj_x ? x_shape.dim(x_rank - 1) : x_shape.dim(x_rank - 2);
481   loco::Dimension x_rhs = adj_x ? x_shape.dim(x_rank - 2) : x_shape.dim(x_rank - 1);
482   loco::Dimension y_lhs = adj_y ? y_shape.dim(y_rank - 1) : y_shape.dim(y_rank - 2);
483   loco::Dimension y_rhs = adj_y ? y_shape.dim(y_rank - 2) : y_shape.dim(y_rank - 1);
484
485   if (x_rhs.known() && y_lhs.known() && not(x_rhs == y_lhs))
486     INTERNAL_EXN("x_rhs and y_lhs should be same");
487
488   uint32_t out_rank = output_shape.rank();
489   output_shape.dim(out_rank - 2) = x_lhs;
490   output_shape.dim(out_rank - 1) = y_rhs;
491
492   return loco::NodeShape{output_shape};
493 }
494
495 loco::NodeShape infer_concatenation(const luci::CircleConcatenation *node)
496 {
497   // TODO Support when CircleConcatenation has 0 input
498   assert(node->numValues() > 0);
499
500   auto first_shape = luci::shape_get(node->values(0)).as<loco::TensorShape>();
501   auto axis = node->axis();
502   if (axis < 0)
503     axis += first_shape.rank();
504
505   assert(0 <= axis);
506   assert(first_shape.rank() > static_cast<uint32_t>(axis));
507
508   loco::TensorShape output_shape;
509
510   output_shape.rank(first_shape.rank());
511   for (uint32_t i = 0; i < output_shape.rank(); ++i)
512     output_shape.dim(i) = first_shape.dim(i);
513
514   for (uint32_t i = 1; i < node->numValues(); ++i)
515   {
516     auto input_shape = luci::shape_get(node->values(i)).as<loco::TensorShape>();
517
518     for (uint32_t j = 0; j < output_shape.rank(); ++j)
519     {
520       if (j == static_cast<uint32_t>(axis))
521       {
522         // If dimension is unknown, value() will return 0.
523         // This is wrong but until new inference algorithm is implemented,
524         // this code will not be modified to keep compatibility.
525         output_shape.dim(j) = output_shape.dim(j).value() + input_shape.dim(j).value();
526       }
527       else
528         assert(!output_shape.dim(j).known() || !input_shape.dim(j).known() ||
529                output_shape.dim(j) == input_shape.dim(j));
530     }
531   }
532
533   return loco::NodeShape{output_shape};
534 }
535
536 loco::NodeShape infer_conv2d(const luci::CircleConv2D *node)
537 {
538   LOGGER(l);
539
540   auto ifm_shape = luci::shape_get(node->input()).as<loco::TensorShape>();  // in NHWC
541   auto ker_shape = luci::shape_get(node->filter()).as<loco::TensorShape>(); // in OHWI
542
543   assert(ifm_shape.rank() == 4);
544   assert(ker_shape.rank() == 4);
545   assert(ifm_shape.dim(3) == ker_shape.dim(3));
546
547   auto os = infer_conv2d_type(node);
548
549   loco::TensorShape ofm_shape;
550   ofm_shape.rank(4);
551   ofm_shape.dim(0) = ifm_shape.dim(0);
552   ofm_shape.dim(1) = os.height;
553   ofm_shape.dim(2) = os.width;
554   ofm_shape.dim(3) = ker_shape.dim(0);
555
556   INFO(l) << "[luci] CircleConv2D ShapeInf ifm(" << ifm_shape.rank() << ") ker(" << ker_shape.rank()
557           << ") output(" << ofm_shape.dim(0).value() << "," << ofm_shape.dim(1).value() << ","
558           << ofm_shape.dim(2).value() << "," << ofm_shape.dim(3).value() << ") " << node->name()
559           << std::endl;
560
561   return loco::NodeShape{ofm_shape};
562 }
563
564 loco::NodeShape infer_depth_to_space(const luci::CircleDepthToSpace *node)
565 {
566   auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
567   LUCI_ASSERT(input_shape.rank() == 4, "Only input rank 4 is supported");
568
569   // Only data format NHWC is supported
570   // TODO need to clarify what to do with layout in this operator
571   int32_t height = input_shape.dim(1).value();
572   int32_t width = input_shape.dim(2).value();
573   int32_t depth = input_shape.dim(3).value();
574
575   int block_size = node->block_size();
576
577   if (block_size < 2)
578     INTERNAL_EXN("Block size must be >= 2");
579
580   if (depth % (block_size * block_size))
581   {
582     INTERNAL_EXN("The input tensor's depth must be divisible by block_size^2");
583   }
584
585   loco::TensorShape output_shape;
586   output_shape.rank(4);
587
588   output_shape.dim(0) = input_shape.dim(0).value();
589   output_shape.dim(1) = height * block_size;
590   output_shape.dim(2) = width * block_size;
591   output_shape.dim(3) = depth / (block_size * block_size);
592
593   return loco::NodeShape{output_shape};
594 }
595
596 loco::NodeShape infer_depthwise_conv2d(const luci::CircleDepthwiseConv2D *node)
597 {
598   auto ifm_shape = luci::shape_get(node->input()).as<loco::TensorShape>();  // in NHWC
599   auto ker_shape = luci::shape_get(node->filter()).as<loco::TensorShape>(); // in 1 H W CM
600
601   assert(ifm_shape.rank() == 4);
602   assert(ker_shape.rank() == 4);
603   assert(ker_shape.dim(0).value() == 1);
604   assert(ifm_shape.dim(3).value() * node->depthMultiplier() == ker_shape.dim(3).value());
605
606   auto os = infer_conv2d_type(node);
607
608   loco::TensorShape ofm_shape;
609   ofm_shape.rank(4);
610   ofm_shape.dim(0) = ifm_shape.dim(0);
611   ofm_shape.dim(1) = os.height;
612   ofm_shape.dim(2) = os.width;
613   ofm_shape.dim(3) = ker_shape.dim(3);
614
615   return loco::NodeShape{ofm_shape};
616 }
617
618 loco::NodeShape infer_expand_dims(const luci::CircleExpandDims *node)
619 {
620   const loco::DataType S32 = loco::DataType::S32;
621   auto x_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
622   if (x_shape.rank() == 0)
623   {
624     // This maybe for unknown shape. We use shape from the node itself.
625     return use_own(node);
626   }
627   auto const_axis = loco::must_cast<luci::CircleConst *>(node->axis());
628   LUCI_ASSERT(const_axis->dtype() == S32, "Only support int32 CircleConst for axis");
629   if (const_axis->rank() != 0 && const_axis->rank() != 1)
630   {
631     INTERNAL_EXN_V("Non-scalar axis in OP", node->opnum());
632   }
633   int32_t axis = const_axis->at<S32>(0);
634   LUCI_ASSERT((axis <= static_cast<int32_t>(x_shape.rank())) &&
635                 (axis >= -1 - static_cast<int32_t>(x_shape.rank())),
636               "Axis has to be between [-(D+1), D], where D is rank of input.");
637   size_t positive_axis = axis < 0 ? x_shape.rank() + axis + 1 : axis;
638   loco::TensorShape output_shape;
639   output_shape.rank(x_shape.rank() + 1);
640   size_t i = 0;
641   for (; i < positive_axis; i++)
642     output_shape.dim(i) = x_shape.dim(i);
643   output_shape.dim(i) = loco::Dimension(1);
644   for (; i < x_shape.rank(); i++)
645     output_shape.dim(i + 1) = x_shape.dim(i);
646   return loco::NodeShape{output_shape};
647 }
648
649 loco::NodeShape infer_fill(const luci::CircleFill *node)
650 {
651   loco::TensorShape shape;
652   {
653     LUCI_ASSERT(node->dims(), "dims input should not be nullptr");
654
655     auto dims_node = dynamic_cast<luci::CircleConst *>(node->dims());
656     if (dims_node != nullptr)
657     {
658       // Only support node with S32
659       LUCI_ASSERT(dims_node->dtype() == loco::DataType::S32, "Only support int32 CircleConst");
660
661       if (dims_node->rank() != 1)
662         INTERNAL_EXN_V("Only support rank 1 CircleConst", oops::to_uint32(dims_node->rank()));
663
664       shape.rank(dims_node->dim(0).value());
665
666       for (uint32_t axis = 0; axis < shape.rank(); ++axis)
667       {
668         shape.dim(axis) = dims_node->at<loco::DataType::S32>(axis);
669       }
670     }
671     else
672     {
673       shape = own_shape(node);
674     }
675   }
676
677   return loco::NodeShape{shape};
678 }
679
680 loco::NodeShape infer_fully_connected(const luci::CircleFullyConnected *node)
681 {
682   auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
683   auto weights_shape = luci::shape_get(node->weights()).as<loco::TensorShape>();
684
685   loco::TensorShape out_shape;
686
687   // NOTE Some recipes in some repositories are using rank 4 input for FullyConnected.
688   //      Until they are all fixed, disable following assert.
689   // TODO Enable following assert after related fixes are applied
690   // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L194
691   // LUCI_ASSERT(input_shape.rank() == 2 || input_shape.rank() == 3,
692   //             "Input rank of FullyConnected should be 2 or 3");
693
694   // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L225
695   LUCI_ASSERT(weights_shape.rank() == 2, "Weights of FullyConnected should be 2");
696
697   // https://github.com/tensorflow/tensorflow/blob/ea33c1e7a25d8025e8ee405ad8ab7be261798d76/tensorflow/lite/kernels/fully_connected.cc#L353-L367
698   if (node->keep_num_dims())
699   {
700     out_shape.rank(input_shape.rank());
701     for (uint32_t i = 0; i < input_shape.rank(); ++i)
702       out_shape.dim(i) = input_shape.dim(i);
703     out_shape.dim(out_shape.rank() - 1) = weights_shape.dim(0);
704   }
705   else
706   {
707     uint32_t input_size = 1;
708     for (uint32_t i = 0; i < input_shape.rank(); i++)
709     {
710       input_size = input_size * input_shape.dim(i).value();
711     }
712     const uint32_t batch_size = input_size / weights_shape.dim(1).value();
713     out_shape.rank(2);
714     out_shape.dim(0) = batch_size;
715     out_shape.dim(1) = weights_shape.dim(0);
716   }
717
718   return loco::NodeShape{out_shape};
719 }
720
721 loco::NodeShape infer_gather(const luci::CircleGather *node)
722 {
723   loco::TensorShape output_shape;
724
725   const auto input_shape = luci::shape_get(node->params()).as<loco::TensorShape>();
726   const auto positions_shape = luci::shape_get(node->indices()).as<loco::TensorShape>();
727   int32_t axis = node->axis();
728
729   // If CircleGather input has a dynamic shape, it can't inference this shape. So, it returns the
730   // shape that node already has.
731   if (input_shape.rank() == 0 || positions_shape.rank() == 0)
732     return use_own(node);
733
734   if (axis < 0)
735     axis += input_shape.rank();
736
737   output_shape.rank(input_shape.rank() - 1 + positions_shape.rank());
738   int32_t outdim_index = 0;
739   for (int32_t i = 0; i < axis; ++i)
740     output_shape.dim(outdim_index++) = input_shape.dim(i);
741   for (uint32_t i = 0; i < positions_shape.rank(); ++i)
742     output_shape.dim(outdim_index++) = positions_shape.dim(i);
743   for (uint32_t i = axis + 1; i < input_shape.rank(); ++i)
744     output_shape.dim(outdim_index++) = input_shape.dim(i);
745
746   return loco::NodeShape{output_shape};
747 }
748
749 loco::NodeShape infer_gather_nd(const luci::CircleGatherNd *node)
750 {
751   loco::TensorShape output_shape;
752
753   const auto params_shape = luci::shape_get(node->params()).as<loco::TensorShape>();
754   const auto indices_shape = luci::shape_get(node->indices()).as<loco::TensorShape>();
755
756   const auto params_rank = params_shape.rank();
757   const auto indices_rank = indices_shape.rank();
758
759   // see https://www.tensorflow.org/api_docs/python/tf/gather_nd
760   // output.shape = indices.shape[:-1] + params.shape[indices.shape[-1]:]
761   // batch_dims isn't supported in tflite
762
763   // TODO: replace exceptions with setting shape to unknown?
764
765   if (!indices_shape.dim(indices_rank - 1).known())
766     INTERNAL_EXN("Last indices dimension is unknown");
767
768   auto indices_last_dim = indices_shape.dim(indices_rank - 1).value();
769
770   if (indices_last_dim > params_rank)
771     INTERNAL_EXN("Last indices dimension should be <= params rank");
772
773   const uint32_t output_rank = indices_rank + params_rank - indices_last_dim - 1;
774
775   output_shape.rank(output_rank);
776
777   uint32_t output_index = 0;
778   for (uint32_t i = 0; i < indices_rank - 1; ++i)
779   {
780     auto &dim = indices_shape.dim(i);
781     if (!dim.known())
782       INTERNAL_EXN("Unknown indices dimension is unsupported");
783     output_shape.dim(output_index++).set(dim.value());
784   }
785
786   for (uint32_t i = indices_last_dim; i < params_rank; ++i)
787   {
788     auto &dim = params_shape.dim(i);
789     if (!dim.known())
790       INTERNAL_EXN("Unknown params dimension is unsupported");
791     output_shape.dim(output_index++).set(dim.value());
792   }
793
794   return loco::NodeShape{output_shape};
795 }
796
797 loco::NodeShape infer_matrix_diag(const luci::CircleMatrixDiag *node)
798 {
799   loco::TensorShape output_shape;
800
801   auto diagonal_shape = luci::shape_get(node->diagonal()).as<loco::TensorShape>();
802   auto rank = diagonal_shape.rank();
803
804   output_shape.rank(rank + 1);
805
806   for (uint32_t i = 0; i < rank; i++)
807   {
808     output_shape.dim(i) = diagonal_shape.dim(i);
809   }
810
811   output_shape.dim(rank) = diagonal_shape.dim(rank - 1);
812
813   return loco::NodeShape{output_shape};
814 }
815
816 loco::NodeShape infer_matrix_set_diag(const luci::CircleMatrixSetDiag *node)
817 {
818   auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
819   auto diagonal_shape = luci::shape_get(node->diagonal()).as<loco::TensorShape>();
820
821   auto rank = diagonal_shape.rank();
822
823   LUCI_ASSERT(rank == input_shape.rank() - 1, "diagonal rank = input rank - 1");
824
825   for (uint32_t i = 0; i < rank - 1; i++)
826   {
827     LUCI_ASSERT(diagonal_shape.dim(i) == input_shape.dim(i), "diagonal dims = input dims");
828   }
829
830   auto dim = std::min(input_shape.dim(rank - 1).value(), input_shape.dim(rank).value());
831
832   LUCI_ASSERT(dim == diagonal_shape.dim(rank - 1), "Max diag len error");
833
834   return loco::NodeShape{input_shape};
835 }
836
837 loco::TensorShape infer_reducer(const loco::Node *input, const loco::Node *indices, bool keep_dims)
838 {
839   const loco::DataType S32 = loco::DataType::S32;
840
841   auto input_shape = luci::shape_get(input).as<loco::TensorShape>();
842   auto reduction_indices = loco::must_cast<const luci::CircleConst *>(indices);
843
844   { // Exceptions
845     // TODO support non-const case
846     // TODO support other data type
847     LUCI_ASSERT(reduction_indices->dtype() == S32, "Only support int 32");
848   }
849
850   std::vector<int32_t> reduction_values;
851
852   for (uint32_t i = 0; i < reduction_indices->size<S32>(); ++i)
853   {
854     int32_t axis = reduction_indices->at<S32>(i);
855     if (axis < 0)
856       axis += input_shape.rank();
857     if (not(0 <= axis and axis < static_cast<int32_t>(input_shape.rank())))
858       INTERNAL_EXN_V("Invalid reduction axis for REDUCER", oops::to_uint32(axis));
859     reduction_values.push_back(axis);
860   }
861
862   loco::TensorShape output_shape;
863
864   if (keep_dims)
865   {
866     output_shape.rank(input_shape.rank());
867     for (uint32_t i = 0; i < input_shape.rank(); ++i)
868       output_shape.dim(i) = input_shape.dim(i);
869     for (uint32_t i = 0; i < reduction_values.size(); ++i)
870       output_shape.dim(reduction_values.at(i)) = 1;
871   }
872   else
873   {
874     std::vector<bool> check_reduce(input_shape.rank(), false);
875     for (uint32_t i = 0; i < reduction_values.size(); ++i)
876       check_reduce.at(reduction_values.at(i)) = true;
877
878     uint32_t reduce_cnt = 0;
879     for (uint32_t i = 0; i < check_reduce.size(); ++i)
880       if (check_reduce.at(i))
881         ++reduce_cnt;
882
883     output_shape.rank(input_shape.rank() - reduce_cnt);
884     for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i)
885       if (check_reduce.at(i) == false)
886         output_shape.dim(j++) = input_shape.dim(i);
887   }
888
889   return output_shape;
890 }
891
892 loco::NodeShape infer_mirror_pad(const luci::CircleMirrorPad *node)
893 {
894   // TODO support non-const case
895   auto paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
896   return use_paddings(node, paddings);
897 }
898
899 loco::NodeShape infer_one_hot(const luci::CircleOneHot *node)
900 {
901   const loco::DataType S32 = loco::DataType::S32;
902   auto indices_shape = luci::shape_get(node->indices()).as<loco::TensorShape>();
903   // Only support OneHot node's depth() is CircleConst with type S32
904   // TODO support depth with other types
905   auto depth = loco::must_cast<luci::CircleConst *>(node->depth());
906   LUCI_ASSERT(depth->dtype() == S32, "Only support int32 CircleConst");
907   if (depth->rank() != 0)
908     INTERNAL_EXN_V("Only support rank 0 CircleOneHot in Depth", oops::to_uint32(depth->rank()));
909   loco::TensorShape output_shape;
910   output_shape.rank(indices_shape.rank() + 1);
911   auto axis = node->axis();
912   if (axis < 0)
913     axis += indices_shape.rank() + 1;
914   LUCI_ASSERT(0 <= axis, "Axis is out of range");
915   LUCI_ASSERT(static_cast<uint32_t>(axis) <= indices_shape.rank(), "Axis is out of range");
916   uint32_t j = 0;
917   for (uint32_t i = 0; i < output_shape.rank(); i++)
918   {
919     if (i == static_cast<uint32_t>(axis))
920     {
921       output_shape.dim(i) = depth->at<S32>(0);
922     }
923     else
924     {
925       output_shape.dim(i) = indices_shape.dim(j++);
926     }
927   }
928   return loco::NodeShape{output_shape};
929 }
930
931 loco::NodeShape infer_pack(const luci::CirclePack *node)
932 {
933   LUCI_ASSERT(node->values_count() > 0, "Only support one or more inputs");
934
935   auto first_shape = luci::shape_get(node->values(0)).as<loco::TensorShape>();
936   // Make sure all inputs have the same shape.
937   for (uint32_t i = 1; i < node->values_count(); ++i)
938   {
939     auto in_shape = luci::shape_get(node->values(i)).as<loco::TensorShape>();
940     LUCI_ASSERT(loco::NodeShape{first_shape} == loco::NodeShape{in_shape},
941                 "All inputs must have the same shape");
942   }
943
944   // Checking shape capability for pack layer
945   // Input: tensors [D1, D2, ... Dn]
946   // Axis: K
947   // Output: [D1, D2, ... , D_K-1, n, D_K+1, ... Dn]
948   auto axis = node->axis();
949   if (axis < 0)
950     axis += first_shape.rank() + 1;
951
952   LUCI_ASSERT(0 <= axis, "Axis is out of range");
953   LUCI_ASSERT(static_cast<uint32_t>(axis) <= first_shape.rank(), "Axis is out of range");
954
955   loco::TensorShape output_shape;
956   output_shape.rank(first_shape.rank() + 1);
957
958   uint32_t j = 0;
959   for (uint32_t i = 0; i < output_shape.rank(); ++i)
960   {
961     if (i == static_cast<uint32_t>(axis))
962     {
963       output_shape.dim(i) = node->values_count();
964     }
965     else
966     {
967       output_shape.dim(i) = first_shape.dim(j++);
968     }
969   }
970
971   return loco::NodeShape{output_shape};
972 }
973
974 loco::NodeShape infer_pad(const luci::CirclePad *node)
975 {
976   // TODO support non-const case
977   auto paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
978   return use_paddings(node, paddings);
979 }
980
981 loco::NodeShape infer_pad_v2(const luci::CirclePadV2 *node)
982 {
983   // TODO support non-const case
984   auto paddings = dynamic_cast<luci::CircleConst *>(node->paddings());
985   if (!paddings)
986   {
987     auto node_shape = own_shape(node);
988     return loco::NodeShape{node_shape};
989   }
990   return use_paddings(node, paddings);
991 }
992
993 loco::NodeShape infer_p_relu(const luci::CirclePRelu *node)
994 {
995   auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
996   auto alpha_shape = luci::shape_get(node->alpha()).as<loco::TensorShape>();
997
998   auto output_shape = broadcast_shape(input_shape, alpha_shape);
999
1000   return loco::NodeShape{output_shape};
1001 }
1002
1003 loco::NodeShape infer_range(const luci::CircleRange *node)
1004 {
1005   loco::TensorShape output_shape;
1006   output_shape.rank(1);
1007
1008   auto start_node = dynamic_cast<luci::CircleConst *>(node->start());
1009   auto limit_node = dynamic_cast<luci::CircleConst *>(node->limit());
1010   auto delta_node = dynamic_cast<luci::CircleConst *>(node->delta());
1011
1012   if (start_node == nullptr || limit_node == nullptr || delta_node == nullptr)
1013   {
1014     return use_own(node);
1015   }
1016
1017   double start = 0, limit = 0, delta = 0;
1018
1019 #define GET_RANGE_PARAM(DT)         \
1020   start = start_node->scalar<DT>(); \
1021   limit = limit_node->scalar<DT>(); \
1022   delta = delta_node->scalar<DT>();
1023
1024   switch (start_node->dtype())
1025   {
1026     case loco::DataType::FLOAT32:
1027       GET_RANGE_PARAM(loco::DataType::FLOAT32)
1028       break;
1029     case loco::DataType::S32:
1030       GET_RANGE_PARAM(loco::DataType::S32)
1031       break;
1032     default:
1033       INTERNAL_EXN("Range data type not supported");
1034   }
1035
1036 #undef GET_RANGE_PARAM
1037
1038   if (delta == 0)
1039     INTERNAL_EXN("Delta can not be zero");
1040
1041   output_shape.dim(0) = ceil((limit - start) / delta);
1042
1043   return loco::NodeShape{output_shape};
1044 }
1045
1046 loco::NodeShape infer_reshape(const luci::CircleReshape *node)
1047 {
1048   LOGGER(l);
1049
1050   const loco::DataType S32 = loco::DataType::S32;
1051
1052   loco::TensorShape shape_by_input;
1053   {
1054     LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr");
1055
1056     // Only support node's shape() is CircleConst with S32
1057     // TODO support other node with other types
1058     auto const_shape_node = dynamic_cast<luci::CircleConst *>(node->shape());
1059     if (const_shape_node != nullptr)
1060     {
1061       LUCI_ASSERT(const_shape_node->dtype() == S32, "Only support int32 CircleConst");
1062
1063       shape_by_input.rank(const_shape_node->size<S32>());
1064
1065       for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis)
1066       {
1067         shape_by_input.dim(axis) = const_shape_node->at<S32>(axis);
1068       }
1069     }
1070     else
1071     {
1072       // We use shape from the node itself
1073       shape_by_input = own_shape(node);
1074     }
1075   }
1076
1077   loco::TensorShape shape_by_attr;
1078   {
1079     shape_by_attr.rank(node->newShape()->rank());
1080
1081     for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis)
1082     {
1083       shape_by_attr.dim(axis) = node->newShape()->dim(axis);
1084     }
1085   }
1086
1087   if (!(shape_by_input == shape_by_attr))
1088   {
1089     INFO(l) << "CircleReshape: Two new shape information mismatched : " << std::endl;
1090     INFO(l) << "   shape_by_input : " << shape_by_input << std::endl;
1091     INFO(l) << "   shape_by_attr : " << shape_by_attr << std::endl;
1092   }
1093
1094   loco::TensorShape output_shape = shape_by_input;
1095
1096   // One of the dimensions can have special value -1, meaning its actual value should be inferred.
1097   const auto input_shape = luci::shape_get(node->tensor()).as<loco::TensorShape>();
1098   uint32_t input_element_count = 1;
1099   uint32_t output_element_count = 1;
1100   uint32_t unknown_dim_index = UINT32_MAX;
1101   for (uint32_t i = 0; i < input_shape.rank(); ++i)
1102     input_element_count *= (input_shape.dim(i).known() ? input_shape.dim(i).value() : 1);
1103   for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)
1104   {
1105     const uint32_t dim_value = output_shape.dim(dim_index).value();
1106     if (static_cast<int>(dim_value) == -1)
1107     {
1108       LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension");
1109       unknown_dim_index = dim_index;
1110     }
1111     else
1112     {
1113       output_element_count *= dim_value;
1114     }
1115   }
1116   if (unknown_dim_index != UINT32_MAX)
1117   {
1118     output_shape.dim(unknown_dim_index) = input_element_count / output_element_count;
1119   }
1120
1121   return loco::NodeShape{output_shape};
1122 }
1123
1124 template <class CIRCLENODE> loco::NodeShape infer_resize_type(const CIRCLENODE *node)
1125 {
1126   auto input_shape = luci::shape_get(node->input()).template as<loco::TensorShape>();
1127
1128   if (input_shape.rank() != 4)
1129     INTERNAL_EXN("Expected input to have rank 4");
1130
1131   auto *const_node = loco::must_cast<luci::CircleConst *>(node->size());
1132
1133   if (const_node->dtype() != loco::DataType::S32)
1134     INTERNAL_EXN("Only S32 datatype is supported for size");
1135
1136   if (const_node->rank() != 1)
1137     INTERNAL_EXN("Expected size tensor of rank 1");
1138
1139   if (const_node->dim(0).value() != 2)
1140     INTERNAL_EXN("Expected size tensor with shape [2]");
1141
1142   loco::TensorShape output_shape;
1143   output_shape.rank(4);
1144   output_shape.dim(0) = input_shape.dim(0);
1145   output_shape.dim(1) = const_node->template at<loco::DataType::S32>(0);
1146   output_shape.dim(2) = const_node->template at<loco::DataType::S32>(1);
1147   output_shape.dim(3) = input_shape.dim(3);
1148
1149   return loco::NodeShape{output_shape};
1150 }
1151
1152 loco::NodeShape infer_scatter_nd(const luci::CircleScatterNd *node)
1153 {
1154   loco::TensorShape output_shape;
1155
1156   auto shape_node = loco::must_cast<luci::CircleConst *>(node->shape());
1157
1158   const loco::DataType S32 = loco::DataType::S32;
1159   const loco::DataType S64 = loco::DataType::S64;
1160
1161   std::vector<int64_t> vect_shape;
1162
1163   if (shape_node->dtype() == S32)
1164     vect_shape = vector_from_constant<S32>(shape_node);
1165   else if (shape_node->dtype() == S64)
1166     vect_shape = vector_from_constant<S64>(shape_node);
1167   else
1168     LUCI_ASSERT(false, "Only support int32/int64 for shape()");
1169
1170   output_shape.rank(vect_shape.size());
1171   for (uint32_t i = 0; i < vect_shape.size(); ++i)
1172     output_shape.dim(i) = vect_shape[i];
1173
1174   return loco::NodeShape{output_shape};
1175 }
1176
1177 loco::NodeShape infer_segment_sum(const luci::CircleSegmentSum *node)
1178 {
1179   auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1180   auto segment_shape = luci::shape_get(node->segment_ids()).as<loco::TensorShape>();
1181
1182   LUCI_ASSERT(segment_shape.rank() == 1, "segment_ids must be 1-D tensor");
1183   LUCI_ASSERT(segment_shape.dim(0).value() == input_shape.dim(0).value(),
1184               "segment_ids size must be equal to the size of data's first dimension");
1185
1186   auto ids_shape_value = loco::must_cast<luci::CircleConst *>(node->segment_ids());
1187
1188   std::vector<int64_t> vect_ids;
1189
1190   if (ids_shape_value->dtype() == loco::DataType::S32)
1191     vect_ids = vector_from_constant<loco::DataType::S32>(ids_shape_value);
1192
1193   LUCI_ASSERT(std::is_sorted(vect_ids.begin(), vect_ids.end()),
1194               "segment_ids values should be sorted")
1195
1196   loco::TensorShape output_shape;
1197
1198   output_shape.rank(input_shape.rank());
1199
1200   for (uint32_t i = 1; i < input_shape.rank(); ++i)
1201     output_shape.dim(i) = input_shape.dim(i);
1202
1203   output_shape.dim(0) = vect_ids.back() + 1;
1204
1205   return loco::NodeShape{output_shape};
1206 }
1207
1208 loco::NodeShape infer_select(const luci::CircleSelect *node)
1209 {
1210   auto t_shape = luci::shape_get(node->t()).as<loco::TensorShape>();
1211   assert(t_shape == luci::shape_get(node->e()).as<loco::TensorShape>());
1212
1213   // condition shape validation
1214   auto c_shape = luci::shape_get(node->condition()).as<loco::TensorShape>();
1215   if (c_shape.rank() != t_shape.rank())
1216   {
1217     if (c_shape.rank() != 0 && c_shape.rank() != 1)
1218       INTERNAL_EXN_V("CircleSelect condition rank is not 0 nor 1: ", c_shape.rank());
1219
1220     if (c_shape.rank() == 1)
1221     {
1222       if (c_shape.dim(0).value() != t_shape.dim(0).value())
1223         INTERNAL_EXN("CircleSelect condition dim(0) should match with t.dim(0)");
1224     }
1225   }
1226
1227   return loco::NodeShape{t_shape};
1228 }
1229
1230 loco::NodeShape infer_select_v2(const luci::CircleSelectV2 *node)
1231 {
1232   auto c_shape = luci::shape_get(node->condition()).as<loco::TensorShape>();
1233   auto t_shape = luci::shape_get(node->t()).as<loco::TensorShape>();
1234   auto e_shape = luci::shape_get(node->e()).as<loco::TensorShape>();
1235
1236   // validate ability to broadcast shapes to each other
1237   auto b_shape = broadcast_shape(broadcast_shape(c_shape, t_shape), e_shape);
1238   return loco::NodeShape{b_shape};
1239 }
1240
1241 loco::NodeShape infer_shape(const luci::CircleShape *node)
1242 {
1243   auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1244
1245   loco::TensorShape output_shape;
1246
1247   output_shape.rank(1);
1248   output_shape.dim(0) = input_shape.rank();
1249
1250   return loco::NodeShape{output_shape};
1251 }
1252
1253 loco::NodeShape infer_slice(const luci::CircleSlice *node)
1254 {
1255   const loco::DataType S32 = loco::DataType::S32;
1256   const loco::DataType S64 = loco::DataType::S64;
1257
1258   auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1259
1260   auto const_begin = loco::must_cast<luci::CircleConst *>(node->begin());
1261   auto const_size = loco::must_cast<luci::CircleConst *>(node->size());
1262
1263   loco::TensorShape output_shape;
1264   std::vector<int64_t> vect_begin; // to hold both S32/S64, we use int64_t
1265   std::vector<int64_t> vect_size;
1266
1267   if (const_begin->dtype() == S32)
1268     vect_begin = vector_from_constant<S32>(const_begin);
1269   else if (const_begin->dtype() == S64)
1270     vect_begin = vector_from_constant<S64>(const_begin);
1271   else
1272     LUCI_ASSERT(false, "Only support int32/int64 for begin()");
1273
1274   if (const_size->dtype() == S32)
1275     vect_size = vector_from_constant<S32>(const_size);
1276   else if (const_size->dtype() == S64)
1277     vect_size = vector_from_constant<S64>(const_size);
1278   else
1279     LUCI_ASSERT(false, "Only support int32/int64 for size()");
1280
1281   assert(input_shape.rank() == vect_begin.size());
1282   assert(input_shape.rank() == vect_size.size());
1283
1284   output_shape.rank(vect_begin.size());
1285   for (uint32_t idx = 0; idx < vect_begin.size(); ++idx)
1286   {
1287     auto size = vect_size.at(idx);
1288     if (size == -1)
1289     {
1290       size = static_cast<int64_t>(input_shape.dim(idx).value()) - vect_begin.at(idx);
1291     }
1292     output_shape.dim(idx) = size;
1293   }
1294
1295   return loco::NodeShape{output_shape};
1296 }
1297
1298 loco::NodeShape infer_space_to_batch_nd(const luci::CircleSpaceToBatchND *node)
1299 {
1300   const loco::DataType S32 = loco::DataType::S32;
1301
1302   auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1303   // Support only input rank is 3 and 4
1304   assert(input_shape.rank() == 3 || input_shape.rank() == 4);
1305
1306   // Only support block_shape() with S32 type CircleConst for now
1307   auto const_block_shape = loco::must_cast<luci::CircleConst *>(node->block_shape());
1308   LUCI_ASSERT(const_block_shape->dtype() == S32, "Only support int32 block_shape");
1309
1310   // Only support paddings() with S32 type CircleConst for now
1311   auto const_paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
1312   LUCI_ASSERT(const_paddings->dtype() == S32, "Only support int32 paddings");
1313
1314   auto const_block_shape_shape = luci::shape_get(const_block_shape).as<loco::TensorShape>();
1315   auto const_paddings_shape = luci::shape_get(const_paddings).as<loco::TensorShape>();
1316   assert(const_block_shape_shape.rank() == 1);
1317   assert(const_paddings_shape.rank() == 2);
1318
1319   int32_t input_spatial_dim = input_shape.rank() - 2;
1320   assert(const_block_shape_shape.dim(0) == input_spatial_dim);
1321   assert(const_paddings_shape.dim(0) == input_spatial_dim);
1322   assert(const_paddings_shape.dim(1) == 2);
1323
1324   // Check all values of block_shape >= 1
1325   uint32_t ele_count = const_block_shape->size<S32>();
1326   for (uint32_t e = 0; e < ele_count; ++e)
1327   {
1328     auto val = const_block_shape->at<S32>(e);
1329     if (val < 1)
1330     {
1331       INTERNAL_EXN_V("All values of block_shape >= 1: ", e);
1332     }
1333   }
1334
1335   loco::TensorShape shape_output;
1336
1337   shape_output.rank(input_shape.rank());
1338
1339   int32_t output_batch_size = input_shape.dim(0).value();
1340   for (int32_t dim = 0; dim < input_spatial_dim; ++dim)
1341   {
1342     int dim_size = input_shape.dim(dim + 1).value();
1343     dim_size += const_paddings->at<S32>(dim * 2);
1344     dim_size += const_paddings->at<S32>(dim * 2 + 1);
1345     shape_output.dim(dim + 1) = dim_size / const_block_shape->at<S32>(dim);
1346
1347     assert(dim_size % const_block_shape->at<S32>(dim) == 0);
1348     output_batch_size = output_batch_size * const_block_shape->at<S32>(dim);
1349   }
1350   shape_output.dim(0) = output_batch_size;
1351   shape_output.dim(input_shape.rank() - 1) = input_shape.dim(input_shape.rank() - 1);
1352
1353   return loco::NodeShape{shape_output};
1354 }
1355
1356 loco::NodeShape infer_space_to_depth(const luci::CircleSpaceToDepth *node)
1357 {
1358   auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1359   LUCI_ASSERT(input_shape.rank() == 4, "Only input rank 4 is supported");
1360
1361   // Only data format NHWC is supported
1362   int32_t height = input_shape.dim(1).value();
1363   int32_t width = input_shape.dim(2).value();
1364   int32_t depth = input_shape.dim(3).value();
1365
1366   int block_size = node->block_size();
1367
1368   if (block_size < 2)
1369     INTERNAL_EXN("Block size must be >= 2");
1370
1371   if ((height % block_size) || (width % block_size))
1372   {
1373     INTERNAL_EXN("The input tensor's height and width must be divisible by block_size");
1374   }
1375
1376   loco::TensorShape output_shape;
1377   output_shape.rank(4);
1378
1379   output_shape.dim(0) = input_shape.dim(0).value();
1380   output_shape.dim(1) = height / block_size;
1381   output_shape.dim(2) = width / block_size;
1382   output_shape.dim(3) = block_size * block_size * depth;
1383
1384   return loco::NodeShape{output_shape};
1385 }
1386
1387 loco::NodeShape infer_sparse_to_dense(const luci::CircleSparseToDense *node)
1388 {
1389   loco::TensorShape shape;
1390   {
1391     LUCI_ASSERT(node->output_shape(), "dims input should not be nullptr");
1392
1393     auto output_shape_node = dynamic_cast<luci::CircleConst *>(node->output_shape());
1394     if (output_shape_node != nullptr)
1395     {
1396       const auto output_shape_type = output_shape_node->dtype();
1397
1398       if (output_shape_node->rank() != 1)
1399         INTERNAL_EXN_V("Only support rank 1 CircleConst",
1400                        oops::to_uint32(output_shape_node->rank()));
1401
1402       if (output_shape_type == loco::DataType::S32)
1403       {
1404         shape.rank(output_shape_node->size<loco::DataType::S32>());
1405
1406         for (uint32_t axis = 0; axis < shape.rank(); ++axis)
1407         {
1408           shape.dim(axis) = output_shape_node->at<loco::DataType::S32>(axis);
1409         }
1410       }
1411       else if (output_shape_type == loco::DataType::S64)
1412       {
1413         shape.rank(output_shape_node->size<loco::DataType::S64>());
1414
1415         for (uint32_t axis = 0; axis < shape.rank(); ++axis)
1416         {
1417           shape.dim(axis) = output_shape_node->at<loco::DataType::S64>(axis);
1418         }
1419       }
1420       else
1421       {
1422         INTERNAL_EXN("Output shape of SparseToDense must be either int32 or int64");
1423       }
1424     }
1425     else
1426     {
1427       shape = own_shape(node);
1428     }
1429   }
1430
1431   return loco::NodeShape{shape};
1432 }
1433
1434 loco::NodeShape infer_strided_slice(const luci::CircleStridedSlice *node)
1435 {
1436   auto begin_node = dynamic_cast<luci::CircleConst *>(node->begin());
1437   auto end_node = dynamic_cast<luci::CircleConst *>(node->end());
1438   auto strides_node = dynamic_cast<luci::CircleConst *>(node->strides());
1439
1440   if (begin_node == nullptr || end_node == nullptr || strides_node == nullptr)
1441   {
1442     return use_own(node);
1443   }
1444
1445   loco::TensorShape shape = infer_output_shape(node);
1446   return loco::NodeShape{shape};
1447 }
1448
1449 loco::NodeShape infer_squeeze(const luci::CircleSqueeze *node)
1450 {
1451   auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1452
1453   // TODO input shape may be unknown before runtime
1454   std::vector<bool> do_squeeze(input_shape.rank(), false);
1455   uint32_t num_squeezed = 0;
1456
1457   if (!node->squeeze_dims().empty())
1458   {
1459     // SqueezeDims not empty, squeeze only dims specified
1460     for (int32_t raw_dim : node->squeeze_dims())
1461     {
1462       int32_t dim = raw_dim < 0 ? raw_dim + input_shape.rank() : raw_dim;
1463
1464       if (dim < 0 || static_cast<uint32_t>(dim) >= input_shape.rank() ||
1465           input_shape.dim(dim).value() != 1)
1466       {
1467         INTERNAL_EXN("invalid dimention specified to Squeeze");
1468       }
1469
1470       if (!do_squeeze[dim])
1471         ++num_squeezed;
1472       do_squeeze[dim] = true;
1473     }
1474   }
1475   else
1476   {
1477     // SqueezeDims empty, squeeze any dims with size == 1
1478     for (uint32_t dim = 0; dim < input_shape.rank(); ++dim)
1479     {
1480       if (input_shape.dim(dim) == 1)
1481       {
1482         do_squeeze[dim] = true;
1483         ++num_squeezed;
1484       }
1485     }
1486   }
1487
1488   loco::TensorShape output_shape;
1489   output_shape.rank(input_shape.rank() - num_squeezed);
1490
1491   for (uint32_t in_dim = 0, out_dim = 0; in_dim < input_shape.rank(); ++in_dim)
1492   {
1493     if (!do_squeeze[in_dim])
1494     {
1495       output_shape.dim(out_dim++) = input_shape.dim(in_dim);
1496     }
1497   }
1498
1499   return loco::NodeShape{output_shape};
1500 }
1501
1502 loco::NodeShape infer_svdf(const luci::CircleSVDF *node)
1503 {
1504   const auto ifm_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1505   const auto weight_feature_shape = luci::shape_get(node->weight_feature()).as<loco::TensorShape>();
1506
1507   assert(ifm_shape.rank() == 2);
1508   assert(weight_feature_shape.rank() == 2);
1509
1510   assert(ifm_shape.dim(1) == weight_feature_shape.dim(1));
1511   assert(weight_feature_shape.dim(0).known());
1512
1513   const auto rank = node->svdf_rank();
1514   const auto num_filters = weight_feature_shape.dim(0).value();
1515   assert(num_filters % rank == 0);
1516   const auto num_units = num_filters / rank;
1517
1518   loco::TensorShape ofm_shape;
1519   ofm_shape.rank(2);
1520   ofm_shape.dim(0) = ifm_shape.dim(0);
1521   ofm_shape.dim(1) = num_units;
1522
1523   return loco::NodeShape{ofm_shape};
1524 }
1525
1526 loco::NodeShape infer_tile(const luci::CircleTile *node)
1527 {
1528   const loco::DataType S32 = loco::DataType::S32;
1529
1530   auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1531   auto multiples = loco::must_cast<luci::CircleConst *>(node->multiples());
1532
1533   // TODO support non-const case
1534   // TODO support S64 type
1535   LUCI_ASSERT(multiples->dtype() == S32, "Only support int32 multiples");
1536   LUCI_ASSERT(multiples->rank() == 1, "multiples should be rank 1")
1537
1538   uint32_t n = multiples->dim(0).value();
1539
1540   LUCI_ASSERT(n == input_shape.rank(), "length of multiples should be the same with input rank");
1541
1542   loco::TensorShape output_shape;
1543
1544   output_shape.rank(input_shape.rank());
1545   for (uint32_t ni = 0; ni < n; ++ni)
1546   {
1547     int32_t multiple = multiples->at<S32>(ni);
1548     output_shape.dim(ni) = input_shape.dim(ni).value() * static_cast<uint32_t>(multiple);
1549   }
1550
1551   return loco::NodeShape{output_shape};
1552 }
1553
1554 loco::NodeShape infer_transpose(const luci::CircleTranspose *node)
1555 {
1556   auto input_shape = luci::shape_get(node->a()).as<loco::TensorShape>();
1557
1558   auto perm_node = loco::must_cast<luci::CircleConst *>(node->perm());
1559
1560   loco::TensorShape output_shape;
1561   output_shape.rank(input_shape.rank());
1562
1563   assert(perm_node->dtype() == loco::DataType::S32);
1564   assert(input_shape.rank() == perm_node->template size<loco::DataType::S32>());
1565
1566   for (uint32_t out_axis = 0; out_axis < output_shape.rank(); out_axis++)
1567   {
1568     auto in_axis = perm_node->template at<loco::DataType::S32>(out_axis);
1569     output_shape.dim(out_axis) = input_shape.dim(in_axis);
1570   }
1571
1572   return output_shape;
1573 }
1574
1575 loco::NodeShape infer_transpose_conv(const luci::CircleTransposeConv *node)
1576 {
1577   // TransposeConv's output shape is written in its 'inputSizes' argument
1578   auto input_sizes_const = loco::must_cast<luci::CircleConst *>(node->inputSizes());
1579   // TODO support non-const type
1580   LUCI_ASSERT(input_sizes_const->dtype() == loco::DataType::S32, "Only support S32 dtype")
1581   LUCI_ASSERT(input_sizes_const->rank() == 1 && input_sizes_const->dim(0).value() == 4,
1582               "Only support rank 1 with 4 entries")
1583
1584   loco::TensorShape shape;
1585
1586   shape.rank(4);
1587   for (uint32_t axis = 0; axis < 4; ++axis)
1588     shape.dim(axis) = input_sizes_const->at<loco::DataType::S32>(axis);
1589
1590   return loco::NodeShape{shape};
1591 }
1592
1593 loco::NodeShape infer_unpack(const luci::CircleUnpack *node)
1594 {
1595   // CircleUnpack provides list(array) of Tensors which has one less dimension of the input
1596   // We'll set shape of CircleUnpack to shape of actual outputs
1597   // TODO fix this if any problem rises
1598   auto value_shape = luci::shape_get(node->value()).as<loco::TensorShape>();
1599
1600   auto axis = node->axis();
1601   auto num = node->num();
1602   auto rank = static_cast<int32_t>(value_shape.rank());
1603
1604   if (rank == 0)
1605   {
1606     // Unknown shape
1607     return use_own(node);
1608   }
1609
1610   LUCI_ASSERT(-rank <= axis && axis < rank, "Axis is out of range");
1611
1612   if (axis < 0)
1613     axis += rank;
1614
1615   LUCI_ASSERT(num == static_cast<int32_t>(value_shape.dim(axis).value()),
1616               "num, axis maybe incorrect");
1617
1618   loco::TensorShape output_shape;
1619   output_shape.rank(rank - 1);
1620
1621   for (int32_t i = 0, o = 0; i < rank; ++i)
1622   {
1623     if (i != axis)
1624       output_shape.dim(o++) = value_shape.dim(i);
1625   }
1626
1627   return loco::NodeShape{output_shape};
1628 }
1629
1630 loco::NodeShape infer_unidirectionalsequencelstm(const luci::CircleUnidirectionalSequenceLSTM *node)
1631 {
1632   auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1633   auto recurrent_to_output_weights =
1634     luci::shape_get(node->recurrent_to_output_weights()).as<loco::TensorShape>();
1635   auto rank = input_shape.rank();
1636   loco::TensorShape output_shape;
1637   output_shape.rank(rank);
1638   for (uint32_t i = 0; i < rank - 1; i++)
1639   {
1640     output_shape.dim(i) = input_shape.dim(i);
1641   }
1642   output_shape.dim(rank - 1) = recurrent_to_output_weights.dim(1);
1643   return loco::NodeShape{output_shape};
1644 }
1645
1646 loco::NodeShape infer_unique(const luci::CircleUnique *node)
1647 {
1648   auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1649
1650   assert(input_shape.rank() == 1);
1651
1652   loco::TensorShape shape_output;
1653   shape_output = own_shape(node);
1654
1655   return loco::NodeShape{shape_output};
1656 }
1657
1658 // Circle Only
1659 loco::NodeShape infer_bcq_fully_connected(const luci::CircleBCQFullyConnected *node)
1660 {
1661   loco::TensorShape out_shape;
1662
1663   auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
1664   auto weights_clusters = loco::must_cast<luci::CircleConst *>(node->weights_clusters());
1665
1666   LUCI_ASSERT(input_shape.rank() == 2, "Input rank of BCQFullyConnected should be 2");
1667
1668   int32_t qbits_sum = 0;
1669   for (uint32_t i = 0; i < weights_clusters->dim(0).value(); ++i)
1670   {
1671     qbits_sum += weights_clusters->at<loco::DataType::S32>(i * 2 + 1);
1672   }
1673
1674   out_shape.rank(2);
1675   out_shape.dim(0) = qbits_sum;
1676   out_shape.dim(1) = input_shape.dim(1);
1677
1678   return loco::NodeShape{out_shape};
1679 }
1680
1681 loco::NodeShape infer_bcq_gather(const luci::CircleBCQGather *node)
1682 {
1683   loco::TensorShape input_shape;
1684   loco::TensorShape output_shape;
1685
1686   const auto input_binary_shape = luci::shape_get(node->input_binary()).as<loco::TensorShape>();
1687   const auto indices_shape = luci::shape_get(node->indices()).as<loco::TensorShape>();
1688   auto axis = node->axis();
1689
1690   auto input_clusters = loco::must_cast<luci::CircleConst *>(node->input_clusters());
1691   auto qbits_sum = 0;
1692   for (uint32_t i = 0; i < input_clusters->dim(0).value(); ++i)
1693   {
1694     qbits_sum += input_clusters->at<loco::DataType::S32>(i * 2 + 1);
1695   }
1696
1697   input_shape.rank(2);
1698   input_shape.dim(0) = qbits_sum;
1699   input_shape.dim(1) = input_binary_shape.dim(1).value() * 32;
1700
1701   output_shape.rank(input_shape.rank() - 1 + indices_shape.rank());
1702   int32_t outdim_index = 0;
1703   for (int32_t i = 0; i < axis; ++i)
1704     output_shape.dim(outdim_index++) = input_shape.dim(i);
1705   for (uint32_t i = 0; i < indices_shape.rank(); ++i)
1706     output_shape.dim(outdim_index++) = indices_shape.dim(i);
1707   for (uint32_t i = axis + 1; i < input_shape.rank(); ++i)
1708     output_shape.dim(outdim_index++) = input_shape.dim(i);
1709
1710   return loco::NodeShape{output_shape};
1711 }
1712
1713 // Virtual
1714 loco::NodeShape infer_input(const luci::CircleInput *node)
1715 {
1716   loco::TensorShape shape;
1717
1718   shape.rank(node->rank());
1719   for (uint32_t axis = 0; axis < node->rank(); axis++)
1720     shape.dim(axis) = node->dim(axis);
1721
1722   return loco::NodeShape{shape};
1723 }
1724
1725 loco::NodeShape infer_output(const luci::CircleOutput *node)
1726 {
1727   auto graph_outputs = node->graph()->outputs();
1728   auto graph_output = graph_outputs->at(node->index());
1729   auto output_shape = graph_output->shape();
1730
1731   return loco::NodeShape{*output_shape};
1732 }
1733
1734 loco::NodeShape infer_non_max_suppression_v4_out(const luci::CircleNonMaxSuppressionV4Out *node)
1735 {
1736   const loco::DataType S32 = loco::DataType::S32;
1737
1738   auto nmsv4 = dynamic_cast<const luci::CircleNonMaxSuppressionV4 *>(node->input());
1739   if (nmsv4 == nullptr)
1740     INTERNAL_EXN("CircleNonMaxSuppressionV4 IR is not configured correctly");
1741
1742   auto index = node->index();
1743   if (index == 1)
1744     return loco::TensorShape({0});
1745
1746   assert(index == 0);
1747
1748   auto unknown = loco::TensorShape{loco::Dimension()};
1749   auto max_output_size = dynamic_cast<const luci::CircleConst *>(nmsv4->max_output_size());
1750   if (max_output_size == nullptr)
1751     return unknown; // we need CircleConst for max output size
1752
1753   LUCI_ASSERT(max_output_size->dtype() == S32, "Only support int32 for max_output_size");
1754
1755   if (max_output_size->size<S32>() < 1)
1756     return unknown;
1757
1758   auto max_output_size_value = uint32_t(max_output_size->at<S32>(0));
1759   return loco::TensorShape{max_output_size_value};
1760 }
1761
1762 loco::NodeShape infer_non_max_suppression_v5_out(const luci::CircleNonMaxSuppressionV5Out *node)
1763 {
1764   const loco::DataType S32 = loco::DataType::S32;
1765
1766   auto nmsv5 = dynamic_cast<const luci::CircleNonMaxSuppressionV5 *>(node->input());
1767   if (nmsv5 == nullptr)
1768     INTERNAL_EXN("CircleNonMaxSuppressionV5 IR is not configured correctly");
1769
1770   auto index = node->index();
1771   if (index == 2)
1772     return loco::TensorShape({0});
1773
1774   assert(index == 0 || index == 1);
1775
1776   auto unknown = loco::TensorShape{loco::Dimension()};
1777   auto max_output_size = dynamic_cast<const luci::CircleConst *>(nmsv5->max_output_size());
1778   if (max_output_size == nullptr)
1779     return unknown; // we need CircleConst for max output size
1780
1781   LUCI_ASSERT(max_output_size->dtype() == S32, "Only support int32 for max_output_size");
1782
1783   if (max_output_size->size<S32>() < 1)
1784     return unknown;
1785
1786   auto max_output_size_value = uint32_t(max_output_size->at<S32>(0));
1787   return loco::TensorShape{max_output_size_value};
1788 }
1789
1790 loco::NodeShape infer_split_out(const luci::CircleSplitOut *node)
1791 {
1792   const loco::DataType S32 = loco::DataType::S32;
1793
1794   auto split = dynamic_cast<const luci::CircleSplit *>(node->input());
1795   if (split == nullptr)
1796     INTERNAL_EXN("CircleSplit IR is not configured correctly");
1797
1798   loco::NodeShape unknown;
1799
1800   auto split_shape = luci::shape_get(split).as<loco::TensorShape>();
1801
1802   auto split_dim = dynamic_cast<const luci::CircleConst *>(split->split_dim());
1803   if (split_dim == nullptr)
1804     return unknown; // we need CircleConst for split_dim
1805   LUCI_ASSERT(split_dim->dtype() == S32, "Only support int32 for split_dim");
1806
1807   assert(split_dim->size<S32>() == 1);
1808   auto split_dim_axis = split_dim->at<S32>(0);
1809   if (split_dim_axis < 0)
1810     split_dim_axis += split_shape.rank();
1811
1812   auto split_dim_value = split_shape.dim(split_dim_axis).value();
1813   assert(split_dim_value % split->num_split() == 0);
1814   const int split_depth = split_dim_value / split->num_split();
1815
1816   loco::TensorShape output_shape = split_shape;
1817
1818   // All shapes are equally same
1819   output_shape.dim(split_dim_axis) = loco::Dimension(split_depth);
1820
1821   return loco::NodeShape{output_shape};
1822 }
1823
1824 loco::NodeShape infer_split_v_out(const luci::CircleSplitVOut *node)
1825 {
1826   const loco::DataType S32 = loco::DataType::S32;
1827
1828   auto split = dynamic_cast<const luci::CircleSplitV *>(node->input());
1829   if (split == nullptr)
1830     INTERNAL_EXN("CircleSplit IR is not configured correctly");
1831
1832   loco::NodeShape unknown;
1833
1834   auto split_shape = luci::shape_get(split).as<loco::TensorShape>();
1835
1836   auto size_splits = dynamic_cast<const luci::CircleConst *>(split->size_splits());
1837   if (size_splits == nullptr)
1838     return unknown; // we need CircleConst for size_splits
1839   LUCI_ASSERT(size_splits->dtype() == S32, "Only support int32 for size_splits");
1840
1841   auto split_dim = dynamic_cast<const luci::CircleConst *>(split->split_dim());
1842   if (split_dim == nullptr)
1843     return unknown; // we need CircleConst for split_dim
1844   LUCI_ASSERT(split_dim->dtype() == S32, "Only support int32 for split_dim");
1845
1846   // fetch axis
1847   assert(split_dim->size<S32>() == 1);
1848   auto split_dim_axis = split_dim->at<S32>(0);
1849   if (split_dim_axis < 0)
1850     split_dim_axis += split_shape.rank();
1851
1852   // interpret size_splits values
1853   int32_t size_splits_count = static_cast<int32_t>(size_splits->size<S32>());
1854   assert(size_splits_count == split->num_split());
1855
1856   int64_t minus_one_count = 0, size_splits_sum = 0;
1857   for (int32_t idx = 0; idx < size_splits_count; ++idx)
1858   {
1859     auto size = size_splits->at<S32>(idx);
1860     assert(size >= -1);
1861     if (size == -1)
1862       ++minus_one_count;
1863     else
1864       size_splits_sum += size;
1865   }
1866   if (minus_one_count > 1)
1867     INTERNAL_EXN("CircleSplitV size_splits has more than two -1 values");
1868
1869   // calcuate this SplitVOut shape
1870   auto input_size = split_shape.dim(split_dim_axis).value();
1871   assert(size_splits_sum <= input_size);
1872
1873   auto index_this = node->index();
1874   assert(0 <= index_this && index_this < split->num_split());
1875   auto split_depth = size_splits->at<S32>(index_this);
1876   if (split_depth == -1)
1877     split_depth = static_cast<int32_t>(input_size) - static_cast<int32_t>(size_splits_sum);
1878
1879   loco::TensorShape output_shape = split_shape;
1880
1881   output_shape.dim(split_dim_axis) = loco::Dimension(split_depth);
1882
1883   return loco::NodeShape{output_shape};
1884 }
1885
1886 loco::NodeShape infer_top_k_v2_out(const luci::CircleTopKV2Out *node)
1887 {
1888   const loco::DataType S32 = loco::DataType::S32;
1889
1890   auto topkv2 = dynamic_cast<const luci::CircleTopKV2 *>(node->input());
1891   if (topkv2 == nullptr)
1892     INTERNAL_EXN("CircleSplit IR is not configured correctly");
1893
1894   // shape of topkv2 is same as topkv2->input()
1895   auto input_shape = luci::shape_get(topkv2).as<loco::TensorShape>();
1896
1897   auto node_k = loco::must_cast<const luci::CircleConst *>(topkv2->k());
1898   LUCI_ASSERT(node_k->dtype() == S32, "Only support Int32");
1899   assert(node_k->size<S32>() == 1);
1900
1901   loco::TensorShape output_shape;
1902
1903   output_shape.rank(input_shape.rank());
1904   for (uint32_t idx = 0; idx < input_shape.rank() - 1; ++idx)
1905   {
1906     output_shape.dim(idx) = input_shape.dim(idx);
1907   }
1908   output_shape.dim(input_shape.rank() - 1) = node_k->at<S32>(0);
1909
1910   return loco::NodeShape{output_shape};
1911 }
1912
1913 loco::NodeShape infer_unique_out(const luci::CircleUniqueOut *node)
1914 {
1915   if (node->index() == 0)
1916   {
1917     auto unique_shape = own_shape(node);
1918     return loco::NodeShape{unique_shape};
1919   }
1920   assert(node->index() == 1);
1921   auto unique = loco::must_cast<luci::CircleUnique *>(node->input());
1922   auto unique_shape = luci::shape_get(unique->input()).as<loco::TensorShape>();
1923
1924   assert(unique_shape.rank() == 1);
1925
1926   loco::TensorShape shape_output;
1927   shape_output.rank(1);
1928   shape_output.dim(0) = unique_shape.dim(0);
1929   return loco::NodeShape{shape_output};
1930 }
1931
1932 loco::NodeShape infer_unpack_out(const luci::CircleUnpackOut *node)
1933 {
1934   auto unpack = dynamic_cast<const luci::CircleUnpack *>(node->input());
1935   if (unpack == nullptr)
1936   {
1937     INTERNAL_EXN("CircleUnpack IR is not configured correctly");
1938   }
1939
1940   auto unpack_shape = luci::shape_get(unpack).as<loco::TensorShape>();
1941
1942   return loco::NodeShape{unpack_shape};
1943 }
1944
1945 loco::NodeShape infer_while_out(const luci::CircleWhileOut *node)
1946 {
1947   /**
1948    * @note  WHILE operator's shape is the same with the "cond"
1949    *        Graph input.
1950    */
1951   auto circle_while = dynamic_cast<const luci::CircleWhile *>(node->input());
1952   if (circle_while == nullptr)
1953   {
1954     INTERNAL_EXN("CircleWhile IR is not configured correctly");
1955   }
1956
1957   auto index = node->index();
1958   auto cond_graph = circle_while->cond_graph();
1959   assert(cond_graph != nullptr);
1960
1961   // Assumption: the index of CircleWhileOut matches with the index of input nodes returned by
1962   // loco::input_nodes
1963   auto cond_inputs = loco::input_nodes(cond_graph);
1964   auto cond_in = loco::must_cast<luci::CircleInput *>(cond_inputs.at(index));
1965
1966   auto cond_graph_inputs = cond_graph->inputs();
1967   auto cond_graph_input = cond_graph_inputs->at(cond_in->index());
1968
1969   auto cond_graph_input_shape = *cond_graph_input->shape();
1970   auto this_shape = own_shape(node);
1971
1972   if (!(this_shape == cond_graph_input_shape))
1973   {
1974     LOGGER(l);
1975     WARN(l) << "Warning: CircleWhileOut '" << node->name() << "' shape mispatch " << this_shape
1976             << " vs " << cond_graph_input_shape;
1977   }
1978
1979   return loco::NodeShape{this_shape};
1980 }
1981
1982 /**
1983  * @brief Class to infer the shape of CircleNode
1984  *
1985  * @note All CircleNode's inputs and outputs are always loco::Domain::Tensor
1986  */
1987 class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::NodeShape>
1988 {
1989 public:
1990   loco::NodeShape visit(const luci::CircleAbs *node) final { return use_x(node); }
1991
1992   loco::NodeShape visit(const luci::CircleAdd *node) final { return broadcast_xy(node); }
1993
1994   loco::NodeShape visit(const luci::CircleAddN *node) final { return infer_add_n(node); }
1995
1996   loco::NodeShape visit(const luci::CircleArgMax *node) final { return infer_arg_maxmin(node); }
1997
1998   loco::NodeShape visit(const luci::CircleArgMin *node) final { return infer_arg_maxmin(node); }
1999
2000   loco::NodeShape visit(const luci::CircleAveragePool2D *node) final
2001   {
2002     return infer_pool_2d_shape(node);
2003   }
2004
2005   loco::NodeShape visit(const luci::CircleBatchMatMul *node) final
2006   {
2007     auto x_shape = luci::shape_get(node->x()).as<loco::TensorShape>();
2008     auto y_shape = luci::shape_get(node->y()).as<loco::TensorShape>();
2009
2010     return infer_batchmatmul_shape(x_shape, y_shape, node->adj_x(), node->adj_y());
2011   }
2012
2013   loco::NodeShape visit(const luci::CircleBatchToSpaceND *node) final
2014   {
2015     return infer_batch_to_space_nd(node);
2016   }
2017
2018   loco::NodeShape visit(const luci::CircleCast *node) final { return use_x(node); }
2019
2020   loco::NodeShape visit(const luci::CircleCeil *node) final { return use_x(node); }
2021
2022   loco::NodeShape visit(const luci::CircleConcatenation *node) final
2023   {
2024     return infer_concatenation(node);
2025   }
2026
2027   loco::NodeShape visit(const luci::CircleConst *node) final { return use_own(node); }
2028
2029   loco::NodeShape visit(const luci::CircleConv2D *node) final { return infer_conv2d(node); }
2030
2031   loco::NodeShape visit(const luci::CircleCos *node) final { return use_x(node); }
2032
2033   loco::NodeShape visit(const luci::CircleCustom *node) final { return use_own(node); }
2034
2035   loco::NodeShape visit(const luci::CircleDensify *node) final { return use_input(node); }
2036
2037   loco::NodeShape visit(const luci::CircleDepthToSpace *node) final
2038   {
2039     return infer_depth_to_space(node);
2040   }
2041
2042   loco::NodeShape visit(const luci::CircleDepthwiseConv2D *node) final
2043   {
2044     return infer_depthwise_conv2d(node);
2045   }
2046
2047   loco::NodeShape visit(const luci::CircleDequantize *node) final
2048   {
2049     const auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
2050     return loco::NodeShape{input_shape};
2051   }
2052
2053   loco::NodeShape visit(const luci::CircleDiv *node) final { return broadcast_xy(node); }
2054
2055   loco::NodeShape visit(const luci::CircleElu *node) final
2056   {
2057     auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
2058
2059     return loco::NodeShape{input_shape};
2060   }
2061
2062   loco::NodeShape visit(const luci::CircleEqual *node) final { return broadcast_xy(node); }
2063
2064   loco::NodeShape visit(const luci::CircleExp *node) final { return use_x(node); }
2065
2066   loco::NodeShape visit(const luci::CircleExpandDims *node) final
2067   {
2068     return infer_expand_dims(node);
2069   }
2070
2071   loco::NodeShape visit(const luci::CircleFakeQuant *node) final { return use_inputs(node); }
2072
2073   loco::NodeShape visit(const luci::CircleFill *node) final { return infer_fill(node); }
2074
2075   loco::NodeShape visit(const luci::CircleFloor *node) final { return use_x(node); }
2076
2077   loco::NodeShape visit(const luci::CircleFloorDiv *node) final { return broadcast_xy(node); }
2078
2079   loco::NodeShape visit(const luci::CircleFloorMod *node) final { return broadcast_xy(node); }
2080
2081   loco::NodeShape visit(const luci::CircleFullyConnected *node) final
2082   {
2083     return infer_fully_connected(node);
2084   }
2085
2086   loco::NodeShape visit(const luci::CircleGather *node) final { return infer_gather(node); }
2087
2088   loco::NodeShape visit(const luci::CircleGatherNd *node) final { return infer_gather_nd(node); }
2089
2090   loco::NodeShape visit(const luci::CircleGelu *node) final
2091   {
2092     auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
2093
2094     return loco::NodeShape{input_shape};
2095   }
2096
2097   loco::NodeShape visit(const luci::CircleGreater *node) final { return broadcast_xy(node); }
2098
2099   loco::NodeShape visit(const luci::CircleGreaterEqual *node) final { return broadcast_xy(node); }
2100
2101   loco::NodeShape visit(const luci::CircleHardSwish *node) final
2102   {
2103     auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
2104
2105     return loco::NodeShape{input_shape};
2106   }
2107
2108   loco::NodeShape visit(const luci::CircleIf *node) final
2109   {
2110     // Shape of CircleIf is not used. Just use input 0
2111     assert(node->input_count() > 0);
2112     const auto input_shape = luci::shape_get(node->input(0)).as<loco::TensorShape>();
2113     return loco::NodeShape{input_shape};
2114   }
2115
2116   loco::NodeShape visit(const luci::CircleL2Normalize *node) final { return use_x(node); }
2117
2118   loco::NodeShape visit(const luci::CircleL2Pool2D *node) final
2119   {
2120     return infer_pool_2d_shape(node);
2121   }
2122
2123   loco::NodeShape visit(const luci::CircleLeakyRelu *node) final
2124   {
2125     const auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
2126     return loco::NodeShape{input_shape};
2127   }
2128
2129   loco::NodeShape visit(const luci::CircleLess *node) final { return broadcast_xy(node); }
2130
2131   loco::NodeShape visit(const luci::CircleLessEqual *node) final { return broadcast_xy(node); }
2132
2133   loco::NodeShape visit(const luci::CircleLocalResponseNormalization *node) final
2134   {
2135     const auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
2136     return loco::NodeShape{input_shape};
2137   }
2138
2139   loco::NodeShape visit(const luci::CircleLog *node) final { return use_x(node); }
2140
2141   loco::NodeShape visit(const luci::CircleLogicalAnd *node) final { return use_x(node); }
2142
2143   loco::NodeShape visit(const luci::CircleLogicalNot *node) final { return use_x(node); }
2144
2145   loco::NodeShape visit(const luci::CircleLogicalOr *node) final { return use_x(node); }
2146
2147   loco::NodeShape visit(const luci::CircleLogistic *node) final { return use_x(node); }
2148
2149   loco::NodeShape visit(const luci::CircleLogSoftmax *node) final { return use_logits(node); }
2150
2151   loco::NodeShape visit(const luci::CircleMatrixDiag *node) final
2152   {
2153     return infer_matrix_diag(node);
2154   }
2155
2156   loco::NodeShape visit(const luci::CircleMatrixSetDiag *node) final
2157   {
2158     return infer_matrix_set_diag(node);
2159   }
2160
2161   loco::NodeShape visit(const luci::CircleMaximum *node) final { return broadcast_xy(node); }
2162
2163   loco::NodeShape visit(const luci::CircleMaxPool2D *node) final
2164   {
2165     return infer_pool_2d_shape(node);
2166   }
2167
2168   loco::NodeShape visit(const luci::CircleMean *node) final
2169   {
2170     auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2171     return loco::NodeShape{output_shape};
2172   }
2173
2174   loco::NodeShape visit(const luci::CircleMinimum *node) final { return broadcast_xy(node); }
2175
2176   loco::NodeShape visit(const luci::CircleMirrorPad *node) final { return infer_mirror_pad(node); }
2177
2178   loco::NodeShape visit(const luci::CircleMul *node) final { return broadcast_xy(node); }
2179
2180   loco::NodeShape visit(const luci::CircleNeg *node) final { return use_x(node); }
2181
2182   loco::NodeShape visit(const luci::CircleNonMaxSuppressionV4 *node) final
2183   {
2184     const auto boxes_shape = luci::shape_get(node->boxes()).as<loco::TensorShape>();
2185     return loco::NodeShape{boxes_shape};
2186   }
2187
2188   loco::NodeShape visit(const luci::CircleNonMaxSuppressionV5 *node) final
2189   {
2190     const auto boxes_shape = luci::shape_get(node->boxes()).as<loco::TensorShape>();
2191     return loco::NodeShape{boxes_shape};
2192   }
2193
2194   loco::NodeShape visit(const luci::CircleNotEqual *node) final { return broadcast_xy(node); }
2195
2196   loco::NodeShape visit(const luci::CircleOneHot *node) final { return infer_one_hot(node); }
2197
2198   loco::NodeShape visit(const luci::CirclePack *node) final { return infer_pack(node); }
2199
2200   loco::NodeShape visit(const luci::CirclePad *node) final { return infer_pad(node); }
2201
2202   loco::NodeShape visit(const luci::CirclePadV2 *node) final { return infer_pad_v2(node); }
2203
2204   loco::NodeShape visit(const luci::CirclePow *node) final { return broadcast_xy(node); }
2205
2206   loco::NodeShape visit(const luci::CirclePRelu *node) final { return infer_p_relu(node); }
2207
2208   loco::NodeShape visit(const luci::CircleQuantize *node) final
2209   {
2210     const auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
2211     return loco::NodeShape{input_shape};
2212   }
2213
2214   loco::NodeShape visit(const luci::CircleRange *node) final { return infer_range(node); }
2215
2216   loco::NodeShape visit(const luci::CircleRank *) final
2217   {
2218     loco::TensorShape shape_output;
2219     shape_output.rank(0);
2220
2221     return loco::NodeShape{shape_output};
2222   }
2223
2224   loco::NodeShape visit(const luci::CircleReduceAny *node) final
2225   {
2226     auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2227     return loco::NodeShape{output_shape};
2228   }
2229
2230   loco::NodeShape visit(const luci::CircleReduceMax *node) final
2231   {
2232     auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2233     return loco::NodeShape{output_shape};
2234   }
2235
2236   loco::NodeShape visit(const luci::CircleReduceMin *node) final
2237   {
2238     auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2239     return loco::NodeShape{output_shape};
2240   }
2241
2242   loco::NodeShape visit(const luci::CircleReduceProd *node) final
2243   {
2244     auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2245     return loco::NodeShape{output_shape};
2246   }
2247
2248   loco::NodeShape visit(const luci::CircleRelu *node) final
2249   {
2250     auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
2251
2252     return loco::NodeShape{input_shape};
2253   }
2254
2255   loco::NodeShape visit(const luci::CircleRelu6 *node) final
2256   {
2257     auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
2258
2259     return loco::NodeShape{input_shape};
2260   }
2261
2262   loco::NodeShape visit(const luci::CircleReluN1To1 *node) final
2263   {
2264     auto input_shape = luci::shape_get(node->features()).as<loco::TensorShape>();
2265
2266     return loco::NodeShape{input_shape};
2267   }
2268
2269   /**
2270    * @note  CircleReshape has new shape info in two places: 2nd input and attribute.
2271    *        This shape inference uses shape from input 'shape' node when it's constant.
2272    *        If not, shape will be from node itself. shape from attribute is not used.
2273    *
2274    * TODO Change this policy when not appropriate
2275    */
2276   loco::NodeShape visit(const luci::CircleReshape *node) final { return infer_reshape(node); }
2277
2278   loco::NodeShape visit(const luci::CircleResizeBilinear *node) final
2279   {
2280     return infer_resize_type(node);
2281   }
2282
2283   loco::NodeShape visit(const luci::CircleResizeNearestNeighbor *node) final
2284   {
2285     return infer_resize_type(node);
2286   }
2287
2288   loco::NodeShape visit(const luci::CircleReverseSequence *node) final
2289   {
2290     auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
2291
2292     return loco::NodeShape{input_shape};
2293   }
2294
2295   loco::NodeShape visit(const luci::CircleRound *node) final { return use_x(node); }
2296
2297   loco::NodeShape visit(const luci::CircleReverseV2 *node) final
2298   {
2299     auto input_shape = luci::shape_get(node->tensor()).as<loco::TensorShape>();
2300
2301     LUCI_ASSERT(luci::shape_get(node->axis()).as<loco::TensorShape>().rank() == 1,
2302                 "Tensor must be 1-D");
2303
2304     return loco::NodeShape{input_shape};
2305   }
2306
2307   loco::NodeShape visit(const luci::CircleRsqrt *node) final { return use_x(node); }
2308
2309   loco::NodeShape visit(const luci::CircleScatterNd *node) final { return infer_scatter_nd(node); }
2310
2311   loco::NodeShape visit(const luci::CircleSegmentSum *node) final
2312   {
2313     return infer_segment_sum(node);
2314   }
2315
2316   loco::NodeShape visit(const luci::CircleSelect *node) final { return infer_select(node); }
2317
2318   loco::NodeShape visit(const luci::CircleSelectV2 *node) final { return infer_select_v2(node); }
2319
2320   loco::NodeShape visit(const luci::CircleShape *node) final { return infer_shape(node); }
2321
2322   loco::NodeShape visit(const luci::CircleSin *node) final { return use_x(node); }
2323
2324   loco::NodeShape visit(const luci::CircleSlice *node) final { return infer_slice(node); }
2325
2326   loco::NodeShape visit(const luci::CircleSoftmax *node) final { return use_logits(node); }
2327
2328   loco::NodeShape visit(const luci::CircleSpaceToBatchND *node) final
2329   {
2330     return infer_space_to_batch_nd(node);
2331   }
2332
2333   loco::NodeShape visit(const luci::CircleSpaceToDepth *node) final
2334   {
2335     return infer_space_to_depth(node);
2336   }
2337
2338   loco::NodeShape visit(const luci::CircleSparseToDense *node) final
2339   {
2340     return infer_sparse_to_dense(node);
2341   }
2342
2343   loco::NodeShape visit(const luci::CircleSplit *node) final
2344   {
2345     // We'll set Split output as same as input so that SplitOut can handle it's own shape
2346     auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
2347     return loco::NodeShape{input_shape};
2348   }
2349
2350   loco::NodeShape visit(const luci::CircleSplitV *node) final
2351   {
2352     // We'll set SplitV output as same as input so that SplitOut can handle it's own shape
2353     auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
2354     return loco::NodeShape{input_shape};
2355   }
2356
2357   loco::NodeShape visit(const luci::CircleSqrt *node) final { return use_x(node); }
2358
2359   loco::NodeShape visit(const luci::CircleSquare *node) final { return use_x(node); }
2360
2361   loco::NodeShape visit(const luci::CircleSquaredDifference *node) final
2362   {
2363     return broadcast_xy(node);
2364   }
2365
2366   loco::NodeShape visit(const luci::CircleStridedSlice *node) final
2367   {
2368     return infer_strided_slice(node);
2369   }
2370
2371   loco::NodeShape visit(const luci::CircleSqueeze *node) final { return infer_squeeze(node); }
2372
2373   loco::NodeShape visit(const luci::CircleSub *node) final { return broadcast_xy(node); }
2374
2375   loco::NodeShape visit(const luci::CircleSum *node) final
2376   {
2377     auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
2378     return loco::NodeShape{output_shape};
2379   }
2380
2381   loco::NodeShape visit(const luci::CircleSVDF *node) final { return infer_svdf(node); }
2382
2383   loco::NodeShape visit(const luci::CircleTanh *node) final { return use_x(node); }
2384
2385   loco::NodeShape visit(const luci::CircleTile *node) final { return infer_tile(node); }
2386
2387   loco::NodeShape visit(const luci::CircleTopKV2 *node) final
2388   {
2389     // set shape of this node as same as input
2390     const auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
2391     return loco::NodeShape{input_shape};
2392   }
2393
2394   loco::NodeShape visit(const luci::CircleTranspose *node) final { return infer_transpose(node); }
2395
2396   loco::NodeShape visit(const luci::CircleTransposeConv *node) final
2397   {
2398     return infer_transpose_conv(node);
2399   }
2400
2401   loco::NodeShape visit(const luci::CircleUnpack *node) final { return infer_unpack(node); }
2402
2403   loco::NodeShape visit(const luci::CircleUnidirectionalSequenceLSTM *node) final
2404   {
2405     return infer_unidirectionalsequencelstm(node);
2406   }
2407
2408   loco::NodeShape visit(const luci::CircleUnique *node) final { return infer_unique(node); }
2409
2410   loco::NodeShape visit(const luci::CircleWhere *node) final { return use_own(node); }
2411
2412   loco::NodeShape visit(const luci::CircleWhile *node) final
2413   {
2414     // Shape of CircleWhile is not used. Just use input 0
2415     assert(node->arity() > 0);
2416     const auto input_shape = luci::shape_get(node->input(0)).as<loco::TensorShape>();
2417     return loco::NodeShape{input_shape};
2418   }
2419
2420   loco::NodeShape visit(const luci::CircleZerosLike *node) final
2421   {
2422     auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
2423
2424     return loco::NodeShape{input_shape};
2425   }
2426
2427   // Circle Only
2428   loco::NodeShape visit(const luci::CircleBCQFullyConnected *node) final
2429   {
2430     return infer_bcq_fully_connected(node);
2431   }
2432
2433   loco::NodeShape visit(const luci::CircleBCQGather *node) final { return infer_bcq_gather(node); }
2434
2435   loco::NodeShape visit(const luci::CircleInstanceNorm *node) final
2436   {
2437     auto input_shape = luci::shape_get(node->input()).as<loco::TensorShape>();
2438
2439     return loco::NodeShape{input_shape};
2440   }
2441
2442   // Virtual
2443   loco::NodeShape visit(const luci::CircleInput *node) final { return infer_input(node); }
2444
2445   loco::NodeShape visit(const luci::CircleOutput *node) final { return infer_output(node); }
2446
2447   loco::NodeShape visit(const luci::CircleOutputDummy *node) final { return use_own(node); }
2448
2449   loco::NodeShape visit(const luci::CircleOutputExclude *node) final { return use_own(node); }
2450
2451   loco::NodeShape visit(const luci::CircleCustomOut *node) final { return use_own(node); }
2452
2453   loco::NodeShape visit(const luci::CircleNonMaxSuppressionV4Out *node) final
2454   {
2455     return infer_non_max_suppression_v4_out(node);
2456   }
2457
2458   loco::NodeShape visit(const luci::CircleNonMaxSuppressionV5Out *node) final
2459   {
2460     return infer_non_max_suppression_v5_out(node);
2461   }
2462
2463   loco::NodeShape visit(const luci::CircleSplitOut *node) final { return infer_split_out(node); }
2464
2465   loco::NodeShape visit(const luci::CircleSplitVOut *node) final { return infer_split_v_out(node); }
2466
2467   loco::NodeShape visit(const luci::CircleTopKV2Out *node) final
2468   {
2469     return infer_top_k_v2_out(node);
2470   }
2471
2472   loco::NodeShape visit(const luci::CircleUniqueOut *node) final { return infer_unique_out(node); }
2473
2474   loco::NodeShape visit(const luci::CircleUnpackOut *node) final { return infer_unpack_out(node); }
2475
2476   loco::NodeShape visit(const luci::CircleVariable *node) final { return use_own(node); }
2477
2478   loco::NodeShape visit(const luci::CircleWhileOut *node) final { return infer_while_out(node); }
2479 };
2480
2481 } // namespace
2482
2483 namespace luci
2484 {
2485
2486 bool CircleShapeInferenceRule::recognize(const loco::Dialect *d) const
2487 {
2488   return CircleDialect::get() == d;
2489 }
2490
2491 bool CircleShapeInferenceRule::infer(const loco::Node *node, loco::NodeShape &shape) const
2492 {
2493   LOGGER(l);
2494
2495   assert(node->dialect() == CircleDialect::get());
2496
2497   ShapeInferenceAlgorithm alg;
2498   auto circle_node = loco::must_cast<const CircleNode *>(node);
2499
2500   bool is_shape_undefined = (circle_node->shape_status() == ShapeStatus::UNDEFINED);
2501   bool is_shape_none = (circle_node->shape_status() == ShapeStatus::NOSHAPE);
2502   bool is_scalar = (circle_node->rank() == 0);
2503
2504   if (is_shape_undefined)
2505     shape = circle_node->accept(&alg);
2506   else
2507   {
2508     if (is_shape_none || is_scalar)
2509       shape = own_shape(circle_node);
2510     else
2511       shape = circle_node->accept(&alg);
2512   }
2513
2514   VERBOSE(l, 1) << "[luci] shape: " << circle_node->name();
2515   VERBOSE(l, 1) << "              own_shape: " << own_shape(circle_node)
2516                 << " -> infer: " << shape.as<loco::TensorShape>();
2517
2518   return true;
2519 }
2520
2521 } // namespace luci