Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / runtime / onert / core / src / compiler / StaticShapeInference.cc
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "compiler/StaticShapeInference.h"
18 #include "util/ShapeInference.h"
19 #include "util/logging.h"
20
21 #include <sstream>
22
23 namespace onert
24 {
25 namespace compiler
26 {
27
28 bool StaticShapeInferer::infer(const ir::OpSequence &op_seq)
29 {
30   bool has_dynamic_tensor = false;
31
32   for (const auto &operation_idx : op_seq.operations())
33   {
34     auto &op = _operations.at(operation_idx);
35     auto opcode = op.opcode();
36
37     _return_has_dynamic_tensor = false; // this is used as a return value inside operation's visit()
38
39     // IF: need shape inference for then, else
40     // While: need shape inference for condition, body
41     if (opcode == ir::OpCode::If || opcode == ir::OpCode::While)
42     {
43       op.accept(*this);
44     }
45     else
46     {
47       _return_has_dynamic_tensor = checkDynamicInput(op);
48
49       if (_return_has_dynamic_tensor)
50       {
51         setDynamicOutput(op);
52       }
53       else
54       {
55         op.accept(*this);
56       }
57     }
58
59     has_dynamic_tensor = has_dynamic_tensor || _return_has_dynamic_tensor;
60   }
61
62   return has_dynamic_tensor;
63 }
64
65 bool StaticShapeInferer::checkDynamicInput(const ir::Operation &op)
66 {
67   for (auto input_idx : op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
68   {
69     if (_operands.at(input_idx).info().isDynamic())
70     {
71       return true;
72     }
73   }
74
75   return false;
76 }
77
78 void StaticShapeInferer::setDynamicOutput(const ir::Operation &op)
79 {
80   for (auto output_idx : op.getOutputs())
81   {
82     _operands.at(output_idx).info().setDynamic();
83   }
84 }
85
86 void StaticShapeInferer::handleBinaryArithmeticOp(const ir::Operation &op,
87                                                   const ir::OperandIndex lhs_idx,
88                                                   const ir::OperandIndex rhs_idx)
89 {
90   const auto &lhs = _operands.at(lhs_idx);
91   const auto &rhs = _operands.at(rhs_idx);
92
93   const auto output_idx = op.getOutputs().at(0);
94   ir::Operand &output = _operands.at(output_idx);
95
96   // re-sizing output shape
97   ir::Shape new_shape = shape_inference::inferEltwiseShape(lhs.info().shape(), rhs.info().shape());
98   output.info().shape(new_shape);
99 }
100
101 void StaticShapeInferer::handleSimpleUnaryOp(const ir::Operation &op,
102                                              const ir::OperandIndex input_idx)
103 {
104   const auto &input = _operands.at(input_idx);
105
106   // get mutable output operand
107   const auto output_idx = op.getOutputs().at(0);
108   ir::Operand &output = _operands.at(output_idx);
109
110   // re-sizing output shape
111   ir::Shape new_shape = input.info().shape();
112   output.info().shape(new_shape);
113 }
114
115 void StaticShapeInferer::dump()
116 {
117   auto get_shape_str = [](const ir::Shape &shape) {
118     std::stringstream sstream;
119     sstream << "shape : {";
120     for (int i = 0; i < shape.rank(); i++)
121     {
122       if (i == 0)
123         sstream << shape.dim(i);
124       else
125         sstream << " " << shape.dim(i);
126     }
127     sstream << "}";
128     return sstream.str();
129   };
130
131   for (const auto &pair : _lowered_subgs)
132   {
133     const auto index = pair.first;
134     const auto &lowered_subg = pair.second;
135     VERBOSE(StaticShapeInferer) << "SubGraph #" << index.value() << std::endl;
136     lowered_subg->graph().operands().iterate(
137         [&](const ir::OperandIndex &ind, const ir::Operand &operand) {
138           VERBOSE(StaticShapeInferer) << "Operand #" << ind.value() << ", "
139                                       << (operand.info().isDynamic() ? "Dynamic" : "Static") << ", "
140                                       << get_shape_str(operand.info().shape()) << std::endl;
141         });
142   }
143 }
144
145 void StaticShapeInferer::visit(const ir::operation::ArgMax &op)
146 {
147   const auto input_idx{op.getInputs().at(ir::operation::ArgMax::Input::INPUT)};
148   const auto &input = _operands.at(input_idx);
149
150   // get mutable output operand
151   const auto output_idx = op.getOutputs().at(0);
152   ir::Operand &output = _operands.at(output_idx);
153   const auto rank = input.info().shape().rank();
154   const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
155
156   assert(0 <= axis && axis < rank);
157
158   // re-sizing output shape
159   ir::Shape new_shape = shape_inference::inferArgMaxShape(input.info().shape(), axis, rank);
160   output.info().shape(new_shape);
161 }
162
163 void StaticShapeInferer::visit(const ir::operation::BatchMatMul &op)
164 {
165   const auto lhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::LHS);
166   const auto rhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::RHS);
167   const auto output_index = op.getOutputs().at(0);
168   const auto lhs = _operands.at(lhs_index);
169   const auto rhs = _operands.at(rhs_index);
170   auto &output = _operands.at(output_index);
171   auto new_shape = shape_inference::inferBatchMatMulShape(lhs.shape(), rhs.shape(), op.param());
172   output.info().shape(new_shape);
173 }
174
175 void StaticShapeInferer::visit(const ir::operation::BinaryArithmetic &op)
176 {
177   handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::BinaryArithmetic::Input::LHS),
178                            op.getInputs().at(ir::operation::BinaryArithmetic::Input::RHS));
179 }
180
181 void StaticShapeInferer::visit(const ir::operation::BroadcastTo &op)
182 {
183   // get mutable output operand
184   const auto output_idx = op.getOutputs().at(0);
185   ir::Operand &output = _operands.at(output_idx);
186
187   const auto shape_idx{op.getInputs().at(ir::operation::BroadcastTo::Input::SHAPE)};
188   const auto &shape = _operands.at(shape_idx);
189
190   if (!shape.isConstant())
191   {
192     output.info().setDynamic();
193     _return_has_dynamic_tensor = true;
194     return;
195   }
196
197   // assert(shape.typeInfo().type() == ir::DataType::INT32);
198   auto shape_buffer = reinterpret_cast<const int32_t *>(shape.data()->base());
199
200   // re-sizing output shape
201   ir::Shape new_shape = shape_inference::inferBroadcastToShape(shape.info().shape(), shape_buffer);
202   output.info().shape(new_shape);
203 }
204
205 void StaticShapeInferer::visit(const ir::operation::Comparison &op)
206 {
207   handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Comparison::Input::INPUT0),
208                            op.getInputs().at(ir::operation::Comparison::Input::INPUT1));
209 }
210
211 void StaticShapeInferer::visit(const ir::operation::Concat &op)
212 {
213   const auto input_count = op.getInputs().size();
214
215   const auto output_idx = op.getOutputs().at(0);
216   ir::Operand &output = _operands.at(output_idx);
217
218   shape_inference::Shapes input_shapes;
219   for (uint32_t i = 0; i < input_count; i++)
220   {
221     const auto input_idx{op.getInputs().at(i)};
222     const auto &input = _operands.at(input_idx);
223     input_shapes.emplace_back(input.shape());
224   }
225
226   ir::Shape out_shape = shape_inference::inferConcatShape(input_shapes, op.param());
227
228   // re-sizing output shape
229   output.info().shape(out_shape);
230 }
231
232 void StaticShapeInferer::visit(const ir::operation::Conv2D &op)
233 {
234   const auto input_idx{op.getInputs().at(ir::operation::Conv2D::Input::INPUT)};
235   const auto &input = _operands.at(input_idx);
236   const auto ker_idx{op.getInputs().at(ir::operation::Conv2D::Input::KERNEL)};
237   const auto &ker = _operands.at(ker_idx);
238   const auto output_idx = op.getOutputs().at(0);
239   ir::Operand &output = _operands.at(output_idx);
240
241   // re-sizing output shape
242   ir::Shape new_shape =
243       shape_inference::inferConv2DShape(input.info().shape(), ker.info().shape(), op.param());
244   output.info().shape(new_shape);
245 }
246
247 void StaticShapeInferer::visit(const ir::operation::ElementwiseActivation &op)
248 {
249   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::ElementwiseActivation::Input::INPUT));
250 }
251
252 void StaticShapeInferer::visit(const ir::operation::ElementwiseBinary &op)
253 {
254   handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::ElementwiseBinary::Input::LHS),
255                            op.getInputs().at(ir::operation::ElementwiseBinary::Input::RHS));
256 }
257
258 void StaticShapeInferer::visit(const ir::operation::ElementwiseUnary &op)
259 {
260   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT));
261 }
262
263 void StaticShapeInferer::visit(const ir::operation::ExpandDims &op)
264 {
265   const auto input_idx{op.getInputs().at(ir::operation::ExpandDims::Input::INPUT)};
266   const auto &input = _operands.at(input_idx);
267   const auto axis_idx{op.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
268   const auto &axis = _operands.at(axis_idx);
269   const auto output_idx = op.getOutputs().at(0);
270   ir::Operand &output = _operands.at(output_idx);
271
272   if (!axis.isConstant())
273   {
274     output.info().setDynamic();
275     _return_has_dynamic_tensor = true;
276     return;
277   }
278
279   // even when axis is constant, output shape should be recalculated since user might call
280   // nnfw_set_input_tensorinfo(input, some_new_shape)
281   auto axis_buf = reinterpret_cast<const int32_t *>(axis.data()->base());
282   assert(axis_buf);
283
284   // re-sizing output shape
285   ir::Shape new_shape = shape_inference::inferExpandDimsShape(input.info().shape(), axis_buf[0]);
286   output.info().shape(new_shape);
287 }
288
289 void StaticShapeInferer::visit(const ir::operation::Fill &op)
290 {
291   const auto input_idx{op.getInputs().at(ir::operation::Fill::Input::INPUT)};
292   const auto &input = _operands.at(input_idx);
293   const auto output_idx = op.getOutputs().at(0);
294   ir::Operand &output = _operands.at(output_idx);
295
296   if (!input.isConstant())
297   {
298     output.info().setDynamic();
299     _return_has_dynamic_tensor = true;
300     return;
301   }
302
303   assert(input.typeInfo().type() == ir::DataType::INT32);
304
305   auto input_buf = reinterpret_cast<const int32_t *>(input.data()->base());
306   assert(input_buf);
307
308   // re-sizing output shape
309   ir::Shape new_shape = shape_inference::inferFillShape(input.info().shape(), input_buf);
310   output.info().shape(new_shape);
311 }
312
313 void StaticShapeInferer::visit(const ir::operation::FullyConnected &op)
314 {
315   const auto input_idx{op.getInputs().at(ir::operation::FullyConnected::Input::INPUT)};
316   const auto &input = _operands.at(input_idx);
317
318   const auto ker_idx{op.getInputs().at(ir::operation::FullyConnected::Input::WEIGHT)};
319   const auto &ker = _operands.at(ker_idx);
320
321   // get mutable output operand
322   const auto output_idx = op.getOutputs().at(0);
323   ir::Operand &output = _operands.at(output_idx);
324   // re-sizing output shape
325   ir::Shape new_shape =
326       shape_inference::inferFullyConnectedShape(input.info().shape(), ker.info().shape());
327   output.info().shape(new_shape);
328 }
329
330 void StaticShapeInferer::visit(const ir::operation::FusedBatchNorm &op)
331 {
332   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::FusedBatchNorm::Input::INPUT));
333 }
334
335 void StaticShapeInferer::visit(const ir::operation::Gather &op)
336 {
337   const auto input_idx{op.getInputs().at(ir::operation::Gather::Input::INPUT)};
338   const auto &input = _operands.at(input_idx);
339
340   // get mutable output operand
341   const auto output_idx = op.getOutputs().at(0);
342   ir::Operand &output = _operands.at(output_idx);
343
344   const auto indices_idx{op.getInputs().at(ir::operation::Gather::Input::INDICES)};
345   const auto &indices = _operands.at(indices_idx);
346   const auto rank = input.info().shape().rank();
347   const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
348
349   assert(0 <= axis && axis < rank);
350
351   // re-sizing output shape
352   ir::Shape new_shape =
353       shape_inference::inferGatherShape(input.info().shape(), indices.info().shape(), axis, rank);
354   output.info().shape(new_shape);
355 }
356
357 void StaticShapeInferer::visit(const ir::operation::If &op)
358 {
359   auto &then_graph = _lowered_subgs.at(op.param().then_subg_index)->graph();
360   auto &else_graph = _lowered_subgs.at(op.param().else_subg_index)->graph();
361   const std::vector<ir::OperandIndex> inputs{op.getInputs().begin() + 1, op.getInputs().end()};
362   const auto &outputs = op.getOutputs();
363
364   // re-sizing input shapes of then subgraph
365   const auto &then_inputs = then_graph.getInputs();
366   assert(inputs.size() == then_inputs.size());
367   for (size_t i = 0; i < inputs.size(); ++i)
368   {
369     auto &then_input = then_graph.operands().at(then_inputs.at(i));
370     if (_operands.at(inputs.at(i)).info().isDynamic())
371     {
372       then_input.info().setDynamic();
373     }
374     else
375     {
376       auto new_shape = _operands.at(inputs.at(i)).info().shape();
377       then_input.info().shape(new_shape);
378     }
379   }
380
381   // re-sizing input shapes of else subgraph
382   const auto &else_inputs = else_graph.getInputs();
383   assert(inputs.size() == else_inputs.size());
384   for (size_t i = 0; i < inputs.size(); ++i)
385   {
386     auto &else_input = else_graph.operands().at(else_inputs.at(i));
387     if (_operands.at(inputs.at(i)).info().isDynamic())
388     {
389       else_input.info().setDynamic();
390     }
391     else
392     {
393       const auto &new_shape = _operands.at(inputs.at(i)).info().shape();
394       else_input.info().shape(new_shape);
395     }
396   }
397
398   // re-sizing operands of then subgraph
399   StaticShapeInferer then_inferer(op.param().then_subg_index, _lowered_subgs);
400   _lowered_subgs.at(op.param().then_subg_index)
401       ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
402         bool has_dynamic_tensor = then_inferer.infer(op_seq);
403         op_seq.has_dynamic_tensor(has_dynamic_tensor);
404       });
405
406   // re-sizing operands of else subgraph
407   StaticShapeInferer else_inferer(op.param().else_subg_index, _lowered_subgs);
408   _lowered_subgs.at(op.param().else_subg_index)
409       ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
410         bool has_dynamic_tensor = else_inferer.infer(op_seq);
411         op_seq.has_dynamic_tensor(has_dynamic_tensor);
412       });
413
414   // re-sizing output shapes
415   const auto &then_outputs = _lowered_subgs.at(op.param().then_subg_index)->graph().getOutputs();
416   const auto &else_outputs = _lowered_subgs.at(op.param().else_subg_index)->graph().getOutputs();
417   assert(outputs.size() == then_outputs.size());
418   assert(outputs.size() == else_outputs.size());
419   for (size_t i = 0; i < outputs.size(); ++i)
420   {
421     const auto &then_output = then_graph.operands().at(then_outputs.at(i));
422     const auto &else_output = else_graph.operands().at(else_outputs.at(i));
423     auto &output = _operands.at(outputs.at(i));
424     if (!then_output.info().isDynamic() && !else_output.info().isDynamic() &&
425         then_output.shape() == else_output.shape())
426     {
427       output.info().shape(then_output.shape());
428     }
429     else
430     {
431       output.info().setDynamic();
432       _return_has_dynamic_tensor = true;
433     }
434   }
435 }
436
437 void StaticShapeInferer::visit(const ir::operation::L2Normalization &op)
438 {
439   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::L2Normalization::Input::INPUT));
440 }
441
442 void StaticShapeInferer::visit(const ir::operation::MatrixBandPart &op)
443 {
444   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT));
445 }
446
447 void StaticShapeInferer::visit(const ir::operation::OneHot &op)
448 {
449   const auto indice_idx{op.getInputs().at(ir::operation::OneHot::Input::INDICES)};
450   const auto &indice = _operands.at(indice_idx);
451   const auto depth_idx{op.getInputs().at(ir::operation::OneHot::Input::DEPTH)};
452   const auto &depth = _operands.at(depth_idx);
453
454   const auto axis = op.param().axis;
455
456   auto output_idx = op.getOutputs().at(0);
457   ir::Operand &output = _operands.at(output_idx);
458
459   if (!depth.isConstant())
460   {
461     output.info().setDynamic();
462     _return_has_dynamic_tensor = true;
463     return;
464   }
465
466   const auto *depth_buf = reinterpret_cast<const int32_t *>(depth.data()->base());
467   assert(depth_buf);
468   // re-sizing output shape
469   ir::Shape new_shape = shape_inference::inferOnehotShape(indice.info().shape(), *depth_buf, axis);
470   output.info().shape(new_shape);
471 }
472
473 void StaticShapeInferer::visit(const ir::operation::Pack &op)
474 {
475   const auto input_idx{op.getInputs().at(0)};
476   const auto &input = _operands.at(input_idx);
477
478   // get mutable output operand
479   const auto output_idx = op.getOutputs().at(0);
480   ir::Operand &output = _operands.at(output_idx);
481
482   const auto rank = input.shape().rank() + 1;
483   const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
484   const auto num = op.param().num;
485
486   assert(0 <= axis && axis < rank);
487
488   // re-sizing output shape
489   ir::Shape new_shape = shape_inference::inferPackShape(input.info().shape(), axis, rank, num);
490   output.info().shape(new_shape);
491 }
492
493 void StaticShapeInferer::visit(const ir::operation::Pad &op)
494 {
495   const auto input_idx{op.getInputs().at(ir::operation::Pad::Input::INPUT)};
496   const auto &input = _operands.at(input_idx);
497
498   const auto pad_idx{op.getInputs().at(ir::operation::Pad::Input::PAD)};
499   const auto &pad = _operands.at(pad_idx);
500
501   // get mutable output operand
502   const auto output_idx = op.getOutputs().at(0);
503   ir::Operand &output = _operands.at(output_idx);
504
505   // if pad is not constant, output also becomes dynamic
506   if (!pad.isConstant())
507   {
508     output.info().setDynamic();
509     _return_has_dynamic_tensor = true;
510     return;
511   }
512
513   // re-sizing output shape
514   const auto new_shape = shape_inference::inferPadShape(
515       input.shape(), reinterpret_cast<const int32_t *>(pad.data()->base()),
516       pad.shape().num_elements());
517   output.info().shape(new_shape);
518 }
519
520 void StaticShapeInferer::visit(const ir::operation::Permute &op)
521 {
522   const auto input_idx{op.getInputs().at(0)};
523   const auto &input = _operands.at(input_idx);
524   const auto output_idx = op.getOutputs().at(0);
525   ir::Operand &output = _operands.at(output_idx);
526
527   // re-sizing output shape
528   // Permute is a special operation that layouts of input/output may be different on backend
529   // However, it is not applied here, so input/output have the same layout of frontend. Because
530   // "ExecutorFactory" would convert shape of input/output accoding to the layouts when registering
531   // operand info to "TensorBuilder" after calling "StaticShapeInferer"
532   const auto new_shape = input.info().shape();
533   output.info().shape(new_shape);
534 }
535
536 void StaticShapeInferer::visit(const ir::operation::Pow &op)
537 {
538   handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Pow::Input::LHS),
539                            op.getInputs().at(ir::operation::Pow::Input::RHS));
540 }
541
542 void StaticShapeInferer::visit(const ir::operation::Range &op)
543 {
544   const auto start_idx{op.getInputs().at(ir::operation::Range::Input::START)};
545   const auto limit_idx{op.getInputs().at(ir::operation::Range::Input::LIMIT)};
546   const auto delta_idx{op.getInputs().at(ir::operation::Range::Input::DELTA)};
547   const auto &start_op = _operands.at(start_idx);
548   const auto &limit_op = _operands.at(limit_idx);
549   const auto &delta_op = _operands.at(delta_idx);
550
551   // get mutable output operand
552   const auto output_idx = op.getOutputs().at(0);
553   ir::Operand &output = _operands.at(output_idx);
554
555   ir::Shape new_shape;
556   if (start_op.isConstant() && limit_op.isConstant() && delta_op.isConstant())
557   {
558     assert(start_op.typeInfo().type() == limit_op.typeInfo().type() &&
559            start_op.typeInfo().type() == delta_op.typeInfo().type());
560     if (output.typeInfo().type() == ir::DataType::FLOAT32)
561     {
562       new_shape = shape_inference::inferRangeShape<float>(
563           start_op.asScalar<float>(), limit_op.asScalar<float>(), delta_op.asScalar<float>());
564     }
565     else if (output.typeInfo().type() == ir::DataType::INT32)
566     {
567       new_shape = shape_inference::inferRangeShape<int32_t>(
568           start_op.asScalar<int32_t>(), limit_op.asScalar<int32_t>(), delta_op.asScalar<int32_t>());
569     }
570     assert(output.shape() == new_shape);
571   }
572   else
573   {
574     output.info().setDynamic();
575     _return_has_dynamic_tensor = true;
576   }
577 }
578
579 void StaticShapeInferer::visit(const ir::operation::Reduce &op)
580 {
581   const auto input_idx{op.getInputs().at(ir::operation::Reduce::Input::INPUT)};
582   const auto &input = _operands.at(input_idx);
583
584   const auto axes_idx{op.getInputs().at(ir::operation::Reduce::Input::AXES)};
585   const auto &axes = _operands.at(axes_idx);
586
587   // get mutable output operand
588   const auto output_idx = op.getOutputs().at(0);
589   ir::Operand &output = _operands.at(output_idx);
590
591   std::vector<int32_t> axes_vec;
592   for (size_t i = 0; i < axes.shape().num_elements(); ++i)
593   {
594     switch (axes.typeInfo().type())
595     {
596       case ir::DataType::INT32:
597       {
598         axes_vec.emplace_back(reinterpret_cast<const int32_t *>(axes.data()->base())[i]);
599         break;
600       }
601       case ir::DataType::INT64:
602       {
603         axes_vec.emplace_back(reinterpret_cast<const int64_t *>(axes.data()->base())[i]);
604         break;
605       }
606       default:
607         throw std::runtime_error("StaticShapeInferer " + op.name() + ": Not supported data type");
608         break;
609     }
610   }
611   const auto keep_dims = op.param().keep_dims;
612
613   // re-sizing output shape
614   ir::Shape new_shape =
615       shape_inference::inferReduceShape(input.info().shape(), axes_vec, keep_dims);
616   output.info().shape(new_shape);
617 }
618
619 void StaticShapeInferer::visit(const ir::operation::Reshape &op)
620 {
621   const auto input_idx{op.getInputs().at(ir::operation::Reshape::Input::INPUT)};
622   const auto &input = _operands.at(input_idx);
623
624   // get mutable output operand
625   const auto output_idx = op.getOutputs().at(0);
626   ir::Operand &output = _operands.at(output_idx);
627
628   // New shape is given by second input tensor
629   if (op.getInputs().size() == 2)
630   {
631     // Let's check the second input
632     const auto shape_idx{op.getInputs().at(ir::operation::Reshape::Input::SHAPE)};
633     const auto &shape = _operands.at(shape_idx);
634
635     if (shape.isConstant())
636     {
637       const auto *shape_buf = reinterpret_cast<const int32_t *>(shape.data()->base());
638       assert(shape_buf);
639
640       ir::Shape new_shape = shape_inference::inferReshapeShape(
641           shape_buf, shape.shape().num_elements(), input.shape().num_elements());
642
643       // if shape is from Const, TFLC put the shape of output into tensor
644       if (new_shape != output.shape())
645       {
646         // change on output shape
647         output.info().shape(new_shape);
648       }
649     }
650     else
651     {
652       // if shape is NOT Const, set output shape to be dynamic_
653       output.info().setDynamic();
654       _return_has_dynamic_tensor = true;
655     }
656   }
657   // New shape is given by option
658   else if (op.param().new_shape.size() != 0)
659   {
660     // Let's check the new_shape option
661     auto shape = op.param().new_shape;
662     ir::Shape new_shape = shape_inference::inferReshapeShape(shape.data(), shape.size(),
663                                                              input.shape().num_elements());
664
665     if (new_shape != output.shape())
666     {
667       // change on output shape
668       output.info().shape(new_shape);
669     }
670   }
671   else
672   {
673     throw std::runtime_error("Reshape: new shape is missing");
674   }
675 }
676
677 void StaticShapeInferer::visit(const ir::operation::ResizeBilinear &op)
678 {
679   const auto input_idx{op.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
680   const auto &input = _operands.at(input_idx);
681
682   // get mutable output operand
683   const auto output_idx = op.getOutputs().at(0);
684   ir::Operand &output = _operands.at(output_idx);
685
686   // Shape inferencing logic based on Params
687   ir::Shape new_shape = shape_inference::inferResizeBilinearShape(
688       input.shape(), op.param().height_out, op.param().width_out);
689
690   // if size_op is from Const, TFLC put the shape of output into tensor
691   if (new_shape != output.shape())
692   {
693     // change on output shape
694     output.info().shape(new_shape);
695   }
696 }
697
698 void StaticShapeInferer::visit(const ir::operation::Reverse &op)
699 {
700   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Reverse::Input::INPUT));
701 }
702
703 void StaticShapeInferer::visit(const ir::operation::Select &op)
704 {
705   const auto input_cond_idx{op.getInputs().at(ir::operation::Select::Input::CONDITION)};
706   const auto &input_cond = _operands.at(input_cond_idx);
707
708   const auto input_true_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_TRUE)};
709   const auto &input_true = _operands.at(input_true_idx);
710
711   const auto input_false_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_FALSE)};
712   const auto &input_false = _operands.at(input_false_idx);
713
714   auto output_idx = op.getOutputs().at(0);
715   ir::Operand &output = _operands.at(output_idx);
716
717   // Select output shpae
718   ir::Shape new_shape = shape_inference::inferSelectShape(
719       input_cond.info().shape(), input_true.info().shape(), input_false.info().shape());
720   output.info().shape(new_shape);
721 }
722
723 void StaticShapeInferer::visit(const ir::operation::Shape &op)
724 {
725   const auto input_idx{op.getInputs().at(0)};
726   const auto &input = _operands.at(input_idx);
727
728   // get mutable output operand
729   const auto output_idx = op.getOutputs().at(0);
730   ir::Operand &output = _operands.at(output_idx);
731
732   // re-sizing output shape
733   ir::Shape output_shape;
734   output_shape.append(input.info().shape().rank());
735
736   output.info().shape(output_shape);
737 }
738
739 void StaticShapeInferer::visit(const ir::operation::Slice &op)
740 {
741   const auto input_index{op.getInputs().at(ir::operation::Slice::Input::INPUT)};
742   const auto &input = _operands.at(input_index);
743   const auto begins_index{op.getInputs().at(ir::operation::Slice::Input::BEGINS)};
744   const auto &begins = _operands.at(begins_index);
745   const auto sizes_index{op.getInputs().at(ir::operation::Slice::Input::SIZES)};
746   const auto &sizes = _operands.at(sizes_index);
747   const auto output_index = op.getOutputs().at(0);
748   ir::Operand &output = _operands.at(output_index);
749
750   // Whether input is constant or not does not affect whether output is dynamic or not
751   if (!(begins.isConstant() && sizes.isConstant()))
752   {
753     output.info().setDynamic();
754     _return_has_dynamic_tensor = true;
755     return;
756   }
757
758   auto begins_buf = reinterpret_cast<const int32_t *>(begins.data()->base());
759   auto sizes_buf = reinterpret_cast<const int32_t *>(sizes.data()->base());
760
761   ir::Shape new_shape =
762       shape_inference::inferSliceShape(input.info().shape(), begins_buf, sizes_buf);
763   output.info().shape(new_shape);
764 }
765
766 void StaticShapeInferer::visit(const ir::operation::Softmax &op)
767 {
768   handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Softmax::Input::INPUT));
769 }
770
771 void StaticShapeInferer::visit(const ir::operation::SpaceToBatchND &op)
772 {
773   const auto output_index = op.getOutputs().at(0);
774   const auto input_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
775   const auto block_shape_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)};
776   const auto padding_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
777
778   ir::Operand &output = _operands.at(output_index);
779   const auto &input = _operands.at(input_idx);
780   const auto &block_shape = _operands.at(block_shape_idx);
781   const auto &padding = _operands.at(padding_idx);
782
783   // Whether input is constant or not does not affect whether output is dynamic or not
784   if (!(block_shape.isConstant() && padding.isConstant()))
785   {
786     output.info().setDynamic();
787     _return_has_dynamic_tensor = true;
788     return;
789   }
790
791   auto input_shape = input.info().shape();
792   auto block_shape_shape = block_shape.info().shape();
793   auto padding_shape = padding.info().shape();
794
795   auto block_shape_data = reinterpret_cast<const int32_t *>(block_shape.data()->base());
796   auto padding_data = reinterpret_cast<const int32_t *>(padding.data()->base());
797
798   ir::Shape new_shape = shape_inference::inferSpaceToBatchNDShape(
799       input_shape, block_shape_shape, padding_shape, block_shape_data, padding_data);
800
801   output.info().shape(new_shape);
802 }
803
804 void StaticShapeInferer::visit(const ir::operation::Split &op)
805 {
806   const auto input_idx{op.getInputs().at(0)};
807   const auto &input = _operands.at(input_idx);
808
809   const auto axis = op.param().axis;
810   const auto num_splits = op.param().num_splits;
811
812   const auto rank = input.info().shape().rank();
813   auto axis_resolved = axis < 0 ? axis + rank : axis;
814
815   assert(0 <= axis_resolved && axis_resolved < rank);
816
817   ir::Shape new_shape =
818       shape_inference::inferSplitShape(input.info().shape(), axis_resolved, num_splits);
819   auto output_tensors = op.getOutputs();
820   for (auto output_idx : output_tensors)
821   {
822     ir::Operand &output = _operands.at(output_idx);
823     output.info().shape(new_shape);
824   }
825 }
826
827 void StaticShapeInferer::visit(const ir::operation::SquaredDifference &op)
828 {
829   handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::SquaredDifference::Input::LHS),
830                            op.getInputs().at(ir::operation::SquaredDifference::Input::RHS));
831 }
832
833 void StaticShapeInferer::visit(const ir::operation::Squeeze &op)
834 {
835   const auto input_idx{op.getInputs().at(ir::operation::Squeeze::Input::INPUT)};
836   const auto &input = _operands.at(input_idx);
837
838   const auto output_idx = op.getOutputs().at(0);
839   ir::Operand &output = _operands.at(output_idx);
840
841   if (input.info().isDynamic())
842   {
843     output.info().setDynamic();
844     _return_has_dynamic_tensor = true;
845     return;
846   }
847
848   // Squeeze output shpae
849   ir::Shape new_shape = shape_inference::inferSqueezeShape(input.info().shape(), op.param());
850   output.info().shape(new_shape);
851 }
852
853 void StaticShapeInferer::visit(const ir::operation::StridedSlice &op)
854 {
855   const auto input_index{op.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
856   const auto &input = _operands.at(input_index);
857   const auto starts_index{op.getInputs().at(ir::operation::StridedSlice::Input::STARTS)};
858   const auto &starts = _operands.at(starts_index);
859   const auto ends_index{op.getInputs().at(ir::operation::StridedSlice::Input::ENDS)};
860   const auto &ends = _operands.at(ends_index);
861   const auto strides_index{op.getInputs().at(ir::operation::StridedSlice::Input::STRIDES)};
862   const auto &strides = _operands.at(strides_index);
863   const auto output_index = op.getOutputs().at(0);
864   ir::Operand &output = _operands.at(output_index);
865
866   if (!(starts.isConstant() && ends.isConstant() && strides.isConstant()))
867   {
868     output.info().setDynamic();
869     _return_has_dynamic_tensor = true;
870     return;
871   }
872
873   const auto begin_mask = op.param().begin_mask;
874   const auto end_mask = op.param().end_mask;
875   const auto shrink_axis_mask = op.param().shrink_axis_mask;
876   const auto rank = input.info().shape().rank();
877
878   auto starts_buf = reinterpret_cast<const uint32_t *>(starts.data()->base());
879   auto ends_buf = reinterpret_cast<const uint32_t *>(ends.data()->base());
880   auto strides_buf = reinterpret_cast<const uint32_t *>(strides.data()->base());
881
882   auto op_params = shape_inference::buildStridedSliceParams(
883       starts_buf, ends_buf, strides_buf, begin_mask, end_mask, shrink_axis_mask, rank);
884
885   ir::Shape new_shape =
886       shape_inference::inferStridedSliceShape(input.info().shape(), op_params, rank);
887   output.info().shape(new_shape);
888 }
889
890 void StaticShapeInferer::visit(const ir::operation::Tile &op)
891 {
892   const auto input_idx{op.getInputs().at(ir::operation::Tile::Input::INPUT)};
893   const auto &input = _operands.at(input_idx);
894
895   const auto multiplier_idx{op.getInputs().at(ir::operation::Tile::Input::MULTIPLES)};
896   const auto &multiplier = _operands.at(multiplier_idx);
897
898   const auto output_idx = op.getOutputs().at(0);
899   ir::Operand &output = _operands.at(output_idx);
900
901   if (!multiplier.isConstant())
902   {
903     output.info().setDynamic();
904     _return_has_dynamic_tensor = true;
905     return;
906   }
907
908   auto multiplier_buffer = reinterpret_cast<const int32_t *>(multiplier.data()->base());
909   assert(multiplier_buffer);
910
911   // re-sizing output shape
912   auto new_shape = shape_inference::inferTileShape(input.info().shape(), multiplier_buffer);
913   output.info().shape(new_shape);
914 }
915
916 void StaticShapeInferer::visit(const ir::operation::Transpose &op)
917 {
918   const auto input_idx{op.getInputs().at(ir::operation::Transpose::Input::INPUT)};
919   const auto &input = _operands.at(input_idx);
920
921   // get mutable output operand
922   const auto output_idx = op.getOutputs().at(0);
923   ir::Operand &output = _operands.at(output_idx);
924   const auto perm{op.param().perm};
925   // const auto rank{op.param().rank};
926
927   // set output shape, based on input and params
928   ir::Shape new_shape = shape_inference::inferTransposeShape(input.info().shape(), perm);
929   output.info().shape(new_shape);
930 }
931
932 void StaticShapeInferer::visit(const ir::operation::Unpack &op)
933 {
934   const auto input_idx{op.getInputs().at(0)};
935   const auto &input = _operands.at(input_idx);
936   const auto num = op.param().num;
937   const auto rank = input.shape().rank();
938   const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
939
940   assert(axis < rank);
941   if (axis < 0)
942   {
943     for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++)
944     {
945       const auto output_idx = op.getOutputs().at(out_tensor_idx);
946       ir::Operand &output = _operands.at(output_idx);
947       output.info().setDynamic();
948     }
949     _return_has_dynamic_tensor = true;
950     return;
951   }
952
953   ir::Shape new_shape = shape_inference::inferUnpackShape(input.info().shape(), axis, rank);
954
955   // re-sizing output shape
956   for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++)
957   {
958     const auto output_idx = op.getOutputs().at(out_tensor_idx);
959     ir::Operand &output = _operands.at(output_idx);
960     output.info().shape(new_shape);
961   }
962 }
963
964 void StaticShapeInferer::visit(const ir::operation::While &op)
965 {
966   auto &cond_graph = _lowered_subgs.at(op.param().cond_subg_index)->graph();
967   auto &body_graph = _lowered_subgs.at(op.param().body_subg_index)->graph();
968   const auto inputs = op.getInputs();
969   const auto &outputs = op.getOutputs();
970
971   // re-sizing input shapes of then subgraph
972   const auto &cond_inputs = cond_graph.getInputs();
973   assert(inputs.size() == cond_inputs.size());
974   for (size_t i = 0; i < inputs.size(); ++i)
975   {
976     const auto &input = _operands.at(inputs.at(i));
977     auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
978     if (input.info().isDynamic())
979     {
980       cond_input.info().setDynamic();
981     }
982     else
983     {
984       auto new_shape = input.info().shape();
985       cond_input.info().shape(new_shape);
986     }
987   }
988
989   // re-sizing input shapes of body subgraph
990   const auto &body_inputs = body_graph.getInputs();
991   assert(cond_inputs.size() == body_inputs.size());
992   for (size_t i = 0; i < cond_inputs.size(); ++i)
993   {
994     const auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
995     auto &body_input = body_graph.operands().at(body_inputs.at(i));
996     if (cond_input.info().isDynamic())
997     {
998       body_input.info().setDynamic();
999     }
1000     else
1001     {
1002       const auto &new_shape = cond_input.info().shape();
1003       body_input.info().shape(new_shape);
1004     }
1005   }
1006
1007   // re-sizing operands of body subgraph
1008   StaticShapeInferer body_inferer(op.param().body_subg_index, _lowered_subgs);
1009   _lowered_subgs.at(op.param().body_subg_index)
1010       ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
1011         bool has_dynamic_tensor = body_inferer.infer(op_seq);
1012         op_seq.has_dynamic_tensor(has_dynamic_tensor);
1013       });
1014
1015   // Check whether while operation's shapes are predictable
1016   // If any of shape of body outputs and cond inputs are different, non-constant operands would be
1017   // set to dynamic
1018   bool check_unpredictable_dynamic = false;
1019   const auto &body_outputs = body_graph.getOutputs();
1020   assert(body_outputs.size() == cond_inputs.size());
1021   for (size_t i = 0; i < body_outputs.size(); ++i)
1022   {
1023     const auto &body_output = body_graph.operands().at(body_outputs.at(i));
1024     auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
1025     if ((cond_input.info().isDynamic() != body_output.info().isDynamic()) ||
1026         (cond_input.shape() != body_output.shape()))
1027     {
1028       check_unpredictable_dynamic = true;
1029       break;
1030     }
1031   }
1032
1033   if (check_unpredictable_dynamic)
1034   {
1035     // Set inputs of body subgraph
1036     for (const auto &input_index : body_inputs)
1037     {
1038       auto &input = body_graph.operands().at(input_index);
1039       if (!input.isConstant())
1040       {
1041         input.info().setDynamic();
1042       }
1043     }
1044
1045     // Set inputs of cond subgraph
1046     for (const auto &input_index : cond_inputs)
1047     {
1048       auto &input = cond_graph.operands().at(input_index);
1049       if (!input.isConstant())
1050       {
1051         input.info().setDynamic();
1052       }
1053     }
1054
1055     // Set non-constant operands of body subgraph to dynamic
1056     StaticShapeInferer body_inferer(op.param().body_subg_index, _lowered_subgs);
1057     _lowered_subgs.at(op.param().body_subg_index)
1058         ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
1059           bool has_dynamic_tensor = body_inferer.infer(op_seq);
1060           op_seq.has_dynamic_tensor(has_dynamic_tensor);
1061         });
1062   }
1063
1064   // re-sizing operands of cond subgraph
1065   // If check_unpredictable_dynamic is true, non-constant operands of cond subgraph would be set to
1066   // dynamic
1067   StaticShapeInferer cond_inferer(op.param().cond_subg_index, _lowered_subgs);
1068   _lowered_subgs.at(op.param().cond_subg_index)
1069       ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
1070         bool has_dynamic_tensor = cond_inferer.infer(op_seq);
1071         op_seq.has_dynamic_tensor(has_dynamic_tensor);
1072       });
1073
1074   // re-sizing outputs of while operation
1075   // If check_unpredictable_dynamic is true, outputs of while operation would be set to dynamic
1076   assert(cond_inputs.size() == outputs.size());
1077   for (size_t i = 0; i < cond_inputs.size(); ++i)
1078   {
1079     const auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
1080     auto &output = _operands.at(outputs.at(i));
1081     if (cond_input.info().isDynamic())
1082     {
1083       output.info().setDynamic();
1084       _return_has_dynamic_tensor = true;
1085     }
1086     else
1087     {
1088       const auto new_shape = cond_input.info().shape();
1089       output.info().shape(new_shape);
1090     }
1091   }
1092 }
1093
1094 } // namespace compiler
1095
1096 } // namespace onert