6355ec546b5ecd4820081817010db90a877d4669
[platform/core/ml/nnfw.git] / compiler / luci / service / src / CircleShapeInferenceRule.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Service/CircleShapeInferenceRule.h"
18 #include "Check.h"
19
20 #include "ShapeInfer_StridedSlice.h"
21
22 #include <luci/IR/CircleNodes.h>
23 #include <luci/IR/CircleDialect.h>
24 #include <luci/IR/CircleNodeVisitor.h>
25 #include <luci/Log.h>
26
27 #include <oops/InternalExn.h>
28
29 #include <algorithm>
30 #include <cassert>
31 #include <cmath>
32 #include <stdexcept>
33
34 namespace
35 {
36
37 std::ostream &operator<<(std::ostream &os, const loco::TensorShape &tensor_shape)
38 {
39   os << "[";
40   for (uint32_t r = 0; r < tensor_shape.rank(); ++r)
41   {
42     if (r)
43       os << ",";
44     os << tensor_shape.dim(r).value();
45   }
46   os << "]";
47   return os;
48 }
49
50 // Call this for CircleAvgPool2D and CircleMaxPool2D only
51 template <class Pool2DType> loco::NodeShape infer_pool_2d_shape(const Pool2DType *node)
52 {
53   LUCI_ASSERT(loco::shape_known(node->value()), "Shape must be known");
54
55   auto ifm_shape = loco::shape_get(node->value()).template as<loco::TensorShape>();
56   assert(ifm_shape.rank() == 4);
57
58   uint32_t input_height = ifm_shape.dim(1).value();
59   uint32_t input_width = ifm_shape.dim(2).value();
60   uint32_t stride_height = node->stride()->h();
61   uint32_t stride_width = node->stride()->w();
62   uint32_t window_height = node->filter()->h();
63   uint32_t window_width = node->filter()->w();
64   uint32_t dilation_height = 1; // dilation for CircleAvgPool2D and CircleMaxPool2D is 1
65   uint32_t dilation_width = 1;
66   uint32_t effective_window_height = dilation_height * (window_height - 1) + 1;
67   uint32_t effective_window_width = dilation_width * (window_width - 1) + 1;
68
69   uint32_t output_height = 0;
70   uint32_t output_width = 0;
71
72   if (node->padding() == luci::Padding::VALID)
73   {
74     output_height = (input_height + stride_height - effective_window_height) / stride_height;
75     output_width = (input_width + stride_width - effective_window_width) / stride_width;
76   }
77   else if (node->padding() == luci::Padding::SAME)
78   {
79     output_height = (input_height + stride_height - 1) / stride_height;
80     output_width = (input_width + stride_width - 1) / stride_width;
81   }
82   else
83     LUCI_ASSERT(false, "Wrong padding type");
84
85   loco::TensorShape ofm_shape;
86   ofm_shape.rank(4);
87   ofm_shape.dim(0) = ifm_shape.dim(0);
88   ofm_shape.dim(1) = output_height;
89   ofm_shape.dim(2) = output_width;
90   ofm_shape.dim(3) = ifm_shape.dim(3);
91
92   return loco::NodeShape{ofm_shape};
93 }
94
95 /**
96  * @brief Create a higher-rank TensorShape following NumPy broadcasting semantics
97  *
98  * HOW TO USE:
99  *
100  *   auto expanded_tensor_shape = expand(tensor_shape).to(N);
101  */
102 class TensorShapeExpander
103 {
104 public:
105   TensorShapeExpander(const loco::TensorShape &shape) : _shape{shape}
106   {
107     // DO NOTHING
108   }
109
110 public:
111   loco::TensorShape to(uint32_t output_rank)
112   {
113     auto const &input_shape = _shape;
114     uint32_t const input_rank = input_shape.rank();
115
116     assert(input_rank <= output_rank && "Cannot shrink rank");
117     uint32_t const axis_shift = output_rank - input_rank;
118
119     loco::TensorShape output_shape;
120
121     output_shape.rank(output_rank);
122     for (uint32_t axis = 0; axis < output_rank; ++axis)
123     {
124       output_shape.dim(axis) = (axis < axis_shift) ? 1 : input_shape.dim(axis - axis_shift);
125     }
126
127     return output_shape;
128   }
129
130 private:
131   const loco::TensorShape _shape;
132 };
133
134 /**
135  * @breif  Expand shape x and y to same rank by align right and filling with 1
136  */
137 void expand_rank(loco::TensorShape &x, loco::TensorShape &y)
138 {
139   auto x_rank = x.rank();
140   auto y_rank = y.rank();
141
142   if (x_rank == y_rank)
143     return;
144
145   TensorShapeExpander x_exp(x);
146   TensorShapeExpander y_exp(y);
147
148   auto xy_rank = std::max(x_rank, y_rank);
149
150   x = x_rank > y_rank ? x : x_exp.to(xy_rank);
151   y = y_rank > x_rank ? y : y_exp.to(xy_rank);
152 }
153
154 /**
155  * @breif  Returns shape of expanded dimension of input x and y having same rank
156  */
157 loco::TensorShape expand_dimension(const loco::TensorShape &x, const loco::TensorShape &y)
158 {
159   assert(x.rank() == y.rank());
160
161   auto rank = x.rank();
162
163   loco::TensorShape output_shape;
164
165   output_shape.rank(rank);
166   for (uint32_t axis = 0; axis < rank; ++axis)
167   {
168     assert(x.dim(axis).known() && y.dim(axis).known());
169
170     auto x_dim = x.dim(axis).value();
171     auto y_dim = y.dim(axis).value();
172
173     // each dimension of x and y should be same or one must be 1 if different
174     if (!((x_dim == y_dim) || (x_dim == 1 || y_dim == 1)))
175       INTERNAL_EXN("Cannot produce expand_dimension of two shapes");
176
177     output_shape.dim(axis) = std::max(x_dim, y_dim);
178   }
179
180   return output_shape;
181 }
182
183 loco::TensorShape broadcast_shape(const loco::TensorShape &x, const loco::TensorShape &y)
184 {
185   auto x_match = x;
186   auto y_match = y;
187
188   expand_rank(x_match, y_match);
189
190   auto output_shape = expand_dimension(x_match, y_match);
191
192   return output_shape;
193 }
194
195 // BatchMatMulV2 supports broadcasting in the batch dimensions(BatchMatMul doesn't)
196 // TODO Distinguish BatchMatMul and BatchMatMulV2
197 loco::NodeShape infer_batchmatmul_shape(const loco::TensorShape &x_shape,
198                                         const loco::TensorShape &y_shape, bool adj_x, bool adj_y)
199 {
200   uint32_t x_rank = x_shape.rank();
201   uint32_t y_rank = y_shape.rank();
202   assert(x_rank >= 2 && y_rank >= 2);
203
204   loco::TensorShape output_shape;
205   output_shape.rank(x_shape.rank());
206   // Braodcast in the batch dimension
207   if (x_rank > 2 || y_rank > 2)
208   {
209     loco::TensorShape dummy_x = x_shape;
210     loco::TensorShape dummy_y = y_shape;
211     expand_rank(dummy_x, dummy_y);
212     if (x_rank < y_rank)
213       expand_rank(output_shape, dummy_y);
214
215     for (uint32_t d = 0; d < output_shape.rank() - 2; d++)
216     {
217       uint32_t max_dim = std::max(dummy_x.dim(d).value(), dummy_y.dim(d).value());
218       if (dummy_x.dim(d) == dummy_y.dim(d) ||
219           dummy_x.dim(d).value() * dummy_y.dim(d).value() == max_dim)
220         output_shape.dim(d).set(max_dim);
221       else
222         INTERNAL_EXN("BatchMatMul has wrong shape");
223     }
224   }
225
226   loco::Dimension x_lhs = adj_x ? x_shape.dim(x_rank - 1) : x_shape.dim(x_rank - 2);
227   loco::Dimension x_rhs = adj_x ? x_shape.dim(x_rank - 2) : x_shape.dim(x_rank - 1);
228   loco::Dimension y_lhs = adj_y ? y_shape.dim(y_rank - 1) : y_shape.dim(y_rank - 2);
229   loco::Dimension y_rhs = adj_y ? y_shape.dim(y_rank - 2) : y_shape.dim(y_rank - 1);
230
231   if (not(x_rhs == y_lhs))
232     INTERNAL_EXN("x_rhs and y_lhs should be same");
233
234   uint32_t out_rank = output_shape.rank();
235   output_shape.dim(out_rank - 2) = x_lhs;
236   output_shape.dim(out_rank - 1) = y_rhs;
237
238   return loco::NodeShape{output_shape};
239 }
240
241 loco::TensorShape own_shape(const luci::CircleNode *node)
242 {
243   loco::TensorShape shape;
244   shape.rank(node->rank());
245   for (uint32_t r = 0; r < node->rank(); ++r)
246     shape.dim(r) = loco::Dimension(node->dim(r).value());
247   return shape;
248 }
249
250 loco::TensorShape infer_reducer(const loco::Node *input, const loco::Node *indices, bool keep_dims)
251 {
252   const loco::DataType S32 = loco::DataType::S32;
253
254   auto input_shape = loco::shape_get(input).as<loco::TensorShape>();
255   auto reduction_indices = loco::must_cast<const luci::CircleConst *>(indices);
256
257   { // Exceptions
258     // TODO support non-const case
259     // TODO support other data type
260     LUCI_ASSERT(reduction_indices->dtype() == S32, "Only support int 32");
261   }
262
263   std::vector<int32_t> reduction_values;
264
265   for (uint32_t i = 0; i < reduction_indices->size<S32>(); ++i)
266   {
267     int32_t axis = reduction_indices->at<S32>(i);
268     if (axis < 0)
269       axis += input_shape.rank();
270     if (not(0 <= axis and axis < static_cast<int32_t>(input_shape.rank())))
271       INTERNAL_EXN_V("Invalid reduction axis for REDUCER", oops::to_uint32(axis));
272     reduction_values.push_back(axis);
273   }
274
275   loco::TensorShape output_shape;
276
277   if (keep_dims)
278   {
279     output_shape.rank(input_shape.rank());
280     for (uint32_t i = 0; i < input_shape.rank(); ++i)
281       output_shape.dim(i) = input_shape.dim(i);
282     for (uint32_t i = 0; i < reduction_values.size(); ++i)
283       output_shape.dim(reduction_values.at(i)) = 1;
284   }
285   else
286   {
287     std::vector<bool> check_reduce(input_shape.rank(), false);
288     for (uint32_t i = 0; i < reduction_values.size(); ++i)
289       check_reduce.at(reduction_values.at(i)) = true;
290
291     uint32_t reduce_cnt = 0;
292     for (uint32_t i = 0; i < check_reduce.size(); ++i)
293       if (check_reduce.at(i))
294         ++reduce_cnt;
295
296     output_shape.rank(input_shape.rank() - reduce_cnt);
297     for (uint32_t i = 0, j = 0; i < check_reduce.size(); ++i)
298       if (check_reduce.at(i) == false)
299         output_shape.dim(j++) = input_shape.dim(i);
300   }
301
302   return output_shape;
303 }
304
305 /**
306  * @brief vector_from_constant will return int64_t vector from CircleConst node
307  */
308 template <loco::DataType T> std::vector<int64_t> vector_from_constant(luci::CircleConst *const_node)
309 {
310   std::vector<int64_t> result;
311
312   for (uint32_t idx = 0; idx < const_node->size<T>(); ++idx)
313     result.push_back(const_node->at<T>(idx));
314
315   return result;
316 }
317
318 template <class CIRCLENODE> loco::NodeShape broadcast_xy(const CIRCLENODE *node)
319 {
320   auto x_shape = loco::shape_get(node->x()).template as<loco::TensorShape>();
321   auto y_shape = loco::shape_get(node->y()).template as<loco::TensorShape>();
322
323   auto output_shape = broadcast_shape(x_shape, y_shape);
324
325   return loco::NodeShape{output_shape};
326 }
327
328 template <class CIRCLENODE> loco::NodeShape use_x(const CIRCLENODE *node)
329 {
330   auto x_shape = loco::shape_get(node->x()).template as<loco::TensorShape>();
331   return loco::NodeShape{x_shape};
332 }
333
334 template <class CIRCLENODE> loco::NodeShape use_logits(const CIRCLENODE *node)
335 {
336   auto shape = loco::shape_get(node->logits()).template as<loco::TensorShape>();
337   return loco::NodeShape{shape};
338 }
339
340 loco::NodeShape use_own(const luci::CircleNode *node)
341 {
342   loco::TensorShape shape = own_shape(node);
343   return loco::NodeShape{shape};
344 }
345
346 /**
347  * @brief Class to infer the shape of CircleNode
348  *
349  * @note All CircleNode's inputs and outputs are always loco::Domain::Tensor
350  */
351 class ShapeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::NodeShape>
352 {
353 public:
354   loco::NodeShape visit(const luci::CircleAbs *node) final { return use_x(node); }
355
356   loco::NodeShape visit(const luci::CircleAdd *node) final { return broadcast_xy(node); }
357
358   loco::NodeShape visit(const luci::CircleAddN *node) final
359   {
360     auto shape = loco::shape_get(node->inputs(0)).as<loco::TensorShape>();
361
362     for (uint32_t idx = 1; idx < node->arity(); ++idx)
363     {
364       auto shape_idx = loco::shape_get(node->inputs(idx)).as<loco::TensorShape>();
365       if (!(shape == shape_idx))
366       {
367         INTERNAL_EXN_V("ADD_N shape not same as the first input: ", idx);
368       }
369     }
370
371     return loco::NodeShape{shape};
372   }
373
374   loco::NodeShape visit(const luci::CircleArgMax *node) final
375   {
376     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
377     auto dimension_shape = loco::shape_get(node->dimension()).as<loco::TensorShape>();
378
379     int64_t select_axis = 0;
380     {
381       LUCI_ASSERT(node->dimension(), "2nd input dimension() should not be nullptr");
382
383       // Only support node's shape() is CircleConst with S32/S64
384       // Support S32 for now.
385       auto const_shape_node = loco::must_cast<luci::CircleConst *>(node->dimension());
386       LUCI_ASSERT(const_shape_node->dtype() == loco::DataType::S32,
387                   "Only support int32 CircleConst for CircleArgMax");
388
389       if (const_shape_node->rank() > 1)
390         INTERNAL_EXN_V("Only support rank 0/1 CircleConst",
391                        oops::to_uint32(const_shape_node->rank()));
392
393       select_axis = const_shape_node->scalar<loco::DataType::S32>();
394     }
395     assert(select_axis < input_shape.rank());
396     assert(select_axis >= 0); // TODO support minus of this breaks
397
398     // NOTE select_axis is removed
399     loco::TensorShape shape_output;
400     uint32_t rank = input_shape.rank();
401     uint32_t shrink = static_cast<uint32_t>(select_axis);
402     assert(rank > 0);
403     shape_output.rank(rank - 1);
404     for (uint32_t r = 0, d = 0; r < rank; ++r)
405     {
406       if (r == shrink)
407         continue;
408       shape_output.dim(d++) = input_shape.dim(r);
409     }
410     return loco::NodeShape{shape_output};
411   }
412
413   loco::NodeShape visit(const luci::CircleArgMin *node) final
414   {
415     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
416     auto dimension_shape = loco::shape_get(node->dimension()).as<loco::TensorShape>();
417
418     int64_t select_axis = 0;
419     {
420       LUCI_ASSERT(node->dimension(), "2nd input dimension() should not be nullptr");
421
422       // Only support node's shape() is CircleConst with S32/S64
423       // Support S32 for now.
424       auto const_shape_node = loco::must_cast<luci::CircleConst *>(node->dimension());
425       LUCI_ASSERT(const_shape_node->dtype() == loco::DataType::S32,
426                   "Only support int32 CircleConst for CircleArgMin");
427
428       if (const_shape_node->rank() > 1)
429         INTERNAL_EXN_V("Only support rank 0/1 CircleConst",
430                        oops::to_uint32(const_shape_node->rank()));
431
432       select_axis = const_shape_node->scalar<loco::DataType::S32>();
433     }
434     assert(select_axis < input_shape.rank());
435     assert(select_axis >= 0); // TODO support minus of this breaks
436
437     // NOTE select_axis is removed
438     loco::TensorShape shape_output;
439     uint32_t rank = input_shape.rank();
440     uint32_t shrink = static_cast<uint32_t>(select_axis);
441     assert(rank > 0);
442     shape_output.rank(rank - 1);
443     for (uint32_t r = 0, d = 0; r < rank; ++r)
444     {
445       if (r == shrink)
446         continue;
447       shape_output.dim(d++) = input_shape.dim(r);
448     }
449     return loco::NodeShape{shape_output};
450   }
451
452   loco::NodeShape visit(const luci::CircleAveragePool2D *node) final
453   {
454     return infer_pool_2d_shape(node);
455   }
456
457   loco::NodeShape visit(const luci::CircleBatchMatMul *node) final
458   {
459     auto x_shape = loco::shape_get(node->x()).as<loco::TensorShape>();
460     auto y_shape = loco::shape_get(node->y()).as<loco::TensorShape>();
461
462     return infer_batchmatmul_shape(x_shape, y_shape, node->adj_x(), node->adj_y());
463   }
464
465   loco::NodeShape visit(const luci::CircleBatchToSpaceND *node) final
466   {
467     const loco::DataType S32 = loco::DataType::S32;
468
469     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
470     // Support only input rank is 3 and 4
471     assert(input_shape.rank() == 3 || input_shape.rank() == 4);
472
473     // Only support block_shape() with S32 type CircleConst for now
474     auto const_block_shape = loco::must_cast<luci::CircleConst *>(node->block_shape());
475     LUCI_ASSERT(const_block_shape->dtype() == loco::DataType::S32,
476                 "Only support int32 block_shape");
477
478     // Only support crops() with S32 type CircleConst for now
479     auto const_crops = loco::must_cast<luci::CircleConst *>(node->crops());
480     LUCI_ASSERT(const_crops->dtype() == loco::DataType::S32, "Only support int32 crops");
481
482     auto const_block_shape_shape = loco::shape_get(const_block_shape).as<loco::TensorShape>();
483     auto const_crops_shape = loco::shape_get(const_crops).as<loco::TensorShape>();
484     assert(const_block_shape_shape.rank() == 1);
485     assert(const_crops_shape.rank() == 2);
486
487     int32_t input_spatial_dim = input_shape.rank() - 2;
488     assert(const_block_shape_shape.dim(0) == input_spatial_dim);
489     assert(const_crops_shape.dim(0) == input_spatial_dim);
490     assert(const_crops_shape.dim(1) == 2);
491
492     loco::TensorShape shape_output;
493
494     shape_output.rank(input_shape.rank());
495
496     int32_t output_batch_size = input_shape.dim(0).value();
497     for (int32_t dim = 0; dim < input_spatial_dim; ++dim)
498     {
499       int dim_size = input_shape.dim(dim + 1).value() * const_block_shape->at<S32>(dim);
500       dim_size -= const_crops->at<S32>(dim * 2);
501       dim_size -= const_crops->at<S32>(dim * 2 + 1);
502       shape_output.dim(dim + 1) = dim_size;
503
504       assert(output_batch_size % const_block_shape->at<S32>(dim) == 0);
505       output_batch_size = output_batch_size / const_block_shape->at<S32>(dim);
506     }
507     shape_output.dim(0) = output_batch_size;
508     shape_output.dim(input_shape.rank() - 1) = input_shape.dim(input_shape.rank() - 1);
509
510     return loco::NodeShape{shape_output};
511   }
512
513   loco::NodeShape visit(const luci::CircleCast *node) final { return use_x(node); }
514
515   loco::NodeShape visit(const luci::CircleCeil *node) final { return use_x(node); }
516
517   loco::NodeShape visit(const luci::CircleConcatenation *node) final
518   {
519     // TODO Support when CircleConcatenation has 0 input
520     assert(node->numValues() > 0);
521
522     auto first_shape = loco::shape_get(node->values(0)).as<loco::TensorShape>();
523     auto axis = node->axis();
524     if (axis < 0)
525       axis += first_shape.rank();
526
527     assert(0 <= axis);
528     assert(first_shape.rank() > static_cast<uint32_t>(axis));
529
530     loco::TensorShape output_shape;
531
532     output_shape.rank(first_shape.rank());
533     for (uint32_t i = 0; i < output_shape.rank(); ++i)
534       output_shape.dim(i) = first_shape.dim(i);
535
536     for (uint32_t i = 1; i < node->numValues(); ++i)
537     {
538       auto input_shape = loco::shape_get(node->values(i)).as<loco::TensorShape>();
539
540       for (uint32_t j = 0; j < output_shape.rank(); ++j)
541       {
542         if (j == static_cast<uint32_t>(axis))
543           output_shape.dim(j) = output_shape.dim(j).value() + input_shape.dim(j).value();
544         else
545           assert(output_shape.dim(j) == input_shape.dim(j));
546       }
547     }
548
549     return loco::NodeShape{output_shape};
550   }
551
552   loco::NodeShape visit(const luci::CircleConst *node) final { return use_own(node); }
553
554   loco::NodeShape visit(const luci::CircleConv2D *node) final
555   {
556     LOGGER(l);
557
558     auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>();  // in NHWC
559     auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in OHWI
560
561     INFO(l) << "[luci] CircleConv2D ShapeInf ifm(" << ifm_shape.rank() << ") ker("
562             << ker_shape.rank() << ")" << std::endl;
563
564     assert(ifm_shape.rank() == 4);
565     assert(ker_shape.rank() == 4);
566     assert(ifm_shape.dim(3) == ker_shape.dim(3));
567
568     uint32_t input_height = ifm_shape.dim(1).value();
569     uint32_t input_width = ifm_shape.dim(2).value();
570     uint32_t stride_height = node->stride()->h();
571     uint32_t stride_width = node->stride()->w();
572     uint32_t ker_height = ker_shape.dim(1).value();
573     uint32_t ker_width = ker_shape.dim(2).value();
574     uint32_t dilation_height = node->dilation()->h();
575     uint32_t dilation_width = node->dilation()->w();
576     uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1;
577     uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1;
578
579     uint32_t output_height = 0;
580     uint32_t output_width = 0;
581
582     if (node->padding() == luci::Padding::VALID)
583     {
584       output_height = (input_height + stride_height - effective_ker_height) / stride_height;
585       output_width = (input_width + stride_width - effective_ker_width) / stride_width;
586     }
587     else if (node->padding() == luci::Padding::SAME)
588     {
589       output_height = (input_height + stride_height - 1) / stride_height;
590       output_width = (input_width + stride_width - 1) / stride_width;
591     }
592     else
593       LUCI_ASSERT(false, "Wrong padding type");
594
595     loco::TensorShape ofm_shape;
596     ofm_shape.rank(4);
597     ofm_shape.dim(0) = ifm_shape.dim(0);
598     ofm_shape.dim(1) = output_height;
599     ofm_shape.dim(2) = output_width;
600     ofm_shape.dim(3) = ker_shape.dim(0);
601
602     return loco::NodeShape{ofm_shape};
603   }
604
605   loco::NodeShape visit(const luci::CircleCos *node) final { return use_x(node); }
606
607   loco::NodeShape visit(const luci::CircleCustom *node) final { return use_own(node); }
608
609   loco::NodeShape visit(const luci::CircleDepthToSpace *node) final
610   {
611     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
612     LUCI_ASSERT(input_shape.rank() == 4, "Only input rank 4 is supported");
613
614     // Only data format NHWC is supported
615     // TODO need to clarify what to do with layout in this operator
616     int32_t height = input_shape.dim(1).value();
617     int32_t width = input_shape.dim(2).value();
618     int32_t depth = input_shape.dim(3).value();
619
620     int block_size = node->block_size();
621
622     if (block_size < 2)
623       INTERNAL_EXN("Block size must be >= 2");
624
625     if (depth % (block_size * block_size))
626     {
627       INTERNAL_EXN("The input tensor's depth must be divisible by block_size^2");
628     }
629
630     loco::TensorShape output_shape;
631     output_shape.rank(4);
632
633     output_shape.dim(0) = input_shape.dim(0).value();
634     output_shape.dim(1) = height * block_size;
635     output_shape.dim(2) = width * block_size;
636     output_shape.dim(3) = depth / (block_size * block_size);
637
638     return loco::NodeShape{output_shape};
639   }
640
641   loco::NodeShape visit(const luci::CircleDepthwiseConv2D *node) final
642   {
643     auto ifm_shape = loco::shape_get(node->input()).as<loco::TensorShape>();  // in NHWC
644     auto ker_shape = loco::shape_get(node->filter()).as<loco::TensorShape>(); // in 1 H W CM
645
646     assert(ifm_shape.rank() == 4);
647     assert(ker_shape.rank() == 4);
648     assert(ker_shape.dim(0).value() == 1);
649
650     uint32_t input_height = ifm_shape.dim(1).value();
651     uint32_t input_width = ifm_shape.dim(2).value();
652     uint32_t stride_height = node->stride()->h();
653     uint32_t stride_width = node->stride()->w();
654     uint32_t ker_height = ker_shape.dim(1).value();
655     uint32_t ker_width = ker_shape.dim(2).value();
656     uint32_t dilation_height = node->dilation()->h();
657     uint32_t dilation_width = node->dilation()->w();
658     uint32_t effective_ker_height = dilation_height * (ker_height - 1) + 1;
659     uint32_t effective_ker_width = dilation_width * (ker_width - 1) + 1;
660
661     uint32_t output_height = 0;
662     uint32_t output_width = 0;
663
664     if (node->padding() == luci::Padding::VALID)
665     {
666       output_height = (input_height + stride_height - effective_ker_height) / stride_height;
667       output_width = (input_width + stride_width - effective_ker_width) / stride_width;
668     }
669     else if (node->padding() == luci::Padding::SAME)
670     {
671       output_height = (input_height + stride_height - 1) / stride_height;
672       output_width = (input_width + stride_width - 1) / stride_width;
673     }
674     else
675       LUCI_ASSERT(false, "Wrong padding type");
676
677     loco::TensorShape ofm_shape;
678     ofm_shape.rank(4);
679     ofm_shape.dim(0) = ifm_shape.dim(0);
680     ofm_shape.dim(1) = output_height;
681     ofm_shape.dim(2) = output_width;
682     ofm_shape.dim(3) = ker_shape.dim(3);
683
684     return loco::NodeShape{ofm_shape};
685   }
686
687   loco::NodeShape visit(const luci::CircleDiv *node) final { return broadcast_xy(node); }
688
689   loco::NodeShape visit(const luci::CircleElu *node) final
690   {
691     auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
692
693     return loco::NodeShape{input_shape};
694   }
695
696   loco::NodeShape visit(const luci::CircleEqual *node) final { return broadcast_xy(node); }
697
698   loco::NodeShape visit(const luci::CircleExp *node) final { return use_x(node); }
699
700   loco::NodeShape visit(const luci::CircleExpandDims *node) final
701   {
702     const loco::DataType S32 = loco::DataType::S32;
703     auto x_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
704     if (x_shape.rank() == 0)
705     {
706       // This maybe for unknown shape. We use shape from the node itself.
707       return use_own(node);
708     }
709     auto const_axis = loco::must_cast<luci::CircleConst *>(node->axis());
710     LUCI_ASSERT(const_axis->dtype() == S32, "Only support int32 CircleConst for axis");
711     if (const_axis->rank() != 0 && const_axis->rank() != 1)
712     {
713       INTERNAL_EXN_V("Non-scalar axis in OP", node->opnum());
714     }
715     int32_t axis = const_axis->at<S32>(0);
716     LUCI_ASSERT((axis <= static_cast<int32_t>(x_shape.rank())) &&
717                     (axis >= -1 - static_cast<int32_t>(x_shape.rank())),
718                 "Axis has to be between [-(D+1), D], where D is rank of input.");
719     size_t positive_axis = axis < 0 ? x_shape.rank() + axis + 1 : axis;
720     loco::TensorShape output_shape;
721     output_shape.rank(x_shape.rank() + 1);
722     size_t i = 0;
723     for (; i < positive_axis; i++)
724       output_shape.dim(i) = x_shape.dim(i);
725     output_shape.dim(i) = loco::Dimension(1);
726     for (; i < x_shape.rank(); i++)
727       output_shape.dim(i + 1) = x_shape.dim(i);
728     return loco::NodeShape{output_shape};
729   }
730
731   loco::NodeShape visit(const luci::CircleFill *node) final
732   {
733     loco::TensorShape shape;
734     {
735       LUCI_ASSERT(node->dims(), "dims input should not be nullptr");
736
737       auto dims_node = dynamic_cast<luci::CircleConst *>(node->dims());
738       if (dims_node != nullptr)
739       {
740         // Only support node with S32
741         LUCI_ASSERT(dims_node->dtype() == loco::DataType::S32, "Only support int32 CircleConst");
742
743         if (dims_node->rank() != 1)
744           INTERNAL_EXN_V("Only support rank 1 CircleConst", oops::to_uint32(dims_node->rank()));
745
746         shape.rank(dims_node->dim(0).value());
747
748         for (uint32_t axis = 0; axis < shape.rank(); ++axis)
749         {
750           shape.dim(axis) = dims_node->at<loco::DataType::S32>(axis);
751         }
752       }
753       else
754       {
755         shape = own_shape(node);
756       }
757     }
758
759     return loco::NodeShape{shape};
760   }
761
762   loco::NodeShape visit(const luci::CircleFloor *node) final { return use_x(node); }
763
764   loco::NodeShape visit(const luci::CircleFloorDiv *node) final { return broadcast_xy(node); }
765
766   loco::NodeShape visit(const luci::CircleFloorMod *node) final { return broadcast_xy(node); }
767
768   loco::NodeShape visit(const luci::CircleFullyConnected *node) final
769   {
770     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
771     auto weights_shape = loco::shape_get(node->weights()).as<loco::TensorShape>();
772
773     // Checking shape capability for fully connected layer
774     // Input: a tensor of at least rank 2 [D1, D2, ... Dn]
775     // Weight: [# of units, K]
776     // Output: [D1 * D2 * ... * Dn / K, # of units]
777     if (input_shape.rank() < 2 || weights_shape.rank() != 2)
778     {
779       // Return node own shape if shape inference is not possible
780       return use_own(node);
781     }
782
783     uint32_t input_size = 1;
784     for (uint32_t i = 0; i < input_shape.rank(); i++)
785     {
786       input_size = input_size * input_shape.dim(i).value();
787     }
788     const uint32_t batch_size = input_size / weights_shape.dim(1).value();
789     loco::TensorShape out_shape;
790     out_shape.rank(2);
791     out_shape.dim(0) = batch_size;
792     out_shape.dim(1) = weights_shape.dim(0);
793
794     return loco::NodeShape{out_shape};
795   }
796
797   loco::NodeShape visit(const luci::CircleGather *node) final
798   {
799     loco::TensorShape output_shape;
800
801     const auto input_shape = loco::shape_get(node->params()).as<loco::TensorShape>();
802     const auto positions_shape = loco::shape_get(node->indices()).as<loco::TensorShape>();
803     int32_t axis = node->axis();
804
805     // If CircleGather input has a dynamic shape, it can't inference this shape. So, it returns the
806     // shape that node already has.
807     if (input_shape.rank() == 0 || positions_shape.rank() == 0)
808       return use_own(node);
809
810     if (axis < 0)
811       axis += input_shape.rank();
812
813     output_shape.rank(input_shape.rank() - 1 + positions_shape.rank());
814     int32_t outdim_index = 0;
815     for (int32_t i = 0; i < axis; ++i)
816       output_shape.dim(outdim_index++) = input_shape.dim(i);
817     for (uint32_t i = 0; i < positions_shape.rank(); ++i)
818       output_shape.dim(outdim_index++) = positions_shape.dim(i);
819     for (uint32_t i = axis + 1; i < input_shape.rank(); ++i)
820       output_shape.dim(outdim_index++) = input_shape.dim(i);
821
822     return loco::NodeShape{output_shape};
823   }
824
825   loco::NodeShape visit(const luci::CircleGatherNd *node) final
826   {
827     loco::TensorShape output_shape;
828
829     const auto params_shape = loco::shape_get(node->params()).as<loco::TensorShape>();
830     const auto indices_shape = loco::shape_get(node->indices()).as<loco::TensorShape>();
831
832     const auto params_rank = params_shape.rank();
833     const auto indices_rank = indices_shape.rank();
834
835     // see https://www.tensorflow.org/api_docs/python/tf/gather_nd
836     // output.shape = indices.shape[:-1] + params.shape[indices.shape[-1]:]
837     // batch_dims isn't supported in tflite
838
839     // TODO: replace exceptions with setting shape to unknown?
840
841     if (!indices_shape.dim(indices_rank - 1).known())
842       INTERNAL_EXN("Last indices dimension is unknown");
843
844     auto indices_last_dim = indices_shape.dim(indices_rank - 1).value();
845
846     if (indices_last_dim > params_rank)
847       INTERNAL_EXN("Last indices dimension should be <= params rank");
848
849     const uint32_t output_rank = indices_rank + params_rank - indices_last_dim - 1;
850
851     output_shape.rank(output_rank);
852
853     uint32_t output_index = 0;
854     for (uint32_t i = 0; i < indices_rank - 1; ++i)
855     {
856       auto &dim = indices_shape.dim(i);
857       if (!dim.known())
858         INTERNAL_EXN("Unknown indices dimension is unsupported");
859       output_shape.dim(output_index++).set(dim.value());
860     }
861
862     for (uint32_t i = indices_last_dim; i < params_rank; ++i)
863     {
864       auto &dim = params_shape.dim(i);
865       if (!dim.known())
866         INTERNAL_EXN("Unknown params dimension is unsupported");
867       output_shape.dim(output_index++).set(dim.value());
868     }
869
870     return loco::NodeShape{output_shape};
871   }
872
873   loco::NodeShape visit(const luci::CircleGreater *node) final { return broadcast_xy(node); }
874
875   loco::NodeShape visit(const luci::CircleGreaterEqual *node) final { return broadcast_xy(node); }
876
877   loco::NodeShape visit(const luci::CircleIf *node) final
878   {
879     // Shape of CircleIf is not used. Just use input 0
880     assert(node->input_count() > 0);
881     const auto input_shape = loco::shape_get(node->input(0)).as<loco::TensorShape>();
882     return loco::NodeShape{input_shape};
883   }
884
885   loco::NodeShape visit(const luci::CircleL2Normalize *node) final { return use_x(node); }
886
887   loco::NodeShape visit(const luci::CircleL2Pool2D *node) final
888   {
889     return infer_pool_2d_shape(node);
890   }
891
892   loco::NodeShape visit(const luci::CircleLeakyRelu *node) final
893   {
894     const auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
895     return loco::NodeShape{input_shape};
896   }
897
898   loco::NodeShape visit(const luci::CircleLess *node) final { return broadcast_xy(node); }
899
900   loco::NodeShape visit(const luci::CircleLessEqual *node) final { return broadcast_xy(node); }
901
902   loco::NodeShape visit(const luci::CircleLocalResponseNormalization *node) final
903   {
904     const auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
905     return loco::NodeShape{input_shape};
906   }
907
908   loco::NodeShape visit(const luci::CircleLog *node) final { return use_x(node); }
909
910   loco::NodeShape visit(const luci::CircleLogicalAnd *node) final { return use_x(node); }
911
912   loco::NodeShape visit(const luci::CircleLogicalNot *node) final { return use_x(node); }
913
914   loco::NodeShape visit(const luci::CircleLogicalOr *node) final { return use_x(node); }
915
916   loco::NodeShape visit(const luci::CircleLogistic *node) final { return use_x(node); }
917
918   loco::NodeShape visit(const luci::CircleMatrixSetDiag *node) final
919   {
920     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
921     auto diagonal_shape = loco::shape_get(node->diagonal()).as<loco::TensorShape>();
922
923     auto rank = diagonal_shape.rank();
924
925     LUCI_ASSERT(rank == input_shape.rank() - 1, "diagonal rank = input rank - 1");
926
927     for (uint32_t i = 0; i < rank - 1; i++)
928     {
929       LUCI_ASSERT(diagonal_shape.dim(i) == input_shape.dim(i), "diagonal dims = input dims");
930     }
931
932     auto dim = std::min(input_shape.dim(rank - 1).value(), input_shape.dim(rank).value());
933
934     LUCI_ASSERT(dim == diagonal_shape.dim(rank - 1), "Max diag len error");
935
936     return loco::NodeShape{input_shape};
937   }
938
939   loco::NodeShape visit(const luci::CircleLogSoftmax *node) final { return use_logits(node); }
940
941   loco::NodeShape visit(const luci::CircleMatrixDiag *node) final
942   {
943     loco::TensorShape output_shape;
944
945     auto diagonal_shape = loco::shape_get(node->diagonal()).as<loco::TensorShape>();
946     auto rank = diagonal_shape.rank();
947
948     output_shape.rank(rank + 1);
949
950     for (uint32_t i = 0; i < rank; i++)
951     {
952       output_shape.dim(i) = diagonal_shape.dim(i);
953     }
954
955     output_shape.dim(rank) = diagonal_shape.dim(rank - 1);
956
957     return loco::NodeShape{output_shape};
958   }
959
960   loco::NodeShape visit(const luci::CircleMaximum *node) final { return broadcast_xy(node); }
961
962   loco::NodeShape visit(const luci::CircleMaxPool2D *node) final
963   {
964     return infer_pool_2d_shape(node);
965   }
966
967   loco::NodeShape visit(const luci::CircleMean *node) final
968   {
969     auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
970     return loco::NodeShape{output_shape};
971   }
972
973   loco::NodeShape visit(const luci::CircleMinimum *node) final { return broadcast_xy(node); }
974
975   loco::NodeShape visit(const luci::CircleMirrorPad *node) final
976   {
977     const loco::DataType S32 = loco::DataType::S32;
978
979     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
980     auto paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
981
982     // TODO support non-const case
983     // TODO support other data type
984     LUCI_ASSERT(paddings->dtype() == S32, "Only support int 32 for now");
985     LUCI_ASSERT(paddings->rank() == 2, "paddings should be rank 2")
986
987     int32_t n = paddings->dim(0).value();
988     int32_t v = paddings->dim(1).value();
989
990     LUCI_ASSERT(v == 2, "paddings should be [n, 2]");
991     LUCI_ASSERT(n == int32_t(input_shape.rank()),
992                 "paddings [n, 2] should have same value of input rank");
993
994     loco::TensorShape output_shape;
995
996     output_shape.rank(input_shape.rank());
997     for (int32_t ni = 0; ni < n; ++ni)
998     {
999       int32_t idx = ni * 2;
1000       int value = input_shape.dim(ni).value();
1001       value += paddings->at<S32>(idx + 0); // left
1002       value += paddings->at<S32>(idx + 1); // right
1003       output_shape.dim(ni) = value;
1004     }
1005
1006     return loco::NodeShape{output_shape};
1007   }
1008
1009   loco::NodeShape visit(const luci::CircleMul *node) final { return broadcast_xy(node); }
1010
1011   loco::NodeShape visit(const luci::CircleNeg *node) final { return use_x(node); }
1012
1013   loco::NodeShape visit(const luci::CircleNonMaxSuppressionV4 *node) final
1014   {
1015     const auto boxes_shape = loco::shape_get(node->boxes()).as<loco::TensorShape>();
1016     return loco::NodeShape{boxes_shape};
1017   }
1018
1019   loco::NodeShape visit(const luci::CircleNotEqual *node) final { return broadcast_xy(node); }
1020
1021   loco::NodeShape visit(const luci::CircleOneHot *node) final
1022   {
1023     const loco::DataType S32 = loco::DataType::S32;
1024     auto indices_shape = loco::shape_get(node->indices()).as<loco::TensorShape>();
1025     // Only support OneHot node's depth() is CircleConst with type S32
1026     // TODO support depth with other types
1027     auto depth = loco::must_cast<luci::CircleConst *>(node->depth());
1028     LUCI_ASSERT(depth->dtype() == S32, "Only support int32 CircleConst");
1029     if (depth->rank() != 0)
1030       INTERNAL_EXN_V("Only support rank 0 CircleOneHot in Depth", oops::to_uint32(depth->rank()));
1031     loco::TensorShape output_shape;
1032     output_shape.rank(indices_shape.rank() + 1);
1033     auto axis = node->axis();
1034     if (axis < 0)
1035       axis += indices_shape.rank() + 1;
1036     LUCI_ASSERT(0 <= axis, "Axis is out of range");
1037     LUCI_ASSERT(static_cast<uint32_t>(axis) <= indices_shape.rank(), "Axis is out of range");
1038     uint32_t j = 0;
1039     for (uint32_t i = 0; i < output_shape.rank(); i++)
1040     {
1041       if (i == static_cast<uint32_t>(axis))
1042       {
1043         output_shape.dim(i) = depth->at<S32>(0);
1044       }
1045       else
1046       {
1047         output_shape.dim(i) = indices_shape.dim(j++);
1048       }
1049     }
1050     return loco::NodeShape{output_shape};
1051   }
1052
1053   loco::NodeShape visit(const luci::CirclePack *node) final
1054   {
1055     LUCI_ASSERT(node->values_count() > 0, "Only support one or more inputs");
1056
1057     auto first_shape = loco::shape_get(node->values(0)).as<loco::TensorShape>();
1058     // Make sure all inputs have the same shape.
1059     for (uint32_t i = 1; i < node->values_count(); ++i)
1060     {
1061       auto in_shape = loco::shape_get(node->values(i)).as<loco::TensorShape>();
1062       LUCI_ASSERT(loco::NodeShape{first_shape} == loco::NodeShape{in_shape},
1063                   "All inputs must have the same shape");
1064     }
1065
1066     // Checking shape capability for pack layer
1067     // Input: tensors [D1, D2, ... Dn]
1068     // Axis: K
1069     // Output: [D1, D2, ... , D_K-1, n, D_K+1, ... Dn]
1070     auto axis = node->axis();
1071     if (axis < 0)
1072       axis += first_shape.rank() + 1;
1073
1074     LUCI_ASSERT(0 <= axis, "Axis is out of range");
1075     LUCI_ASSERT(static_cast<uint32_t>(axis) <= first_shape.rank(), "Axis is out of range");
1076
1077     loco::TensorShape output_shape;
1078     output_shape.rank(first_shape.rank() + 1);
1079
1080     uint32_t j = 0;
1081     for (uint32_t i = 0; i < output_shape.rank(); ++i)
1082     {
1083       if (i == static_cast<uint32_t>(axis))
1084       {
1085         output_shape.dim(i) = node->values_count();
1086       }
1087       else
1088       {
1089         output_shape.dim(i) = first_shape.dim(j++);
1090       }
1091     }
1092
1093     return loco::NodeShape{output_shape};
1094   }
1095
1096   loco::NodeShape visit(const luci::CirclePad *node) final
1097   {
1098     const loco::DataType S32 = loco::DataType::S32;
1099
1100     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1101     auto paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
1102
1103     // TODO support non-const case
1104     // TODO support other data type
1105     LUCI_ASSERT(paddings->dtype() == S32, "Only support int 32 for now");
1106     LUCI_ASSERT(paddings->rank() == 2, "paddings should be rank 2")
1107
1108     int32_t n = paddings->dim(0).value();
1109     int32_t v = paddings->dim(1).value();
1110
1111     LUCI_ASSERT(v == 2, "paddings should be [n, 2]");
1112     LUCI_ASSERT(n == int32_t(input_shape.rank()),
1113                 "paddings [n, 2] should have same value of input rank");
1114
1115     loco::TensorShape output_shape;
1116
1117     output_shape.rank(input_shape.rank());
1118     for (int32_t ni = 0; ni < n; ++ni)
1119     {
1120       int32_t idx = ni * 2;
1121       int value = input_shape.dim(ni).value();
1122       value += paddings->at<S32>(idx + 0); // left
1123       value += paddings->at<S32>(idx + 1); // right
1124       output_shape.dim(ni) = value;
1125     }
1126
1127     return loco::NodeShape{output_shape};
1128   }
1129
1130   loco::NodeShape visit(const luci::CirclePow *node) final { return broadcast_xy(node); }
1131
1132   loco::NodeShape visit(const luci::CirclePRelu *node) final
1133   {
1134     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1135     auto alpha_shape = loco::shape_get(node->alpha()).as<loco::TensorShape>();
1136
1137     auto output_shape = broadcast_shape(input_shape, alpha_shape);
1138
1139     return loco::NodeShape{output_shape};
1140   }
1141
1142   loco::NodeShape visit(const luci::CircleRange *node) final
1143   {
1144     loco::TensorShape output_shape;
1145     output_shape.rank(1);
1146
1147     auto start_node = dynamic_cast<luci::CircleConst *>(node->start());
1148     auto limit_node = dynamic_cast<luci::CircleConst *>(node->limit());
1149     auto delta_node = dynamic_cast<luci::CircleConst *>(node->delta());
1150
1151     if (start_node == nullptr || limit_node == nullptr || delta_node == nullptr)
1152     {
1153       return use_own(node);
1154     }
1155
1156     double start = 0, limit = 0, delta = 0;
1157
1158 #define GET_RANGE_PARAM(DT)         \
1159   start = start_node->scalar<DT>(); \
1160   limit = limit_node->scalar<DT>(); \
1161   delta = delta_node->scalar<DT>();
1162
1163     switch (start_node->dtype())
1164     {
1165       case loco::DataType::FLOAT32:
1166         GET_RANGE_PARAM(loco::DataType::FLOAT32)
1167         break;
1168       case loco::DataType::S32:
1169         GET_RANGE_PARAM(loco::DataType::S32)
1170         break;
1171       default:
1172         INTERNAL_EXN("Range data type not supported");
1173     }
1174
1175 #undef GET_RANGE_PARAM
1176
1177     if (delta == 0)
1178       INTERNAL_EXN("Delta can not be zero");
1179
1180     output_shape.dim(0) = ceil((limit - start) / delta);
1181
1182     return loco::NodeShape{output_shape};
1183   }
1184
1185   loco::NodeShape visit(const luci::CircleRank *) final
1186   {
1187     loco::TensorShape shape_output;
1188     shape_output.rank(0);
1189
1190     return loco::NodeShape{shape_output};
1191   }
1192
1193   loco::NodeShape visit(const luci::CircleReduceAny *node) final
1194   {
1195     auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
1196     return loco::NodeShape{output_shape};
1197   }
1198
1199   loco::NodeShape visit(const luci::CircleReduceMax *node) final
1200   {
1201     auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
1202     return loco::NodeShape{output_shape};
1203   }
1204
1205   loco::NodeShape visit(const luci::CircleReduceMin *node) final
1206   {
1207     auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
1208     return loco::NodeShape{output_shape};
1209   }
1210
1211   loco::NodeShape visit(const luci::CircleReduceProd *node) final
1212   {
1213     auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
1214     return loco::NodeShape{output_shape};
1215   }
1216
1217   loco::NodeShape visit(const luci::CircleRelu *node) final
1218   {
1219     auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
1220
1221     return loco::NodeShape{input_shape};
1222   }
1223
1224   loco::NodeShape visit(const luci::CircleRelu6 *node) final
1225   {
1226     auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
1227
1228     return loco::NodeShape{input_shape};
1229   }
1230
1231   loco::NodeShape visit(const luci::CircleReluN1To1 *node) final
1232   {
1233     auto input_shape = loco::shape_get(node->features()).as<loco::TensorShape>();
1234
1235     return loco::NodeShape{input_shape};
1236   }
1237
1238   /**
1239    * @note  CircleReshape has new shape info in two places: 2nd input and attribute.
1240    *        This shape inference uses shape from input 'shape' node when it's constant.
1241    *        If not, shape will be from node itself. shape from attribute is not used.
1242    *
1243    * TODO Change this policy when not appropriate
1244    */
1245   loco::NodeShape visit(const luci::CircleReshape *node) final
1246   {
1247     LOGGER(l);
1248
1249     const loco::DataType S32 = loco::DataType::S32;
1250
1251     loco::TensorShape shape_by_input;
1252     {
1253       LUCI_ASSERT(node->shape(), "2nd input shape() should not be nullptr");
1254
1255       // Only support node's shape() is CircleConst with S32
1256       // TODO support other node with other types
1257       auto const_shape_node = dynamic_cast<luci::CircleConst *>(node->shape());
1258       if (const_shape_node != nullptr)
1259       {
1260         LUCI_ASSERT(const_shape_node->dtype() == S32, "Only support int32 CircleConst");
1261
1262         shape_by_input.rank(const_shape_node->size<S32>());
1263
1264         for (uint32_t axis = 0; axis < shape_by_input.rank(); ++axis)
1265         {
1266           shape_by_input.dim(axis) = const_shape_node->at<S32>(axis);
1267         }
1268       }
1269       else
1270       {
1271         // We use shape from the node itself
1272         shape_by_input = own_shape(node);
1273       }
1274     }
1275
1276     loco::TensorShape shape_by_attr;
1277     {
1278       shape_by_attr.rank(node->newShape()->rank());
1279
1280       for (uint32_t axis = 0; axis < shape_by_attr.rank(); ++axis)
1281       {
1282         shape_by_attr.dim(axis) = node->newShape()->dim(axis);
1283       }
1284     }
1285
1286     if (!(shape_by_input == shape_by_attr))
1287     {
1288       INFO(l) << "CircleReshape: Two new shape information mismatched : " << std::endl;
1289       INFO(l) << "   shape_by_input : " << shape_by_input << std::endl;
1290       INFO(l) << "   shape_by_attr : " << shape_by_attr << std::endl;
1291     }
1292
1293     loco::TensorShape output_shape = shape_by_input;
1294
1295     // One of the dimensions can have special value -1, meaning its actual value should be inferred.
1296     const auto input_shape = loco::shape_get(node->tensor()).as<loco::TensorShape>();
1297     const uint32_t input_element_count = loco::element_count(&input_shape);
1298     uint32_t output_element_count = 1;
1299     uint32_t unknown_dim_index = UINT32_MAX;
1300     for (uint32_t dim_index = 0; dim_index < output_shape.rank(); ++dim_index)
1301     {
1302       const uint32_t dim_value = output_shape.dim(dim_index).value();
1303       if (static_cast<int>(dim_value) == -1)
1304       {
1305         LUCI_ASSERT(unknown_dim_index == UINT32_MAX, "More than one unknown dimension");
1306         unknown_dim_index = dim_index;
1307       }
1308       else
1309       {
1310         output_element_count *= dim_value;
1311       }
1312     }
1313     if (unknown_dim_index != UINT32_MAX)
1314     {
1315       output_shape.dim(unknown_dim_index) = input_element_count / output_element_count;
1316     }
1317
1318     return loco::NodeShape{output_shape};
1319   }
1320
1321   loco::NodeShape visit(const luci::CircleResizeBilinear *node) final
1322   {
1323     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1324
1325     if (input_shape.rank() != 4)
1326       INTERNAL_EXN("Expected ResizeBilinear input to have rank 4");
1327
1328     auto *const_node = loco::must_cast<luci::CircleConst *>(node->size());
1329
1330     if (const_node->dtype() != loco::DataType::S32)
1331       INTERNAL_EXN("Only S32 datatype is supported for ResizeBilinear size");
1332
1333     if (const_node->rank() != 1)
1334       INTERNAL_EXN("Expected size tensor of rank 1");
1335
1336     if (const_node->dim(0).value() != 2)
1337       INTERNAL_EXN("Expected size tensor with shape [2]");
1338
1339     loco::TensorShape output_shape;
1340     output_shape.rank(4);
1341     output_shape.dim(0) = input_shape.dim(0);
1342     output_shape.dim(1) = const_node->at<loco::DataType::S32>(0);
1343     output_shape.dim(2) = const_node->at<loco::DataType::S32>(1);
1344     output_shape.dim(3) = input_shape.dim(3);
1345
1346     return loco::NodeShape{output_shape};
1347   }
1348
1349   loco::NodeShape visit(const luci::CircleResizeNearestNeighbor *node) final
1350   {
1351     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1352
1353     if (input_shape.rank() != 4)
1354       INTERNAL_EXN("Expected ResizeNearesNeighbor input to have rank 4");
1355
1356     auto *const_node = loco::must_cast<luci::CircleConst *>(node->size());
1357
1358     if (const_node->dtype() != loco::DataType::S32)
1359       INTERNAL_EXN("Only S32 datatype is supported for ResizeNearesNeighbor size");
1360
1361     if (const_node->rank() != 1)
1362       INTERNAL_EXN("Expected size tensor of rank 1");
1363
1364     if (const_node->dim(0).value() != 2)
1365       INTERNAL_EXN("Expected size tensor with shape [2]");
1366
1367     loco::TensorShape output_shape;
1368     output_shape.rank(4);
1369     output_shape.dim(0) = input_shape.dim(0);
1370     output_shape.dim(1) = const_node->at<loco::DataType::S32>(0);
1371     output_shape.dim(2) = const_node->at<loco::DataType::S32>(1);
1372     output_shape.dim(3) = input_shape.dim(3);
1373
1374     return loco::NodeShape{output_shape};
1375   }
1376
1377   loco::NodeShape visit(const luci::CircleReverseSequence *node) final
1378   {
1379     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1380
1381     return loco::NodeShape{input_shape};
1382   }
1383
1384   loco::NodeShape visit(const luci::CircleRound *node) final { return use_x(node); }
1385
1386   loco::NodeShape visit(const luci::CircleReverseV2 *node) final
1387   {
1388     auto input_shape = loco::shape_get(node->tensor()).as<loco::TensorShape>();
1389
1390     LUCI_ASSERT(loco::shape_get(node->axis()).as<loco::TensorShape>().rank() == 1,
1391                 "Tensor must be 1-D");
1392
1393     return loco::NodeShape{input_shape};
1394   }
1395
1396   loco::NodeShape visit(const luci::CircleRsqrt *node) final { return use_x(node); }
1397
1398   loco::NodeShape visit(const luci::CircleScatterNd *node) final
1399   {
1400     loco::TensorShape output_shape;
1401
1402     auto shape_node = loco::must_cast<luci::CircleConst *>(node->shape());
1403
1404     const loco::DataType S32 = loco::DataType::S32;
1405     const loco::DataType S64 = loco::DataType::S64;
1406
1407     std::vector<int64_t> vect_shape;
1408
1409     if (shape_node->dtype() == S32)
1410       vect_shape = vector_from_constant<S32>(shape_node);
1411     else if (shape_node->dtype() == S64)
1412       vect_shape = vector_from_constant<S64>(shape_node);
1413     else
1414       LUCI_ASSERT(false, "Only support int32/int64 for shape()");
1415
1416     output_shape.rank(vect_shape.size());
1417     for (uint32_t i = 0; i < vect_shape.size(); ++i)
1418       output_shape.dim(i) = vect_shape[i];
1419
1420     return loco::NodeShape{output_shape};
1421   }
1422
1423   loco::NodeShape visit(const luci::CircleSegmentSum *node) final
1424   {
1425     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1426     auto segment_shape = loco::shape_get(node->segment_ids()).as<loco::TensorShape>();
1427
1428     LUCI_ASSERT(segment_shape.rank() == 1, "segment_ids must be 1-D tensor");
1429     LUCI_ASSERT(segment_shape.dim(0).value() == input_shape.dim(0).value(),
1430                 "segment_ids size must be equal to the size of data's first dimension");
1431
1432     auto ids_shape_value = loco::must_cast<luci::CircleConst *>(node->segment_ids());
1433
1434     std::vector<int64_t> vect_ids;
1435
1436     if (ids_shape_value->dtype() == loco::DataType::S32)
1437       vect_ids = vector_from_constant<loco::DataType::S32>(ids_shape_value);
1438
1439     LUCI_ASSERT(std::is_sorted(vect_ids.begin(), vect_ids.end()),
1440                 "segment_ids values should be sorted")
1441
1442     loco::TensorShape output_shape;
1443
1444     output_shape.rank(input_shape.rank());
1445
1446     for (uint32_t i = 1; i < input_shape.rank(); ++i)
1447       output_shape.dim(i) = input_shape.dim(i);
1448
1449     output_shape.dim(0) = vect_ids.back() + 1;
1450
1451     return loco::NodeShape{output_shape};
1452   }
1453
1454   loco::NodeShape visit(const luci::CircleSelect *node) final
1455   {
1456     auto t_shape = loco::shape_get(node->t()).as<loco::TensorShape>();
1457     assert(t_shape == loco::shape_get(node->e()).as<loco::TensorShape>());
1458
1459     // condition shape validation
1460     auto c_shape = loco::shape_get(node->condition()).as<loco::TensorShape>();
1461     if (c_shape.rank() != t_shape.rank())
1462     {
1463       if (c_shape.rank() != 0 && c_shape.rank() != 1)
1464         INTERNAL_EXN_V("CircleSelect condition rank is not 0 nor 1: ", c_shape.rank());
1465
1466       if (c_shape.rank() == 1)
1467       {
1468         if (c_shape.dim(0).value() != t_shape.dim(0).value())
1469           INTERNAL_EXN("CircleSelect condition dim(0) should match with t.dim(0)");
1470       }
1471     }
1472
1473     return loco::NodeShape{t_shape};
1474   }
1475
1476   loco::NodeShape visit(const luci::CircleSelectV2 *node) final
1477   {
1478     auto c_shape = loco::shape_get(node->condition()).as<loco::TensorShape>();
1479     auto t_shape = loco::shape_get(node->t()).as<loco::TensorShape>();
1480     auto e_shape = loco::shape_get(node->e()).as<loco::TensorShape>();
1481
1482     // validate ability to broadcast shapes to each other
1483     auto b_shape = broadcast_shape(broadcast_shape(c_shape, t_shape), e_shape);
1484     return loco::NodeShape{b_shape};
1485   }
1486
1487   loco::NodeShape visit(const luci::CircleShape *node) final
1488   {
1489     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1490
1491     loco::TensorShape output_shape;
1492
1493     output_shape.rank(1);
1494     output_shape.dim(0) = input_shape.rank();
1495
1496     return loco::NodeShape{output_shape};
1497   }
1498
1499   loco::NodeShape visit(const luci::CircleSin *node) final { return use_x(node); }
1500
1501   loco::NodeShape visit(const luci::CircleSlice *node) final
1502   {
1503     const loco::DataType S32 = loco::DataType::S32;
1504     const loco::DataType S64 = loco::DataType::S64;
1505
1506     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1507
1508     auto const_begin = loco::must_cast<luci::CircleConst *>(node->begin());
1509     auto const_size = loco::must_cast<luci::CircleConst *>(node->size());
1510
1511     loco::TensorShape output_shape;
1512     std::vector<int64_t> vect_begin; // to hold both S32/S64, we use int64_t
1513     std::vector<int64_t> vect_size;
1514
1515     if (const_begin->dtype() == S32)
1516       vect_begin = vector_from_constant<S32>(const_begin);
1517     else if (const_begin->dtype() == S64)
1518       vect_begin = vector_from_constant<S64>(const_begin);
1519     else
1520       LUCI_ASSERT(false, "Only support int32/int64 for begin()");
1521
1522     if (const_size->dtype() == S32)
1523       vect_size = vector_from_constant<S32>(const_size);
1524     else if (const_size->dtype() == S64)
1525       vect_size = vector_from_constant<S64>(const_size);
1526     else
1527       LUCI_ASSERT(false, "Only support int32/int64 for size()");
1528
1529     assert(input_shape.rank() == vect_begin.size());
1530     assert(input_shape.rank() == vect_size.size());
1531
1532     output_shape.rank(vect_begin.size());
1533     for (uint32_t idx = 0; idx < vect_begin.size(); ++idx)
1534     {
1535       auto size = vect_size.at(idx);
1536       if (size == -1)
1537       {
1538         size = input_shape.dim(idx).value() - vect_begin.at(idx);
1539       }
1540       output_shape.dim(idx) = size;
1541     }
1542
1543     return loco::NodeShape{output_shape};
1544   }
1545
1546   loco::NodeShape visit(const luci::CircleSoftmax *node) final { return use_logits(node); }
1547
1548   loco::NodeShape visit(const luci::CircleSpaceToBatchND *node) final
1549   {
1550     const loco::DataType S32 = loco::DataType::S32;
1551
1552     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1553     // Support only input rank is 3 and 4
1554     assert(input_shape.rank() == 3 || input_shape.rank() == 4);
1555
1556     // Only support block_shape() with S32 type CircleConst for now
1557     auto const_block_shape = loco::must_cast<luci::CircleConst *>(node->block_shape());
1558     LUCI_ASSERT(const_block_shape->dtype() == S32, "Only support int32 block_shape");
1559
1560     // Only support paddings() with S32 type CircleConst for now
1561     auto const_paddings = loco::must_cast<luci::CircleConst *>(node->paddings());
1562     LUCI_ASSERT(const_paddings->dtype() == S32, "Only support int32 paddings");
1563
1564     auto const_block_shape_shape = loco::shape_get(const_block_shape).as<loco::TensorShape>();
1565     auto const_paddings_shape = loco::shape_get(const_paddings).as<loco::TensorShape>();
1566     assert(const_block_shape_shape.rank() == 1);
1567     assert(const_paddings_shape.rank() == 2);
1568
1569     int32_t input_spatial_dim = input_shape.rank() - 2;
1570     assert(const_block_shape_shape.dim(0) == input_spatial_dim);
1571     assert(const_paddings_shape.dim(0) == input_spatial_dim);
1572     assert(const_paddings_shape.dim(1) == 2);
1573
1574     // Check all values of block_shape >= 1
1575     uint32_t ele_count = const_block_shape->size<S32>();
1576     for (uint32_t e = 0; e < ele_count; ++e)
1577     {
1578       auto val = const_block_shape->at<S32>(e);
1579       if (val < 1)
1580       {
1581         INTERNAL_EXN_V("All values of block_shape >= 1: ", e);
1582       }
1583     }
1584
1585     loco::TensorShape shape_output;
1586
1587     shape_output.rank(input_shape.rank());
1588
1589     int32_t output_batch_size = input_shape.dim(0).value();
1590     for (int32_t dim = 0; dim < input_spatial_dim; ++dim)
1591     {
1592       int dim_size = input_shape.dim(dim + 1).value();
1593       dim_size += const_paddings->at<S32>(dim * 2);
1594       dim_size += const_paddings->at<S32>(dim * 2 + 1);
1595       shape_output.dim(dim + 1) = dim_size / const_block_shape->at<S32>(dim);
1596
1597       assert(dim_size % const_block_shape->at<S32>(dim) == 0);
1598       output_batch_size = output_batch_size * const_block_shape->at<S32>(dim);
1599     }
1600     shape_output.dim(0) = output_batch_size;
1601     shape_output.dim(input_shape.rank() - 1) = input_shape.dim(input_shape.rank() - 1);
1602
1603     return loco::NodeShape{shape_output};
1604   }
1605
1606   loco::NodeShape visit(const luci::CircleSpaceToDepth *node) final
1607   {
1608     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1609     LUCI_ASSERT(input_shape.rank() == 4, "Only input rank 4 is supported");
1610
1611     // Only data format NHWC is supported
1612     int32_t height = input_shape.dim(1).value();
1613     int32_t width = input_shape.dim(2).value();
1614     int32_t depth = input_shape.dim(3).value();
1615
1616     int block_size = node->block_size();
1617
1618     if (block_size < 2)
1619       INTERNAL_EXN("Block size must be >= 2");
1620
1621     if ((height % block_size) || (width % block_size))
1622     {
1623       INTERNAL_EXN("The input tensor's height and width must be divisible by block_size");
1624     }
1625
1626     loco::TensorShape output_shape;
1627     output_shape.rank(4);
1628
1629     output_shape.dim(0) = input_shape.dim(0).value();
1630     output_shape.dim(1) = height / block_size;
1631     output_shape.dim(2) = width / block_size;
1632     output_shape.dim(3) = block_size * block_size * depth;
1633
1634     return loco::NodeShape{output_shape};
1635   }
1636
1637   loco::NodeShape visit(const luci::CircleSparseToDense *node) final
1638   {
1639     loco::TensorShape shape;
1640     {
1641       LUCI_ASSERT(node->output_shape(), "dims input should not be nullptr");
1642
1643       auto output_shape_node = dynamic_cast<luci::CircleConst *>(node->output_shape());
1644       if (output_shape_node != nullptr)
1645       {
1646         // Only support node with S32
1647         LUCI_ASSERT(output_shape_node->dtype() == loco::DataType::S32,
1648                     "Only support int32 CircleConst");
1649
1650         if (output_shape_node->rank() != 1)
1651           INTERNAL_EXN_V("Only support rank 1 CircleConst",
1652                          oops::to_uint32(output_shape_node->rank()));
1653
1654         shape.rank(output_shape_node->dim(0).value());
1655
1656         for (uint32_t axis = 0; axis < shape.rank(); ++axis)
1657         {
1658           shape.dim(axis) = output_shape_node->at<loco::DataType::S32>(axis);
1659         }
1660       }
1661       else
1662       {
1663         shape = own_shape(node);
1664       }
1665     }
1666
1667     return loco::NodeShape{shape};
1668   }
1669
1670   loco::NodeShape visit(const luci::CircleSplit *node) final
1671   {
1672     // We'll set Split output as same as input so that SplitOut can handle it's own shape
1673     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1674     return loco::NodeShape{input_shape};
1675   }
1676
1677   loco::NodeShape visit(const luci::CircleSplitV *node) final
1678   {
1679     // We'll set SplitV output as same as input so that SplitOut can handle it's own shape
1680     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1681     return loco::NodeShape{input_shape};
1682   }
1683
1684   loco::NodeShape visit(const luci::CircleSqrt *node) final { return use_x(node); }
1685
1686   loco::NodeShape visit(const luci::CircleSquare *node) final { return use_x(node); }
1687
1688   loco::NodeShape visit(const luci::CircleSquaredDifference *node) final
1689   {
1690     return broadcast_xy(node);
1691   }
1692
1693   loco::NodeShape visit(const luci::CircleStridedSlice *node) final
1694   {
1695     auto begin_node = dynamic_cast<luci::CircleConst *>(node->begin());
1696     auto end_node = dynamic_cast<luci::CircleConst *>(node->end());
1697     auto strides_node = dynamic_cast<luci::CircleConst *>(node->strides());
1698
1699     if (begin_node == nullptr || end_node == nullptr || strides_node == nullptr)
1700     {
1701       return use_own(node);
1702     }
1703
1704     loco::TensorShape shape = infer_output_shape(node);
1705     return loco::NodeShape{shape};
1706   }
1707
1708   loco::NodeShape visit(const luci::CircleSqueeze *node) final
1709   {
1710     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1711
1712     // TODO input shape may be unknown before runtime
1713     std::vector<bool> do_squeeze(input_shape.rank(), false);
1714     uint32_t num_squeezed = 0;
1715
1716     if (!node->squeeze_dims().empty())
1717     {
1718       // SqueezeDims not empty, squeeze only dims specified
1719       for (int32_t raw_dim : node->squeeze_dims())
1720       {
1721         int32_t dim = raw_dim < 0 ? raw_dim + input_shape.rank() : raw_dim;
1722
1723         if (dim < 0 || static_cast<uint32_t>(dim) >= input_shape.rank() ||
1724             input_shape.dim(dim).value() != 1)
1725         {
1726           INTERNAL_EXN("invalid dimention specified to Squeeze");
1727         }
1728
1729         if (!do_squeeze[dim])
1730           ++num_squeezed;
1731         do_squeeze[dim] = true;
1732       }
1733     }
1734     else
1735     {
1736       // SqueezeDims empty, squeeze any dims with size == 1
1737       for (uint32_t dim = 0; dim < input_shape.rank(); ++dim)
1738       {
1739         if (input_shape.dim(dim) == 1)
1740         {
1741           do_squeeze[dim] = true;
1742           ++num_squeezed;
1743         }
1744       }
1745     }
1746
1747     loco::TensorShape output_shape;
1748     output_shape.rank(input_shape.rank() - num_squeezed);
1749
1750     for (uint32_t in_dim = 0, out_dim = 0; in_dim < input_shape.rank(); ++in_dim)
1751     {
1752       if (!do_squeeze[in_dim])
1753       {
1754         output_shape.dim(out_dim++) = input_shape.dim(in_dim);
1755       }
1756     }
1757
1758     return loco::NodeShape{output_shape};
1759   }
1760
1761   loco::NodeShape visit(const luci::CircleSub *node) final { return broadcast_xy(node); }
1762
1763   loco::NodeShape visit(const luci::CircleSum *node) final
1764   {
1765     auto output_shape = infer_reducer(node->input(), node->reduction_indices(), node->keep_dims());
1766     return loco::NodeShape{output_shape};
1767   }
1768
1769   loco::NodeShape visit(const luci::CircleTanh *node) final { return use_x(node); }
1770
1771   loco::NodeShape visit(const luci::CircleTile *node) final
1772   {
1773     const loco::DataType S32 = loco::DataType::S32;
1774
1775     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1776     auto multiples = loco::must_cast<luci::CircleConst *>(node->multiples());
1777
1778     // TODO support non-const case
1779     // TODO support S64 type
1780     LUCI_ASSERT(multiples->dtype() == S32, "Only support int32 multiples");
1781     LUCI_ASSERT(multiples->rank() == 1, "multiples should be rank 1")
1782
1783     uint32_t n = multiples->dim(0).value();
1784
1785     LUCI_ASSERT(n == input_shape.rank(), "length of multiples should be the same with input rank");
1786
1787     loco::TensorShape output_shape;
1788
1789     output_shape.rank(input_shape.rank());
1790     for (uint32_t ni = 0; ni < n; ++ni)
1791     {
1792       int32_t multiple = multiples->at<S32>(ni);
1793       output_shape.dim(ni) = input_shape.dim(ni).value() * static_cast<uint32_t>(multiple);
1794     }
1795
1796     return loco::NodeShape{output_shape};
1797   }
1798
1799   loco::NodeShape visit(const luci::CircleTopKV2 *node) final
1800   {
1801     // set shape of this node as same as input
1802     const auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1803     return loco::NodeShape{input_shape};
1804   }
1805
1806   loco::NodeShape visit(const luci::CircleTranspose *node) final
1807   {
1808     auto input_shape = loco::shape_get(node->a()).as<loco::TensorShape>();
1809
1810     auto perm_node = loco::must_cast<luci::CircleConst *>(node->perm());
1811
1812     loco::TensorShape output_shape;
1813     output_shape.rank(input_shape.rank());
1814
1815     assert(perm_node->dtype() == loco::DataType::S32);
1816     assert(input_shape.rank() == perm_node->template size<loco::DataType::S32>());
1817
1818     for (uint32_t out_axis = 0; out_axis < output_shape.rank(); out_axis++)
1819     {
1820       auto in_axis = perm_node->template at<loco::DataType::S32>(out_axis);
1821       output_shape.dim(out_axis) = input_shape.dim(in_axis);
1822     }
1823
1824     return output_shape;
1825   }
1826
1827   loco::NodeShape visit(const luci::CircleUnique *node) final
1828   {
1829     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1830
1831     assert(input_shape.rank() == 1);
1832
1833     loco::TensorShape shape_output;
1834     shape_output = own_shape(node);
1835
1836     return loco::NodeShape{shape_output};
1837   }
1838
1839   loco::NodeShape visit(const luci::CircleTransposeConv *node) final
1840   {
1841     // TransposeConv's output shape is written in its 'inputSizes' argument
1842     auto input_sizes_const = loco::must_cast<luci::CircleConst *>(node->inputSizes());
1843     // TODO support non-const type
1844     LUCI_ASSERT(input_sizes_const->dtype() == loco::DataType::S32, "Only support S32 dtype")
1845     LUCI_ASSERT(input_sizes_const->rank() == 1 && input_sizes_const->dim(0).value() == 4,
1846                 "Only support rank 1 with 4 entries")
1847
1848     loco::TensorShape shape;
1849
1850     shape.rank(4);
1851     for (uint32_t axis = 0; axis < 4; ++axis)
1852       shape.dim(axis) = input_sizes_const->at<loco::DataType::S32>(axis);
1853
1854     return loco::NodeShape{shape};
1855   }
1856
1857   loco::NodeShape visit(const luci::CircleUnpack *node) final
1858   {
1859     // CircleUnpack provides list(array) of Tensors which has one less dimension of the input
1860     // We'll set shape of CircleUnpack to shape of actual outputs
1861     // TODO fix this if any problem rises
1862     auto value_shape = loco::shape_get(node->value()).as<loco::TensorShape>();
1863
1864     auto axis = node->axis();
1865     auto num = node->num();
1866     auto rank = static_cast<int32_t>(value_shape.rank());
1867
1868     if (rank == 0)
1869     {
1870       // Unknown shape
1871       return use_own(node);
1872     }
1873
1874     LUCI_ASSERT(-rank <= axis && axis < rank, "Axis is out of range");
1875
1876     if (axis < 0)
1877       axis += rank;
1878
1879     LUCI_ASSERT(num == static_cast<int32_t>(value_shape.dim(axis).value()),
1880                 "num, axis maybe incorrect");
1881
1882     loco::TensorShape output_shape;
1883     output_shape.rank(rank - 1);
1884
1885     for (int32_t i = 0, o = 0; i < rank; ++i)
1886     {
1887       if (i != axis)
1888         output_shape.dim(o++) = value_shape.dim(i);
1889     }
1890
1891     return loco::NodeShape{output_shape};
1892   }
1893
1894   loco::NodeShape visit(const luci::CircleWhere *node) final { return use_own(node); }
1895
1896   loco::NodeShape visit(const luci::CircleWhile *node) final
1897   {
1898     // Shape of CircleWhile is not used. Just use input 0
1899     assert(node->arity() > 0);
1900     const auto input_shape = loco::shape_get(node->input(0)).as<loco::TensorShape>();
1901     return loco::NodeShape{input_shape};
1902   }
1903
1904   loco::NodeShape visit(const luci::CircleZerosLike *node) final
1905   {
1906     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1907
1908     return loco::NodeShape{input_shape};
1909   }
1910
1911   // Circle Only
1912   loco::NodeShape visit(const luci::CircleBCQFullyConnected *node) final
1913   {
1914     loco::TensorShape out_shape;
1915
1916     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1917     auto weights_clusters = loco::must_cast<luci::CircleConst *>(node->weights_clusters());
1918
1919     LUCI_ASSERT(input_shape.rank() == 2, "Input rank of BCQFullyConnected should be 2");
1920
1921     int32_t qbits_sum = 0;
1922     for (uint32_t i = 0; i < weights_clusters->dim(0).value(); ++i)
1923     {
1924       qbits_sum += weights_clusters->at<loco::DataType::S32>(i * 2 + 1);
1925     }
1926
1927     out_shape.rank(2);
1928     out_shape.dim(0) = qbits_sum;
1929     out_shape.dim(1) = input_shape.dim(1);
1930
1931     return loco::NodeShape{out_shape};
1932   }
1933
1934   loco::NodeShape visit(const luci::CircleBCQGather *node) final
1935   {
1936     loco::TensorShape input_shape;
1937     loco::TensorShape output_shape;
1938
1939     const auto input_binary_shape = loco::shape_get(node->input_binary()).as<loco::TensorShape>();
1940     const auto indices_shape = loco::shape_get(node->indices()).as<loco::TensorShape>();
1941     auto axis = node->axis();
1942
1943     auto input_clusters = loco::must_cast<luci::CircleConst *>(node->input_clusters());
1944     auto qbits_sum = 0;
1945     for (uint32_t i = 0; i < input_clusters->dim(0).value(); ++i)
1946     {
1947       qbits_sum += input_clusters->at<loco::DataType::S32>(i * 2 + 1);
1948     }
1949
1950     input_shape.rank(2);
1951     input_shape.dim(0) = qbits_sum;
1952     input_shape.dim(1) = input_binary_shape.dim(1).value() * 32;
1953
1954     output_shape.rank(input_shape.rank() - 1 + indices_shape.rank());
1955     int32_t outdim_index = 0;
1956     for (int32_t i = 0; i < axis; ++i)
1957       output_shape.dim(outdim_index++) = input_shape.dim(i);
1958     for (uint32_t i = 0; i < indices_shape.rank(); ++i)
1959       output_shape.dim(outdim_index++) = indices_shape.dim(i);
1960     for (uint32_t i = axis + 1; i < input_shape.rank(); ++i)
1961       output_shape.dim(outdim_index++) = input_shape.dim(i);
1962
1963     return loco::NodeShape{output_shape};
1964   }
1965
1966   loco::NodeShape visit(const luci::CircleInstanceNorm *node) final
1967   {
1968     auto input_shape = loco::shape_get(node->input()).as<loco::TensorShape>();
1969
1970     return loco::NodeShape{input_shape};
1971   }
1972
1973   // Virtual
1974   loco::NodeShape visit(const luci::CircleInput *node) final
1975   {
1976     loco::TensorShape shape;
1977
1978     shape.rank(node->rank());
1979     for (uint32_t axis = 0; axis < node->rank(); axis++)
1980       shape.dim(axis) = node->dim(axis);
1981
1982     return loco::NodeShape{shape};
1983   }
1984
1985   loco::NodeShape visit(const luci::CircleOutput *node) final
1986   {
1987     auto graph_outputs = node->graph()->outputs();
1988     auto graph_output = graph_outputs->at(node->index());
1989     auto output_shape = graph_output->shape();
1990
1991     return loco::NodeShape{*output_shape};
1992   }
1993
1994   loco::NodeShape visit(const luci::CircleOutputDummy *node) final { return use_own(node); }
1995
1996   loco::NodeShape visit(const luci::CircleOutputExclude *node) final { return use_own(node); }
1997
1998   loco::NodeShape visit(const luci::CircleCustomOut *node) final { return use_own(node); }
1999
2000   loco::NodeShape visit(const luci::CircleIfOut *node) final
2001   {
2002     /**
2003      * @note  IF operator type and shape are that of the "then" and "else"
2004      *        Graph Outputs.
2005      */
2006     auto circle_if = dynamic_cast<const luci::CircleIf *>(node->input());
2007     if (circle_if == nullptr)
2008     {
2009       INTERNAL_EXN("CircleIf IR is not configured correctly");
2010     }
2011
2012     auto index = node->index();
2013     auto then_graph = circle_if->then_graph();
2014     auto else_graph = circle_if->else_graph();
2015     assert(then_graph != nullptr);
2016     assert(else_graph != nullptr);
2017
2018     // shape and type are assumed to be same
2019     // these are checked at post_import_graph() in Import
2020     auto then_outputs = loco::output_nodes(then_graph);
2021     auto else_outputs = loco::output_nodes(else_graph);
2022     assert(then_outputs.size() == else_outputs.size());
2023     assert(index < static_cast<int32_t>(then_outputs.size()));
2024
2025     auto then_out = loco::must_cast<luci::CircleOutput *>(then_outputs.at(index));
2026     auto else_out = loco::must_cast<luci::CircleOutput *>(else_outputs.at(index));
2027
2028     auto then_graph_outputs = then_graph->outputs(); // loco::GraphOutput items
2029     auto else_graph_outputs = else_graph->outputs();
2030     assert(then_graph_outputs->size() == else_graph_outputs->size());
2031
2032     auto then_graph_output = then_graph_outputs->at(then_out->index());
2033     auto else_graph_output = else_graph_outputs->at(else_out->index());
2034     (void)else_graph_output; // make compiler happy for unused variable warnings
2035     assert(*then_graph_output->shape() == *else_graph_output->shape());
2036
2037     return loco::NodeShape{*then_graph_output->shape()};
2038   }
2039
2040   loco::NodeShape visit(const luci::CircleNonMaxSuppressionV4Out *node) final
2041   {
2042     const loco::DataType S32 = loco::DataType::S32;
2043
2044     auto nmsv4 = dynamic_cast<const luci::CircleNonMaxSuppressionV4 *>(node->input());
2045     if (nmsv4 == nullptr)
2046       INTERNAL_EXN("CircleNonMaxSuppressionV4 IR is not configured correctly");
2047
2048     auto index = node->index();
2049     if (index == 1)
2050       return loco::TensorShape({0});
2051
2052     assert(index == 0);
2053
2054     auto unknown = loco::TensorShape{loco::Dimension()};
2055     auto max_output_size = dynamic_cast<const luci::CircleConst *>(nmsv4->max_output_size());
2056     if (max_output_size == nullptr)
2057       return unknown; // we need CircleConst for max output size
2058
2059     LUCI_ASSERT(max_output_size->dtype() == S32, "Only support int32 for max_output_size");
2060
2061     if (max_output_size->size<S32>() < 1)
2062       return unknown;
2063
2064     auto max_output_size_value = uint32_t(max_output_size->at<S32>(0));
2065     return loco::TensorShape{max_output_size_value};
2066   }
2067
2068   loco::NodeShape visit(const luci::CircleSplitOut *node) final
2069   {
2070     const loco::DataType S32 = loco::DataType::S32;
2071
2072     auto split = dynamic_cast<const luci::CircleSplit *>(node->input());
2073     if (split == nullptr)
2074       INTERNAL_EXN("CircleSplit IR is not configured correctly");
2075
2076     loco::NodeShape unknown;
2077
2078     auto split_shape = loco::shape_get(split).as<loco::TensorShape>();
2079
2080     auto split_dim = dynamic_cast<const luci::CircleConst *>(split->split_dim());
2081     if (split_dim == nullptr)
2082       return unknown; // we need CircleConst for split_dim
2083     LUCI_ASSERT(split_dim->dtype() == S32, "Only support int32 for split_dim");
2084
2085     assert(split_dim->size<S32>() == 1);
2086     auto split_dim_axis = split_dim->at<S32>(0);
2087     if (split_dim_axis < 0)
2088       split_dim_axis += split_shape.rank();
2089
2090     auto split_dim_value = split_shape.dim(split_dim_axis).value();
2091     assert(split_dim_value % split->num_split() == 0);
2092     const int split_depth = split_dim_value / split->num_split();
2093
2094     loco::TensorShape output_shape = split_shape;
2095
2096     // All shapes are equally same
2097     output_shape.dim(split_dim_axis) = loco::Dimension(split_depth);
2098
2099     return loco::NodeShape{output_shape};
2100   }
2101
2102   loco::NodeShape visit(const luci::CircleSplitVOut *node) final
2103   {
2104     const loco::DataType S32 = loco::DataType::S32;
2105
2106     auto split = dynamic_cast<const luci::CircleSplitV *>(node->input());
2107     if (split == nullptr)
2108       INTERNAL_EXN("CircleSplit IR is not configured correctly");
2109
2110     loco::NodeShape unknown;
2111
2112     auto split_shape = loco::shape_get(split).as<loco::TensorShape>();
2113
2114     auto size_splits = dynamic_cast<const luci::CircleConst *>(split->size_splits());
2115     if (size_splits == nullptr)
2116       return unknown; // we need CircleConst for size_splits
2117     LUCI_ASSERT(size_splits->dtype() == S32, "Only support int32 for size_splits");
2118
2119     auto split_dim = dynamic_cast<const luci::CircleConst *>(split->split_dim());
2120     if (split_dim == nullptr)
2121       return unknown; // we need CircleConst for split_dim
2122     LUCI_ASSERT(split_dim->dtype() == S32, "Only support int32 for split_dim");
2123
2124     // fetch axis
2125     assert(split_dim->size<S32>() == 1);
2126     auto split_dim_axis = split_dim->at<S32>(0);
2127     if (split_dim_axis < 0)
2128       split_dim_axis += split_shape.rank();
2129
2130     // interpret size_splits values
2131     int32_t size_splits_count = static_cast<int32_t>(size_splits->size<S32>());
2132     assert(size_splits_count == split->num_split());
2133
2134     int64_t minus_one_count = 0, size_splits_sum = 0;
2135     for (int32_t idx = 0; idx < size_splits_count; ++idx)
2136     {
2137       auto size = size_splits->at<S32>(idx);
2138       assert(size >= -1);
2139       if (size == -1)
2140         ++minus_one_count;
2141       else
2142         size_splits_sum += size;
2143     }
2144     if (minus_one_count > 1)
2145       INTERNAL_EXN("CircleSplitV size_splits has more than two -1 values");
2146
2147     // calcuate this SplitVOut shape
2148     auto input_size = split_shape.dim(split_dim_axis).value();
2149     assert(size_splits_sum <= input_size);
2150
2151     auto index_this = node->index();
2152     assert(0 <= index_this && index_this < split->num_split());
2153     auto split_depth = size_splits->at<S32>(index_this);
2154     if (split_depth == -1)
2155       split_depth = input_size - size_splits_sum;
2156
2157     loco::TensorShape output_shape = split_shape;
2158
2159     output_shape.dim(split_dim_axis) = loco::Dimension(split_depth);
2160
2161     return loco::NodeShape{output_shape};
2162   }
2163
2164   loco::NodeShape visit(const luci::CircleTopKV2Out *node) final
2165   {
2166     const loco::DataType S32 = loco::DataType::S32;
2167
2168     auto topkv2 = dynamic_cast<const luci::CircleTopKV2 *>(node->input());
2169     if (topkv2 == nullptr)
2170       INTERNAL_EXN("CircleSplit IR is not configured correctly");
2171
2172     // shape of topkv2 is same as topkv2->input()
2173     auto input_shape = loco::shape_get(topkv2).as<loco::TensorShape>();
2174
2175     auto node_k = loco::must_cast<const luci::CircleConst *>(topkv2->k());
2176     LUCI_ASSERT(node_k->dtype() == S32, "Only support Int32");
2177     assert(node_k->size<S32>() == 1);
2178
2179     loco::TensorShape output_shape;
2180
2181     output_shape.rank(input_shape.rank());
2182     for (uint32_t idx = 0; idx < input_shape.rank() - 1; ++idx)
2183     {
2184       output_shape.dim(idx) = input_shape.dim(idx);
2185     }
2186     output_shape.dim(input_shape.rank() - 1) = node_k->at<S32>(0);
2187
2188     return loco::NodeShape{output_shape};
2189   }
2190
2191   loco::NodeShape visit(const luci::CircleUniqueOut *node) final
2192   {
2193     auto unique = dynamic_cast<const luci::CircleUnique *>(node->input());
2194     if (unique == nullptr)
2195     {
2196       INTERNAL_EXN("CircleUnique IR is not configured correctly");
2197     }
2198
2199     auto unique_shape = loco::shape_get(unique).as<loco::TensorShape>();
2200
2201     return loco::NodeShape{unique_shape};
2202   }
2203
2204   loco::NodeShape visit(const luci::CircleUnpackOut *node) final
2205   {
2206     auto unpack = dynamic_cast<const luci::CircleUnpack *>(node->input());
2207     if (unpack == nullptr)
2208     {
2209       INTERNAL_EXN("CircleUnpack IR is not configured correctly");
2210     }
2211
2212     auto unpack_shape = loco::shape_get(unpack).as<loco::TensorShape>();
2213
2214     return loco::NodeShape{unpack_shape};
2215   }
2216
2217   loco::NodeShape visit(const luci::CircleWhileOut *node) final
2218   {
2219     /**
2220      * @note  WHILE operator's shape is the same with the "cond"
2221      *        Graph input.
2222      */
2223     auto circle_while = dynamic_cast<const luci::CircleWhile *>(node->input());
2224     if (circle_while == nullptr)
2225     {
2226       INTERNAL_EXN("CircleWhile IR is not configured correctly");
2227     }
2228
2229     auto index = node->index();
2230     auto cond_graph = circle_while->cond_graph();
2231     assert(cond_graph != nullptr);
2232
2233     // Assumption: the index of CircleWhileOut matches with the index of input nodes returned by
2234     // loco::input_nodes
2235     auto cond_inputs = loco::input_nodes(cond_graph);
2236     auto cond_in = loco::must_cast<luci::CircleInput *>(cond_inputs.at(index));
2237
2238     auto cond_graph_inputs = cond_graph->inputs();
2239     auto cond_graph_input = cond_graph_inputs->at(cond_in->index());
2240
2241     auto cond_graph_input_shape = *cond_graph_input->shape();
2242     auto this_shape = own_shape(node);
2243
2244     if (!(this_shape == cond_graph_input_shape))
2245     {
2246       LOGGER(l);
2247       WARN(l) << "Warning: CircleWhileOut '" << node->name() << "' shape mispatch " << this_shape
2248               << " vs " << cond_graph_input_shape;
2249     }
2250
2251     return loco::NodeShape{this_shape};
2252   }
2253 };
2254
2255 } // namespace
2256
2257 namespace luci
2258 {
2259
2260 bool CircleShapeInferenceRule::recognize(const loco::Dialect *d) const
2261 {
2262   return CircleDialect::get() == d;
2263 }
2264
2265 bool CircleShapeInferenceRule::infer(const loco::Node *node, loco::NodeShape &shape) const
2266 {
2267   LOGGER(l);
2268
2269   assert(node->dialect() == CircleDialect::get());
2270
2271   ShapeInferenceAlgorithm alg;
2272   auto circle_node = loco::must_cast<const CircleNode *>(node);
2273
2274   bool is_shape_undefined = (circle_node->shape_status() == ShapeStatus::UNDEFINED);
2275   bool is_shape_none = (circle_node->shape_status() == ShapeStatus::NOSHAPE);
2276   bool is_scalar = (circle_node->rank() == 0);
2277
2278   if (is_shape_undefined)
2279     shape = circle_node->accept(&alg);
2280   else
2281   {
2282     if (is_shape_none || is_scalar)
2283       shape = own_shape(circle_node);
2284     else
2285       shape = circle_node->accept(&alg);
2286   }
2287
2288   VERBOSE(l, 1) << "[luci] shape: " << circle_node->name();
2289   VERBOSE(l, 1) << "              own_shape: " << own_shape(circle_node)
2290                 << " -> infer: " << shape.as<loco::TensorShape>();
2291
2292   return true;
2293 }
2294
2295 } // namespace luci