2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "compiler/StaticShapeInferer.h"
18 #include "util/ShapeInference.h"
19 #include "util/logging.h"
28 bool StaticShapeInferer::infer(const ir::OpSequence &op_seq)
30 bool has_dynamic_tensor = false;
32 for (const auto &operation_idx : op_seq.operations())
34 auto &op = _operations.at(operation_idx);
35 auto opcode = op.opcode();
37 _return_has_dynamic_tensor = false; // this is used as a return value inside operation's visit()
39 // IF: need shape inference for then, else
40 // While: need shape inference for condition, body
41 if (opcode == ir::OpCode::If || opcode == ir::OpCode::While)
47 _return_has_dynamic_tensor = checkDynamicInput(op);
49 if (_return_has_dynamic_tensor)
59 has_dynamic_tensor = has_dynamic_tensor || _return_has_dynamic_tensor;
62 return has_dynamic_tensor;
65 bool StaticShapeInferer::checkDynamicInput(const ir::Operation &op)
67 for (auto input_idx : op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
69 if (_operands.at(input_idx).info().isDynamic())
78 void StaticShapeInferer::setDynamicOutput(const ir::Operation &op)
80 for (auto output_idx : op.getOutputs() | ir::Remove::UNDEFINED)
82 _operands.at(output_idx).info().setDynamic();
86 void StaticShapeInferer::handleBinaryArithmeticOp(const ir::Operation &op,
87 const ir::OperandIndex lhs_idx,
88 const ir::OperandIndex rhs_idx)
90 const auto &lhs = _operands.at(lhs_idx);
91 const auto &rhs = _operands.at(rhs_idx);
93 const auto output_idx = op.getOutputs().at(0);
94 ir::Operand &output = _operands.at(output_idx);
96 // re-sizing output shape
97 ir::Shape new_shape = shape_inference::inferEltwiseShape(lhs.info().shape(), rhs.info().shape());
98 output.info().shape(new_shape);
101 void StaticShapeInferer::handleSimpleUnaryOp(const ir::Operation &op,
102 const ir::OperandIndex input_idx)
104 const auto &input = _operands.at(input_idx);
106 // get mutable output operand
107 const auto output_idx = op.getOutputs().at(0);
108 ir::Operand &output = _operands.at(output_idx);
110 // re-sizing output shape
111 ir::Shape new_shape = input.info().shape();
112 output.info().shape(new_shape);
115 void StaticShapeInferer::dump()
117 auto get_shape_str = [](const ir::Shape &shape) {
118 std::stringstream sstream;
119 sstream << "shape : {";
120 for (int i = 0; i < shape.rank(); i++)
123 sstream << shape.dim(i);
125 sstream << " " << shape.dim(i);
128 return sstream.str();
131 for (const auto &pair : _lowered_subgs)
133 const auto index = pair.first;
134 const auto &lowered_subg = pair.second;
135 VERBOSE(StaticShapeInferer) << "SubGraph #" << index.value() << std::endl;
136 lowered_subg->graph().operands().iterate(
137 [&](const ir::OperandIndex &ind, const ir::Operand &operand) {
138 VERBOSE(StaticShapeInferer) << "Operand #" << ind.value() << ", "
139 << (operand.info().isDynamic() ? "Dynamic" : "Static") << ", "
140 << get_shape_str(operand.info().shape()) << std::endl;
145 void StaticShapeInferer::visit(const ir::operation::ArgMinMax &op)
147 const auto input_idx{op.getInputs().at(ir::operation::ArgMinMax::Input::INPUT)};
148 const auto &input = _operands.at(input_idx);
150 const auto axis_idx{op.getInputs().at(ir::operation::ArgMinMax::Input::AXIS)};
151 const auto &axis = _operands.at(axis_idx);
153 // get mutable output operand
154 const auto output_idx = op.getOutputs().at(0);
155 ir::Operand &output = _operands.at(output_idx);
157 if (!axis.isConstant())
159 output.info().setDynamic();
160 _return_has_dynamic_tensor = true;
164 const auto rank = input.info().shape().rank();
165 auto axis_value = axis.asScalar<int32_t>();
166 axis_value = axis_value < 0 ? axis_value + rank : axis_value;
168 // re-sizing output shape
169 ir::Shape new_shape =
170 shape_inference::inferArgMinMaxShape(input.info().shape(), axis_value, rank);
171 output.info().shape(new_shape);
174 void StaticShapeInferer::visit(const ir::operation::BatchMatMul &op)
176 const auto lhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::LHS);
177 const auto rhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::RHS);
178 const auto output_index = op.getOutputs().at(0);
179 const auto &lhs = _operands.at(lhs_index);
180 const auto &rhs = _operands.at(rhs_index);
181 auto &output = _operands.at(output_index);
182 auto new_shape = shape_inference::inferBatchMatMulShape(lhs.shape(), rhs.shape(), op.param());
183 output.info().shape(new_shape);
186 void StaticShapeInferer::visit(const ir::operation::BCQFullyConnected &op)
188 const auto input_idx{op.getInputs().at(ir::operation::BCQFullyConnected::Input::INPUT)};
189 const auto &input = _operands.at(input_idx);
191 const auto cluster_idx{
192 op.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_CLUSTERS)};
193 const auto &cluster = _operands.at(cluster_idx);
195 const auto output_idx = op.getOutputs().at(0);
196 ir::Operand &output = _operands.at(output_idx);
198 auto cluster_buf = reinterpret_cast<const int32_t *>(cluster.data()->base());
201 // re-sizing output shape
202 ir::Shape new_shape = shape_inference::inferBCQFullyConnectedShape(
203 input.info().shape(), cluster.info().shape(), cluster_buf);
204 output.info().shape(new_shape);
207 void StaticShapeInferer::visit(const ir::operation::BCQGather &op)
209 const auto indices_idx{op.getInputs().at(ir::operation::BCQGather::Input::INDICES)};
210 const auto &indices = _operands.at(indices_idx);
212 const auto input_binary_idx{op.getInputs().at(ir::operation::BCQGather::Input::INPUT_BINARY)};
213 const auto &input_binary = _operands.at(input_binary_idx);
215 const auto cluster_idx{op.getInputs().at(ir::operation::BCQGather::Input::INPUT_CLUSTERS)};
216 const auto &cluster = _operands.at(cluster_idx);
218 const auto output_idx = op.getOutputs().at(0);
219 ir::Operand &output = _operands.at(output_idx);
221 auto cluster_buf = reinterpret_cast<const int32_t *>(cluster.data()->base());
224 auto rank = input_binary.shape().rank();
226 // re-sizing output shape
227 ir::Shape new_shape = shape_inference::inferBCQGatherShape(
228 indices.info().shape(), cluster.info().shape(), cluster_buf, rank, op.param());
230 output.info().shape(new_shape);
233 void StaticShapeInferer::visit(const ir::operation::BinaryArithmetic &op)
235 handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::BinaryArithmetic::Input::LHS),
236 op.getInputs().at(ir::operation::BinaryArithmetic::Input::RHS));
239 void StaticShapeInferer::visit(const ir::operation::BroadcastTo &op)
241 // get mutable output operand
242 const auto output_idx = op.getOutputs().at(0);
243 ir::Operand &output = _operands.at(output_idx);
245 const auto shape_idx{op.getInputs().at(ir::operation::BroadcastTo::Input::SHAPE)};
246 const auto &shape = _operands.at(shape_idx);
248 if (!shape.isConstant())
250 output.info().setDynamic();
251 _return_has_dynamic_tensor = true;
255 // assert(shape.typeInfo().type() == ir::DataType::INT32);
256 auto shape_buffer = reinterpret_cast<const int32_t *>(shape.data()->base());
258 // re-sizing output shape
259 ir::Shape new_shape = shape_inference::inferBroadcastToShape(shape.info().shape(), shape_buffer);
260 output.info().shape(new_shape);
263 void StaticShapeInferer::visit(const ir::operation::Comparison &op)
265 handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Comparison::Input::INPUT0),
266 op.getInputs().at(ir::operation::Comparison::Input::INPUT1));
269 void StaticShapeInferer::visit(const ir::operation::Concat &op)
271 const auto input_count = op.getInputs().size();
273 const auto output_idx = op.getOutputs().at(0);
274 ir::Operand &output = _operands.at(output_idx);
276 shape_inference::Shapes input_shapes;
277 for (uint32_t i = 0; i < input_count; i++)
279 const auto input_idx{op.getInputs().at(i)};
280 const auto &input = _operands.at(input_idx);
281 input_shapes.emplace_back(input.shape());
284 ir::Shape out_shape = shape_inference::inferConcatShape(input_shapes, op.param());
286 // re-sizing output shape
287 output.info().shape(out_shape);
290 void StaticShapeInferer::visit(const ir::operation::Conv2D &op)
292 const auto input_idx{op.getInputs().at(ir::operation::Conv2D::Input::INPUT)};
293 const auto &input = _operands.at(input_idx);
294 const auto ker_idx{op.getInputs().at(ir::operation::Conv2D::Input::KERNEL)};
295 const auto &ker = _operands.at(ker_idx);
296 const auto output_idx = op.getOutputs().at(0);
297 ir::Operand &output = _operands.at(output_idx);
299 // re-sizing output shape
300 ir::Shape new_shape =
301 shape_inference::inferConv2DShape(input.info().shape(), ker.info().shape(), op.param());
302 output.info().shape(new_shape);
305 void StaticShapeInferer::visit(const ir::operation::ElementwiseActivation &op)
307 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::ElementwiseActivation::Input::INPUT));
310 void StaticShapeInferer::visit(const ir::operation::ElementwiseBinary &op)
312 handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::ElementwiseBinary::Input::LHS),
313 op.getInputs().at(ir::operation::ElementwiseBinary::Input::RHS));
316 void StaticShapeInferer::visit(const ir::operation::ElementwiseUnary &op)
318 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT));
321 void StaticShapeInferer::visit(const ir::operation::ExpandDims &op)
323 const auto input_idx{op.getInputs().at(ir::operation::ExpandDims::Input::INPUT)};
324 const auto &input = _operands.at(input_idx);
325 const auto axis_idx{op.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
326 const auto &axis = _operands.at(axis_idx);
327 const auto output_idx = op.getOutputs().at(0);
328 ir::Operand &output = _operands.at(output_idx);
330 if (!axis.isConstant())
332 output.info().setDynamic();
333 _return_has_dynamic_tensor = true;
337 // even when axis is constant, output shape should be recalculated since user might call
338 // nnfw_set_input_tensorinfo(input, some_new_shape)
339 auto axis_type = axis.typeInfo().type();
340 assert(axis_type == ir::DataType::INT32 || axis_type == ir::DataType::INT64);
342 assert(axis.data()->base());
344 (axis_type == ir::DataType::INT32)
345 ? reinterpret_cast<const int32_t *>(axis.data()->base())[0]
346 : static_cast<int32_t>(reinterpret_cast<const int64_t *>(axis.data()->base())[0]);
348 // re-sizing output shape
349 ir::Shape new_shape = shape_inference::inferExpandDimsShape(input.info().shape(), axis_value);
350 output.info().shape(new_shape);
353 void StaticShapeInferer::visit(const ir::operation::Fill &op)
355 const auto shape_idx{op.getInputs().at(ir::operation::Fill::Input::SHAPE)};
356 const auto &shape = _operands.at(shape_idx);
357 const auto output_idx = op.getOutputs().at(0);
358 ir::Operand &output = _operands.at(output_idx);
360 if (!shape.isConstant())
362 output.info().setDynamic();
363 _return_has_dynamic_tensor = true;
367 const auto dims_type = shape.typeInfo().type();
368 assert(dims_type == ir::DataType::INT32 || dims_type == ir::DataType::INT64);
370 auto dims_buf = shape.data()->base();
373 const auto &dims_shape = shape.info().shape();
374 auto new_shape = ((dims_type == ir::DataType::INT32)
375 ? shape_inference::inferFillShape<int32_t>(
376 dims_shape, reinterpret_cast<const int32_t *>(dims_buf))
377 : shape_inference::inferFillShape<int64_t>(
378 dims_shape, reinterpret_cast<const int64_t *>(dims_buf)));
380 output.info().shape(new_shape);
383 void StaticShapeInferer::visit(const ir::operation::FullyConnected &op)
385 const auto input_idx{op.getInputs().at(ir::operation::FullyConnected::Input::INPUT)};
386 const auto &input = _operands.at(input_idx);
388 const auto ker_idx{op.getInputs().at(ir::operation::FullyConnected::Input::WEIGHT)};
389 const auto &ker = _operands.at(ker_idx);
391 // get mutable output operand
392 const auto output_idx = op.getOutputs().at(0);
393 ir::Operand &output = _operands.at(output_idx);
394 // re-sizing output shape
395 ir::Shape new_shape =
396 shape_inference::inferFullyConnectedShape(input.info().shape(), ker.info().shape());
397 output.info().shape(new_shape);
400 void StaticShapeInferer::visit(const ir::operation::FusedBatchNorm &op)
402 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::FusedBatchNorm::Input::INPUT));
405 void StaticShapeInferer::visit(const ir::operation::Gather &op)
407 const auto input_idx{op.getInputs().at(ir::operation::Gather::Input::INPUT)};
408 const auto &input = _operands.at(input_idx);
410 // get mutable output operand
411 const auto output_idx = op.getOutputs().at(0);
412 ir::Operand &output = _operands.at(output_idx);
414 const auto indices_idx{op.getInputs().at(ir::operation::Gather::Input::INDICES)};
415 const auto &indices = _operands.at(indices_idx);
416 const auto rank = input.info().shape().rank();
417 const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
419 assert(0 <= axis && axis < rank);
421 // re-sizing output shape
422 ir::Shape new_shape =
423 shape_inference::inferGatherShape(input.info().shape(), indices.info().shape(), axis, rank);
424 output.info().shape(new_shape);
427 void StaticShapeInferer::visit(const ir::operation::If &op)
429 auto &then_graph = _lowered_subgs.at(op.param().then_subg_index)->graph();
430 auto &else_graph = _lowered_subgs.at(op.param().else_subg_index)->graph();
431 const std::vector<ir::OperandIndex> inputs{op.getInputs().begin() + 1, op.getInputs().end()};
432 const auto &outputs = op.getOutputs();
434 // re-sizing input shapes of then subgraph
435 const auto &then_inputs = then_graph.getInputs();
436 assert(inputs.size() == then_inputs.size());
437 for (size_t i = 0; i < inputs.size(); ++i)
439 auto &then_input = then_graph.operands().at(then_inputs.at(i));
440 if (_operands.at(inputs.at(i)).info().isDynamic())
442 then_input.info().setDynamic();
446 auto new_shape = _operands.at(inputs.at(i)).info().shape();
447 then_input.info().shape(new_shape);
451 // re-sizing input shapes of else subgraph
452 const auto &else_inputs = else_graph.getInputs();
453 assert(inputs.size() == else_inputs.size());
454 for (size_t i = 0; i < inputs.size(); ++i)
456 auto &else_input = else_graph.operands().at(else_inputs.at(i));
457 if (_operands.at(inputs.at(i)).info().isDynamic())
459 else_input.info().setDynamic();
463 const auto &new_shape = _operands.at(inputs.at(i)).info().shape();
464 else_input.info().shape(new_shape);
468 // re-sizing operands of then subgraph
469 StaticShapeInferer then_inferer(op.param().then_subg_index, _lowered_subgs);
470 _lowered_subgs.at(op.param().then_subg_index)
471 ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
472 bool has_dynamic_tensor = then_inferer.infer(op_seq);
473 op_seq.has_dynamic_tensor(has_dynamic_tensor);
476 // re-sizing operands of else subgraph
477 StaticShapeInferer else_inferer(op.param().else_subg_index, _lowered_subgs);
478 _lowered_subgs.at(op.param().else_subg_index)
479 ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
480 bool has_dynamic_tensor = else_inferer.infer(op_seq);
481 op_seq.has_dynamic_tensor(has_dynamic_tensor);
484 // re-sizing output shapes
485 const auto &then_outputs = _lowered_subgs.at(op.param().then_subg_index)->graph().getOutputs();
486 const auto &else_outputs = _lowered_subgs.at(op.param().else_subg_index)->graph().getOutputs();
487 assert(outputs.size() == then_outputs.size());
488 assert(outputs.size() == else_outputs.size());
489 for (size_t i = 0; i < outputs.size(); ++i)
491 const auto &then_output = then_graph.operands().at(then_outputs.at(i));
492 const auto &else_output = else_graph.operands().at(else_outputs.at(i));
493 auto &output = _operands.at(outputs.at(i));
494 if (!then_output.info().isDynamic() && !else_output.info().isDynamic() &&
495 then_output.shape() == else_output.shape())
497 output.info().shape(then_output.shape());
501 output.info().setDynamic();
502 _return_has_dynamic_tensor = true;
507 void StaticShapeInferer::visit(const ir::operation::L2Normalization &op)
509 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::L2Normalization::Input::INPUT));
512 void StaticShapeInferer::visit(const ir::operation::LSTM &op)
514 const auto output_index{op.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
515 auto &output = _operands.at(output_index);
517 const auto output_state_out_index{
518 op.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)};
520 const auto cell_state_out_index{op.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)};
522 const auto scratch_buffer_index{op.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)};
524 if (output.info().isDynamic() || (_operands.exist(output_state_out_index) &&
525 _operands.at(output_state_out_index).info().isDynamic()) ||
526 (_operands.exist(cell_state_out_index) &&
527 _operands.at(cell_state_out_index).info().isDynamic()) ||
528 (_operands.exist(scratch_buffer_index) &&
529 _operands.at(scratch_buffer_index).info().isDynamic()))
532 const auto input_index{op.getInputs().at(ir::operation::LSTM::Input::INPUT)};
533 const auto &input = _operands.at(input_index);
535 const auto input_to_output_weights_index{
536 op.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)};
537 const auto &input_to_output_weights = _operands.at(input_to_output_weights_index);
539 const auto recurrent_to_output_weights_index{
540 op.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)};
541 const auto &recurrent_to_output_weights = _operands.at(recurrent_to_output_weights_index);
544 const int n_batch = (input.shape().rank() == 3 && op.param().time_major) ? input.shape().dim(1)
545 : input.shape().dim(0);
546 const int n_cell = input_to_output_weights.shape().dim(0);
547 const int n_output = recurrent_to_output_weights.shape().dim(1);
548 if (input.shape().rank() == 3)
550 if (op.param().time_major)
551 output.info().shape(ir::Shape{input.shape().dim(0), n_batch, n_output});
553 output.info().shape(ir::Shape{n_batch, input.shape().dim(1), n_output});
557 assert(input.shape().rank() == 2);
558 output.info().shape(ir::Shape{n_batch, n_output});
561 if (_operands.exist(output_state_out_index))
563 auto &output_state_out = _operands.at(output_state_out_index);
564 output_state_out.info().shape(ir::Shape{n_batch, n_output});
567 if (_operands.exist(cell_state_out_index))
569 auto &cell_state_out = _operands.at(cell_state_out_index);
570 cell_state_out.info().shape(ir::Shape{n_batch, n_cell});
573 if (_operands.exist(scratch_buffer_index))
575 auto &scratch_buffer = _operands.at(scratch_buffer_index);
577 const auto input_to_input_weights_index{
578 op.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)};
579 const auto recurrent_to_input_weights_index{
580 op.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)};
582 bool has_input_to_input_weights =
583 _operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
584 _operands.at(input_to_input_weights_index).shape().dim(1) != 0;
585 bool has_recurrent_to_input_weights =
586 _operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
587 _operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0;
589 // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG).
592 bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
595 scratch_buffer.info().shape(ir::Shape{n_batch, n_cell * 4});
599 scratch_buffer.info().shape(ir::Shape{n_batch, n_cell * 3});
604 void StaticShapeInferer::visit(const ir::operation::MatrixBandPart &op)
606 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT));
609 void StaticShapeInferer::visit(const ir::operation::OneHot &op)
611 const auto indice_idx{op.getInputs().at(ir::operation::OneHot::Input::INDICES)};
612 const auto &indice = _operands.at(indice_idx);
613 const auto depth_idx{op.getInputs().at(ir::operation::OneHot::Input::DEPTH)};
614 const auto &depth = _operands.at(depth_idx);
616 const auto axis = op.param().axis;
618 auto output_idx = op.getOutputs().at(0);
619 ir::Operand &output = _operands.at(output_idx);
621 if (!depth.isConstant())
623 output.info().setDynamic();
624 _return_has_dynamic_tensor = true;
628 const auto *depth_buf = reinterpret_cast<const int32_t *>(depth.data()->base());
630 // re-sizing output shape
631 ir::Shape new_shape = shape_inference::inferOnehotShape(indice.info().shape(), *depth_buf, axis);
632 output.info().shape(new_shape);
635 void StaticShapeInferer::visit(const ir::operation::Pack &op)
637 const auto input_idx{op.getInputs().at(0)};
638 const auto &input = _operands.at(input_idx);
640 // get mutable output operand
641 const auto output_idx = op.getOutputs().at(0);
642 ir::Operand &output = _operands.at(output_idx);
644 const auto rank = input.shape().rank() + 1;
645 const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
646 const auto num = op.param().num;
648 assert(0 <= axis && axis < rank);
650 // re-sizing output shape
651 ir::Shape new_shape = shape_inference::inferPackShape(input.info().shape(), axis, rank, num);
652 output.info().shape(new_shape);
655 void StaticShapeInferer::visit(const ir::operation::Pad &op)
657 const auto input_idx{op.getInputs().at(ir::operation::Pad::Input::INPUT)};
658 const auto &input = _operands.at(input_idx);
660 const auto pad_idx{op.getInputs().at(ir::operation::Pad::Input::PAD)};
661 const auto &pad = _operands.at(pad_idx);
663 // get mutable output operand
664 const auto output_idx = op.getOutputs().at(0);
665 ir::Operand &output = _operands.at(output_idx);
667 // if pad is not constant, output also becomes dynamic
668 if (!pad.isConstant())
670 output.info().setDynamic();
671 _return_has_dynamic_tensor = true;
675 // re-sizing output shape
676 const auto new_shape = shape_inference::inferPadShape(
677 input.shape(), reinterpret_cast<const int32_t *>(pad.data()->base()),
678 pad.shape().num_elements());
679 output.info().shape(new_shape);
682 void StaticShapeInferer::visit(const ir::operation::Permute &op)
684 const auto input_idx{op.getInputs().at(0)};
685 const auto &input = _operands.at(input_idx);
686 const auto output_idx = op.getOutputs().at(0);
687 ir::Operand &output = _operands.at(output_idx);
689 // re-sizing output shape
690 // Permute is a special operation that layouts of input/output may be different on backend
691 // However, it is not applied here, so input/output have the same layout of frontend. Because
692 // "ExecutorFactory" would convert shape of input/output accoding to the layouts when registering
693 // operand info to "TensorBuilder" after calling "StaticShapeInferer"
694 const auto new_shape = input.info().shape();
695 output.info().shape(new_shape);
698 void StaticShapeInferer::visit(const ir::operation::Pow &op)
700 handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Pow::Input::LHS),
701 op.getInputs().at(ir::operation::Pow::Input::RHS));
704 void StaticShapeInferer::visit(const ir::operation::Range &op)
706 const auto start_idx{op.getInputs().at(ir::operation::Range::Input::START)};
707 const auto limit_idx{op.getInputs().at(ir::operation::Range::Input::LIMIT)};
708 const auto delta_idx{op.getInputs().at(ir::operation::Range::Input::DELTA)};
709 const auto &start_op = _operands.at(start_idx);
710 const auto &limit_op = _operands.at(limit_idx);
711 const auto &delta_op = _operands.at(delta_idx);
713 // get mutable output operand
714 const auto output_idx = op.getOutputs().at(0);
715 ir::Operand &output = _operands.at(output_idx);
718 if (start_op.isConstant() && limit_op.isConstant() && delta_op.isConstant())
720 assert(start_op.typeInfo().type() == limit_op.typeInfo().type() &&
721 start_op.typeInfo().type() == delta_op.typeInfo().type());
722 if (output.typeInfo().type() == ir::DataType::FLOAT32)
724 new_shape = shape_inference::inferRangeShape<float>(
725 start_op.asScalar<float>(), limit_op.asScalar<float>(), delta_op.asScalar<float>());
727 else if (output.typeInfo().type() == ir::DataType::INT32)
729 new_shape = shape_inference::inferRangeShape<int32_t>(
730 start_op.asScalar<int32_t>(), limit_op.asScalar<int32_t>(), delta_op.asScalar<int32_t>());
732 assert(output.shape() == new_shape);
736 output.info().setDynamic();
737 _return_has_dynamic_tensor = true;
741 void StaticShapeInferer::visit(const ir::operation::Reduce &op)
743 const auto input_idx{op.getInputs().at(ir::operation::Reduce::Input::INPUT)};
744 const auto &input = _operands.at(input_idx);
746 const auto axes_idx{op.getInputs().at(ir::operation::Reduce::Input::AXES)};
747 const auto &axes = _operands.at(axes_idx);
749 // get mutable output operand
750 const auto output_idx = op.getOutputs().at(0);
751 ir::Operand &output = _operands.at(output_idx);
753 std::vector<int32_t> axes_vec;
754 for (size_t i = 0; i < axes.shape().num_elements(); ++i)
756 switch (axes.typeInfo().type())
758 case ir::DataType::INT32:
760 axes_vec.emplace_back(reinterpret_cast<const int32_t *>(axes.data()->base())[i]);
763 case ir::DataType::INT64:
765 axes_vec.emplace_back(reinterpret_cast<const int64_t *>(axes.data()->base())[i]);
769 throw std::runtime_error("StaticShapeInferer " + op.name() + ": Not supported data type");
773 const auto keep_dims = op.param().keep_dims;
775 // re-sizing output shape
776 ir::Shape new_shape =
777 shape_inference::inferReduceShape(input.info().shape(), axes_vec, keep_dims);
778 output.info().shape(new_shape);
781 void StaticShapeInferer::visit(const ir::operation::Reshape &op)
783 const auto input_idx{op.getInputs().at(ir::operation::Reshape::Input::INPUT)};
784 const auto &input = _operands.at(input_idx);
786 // get mutable output operand
787 const auto output_idx = op.getOutputs().at(0);
788 ir::Operand &output = _operands.at(output_idx);
790 // New shape is given by second input tensor
791 if (op.getInputs().size() == 2)
793 // Let's check the second input
794 const auto shape_idx{op.getInputs().at(ir::operation::Reshape::Input::SHAPE)};
795 const auto &shape = _operands.at(shape_idx);
797 if (shape.isConstant())
799 const auto *shape_buf = reinterpret_cast<const int32_t *>(shape.data()->base());
802 ir::Shape new_shape = shape_inference::inferReshapeShape(
803 shape_buf, shape.shape().num_elements(), input.shape().num_elements());
805 // if shape is from Const, TFLC put the shape of output into tensor
806 if (new_shape != output.shape())
808 // change on output shape
809 output.info().shape(new_shape);
814 // if shape is NOT Const, set output shape to be dynamic_
815 output.info().setDynamic();
816 _return_has_dynamic_tensor = true;
819 // New shape is given by option
820 else if (op.param().new_shape.size() != 0)
822 // Let's check the new_shape option
823 auto shape = op.param().new_shape;
824 ir::Shape new_shape = shape_inference::inferReshapeShape(shape.data(), shape.size(),
825 input.shape().num_elements());
827 if (new_shape != output.shape())
829 // change on output shape
830 output.info().shape(new_shape);
835 throw std::runtime_error("Reshape: new shape is missing");
839 void StaticShapeInferer::visit(const ir::operation::ResizeBilinear &op)
841 const auto input_idx{op.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
842 const auto &input = _operands.at(input_idx);
844 // get mutable output operand
845 const auto output_idx = op.getOutputs().at(0);
846 ir::Operand &output = _operands.at(output_idx);
848 int32_t height_out, width_out;
849 if (op.getInputs().size() == 2)
851 auto &size = _operands.at(op.getInputs().at(ir::operation::ResizeBilinear::Input::SIZE));
852 if (!size.isConstant())
854 output.info().setDynamic();
855 _return_has_dynamic_tensor = true;
858 const auto size_v = size.asVector<std::int32_t>();
859 height_out = size_v[0];
860 width_out = size_v[1];
864 height_out = op.param().height_out;
865 width_out = op.param().width_out;
868 // Shape inferencing logic based on Params
869 ir::Shape new_shape =
870 shape_inference::inferResizeBilinearShape(input.shape(), height_out, width_out);
872 // if size_op is from Const, TFLC put the shape of output into tensor
873 if (new_shape != output.shape())
875 // change on output shape
876 output.info().shape(new_shape);
880 void StaticShapeInferer::visit(const ir::operation::Reverse &op)
882 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Reverse::Input::INPUT));
885 void StaticShapeInferer::visit(const ir::operation::Select &op)
887 const auto input_cond_idx{op.getInputs().at(ir::operation::Select::Input::CONDITION)};
888 const auto &input_cond = _operands.at(input_cond_idx);
890 const auto input_true_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_TRUE)};
891 const auto &input_true = _operands.at(input_true_idx);
893 const auto input_false_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_FALSE)};
894 const auto &input_false = _operands.at(input_false_idx);
896 auto output_idx = op.getOutputs().at(0);
897 ir::Operand &output = _operands.at(output_idx);
899 // Select output shpae
900 ir::Shape new_shape = shape_inference::inferSelectShape(
901 input_cond.info().shape(), input_true.info().shape(), input_false.info().shape());
902 output.info().shape(new_shape);
905 void StaticShapeInferer::visit(const ir::operation::Shape &op)
907 const auto input_idx{op.getInputs().at(0)};
908 const auto &input = _operands.at(input_idx);
910 // get mutable output operand
911 const auto output_idx = op.getOutputs().at(0);
912 ir::Operand &output = _operands.at(output_idx);
914 // re-sizing output shape
915 ir::Shape output_shape;
916 output_shape.append(input.info().shape().rank());
918 output.info().shape(output_shape);
921 void StaticShapeInferer::visit(const ir::operation::Slice &op)
923 const auto input_index{op.getInputs().at(ir::operation::Slice::Input::INPUT)};
924 const auto &input = _operands.at(input_index);
925 const auto begins_index{op.getInputs().at(ir::operation::Slice::Input::BEGINS)};
926 const auto &begins = _operands.at(begins_index);
927 const auto sizes_index{op.getInputs().at(ir::operation::Slice::Input::SIZES)};
928 const auto &sizes = _operands.at(sizes_index);
929 const auto output_index = op.getOutputs().at(0);
930 ir::Operand &output = _operands.at(output_index);
932 // Whether input is constant or not does not affect whether output is dynamic or not
933 if (!(begins.isConstant() && sizes.isConstant()))
935 output.info().setDynamic();
936 _return_has_dynamic_tensor = true;
940 auto begins_buf = reinterpret_cast<const int32_t *>(begins.data()->base());
941 auto sizes_buf = reinterpret_cast<const int32_t *>(sizes.data()->base());
943 ir::Shape new_shape =
944 shape_inference::inferSliceShape(input.info().shape(), begins_buf, sizes_buf);
945 output.info().shape(new_shape);
948 void StaticShapeInferer::visit(const ir::operation::Softmax &op)
950 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Softmax::Input::INPUT));
953 void StaticShapeInferer::visit(const ir::operation::SpaceToBatchND &op)
955 const auto output_index = op.getOutputs().at(0);
956 const auto input_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
957 const auto block_shape_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)};
958 const auto padding_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
960 ir::Operand &output = _operands.at(output_index);
961 const auto &input = _operands.at(input_idx);
962 const auto &block_shape = _operands.at(block_shape_idx);
963 const auto &padding = _operands.at(padding_idx);
965 // Whether input is constant or not does not affect whether output is dynamic or not
966 if (!(block_shape.isConstant() && padding.isConstant()))
968 output.info().setDynamic();
969 _return_has_dynamic_tensor = true;
973 auto input_shape = input.info().shape();
974 auto block_shape_shape = block_shape.info().shape();
975 auto padding_shape = padding.info().shape();
977 auto block_shape_data = reinterpret_cast<const int32_t *>(block_shape.data()->base());
978 auto padding_data = reinterpret_cast<const int32_t *>(padding.data()->base());
980 ir::Shape new_shape = shape_inference::inferSpaceToBatchNDShape(
981 input_shape, block_shape_shape, padding_shape, block_shape_data, padding_data);
983 output.info().shape(new_shape);
986 void StaticShapeInferer::visit(const ir::operation::Split &op)
988 const auto input_idx{op.getInputs().at(ir::operation::Split::Input::INPUT)};
989 const auto &input = _operands.at(input_idx);
991 const auto axis_idx{op.getInputs().at(ir::operation::Split::Input::AXIS)};
992 const auto &axis = _operands.at(axis_idx);
994 auto outputs = op.getOutputs();
995 if (!axis.isConstant())
997 for (auto output_idx : outputs)
999 ir::Operand &output = _operands.at(output_idx);
1000 output.info().setDynamic();
1002 _return_has_dynamic_tensor = true;
1006 const auto num_splits = op.param().num_splits;
1008 const auto rank = input.info().shape().rank();
1009 auto axis_value = axis.asScalar<int32_t>();
1010 axis_value = axis_value < 0 ? axis_value + rank : axis_value;
1012 assert(0 <= axis_value && axis_value < rank);
1014 ir::Shape new_shape =
1015 shape_inference::inferSplitShape(input.info().shape(), axis_value, num_splits);
1016 for (auto output_idx : outputs)
1018 ir::Operand &output = _operands.at(output_idx);
1019 output.info().shape(new_shape);
1023 void StaticShapeInferer::visit(const ir::operation::SquaredDifference &op)
1025 handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::SquaredDifference::Input::LHS),
1026 op.getInputs().at(ir::operation::SquaredDifference::Input::RHS));
1029 void StaticShapeInferer::visit(const ir::operation::Squeeze &op)
1031 const auto input_idx{op.getInputs().at(ir::operation::Squeeze::Input::INPUT)};
1032 const auto &input = _operands.at(input_idx);
1034 const auto output_idx = op.getOutputs().at(0);
1035 ir::Operand &output = _operands.at(output_idx);
1037 // Squeeze output shpae
1038 ir::Shape new_shape = shape_inference::inferSqueezeShape(input.info().shape(), op.param());
1039 output.info().shape(new_shape);
1042 void StaticShapeInferer::visit(const ir::operation::StridedSlice &op)
1044 const auto input_index{op.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
1045 const auto &input = _operands.at(input_index);
1046 const auto starts_index{op.getInputs().at(ir::operation::StridedSlice::Input::STARTS)};
1047 const auto &starts = _operands.at(starts_index);
1048 const auto ends_index{op.getInputs().at(ir::operation::StridedSlice::Input::ENDS)};
1049 const auto &ends = _operands.at(ends_index);
1050 const auto strides_index{op.getInputs().at(ir::operation::StridedSlice::Input::STRIDES)};
1051 const auto &strides = _operands.at(strides_index);
1052 const auto output_index = op.getOutputs().at(0);
1053 ir::Operand &output = _operands.at(output_index);
1055 if (!(starts.isConstant() && ends.isConstant() && strides.isConstant()))
1057 output.info().setDynamic();
1058 _return_has_dynamic_tensor = true;
1062 const auto begin_mask = op.param().begin_mask;
1063 const auto end_mask = op.param().end_mask;
1064 const auto shrink_axis_mask = op.param().shrink_axis_mask;
1065 const auto rank = input.info().shape().rank();
1067 auto starts_buf = reinterpret_cast<const uint32_t *>(starts.data()->base());
1068 auto ends_buf = reinterpret_cast<const uint32_t *>(ends.data()->base());
1069 auto strides_buf = reinterpret_cast<const uint32_t *>(strides.data()->base());
1071 auto op_params = shape_inference::buildStridedSliceParams(
1072 starts_buf, ends_buf, strides_buf, begin_mask, end_mask, shrink_axis_mask, rank);
1074 ir::Shape new_shape =
1075 shape_inference::inferStridedSliceShape(input.info().shape(), op_params, rank);
1076 output.info().shape(new_shape);
1079 void StaticShapeInferer::visit(const ir::operation::Tile &op)
1081 const auto input_idx{op.getInputs().at(ir::operation::Tile::Input::INPUT)};
1082 const auto &input = _operands.at(input_idx);
1084 const auto multiplier_idx{op.getInputs().at(ir::operation::Tile::Input::MULTIPLES)};
1085 const auto &multiplier = _operands.at(multiplier_idx);
1087 const auto output_idx = op.getOutputs().at(0);
1088 ir::Operand &output = _operands.at(output_idx);
1090 if (!multiplier.isConstant())
1092 output.info().setDynamic();
1093 _return_has_dynamic_tensor = true;
1097 auto multiplier_buffer = reinterpret_cast<const int32_t *>(multiplier.data()->base());
1098 assert(multiplier_buffer);
1100 // re-sizing output shape
1101 auto new_shape = shape_inference::inferTileShape(input.info().shape(), multiplier_buffer,
1102 multiplier.shape().num_elements());
1103 output.info().shape(new_shape);
1106 void StaticShapeInferer::visit(const ir::operation::Transpose &op)
1108 const auto input_idx{op.getInputs().at(ir::operation::Transpose::Input::INPUT)};
1109 const auto &input = _operands.at(input_idx);
1111 const auto perm_idx{op.getInputs().at(ir::operation::Transpose::Input::PERMUTATION)};
1112 const auto &perm = _operands.at(perm_idx);
1114 // perm.shape() != ir::Shape{0} means that perm is (n-1...0)
1115 // TODO This condition changes to perm.num_elements() == 0
1116 const auto is_regular_transpose = perm.shape() == ir::Shape{0};
1118 // get mutable output operand
1119 const auto output_idx = op.getOutputs().at(0);
1120 auto &output = _operands.at(output_idx);
1121 if (!perm.isConstant() && !is_regular_transpose)
1123 output.info().setDynamic();
1124 _return_has_dynamic_tensor = true;
1128 ir::Shape new_shape;
1129 if (is_regular_transpose)
1131 // Call by (n-1...0)
1132 new_shape = shape_inference::inferTransposeShape(input.info().shape(), nullptr, 0);
1137 if (input.info().shape().rank() != static_cast<int>(perm.info().shape().num_elements()))
1139 throw std::runtime_error("StaticShapeInferer failed, bad rank size: " +
1140 std::to_string(perm.info().shape().num_elements()));
1143 // set output shape, based on input and params
1144 const auto perm_buf = reinterpret_cast<const int32_t *>(perm.data()->base());
1145 new_shape = shape_inference::inferTransposeShape(input.info().shape(), perm_buf,
1146 perm.shape().num_elements());
1148 output.info().shape(new_shape);
1151 void StaticShapeInferer::visit(const ir::operation::Unpack &op)
1153 const auto input_idx{op.getInputs().at(0)};
1154 const auto &input = _operands.at(input_idx);
1155 const auto num = op.param().num;
1156 const auto rank = input.shape().rank();
1157 const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
1159 assert(axis < rank);
1162 for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++)
1164 const auto output_idx = op.getOutputs().at(out_tensor_idx);
1165 ir::Operand &output = _operands.at(output_idx);
1166 output.info().setDynamic();
1168 _return_has_dynamic_tensor = true;
1172 ir::Shape new_shape = shape_inference::inferUnpackShape(input.info().shape(), axis, rank);
1174 // re-sizing output shape
1175 for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++)
1177 const auto output_idx = op.getOutputs().at(out_tensor_idx);
1178 ir::Operand &output = _operands.at(output_idx);
1179 output.info().shape(new_shape);
1183 void StaticShapeInferer::visit(const ir::operation::While &op)
1185 auto &cond_graph = _lowered_subgs.at(op.param().cond_subg_index)->graph();
1186 auto &body_graph = _lowered_subgs.at(op.param().body_subg_index)->graph();
1187 const auto inputs = op.getInputs();
1188 const auto &outputs = op.getOutputs();
1190 // re-sizing input shapes of then subgraph
1191 const auto &cond_inputs = cond_graph.getInputs();
1192 assert(inputs.size() == cond_inputs.size());
1193 for (size_t i = 0; i < inputs.size(); ++i)
1195 const auto &input = _operands.at(inputs.at(i));
1196 auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
1197 if (input.info().isDynamic())
1199 cond_input.info().setDynamic();
1203 auto new_shape = input.info().shape();
1204 cond_input.info().shape(new_shape);
1208 // re-sizing input shapes of body subgraph
1209 const auto &body_inputs = body_graph.getInputs();
1210 assert(cond_inputs.size() == body_inputs.size());
1211 for (size_t i = 0; i < cond_inputs.size(); ++i)
1213 const auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
1214 auto &body_input = body_graph.operands().at(body_inputs.at(i));
1215 if (cond_input.info().isDynamic())
1217 body_input.info().setDynamic();
1221 const auto &new_shape = cond_input.info().shape();
1222 body_input.info().shape(new_shape);
1226 // re-sizing operands of body subgraph
1227 StaticShapeInferer body_inferer(op.param().body_subg_index, _lowered_subgs);
1228 _lowered_subgs.at(op.param().body_subg_index)
1229 ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
1230 bool has_dynamic_tensor = body_inferer.infer(op_seq);
1231 op_seq.has_dynamic_tensor(has_dynamic_tensor);
1234 // Check whether while operation's shapes are predictable
1235 // If any of shape of body outputs and cond inputs are different, non-constant operands would be
1237 bool check_unpredictable_dynamic = false;
1238 const auto &body_outputs = body_graph.getOutputs();
1239 assert(body_outputs.size() == cond_inputs.size());
1240 for (size_t i = 0; i < body_outputs.size(); ++i)
1242 const auto &body_output = body_graph.operands().at(body_outputs.at(i));
1243 auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
1244 if ((cond_input.info().isDynamic() != body_output.info().isDynamic()) ||
1245 (cond_input.shape() != body_output.shape()))
1247 check_unpredictable_dynamic = true;
1252 if (check_unpredictable_dynamic)
1254 // Set inputs of body subgraph
1255 for (const auto &input_index : body_inputs)
1257 auto &input = body_graph.operands().at(input_index);
1258 if (!input.isConstant())
1260 input.info().setDynamic();
1264 // Set inputs of cond subgraph
1265 for (const auto &input_index : cond_inputs)
1267 auto &input = cond_graph.operands().at(input_index);
1268 if (!input.isConstant())
1270 input.info().setDynamic();
1274 // Set non-constant operands of body subgraph to dynamic
1275 StaticShapeInferer body_inferer(op.param().body_subg_index, _lowered_subgs);
1276 _lowered_subgs.at(op.param().body_subg_index)
1277 ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
1278 bool has_dynamic_tensor = body_inferer.infer(op_seq);
1279 op_seq.has_dynamic_tensor(has_dynamic_tensor);
1283 // re-sizing operands of cond subgraph
1284 // If check_unpredictable_dynamic is true, non-constant operands of cond subgraph would be set to
1286 StaticShapeInferer cond_inferer(op.param().cond_subg_index, _lowered_subgs);
1287 _lowered_subgs.at(op.param().cond_subg_index)
1288 ->iterateTopolOpSeqs([&](const ir::OpSequenceIndex &, ir::OpSequence &op_seq) {
1289 bool has_dynamic_tensor = cond_inferer.infer(op_seq);
1290 op_seq.has_dynamic_tensor(has_dynamic_tensor);
1293 // re-sizing outputs of while operation
1294 // If check_unpredictable_dynamic is true, outputs of while operation would be set to dynamic
1295 assert(cond_inputs.size() == outputs.size());
1296 for (size_t i = 0; i < cond_inputs.size(); ++i)
1298 const auto &cond_input = cond_graph.operands().at(cond_inputs.at(i));
1299 auto &output = _operands.at(outputs.at(i));
1300 if (cond_input.info().isDynamic())
1302 output.info().setDynamic();
1303 _return_has_dynamic_tensor = true;
1307 const auto new_shape = cond_input.info().shape();
1308 output.info().shape(new_shape);
1313 } // namespace compiler
1315 } // namespace onert