25747d95017c3150866eda0fba5467479e0397a8
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / StaticShapeInferer.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "compiler/StaticShapeInferer.h"
18 #include "util/ShapeInference.h"
19 #include "util/logging.h"
20
21 #include <misc/polymorphic_downcast.h>
22
23 #include <sstream>
24 #include <stdexcept>
25
26 namespace onert
27 {
28 namespace compiler
29 {
30 void OperandObserver::updateShapes(const std::vector<ir::OperandInfo> &changed_operands_info,
31                                    bool unpredictable)
32 {
33   assert(changed_operands_info.size() == _operands.size());
34   for (size_t i = 0; i < changed_operands_info.size(); ++i)
35   {
36     const auto &changed_operand_info = changed_operands_info.at(i);
37     auto &operand = _operands.at(i);
38     // assert(changed_operand_info.typeInfo() == operand->typeInfo());
39     // assert(changed_operand_info.typeInfo() == operand->typeInfo());
40     // This error check may by replaced by an assertion if this function is called after the
41     // validation of models are completed.
42     if (changed_operand_info.typeInfo() != operand->typeInfo())
43     {
44       throw std::runtime_error("OperandObserver: The types of operands are mismatched");
45     }
46     if (!operand->info().isConstant() && (changed_operand_info.isDynamic() || unpredictable))
47     {
48       operand->info().setDynamic();
49     }
50     else
51     {
52       const auto &new_shape = changed_operands_info.at(i).shape();
53       operand->info().shape(new_shape);
54     }
55   }
56 }
57
58 void StaticShapeInferer::infer()
59 {
60   for (const auto &op_idx : _lowered_subg->graph().topolSortOperations())
61   {
62     const auto &op = _lowered_subg->graph().operations().at(op_idx);
63     bool has_dynamic_tensor = false;
64     const auto opcode = op.opcode();
65     // IF: requires shape inference for then, else
66     // While: requires shape inference for condition, body
67     if (opcode == ir::OpCode::If || opcode == ir::OpCode::While)
68     {
69       op.accept(*this);
70     }
71     else
72     {
73       has_dynamic_tensor = checkDynamicInput(op);
74       if (has_dynamic_tensor)
75       {
76         setDynamicOutput(op);
77       }
78       else
79       {
80         op.accept(*this);
81       }
82     }
83     has_dynamic_tensor = has_dynamic_tensor || checkDynamicOutput(op);
84     _lowered_subg->setHasDynamicTensor(op_idx, has_dynamic_tensor);
85   }
86
87   if (_controlflow_output_observer != nullptr)
88   {
89     // re-sizing output shapes of the controflow operation branching to this subgraph
90     std::vector<ir::OperandInfo> outputs_info;
91     const auto &graph = _lowered_subg->graph();
92     const auto &outputs = graph.getOutputs();
93     for (size_t i = 0; i < outputs.size(); ++i)
94     {
95       const auto &operand_info = graph.operands().at(outputs.at(i)).info();
96       outputs_info.emplace_back(operand_info);
97     }
98     _controlflow_output_observer->updateShapes(outputs_info);
99   }
100 }
101
102 bool StaticShapeInferer::checkDynamicInput(const ir::Operation &op)
103 {
104   const auto &operands = _lowered_subg->graph().operands();
105   for (auto input_idx : op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
106   {
107     if (operands.at(input_idx).info().isDynamic())
108     {
109       return true;
110     }
111   }
112
113   return false;
114 }
115
116 bool StaticShapeInferer::checkDynamicOutput(const ir::Operation &op)
117 {
118   auto &operands = _lowered_subg->graph().operands();
119   for (auto output_idx : op.getOutputs() | ir::Remove::UNDEFINED)
120   {
121     if (operands.at(output_idx).info().isDynamic())
122     {
123       return true;
124     }
125   }
126   return false;
127 }
128
129 void StaticShapeInferer::setDynamicOutput(const ir::Operation &op)
130 {
131   auto &operands = _lowered_subg->graph().operands();
132   for (auto output_idx : op.getOutputs() | ir::Remove::UNDEFINED)
133   {
134     operands.at(output_idx).info().setDynamic();
135   }
136 }
137
138 void StaticShapeInferer::handleBinaryArithmeticOp(const ir::Operation &op,
139                                                   const ir::OperandIndex lhs_idx,
140                                                   const ir::OperandIndex rhs_idx)
141 {
142   auto &operands = _lowered_subg->graph().operands();
143   const auto &lhs = operands.at(lhs_idx);
144   const auto &rhs = operands.at(rhs_idx);
145
146   const auto output_idx = op.getOutputs().at(0);
147   ir::Operand &output = operands.at(output_idx);
148
149   // re-sizing output shape
150   ir::Shape new_shape = shape_inference::inferEltwiseShape(lhs.info().shape(), rhs.info().shape());
151   output.info().shape(new_shape);
152 }
153
154 void StaticShapeInferer::handleSimpleUnaryOp(const ir::Operation &op,
155                                              const ir::OperandIndex input_idx)
156 {
157   auto &operands = _lowered_subg->graph().operands();
158   const auto &input = operands.at(input_idx);
159
160   // get mutable output operand
161   const auto output_idx = op.getOutputs().at(0);
162   ir::Operand &output = operands.at(output_idx);
163
164   // re-sizing output shape
165   ir::Shape new_shape = input.info().shape();
166   output.info().shape(new_shape);
167 }
168
169 void StaticShapeInferer::dump()
170 {
171   auto get_shape_str = [](const ir::Shape &shape) {
172     std::stringstream sstream;
173     sstream << "shape : {";
174     for (int i = 0; i < shape.rank(); i++)
175     {
176       if (i == 0)
177         sstream << shape.dim(i);
178       else
179         sstream << " " << shape.dim(i);
180     }
181     sstream << "}";
182     return sstream.str();
183   };
184
185   _lowered_subg->graph().operands().iterate(
186     [&](const ir::OperandIndex &ind, const ir::Operand &operand) {
187       VERBOSE(StaticShapeInferer) << "  " << ind << ", "
188                                   << (operand.info().isDynamic() ? "Dynamic" : "Static") << ", "
189                                   << get_shape_str(operand.info().shape()) << std::endl;
190     });
191 }
192
193 std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>>
194 StaticShapeInferer::createStaticShapeInferers(
195   const std::unordered_map<ir::SubgraphIndex, std::unique_ptr<LoweredGraph>> &lowered_subgs)
196 {
197   // Allocate StaticShapeInferer per each subgraph
198   std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers;
199   for (auto &&pair : lowered_subgs)
200   {
201     const auto &subg_index = pair.first;
202     auto &lowered_subg = pair.second;
203     inferers[subg_index] = std::make_unique<StaticShapeInferer>(lowered_subg.get());
204   }
205
206   // Append observers in all StaticShapeInferers
207   for (auto &&pair : lowered_subgs)
208   {
209     const auto &subg_index = pair.first;
210     auto &lowered_subg = pair.second;
211
212     // TODO: Change this iteration for all to controlflow iteration
213     lowered_subg->graph().operations().iterate(
214       [&](const ir::OperationIndex &, const ir::Operation &op) {
215         // A Function to append child inferers. These make it possible for a StaticShapeInferer to
216         // call StaticShapeInferes of child subgraphs recursively
217         auto appendChildInferer = [&](const ir::SubgraphIndex &child_subg_idx) {
218           auto *child_inferer = inferers.at(child_subg_idx).get();
219           inferers.at(subg_index)->appendChildInferer(child_subg_idx, child_inferer);
220         };
221
222         // A Function to appaend subg input observers. This makes it possible for a
223         // StaticShapeInferer to update inputs of child subgraphs
224         auto appendSubgraphInputObserver = [&](const ir::SubgraphIndex &child_subg_idx) {
225           std::vector<ir::Operand *> child_subg_inputs;
226           auto &child_subg = lowered_subgs.at(child_subg_idx)->graph();
227           for (const auto &input_idx : child_subg.getInputs())
228           {
229             auto operand_ptr = child_subg.operands().getRawPtr(input_idx);
230             child_subg_inputs.emplace_back(operand_ptr);
231           }
232           inferers.at(subg_index)
233             ->appendSubgInputObserver(child_subg_idx,
234                                       std::make_unique<OperandObserver>(child_subg_inputs));
235         };
236
237         // A Function to set controlflow output observers. This makes it possible for a
238         // StaticShapeInferer to update outputs of parent controlflow opeerations
239         auto setControlFlowOutputObserver = [&](const ir::SubgraphIndex &child_subg_idx) {
240           std::vector<ir::Operand *> cf_outputs;
241           auto &subg = lowered_subg->graph();
242           for (const auto &output_idx : op.getOutputs())
243           {
244             auto operand_ptr = subg.operands().getRawPtr(output_idx);
245             cf_outputs.emplace_back(operand_ptr);
246           }
247           inferers.at(child_subg_idx)
248             ->setControlflowOutputObserver(std::make_unique<OperandObserver>(cf_outputs));
249         };
250
251         // Append Observers in a StaticShapeInferer
252         if (op.opcode() == ir::OpCode::If)
253         {
254           const auto &if_op = nnfw::misc::polymorphic_downcast<const ir::operation::If &>(op);
255
256           appendChildInferer(if_op.param().then_subg_index);
257           appendChildInferer(if_op.param().else_subg_index);
258
259           appendSubgraphInputObserver(if_op.param().then_subg_index);
260           appendSubgraphInputObserver(if_op.param().else_subg_index);
261
262           setControlFlowOutputObserver(if_op.param().then_subg_index);
263         }
264         else if (op.opcode() == ir::OpCode::While)
265         {
266           const auto &while_op = nnfw::misc::polymorphic_downcast<const ir::operation::While &>(op);
267
268           appendChildInferer(while_op.param().cond_subg_index);
269           appendChildInferer(while_op.param().body_subg_index);
270
271           appendSubgraphInputObserver(while_op.param().cond_subg_index);
272           appendSubgraphInputObserver(while_op.param().body_subg_index);
273
274           setControlFlowOutputObserver(while_op.param().body_subg_index);
275         }
276       });
277   }
278
279   return inferers;
280 }
281
282 void StaticShapeInferer::visit(const ir::operation::ArgMinMax &op)
283 {
284   auto &operands = _lowered_subg->graph().operands();
285
286   const auto input_idx{op.getInputs().at(ir::operation::ArgMinMax::Input::INPUT)};
287   const auto &input = operands.at(input_idx);
288
289   const auto axis_idx{op.getInputs().at(ir::operation::ArgMinMax::Input::AXIS)};
290   const auto &axis = operands.at(axis_idx);
291
292   // get mutable output operand
293   const auto output_idx = op.getOutputs().at(0);
294   ir::Operand &output = operands.at(output_idx);
295
296   if (!axis.isConstant())
297   {
298     output.info().setDynamic();
299     return;
300   }
301
302   const auto rank = input.info().shape().rank();
303   auto axis_value = axis.asScalar<int32_t>();
304   axis_value = axis_value < 0 ? axis_value + rank : axis_value;
305
306   // re-sizing output shape
307   ir::Shape new_shape =
308     shape_inference::inferArgMinMaxShape(input.info().shape(), axis_value, rank);
309   output.info().shape(new_shape);
310 }
311
312 void StaticShapeInferer::visit(const ir::operation::BatchMatMul &op)
313 {
314   auto &operands = _lowered_subg->graph().operands();
315
316   const auto lhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::LHS);
317   const auto rhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::RHS);
318   const auto output_index = op.getOutputs().at(0);
319   const auto &lhs = operands.at(lhs_index);
320   const auto &rhs = operands.at(rhs_index);
321   auto &output = operands.at(output_index);
322   auto new_shape = shape_inference::inferBatchMatMulShape(lhs.shape(), rhs.shape(), op.param());
323   output.info().shape(new_shape);
324 }
325
326 void StaticShapeInferer::visit(const ir::operation::BCQFullyConnected &op)
327 {
328   auto &operands = _lowered_subg->graph().operands();
329
330   const auto input_idx{op.getInputs().at(ir::operation::BCQFullyConnected::Input::INPUT)};
331   const auto &input = operands.at(input_idx);
332
333   const auto cluster_idx{
334     op.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_CLUSTERS)};
335   const auto &cluster = operands.at(cluster_idx);
336
337   const auto output_idx = op.getOutputs().at(0);
338   ir::Operand &output = operands.at(output_idx);
339
340   auto cluster_buf = reinterpret_cast<const int32_t *>(cluster.data()->base());
341   assert(cluster_buf);
342
343   // re-sizing output shape
344   ir::Shape new_shape = shape_inference::inferBCQFullyConnectedShape(
345     input.info().shape(), cluster.info().shape(), cluster_buf);
346   output.info().shape(new_shape);
347 }
348
349 void StaticShapeInferer::visit(const ir::operation::BCQGather &op)
350 {
351   auto &operands = _lowered_subg->graph().operands();
352
353   const auto indices_idx{op.getInputs().at(ir::operation::BCQGather::Input::INDICES)};
354   const auto &indices = operands.at(indices_idx);
355
356   const auto input_binary_idx{op.getInputs().at(ir::operation::BCQGather::Input::INPUT_BINARY)};
357   const auto &input_binary = operands.at(input_binary_idx);
358
359   const auto cluster_idx{op.getInputs().at(ir::operation::BCQGather::Input::INPUT_CLUSTERS)};
360   const auto &cluster = operands.at(cluster_idx);
361
362   const auto output_idx = op.getOutputs().at(0);
363   ir::Operand &output = operands.at(output_idx);
364
365   auto cluster_buf = reinterpret_cast<const int32_t *>(cluster.data()->base());
366   assert(cluster_buf);
367
368   auto rank = input_binary.shape().rank();
369
370   // re-sizing output shape
371   ir::Shape new_shape = shape_inference::inferBCQGatherShape(
372     indices.info().shape(), cluster.info().shape(), cluster_buf, rank, op.param());
373
374   output.info().shape(new_shape);
375 }
376
377 void StaticShapeInferer::visit(const ir::operation::BinaryArithmetic &op)
378 {
379   handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::BinaryArithmetic::Input::LHS),
380                            op.getInputs().at(ir::operation::BinaryArithmetic::Input::RHS));
381 }
382
383 void StaticShapeInferer::visit(const ir::operation::BroadcastTo &op)
384 {
385   // get mutable output operand
386   auto &operands = _lowered_subg->graph().operands();
387   const auto output_idx = op.getOutputs().at(0);
388   ir::Operand &output = operands.at(output_idx);
389
390   const auto shape_idx{op.getInputs().at(ir::operation::BroadcastTo::Input::SHAPE)};
391   const auto &shape = operands.at(shape_idx);
392
393   if (!shape.isConstant())
394   {
395     output.info().setDynamic();
396     return;
397   }
398
399   // assert(shape.typeInfo().type() == ir::DataType::INT32);
400   auto shape_buffer = reinterpret_cast<const int32_t *>(shape.data()->base());
401
402   // re-sizing output shape
403   ir::Shape new_shape = shape_inference::inferBroadcastToShape(shape.info().shape(), shape_buffer);
404   output.info().shape(new_shape);
405 }
406
407 void StaticShapeInferer::visit(const ir::operation::Comparison &op)
408 {
409   handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Comparison::Input::INPUT0),
410                            op.getInputs().at(ir::operation::Comparison::Input::INPUT1));
411 }
412
413 void StaticShapeInferer::visit(const ir::operation::Concat &op)
414 {
415   auto &operands = _lowered_subg->graph().operands();
416
417   const auto input_count = op.getInputs().size();
418
419   const auto output_idx = op.getOutputs().at(0);
420   ir::Operand &output = operands.at(output_idx);
421
422   shape_inference::Shapes input_shapes;
423   for (uint32_t i = 0; i < input_count; i++)
424   {
425     const auto input_idx{op.getInputs().at(i)};
426     const auto &input = operands.at(input_idx);
427     input_shapes.emplace_back(input.shape());
428   }
429
430   ir::Shape out_shape = shape_inference::inferConcatShape(input_shapes, op.param());
431
432   // re-sizing output shape
433   output.info().shape(out_shape);
434 }
435
436 void StaticShapeInferer::visit(const ir::operation::Conv2D &op)
437 {
438   auto &operands = _lowered_subg->graph().operands();
439
440   const auto input_idx{op.getInputs().at(ir::operation::Conv2D::Input::INPUT)};
441   const auto &input = operands.at(input_idx);
442   const auto ker_idx{op.getInputs().at(ir::operation::Conv2D::Input::KERNEL)};
443   const auto &ker = operands.at(ker_idx);
444   const auto output_idx = op.getOutputs().at(0);
445   ir::Operand &output = operands.at(output_idx);
446
447   // re-sizing output shape
448   ir::Shape new_shape =
449     shape_inference::inferConv2DShape(input.info().shape(), ker.info().shape(), op.param());
450   output.info().shape(new_shape);
451 }
452
453 void StaticShapeInferer::visit(const ir::operation::ElementwiseActivation &op)
454 {
455   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::ElementwiseActivation::Input::INPUT));
456 }
457
458 void StaticShapeInferer::visit(const ir::operation::ElementwiseBinary &op)
459 {
460   handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::ElementwiseBinary::Input::LHS),
461                            op.getInputs().at(ir::operation::ElementwiseBinary::Input::RHS));
462 }
463
464 void StaticShapeInferer::visit(const ir::operation::ElementwiseUnary &op)
465 {
466   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT));
467 }
468
469 void StaticShapeInferer::visit(const ir::operation::ExpandDims &op)
470 {
471   auto &operands = _lowered_subg->graph().operands();
472
473   const auto input_idx{op.getInputs().at(ir::operation::ExpandDims::Input::INPUT)};
474   const auto &input = operands.at(input_idx);
475   const auto axis_idx{op.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
476   const auto &axis = operands.at(axis_idx);
477   const auto output_idx = op.getOutputs().at(0);
478   ir::Operand &output = operands.at(output_idx);
479
480   if (!axis.isConstant())
481   {
482     output.info().setDynamic();
483     return;
484   }
485
486   // even when axis is constant, output shape should be recalculated since user might call
487   // nnfw_set_input_tensorinfo(input, some_new_shape)
488   auto axis_type = axis.typeInfo().type();
489   assert(axis_type == ir::DataType::INT32 || axis_type == ir::DataType::INT64);
490
491   assert(axis.data()->base());
492   int32_t axis_value =
493     (axis_type == ir::DataType::INT32)
494       ? reinterpret_cast<const int32_t *>(axis.data()->base())[0]
495       : static_cast<int32_t>(reinterpret_cast<const int64_t *>(axis.data()->base())[0]);
496
497   // re-sizing output shape
498   ir::Shape new_shape = shape_inference::inferExpandDimsShape(input.info().shape(), axis_value);
499   output.info().shape(new_shape);
500 }
501
502 void StaticShapeInferer::visit(const ir::operation::Fill &op)
503 {
504   auto &operands = _lowered_subg->graph().operands();
505
506   const auto shape_idx{op.getInputs().at(ir::operation::Fill::Input::SHAPE)};
507   const auto &shape = operands.at(shape_idx);
508   const auto output_idx = op.getOutputs().at(0);
509   ir::Operand &output = operands.at(output_idx);
510
511   if (!shape.isConstant())
512   {
513     output.info().setDynamic();
514     return;
515   }
516
517   const auto dims_type = shape.typeInfo().type();
518   assert(dims_type == ir::DataType::INT32 || dims_type == ir::DataType::INT64);
519
520   auto dims_buf = shape.data()->base();
521   assert(dims_buf);
522
523   const auto &dims_shape = shape.info().shape();
524   auto new_shape = ((dims_type == ir::DataType::INT32)
525                       ? shape_inference::inferFillShape<int32_t>(
526                           dims_shape, reinterpret_cast<const int32_t *>(dims_buf))
527                       : shape_inference::inferFillShape<int64_t>(
528                           dims_shape, reinterpret_cast<const int64_t *>(dims_buf)));
529
530   output.info().shape(new_shape);
531 }
532
533 void StaticShapeInferer::visit(const ir::operation::FullyConnected &op)
534 {
535   auto &operands = _lowered_subg->graph().operands();
536
537   const auto input_idx{op.getInputs().at(ir::operation::FullyConnected::Input::INPUT)};
538   const auto &input = operands.at(input_idx);
539
540   const auto ker_idx{op.getInputs().at(ir::operation::FullyConnected::Input::WEIGHT)};
541   const auto &ker = operands.at(ker_idx);
542
543   // get mutable output operand
544   const auto output_idx = op.getOutputs().at(0);
545   ir::Operand &output = operands.at(output_idx);
546   // re-sizing output shape
547   ir::Shape new_shape =
548     shape_inference::inferFullyConnectedShape(input.info().shape(), ker.info().shape());
549   output.info().shape(new_shape);
550 }
551
552 void StaticShapeInferer::visit(const ir::operation::FusedBatchNorm &op)
553 {
554   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::FusedBatchNorm::Input::INPUT));
555 }
556
557 void StaticShapeInferer::visit(const ir::operation::Gather &op)
558 {
559   auto &operands = _lowered_subg->graph().operands();
560
561   const auto input_idx{op.getInputs().at(ir::operation::Gather::Input::INPUT)};
562   const auto &input = operands.at(input_idx);
563
564   // get mutable output operand
565   const auto output_idx = op.getOutputs().at(0);
566   ir::Operand &output = operands.at(output_idx);
567
568   const auto indices_idx{op.getInputs().at(ir::operation::Gather::Input::INDICES)};
569   const auto &indices = operands.at(indices_idx);
570   const auto rank = input.info().shape().rank();
571   const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
572
573   assert(0 <= axis && axis < rank);
574
575   // re-sizing output shape
576   ir::Shape new_shape =
577     shape_inference::inferGatherShape(input.info().shape(), indices.info().shape(), axis, rank);
578   output.info().shape(new_shape);
579 }
580
581 void StaticShapeInferer::visit(const ir::operation::If &op)
582 {
583   // re-sizing input shapes of then/else subgraph
584   const std::vector<ir::OperandIndex> inputs{op.getInputs().begin() + 1, op.getInputs().end()};
585
586   std::vector<ir::OperandInfo> inputs_info;
587   const auto &graph = _lowered_subg->graph();
588   for (size_t i = 0; i < inputs.size(); ++i)
589   {
590     const auto &operand_info = graph.operands().at(inputs.at(i)).info();
591     inputs_info.emplace_back(operand_info);
592   }
593   _subg_input_observers.at(op.param().then_subg_index)->updateShapes(inputs_info);
594   _child_inferers.at(op.param().then_subg_index)->infer();
595
596   _subg_input_observers.at(op.param().else_subg_index)->updateShapes(inputs_info);
597   _child_inferers.at(op.param().else_subg_index)->infer();
598 }
599
600 void StaticShapeInferer::visit(const ir::operation::L2Normalization &op)
601 {
602   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::L2Normalization::Input::INPUT));
603 }
604
605 void StaticShapeInferer::visit(const ir::operation::LSTM &op)
606 {
607   auto &operands = _lowered_subg->graph().operands();
608
609   const auto output_index{op.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
610   auto &output = operands.at(output_index);
611
612   const auto output_state_out_index{
613     op.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)};
614
615   const auto cell_state_out_index{op.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)};
616
617   const auto scratch_buffer_index{op.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)};
618
619   if (output.info().isDynamic() ||
620       (operands.exist(output_state_out_index) &&
621        operands.at(output_state_out_index).info().isDynamic()) ||
622       (operands.exist(cell_state_out_index) &&
623        operands.at(cell_state_out_index).info().isDynamic()) ||
624       (operands.exist(scratch_buffer_index) &&
625        operands.at(scratch_buffer_index).info().isDynamic()))
626     return;
627
628   const auto input_index{op.getInputs().at(ir::operation::LSTM::Input::INPUT)};
629   const auto &input = operands.at(input_index);
630
631   const auto input_to_output_weights_index{
632     op.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)};
633   const auto &input_to_output_weights = operands.at(input_to_output_weights_index);
634
635   const auto recurrent_to_output_weights_index{
636     op.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)};
637   const auto &recurrent_to_output_weights = operands.at(recurrent_to_output_weights_index);
638
639   // re-sizing outputs
640   const int n_batch = (input.shape().rank() == 3 && op.param().time_major) ? input.shape().dim(1)
641                                                                            : input.shape().dim(0);
642   const int n_cell = input_to_output_weights.shape().dim(0);
643   const int n_output = recurrent_to_output_weights.shape().dim(1);
644   if (input.shape().rank() == 3)
645   {
646     if (op.param().time_major)
647       output.info().shape(ir::Shape{input.shape().dim(0), n_batch, n_output});
648     else
649       output.info().shape(ir::Shape{n_batch, input.shape().dim(1), n_output});
650   }
651   else
652   {
653     assert(input.shape().rank() == 2);
654     output.info().shape(ir::Shape{n_batch, n_output});
655   }
656
657   if (operands.exist(output_state_out_index))
658   {
659     auto &output_state_out = operands.at(output_state_out_index);
660     output_state_out.info().shape(ir::Shape{n_batch, n_output});
661   }
662
663   if (operands.exist(cell_state_out_index))
664   {
665     auto &cell_state_out = operands.at(cell_state_out_index);
666     cell_state_out.info().shape(ir::Shape{n_batch, n_cell});
667   }
668
669   if (operands.exist(scratch_buffer_index))
670   {
671     auto &scratch_buffer = operands.at(scratch_buffer_index);
672
673     const auto input_to_input_weights_index{
674       op.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)};
675     const auto recurrent_to_input_weights_index{
676       op.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)};
677
678     bool has_input_to_input_weights =
679       operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
680       operands.at(input_to_input_weights_index).shape().dim(1) != 0;
681     bool has_recurrent_to_input_weights =
682       operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
683       operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0;
684
685     // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG).
686     // true: no CIFG
687     // false: CIFG
688     bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
689     if (has_cifg_param)
690     {
691       scratch_buffer.info().shape(ir::Shape{n_batch, n_cell * 4});
692     }
693     else
694     {
695       scratch_buffer.info().shape(ir::Shape{n_batch, n_cell * 3});
696     }
697   }
698 }
699
700 void StaticShapeInferer::visit(const ir::operation::MatrixBandPart &op)
701 {
702   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT));
703 }
704
705 void StaticShapeInferer::visit(const ir::operation::OneHot &op)
706 {
707   auto &operands = _lowered_subg->graph().operands();
708
709   const auto indice_idx{op.getInputs().at(ir::operation::OneHot::Input::INDICES)};
710   const auto &indice = operands.at(indice_idx);
711   const auto depth_idx{op.getInputs().at(ir::operation::OneHot::Input::DEPTH)};
712   const auto &depth = operands.at(depth_idx);
713
714   const auto axis = op.param().axis;
715
716   auto output_idx = op.getOutputs().at(0);
717   ir::Operand &output = operands.at(output_idx);
718
719   if (!depth.isConstant())
720   {
721     output.info().setDynamic();
722     return;
723   }
724
725   const auto *depth_buf = reinterpret_cast<const int32_t *>(depth.data()->base());
726   assert(depth_buf);
727   // re-sizing output shape
728   ir::Shape new_shape = shape_inference::inferOnehotShape(indice.info().shape(), *depth_buf, axis);
729   output.info().shape(new_shape);
730 }
731
732 void StaticShapeInferer::visit(const ir::operation::Pack &op)
733 {
734   auto &operands = _lowered_subg->graph().operands();
735
736   const auto input_idx{op.getInputs().at(0)};
737   const auto &input = operands.at(input_idx);
738
739   // get mutable output operand
740   const auto output_idx = op.getOutputs().at(0);
741   ir::Operand &output = operands.at(output_idx);
742
743   const auto rank = input.shape().rank() + 1;
744   const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
745   const auto num = op.param().num;
746
747   assert(0 <= axis && axis < rank);
748
749   // re-sizing output shape
750   ir::Shape new_shape = shape_inference::inferPackShape(input.info().shape(), axis, rank, num);
751   output.info().shape(new_shape);
752 }
753
754 void StaticShapeInferer::visit(const ir::operation::Pad &op)
755 {
756   auto &operands = _lowered_subg->graph().operands();
757
758   const auto input_idx{op.getInputs().at(ir::operation::Pad::Input::INPUT)};
759   const auto &input = operands.at(input_idx);
760
761   const auto pad_idx{op.getInputs().at(ir::operation::Pad::Input::PAD)};
762   const auto &pad = operands.at(pad_idx);
763
764   // get mutable output operand
765   const auto output_idx = op.getOutputs().at(0);
766   ir::Operand &output = operands.at(output_idx);
767
768   // if pad is not constant, output also becomes dynamic
769   if (!pad.isConstant())
770   {
771     output.info().setDynamic();
772     return;
773   }
774
775   // re-sizing output shape
776   const auto new_shape = shape_inference::inferPadShape(
777     input.shape(), reinterpret_cast<const int32_t *>(pad.data()->base()),
778     pad.shape().num_elements());
779   output.info().shape(new_shape);
780 }
781
782 void StaticShapeInferer::visit(const ir::operation::Permute &op)
783 {
784   auto &operands = _lowered_subg->graph().operands();
785
786   const auto input_idx{op.getInputs().at(0)};
787   const auto &input = operands.at(input_idx);
788   const auto output_idx = op.getOutputs().at(0);
789   ir::Operand &output = operands.at(output_idx);
790
791   // re-sizing output shape
792   // Permute is a special operation that layouts of input/output may be different on backend
793   // However, it is not applied here, so input/output have the same layout of frontend. Because
794   // "ExecutorFactory" would convert shape of input/output accoding to the layouts when registering
795   // operand info to "TensorBuilder" after calling "StaticShapeInferer"
796   const auto new_shape = input.info().shape();
797   output.info().shape(new_shape);
798 }
799
800 void StaticShapeInferer::visit(const ir::operation::Pow &op)
801 {
802   handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Pow::Input::LHS),
803                            op.getInputs().at(ir::operation::Pow::Input::RHS));
804 }
805
806 void StaticShapeInferer::visit(const ir::operation::Range &op)
807 {
808   auto &operands = _lowered_subg->graph().operands();
809
810   const auto start_idx{op.getInputs().at(ir::operation::Range::Input::START)};
811   const auto limit_idx{op.getInputs().at(ir::operation::Range::Input::LIMIT)};
812   const auto delta_idx{op.getInputs().at(ir::operation::Range::Input::DELTA)};
813   const auto &start_op = operands.at(start_idx);
814   const auto &limit_op = operands.at(limit_idx);
815   const auto &delta_op = operands.at(delta_idx);
816
817   // get mutable output operand
818   const auto output_idx = op.getOutputs().at(0);
819   ir::Operand &output = operands.at(output_idx);
820
821   ir::Shape new_shape;
822   if (start_op.isConstant() && limit_op.isConstant() && delta_op.isConstant())
823   {
824     assert(start_op.typeInfo().type() == limit_op.typeInfo().type() &&
825            start_op.typeInfo().type() == delta_op.typeInfo().type());
826     if (output.typeInfo().type() == ir::DataType::FLOAT32)
827     {
828       new_shape = shape_inference::inferRangeShape<float>(
829         start_op.asScalar<float>(), limit_op.asScalar<float>(), delta_op.asScalar<float>());
830     }
831     else if (output.typeInfo().type() == ir::DataType::INT32)
832     {
833       new_shape = shape_inference::inferRangeShape<int32_t>(
834         start_op.asScalar<int32_t>(), limit_op.asScalar<int32_t>(), delta_op.asScalar<int32_t>());
835     }
836     assert(output.shape() == new_shape);
837   }
838   else
839   {
840     output.info().setDynamic();
841   }
842 }
843
844 void StaticShapeInferer::visit(const ir::operation::Reduce &op)
845 {
846   auto &operands = _lowered_subg->graph().operands();
847
848   const auto input_idx{op.getInputs().at(ir::operation::Reduce::Input::INPUT)};
849   const auto &input = operands.at(input_idx);
850
851   const auto axes_idx{op.getInputs().at(ir::operation::Reduce::Input::AXES)};
852   const auto &axes = operands.at(axes_idx);
853
854   // get mutable output operand
855   const auto output_idx = op.getOutputs().at(0);
856   ir::Operand &output = operands.at(output_idx);
857
858   std::vector<int32_t> axes_vec;
859   for (size_t i = 0; i < axes.shape().num_elements(); ++i)
860   {
861     switch (axes.typeInfo().type())
862     {
863       case ir::DataType::INT32:
864       {
865         axes_vec.emplace_back(reinterpret_cast<const int32_t *>(axes.data()->base())[i]);
866         break;
867       }
868       case ir::DataType::INT64:
869       {
870         axes_vec.emplace_back(reinterpret_cast<const int64_t *>(axes.data()->base())[i]);
871         break;
872       }
873       default:
874         throw std::runtime_error("StaticShapeInferer " + op.name() + ": Not supported data type");
875         break;
876     }
877   }
878   const auto keep_dims = op.param().keep_dims;
879
880   // re-sizing output shape
881   ir::Shape new_shape =
882     shape_inference::inferReduceShape(input.info().shape(), axes_vec, keep_dims);
883   output.info().shape(new_shape);
884 }
885
886 void StaticShapeInferer::visit(const ir::operation::Reshape &op)
887 {
888   auto &operands = _lowered_subg->graph().operands();
889
890   const auto input_idx{op.getInputs().at(ir::operation::Reshape::Input::INPUT)};
891   const auto &input = operands.at(input_idx);
892
893   // get mutable output operand
894   const auto output_idx = op.getOutputs().at(0);
895   ir::Operand &output = operands.at(output_idx);
896
897   // New shape is given by second input tensor
898   if (op.getInputs().size() == 2)
899   {
900     // Let's check the second input
901     const auto shape_idx{op.getInputs().at(ir::operation::Reshape::Input::SHAPE)};
902     const auto &shape = operands.at(shape_idx);
903
904     if (shape.isConstant())
905     {
906       const auto *shape_buf = reinterpret_cast<const int32_t *>(shape.data()->base());
907       assert(shape_buf);
908
909       ir::Shape new_shape = shape_inference::inferReshapeShape(
910         shape_buf, shape.shape().num_elements(), input.shape().num_elements());
911
912       // if shape is from Const, TFLC put the shape of output into tensor
913       if (new_shape != output.shape())
914       {
915         // change on output shape
916         output.info().shape(new_shape);
917       }
918     }
919     else
920     {
921       // if shape is NOT Const, set output shape to be dynamic_
922       output.info().setDynamic();
923     }
924   }
925   // New shape is given by option
926   else if (op.param().new_shape.size() != 0)
927   {
928     // Let's check the new_shape option
929     auto shape = op.param().new_shape;
930     ir::Shape new_shape =
931       shape_inference::inferReshapeShape(shape.data(), shape.size(), input.shape().num_elements());
932
933     if (new_shape != output.shape())
934     {
935       // change on output shape
936       output.info().shape(new_shape);
937     }
938   }
939   else
940   {
941     throw std::runtime_error("Reshape: new shape is missing");
942   }
943 }
944
945 void StaticShapeInferer::visit(const ir::operation::ResizeBilinear &op)
946 {
947   auto &operands = _lowered_subg->graph().operands();
948
949   const auto input_idx{op.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
950   const auto &input = operands.at(input_idx);
951
952   // get mutable output operand
953   const auto output_idx = op.getOutputs().at(0);
954   ir::Operand &output = operands.at(output_idx);
955
956   int32_t height_out, width_out;
957   if (op.getInputs().size() == 2)
958   {
959     auto &size = operands.at(op.getInputs().at(ir::operation::ResizeBilinear::Input::SIZE));
960     if (!size.isConstant())
961     {
962       output.info().setDynamic();
963       return;
964     }
965     const auto size_v = size.asVector<std::int32_t>();
966     height_out = size_v[0];
967     width_out = size_v[1];
968   }
969   else
970   {
971     height_out = op.param().height_out;
972     width_out = op.param().width_out;
973   }
974
975   // Shape inferencing logic based on Params
976   ir::Shape new_shape =
977     shape_inference::inferResizeBilinearShape(input.shape(), height_out, width_out);
978
979   // if size_op is from Const, TFLC put the shape of output into tensor
980   if (new_shape != output.shape())
981   {
982     // change on output shape
983     output.info().shape(new_shape);
984   }
985 }
986
987 void StaticShapeInferer::visit(const ir::operation::Reverse &op)
988 {
989   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Reverse::Input::INPUT));
990 }
991
992 void StaticShapeInferer::visit(const ir::operation::Select &op)
993 {
994   auto &operands = _lowered_subg->graph().operands();
995
996   const auto input_cond_idx{op.getInputs().at(ir::operation::Select::Input::CONDITION)};
997   const auto &input_cond = operands.at(input_cond_idx);
998
999   const auto input_true_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_TRUE)};
1000   const auto &input_true = operands.at(input_true_idx);
1001
1002   const auto input_false_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_FALSE)};
1003   const auto &input_false = operands.at(input_false_idx);
1004
1005   auto output_idx = op.getOutputs().at(0);
1006   ir::Operand &output = operands.at(output_idx);
1007
1008   // Select output shpae
1009   ir::Shape new_shape = shape_inference::inferSelectShape(
1010     input_cond.info().shape(), input_true.info().shape(), input_false.info().shape());
1011   output.info().shape(new_shape);
1012 }
1013
1014 void StaticShapeInferer::visit(const ir::operation::Shape &op)
1015 {
1016   auto &operands = _lowered_subg->graph().operands();
1017
1018   const auto input_idx{op.getInputs().at(0)};
1019   const auto &input = operands.at(input_idx);
1020
1021   // get mutable output operand
1022   const auto output_idx = op.getOutputs().at(0);
1023   ir::Operand &output = operands.at(output_idx);
1024
1025   // re-sizing output shape
1026   ir::Shape output_shape;
1027   output_shape.append(input.info().shape().rank());
1028
1029   output.info().shape(output_shape);
1030 }
1031
1032 void StaticShapeInferer::visit(const ir::operation::Slice &op)
1033 {
1034   auto &operands = _lowered_subg->graph().operands();
1035
1036   const auto input_index{op.getInputs().at(ir::operation::Slice::Input::INPUT)};
1037   const auto &input = operands.at(input_index);
1038   const auto begins_index{op.getInputs().at(ir::operation::Slice::Input::BEGINS)};
1039   const auto &begins = operands.at(begins_index);
1040   const auto sizes_index{op.getInputs().at(ir::operation::Slice::Input::SIZES)};
1041   const auto &sizes = operands.at(sizes_index);
1042   const auto output_index = op.getOutputs().at(0);
1043   ir::Operand &output = operands.at(output_index);
1044
1045   // Whether input is constant or not does not affect whether output is dynamic or not
1046   if (!(begins.isConstant() && sizes.isConstant()))
1047   {
1048     output.info().setDynamic();
1049     return;
1050   }
1051
1052   auto begins_buf = begins.data()->base();
1053   auto sizes_buf = sizes.data()->base();
1054
1055   const auto begins_type = begins.typeInfo().type();
1056   assert(begins_type == ir::DataType::INT32 || begins_type == ir::DataType::INT64);
1057   assert(begins_type == sizes.typeInfo().type());
1058
1059   ir::Shape new_shape =
1060     (begins_type == ir::DataType::INT32)
1061       ? shape_inference::inferSliceShape<int32_t>(input.info().shape(),
1062                                                   reinterpret_cast<const int32_t *>(begins_buf),
1063                                                   reinterpret_cast<const int32_t *>(sizes_buf))
1064       : shape_inference::inferSliceShape<int64_t>(input.info().shape(),
1065                                                   reinterpret_cast<const int64_t *>(begins_buf),
1066                                                   reinterpret_cast<const int64_t *>(sizes_buf));
1067   output.info().shape(new_shape);
1068 }
1069
1070 void StaticShapeInferer::visit(const ir::operation::Softmax &op)
1071 {
1072   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Softmax::Input::INPUT));
1073 }
1074
1075 void StaticShapeInferer::visit(const ir::operation::SpaceToBatchND &op)
1076 {
1077   auto &operands = _lowered_subg->graph().operands();
1078
1079   const auto output_index = op.getOutputs().at(0);
1080   const auto input_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
1081   const auto block_shape_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)};
1082   const auto padding_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
1083
1084   ir::Operand &output = operands.at(output_index);
1085   const auto &input = operands.at(input_idx);
1086   const auto &block_shape = operands.at(block_shape_idx);
1087   const auto &padding = operands.at(padding_idx);
1088
1089   // Whether input is constant or not does not affect whether output is dynamic or not
1090   if (!(block_shape.isConstant() && padding.isConstant()))
1091   {
1092     output.info().setDynamic();
1093     return;
1094   }
1095
1096   auto input_shape = input.info().shape();
1097   auto block_shape_shape = block_shape.info().shape();
1098   auto padding_shape = padding.info().shape();
1099
1100   auto block_shape_data = reinterpret_cast<const int32_t *>(block_shape.data()->base());
1101   auto padding_data = reinterpret_cast<const int32_t *>(padding.data()->base());
1102
1103   ir::Shape new_shape = shape_inference::inferSpaceToBatchNDShape(
1104     input_shape, block_shape_shape, padding_shape, block_shape_data, padding_data);
1105
1106   output.info().shape(new_shape);
1107 }
1108
1109 void StaticShapeInferer::visit(const ir::operation::Split &op)
1110 {
1111   auto &operands = _lowered_subg->graph().operands();
1112
1113   const auto input_idx{op.getInputs().at(ir::operation::Split::Input::INPUT)};
1114   const auto &input = operands.at(input_idx);
1115
1116   const auto axis_idx{op.getInputs().at(ir::operation::Split::Input::AXIS)};
1117   const auto &axis = operands.at(axis_idx);
1118
1119   auto outputs = op.getOutputs();
1120   if (!axis.isConstant())
1121   {
1122     for (auto output_idx : outputs)
1123     {
1124       ir::Operand &output = operands.at(output_idx);
1125       output.info().setDynamic();
1126     }
1127     return;
1128   }
1129
1130   const auto num_splits = op.param().num_splits;
1131
1132   const auto rank = input.info().shape().rank();
1133   auto axis_value = axis.asScalar<int32_t>();
1134   axis_value = axis_value < 0 ? axis_value + rank : axis_value;
1135
1136   assert(0 <= axis_value && axis_value < rank);
1137
1138   ir::Shape new_shape =
1139     shape_inference::inferSplitShape(input.info().shape(), axis_value, num_splits);
1140   for (auto output_idx : outputs)
1141   {
1142     ir::Operand &output = operands.at(output_idx);
1143     output.info().shape(new_shape);
1144   }
1145 }
1146
1147 void StaticShapeInferer::visit(const ir::operation::SquaredDifference &op)
1148 {
1149   handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::SquaredDifference::Input::LHS),
1150                            op.getInputs().at(ir::operation::SquaredDifference::Input::RHS));
1151 }
1152
1153 void StaticShapeInferer::visit(const ir::operation::Squeeze &op)
1154 {
1155   auto &operands = _lowered_subg->graph().operands();
1156
1157   const auto input_idx{op.getInputs().at(ir::operation::Squeeze::Input::INPUT)};
1158   const auto &input = operands.at(input_idx);
1159
1160   const auto output_idx = op.getOutputs().at(0);
1161   ir::Operand &output = operands.at(output_idx);
1162
1163   // Squeeze output shpae
1164   ir::Shape new_shape = shape_inference::inferSqueezeShape(input.info().shape(), op.param());
1165   output.info().shape(new_shape);
1166 }
1167
1168 void StaticShapeInferer::visit(const ir::operation::StridedSlice &op)
1169 {
1170   auto &operands = _lowered_subg->graph().operands();
1171
1172   const auto input_index{op.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
1173   const auto &input = operands.at(input_index);
1174   const auto starts_index{op.getInputs().at(ir::operation::StridedSlice::Input::STARTS)};
1175   const auto &starts = operands.at(starts_index);
1176   const auto ends_index{op.getInputs().at(ir::operation::StridedSlice::Input::ENDS)};
1177   const auto &ends = operands.at(ends_index);
1178   const auto strides_index{op.getInputs().at(ir::operation::StridedSlice::Input::STRIDES)};
1179   const auto &strides = operands.at(strides_index);
1180   const auto output_index = op.getOutputs().at(0);
1181   ir::Operand &output = operands.at(output_index);
1182
1183   if (!(starts.isConstant() && ends.isConstant() && strides.isConstant()))
1184   {
1185     output.info().setDynamic();
1186     return;
1187   }
1188
1189   const auto begin_mask = op.param().begin_mask;
1190   const auto end_mask = op.param().end_mask;
1191   const auto shrink_axis_mask = op.param().shrink_axis_mask;
1192   const auto rank = input.info().shape().rank();
1193
1194   auto starts_buf = reinterpret_cast<const uint32_t *>(starts.data()->base());
1195   auto ends_buf = reinterpret_cast<const uint32_t *>(ends.data()->base());
1196   auto strides_buf = reinterpret_cast<const uint32_t *>(strides.data()->base());
1197
1198   auto op_params = shape_inference::buildStridedSliceParams(
1199     starts_buf, ends_buf, strides_buf, begin_mask, end_mask, shrink_axis_mask, rank);
1200
1201   ir::Shape new_shape =
1202     shape_inference::inferStridedSliceShape(input.info().shape(), op_params, rank);
1203   output.info().shape(new_shape);
1204 }
1205
1206 void StaticShapeInferer::visit(const ir::operation::Tile &op)
1207 {
1208   auto &operands = _lowered_subg->graph().operands();
1209
1210   const auto input_idx{op.getInputs().at(ir::operation::Tile::Input::INPUT)};
1211   const auto &input = operands.at(input_idx);
1212
1213   const auto multiplier_idx{op.getInputs().at(ir::operation::Tile::Input::MULTIPLES)};
1214   const auto &multiplier = operands.at(multiplier_idx);
1215
1216   const auto output_idx = op.getOutputs().at(0);
1217   ir::Operand &output = operands.at(output_idx);
1218
1219   if (!multiplier.isConstant())
1220   {
1221     output.info().setDynamic();
1222     return;
1223   }
1224
1225   auto multiplier_buffer = reinterpret_cast<const int32_t *>(multiplier.data()->base());
1226   assert(multiplier_buffer);
1227
1228   // re-sizing output shape
1229   auto new_shape = shape_inference::inferTileShape(input.info().shape(), multiplier_buffer,
1230                                                    multiplier.shape().num_elements());
1231   output.info().shape(new_shape);
1232 }
1233
1234 void StaticShapeInferer::visit(const ir::operation::Transpose &op)
1235 {
1236   auto &operands = _lowered_subg->graph().operands();
1237
1238   const auto input_idx{op.getInputs().at(ir::operation::Transpose::Input::INPUT)};
1239   const auto &input = operands.at(input_idx);
1240
1241   const auto perm_idx{op.getInputs().at(ir::operation::Transpose::Input::PERMUTATION)};
1242   const auto &perm = operands.at(perm_idx);
1243
1244   // perm.shape() != ir::Shape{0} means that perm is (n-1...0)
1245   // TODO This condition changes to perm.num_elements() == 0
1246   const auto is_regular_transpose = perm.shape() == ir::Shape{0};
1247
1248   // get mutable output operand
1249   const auto output_idx = op.getOutputs().at(0);
1250   auto &output = operands.at(output_idx);
1251   if (!perm.isConstant() && !is_regular_transpose)
1252   {
1253     output.info().setDynamic();
1254     return;
1255   }
1256
1257   ir::Shape new_shape;
1258   if (is_regular_transpose)
1259   {
1260     // Call by (n-1...0)
1261     new_shape = shape_inference::inferTransposeShape(input.info().shape(), nullptr, 0);
1262   }
1263   else
1264   {
1265     // Check rank
1266     if (input.info().shape().rank() != static_cast<int>(perm.info().shape().num_elements()))
1267     {
1268       throw std::runtime_error("StaticShapeInferer failed, bad rank size: " +
1269                                std::to_string(perm.info().shape().num_elements()));
1270     }
1271
1272     // set output shape, based on input and params
1273     const auto perm_buf = reinterpret_cast<const int32_t *>(perm.data()->base());
1274     new_shape = shape_inference::inferTransposeShape(input.info().shape(), perm_buf,
1275                                                      perm.shape().num_elements());
1276   }
1277   output.info().shape(new_shape);
1278 }
1279
1280 void StaticShapeInferer::visit(const ir::operation::Unpack &op)
1281 {
1282   auto &operands = _lowered_subg->graph().operands();
1283
1284   const auto input_idx{op.getInputs().at(0)};
1285   const auto &input = operands.at(input_idx);
1286   const auto num = op.param().num;
1287   const auto rank = input.shape().rank();
1288   const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
1289
1290   assert(axis < rank);
1291   if (axis < 0)
1292   {
1293     for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++)
1294     {
1295       const auto output_idx = op.getOutputs().at(out_tensor_idx);
1296       ir::Operand &output = operands.at(output_idx);
1297       output.info().setDynamic();
1298     }
1299     return;
1300   }
1301
1302   ir::Shape new_shape = shape_inference::inferUnpackShape(input.info().shape(), axis, rank);
1303
1304   // re-sizing output shape
1305   for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++)
1306   {
1307     const auto output_idx = op.getOutputs().at(out_tensor_idx);
1308     ir::Operand &output = operands.at(output_idx);
1309     output.info().shape(new_shape);
1310   }
1311 }
1312
1313 void StaticShapeInferer::visit(const ir::operation::While &op)
1314 {
1315   auto body_input_observer = _subg_input_observers.at(op.param().body_subg_index).get();
1316   auto cond_input_observer = _subg_input_observers.at(op.param().cond_subg_index).get();
1317   // re-sizing input shapes of body subgraph
1318   const auto inputs = op.getInputs();
1319   std::vector<ir::OperandInfo> inputs_info;
1320   const auto &graph = _lowered_subg->graph();
1321   for (size_t i = 0; i < inputs.size(); ++i)
1322   {
1323     const auto &operand_info = graph.operands().at(inputs.at(i)).info();
1324     inputs_info.emplace_back(operand_info);
1325   }
1326
1327   body_input_observer->updateShapes(inputs_info);
1328   _child_inferers.at(op.param().body_subg_index)->infer();
1329
1330   // Check whether while operation's shapes are predictable
1331   // This while op's outputs are also updated in the above function
1332   // "_child_inferers.at(op.param().body_subg_index)->update()". That means that body's outputs and
1333   // thils op's outputs must have the same shape. So we can predict whether body subgraphs will
1334   // change at every step by comparing the shapes of inputs/outputs. If any of shape of body outputs
1335   // and inputs are different Non-constant operands will be set to dynamic.
1336   bool check_unpredictable_dynamic = false;
1337   const auto &updated_outputs = op.getOutputs();
1338   assert(inputs_info.size() == updated_outputs.size());
1339   for (size_t i = 0; i < updated_outputs.size(); ++i)
1340   {
1341     const auto &input_info = inputs_info.at(i);
1342     const auto &output_info = graph.operands().at(updated_outputs.at(i)).info();
1343     if (input_info.isDynamic() != output_info.isDynamic() ||
1344         input_info.shape() != output_info.shape())
1345     {
1346       check_unpredictable_dynamic = true;
1347       break;
1348     }
1349   }
1350
1351   if (check_unpredictable_dynamic)
1352   {
1353     body_input_observer->updateShapes(inputs_info, check_unpredictable_dynamic);
1354     _child_inferers.at(op.param().body_subg_index)->infer();
1355   }
1356   cond_input_observer->updateShapes(inputs_info, check_unpredictable_dynamic);
1357   _child_inferers.at(op.param().cond_subg_index)->infer();
1358 }
1359
1360 void StaticShapeInferer::visit(const ir::operation::DetectionPostProcess &op)
1361 {
1362   // TODO: NMS supports very limited input/output size.
1363   ir::operation::DetectionPostProcess::Param param = op.param();
1364
1365   auto &operands = _lowered_subg->graph().operands();
1366   const int num_detected_boxes = param.max_detections * param.max_classes_per_detection;
1367
1368   const auto output_idx1 = op.getOutputs().at(0);
1369   auto &output1 = operands.at(output_idx1);
1370   output1.info().shape({1, num_detected_boxes, 4});
1371
1372   const auto output_idx2 = op.getOutputs().at(1);
1373   auto &output2 = operands.at(output_idx2);
1374   output2.info().shape({1, num_detected_boxes});
1375
1376   const auto output_idx3 = op.getOutputs().at(2);
1377   auto &output3 = operands.at(output_idx3);
1378   output3.info().shape({1, num_detected_boxes});
1379
1380   const auto output_idx4 = op.getOutputs().at(3);
1381   auto &output4 = operands.at(output_idx4);
1382   output4.info().shape({1});
1383 }
1384 void StaticShapeInferer::visit(const ir::operation::Bulk &op)
1385 {
1386   auto &operands = _lowered_subg->graph().operands();
1387
1388   // TODO: support multiple inputs/outputs
1389   const auto input_idx{op.getInputs().at(0)};
1390   const auto &input = operands.at(input_idx);
1391   const auto output_idx = op.getOutputs().at(0);
1392   ir::Operand &output = operands.at(output_idx);
1393
1394   auto cur_input_shape = input.info().shape();
1395   auto origin_input_shape = op.param().origin_input_shapes[0];
1396   auto cur_output_shape = output.info().shape();
1397   auto origin_output_shape = op.param().origin_output_shapes[0];
1398
1399   // TODO: more check for valid batch request
1400   if ((cur_input_shape.dim(0) < origin_output_shape.dim(0)) ||
1401       (cur_input_shape.dim(0) % origin_output_shape.dim(0) != 0))
1402   {
1403     throw std::runtime_error("StaticShapeInferer " + op.name() + ": Not supported batch size");
1404   }
1405   size_t batch_multiplier = cur_input_shape.dim(0) / origin_output_shape.dim(0);
1406
1407   ir::Shape new_shape;
1408   new_shape.append(origin_output_shape.dim(0) * batch_multiplier);
1409   for (int32_t d = 1; d < origin_output_shape.rank(); ++d)
1410     new_shape.append(origin_output_shape.dim(d));
1411
1412   output.info().shape(new_shape);
1413 }
1414
1415 } // namespace compiler
1416
1417 } // namespace onert