2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "compiler/StaticShapeInference.h"
18 #include "util/ShapeInference.h"
19 #include "util/logging.h"
28 void StaticShapeInferer::handleBinaryArithmeticOp(const ir::Operation &op,
29 const ir::OperandIndex lhs_idx,
30 const ir::OperandIndex rhs_idx)
32 const auto &lhs = _operands.at(lhs_idx);
33 const auto &rhs = _operands.at(rhs_idx);
35 const auto output_idx = op.getOutputs().at(0);
36 ir::Operand &output = _operands.at(output_idx);
38 if (lhs.info().isDynamic() || rhs.info().isDynamic())
40 output.info().setDynamic();
41 _return_has_dynamic_tensor = true;
45 // re-sizing output shape
46 ir::Shape new_shape = shape_inference::inferEltwiseShape(lhs.info().shape(), rhs.info().shape());
47 output.info().shape(new_shape);
50 void StaticShapeInferer::handleSimpleUnaryOp(const ir::Operation &op,
51 const ir::OperandIndex input_idx)
53 const auto &input = _operands.at(input_idx);
55 // get mutable output operand
56 const auto output_idx = op.getOutputs().at(0);
57 ir::Operand &output = _operands.at(output_idx);
59 // if input is dynamic, output also becomes dynamic
60 if (input.info().isDynamic())
62 output.info().setDynamic();
63 _return_has_dynamic_tensor = true;
67 // re-sizing output shape
68 ir::Shape new_shape = input.info().shape();
69 output.info().shape(new_shape);
72 void StaticShapeInferer::dump()
74 auto get_shape_str = [](const ir::Shape &shape) {
75 std::stringstream sstream;
76 sstream << "shape : {";
77 for (int i = 0; i < shape.rank(); i++)
80 sstream << shape.dim(i);
82 sstream << " " << shape.dim(i);
88 for (const auto &pair : _lowered_subgs)
90 const auto index = pair.first;
91 const auto &lowered_subg = pair.second;
92 VERBOSE(StaticShapeInferer) << "SubGraph #" << index.value() << std::endl;
93 lowered_subg->graph().operands().iterate(
94 [&](const ir::OperandIndex &ind, const ir::Operand &operand) {
95 VERBOSE(StaticShapeInferer) << "Operand #" << ind.value() << ", "
96 << (operand.info().isDynamic() ? "Dynamic" : "Static") << ", "
97 << get_shape_str(operand.info().shape()) << std::endl;
102 void StaticShapeInferer::visit(const ir::operation::Abs &op)
104 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Abs::Input::INPUT));
107 void StaticShapeInferer::visit(const ir::operation::Add &op)
109 handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Add::Input::LHS),
110 op.getInputs().at(ir::operation::Add::Input::RHS));
113 void StaticShapeInferer::visit(const ir::operation::ArgMax &op)
115 const auto input_idx{op.getInputs().at(ir::operation::ArgMax::Input::INPUT)};
116 const auto &input = _operands.at(input_idx);
118 // get mutable output operand
119 const auto output_idx = op.getOutputs().at(0);
120 ir::Operand &output = _operands.at(output_idx);
122 // if input is dynamic, output also becomes dynamic
123 if (input.info().isDynamic())
125 output.info().setDynamic();
126 _return_has_dynamic_tensor = true;
130 const auto rank = input.info().shape().rank();
131 const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
133 assert(0 <= axis && axis < rank);
135 // re-sizing output shape
136 ir::Shape new_shape = shape_inference::inferArgMaxShape(input.info().shape(), axis, rank);
137 output.info().shape(new_shape);
140 void StaticShapeInferer::visit(const ir::operation::BatchMatMul &op)
142 const auto lhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::LHS);
143 const auto rhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::RHS);
144 const auto output_index = op.getOutputs().at(0);
145 const auto lhs = _operands.at(lhs_index);
146 const auto rhs = _operands.at(rhs_index);
147 auto &output = _operands.at(output_index);
149 if (lhs.info().isDynamic() || rhs.info().isDynamic())
151 output.info().setDynamic();
152 _return_has_dynamic_tensor = true;
156 auto new_shape = shape_inference::inferBatchMatMulShape(lhs.shape(), rhs.shape(), op.param());
157 output.info().shape(new_shape);
160 void StaticShapeInferer::visit(const ir::operation::BroadcastTo &op)
162 const auto input_idx{op.getInputs().at(ir::operation::BroadcastTo::Input::INPUT)};
163 const auto &input = _operands.at(input_idx);
165 // get mutable output operand
166 const auto output_idx = op.getOutputs().at(0);
167 ir::Operand &output = _operands.at(output_idx);
169 // if input is dynamic, output also becomes dynamic.
170 if (input.info().isDynamic())
172 output.info().setDynamic();
173 _return_has_dynamic_tensor = true;
177 const auto shape_idx{op.getInputs().at(ir::operation::BroadcastTo::Input::SHAPE)};
178 const auto &shape = _operands.at(shape_idx);
180 if (!shape.isConstant())
182 output.info().setDynamic();
183 _return_has_dynamic_tensor = true;
187 // assert(shape.typeInfo().type() == ir::DataType::INT32);
188 auto shape_buffer = reinterpret_cast<const int32_t *>(shape.data()->base());
190 // re-sizing output shape
191 ir::Shape new_shape = shape_inference::inferBroadcastToShape(shape.info().shape(), shape_buffer);
192 output.info().shape(new_shape);
195 void StaticShapeInferer::visit(const ir::operation::Cast &op)
197 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Cast::Input::INPUT));
200 void StaticShapeInferer::visit(const ir::operation::Comparison &op)
202 handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Comparison::Input::INPUT0),
203 op.getInputs().at(ir::operation::Comparison::Input::INPUT1));
206 void StaticShapeInferer::visit(const ir::operation::Concat &op)
208 const auto input_count = op.getInputs().size();
210 const auto output_idx = op.getOutputs().at(0);
211 ir::Operand &output = _operands.at(output_idx);
213 shape_inference::Shapes input_shapes;
214 for (uint32_t i = 0; i < input_count; i++)
216 const auto input_idx{op.getInputs().at(i)};
217 const auto &input = _operands.at(input_idx);
219 if (input.info().isDynamic())
221 output.info().setDynamic();
222 _return_has_dynamic_tensor = true;
226 input_shapes.emplace_back(input.shape());
229 ir::Shape out_shape = shape_inference::inferConcatShape(input_shapes, op.param());
231 // re-sizing output shape
232 output.info().shape(out_shape);
235 void StaticShapeInferer::visit(const ir::operation::Conv2D &op)
237 const auto input_idx{op.getInputs().at(ir::operation::Conv2D::Input::INPUT)};
238 const auto &input = _operands.at(input_idx);
239 const auto ker_idx{op.getInputs().at(ir::operation::Conv2D::Input::KERNEL)};
240 const auto &ker = _operands.at(ker_idx);
241 const auto output_idx = op.getOutputs().at(0);
242 ir::Operand &output = _operands.at(output_idx);
244 if (input.info().isDynamic() || ker.info().isDynamic())
246 output.info().setDynamic();
247 _return_has_dynamic_tensor = true;
251 // re-sizing output shape
252 ir::Shape new_shape =
253 shape_inference::inferConv2DShape(input.info().shape(), ker.info().shape(), op.param());
254 output.info().shape(new_shape);
257 void StaticShapeInferer::visit(const ir::operation::Cos &op)
259 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Cos::Input::INPUT));
262 void StaticShapeInferer::visit(const ir::operation::Div &op)
264 handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Div::Input::LHS),
265 op.getInputs().at(ir::operation::Div::Input::RHS));
268 void StaticShapeInferer::visit(const ir::operation::Exp &op)
270 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Exp::Input::INPUT));
273 void StaticShapeInferer::visit(const ir::operation::ExpandDims &op)
275 const auto input_idx{op.getInputs().at(ir::operation::ExpandDims::Input::INPUT)};
276 const auto &input = _operands.at(input_idx);
277 const auto axis_idx{op.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
278 const auto &axis = _operands.at(axis_idx);
279 const auto output_idx = op.getOutputs().at(0);
280 ir::Operand &output = _operands.at(output_idx);
282 if (input.info().isDynamic())
284 output.info().setDynamic();
285 _return_has_dynamic_tensor = true;
289 if (!axis.isConstant())
291 output.info().setDynamic();
292 _return_has_dynamic_tensor = true;
296 // even when axis is constant, output shape should be recalculated since user might call
297 // nnfw_set_input_tensorinfo(input, some_new_shape)
298 auto axis_buf = reinterpret_cast<const int32_t *>(axis.data()->base());
301 // re-sizing output shape
302 ir::Shape new_shape = shape_inference::inferExpandDimsShape(input.info().shape(), axis_buf[0]);
303 output.info().shape(new_shape);
306 void StaticShapeInferer::visit(const ir::operation::Fill &op)
308 const auto input_idx{op.getInputs().at(ir::operation::Fill::Input::INPUT)};
309 const auto &input = _operands.at(input_idx);
310 const auto output_idx = op.getOutputs().at(0);
311 ir::Operand &output = _operands.at(output_idx);
313 if (input.info().isDynamic())
315 output.info().setDynamic();
316 _return_has_dynamic_tensor = true;
320 if (!input.isConstant())
322 output.info().setDynamic();
323 _return_has_dynamic_tensor = true;
327 assert(input.typeInfo().type() == ir::DataType::INT32);
329 auto input_buf = reinterpret_cast<const int32_t *>(input.data()->base());
332 // re-sizing output shape
333 ir::Shape new_shape = shape_inference::inferFillShape(input.info().shape(), input_buf);
334 output.info().shape(new_shape);
337 void StaticShapeInferer::visit(const ir::operation::FullyConnected &op)
339 const auto input_idx{op.getInputs().at(ir::operation::FullyConnected::Input::INPUT)};
340 const auto &input = _operands.at(input_idx);
342 const auto ker_idx{op.getInputs().at(ir::operation::FullyConnected::Input::WEIGHT)};
343 const auto &ker = _operands.at(ker_idx);
345 // get mutable output operand
346 const auto output_idx = op.getOutputs().at(0);
347 ir::Operand &output = _operands.at(output_idx);
349 // if input or ker is dynamic, output also becomes dynamic
350 if (input.info().isDynamic() || ker.info().isDynamic())
352 output.info().setDynamic();
353 _return_has_dynamic_tensor = true;
357 // re-sizing output shape
358 ir::Shape new_shape =
359 shape_inference::inferFullyConnectedShape(input.info().shape(), ker.info().shape());
360 output.info().shape(new_shape);
363 void StaticShapeInferer::visit(const ir::operation::FusedBatchNorm &op)
365 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::FusedBatchNorm::Input::INPUT));
368 void StaticShapeInferer::visit(const ir::operation::Gather &op)
370 const auto input_idx{op.getInputs().at(ir::operation::Gather::Input::INPUT)};
371 const auto &input = _operands.at(input_idx);
373 // get mutable output operand
374 const auto output_idx = op.getOutputs().at(0);
375 ir::Operand &output = _operands.at(output_idx);
377 const auto indices_idx{op.getInputs().at(ir::operation::Gather::Input::INDICES)};
378 const auto &indices = _operands.at(indices_idx);
380 // if input is dynamic, output also becomes dynamic
381 if (input.info().isDynamic() || indices.info().isDynamic())
383 output.info().setDynamic();
384 _return_has_dynamic_tensor = true;
388 const auto rank = input.info().shape().rank();
389 const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
391 assert(0 <= axis && axis < rank);
393 // re-sizing output shape
394 ir::Shape new_shape =
395 shape_inference::inferGatherShape(input.info().shape(), indices.info().shape(), axis, rank);
396 output.info().shape(new_shape);
399 void StaticShapeInferer::visit(const ir::operation::If &op)
401 auto &then_graph = _lowered_subgs.at(op.param().then_subg_index)->graph();
402 auto &else_graph = _lowered_subgs.at(op.param().else_subg_index)->graph();
403 const std::vector<ir::OperandIndex> inputs{op.getInputs().begin() + 1, op.getInputs().end()};
404 const auto &outputs = op.getOutputs();
406 // re-sizing input shapes of then subgraph
407 const auto &then_inputs = then_graph.getInputs();
408 assert(inputs.size() == then_inputs.size());
409 for (size_t i = 0; i < inputs.size(); ++i)
411 auto &then_input = then_graph.operands().at(then_inputs.at(i));
412 if (_operands.at(inputs.at(i)).info().isDynamic())
414 then_input.info().setDynamic();
418 auto new_shape = _operands.at(inputs.at(i)).info().shape();
419 then_input.info().shape(new_shape);
423 // re-sizing input shapes of else subgraph
424 const auto &else_inputs = else_graph.getInputs();
425 assert(inputs.size() == else_inputs.size());
426 for (size_t i = 0; i < inputs.size(); ++i)
428 auto &else_input = else_graph.operands().at(else_inputs.at(i));
429 if (_operands.at(inputs.at(i)).info().isDynamic())
431 else_input.info().setDynamic();
435 const auto &new_shape = _operands.at(inputs.at(i)).info().shape();
436 else_input.info().shape(new_shape);
440 // re-sizing operands of then subgraph
441 StaticShapeInferer then_inferer(op.param().then_subg_index, _lowered_subgs);
442 _lowered_subgs.at(op.param().then_subg_index)
443 ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
444 bool has_dynamic_tensor = then_inferer.infer(op_seq);
445 op_seq.has_dynamic_tensor(has_dynamic_tensor);
448 // re-sizing operands of else subgraph
449 StaticShapeInferer else_inferer(op.param().else_subg_index, _lowered_subgs);
450 _lowered_subgs.at(op.param().else_subg_index)
451 ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
452 bool has_dynamic_tensor = else_inferer.infer(op_seq);
453 op_seq.has_dynamic_tensor(has_dynamic_tensor);
456 // re-sizing output shapes
457 const auto &then_outputs = _lowered_subgs.at(op.param().then_subg_index)->graph().getOutputs();
458 const auto &else_outputs = _lowered_subgs.at(op.param().else_subg_index)->graph().getOutputs();
459 assert(outputs.size() == then_outputs.size());
460 assert(outputs.size() == else_outputs.size());
461 for (size_t i = 0; i < outputs.size(); ++i)
463 const auto &then_output = then_graph.operands().at(then_outputs.at(i));
464 const auto &else_output = else_graph.operands().at(else_outputs.at(i));
465 auto &output = _operands.at(outputs.at(i));
466 if (!then_output.info().isDynamic() && !else_output.info().isDynamic() &&
467 then_output.shape() == else_output.shape())
469 output.info().shape(then_output.shape());
473 output.info().setDynamic();
474 _return_has_dynamic_tensor = true;
479 void StaticShapeInferer::visit(const ir::operation::Log &op)
481 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Log::Input::INPUT));
484 void StaticShapeInferer::visit(const ir::operation::LogicalNot &op)
486 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::LogicalNot::Input::INPUT));
489 void StaticShapeInferer::visit(const ir::operation::LogicalOr &op)
491 handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::LogicalOr::Input::INPUT0),
492 op.getInputs().at(ir::operation::LogicalOr::Input::INPUT1));
495 void StaticShapeInferer::visit(const ir::operation::Logistic &op)
497 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Logistic::Input::INPUT));
500 void StaticShapeInferer::visit(const ir::operation::L2Normalization &op)
502 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::L2Normalization::Input::INPUT));
505 void StaticShapeInferer::visit(const ir::operation::MatrixBandPart &op)
507 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT));
510 void StaticShapeInferer::visit(const ir::operation::Max &op)
512 handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Max::Input::LHS),
513 op.getInputs().at(ir::operation::Max::Input::RHS));
516 void StaticShapeInferer::visit(const ir::operation::Min &op)
518 handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Min::Input::LHS),
519 op.getInputs().at(ir::operation::Min::Input::RHS));
522 void StaticShapeInferer::visit(const ir::operation::Mul &op)
524 handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Mul::Input::LHS),
525 op.getInputs().at(ir::operation::Mul::Input::RHS));
528 void StaticShapeInferer::visit(const ir::operation::Neg &op)
530 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Neg::Input::INPUT));
533 void StaticShapeInferer::visit(const ir::operation::OneHot &op)
535 const auto indice_idx{op.getInputs().at(ir::operation::OneHot::Input::INDICES)};
536 const auto &indice = _operands.at(indice_idx);
537 const auto depth_idx{op.getInputs().at(ir::operation::OneHot::Input::DEPTH)};
538 const auto &depth = _operands.at(depth_idx);
540 const auto axis = op.param().axis;
542 auto output_idx = op.getOutputs().at(0);
543 ir::Operand &output = _operands.at(output_idx);
545 if (indice.info().isDynamic() || depth.info().isDynamic() || !depth.isConstant())
547 output.info().setDynamic();
548 _return_has_dynamic_tensor = true;
552 const auto *depth_buf = reinterpret_cast<const int32_t *>(depth.data()->base());
554 // re-sizing output shape
555 ir::Shape new_shape = shape_inference::inferOnehotShape(indice.info().shape(), *depth_buf, axis);
556 output.info().shape(new_shape);
559 void StaticShapeInferer::visit(const ir::operation::Pack &op)
561 bool is_any_of_inputs_dynamic = [&]() -> bool {
562 for (uint32_t i = 0; i < op.getInputs().size(); ++i)
564 const auto &input = _operands.at(op.getInputs().at(i));
565 if (input.info().isDynamic())
573 const auto input_idx{op.getInputs().at(0)};
574 const auto &input = _operands.at(input_idx);
576 // get mutable output operand
577 const auto output_idx = op.getOutputs().at(0);
578 ir::Operand &output = _operands.at(output_idx);
580 // if input is dynamic, output also becomes dynamic
581 if (is_any_of_inputs_dynamic)
583 output.info().setDynamic();
584 _return_has_dynamic_tensor = true;
588 const auto rank = input.shape().rank() + 1;
589 const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
590 const auto num = op.param().num;
592 assert(0 <= axis && axis < rank);
594 // re-sizing output shape
595 ir::Shape new_shape = shape_inference::inferPackShape(input.info().shape(), axis, rank, num);
596 output.info().shape(new_shape);
599 void StaticShapeInferer::visit(const ir::operation::Pad &op)
601 const auto input_idx{op.getInputs().at(ir::operation::Pad::Input::INPUT)};
602 const auto &input = _operands.at(input_idx);
604 const auto pad_idx{op.getInputs().at(ir::operation::Pad::Input::PAD)};
605 const auto &pad = _operands.at(pad_idx);
607 // get mutable output operand
608 const auto output_idx = op.getOutputs().at(0);
609 ir::Operand &output = _operands.at(output_idx);
611 // if input is dynamic or pad is dynamic, output also becomes dynamic
612 if (input.info().isDynamic() || pad.info().isDynamic())
614 output.info().setDynamic();
615 _return_has_dynamic_tensor = true;
619 // if pad is not constant, output also becomes dynamic
620 if (!pad.isConstant())
622 output.info().setDynamic();
623 _return_has_dynamic_tensor = true;
627 // re-sizing output shape
628 const auto new_shape = shape_inference::inferPadShape(
629 input.shape(), reinterpret_cast<const int32_t *>(pad.data()->base()),
630 pad.shape().num_elements());
631 output.info().shape(new_shape);
634 void StaticShapeInferer::visit(const ir::operation::Permute &op)
636 const auto input_idx{op.getInputs().at(0)};
637 const auto &input = _operands.at(input_idx);
638 const auto output_idx = op.getOutputs().at(0);
639 ir::Operand &output = _operands.at(output_idx);
641 if (input.info().isDynamic())
643 output.info().setDynamic();
644 _return_has_dynamic_tensor = true;
648 // re-sizing output shape
649 // Permute is a special operation that layouts of input/output may be different on backend
650 // However, it is not applied here, so input/output have the same layout of frontend. Because
651 // "ExecutorFactory" would convert shape of input/output accoding to the layouts when registering
652 // operand info to "TensorBuilder" after calling "StaticShapeInferer"
653 const auto new_shape = input.info().shape();
654 output.info().shape(new_shape);
657 void StaticShapeInferer::visit(const ir::operation::Pow &op)
659 handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Pow::Input::LHS),
660 op.getInputs().at(ir::operation::Pow::Input::RHS));
663 void StaticShapeInferer::visit(const ir::operation::Range &op)
665 const auto start_idx{op.getInputs().at(ir::operation::Range::Input::START)};
666 const auto limit_idx{op.getInputs().at(ir::operation::Range::Input::LIMIT)};
667 const auto delta_idx{op.getInputs().at(ir::operation::Range::Input::DELTA)};
668 const auto &start_op = _operands.at(start_idx);
669 const auto &limit_op = _operands.at(limit_idx);
670 const auto &delta_op = _operands.at(delta_idx);
672 // get mutable output operand
673 const auto output_idx = op.getOutputs().at(0);
674 ir::Operand &output = _operands.at(output_idx);
675 // if any input is dynamic, output also becomes dynamic
676 if (start_op.info().isDynamic() || limit_op.info().isDynamic() || delta_op.info().isDynamic())
678 output.info().setDynamic();
679 _return_has_dynamic_tensor = true;
684 if (start_op.isConstant() && limit_op.isConstant() && delta_op.isConstant())
686 assert(start_op.typeInfo().type() == limit_op.typeInfo().type() &&
687 start_op.typeInfo().type() == delta_op.typeInfo().type());
688 if (output.typeInfo().type() == ir::DataType::FLOAT32)
690 new_shape = shape_inference::inferRangeShape<float>(
691 start_op.asScalar<float>(), limit_op.asScalar<float>(), delta_op.asScalar<float>());
693 else if (output.typeInfo().type() == ir::DataType::INT32)
695 new_shape = shape_inference::inferRangeShape<int32_t>(
696 start_op.asScalar<int32_t>(), limit_op.asScalar<int32_t>(), delta_op.asScalar<int32_t>());
698 assert(output.shape() == new_shape);
702 output.info().setDynamic();
703 _return_has_dynamic_tensor = true;
707 void StaticShapeInferer::visit(const ir::operation::Reduce &op)
709 const auto input_idx{op.getInputs().at(ir::operation::Reduce::Input::INPUT)};
710 const auto &input = _operands.at(input_idx);
712 const auto axes_idx{op.getInputs().at(ir::operation::Reduce::Input::AXES)};
713 const auto &axes = _operands.at(axes_idx);
715 // get mutable output operand
716 const auto output_idx = op.getOutputs().at(0);
717 ir::Operand &output = _operands.at(output_idx);
719 // if input is dynamic, output also becomes dynamic
720 if (input.info().isDynamic())
722 output.info().setDynamic();
723 _return_has_dynamic_tensor = true;
727 std::vector<int32_t> axes_vec;
728 for (size_t i = 0; i < axes.shape().num_elements(); ++i)
730 switch (axes.typeInfo().type())
732 case ir::DataType::INT32:
734 axes_vec.emplace_back(reinterpret_cast<const int32_t *>(axes.data()->base())[i]);
737 case ir::DataType::INT64:
739 axes_vec.emplace_back(reinterpret_cast<const int64_t *>(axes.data()->base())[i]);
743 throw std::runtime_error("StaticShapeInferer " + op.name() + ": Not supported data type");
747 const auto keep_dims = op.param().keep_dims;
749 // re-sizing output shape
750 ir::Shape new_shape =
751 shape_inference::inferReduceShape(input.info().shape(), axes_vec, keep_dims);
752 output.info().shape(new_shape);
755 void StaticShapeInferer::visit(const ir::operation::Reshape &op)
757 const auto input_idx{op.getInputs().at(ir::operation::Reshape::Input::INPUT)};
758 const auto &input = _operands.at(input_idx);
760 // get mutable output operand
761 const auto output_idx = op.getOutputs().at(0);
762 ir::Operand &output = _operands.at(output_idx);
764 // if input is dynamic, output also becomes dynamic
765 if (input.info().isDynamic())
767 output.info().setDynamic();
768 _return_has_dynamic_tensor = true;
772 // New shape is given by second input tensor
773 if (op.getInputs().size() == 2)
775 // Let's check the second input
776 const auto shape_idx{op.getInputs().at(ir::operation::Reshape::Input::SHAPE)};
777 const auto &shape = _operands.at(shape_idx);
779 if (shape.isConstant())
781 const auto *shape_buf = reinterpret_cast<const int32_t *>(shape.data()->base());
784 ir::Shape new_shape = shape_inference::inferReshapeShape(
785 shape_buf, shape.shape().num_elements(), input.shape().num_elements());
787 // if shape is from Const, TFLC put the shape of output into tensor
788 if (new_shape != output.shape())
790 // change on output shape
791 output.info().shape(new_shape);
796 // if shape is NOT Const, set output shape to be dynamic_
797 output.info().setDynamic();
798 _return_has_dynamic_tensor = true;
801 // New shape is given by option
802 else if (op.param().new_shape.size() != 0)
804 // Let's check the new_shape option
805 auto shape = op.param().new_shape;
806 ir::Shape new_shape = shape_inference::inferReshapeShape(shape.data(), shape.size(),
807 input.shape().num_elements());
809 if (new_shape != output.shape())
811 // change on output shape
812 output.info().shape(new_shape);
817 throw std::runtime_error("Reshape: new shape is missing");
821 void StaticShapeInferer::visit(const ir::operation::ResizeBilinear &op)
823 const auto input_idx{op.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
824 const auto &input = _operands.at(input_idx);
826 // get mutable output operand
827 const auto output_idx = op.getOutputs().at(0);
828 ir::Operand &output = _operands.at(output_idx);
830 // if input is dynamic, output also becomes dynamic
831 if (input.info().isDynamic())
833 output.info().setDynamic();
834 _return_has_dynamic_tensor = true;
838 // Shape inferencing logic based on Params
839 ir::Shape new_shape = shape_inference::inferResizeBilinearShape(
840 input.shape(), op.param().height_out, op.param().width_out);
842 // if size_op is from Const, TFLC put the shape of output into tensor
843 if (new_shape != output.shape())
845 // change on output shape
846 output.info().shape(new_shape);
850 void StaticShapeInferer::visit(const ir::operation::Reverse &op)
852 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Reverse::Input::INPUT));
855 void StaticShapeInferer::visit(const ir::operation::Round &op)
857 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Round::Input::INPUT));
860 void StaticShapeInferer::visit(const ir::operation::RSQRT &op)
862 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::RSQRT::Input::INPUT));
865 void StaticShapeInferer::visit(const ir::operation::Select &op)
867 const auto input_cond_idx{op.getInputs().at(ir::operation::Select::Input::CONDITION)};
868 const auto &input_cond = _operands.at(input_cond_idx);
870 const auto input_true_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_TRUE)};
871 const auto &input_true = _operands.at(input_true_idx);
873 const auto input_false_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_FALSE)};
874 const auto &input_false = _operands.at(input_false_idx);
876 auto output_idx = op.getOutputs().at(0);
877 ir::Operand &output = _operands.at(output_idx);
879 if (input_cond.info().isDynamic() || input_true.info().isDynamic() ||
880 input_false.info().isDynamic())
882 output.info().setDynamic();
883 _return_has_dynamic_tensor = true;
887 // Select output shpae
888 ir::Shape new_shape = shape_inference::inferSelectShape(
889 input_cond.info().shape(), input_true.info().shape(), input_false.info().shape());
890 output.info().shape(new_shape);
893 void StaticShapeInferer::visit(const ir::operation::Shape &op)
895 const auto input_idx{op.getInputs().at(0)};
896 const auto &input = _operands.at(input_idx);
898 // get mutable output operand
899 const auto output_idx = op.getOutputs().at(0);
900 ir::Operand &output = _operands.at(output_idx);
902 // if input is dynamic, output also becomes dynamic
903 if (input.info().isDynamic())
905 output.info().setDynamic();
906 _return_has_dynamic_tensor = true;
910 // re-sizing output shape
911 ir::Shape output_shape;
912 output_shape.append(input.info().shape().rank());
914 output.info().shape(output_shape);
917 void StaticShapeInferer::visit(const ir::operation::Sin &op)
919 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Sin::Input::INPUT));
922 void StaticShapeInferer::visit(const ir::operation::Slice &op)
924 const auto input_index{op.getInputs().at(ir::operation::Slice::Input::INPUT)};
925 const auto &input = _operands.at(input_index);
926 const auto begins_index{op.getInputs().at(ir::operation::Slice::Input::BEGINS)};
927 const auto &begins = _operands.at(begins_index);
928 const auto sizes_index{op.getInputs().at(ir::operation::Slice::Input::SIZES)};
929 const auto &sizes = _operands.at(sizes_index);
930 const auto output_index = op.getOutputs().at(0);
931 ir::Operand &output = _operands.at(output_index);
933 if (input.info().isDynamic() || begins.info().isDynamic() || sizes.info().isDynamic())
935 output.info().setDynamic();
936 _return_has_dynamic_tensor = true;
940 // Whether input is constant or not does not affect whether output is dynamic or not
941 if (!(begins.isConstant() && sizes.isConstant()))
943 output.info().setDynamic();
944 _return_has_dynamic_tensor = true;
948 auto begins_buf = reinterpret_cast<const int32_t *>(begins.data()->base());
949 auto sizes_buf = reinterpret_cast<const int32_t *>(sizes.data()->base());
951 ir::Shape new_shape =
952 shape_inference::inferSliceShape(input.info().shape(), begins_buf, sizes_buf);
953 output.info().shape(new_shape);
956 void StaticShapeInferer::visit(const ir::operation::Softmax &op)
958 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Softmax::Input::INPUT));
961 void StaticShapeInferer::visit(const ir::operation::SpaceToBatchND &op)
963 const auto output_index = op.getOutputs().at(0);
964 const auto input_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
965 const auto block_shape_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)};
966 const auto padding_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
968 ir::Operand &output = _operands.at(output_index);
969 const auto &input = _operands.at(input_idx);
970 const auto &block_shape = _operands.at(block_shape_idx);
971 const auto &padding = _operands.at(padding_idx);
973 if (input.info().isDynamic() || block_shape.info().isDynamic() || padding.info().isDynamic())
975 output.info().setDynamic();
976 _return_has_dynamic_tensor = true;
980 // Whether input is constant or not does not affect whether output is dynamic or not
981 if (!(block_shape.isConstant() && padding.isConstant()))
983 output.info().setDynamic();
984 _return_has_dynamic_tensor = true;
988 auto input_shape = input.info().shape();
989 auto block_shape_shape = block_shape.info().shape();
990 auto padding_shape = padding.info().shape();
992 auto block_shape_data = reinterpret_cast<const int32_t *>(block_shape.data()->base());
993 auto padding_data = reinterpret_cast<const int32_t *>(padding.data()->base());
995 ir::Shape new_shape = shape_inference::inferSpaceToBatchNDShape(
996 input_shape, block_shape_shape, padding_shape, block_shape_data, padding_data);
998 output.info().shape(new_shape);
1001 void StaticShapeInferer::visit(const ir::operation::Split &op)
1003 const auto input_idx{op.getInputs().at(0)};
1004 const auto &input = _operands.at(input_idx);
1006 const auto axis = op.param().axis;
1007 const auto num_splits = op.param().num_splits;
1009 if (input.info().isDynamic())
1011 for (int out_tensor_idx = 0; out_tensor_idx < num_splits; out_tensor_idx++)
1013 const auto output_idx = op.getOutputs().at(out_tensor_idx);
1014 ir::Operand &output = _operands.at(output_idx);
1015 output.info().setDynamic();
1017 _return_has_dynamic_tensor = true;
1021 const auto rank = input.info().shape().rank();
1022 auto axis_resolved = axis < 0 ? axis + rank : axis;
1024 assert(0 <= axis_resolved && axis_resolved < rank);
1026 ir::Shape new_shape =
1027 shape_inference::inferSplitShape(input.info().shape(), axis_resolved, num_splits);
1028 auto output_tensors = op.getOutputs();
1029 for (auto output_idx : output_tensors)
1031 ir::Operand &output = _operands.at(output_idx);
1032 output.info().shape(new_shape);
1036 void StaticShapeInferer::visit(const ir::operation::SquaredDifference &op)
1038 handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::SquaredDifference::Input::LHS),
1039 op.getInputs().at(ir::operation::SquaredDifference::Input::RHS));
1042 void StaticShapeInferer::visit(const ir::operation::Squeeze &op)
1044 const auto input_idx{op.getInputs().at(ir::operation::Squeeze::Input::INPUT)};
1045 const auto &input = _operands.at(input_idx);
1047 const auto output_idx = op.getOutputs().at(0);
1048 ir::Operand &output = _operands.at(output_idx);
1050 if (input.info().isDynamic())
1052 output.info().setDynamic();
1053 _return_has_dynamic_tensor = true;
1057 // Squeeze output shpae
1058 ir::Shape new_shape = shape_inference::inferSqueezeShape(input.info().shape(), op.param());
1059 output.info().shape(new_shape);
1062 void StaticShapeInferer::visit(const ir::operation::StridedSlice &op)
1064 const auto input_index{op.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
1065 const auto &input = _operands.at(input_index);
1066 const auto starts_index{op.getInputs().at(ir::operation::StridedSlice::Input::STARTS)};
1067 const auto &starts = _operands.at(starts_index);
1068 const auto ends_index{op.getInputs().at(ir::operation::StridedSlice::Input::ENDS)};
1069 const auto &ends = _operands.at(ends_index);
1070 const auto strides_index{op.getInputs().at(ir::operation::StridedSlice::Input::STRIDES)};
1071 const auto &strides = _operands.at(strides_index);
1072 const auto output_index = op.getOutputs().at(0);
1073 ir::Operand &output = _operands.at(output_index);
1075 if (input.info().isDynamic() || starts.info().isDynamic() || ends.info().isDynamic() ||
1076 strides.info().isDynamic())
1078 output.info().setDynamic();
1079 _return_has_dynamic_tensor = true;
1083 if (!(starts.isConstant() && ends.isConstant() && strides.isConstant()))
1085 output.info().setDynamic();
1086 _return_has_dynamic_tensor = true;
1090 const auto begin_mask = op.param().begin_mask;
1091 const auto end_mask = op.param().end_mask;
1092 const auto shrink_axis_mask = op.param().shrink_axis_mask;
1093 const auto rank = input.info().shape().rank();
1095 auto starts_buf = reinterpret_cast<const uint32_t *>(starts.data()->base());
1096 auto ends_buf = reinterpret_cast<const uint32_t *>(ends.data()->base());
1097 auto strides_buf = reinterpret_cast<const uint32_t *>(strides.data()->base());
1099 auto op_params = shape_inference::buildStridedSliceParams(
1100 starts_buf, ends_buf, strides_buf, begin_mask, end_mask, shrink_axis_mask, rank);
1102 ir::Shape new_shape =
1103 shape_inference::inferStridedSliceShape(input.info().shape(), op_params, rank);
1104 output.info().shape(new_shape);
1107 void StaticShapeInferer::visit(const ir::operation::Sub &op)
1109 handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Sub::Input::LHS),
1110 op.getInputs().at(ir::operation::Sub::Input::RHS));
1113 void StaticShapeInferer::visit(const ir::operation::Tanh &op)
1115 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Tanh::Input::INPUT));
1118 void StaticShapeInferer::visit(const ir::operation::Tile &op)
1120 const auto input_idx{op.getInputs().at(ir::operation::Tile::Input::INPUT)};
1121 const auto &input = _operands.at(input_idx);
1123 const auto multiplier_idx{op.getInputs().at(ir::operation::Tile::Input::MULTIPLES)};
1124 const auto &multiplier = _operands.at(multiplier_idx);
1126 const auto output_idx = op.getOutputs().at(0);
1127 ir::Operand &output = _operands.at(output_idx);
1129 if (input.info().isDynamic())
1131 output.info().setDynamic();
1132 _return_has_dynamic_tensor = true;
1136 if (!multiplier.isConstant())
1138 output.info().setDynamic();
1139 _return_has_dynamic_tensor = true;
1143 auto multiplier_buffer = reinterpret_cast<const int32_t *>(multiplier.data()->base());
1144 assert(multiplier_buffer);
1146 // re-sizing output shape
1147 auto new_shape = shape_inference::inferTileShape(input.info().shape(), multiplier_buffer);
1148 output.info().shape(new_shape);
1151 void StaticShapeInferer::visit(const ir::operation::Transpose &op)
1153 const auto input_idx{op.getInputs().at(ir::operation::Transpose::Input::INPUT)};
1154 const auto &input = _operands.at(input_idx);
1156 // get mutable output operand
1157 const auto output_idx = op.getOutputs().at(0);
1158 ir::Operand &output = _operands.at(output_idx);
1159 const auto perm{op.param().perm};
1160 // const auto rank{op.param().rank};
1161 // if input is dynamic, output also becomes dynamic
1162 if (input.info().isDynamic())
1164 output.info().setDynamic();
1165 _return_has_dynamic_tensor = true;
1168 // set output shape, based on input and params
1169 ir::Shape new_shape = shape_inference::inferTransposeShape(input.info().shape(), perm);
1170 output.info().shape(new_shape);
1173 void StaticShapeInferer::visit(const ir::operation::Unpack &op)
1175 const auto input_idx{op.getInputs().at(0)};
1176 const auto &input = _operands.at(input_idx);
1177 const auto num = op.param().num;
1179 // if input is dynamic, output also becomes dynamic
1180 if (input.info().isDynamic())
1182 for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++)
1184 const auto output_idx = op.getOutputs().at(out_tensor_idx);
1185 ir::Operand &output = _operands.at(output_idx);
1186 output.info().setDynamic();
1188 _return_has_dynamic_tensor = true;
1192 const auto rank = input.shape().rank();
1193 const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
1195 assert(axis < rank);
1198 for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++)
1200 const auto output_idx = op.getOutputs().at(out_tensor_idx);
1201 ir::Operand &output = _operands.at(output_idx);
1202 output.info().setDynamic();
1204 _return_has_dynamic_tensor = true;
1208 ir::Shape new_shape = shape_inference::inferUnpackShape(input.info().shape(), axis, rank);
1210 // re-sizing output shape
1211 for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++)
1213 const auto output_idx = op.getOutputs().at(out_tensor_idx);
1214 ir::Operand &output = _operands.at(output_idx);
1215 output.info().shape(new_shape);
1219 void StaticShapeInferer::visit(const ir::operation::While &op)
1221 auto &cond_graph = _lowered_subgs.at(op.param().cond_subg_index)->graph();
1222 auto &body_graph = _lowered_subgs.at(op.param().body_subg_index)->graph();
1223 const auto inputs = op.getInputs();
1224 const auto &outputs = op.getOutputs();
1226 // re-sizing input shapes of then subgraph
1227 const auto &cond_inputs = cond_graph.getInputs();
1228 assert(inputs.size() == cond_inputs.size());
1229 for (size_t i = 0; i < inputs.size(); ++i)
1231 const auto &input = _operands.at(inputs.at(i));
1232 auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
1233 if (input.info().isDynamic())
1235 cond_input.info().setDynamic();
1239 auto new_shape = input.info().shape();
1240 cond_input.info().shape(new_shape);
1244 // re-sizing input shapes of body subgraph
1245 const auto &body_inputs = body_graph.getInputs();
1246 assert(cond_inputs.size() == body_inputs.size());
1247 for (size_t i = 0; i < cond_inputs.size(); ++i)
1249 const auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
1250 auto &body_input = body_graph.operands().at(body_inputs.at(i));
1251 if (cond_input.info().isDynamic())
1253 body_input.info().setDynamic();
1257 const auto &new_shape = cond_input.info().shape();
1258 body_input.info().shape(new_shape);
1262 // re-sizing operands of body subgraph
1263 StaticShapeInferer body_inferer(op.param().body_subg_index, _lowered_subgs);
1264 _lowered_subgs.at(op.param().body_subg_index)
1265 ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
1266 bool has_dynamic_tensor = body_inferer.infer(op_seq);
1267 op_seq.has_dynamic_tensor(has_dynamic_tensor);
1270 // Check whether while operation's shapes are predictable
1271 // If any of shape of body outputs and cond inputs are different, non-constant operands would be
1273 bool check_unpredictable_dynamic = false;
1274 const auto &body_outputs = body_graph.getOutputs();
1275 assert(body_outputs.size() == cond_inputs.size());
1276 for (size_t i = 0; i < body_outputs.size(); ++i)
1278 const auto &body_output = body_graph.operands().at(body_outputs.at(i));
1279 auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
1280 if ((cond_input.info().isDynamic() != body_output.info().isDynamic()) ||
1281 (cond_input.shape() != body_output.shape()))
1283 check_unpredictable_dynamic = true;
1288 if (check_unpredictable_dynamic)
1290 // Set inputs of body subgraph
1291 for (const auto &input_index : body_inputs)
1293 auto &input = body_graph.operands().at(input_index);
1294 if (!input.isConstant())
1296 input.info().setDynamic();
1300 // Set inputs of cond subgraph
1301 for (const auto &input_index : cond_inputs)
1303 auto &input = cond_graph.operands().at(input_index);
1304 if (!input.isConstant())
1306 input.info().setDynamic();
1310 // Set non-constant operands of body subgraph to dynamic
1311 StaticShapeInferer body_inferer(op.param().body_subg_index, _lowered_subgs);
1312 _lowered_subgs.at(op.param().body_subg_index)
1313 ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
1314 bool has_dynamic_tensor = body_inferer.infer(op_seq);
1315 op_seq.has_dynamic_tensor(has_dynamic_tensor);
1319 // re-sizing operands of cond subgraph
1320 // If check_unpredictable_dynamic is true, non-constant operands of cond subgraph would be set to
1322 StaticShapeInferer cond_inferer(op.param().cond_subg_index, _lowered_subgs);
1323 _lowered_subgs.at(op.param().cond_subg_index)
1324 ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
1325 bool has_dynamic_tensor = cond_inferer.infer(op_seq);
1326 op_seq.has_dynamic_tensor(has_dynamic_tensor);
1329 // re-sizing outputs of while operation
1330 // If check_unpredictable_dynamic is true, outputs of while operation would be set to dynamic
1331 assert(cond_inputs.size() == outputs.size());
1332 for (size_t i = 0; i < cond_inputs.size(); ++i)
1334 const auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
1335 auto &output = _operands.at(outputs.at(i));
1336 if (cond_input.info().isDynamic())
1338 output.info().setDynamic();
1339 _return_has_dynamic_tensor = true;
1343 const auto new_shape = cond_input.info().shape();
1344 output.info().shape(new_shape);
1349 void StaticShapeInferer::visit(const ir::operation::ZerosLike &op)
1351 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::ZerosLike::Input::INPUT));
1354 } // namespace compiler
1356 } // namespace onert