Imported Upstream version 1.12.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / StaticShapeInferer.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "compiler/StaticShapeInferer.h"
18 #include "util/ShapeInference.h"
19 #include "util/logging.h"
20
21 #include <sstream>
22
23 namespace onert
24 {
25 namespace compiler
26 {
27
28 bool StaticShapeInferer::infer(const ir::OpSequence &op_seq)
29 {
30   bool has_dynamic_tensor = false;
31
32   for (const auto &operation_idx : op_seq.operations())
33   {
34     auto &op = _operations.at(operation_idx);
35     auto opcode = op.opcode();
36
37     _return_has_dynamic_tensor = false; // this is used as a return value inside operation's visit()
38
39     // IF: need shape inference for then, else
40     // While: need shape inference for condition, body
41     if (opcode == ir::OpCode::If || opcode == ir::OpCode::While)
42     {
43       op.accept(*this);
44     }
45     else
46     {
47       _return_has_dynamic_tensor = checkDynamicInput(op);
48
49       if (_return_has_dynamic_tensor)
50       {
51         setDynamicOutput(op);
52       }
53       else
54       {
55         op.accept(*this);
56       }
57     }
58
59     has_dynamic_tensor = has_dynamic_tensor || _return_has_dynamic_tensor;
60   }
61
62   return has_dynamic_tensor;
63 }
64
65 bool StaticShapeInferer::checkDynamicInput(const ir::Operation &op)
66 {
67   for (auto input_idx : op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
68   {
69     if (_operands.at(input_idx).info().isDynamic())
70     {
71       return true;
72     }
73   }
74
75   return false;
76 }
77
78 void StaticShapeInferer::setDynamicOutput(const ir::Operation &op)
79 {
80   for (auto output_idx : op.getOutputs() | ir::Remove::UNDEFINED)
81   {
82     _operands.at(output_idx).info().setDynamic();
83   }
84 }
85
86 void StaticShapeInferer::handleBinaryArithmeticOp(const ir::Operation &op,
87                                                   const ir::OperandIndex lhs_idx,
88                                                   const ir::OperandIndex rhs_idx)
89 {
90   const auto &lhs = _operands.at(lhs_idx);
91   const auto &rhs = _operands.at(rhs_idx);
92
93   const auto output_idx = op.getOutputs().at(0);
94   ir::Operand &output = _operands.at(output_idx);
95
96   // re-sizing output shape
97   ir::Shape new_shape = shape_inference::inferEltwiseShape(lhs.info().shape(), rhs.info().shape());
98   output.info().shape(new_shape);
99 }
100
101 void StaticShapeInferer::handleSimpleUnaryOp(const ir::Operation &op,
102                                              const ir::OperandIndex input_idx)
103 {
104   const auto &input = _operands.at(input_idx);
105
106   // get mutable output operand
107   const auto output_idx = op.getOutputs().at(0);
108   ir::Operand &output = _operands.at(output_idx);
109
110   // re-sizing output shape
111   ir::Shape new_shape = input.info().shape();
112   output.info().shape(new_shape);
113 }
114
115 void StaticShapeInferer::dump()
116 {
117   auto get_shape_str = [](const ir::Shape &shape) {
118     std::stringstream sstream;
119     sstream << "shape : {";
120     for (int i = 0; i < shape.rank(); i++)
121     {
122       if (i == 0)
123         sstream << shape.dim(i);
124       else
125         sstream << " " << shape.dim(i);
126     }
127     sstream << "}";
128     return sstream.str();
129   };
130
131   for (const auto &pair : _lowered_subgs)
132   {
133     const auto index = pair.first;
134     const auto &lowered_subg = pair.second;
135     VERBOSE(StaticShapeInferer) << "SubGraph #" << index.value() << std::endl;
136     lowered_subg->graph().operands().iterate(
137         [&](const ir::OperandIndex &ind, const ir::Operand &operand) {
138           VERBOSE(StaticShapeInferer) << "Operand #" << ind.value() << ", "
139                                       << (operand.info().isDynamic() ? "Dynamic" : "Static") << ", "
140                                       << get_shape_str(operand.info().shape()) << std::endl;
141         });
142   }
143 }
144
145 void StaticShapeInferer::visit(const ir::operation::ArgMinMax &op)
146 {
147   const auto input_idx{op.getInputs().at(ir::operation::ArgMinMax::Input::INPUT)};
148   const auto &input = _operands.at(input_idx);
149
150   const auto axis_idx{op.getInputs().at(ir::operation::ArgMinMax::Input::AXIS)};
151   const auto &axis = _operands.at(axis_idx);
152
153   // get mutable output operand
154   const auto output_idx = op.getOutputs().at(0);
155   ir::Operand &output = _operands.at(output_idx);
156
157   if (!axis.isConstant())
158   {
159     output.info().setDynamic();
160     _return_has_dynamic_tensor = true;
161     return;
162   }
163
164   const auto rank = input.info().shape().rank();
165   auto axis_value = axis.asScalar<int32_t>();
166   axis_value = axis_value < 0 ? axis_value + rank : axis_value;
167
168   // re-sizing output shape
169   ir::Shape new_shape =
170       shape_inference::inferArgMinMaxShape(input.info().shape(), axis_value, rank);
171   output.info().shape(new_shape);
172 }
173
174 void StaticShapeInferer::visit(const ir::operation::BatchMatMul &op)
175 {
176   const auto lhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::LHS);
177   const auto rhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::RHS);
178   const auto output_index = op.getOutputs().at(0);
179   const auto &lhs = _operands.at(lhs_index);
180   const auto &rhs = _operands.at(rhs_index);
181   auto &output = _operands.at(output_index);
182   auto new_shape = shape_inference::inferBatchMatMulShape(lhs.shape(), rhs.shape(), op.param());
183   output.info().shape(new_shape);
184 }
185
186 void StaticShapeInferer::visit(const ir::operation::BCQFullyConnected &op)
187 {
188   const auto input_idx{op.getInputs().at(ir::operation::BCQFullyConnected::Input::INPUT)};
189   const auto &input = _operands.at(input_idx);
190
191   const auto cluster_idx{
192       op.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_CLUSTERS)};
193   const auto &cluster = _operands.at(cluster_idx);
194
195   const auto output_idx = op.getOutputs().at(0);
196   ir::Operand &output = _operands.at(output_idx);
197
198   auto cluster_buf = reinterpret_cast<const int32_t *>(cluster.data()->base());
199   assert(cluster_buf);
200
201   // re-sizing output shape
202   ir::Shape new_shape = shape_inference::inferBCQFullyConnectedShape(
203       input.info().shape(), cluster.info().shape(), cluster_buf);
204   output.info().shape(new_shape);
205 }
206
207 void StaticShapeInferer::visit(const ir::operation::BCQGather &op)
208 {
209   const auto indices_idx{op.getInputs().at(ir::operation::BCQGather::Input::INDICES)};
210   const auto &indices = _operands.at(indices_idx);
211
212   const auto input_binary_idx{op.getInputs().at(ir::operation::BCQGather::Input::INPUT_BINARY)};
213   const auto &input_binary = _operands.at(input_binary_idx);
214
215   const auto cluster_idx{op.getInputs().at(ir::operation::BCQGather::Input::INPUT_CLUSTERS)};
216   const auto &cluster = _operands.at(cluster_idx);
217
218   const auto output_idx = op.getOutputs().at(0);
219   ir::Operand &output = _operands.at(output_idx);
220
221   auto cluster_buf = reinterpret_cast<const int32_t *>(cluster.data()->base());
222   assert(cluster_buf);
223
224   auto rank = input_binary.shape().rank();
225
226   // re-sizing output shape
227   ir::Shape new_shape = shape_inference::inferBCQGatherShape(
228       indices.info().shape(), cluster.info().shape(), cluster_buf, rank, op.param());
229
230   output.info().shape(new_shape);
231 }
232
233 void StaticShapeInferer::visit(const ir::operation::BinaryArithmetic &op)
234 {
235   handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::BinaryArithmetic::Input::LHS),
236                            op.getInputs().at(ir::operation::BinaryArithmetic::Input::RHS));
237 }
238
239 void StaticShapeInferer::visit(const ir::operation::BroadcastTo &op)
240 {
241   // get mutable output operand
242   const auto output_idx = op.getOutputs().at(0);
243   ir::Operand &output = _operands.at(output_idx);
244
245   const auto shape_idx{op.getInputs().at(ir::operation::BroadcastTo::Input::SHAPE)};
246   const auto &shape = _operands.at(shape_idx);
247
248   if (!shape.isConstant())
249   {
250     output.info().setDynamic();
251     _return_has_dynamic_tensor = true;
252     return;
253   }
254
255   // assert(shape.typeInfo().type() == ir::DataType::INT32);
256   auto shape_buffer = reinterpret_cast<const int32_t *>(shape.data()->base());
257
258   // re-sizing output shape
259   ir::Shape new_shape = shape_inference::inferBroadcastToShape(shape.info().shape(), shape_buffer);
260   output.info().shape(new_shape);
261 }
262
263 void StaticShapeInferer::visit(const ir::operation::Comparison &op)
264 {
265   handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Comparison::Input::INPUT0),
266                            op.getInputs().at(ir::operation::Comparison::Input::INPUT1));
267 }
268
269 void StaticShapeInferer::visit(const ir::operation::Concat &op)
270 {
271   const auto input_count = op.getInputs().size();
272
273   const auto output_idx = op.getOutputs().at(0);
274   ir::Operand &output = _operands.at(output_idx);
275
276   shape_inference::Shapes input_shapes;
277   for (uint32_t i = 0; i < input_count; i++)
278   {
279     const auto input_idx{op.getInputs().at(i)};
280     const auto &input = _operands.at(input_idx);
281     input_shapes.emplace_back(input.shape());
282   }
283
284   ir::Shape out_shape = shape_inference::inferConcatShape(input_shapes, op.param());
285
286   // re-sizing output shape
287   output.info().shape(out_shape);
288 }
289
290 void StaticShapeInferer::visit(const ir::operation::Conv2D &op)
291 {
292   const auto input_idx{op.getInputs().at(ir::operation::Conv2D::Input::INPUT)};
293   const auto &input = _operands.at(input_idx);
294   const auto ker_idx{op.getInputs().at(ir::operation::Conv2D::Input::KERNEL)};
295   const auto &ker = _operands.at(ker_idx);
296   const auto output_idx = op.getOutputs().at(0);
297   ir::Operand &output = _operands.at(output_idx);
298
299   // re-sizing output shape
300   ir::Shape new_shape =
301       shape_inference::inferConv2DShape(input.info().shape(), ker.info().shape(), op.param());
302   output.info().shape(new_shape);
303 }
304
305 void StaticShapeInferer::visit(const ir::operation::ElementwiseActivation &op)
306 {
307   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::ElementwiseActivation::Input::INPUT));
308 }
309
310 void StaticShapeInferer::visit(const ir::operation::ElementwiseBinary &op)
311 {
312   handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::ElementwiseBinary::Input::LHS),
313                            op.getInputs().at(ir::operation::ElementwiseBinary::Input::RHS));
314 }
315
316 void StaticShapeInferer::visit(const ir::operation::ElementwiseUnary &op)
317 {
318   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT));
319 }
320
321 void StaticShapeInferer::visit(const ir::operation::ExpandDims &op)
322 {
323   const auto input_idx{op.getInputs().at(ir::operation::ExpandDims::Input::INPUT)};
324   const auto &input = _operands.at(input_idx);
325   const auto axis_idx{op.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
326   const auto &axis = _operands.at(axis_idx);
327   const auto output_idx = op.getOutputs().at(0);
328   ir::Operand &output = _operands.at(output_idx);
329
330   if (!axis.isConstant())
331   {
332     output.info().setDynamic();
333     _return_has_dynamic_tensor = true;
334     return;
335   }
336
337   // even when axis is constant, output shape should be recalculated since user might call
338   // nnfw_set_input_tensorinfo(input, some_new_shape)
339   auto axis_type = axis.typeInfo().type();
340   assert(axis_type == ir::DataType::INT32 || axis_type == ir::DataType::INT64);
341
342   assert(axis.data()->base());
343   int32_t axis_value =
344       (axis_type == ir::DataType::INT32)
345           ? reinterpret_cast<const int32_t *>(axis.data()->base())[0]
346           : static_cast<int32_t>(reinterpret_cast<const int64_t *>(axis.data()->base())[0]);
347
348   // re-sizing output shape
349   ir::Shape new_shape = shape_inference::inferExpandDimsShape(input.info().shape(), axis_value);
350   output.info().shape(new_shape);
351 }
352
353 void StaticShapeInferer::visit(const ir::operation::Fill &op)
354 {
355   const auto shape_idx{op.getInputs().at(ir::operation::Fill::Input::SHAPE)};
356   const auto &shape = _operands.at(shape_idx);
357   const auto output_idx = op.getOutputs().at(0);
358   ir::Operand &output = _operands.at(output_idx);
359
360   if (!shape.isConstant())
361   {
362     output.info().setDynamic();
363     _return_has_dynamic_tensor = true;
364     return;
365   }
366
367   const auto dims_type = shape.typeInfo().type();
368   assert(dims_type == ir::DataType::INT32 || dims_type == ir::DataType::INT64);
369
370   auto dims_buf = shape.data()->base();
371   assert(dims_buf);
372
373   const auto &dims_shape = shape.info().shape();
374   auto new_shape = ((dims_type == ir::DataType::INT32)
375                         ? shape_inference::inferFillShape<int32_t>(
376                               dims_shape, reinterpret_cast<const int32_t *>(dims_buf))
377                         : shape_inference::inferFillShape<int64_t>(
378                               dims_shape, reinterpret_cast<const int64_t *>(dims_buf)));
379
380   output.info().shape(new_shape);
381 }
382
383 void StaticShapeInferer::visit(const ir::operation::FullyConnected &op)
384 {
385   const auto input_idx{op.getInputs().at(ir::operation::FullyConnected::Input::INPUT)};
386   const auto &input = _operands.at(input_idx);
387
388   const auto ker_idx{op.getInputs().at(ir::operation::FullyConnected::Input::WEIGHT)};
389   const auto &ker = _operands.at(ker_idx);
390
391   // get mutable output operand
392   const auto output_idx = op.getOutputs().at(0);
393   ir::Operand &output = _operands.at(output_idx);
394   // re-sizing output shape
395   ir::Shape new_shape =
396       shape_inference::inferFullyConnectedShape(input.info().shape(), ker.info().shape());
397   output.info().shape(new_shape);
398 }
399
400 void StaticShapeInferer::visit(const ir::operation::FusedBatchNorm &op)
401 {
402   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::FusedBatchNorm::Input::INPUT));
403 }
404
405 void StaticShapeInferer::visit(const ir::operation::Gather &op)
406 {
407   const auto input_idx{op.getInputs().at(ir::operation::Gather::Input::INPUT)};
408   const auto &input = _operands.at(input_idx);
409
410   // get mutable output operand
411   const auto output_idx = op.getOutputs().at(0);
412   ir::Operand &output = _operands.at(output_idx);
413
414   const auto indices_idx{op.getInputs().at(ir::operation::Gather::Input::INDICES)};
415   const auto &indices = _operands.at(indices_idx);
416   const auto rank = input.info().shape().rank();
417   const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
418
419   assert(0 <= axis && axis < rank);
420
421   // re-sizing output shape
422   ir::Shape new_shape =
423       shape_inference::inferGatherShape(input.info().shape(), indices.info().shape(), axis, rank);
424   output.info().shape(new_shape);
425 }
426
427 void StaticShapeInferer::visit(const ir::operation::If &op)
428 {
429   auto &then_graph = _lowered_subgs.at(op.param().then_subg_index)->graph();
430   auto &else_graph = _lowered_subgs.at(op.param().else_subg_index)->graph();
431   const std::vector<ir::OperandIndex> inputs{op.getInputs().begin() + 1, op.getInputs().end()};
432   const auto &outputs = op.getOutputs();
433
434   // re-sizing input shapes of then subgraph
435   const auto &then_inputs = then_graph.getInputs();
436   assert(inputs.size() == then_inputs.size());
437   for (size_t i = 0; i < inputs.size(); ++i)
438   {
439     auto &then_input = then_graph.operands().at(then_inputs.at(i));
440     if (_operands.at(inputs.at(i)).info().isDynamic())
441     {
442       then_input.info().setDynamic();
443     }
444     else
445     {
446       auto new_shape = _operands.at(inputs.at(i)).info().shape();
447       then_input.info().shape(new_shape);
448     }
449   }
450
451   // re-sizing input shapes of else subgraph
452   const auto &else_inputs = else_graph.getInputs();
453   assert(inputs.size() == else_inputs.size());
454   for (size_t i = 0; i < inputs.size(); ++i)
455   {
456     auto &else_input = else_graph.operands().at(else_inputs.at(i));
457     if (_operands.at(inputs.at(i)).info().isDynamic())
458     {
459       else_input.info().setDynamic();
460     }
461     else
462     {
463       const auto &new_shape = _operands.at(inputs.at(i)).info().shape();
464       else_input.info().shape(new_shape);
465     }
466   }
467
468   // re-sizing operands of then subgraph
469   StaticShapeInferer then_inferer(op.param().then_subg_index, _lowered_subgs);
470   _lowered_subgs.at(op.param().then_subg_index)
471       ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
472         bool has_dynamic_tensor = then_inferer.infer(op_seq);
473         op_seq.has_dynamic_tensor(has_dynamic_tensor);
474       });
475
476   // re-sizing operands of else subgraph
477   StaticShapeInferer else_inferer(op.param().else_subg_index, _lowered_subgs);
478   _lowered_subgs.at(op.param().else_subg_index)
479       ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
480         bool has_dynamic_tensor = else_inferer.infer(op_seq);
481         op_seq.has_dynamic_tensor(has_dynamic_tensor);
482       });
483
484   // re-sizing output shapes
485   const auto &then_outputs = _lowered_subgs.at(op.param().then_subg_index)->graph().getOutputs();
486   const auto &else_outputs = _lowered_subgs.at(op.param().else_subg_index)->graph().getOutputs();
487   assert(outputs.size() == then_outputs.size());
488   assert(outputs.size() == else_outputs.size());
489   for (size_t i = 0; i < outputs.size(); ++i)
490   {
491     const auto &then_output = then_graph.operands().at(then_outputs.at(i));
492     const auto &else_output = else_graph.operands().at(else_outputs.at(i));
493     auto &output = _operands.at(outputs.at(i));
494     if (!then_output.info().isDynamic() && !else_output.info().isDynamic() &&
495         then_output.shape() == else_output.shape())
496     {
497       output.info().shape(then_output.shape());
498     }
499     else
500     {
501       output.info().setDynamic();
502       _return_has_dynamic_tensor = true;
503     }
504   }
505 }
506
507 void StaticShapeInferer::visit(const ir::operation::L2Normalization &op)
508 {
509   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::L2Normalization::Input::INPUT));
510 }
511
512 void StaticShapeInferer::visit(const ir::operation::LSTM &op)
513 {
514   const auto output_index{op.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
515   auto &output = _operands.at(output_index);
516
517   const auto output_state_out_index{
518       op.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)};
519
520   const auto cell_state_out_index{op.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)};
521
522   const auto scratch_buffer_index{op.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)};
523
524   if (output.info().isDynamic() || (_operands.exist(output_state_out_index) &&
525                                     _operands.at(output_state_out_index).info().isDynamic()) ||
526       (_operands.exist(cell_state_out_index) &&
527        _operands.at(cell_state_out_index).info().isDynamic()) ||
528       (_operands.exist(scratch_buffer_index) &&
529        _operands.at(scratch_buffer_index).info().isDynamic()))
530     return;
531
532   const auto input_index{op.getInputs().at(ir::operation::LSTM::Input::INPUT)};
533   const auto &input = _operands.at(input_index);
534
535   const auto input_to_output_weights_index{
536       op.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)};
537   const auto &input_to_output_weights = _operands.at(input_to_output_weights_index);
538
539   const auto recurrent_to_output_weights_index{
540       op.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)};
541   const auto &recurrent_to_output_weights = _operands.at(recurrent_to_output_weights_index);
542
543   // re-sizing outputs
544   const int n_batch = (input.shape().rank() == 3 && op.param().time_major) ? input.shape().dim(1)
545                                                                            : input.shape().dim(0);
546   const int n_cell = input_to_output_weights.shape().dim(0);
547   const int n_output = recurrent_to_output_weights.shape().dim(1);
548   if (input.shape().rank() == 3)
549   {
550     if (op.param().time_major)
551       output.info().shape(ir::Shape{input.shape().dim(0), n_batch, n_output});
552     else
553       output.info().shape(ir::Shape{n_batch, input.shape().dim(1), n_output});
554   }
555   else
556   {
557     assert(input.shape().rank() == 2);
558     output.info().shape(ir::Shape{n_batch, n_output});
559   }
560
561   if (_operands.exist(output_state_out_index))
562   {
563     auto &output_state_out = _operands.at(output_state_out_index);
564     output_state_out.info().shape(ir::Shape{n_batch, n_output});
565   }
566
567   if (_operands.exist(cell_state_out_index))
568   {
569     auto &cell_state_out = _operands.at(cell_state_out_index);
570     cell_state_out.info().shape(ir::Shape{n_batch, n_cell});
571   }
572
573   if (_operands.exist(scratch_buffer_index))
574   {
575     auto &scratch_buffer = _operands.at(scratch_buffer_index);
576
577     const auto input_to_input_weights_index{
578         op.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)};
579     const auto recurrent_to_input_weights_index{
580         op.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)};
581
582     bool has_input_to_input_weights =
583         _operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
584         _operands.at(input_to_input_weights_index).shape().dim(1) != 0;
585     bool has_recurrent_to_input_weights =
586         _operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
587         _operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0;
588
589     // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG).
590     // true: no CIFG
591     // false: CIFG
592     bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
593     if (has_cifg_param)
594     {
595       scratch_buffer.info().shape(ir::Shape{n_batch, n_cell * 4});
596     }
597     else
598     {
599       scratch_buffer.info().shape(ir::Shape{n_batch, n_cell * 3});
600     }
601   }
602 }
603
604 void StaticShapeInferer::visit(const ir::operation::MatrixBandPart &op)
605 {
606   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT));
607 }
608
609 void StaticShapeInferer::visit(const ir::operation::OneHot &op)
610 {
611   const auto indice_idx{op.getInputs().at(ir::operation::OneHot::Input::INDICES)};
612   const auto &indice = _operands.at(indice_idx);
613   const auto depth_idx{op.getInputs().at(ir::operation::OneHot::Input::DEPTH)};
614   const auto &depth = _operands.at(depth_idx);
615
616   const auto axis = op.param().axis;
617
618   auto output_idx = op.getOutputs().at(0);
619   ir::Operand &output = _operands.at(output_idx);
620
621   if (!depth.isConstant())
622   {
623     output.info().setDynamic();
624     _return_has_dynamic_tensor = true;
625     return;
626   }
627
628   const auto *depth_buf = reinterpret_cast<const int32_t *>(depth.data()->base());
629   assert(depth_buf);
630   // re-sizing output shape
631   ir::Shape new_shape = shape_inference::inferOnehotShape(indice.info().shape(), *depth_buf, axis);
632   output.info().shape(new_shape);
633 }
634
635 void StaticShapeInferer::visit(const ir::operation::Pack &op)
636 {
637   const auto input_idx{op.getInputs().at(0)};
638   const auto &input = _operands.at(input_idx);
639
640   // get mutable output operand
641   const auto output_idx = op.getOutputs().at(0);
642   ir::Operand &output = _operands.at(output_idx);
643
644   const auto rank = input.shape().rank() + 1;
645   const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
646   const auto num = op.param().num;
647
648   assert(0 <= axis && axis < rank);
649
650   // re-sizing output shape
651   ir::Shape new_shape = shape_inference::inferPackShape(input.info().shape(), axis, rank, num);
652   output.info().shape(new_shape);
653 }
654
655 void StaticShapeInferer::visit(const ir::operation::Pad &op)
656 {
657   const auto input_idx{op.getInputs().at(ir::operation::Pad::Input::INPUT)};
658   const auto &input = _operands.at(input_idx);
659
660   const auto pad_idx{op.getInputs().at(ir::operation::Pad::Input::PAD)};
661   const auto &pad = _operands.at(pad_idx);
662
663   // get mutable output operand
664   const auto output_idx = op.getOutputs().at(0);
665   ir::Operand &output = _operands.at(output_idx);
666
667   // if pad is not constant, output also becomes dynamic
668   if (!pad.isConstant())
669   {
670     output.info().setDynamic();
671     _return_has_dynamic_tensor = true;
672     return;
673   }
674
675   // re-sizing output shape
676   const auto new_shape = shape_inference::inferPadShape(
677       input.shape(), reinterpret_cast<const int32_t *>(pad.data()->base()),
678       pad.shape().num_elements());
679   output.info().shape(new_shape);
680 }
681
682 void StaticShapeInferer::visit(const ir::operation::Permute &op)
683 {
684   const auto input_idx{op.getInputs().at(0)};
685   const auto &input = _operands.at(input_idx);
686   const auto output_idx = op.getOutputs().at(0);
687   ir::Operand &output = _operands.at(output_idx);
688
689   // re-sizing output shape
690   // Permute is a special operation that layouts of input/output may be different on backend
691   // However, it is not applied here, so input/output have the same layout of frontend. Because
692   // "ExecutorFactory" would convert shape of input/output accoding to the layouts when registering
693   // operand info to "TensorBuilder" after calling "StaticShapeInferer"
694   const auto new_shape = input.info().shape();
695   output.info().shape(new_shape);
696 }
697
698 void StaticShapeInferer::visit(const ir::operation::Pow &op)
699 {
700   handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Pow::Input::LHS),
701                            op.getInputs().at(ir::operation::Pow::Input::RHS));
702 }
703
704 void StaticShapeInferer::visit(const ir::operation::Range &op)
705 {
706   const auto start_idx{op.getInputs().at(ir::operation::Range::Input::START)};
707   const auto limit_idx{op.getInputs().at(ir::operation::Range::Input::LIMIT)};
708   const auto delta_idx{op.getInputs().at(ir::operation::Range::Input::DELTA)};
709   const auto &start_op = _operands.at(start_idx);
710   const auto &limit_op = _operands.at(limit_idx);
711   const auto &delta_op = _operands.at(delta_idx);
712
713   // get mutable output operand
714   const auto output_idx = op.getOutputs().at(0);
715   ir::Operand &output = _operands.at(output_idx);
716
717   ir::Shape new_shape;
718   if (start_op.isConstant() && limit_op.isConstant() && delta_op.isConstant())
719   {
720     assert(start_op.typeInfo().type() == limit_op.typeInfo().type() &&
721            start_op.typeInfo().type() == delta_op.typeInfo().type());
722     if (output.typeInfo().type() == ir::DataType::FLOAT32)
723     {
724       new_shape = shape_inference::inferRangeShape<float>(
725           start_op.asScalar<float>(), limit_op.asScalar<float>(), delta_op.asScalar<float>());
726     }
727     else if (output.typeInfo().type() == ir::DataType::INT32)
728     {
729       new_shape = shape_inference::inferRangeShape<int32_t>(
730           start_op.asScalar<int32_t>(), limit_op.asScalar<int32_t>(), delta_op.asScalar<int32_t>());
731     }
732     assert(output.shape() == new_shape);
733   }
734   else
735   {
736     output.info().setDynamic();
737     _return_has_dynamic_tensor = true;
738   }
739 }
740
741 void StaticShapeInferer::visit(const ir::operation::Reduce &op)
742 {
743   const auto input_idx{op.getInputs().at(ir::operation::Reduce::Input::INPUT)};
744   const auto &input = _operands.at(input_idx);
745
746   const auto axes_idx{op.getInputs().at(ir::operation::Reduce::Input::AXES)};
747   const auto &axes = _operands.at(axes_idx);
748
749   // get mutable output operand
750   const auto output_idx = op.getOutputs().at(0);
751   ir::Operand &output = _operands.at(output_idx);
752
753   std::vector<int32_t> axes_vec;
754   for (size_t i = 0; i < axes.shape().num_elements(); ++i)
755   {
756     switch (axes.typeInfo().type())
757     {
758       case ir::DataType::INT32:
759       {
760         axes_vec.emplace_back(reinterpret_cast<const int32_t *>(axes.data()->base())[i]);
761         break;
762       }
763       case ir::DataType::INT64:
764       {
765         axes_vec.emplace_back(reinterpret_cast<const int64_t *>(axes.data()->base())[i]);
766         break;
767       }
768       default:
769         throw std::runtime_error("StaticShapeInferer " + op.name() + ": Not supported data type");
770         break;
771     }
772   }
773   const auto keep_dims = op.param().keep_dims;
774
775   // re-sizing output shape
776   ir::Shape new_shape =
777       shape_inference::inferReduceShape(input.info().shape(), axes_vec, keep_dims);
778   output.info().shape(new_shape);
779 }
780
781 void StaticShapeInferer::visit(const ir::operation::Reshape &op)
782 {
783   const auto input_idx{op.getInputs().at(ir::operation::Reshape::Input::INPUT)};
784   const auto &input = _operands.at(input_idx);
785
786   // get mutable output operand
787   const auto output_idx = op.getOutputs().at(0);
788   ir::Operand &output = _operands.at(output_idx);
789
790   // New shape is given by second input tensor
791   if (op.getInputs().size() == 2)
792   {
793     // Let's check the second input
794     const auto shape_idx{op.getInputs().at(ir::operation::Reshape::Input::SHAPE)};
795     const auto &shape = _operands.at(shape_idx);
796
797     if (shape.isConstant())
798     {
799       const auto *shape_buf = reinterpret_cast<const int32_t *>(shape.data()->base());
800       assert(shape_buf);
801
802       ir::Shape new_shape = shape_inference::inferReshapeShape(
803           shape_buf, shape.shape().num_elements(), input.shape().num_elements());
804
805       // if shape is from Const, TFLC put the shape of output into tensor
806       if (new_shape != output.shape())
807       {
808         // change on output shape
809         output.info().shape(new_shape);
810       }
811     }
812     else
813     {
814       // if shape is NOT Const, set output shape to be dynamic_
815       output.info().setDynamic();
816       _return_has_dynamic_tensor = true;
817     }
818   }
819   // New shape is given by option
820   else if (op.param().new_shape.size() != 0)
821   {
822     // Let's check the new_shape option
823     auto shape = op.param().new_shape;
824     ir::Shape new_shape = shape_inference::inferReshapeShape(shape.data(), shape.size(),
825                                                              input.shape().num_elements());
826
827     if (new_shape != output.shape())
828     {
829       // change on output shape
830       output.info().shape(new_shape);
831     }
832   }
833   else
834   {
835     throw std::runtime_error("Reshape: new shape is missing");
836   }
837 }
838
839 void StaticShapeInferer::visit(const ir::operation::ResizeBilinear &op)
840 {
841   const auto input_idx{op.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
842   const auto &input = _operands.at(input_idx);
843
844   // get mutable output operand
845   const auto output_idx = op.getOutputs().at(0);
846   ir::Operand &output = _operands.at(output_idx);
847
848   int32_t height_out, width_out;
849   if (op.getInputs().size() == 2)
850   {
851     auto &size = _operands.at(op.getInputs().at(ir::operation::ResizeBilinear::Input::SIZE));
852     if (!size.isConstant())
853     {
854       output.info().setDynamic();
855       _return_has_dynamic_tensor = true;
856       return;
857     }
858     const auto size_v = size.asVector<std::int32_t>();
859     height_out = size_v[0];
860     width_out = size_v[1];
861   }
862   else
863   {
864     height_out = op.param().height_out;
865     width_out = op.param().width_out;
866   }
867
868   // Shape inferencing logic based on Params
869   ir::Shape new_shape =
870       shape_inference::inferResizeBilinearShape(input.shape(), height_out, width_out);
871
872   // if size_op is from Const, TFLC put the shape of output into tensor
873   if (new_shape != output.shape())
874   {
875     // change on output shape
876     output.info().shape(new_shape);
877   }
878 }
879
880 void StaticShapeInferer::visit(const ir::operation::Reverse &op)
881 {
882   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Reverse::Input::INPUT));
883 }
884
885 void StaticShapeInferer::visit(const ir::operation::Select &op)
886 {
887   const auto input_cond_idx{op.getInputs().at(ir::operation::Select::Input::CONDITION)};
888   const auto &input_cond = _operands.at(input_cond_idx);
889
890   const auto input_true_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_TRUE)};
891   const auto &input_true = _operands.at(input_true_idx);
892
893   const auto input_false_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_FALSE)};
894   const auto &input_false = _operands.at(input_false_idx);
895
896   auto output_idx = op.getOutputs().at(0);
897   ir::Operand &output = _operands.at(output_idx);
898
899   // Select output shpae
900   ir::Shape new_shape = shape_inference::inferSelectShape(
901       input_cond.info().shape(), input_true.info().shape(), input_false.info().shape());
902   output.info().shape(new_shape);
903 }
904
905 void StaticShapeInferer::visit(const ir::operation::Shape &op)
906 {
907   const auto input_idx{op.getInputs().at(0)};
908   const auto &input = _operands.at(input_idx);
909
910   // get mutable output operand
911   const auto output_idx = op.getOutputs().at(0);
912   ir::Operand &output = _operands.at(output_idx);
913
914   // re-sizing output shape
915   ir::Shape output_shape;
916   output_shape.append(input.info().shape().rank());
917
918   output.info().shape(output_shape);
919 }
920
921 void StaticShapeInferer::visit(const ir::operation::Slice &op)
922 {
923   const auto input_index{op.getInputs().at(ir::operation::Slice::Input::INPUT)};
924   const auto &input = _operands.at(input_index);
925   const auto begins_index{op.getInputs().at(ir::operation::Slice::Input::BEGINS)};
926   const auto &begins = _operands.at(begins_index);
927   const auto sizes_index{op.getInputs().at(ir::operation::Slice::Input::SIZES)};
928   const auto &sizes = _operands.at(sizes_index);
929   const auto output_index = op.getOutputs().at(0);
930   ir::Operand &output = _operands.at(output_index);
931
932   // Whether input is constant or not does not affect whether output is dynamic or not
933   if (!(begins.isConstant() && sizes.isConstant()))
934   {
935     output.info().setDynamic();
936     _return_has_dynamic_tensor = true;
937     return;
938   }
939
940   auto begins_buf = reinterpret_cast<const int32_t *>(begins.data()->base());
941   auto sizes_buf = reinterpret_cast<const int32_t *>(sizes.data()->base());
942
943   ir::Shape new_shape =
944       shape_inference::inferSliceShape(input.info().shape(), begins_buf, sizes_buf);
945   output.info().shape(new_shape);
946 }
947
948 void StaticShapeInferer::visit(const ir::operation::Softmax &op)
949 {
950   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Softmax::Input::INPUT));
951 }
952
953 void StaticShapeInferer::visit(const ir::operation::SpaceToBatchND &op)
954 {
955   const auto output_index = op.getOutputs().at(0);
956   const auto input_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
957   const auto block_shape_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)};
958   const auto padding_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
959
960   ir::Operand &output = _operands.at(output_index);
961   const auto &input = _operands.at(input_idx);
962   const auto &block_shape = _operands.at(block_shape_idx);
963   const auto &padding = _operands.at(padding_idx);
964
965   // Whether input is constant or not does not affect whether output is dynamic or not
966   if (!(block_shape.isConstant() && padding.isConstant()))
967   {
968     output.info().setDynamic();
969     _return_has_dynamic_tensor = true;
970     return;
971   }
972
973   auto input_shape = input.info().shape();
974   auto block_shape_shape = block_shape.info().shape();
975   auto padding_shape = padding.info().shape();
976
977   auto block_shape_data = reinterpret_cast<const int32_t *>(block_shape.data()->base());
978   auto padding_data = reinterpret_cast<const int32_t *>(padding.data()->base());
979
980   ir::Shape new_shape = shape_inference::inferSpaceToBatchNDShape(
981       input_shape, block_shape_shape, padding_shape, block_shape_data, padding_data);
982
983   output.info().shape(new_shape);
984 }
985
986 void StaticShapeInferer::visit(const ir::operation::Split &op)
987 {
988   const auto input_idx{op.getInputs().at(ir::operation::Split::Input::INPUT)};
989   const auto &input = _operands.at(input_idx);
990
991   const auto axis_idx{op.getInputs().at(ir::operation::Split::Input::AXIS)};
992   const auto &axis = _operands.at(axis_idx);
993
994   auto outputs = op.getOutputs();
995   if (!axis.isConstant())
996   {
997     for (auto output_idx : outputs)
998     {
999       ir::Operand &output = _operands.at(output_idx);
1000       output.info().setDynamic();
1001     }
1002     _return_has_dynamic_tensor = true;
1003     return;
1004   }
1005
1006   const auto num_splits = op.param().num_splits;
1007
1008   const auto rank = input.info().shape().rank();
1009   auto axis_value = axis.asScalar<int32_t>();
1010   axis_value = axis_value < 0 ? axis_value + rank : axis_value;
1011
1012   assert(0 <= axis_value && axis_value < rank);
1013
1014   ir::Shape new_shape =
1015       shape_inference::inferSplitShape(input.info().shape(), axis_value, num_splits);
1016   for (auto output_idx : outputs)
1017   {
1018     ir::Operand &output = _operands.at(output_idx);
1019     output.info().shape(new_shape);
1020   }
1021 }
1022
1023 void StaticShapeInferer::visit(const ir::operation::SquaredDifference &op)
1024 {
1025   handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::SquaredDifference::Input::LHS),
1026                            op.getInputs().at(ir::operation::SquaredDifference::Input::RHS));
1027 }
1028
1029 void StaticShapeInferer::visit(const ir::operation::Squeeze &op)
1030 {
1031   const auto input_idx{op.getInputs().at(ir::operation::Squeeze::Input::INPUT)};
1032   const auto &input = _operands.at(input_idx);
1033
1034   const auto output_idx = op.getOutputs().at(0);
1035   ir::Operand &output = _operands.at(output_idx);
1036
1037   // Squeeze output shpae
1038   ir::Shape new_shape = shape_inference::inferSqueezeShape(input.info().shape(), op.param());
1039   output.info().shape(new_shape);
1040 }
1041
1042 void StaticShapeInferer::visit(const ir::operation::StridedSlice &op)
1043 {
1044   const auto input_index{op.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
1045   const auto &input = _operands.at(input_index);
1046   const auto starts_index{op.getInputs().at(ir::operation::StridedSlice::Input::STARTS)};
1047   const auto &starts = _operands.at(starts_index);
1048   const auto ends_index{op.getInputs().at(ir::operation::StridedSlice::Input::ENDS)};
1049   const auto &ends = _operands.at(ends_index);
1050   const auto strides_index{op.getInputs().at(ir::operation::StridedSlice::Input::STRIDES)};
1051   const auto &strides = _operands.at(strides_index);
1052   const auto output_index = op.getOutputs().at(0);
1053   ir::Operand &output = _operands.at(output_index);
1054
1055   if (!(starts.isConstant() && ends.isConstant() && strides.isConstant()))
1056   {
1057     output.info().setDynamic();
1058     _return_has_dynamic_tensor = true;
1059     return;
1060   }
1061
1062   const auto begin_mask = op.param().begin_mask;
1063   const auto end_mask = op.param().end_mask;
1064   const auto shrink_axis_mask = op.param().shrink_axis_mask;
1065   const auto rank = input.info().shape().rank();
1066
1067   auto starts_buf = reinterpret_cast<const uint32_t *>(starts.data()->base());
1068   auto ends_buf = reinterpret_cast<const uint32_t *>(ends.data()->base());
1069   auto strides_buf = reinterpret_cast<const uint32_t *>(strides.data()->base());
1070
1071   auto op_params = shape_inference::buildStridedSliceParams(
1072       starts_buf, ends_buf, strides_buf, begin_mask, end_mask, shrink_axis_mask, rank);
1073
1074   ir::Shape new_shape =
1075       shape_inference::inferStridedSliceShape(input.info().shape(), op_params, rank);
1076   output.info().shape(new_shape);
1077 }
1078
1079 void StaticShapeInferer::visit(const ir::operation::Tile &op)
1080 {
1081   const auto input_idx{op.getInputs().at(ir::operation::Tile::Input::INPUT)};
1082   const auto &input = _operands.at(input_idx);
1083
1084   const auto multiplier_idx{op.getInputs().at(ir::operation::Tile::Input::MULTIPLES)};
1085   const auto &multiplier = _operands.at(multiplier_idx);
1086
1087   const auto output_idx = op.getOutputs().at(0);
1088   ir::Operand &output = _operands.at(output_idx);
1089
1090   if (!multiplier.isConstant())
1091   {
1092     output.info().setDynamic();
1093     _return_has_dynamic_tensor = true;
1094     return;
1095   }
1096
1097   auto multiplier_buffer = reinterpret_cast<const int32_t *>(multiplier.data()->base());
1098   assert(multiplier_buffer);
1099
1100   // re-sizing output shape
1101   auto new_shape = shape_inference::inferTileShape(input.info().shape(), multiplier_buffer,
1102                                                    multiplier.shape().num_elements());
1103   output.info().shape(new_shape);
1104 }
1105
1106 void StaticShapeInferer::visit(const ir::operation::Transpose &op)
1107 {
1108   const auto input_idx{op.getInputs().at(ir::operation::Transpose::Input::INPUT)};
1109   const auto &input = _operands.at(input_idx);
1110
1111   const auto perm_idx{op.getInputs().at(ir::operation::Transpose::Input::PERMUTATION)};
1112   const auto &perm = _operands.at(perm_idx);
1113
1114   // perm.shape() != ir::Shape{0} means that perm is (n-1...0)
1115   // TODO This condition changes to perm.num_elements() == 0
1116   const auto is_regular_transpose = perm.shape() == ir::Shape{0};
1117
1118   // get mutable output operand
1119   const auto output_idx = op.getOutputs().at(0);
1120   auto &output = _operands.at(output_idx);
1121   if (!perm.isConstant() && !is_regular_transpose)
1122   {
1123     output.info().setDynamic();
1124     _return_has_dynamic_tensor = true;
1125     return;
1126   }
1127
1128   ir::Shape new_shape;
1129   if (is_regular_transpose)
1130   {
1131     // Call by (n-1...0)
1132     new_shape = shape_inference::inferTransposeShape(input.info().shape(), nullptr, 0);
1133   }
1134   else
1135   {
1136     // Check rank
1137     if (input.info().shape().rank() != static_cast<int>(perm.info().shape().num_elements()))
1138     {
1139       throw std::runtime_error("StaticShapeInferer failed, bad rank size: " +
1140                                std::to_string(perm.info().shape().num_elements()));
1141     }
1142
1143     // set output shape, based on input and params
1144     const auto perm_buf = reinterpret_cast<const int32_t *>(perm.data()->base());
1145     new_shape = shape_inference::inferTransposeShape(input.info().shape(), perm_buf,
1146                                                      perm.shape().num_elements());
1147   }
1148   output.info().shape(new_shape);
1149 }
1150
1151 void StaticShapeInferer::visit(const ir::operation::Unpack &op)
1152 {
1153   const auto input_idx{op.getInputs().at(0)};
1154   const auto &input = _operands.at(input_idx);
1155   const auto num = op.param().num;
1156   const auto rank = input.shape().rank();
1157   const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
1158
1159   assert(axis < rank);
1160   if (axis < 0)
1161   {
1162     for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++)
1163     {
1164       const auto output_idx = op.getOutputs().at(out_tensor_idx);
1165       ir::Operand &output = _operands.at(output_idx);
1166       output.info().setDynamic();
1167     }
1168     _return_has_dynamic_tensor = true;
1169     return;
1170   }
1171
1172   ir::Shape new_shape = shape_inference::inferUnpackShape(input.info().shape(), axis, rank);
1173
1174   // re-sizing output shape
1175   for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++)
1176   {
1177     const auto output_idx = op.getOutputs().at(out_tensor_idx);
1178     ir::Operand &output = _operands.at(output_idx);
1179     output.info().shape(new_shape);
1180   }
1181 }
1182
1183 void StaticShapeInferer::visit(const ir::operation::While &op)
1184 {
1185   auto &cond_graph = _lowered_subgs.at(op.param().cond_subg_index)->graph();
1186   auto &body_graph = _lowered_subgs.at(op.param().body_subg_index)->graph();
1187   const auto inputs = op.getInputs();
1188   const auto &outputs = op.getOutputs();
1189
1190   // re-sizing input shapes of then subgraph
1191   const auto &cond_inputs = cond_graph.getInputs();
1192   assert(inputs.size() == cond_inputs.size());
1193   for (size_t i = 0; i < inputs.size(); ++i)
1194   {
1195     const auto &input = _operands.at(inputs.at(i));
1196     auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
1197     if (input.info().isDynamic())
1198     {
1199       cond_input.info().setDynamic();
1200     }
1201     else
1202     {
1203       auto new_shape = input.info().shape();
1204       cond_input.info().shape(new_shape);
1205     }
1206   }
1207
1208   // re-sizing input shapes of body subgraph
1209   const auto &body_inputs = body_graph.getInputs();
1210   assert(cond_inputs.size() == body_inputs.size());
1211   for (size_t i = 0; i < cond_inputs.size(); ++i)
1212   {
1213     const auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
1214     auto &body_input = body_graph.operands().at(body_inputs.at(i));
1215     if (cond_input.info().isDynamic())
1216     {
1217       body_input.info().setDynamic();
1218     }
1219     else
1220     {
1221       const auto &new_shape = cond_input.info().shape();
1222       body_input.info().shape(new_shape);
1223     }
1224   }
1225
1226   // re-sizing operands of body subgraph
1227   StaticShapeInferer body_inferer(op.param().body_subg_index, _lowered_subgs);
1228   _lowered_subgs.at(op.param().body_subg_index)
1229       ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
1230         bool has_dynamic_tensor = body_inferer.infer(op_seq);
1231         op_seq.has_dynamic_tensor(has_dynamic_tensor);
1232       });
1233
1234   // Check whether while operation's shapes are predictable
1235   // If any of shape of body outputs and cond inputs are different, non-constant operands would be
1236   // set to dynamic
1237   bool check_unpredictable_dynamic = false;
1238   const auto &body_outputs = body_graph.getOutputs();
1239   assert(body_outputs.size() == cond_inputs.size());
1240   for (size_t i = 0; i < body_outputs.size(); ++i)
1241   {
1242     const auto &body_output = body_graph.operands().at(body_outputs.at(i));
1243     auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
1244     if ((cond_input.info().isDynamic() != body_output.info().isDynamic()) ||
1245         (cond_input.shape() != body_output.shape()))
1246     {
1247       check_unpredictable_dynamic = true;
1248       break;
1249     }
1250   }
1251
1252   if (check_unpredictable_dynamic)
1253   {
1254     // Set inputs of body subgraph
1255     for (const auto &input_index : body_inputs)
1256     {
1257       auto &input = body_graph.operands().at(input_index);
1258       if (!input.isConstant())
1259       {
1260         input.info().setDynamic();
1261       }
1262     }
1263
1264     // Set inputs of cond subgraph
1265     for (const auto &input_index : cond_inputs)
1266     {
1267       auto &input = cond_graph.operands().at(input_index);
1268       if (!input.isConstant())
1269       {
1270         input.info().setDynamic();
1271       }
1272     }
1273
1274     // Set non-constant operands of body subgraph to dynamic
1275     StaticShapeInferer body_inferer(op.param().body_subg_index, _lowered_subgs);
1276     _lowered_subgs.at(op.param().body_subg_index)
1277         ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
1278           bool has_dynamic_tensor = body_inferer.infer(op_seq);
1279           op_seq.has_dynamic_tensor(has_dynamic_tensor);
1280         });
1281   }
1282
1283   // re-sizing operands of cond subgraph
1284   // If check_unpredictable_dynamic is true, non-constant operands of cond subgraph would be set to
1285   // dynamic
1286   StaticShapeInferer cond_inferer(op.param().cond_subg_index, _lowered_subgs);
1287   _lowered_subgs.at(op.param().cond_subg_index)
1288       ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
1289         bool has_dynamic_tensor = cond_inferer.infer(op_seq);
1290         op_seq.has_dynamic_tensor(has_dynamic_tensor);
1291       });
1292
1293   // re-sizing outputs of while operation
1294   // If check_unpredictable_dynamic is true, outputs of while operation would be set to dynamic
1295   assert(cond_inputs.size() == outputs.size());
1296   for (size_t i = 0; i < cond_inputs.size(); ++i)
1297   {
1298     const auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
1299     auto &output = _operands.at(outputs.at(i));
1300     if (cond_input.info().isDynamic())
1301     {
1302       output.info().setDynamic();
1303       _return_has_dynamic_tensor = true;
1304     }
1305     else
1306     {
1307       const auto new_shape = cond_input.info().shape();
1308       output.info().shape(new_shape);
1309     }
1310   }
1311 }
1312
1313 } // namespace compiler
1314
1315 } // namespace onert