2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "compiler/StaticShapeInferer.h"
18 #include "util/ShapeInference.h"
19 #include "util/logging.h"
21 #include <misc/polymorphic_downcast.h>
30 void OperandObserver::updateShapes(const std::vector<ir::OperandInfo> &changed_operands_info,
33 assert(changed_operands_info.size() == _operands.size());
34 for (size_t i = 0; i < changed_operands_info.size(); ++i)
36 const auto &changed_operand_info = changed_operands_info.at(i);
37 auto &operand = _operands.at(i);
38 // assert(changed_operand_info.typeInfo() == operand->typeInfo());
39 // assert(changed_operand_info.typeInfo() == operand->typeInfo());
40 // This error check may by replaced by an assertion if this function is called after the
41 // validation of models are completed.
42 if (changed_operand_info.typeInfo() != operand->typeInfo())
44 throw std::runtime_error("OperandObserver: The types of operands are mismatched");
46 if (!operand->info().isConstant() && (changed_operand_info.isDynamic() || unpredictable))
48 operand->info().setDynamic();
52 const auto &new_shape = changed_operands_info.at(i).shape();
53 operand->info().shape(new_shape);
58 void StaticShapeInferer::infer()
60 for (const auto &op_idx : _lowered_subg->graph().topolSortOperations())
62 const auto &op = _lowered_subg->graph().operations().at(op_idx);
63 bool has_dynamic_tensor = false;
64 const auto opcode = op.opcode();
65 // IF: requires shape inference for then, else
66 // While: requires shape inference for condition, body
67 if (opcode == ir::OpCode::If || opcode == ir::OpCode::While)
73 has_dynamic_tensor = checkDynamicInput(op);
74 if (has_dynamic_tensor)
83 has_dynamic_tensor = has_dynamic_tensor || checkDynamicOutput(op);
84 _lowered_subg->setHasDynamicTensor(op_idx, has_dynamic_tensor);
87 if (_controlflow_output_observer != nullptr)
89 // re-sizing output shapes of the controflow operation branching to this subgraph
90 std::vector<ir::OperandInfo> outputs_info;
91 const auto &graph = _lowered_subg->graph();
92 const auto &outputs = graph.getOutputs();
93 for (size_t i = 0; i < outputs.size(); ++i)
95 const auto &operand_info = graph.operands().at(outputs.at(i)).info();
96 outputs_info.emplace_back(operand_info);
98 _controlflow_output_observer->updateShapes(outputs_info);
102 bool StaticShapeInferer::checkDynamicInput(const ir::Operation &op)
104 const auto &operands = _lowered_subg->graph().operands();
105 for (auto input_idx : op.getInputs() | ir::Remove::UNDEFINED | ir::Remove::DUPLICATED)
107 if (operands.at(input_idx).info().isDynamic())
116 bool StaticShapeInferer::checkDynamicOutput(const ir::Operation &op)
118 auto &operands = _lowered_subg->graph().operands();
119 for (auto output_idx : op.getOutputs() | ir::Remove::UNDEFINED)
121 if (operands.at(output_idx).info().isDynamic())
129 void StaticShapeInferer::setDynamicOutput(const ir::Operation &op)
131 auto &operands = _lowered_subg->graph().operands();
132 for (auto output_idx : op.getOutputs() | ir::Remove::UNDEFINED)
134 operands.at(output_idx).info().setDynamic();
138 void StaticShapeInferer::handleBinaryArithmeticOp(const ir::Operation &op,
139 const ir::OperandIndex lhs_idx,
140 const ir::OperandIndex rhs_idx)
142 auto &operands = _lowered_subg->graph().operands();
143 const auto &lhs = operands.at(lhs_idx);
144 const auto &rhs = operands.at(rhs_idx);
146 const auto output_idx = op.getOutputs().at(0);
147 ir::Operand &output = operands.at(output_idx);
149 // re-sizing output shape
150 ir::Shape new_shape = shape_inference::inferEltwiseShape(lhs.info().shape(), rhs.info().shape());
151 output.info().shape(new_shape);
154 void StaticShapeInferer::handleSimpleUnaryOp(const ir::Operation &op,
155 const ir::OperandIndex input_idx)
157 auto &operands = _lowered_subg->graph().operands();
158 const auto &input = operands.at(input_idx);
160 // get mutable output operand
161 const auto output_idx = op.getOutputs().at(0);
162 ir::Operand &output = operands.at(output_idx);
164 // re-sizing output shape
165 ir::Shape new_shape = input.info().shape();
166 output.info().shape(new_shape);
169 void StaticShapeInferer::dump()
171 auto get_shape_str = [](const ir::Shape &shape) {
172 std::stringstream sstream;
173 sstream << "shape : {";
174 for (int i = 0; i < shape.rank(); i++)
177 sstream << shape.dim(i);
179 sstream << " " << shape.dim(i);
182 return sstream.str();
185 _lowered_subg->graph().operands().iterate(
186 [&](const ir::OperandIndex &ind, const ir::Operand &operand) {
187 VERBOSE(StaticShapeInferer) << " " << ind << ", "
188 << (operand.info().isDynamic() ? "Dynamic" : "Static") << ", "
189 << get_shape_str(operand.info().shape()) << std::endl;
193 std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>>
194 StaticShapeInferer::createStaticShapeInferers(
195 const std::unordered_map<ir::SubgraphIndex, std::unique_ptr<LoweredGraph>> &lowered_subgs)
197 // Allocate StaticShapeInferer per each subgraph
198 std::unordered_map<ir::SubgraphIndex, std::unique_ptr<StaticShapeInferer>> inferers;
199 for (auto &&pair : lowered_subgs)
201 const auto &subg_index = pair.first;
202 auto &lowered_subg = pair.second;
203 inferers[subg_index] = std::make_unique<StaticShapeInferer>(lowered_subg.get());
206 // Append observers in all StaticShapeInferers
207 for (auto &&pair : lowered_subgs)
209 const auto &subg_index = pair.first;
210 auto &lowered_subg = pair.second;
212 // TODO: Change this iteration for all to controlflow iteration
213 lowered_subg->graph().operations().iterate(
214 [&](const ir::OperationIndex &, const ir::Operation &op) {
215 // A Function to append child inferers. These make it possible for a StaticShapeInferer to
216 // call StaticShapeInferes of child subgraphs recursively
217 auto appendChildInferer = [&](const ir::SubgraphIndex &child_subg_idx) {
218 auto *child_inferer = inferers.at(child_subg_idx).get();
219 inferers.at(subg_index)->appendChildInferer(child_subg_idx, child_inferer);
222 // A Function to appaend subg input observers. This makes it possible for a
223 // StaticShapeInferer to update inputs of child subgraphs
224 auto appendSubgraphInputObserver = [&](const ir::SubgraphIndex &child_subg_idx) {
225 std::vector<ir::Operand *> child_subg_inputs;
226 auto &child_subg = lowered_subgs.at(child_subg_idx)->graph();
227 for (const auto &input_idx : child_subg.getInputs())
229 auto operand_ptr = child_subg.operands().getRawPtr(input_idx);
230 child_subg_inputs.emplace_back(operand_ptr);
232 inferers.at(subg_index)
233 ->appendSubgInputObserver(child_subg_idx,
234 std::make_unique<OperandObserver>(child_subg_inputs));
237 // A Function to set controlflow output observers. This makes it possible for a
238 // StaticShapeInferer to update outputs of parent controlflow opeerations
239 auto setControlFlowOutputObserver = [&](const ir::SubgraphIndex &child_subg_idx) {
240 std::vector<ir::Operand *> cf_outputs;
241 auto &subg = lowered_subg->graph();
242 for (const auto &output_idx : op.getOutputs())
244 auto operand_ptr = subg.operands().getRawPtr(output_idx);
245 cf_outputs.emplace_back(operand_ptr);
247 inferers.at(child_subg_idx)
248 ->setControlflowOutputObserver(std::make_unique<OperandObserver>(cf_outputs));
251 // Append Observers in a StaticShapeInferer
252 if (op.opcode() == ir::OpCode::If)
254 const auto &if_op = nnfw::misc::polymorphic_downcast<const ir::operation::If &>(op);
256 appendChildInferer(if_op.param().then_subg_index);
257 appendChildInferer(if_op.param().else_subg_index);
259 appendSubgraphInputObserver(if_op.param().then_subg_index);
260 appendSubgraphInputObserver(if_op.param().else_subg_index);
262 setControlFlowOutputObserver(if_op.param().then_subg_index);
264 else if (op.opcode() == ir::OpCode::While)
266 const auto &while_op = nnfw::misc::polymorphic_downcast<const ir::operation::While &>(op);
268 appendChildInferer(while_op.param().cond_subg_index);
269 appendChildInferer(while_op.param().body_subg_index);
271 appendSubgraphInputObserver(while_op.param().cond_subg_index);
272 appendSubgraphInputObserver(while_op.param().body_subg_index);
274 setControlFlowOutputObserver(while_op.param().body_subg_index);
282 void StaticShapeInferer::visit(const ir::operation::ArgMinMax &op)
284 auto &operands = _lowered_subg->graph().operands();
286 const auto input_idx{op.getInputs().at(ir::operation::ArgMinMax::Input::INPUT)};
287 const auto &input = operands.at(input_idx);
289 const auto axis_idx{op.getInputs().at(ir::operation::ArgMinMax::Input::AXIS)};
290 const auto &axis = operands.at(axis_idx);
292 // get mutable output operand
293 const auto output_idx = op.getOutputs().at(0);
294 ir::Operand &output = operands.at(output_idx);
296 if (!axis.isConstant())
298 output.info().setDynamic();
302 const auto rank = input.info().shape().rank();
303 auto axis_value = axis.asScalar<int32_t>();
304 axis_value = axis_value < 0 ? axis_value + rank : axis_value;
306 // re-sizing output shape
307 ir::Shape new_shape =
308 shape_inference::inferArgMinMaxShape(input.info().shape(), axis_value, rank);
309 output.info().shape(new_shape);
312 void StaticShapeInferer::visit(const ir::operation::BatchMatMul &op)
314 auto &operands = _lowered_subg->graph().operands();
316 const auto lhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::LHS);
317 const auto rhs_index = op.getInputs().at(ir::operation::BatchMatMul::Input::RHS);
318 const auto output_index = op.getOutputs().at(0);
319 const auto &lhs = operands.at(lhs_index);
320 const auto &rhs = operands.at(rhs_index);
321 auto &output = operands.at(output_index);
322 auto new_shape = shape_inference::inferBatchMatMulShape(lhs.shape(), rhs.shape(), op.param());
323 output.info().shape(new_shape);
326 void StaticShapeInferer::visit(const ir::operation::BCQFullyConnected &op)
328 auto &operands = _lowered_subg->graph().operands();
330 const auto input_idx{op.getInputs().at(ir::operation::BCQFullyConnected::Input::INPUT)};
331 const auto &input = operands.at(input_idx);
333 const auto cluster_idx{
334 op.getInputs().at(ir::operation::BCQFullyConnected::Input::WEIGHTS_CLUSTERS)};
335 const auto &cluster = operands.at(cluster_idx);
337 const auto output_idx = op.getOutputs().at(0);
338 ir::Operand &output = operands.at(output_idx);
340 auto cluster_buf = reinterpret_cast<const int32_t *>(cluster.data()->base());
343 // re-sizing output shape
344 ir::Shape new_shape = shape_inference::inferBCQFullyConnectedShape(
345 input.info().shape(), cluster.info().shape(), cluster_buf);
346 output.info().shape(new_shape);
349 void StaticShapeInferer::visit(const ir::operation::BCQGather &op)
351 auto &operands = _lowered_subg->graph().operands();
353 const auto indices_idx{op.getInputs().at(ir::operation::BCQGather::Input::INDICES)};
354 const auto &indices = operands.at(indices_idx);
356 const auto input_binary_idx{op.getInputs().at(ir::operation::BCQGather::Input::INPUT_BINARY)};
357 const auto &input_binary = operands.at(input_binary_idx);
359 const auto cluster_idx{op.getInputs().at(ir::operation::BCQGather::Input::INPUT_CLUSTERS)};
360 const auto &cluster = operands.at(cluster_idx);
362 const auto output_idx = op.getOutputs().at(0);
363 ir::Operand &output = operands.at(output_idx);
365 auto cluster_buf = reinterpret_cast<const int32_t *>(cluster.data()->base());
368 auto rank = input_binary.shape().rank();
370 // re-sizing output shape
371 ir::Shape new_shape = shape_inference::inferBCQGatherShape(
372 indices.info().shape(), cluster.info().shape(), cluster_buf, rank, op.param());
374 output.info().shape(new_shape);
377 void StaticShapeInferer::visit(const ir::operation::BinaryArithmetic &op)
379 handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::BinaryArithmetic::Input::LHS),
380 op.getInputs().at(ir::operation::BinaryArithmetic::Input::RHS));
383 void StaticShapeInferer::visit(const ir::operation::BroadcastTo &op)
385 // get mutable output operand
386 auto &operands = _lowered_subg->graph().operands();
387 const auto output_idx = op.getOutputs().at(0);
388 ir::Operand &output = operands.at(output_idx);
390 const auto shape_idx{op.getInputs().at(ir::operation::BroadcastTo::Input::SHAPE)};
391 const auto &shape = operands.at(shape_idx);
393 if (!shape.isConstant())
395 output.info().setDynamic();
399 // assert(shape.typeInfo().type() == ir::DataType::INT32);
400 auto shape_buffer = reinterpret_cast<const int32_t *>(shape.data()->base());
402 // re-sizing output shape
403 ir::Shape new_shape = shape_inference::inferBroadcastToShape(shape.info().shape(), shape_buffer);
404 output.info().shape(new_shape);
407 void StaticShapeInferer::visit(const ir::operation::Comparison &op)
409 handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Comparison::Input::INPUT0),
410 op.getInputs().at(ir::operation::Comparison::Input::INPUT1));
413 void StaticShapeInferer::visit(const ir::operation::Concat &op)
415 auto &operands = _lowered_subg->graph().operands();
417 const auto input_count = op.getInputs().size();
419 const auto output_idx = op.getOutputs().at(0);
420 ir::Operand &output = operands.at(output_idx);
422 shape_inference::Shapes input_shapes;
423 for (uint32_t i = 0; i < input_count; i++)
425 const auto input_idx{op.getInputs().at(i)};
426 const auto &input = operands.at(input_idx);
427 input_shapes.emplace_back(input.shape());
430 ir::Shape out_shape = shape_inference::inferConcatShape(input_shapes, op.param());
432 // re-sizing output shape
433 output.info().shape(out_shape);
436 void StaticShapeInferer::visit(const ir::operation::Conv2D &op)
438 auto &operands = _lowered_subg->graph().operands();
440 const auto input_idx{op.getInputs().at(ir::operation::Conv2D::Input::INPUT)};
441 const auto &input = operands.at(input_idx);
442 const auto ker_idx{op.getInputs().at(ir::operation::Conv2D::Input::KERNEL)};
443 const auto &ker = operands.at(ker_idx);
444 const auto output_idx = op.getOutputs().at(0);
445 ir::Operand &output = operands.at(output_idx);
447 // re-sizing output shape
448 ir::Shape new_shape =
449 shape_inference::inferConv2DShape(input.info().shape(), ker.info().shape(), op.param());
450 output.info().shape(new_shape);
453 void StaticShapeInferer::visit(const ir::operation::ElementwiseActivation &op)
455 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::ElementwiseActivation::Input::INPUT));
458 void StaticShapeInferer::visit(const ir::operation::ElementwiseBinary &op)
460 handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::ElementwiseBinary::Input::LHS),
461 op.getInputs().at(ir::operation::ElementwiseBinary::Input::RHS));
464 void StaticShapeInferer::visit(const ir::operation::ElementwiseUnary &op)
466 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::ElementwiseUnary::Input::INPUT));
469 void StaticShapeInferer::visit(const ir::operation::ExpandDims &op)
471 auto &operands = _lowered_subg->graph().operands();
473 const auto input_idx{op.getInputs().at(ir::operation::ExpandDims::Input::INPUT)};
474 const auto &input = operands.at(input_idx);
475 const auto axis_idx{op.getInputs().at(ir::operation::ExpandDims::Input::AXIS)};
476 const auto &axis = operands.at(axis_idx);
477 const auto output_idx = op.getOutputs().at(0);
478 ir::Operand &output = operands.at(output_idx);
480 if (!axis.isConstant())
482 output.info().setDynamic();
486 // even when axis is constant, output shape should be recalculated since user might call
487 // nnfw_set_input_tensorinfo(input, some_new_shape)
488 auto axis_type = axis.typeInfo().type();
489 assert(axis_type == ir::DataType::INT32 || axis_type == ir::DataType::INT64);
491 assert(axis.data()->base());
493 (axis_type == ir::DataType::INT32)
494 ? reinterpret_cast<const int32_t *>(axis.data()->base())[0]
495 : static_cast<int32_t>(reinterpret_cast<const int64_t *>(axis.data()->base())[0]);
497 // re-sizing output shape
498 ir::Shape new_shape = shape_inference::inferExpandDimsShape(input.info().shape(), axis_value);
499 output.info().shape(new_shape);
502 void StaticShapeInferer::visit(const ir::operation::Fill &op)
504 auto &operands = _lowered_subg->graph().operands();
506 const auto shape_idx{op.getInputs().at(ir::operation::Fill::Input::SHAPE)};
507 const auto &shape = operands.at(shape_idx);
508 const auto output_idx = op.getOutputs().at(0);
509 ir::Operand &output = operands.at(output_idx);
511 if (!shape.isConstant())
513 output.info().setDynamic();
517 const auto dims_type = shape.typeInfo().type();
518 assert(dims_type == ir::DataType::INT32 || dims_type == ir::DataType::INT64);
520 auto dims_buf = shape.data()->base();
523 const auto &dims_shape = shape.info().shape();
524 auto new_shape = ((dims_type == ir::DataType::INT32)
525 ? shape_inference::inferFillShape<int32_t>(
526 dims_shape, reinterpret_cast<const int32_t *>(dims_buf))
527 : shape_inference::inferFillShape<int64_t>(
528 dims_shape, reinterpret_cast<const int64_t *>(dims_buf)));
530 output.info().shape(new_shape);
533 void StaticShapeInferer::visit(const ir::operation::FullyConnected &op)
535 auto &operands = _lowered_subg->graph().operands();
537 const auto input_idx{op.getInputs().at(ir::operation::FullyConnected::Input::INPUT)};
538 const auto &input = operands.at(input_idx);
540 const auto ker_idx{op.getInputs().at(ir::operation::FullyConnected::Input::WEIGHT)};
541 const auto &ker = operands.at(ker_idx);
543 // get mutable output operand
544 const auto output_idx = op.getOutputs().at(0);
545 ir::Operand &output = operands.at(output_idx);
546 // re-sizing output shape
547 ir::Shape new_shape =
548 shape_inference::inferFullyConnectedShape(input.info().shape(), ker.info().shape());
549 output.info().shape(new_shape);
552 void StaticShapeInferer::visit(const ir::operation::FusedBatchNorm &op)
554 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::FusedBatchNorm::Input::INPUT));
557 void StaticShapeInferer::visit(const ir::operation::Gather &op)
559 auto &operands = _lowered_subg->graph().operands();
561 const auto input_idx{op.getInputs().at(ir::operation::Gather::Input::INPUT)};
562 const auto &input = operands.at(input_idx);
564 // get mutable output operand
565 const auto output_idx = op.getOutputs().at(0);
566 ir::Operand &output = operands.at(output_idx);
568 const auto indices_idx{op.getInputs().at(ir::operation::Gather::Input::INDICES)};
569 const auto &indices = operands.at(indices_idx);
570 const auto rank = input.info().shape().rank();
571 const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
573 assert(0 <= axis && axis < rank);
575 // re-sizing output shape
576 ir::Shape new_shape =
577 shape_inference::inferGatherShape(input.info().shape(), indices.info().shape(), axis, rank);
578 output.info().shape(new_shape);
581 void StaticShapeInferer::visit(const ir::operation::If &op)
583 // re-sizing input shapes of then/else subgraph
584 const std::vector<ir::OperandIndex> inputs{op.getInputs().begin() + 1, op.getInputs().end()};
586 std::vector<ir::OperandInfo> inputs_info;
587 const auto &graph = _lowered_subg->graph();
588 for (size_t i = 0; i < inputs.size(); ++i)
590 const auto &operand_info = graph.operands().at(inputs.at(i)).info();
591 inputs_info.emplace_back(operand_info);
593 _subg_input_observers.at(op.param().then_subg_index)->updateShapes(inputs_info);
594 _child_inferers.at(op.param().then_subg_index)->infer();
596 _subg_input_observers.at(op.param().else_subg_index)->updateShapes(inputs_info);
597 _child_inferers.at(op.param().else_subg_index)->infer();
600 void StaticShapeInferer::visit(const ir::operation::L2Normalization &op)
602 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::L2Normalization::Input::INPUT));
605 void StaticShapeInferer::visit(const ir::operation::LSTM &op)
607 auto &operands = _lowered_subg->graph().operands();
609 const auto output_index{op.getOutputs().at(ir::operation::LSTM::Output::OUTPUT)};
610 auto &output = operands.at(output_index);
612 const auto output_state_out_index{
613 op.getOutputs().at(ir::operation::LSTM::Output::OUTPUT_STATE_OUT)};
615 const auto cell_state_out_index{op.getOutputs().at(ir::operation::LSTM::Output::CELL_STATE_OUT)};
617 const auto scratch_buffer_index{op.getOutputs().at(ir::operation::LSTM::Output::SCRATCH_BUFFER)};
619 if (output.info().isDynamic() ||
620 (operands.exist(output_state_out_index) &&
621 operands.at(output_state_out_index).info().isDynamic()) ||
622 (operands.exist(cell_state_out_index) &&
623 operands.at(cell_state_out_index).info().isDynamic()) ||
624 (operands.exist(scratch_buffer_index) &&
625 operands.at(scratch_buffer_index).info().isDynamic()))
628 const auto input_index{op.getInputs().at(ir::operation::LSTM::Input::INPUT)};
629 const auto &input = operands.at(input_index);
631 const auto input_to_output_weights_index{
632 op.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_OUTPUT_WEIGHTS)};
633 const auto &input_to_output_weights = operands.at(input_to_output_weights_index);
635 const auto recurrent_to_output_weights_index{
636 op.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_OUTPUT_WEIGHTS)};
637 const auto &recurrent_to_output_weights = operands.at(recurrent_to_output_weights_index);
640 const int n_batch = (input.shape().rank() == 3 && op.param().time_major) ? input.shape().dim(1)
641 : input.shape().dim(0);
642 const int n_cell = input_to_output_weights.shape().dim(0);
643 const int n_output = recurrent_to_output_weights.shape().dim(1);
644 if (input.shape().rank() == 3)
646 if (op.param().time_major)
647 output.info().shape(ir::Shape{input.shape().dim(0), n_batch, n_output});
649 output.info().shape(ir::Shape{n_batch, input.shape().dim(1), n_output});
653 assert(input.shape().rank() == 2);
654 output.info().shape(ir::Shape{n_batch, n_output});
657 if (operands.exist(output_state_out_index))
659 auto &output_state_out = operands.at(output_state_out_index);
660 output_state_out.info().shape(ir::Shape{n_batch, n_output});
663 if (operands.exist(cell_state_out_index))
665 auto &cell_state_out = operands.at(cell_state_out_index);
666 cell_state_out.info().shape(ir::Shape{n_batch, n_cell});
669 if (operands.exist(scratch_buffer_index))
671 auto &scratch_buffer = operands.at(scratch_buffer_index);
673 const auto input_to_input_weights_index{
674 op.getInputs().at(ir::operation::LSTM::Input::INPUT_TO_INPUT_WEIGHTS)};
675 const auto recurrent_to_input_weights_index{
676 op.getInputs().at(ir::operation::LSTM::Input::RECURRENT_TO_INPUT_WEIGHTS)};
678 bool has_input_to_input_weights =
679 operands.at(input_to_input_weights_index).shape().dim(0) != 0 &&
680 operands.at(input_to_input_weights_index).shape().dim(1) != 0;
681 bool has_recurrent_to_input_weights =
682 operands.at(recurrent_to_input_weights_index).shape().dim(0) != 0 &&
683 operands.at(recurrent_to_input_weights_index).shape().dim(1) != 0;
685 // NOTE The cell_to_input_weights do not exist in non-peephole although regular LSTM(non-CIFG).
688 bool has_cifg_param = has_input_to_input_weights && has_recurrent_to_input_weights;
691 scratch_buffer.info().shape(ir::Shape{n_batch, n_cell * 4});
695 scratch_buffer.info().shape(ir::Shape{n_batch, n_cell * 3});
700 void StaticShapeInferer::visit(const ir::operation::MatrixBandPart &op)
702 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::MatrixBandPart::Input::INPUT));
705 void StaticShapeInferer::visit(const ir::operation::OneHot &op)
707 auto &operands = _lowered_subg->graph().operands();
709 const auto indice_idx{op.getInputs().at(ir::operation::OneHot::Input::INDICES)};
710 const auto &indice = operands.at(indice_idx);
711 const auto depth_idx{op.getInputs().at(ir::operation::OneHot::Input::DEPTH)};
712 const auto &depth = operands.at(depth_idx);
714 const auto axis = op.param().axis;
716 auto output_idx = op.getOutputs().at(0);
717 ir::Operand &output = operands.at(output_idx);
719 if (!depth.isConstant())
721 output.info().setDynamic();
725 const auto *depth_buf = reinterpret_cast<const int32_t *>(depth.data()->base());
727 // re-sizing output shape
728 ir::Shape new_shape = shape_inference::inferOnehotShape(indice.info().shape(), *depth_buf, axis);
729 output.info().shape(new_shape);
732 void StaticShapeInferer::visit(const ir::operation::Pack &op)
734 auto &operands = _lowered_subg->graph().operands();
736 const auto input_idx{op.getInputs().at(0)};
737 const auto &input = operands.at(input_idx);
739 // get mutable output operand
740 const auto output_idx = op.getOutputs().at(0);
741 ir::Operand &output = operands.at(output_idx);
743 const auto rank = input.shape().rank() + 1;
744 const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
745 const auto num = op.param().num;
747 assert(0 <= axis && axis < rank);
749 // re-sizing output shape
750 ir::Shape new_shape = shape_inference::inferPackShape(input.info().shape(), axis, rank, num);
751 output.info().shape(new_shape);
754 void StaticShapeInferer::visit(const ir::operation::Pad &op)
756 auto &operands = _lowered_subg->graph().operands();
758 const auto input_idx{op.getInputs().at(ir::operation::Pad::Input::INPUT)};
759 const auto &input = operands.at(input_idx);
761 const auto pad_idx{op.getInputs().at(ir::operation::Pad::Input::PAD)};
762 const auto &pad = operands.at(pad_idx);
764 // get mutable output operand
765 const auto output_idx = op.getOutputs().at(0);
766 ir::Operand &output = operands.at(output_idx);
768 // if pad is not constant, output also becomes dynamic
769 if (!pad.isConstant())
771 output.info().setDynamic();
775 // re-sizing output shape
776 const auto new_shape = shape_inference::inferPadShape(
777 input.shape(), reinterpret_cast<const int32_t *>(pad.data()->base()),
778 pad.shape().num_elements());
779 output.info().shape(new_shape);
782 void StaticShapeInferer::visit(const ir::operation::Permute &op)
784 auto &operands = _lowered_subg->graph().operands();
786 const auto input_idx{op.getInputs().at(0)};
787 const auto &input = operands.at(input_idx);
788 const auto output_idx = op.getOutputs().at(0);
789 ir::Operand &output = operands.at(output_idx);
791 // re-sizing output shape
792 // Permute is a special operation that layouts of input/output may be different on backend
793 // However, it is not applied here, so input/output have the same layout of frontend. Because
794 // "ExecutorFactory" would convert shape of input/output accoding to the layouts when registering
795 // operand info to "TensorBuilder" after calling "StaticShapeInferer"
796 const auto new_shape = input.info().shape();
797 output.info().shape(new_shape);
800 void StaticShapeInferer::visit(const ir::operation::Pow &op)
802 handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::Pow::Input::LHS),
803 op.getInputs().at(ir::operation::Pow::Input::RHS));
806 void StaticShapeInferer::visit(const ir::operation::Range &op)
808 auto &operands = _lowered_subg->graph().operands();
810 const auto start_idx{op.getInputs().at(ir::operation::Range::Input::START)};
811 const auto limit_idx{op.getInputs().at(ir::operation::Range::Input::LIMIT)};
812 const auto delta_idx{op.getInputs().at(ir::operation::Range::Input::DELTA)};
813 const auto &start_op = operands.at(start_idx);
814 const auto &limit_op = operands.at(limit_idx);
815 const auto &delta_op = operands.at(delta_idx);
817 // get mutable output operand
818 const auto output_idx = op.getOutputs().at(0);
819 ir::Operand &output = operands.at(output_idx);
822 if (start_op.isConstant() && limit_op.isConstant() && delta_op.isConstant())
824 assert(start_op.typeInfo().type() == limit_op.typeInfo().type() &&
825 start_op.typeInfo().type() == delta_op.typeInfo().type());
826 if (output.typeInfo().type() == ir::DataType::FLOAT32)
828 new_shape = shape_inference::inferRangeShape<float>(
829 start_op.asScalar<float>(), limit_op.asScalar<float>(), delta_op.asScalar<float>());
831 else if (output.typeInfo().type() == ir::DataType::INT32)
833 new_shape = shape_inference::inferRangeShape<int32_t>(
834 start_op.asScalar<int32_t>(), limit_op.asScalar<int32_t>(), delta_op.asScalar<int32_t>());
836 assert(output.shape() == new_shape);
840 output.info().setDynamic();
844 void StaticShapeInferer::visit(const ir::operation::Reduce &op)
846 auto &operands = _lowered_subg->graph().operands();
848 const auto input_idx{op.getInputs().at(ir::operation::Reduce::Input::INPUT)};
849 const auto &input = operands.at(input_idx);
851 const auto axes_idx{op.getInputs().at(ir::operation::Reduce::Input::AXES)};
852 const auto &axes = operands.at(axes_idx);
854 // get mutable output operand
855 const auto output_idx = op.getOutputs().at(0);
856 ir::Operand &output = operands.at(output_idx);
858 std::vector<int32_t> axes_vec;
859 for (size_t i = 0; i < axes.shape().num_elements(); ++i)
861 switch (axes.typeInfo().type())
863 case ir::DataType::INT32:
865 axes_vec.emplace_back(reinterpret_cast<const int32_t *>(axes.data()->base())[i]);
868 case ir::DataType::INT64:
870 axes_vec.emplace_back(reinterpret_cast<const int64_t *>(axes.data()->base())[i]);
874 throw std::runtime_error("StaticShapeInferer " + op.name() + ": Not supported data type");
878 const auto keep_dims = op.param().keep_dims;
880 // re-sizing output shape
881 ir::Shape new_shape =
882 shape_inference::inferReduceShape(input.info().shape(), axes_vec, keep_dims);
883 output.info().shape(new_shape);
886 void StaticShapeInferer::visit(const ir::operation::Reshape &op)
888 auto &operands = _lowered_subg->graph().operands();
890 const auto input_idx{op.getInputs().at(ir::operation::Reshape::Input::INPUT)};
891 const auto &input = operands.at(input_idx);
893 // get mutable output operand
894 const auto output_idx = op.getOutputs().at(0);
895 ir::Operand &output = operands.at(output_idx);
897 // New shape is given by second input tensor
898 if (op.getInputs().size() == 2)
900 // Let's check the second input
901 const auto shape_idx{op.getInputs().at(ir::operation::Reshape::Input::SHAPE)};
902 const auto &shape = operands.at(shape_idx);
904 if (shape.isConstant())
906 const auto *shape_buf = reinterpret_cast<const int32_t *>(shape.data()->base());
909 ir::Shape new_shape = shape_inference::inferReshapeShape(
910 shape_buf, shape.shape().num_elements(), input.shape().num_elements());
912 // if shape is from Const, TFLC put the shape of output into tensor
913 if (new_shape != output.shape())
915 // change on output shape
916 output.info().shape(new_shape);
921 // if shape is NOT Const, set output shape to be dynamic_
922 output.info().setDynamic();
925 // New shape is given by option
926 else if (op.param().new_shape.size() != 0)
928 // Let's check the new_shape option
929 auto shape = op.param().new_shape;
930 ir::Shape new_shape =
931 shape_inference::inferReshapeShape(shape.data(), shape.size(), input.shape().num_elements());
933 if (new_shape != output.shape())
935 // change on output shape
936 output.info().shape(new_shape);
941 throw std::runtime_error("Reshape: new shape is missing");
945 void StaticShapeInferer::visit(const ir::operation::ResizeBilinear &op)
947 auto &operands = _lowered_subg->graph().operands();
949 const auto input_idx{op.getInputs().at(ir::operation::ResizeBilinear::Input::INPUT)};
950 const auto &input = operands.at(input_idx);
952 // get mutable output operand
953 const auto output_idx = op.getOutputs().at(0);
954 ir::Operand &output = operands.at(output_idx);
956 int32_t height_out, width_out;
957 if (op.getInputs().size() == 2)
959 auto &size = operands.at(op.getInputs().at(ir::operation::ResizeBilinear::Input::SIZE));
960 if (!size.isConstant())
962 output.info().setDynamic();
965 const auto size_v = size.asVector<std::int32_t>();
966 height_out = size_v[0];
967 width_out = size_v[1];
971 height_out = op.param().height_out;
972 width_out = op.param().width_out;
975 // Shape inferencing logic based on Params
976 ir::Shape new_shape =
977 shape_inference::inferResizeBilinearShape(input.shape(), height_out, width_out);
979 // if size_op is from Const, TFLC put the shape of output into tensor
980 if (new_shape != output.shape())
982 // change on output shape
983 output.info().shape(new_shape);
987 void StaticShapeInferer::visit(const ir::operation::Reverse &op)
989 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Reverse::Input::INPUT));
992 void StaticShapeInferer::visit(const ir::operation::Select &op)
994 auto &operands = _lowered_subg->graph().operands();
996 const auto input_cond_idx{op.getInputs().at(ir::operation::Select::Input::CONDITION)};
997 const auto &input_cond = operands.at(input_cond_idx);
999 const auto input_true_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_TRUE)};
1000 const auto &input_true = operands.at(input_true_idx);
1002 const auto input_false_idx{op.getInputs().at(ir::operation::Select::Input::INPUT_FALSE)};
1003 const auto &input_false = operands.at(input_false_idx);
1005 auto output_idx = op.getOutputs().at(0);
1006 ir::Operand &output = operands.at(output_idx);
1008 // Select output shpae
1009 ir::Shape new_shape = shape_inference::inferSelectShape(
1010 input_cond.info().shape(), input_true.info().shape(), input_false.info().shape());
1011 output.info().shape(new_shape);
1014 void StaticShapeInferer::visit(const ir::operation::Shape &op)
1016 auto &operands = _lowered_subg->graph().operands();
1018 const auto input_idx{op.getInputs().at(0)};
1019 const auto &input = operands.at(input_idx);
1021 // get mutable output operand
1022 const auto output_idx = op.getOutputs().at(0);
1023 ir::Operand &output = operands.at(output_idx);
1025 // re-sizing output shape
1026 ir::Shape output_shape;
1027 output_shape.append(input.info().shape().rank());
1029 output.info().shape(output_shape);
1032 void StaticShapeInferer::visit(const ir::operation::Slice &op)
1034 auto &operands = _lowered_subg->graph().operands();
1036 const auto input_index{op.getInputs().at(ir::operation::Slice::Input::INPUT)};
1037 const auto &input = operands.at(input_index);
1038 const auto begins_index{op.getInputs().at(ir::operation::Slice::Input::BEGINS)};
1039 const auto &begins = operands.at(begins_index);
1040 const auto sizes_index{op.getInputs().at(ir::operation::Slice::Input::SIZES)};
1041 const auto &sizes = operands.at(sizes_index);
1042 const auto output_index = op.getOutputs().at(0);
1043 ir::Operand &output = operands.at(output_index);
1045 // Whether input is constant or not does not affect whether output is dynamic or not
1046 if (!(begins.isConstant() && sizes.isConstant()))
1048 output.info().setDynamic();
1052 auto begins_buf = begins.data()->base();
1053 auto sizes_buf = sizes.data()->base();
1055 const auto begins_type = begins.typeInfo().type();
1056 assert(begins_type == ir::DataType::INT32 || begins_type == ir::DataType::INT64);
1057 assert(begins_type == sizes.typeInfo().type());
1059 ir::Shape new_shape =
1060 (begins_type == ir::DataType::INT32)
1061 ? shape_inference::inferSliceShape<int32_t>(input.info().shape(),
1062 reinterpret_cast<const int32_t *>(begins_buf),
1063 reinterpret_cast<const int32_t *>(sizes_buf))
1064 : shape_inference::inferSliceShape<int64_t>(input.info().shape(),
1065 reinterpret_cast<const int64_t *>(begins_buf),
1066 reinterpret_cast<const int64_t *>(sizes_buf));
1067 output.info().shape(new_shape);
1070 void StaticShapeInferer::visit(const ir::operation::Softmax &op)
1072 handleSimpleUnaryOp(op, op.getInputs().at(ir::operation::Softmax::Input::INPUT));
1075 void StaticShapeInferer::visit(const ir::operation::SpaceToBatchND &op)
1077 auto &operands = _lowered_subg->graph().operands();
1079 const auto output_index = op.getOutputs().at(0);
1080 const auto input_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::INPUT)};
1081 const auto block_shape_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::BLOCK_SIZE)};
1082 const auto padding_idx{op.getInputs().at(ir::operation::SpaceToBatchND::Input::PADDINGS)};
1084 ir::Operand &output = operands.at(output_index);
1085 const auto &input = operands.at(input_idx);
1086 const auto &block_shape = operands.at(block_shape_idx);
1087 const auto &padding = operands.at(padding_idx);
1089 // Whether input is constant or not does not affect whether output is dynamic or not
1090 if (!(block_shape.isConstant() && padding.isConstant()))
1092 output.info().setDynamic();
1096 auto input_shape = input.info().shape();
1097 auto block_shape_shape = block_shape.info().shape();
1098 auto padding_shape = padding.info().shape();
1100 auto block_shape_data = reinterpret_cast<const int32_t *>(block_shape.data()->base());
1101 auto padding_data = reinterpret_cast<const int32_t *>(padding.data()->base());
1103 ir::Shape new_shape = shape_inference::inferSpaceToBatchNDShape(
1104 input_shape, block_shape_shape, padding_shape, block_shape_data, padding_data);
1106 output.info().shape(new_shape);
1109 void StaticShapeInferer::visit(const ir::operation::Split &op)
1111 auto &operands = _lowered_subg->graph().operands();
1113 const auto input_idx{op.getInputs().at(ir::operation::Split::Input::INPUT)};
1114 const auto &input = operands.at(input_idx);
1116 const auto axis_idx{op.getInputs().at(ir::operation::Split::Input::AXIS)};
1117 const auto &axis = operands.at(axis_idx);
1119 auto outputs = op.getOutputs();
1120 if (!axis.isConstant())
1122 for (auto output_idx : outputs)
1124 ir::Operand &output = operands.at(output_idx);
1125 output.info().setDynamic();
1130 const auto num_splits = op.param().num_splits;
1132 const auto rank = input.info().shape().rank();
1133 auto axis_value = axis.asScalar<int32_t>();
1134 axis_value = axis_value < 0 ? axis_value + rank : axis_value;
1136 assert(0 <= axis_value && axis_value < rank);
1138 ir::Shape new_shape =
1139 shape_inference::inferSplitShape(input.info().shape(), axis_value, num_splits);
1140 for (auto output_idx : outputs)
1142 ir::Operand &output = operands.at(output_idx);
1143 output.info().shape(new_shape);
1147 void StaticShapeInferer::visit(const ir::operation::SquaredDifference &op)
1149 handleBinaryArithmeticOp(op, op.getInputs().at(ir::operation::SquaredDifference::Input::LHS),
1150 op.getInputs().at(ir::operation::SquaredDifference::Input::RHS));
1153 void StaticShapeInferer::visit(const ir::operation::Squeeze &op)
1155 auto &operands = _lowered_subg->graph().operands();
1157 const auto input_idx{op.getInputs().at(ir::operation::Squeeze::Input::INPUT)};
1158 const auto &input = operands.at(input_idx);
1160 const auto output_idx = op.getOutputs().at(0);
1161 ir::Operand &output = operands.at(output_idx);
1163 // Squeeze output shpae
1164 ir::Shape new_shape = shape_inference::inferSqueezeShape(input.info().shape(), op.param());
1165 output.info().shape(new_shape);
1168 void StaticShapeInferer::visit(const ir::operation::StridedSlice &op)
1170 auto &operands = _lowered_subg->graph().operands();
1172 const auto input_index{op.getInputs().at(ir::operation::StridedSlice::Input::INPUT)};
1173 const auto &input = operands.at(input_index);
1174 const auto starts_index{op.getInputs().at(ir::operation::StridedSlice::Input::STARTS)};
1175 const auto &starts = operands.at(starts_index);
1176 const auto ends_index{op.getInputs().at(ir::operation::StridedSlice::Input::ENDS)};
1177 const auto &ends = operands.at(ends_index);
1178 const auto strides_index{op.getInputs().at(ir::operation::StridedSlice::Input::STRIDES)};
1179 const auto &strides = operands.at(strides_index);
1180 const auto output_index = op.getOutputs().at(0);
1181 ir::Operand &output = operands.at(output_index);
1183 if (!(starts.isConstant() && ends.isConstant() && strides.isConstant()))
1185 output.info().setDynamic();
1189 const auto begin_mask = op.param().begin_mask;
1190 const auto end_mask = op.param().end_mask;
1191 const auto shrink_axis_mask = op.param().shrink_axis_mask;
1192 const auto rank = input.info().shape().rank();
1194 auto starts_buf = reinterpret_cast<const uint32_t *>(starts.data()->base());
1195 auto ends_buf = reinterpret_cast<const uint32_t *>(ends.data()->base());
1196 auto strides_buf = reinterpret_cast<const uint32_t *>(strides.data()->base());
1198 auto op_params = shape_inference::buildStridedSliceParams(
1199 starts_buf, ends_buf, strides_buf, begin_mask, end_mask, shrink_axis_mask, rank);
1201 ir::Shape new_shape =
1202 shape_inference::inferStridedSliceShape(input.info().shape(), op_params, rank);
1203 output.info().shape(new_shape);
1206 void StaticShapeInferer::visit(const ir::operation::Tile &op)
1208 auto &operands = _lowered_subg->graph().operands();
1210 const auto input_idx{op.getInputs().at(ir::operation::Tile::Input::INPUT)};
1211 const auto &input = operands.at(input_idx);
1213 const auto multiplier_idx{op.getInputs().at(ir::operation::Tile::Input::MULTIPLES)};
1214 const auto &multiplier = operands.at(multiplier_idx);
1216 const auto output_idx = op.getOutputs().at(0);
1217 ir::Operand &output = operands.at(output_idx);
1219 if (!multiplier.isConstant())
1221 output.info().setDynamic();
1225 auto multiplier_buffer = reinterpret_cast<const int32_t *>(multiplier.data()->base());
1226 assert(multiplier_buffer);
1228 // re-sizing output shape
1229 auto new_shape = shape_inference::inferTileShape(input.info().shape(), multiplier_buffer,
1230 multiplier.shape().num_elements());
1231 output.info().shape(new_shape);
1234 void StaticShapeInferer::visit(const ir::operation::Transpose &op)
1236 auto &operands = _lowered_subg->graph().operands();
1238 const auto input_idx{op.getInputs().at(ir::operation::Transpose::Input::INPUT)};
1239 const auto &input = operands.at(input_idx);
1241 const auto perm_idx{op.getInputs().at(ir::operation::Transpose::Input::PERMUTATION)};
1242 const auto &perm = operands.at(perm_idx);
1244 // perm.shape() != ir::Shape{0} means that perm is (n-1...0)
1245 // TODO This condition changes to perm.num_elements() == 0
1246 const auto is_regular_transpose = perm.shape() == ir::Shape{0};
1248 // get mutable output operand
1249 const auto output_idx = op.getOutputs().at(0);
1250 auto &output = operands.at(output_idx);
1251 if (!perm.isConstant() && !is_regular_transpose)
1253 output.info().setDynamic();
1257 ir::Shape new_shape;
1258 if (is_regular_transpose)
1260 // Call by (n-1...0)
1261 new_shape = shape_inference::inferTransposeShape(input.info().shape(), nullptr, 0);
1266 if (input.info().shape().rank() != static_cast<int>(perm.info().shape().num_elements()))
1268 throw std::runtime_error("StaticShapeInferer failed, bad rank size: " +
1269 std::to_string(perm.info().shape().num_elements()));
1272 // set output shape, based on input and params
1273 const auto perm_buf = reinterpret_cast<const int32_t *>(perm.data()->base());
1274 new_shape = shape_inference::inferTransposeShape(input.info().shape(), perm_buf,
1275 perm.shape().num_elements());
1277 output.info().shape(new_shape);
1280 void StaticShapeInferer::visit(const ir::operation::Unpack &op)
1282 auto &operands = _lowered_subg->graph().operands();
1284 const auto input_idx{op.getInputs().at(0)};
1285 const auto &input = operands.at(input_idx);
1286 const auto num = op.param().num;
1287 const auto rank = input.shape().rank();
1288 const auto axis = ((op.param().axis < 0) ? rank + op.param().axis : op.param().axis);
1290 assert(axis < rank);
1293 for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++)
1295 const auto output_idx = op.getOutputs().at(out_tensor_idx);
1296 ir::Operand &output = operands.at(output_idx);
1297 output.info().setDynamic();
1302 ir::Shape new_shape = shape_inference::inferUnpackShape(input.info().shape(), axis, rank);
1304 // re-sizing output shape
1305 for (int out_tensor_idx = 0; out_tensor_idx < num; out_tensor_idx++)
1307 const auto output_idx = op.getOutputs().at(out_tensor_idx);
1308 ir::Operand &output = operands.at(output_idx);
1309 output.info().shape(new_shape);
1313 void StaticShapeInferer::visit(const ir::operation::While &op)
1315 auto body_input_observer = _subg_input_observers.at(op.param().body_subg_index).get();
1316 auto cond_input_observer = _subg_input_observers.at(op.param().cond_subg_index).get();
1317 // re-sizing input shapes of body subgraph
1318 const auto inputs = op.getInputs();
1319 std::vector<ir::OperandInfo> inputs_info;
1320 const auto &graph = _lowered_subg->graph();
1321 for (size_t i = 0; i < inputs.size(); ++i)
1323 const auto &operand_info = graph.operands().at(inputs.at(i)).info();
1324 inputs_info.emplace_back(operand_info);
1327 body_input_observer->updateShapes(inputs_info);
1328 _child_inferers.at(op.param().body_subg_index)->infer();
1330 // Check whether while operation's shapes are predictable
1331 // This while op's outputs are also updated in the above function
1332 // "_child_inferers.at(op.param().body_subg_index)->update()". That means that body's outputs and
1333 // thils op's outputs must have the same shape. So we can predict whether body subgraphs will
1334 // change at every step by comparing the shapes of inputs/outputs. If any of shape of body outputs
1335 // and inputs are different Non-constant operands will be set to dynamic.
1336 bool check_unpredictable_dynamic = false;
1337 const auto &updated_outputs = op.getOutputs();
1338 assert(inputs_info.size() == updated_outputs.size());
1339 for (size_t i = 0; i < updated_outputs.size(); ++i)
1341 const auto &input_info = inputs_info.at(i);
1342 const auto &output_info = graph.operands().at(updated_outputs.at(i)).info();
1343 if (input_info.isDynamic() != output_info.isDynamic() ||
1344 input_info.shape() != output_info.shape())
1346 check_unpredictable_dynamic = true;
1351 if (check_unpredictable_dynamic)
1353 body_input_observer->updateShapes(inputs_info, check_unpredictable_dynamic);
1354 _child_inferers.at(op.param().body_subg_index)->infer();
1356 cond_input_observer->updateShapes(inputs_info, check_unpredictable_dynamic);
1357 _child_inferers.at(op.param().cond_subg_index)->infer();
1360 void StaticShapeInferer::visit(const ir::operation::DetectionPostProcess &op)
1362 // TODO: NMS supports very limited input/output size.
1363 ir::operation::DetectionPostProcess::Param param = op.param();
1365 auto &operands = _lowered_subg->graph().operands();
1366 const int num_detected_boxes = param.max_detections * param.max_classes_per_detection;
1368 const auto output_idx1 = op.getOutputs().at(0);
1369 auto &output1 = operands.at(output_idx1);
1370 output1.info().shape({1, num_detected_boxes, 4});
1372 const auto output_idx2 = op.getOutputs().at(1);
1373 auto &output2 = operands.at(output_idx2);
1374 output2.info().shape({1, num_detected_boxes});
1376 const auto output_idx3 = op.getOutputs().at(2);
1377 auto &output3 = operands.at(output_idx3);
1378 output3.info().shape({1, num_detected_boxes});
1380 const auto output_idx4 = op.getOutputs().at(3);
1381 auto &output4 = operands.at(output_idx4);
1382 output4.info().shape({1});
1384 void StaticShapeInferer::visit(const ir::operation::Bulk &op)
1386 auto &operands = _lowered_subg->graph().operands();
1388 // TODO: support multiple inputs/outputs
1389 const auto input_idx{op.getInputs().at(0)};
1390 const auto &input = operands.at(input_idx);
1391 const auto output_idx = op.getOutputs().at(0);
1392 ir::Operand &output = operands.at(output_idx);
1394 auto cur_input_shape = input.info().shape();
1395 auto origin_input_shape = op.param().origin_input_shapes[0];
1396 auto cur_output_shape = output.info().shape();
1397 auto origin_output_shape = op.param().origin_output_shapes[0];
1399 // TODO: more check for valid batch request
1400 if ((cur_input_shape.dim(0) < origin_output_shape.dim(0)) ||
1401 (cur_input_shape.dim(0) % origin_output_shape.dim(0) != 0))
1403 throw std::runtime_error("StaticShapeInferer " + op.name() + ": Not supported batch size");
1405 size_t batch_multiplier = cur_input_shape.dim(0) / origin_output_shape.dim(0);
1407 ir::Shape new_shape;
1408 new_shape.append(origin_output_shape.dim(0) * batch_multiplier);
1409 for (int32_t d = 1; d < origin_output_shape.rank(); ++d)
1410 new_shape.append(origin_output_shape.dim(d));
1412 output.info().shape(new_shape);
1415 } // namespace compiler
1417 } // namespace onert