2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Service/CircleTypeInferenceRule.h"
19 #include <luci/IR/CircleDialect.h>
20 #include <luci/IR/CircleNodeVisitor.h>
21 #include <luci/IR/CircleNodes.h>
28 struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataType>
30 // TODO Given a tensor x of complex numbers, Abs operation returns a tensor of type float32 or
32 loco::DataType visit(const luci::CircleAbs *node) final { return loco::dtype_get(node->x()); }
34 loco::DataType visit(const luci::CircleAdd *node) final { return loco::dtype_get(node->x()); }
36 loco::DataType visit(const luci::CircleAddN *node) final
38 auto dtype = loco::dtype_get(node->inputs(0));
40 for (uint32_t idx = 1; idx < node->arity(); ++idx)
42 auto dtype_idx = loco::dtype_get(node->inputs(idx));
43 if (dtype != dtype_idx)
45 INTERNAL_EXN_V("ADD_N dtype not same as the first input: ", idx);
49 return loco::dtype_get(node->inputs(0));
52 loco::DataType visit(const luci::CircleArgMax *node) final { return node->output_type(); }
54 loco::DataType visit(const luci::CircleArgMin *node) final { return node->output_type(); }
56 loco::DataType visit(const luci::CircleAveragePool2D *node) final
58 return loco::dtype_get(node->value());
61 loco::DataType visit(const luci::CircleBatchMatMul *node) final
63 return loco::dtype_get(node->x());
66 loco::DataType visit(const luci::CircleBatchToSpaceND *node) final
68 return loco::dtype_get(node->input());
71 loco::DataType visit(const luci::CircleCast *node) final { return node->dtype(); }
73 loco::DataType visit(const luci::CircleCeil *node) final { return loco::dtype_get(node->x()); }
75 loco::DataType visit(const luci::CircleConcatenation *node) final
77 // TODO Support when CircleConcatenation has 0 input
78 assert(node->numValues() > 0);
80 for (uint32_t i = 1; i < node->numValues(); ++i)
81 assert(loco::dtype_get(node->values(i - 1)) == loco::dtype_get(node->values(i)));
83 return loco::dtype_get(node->values(0));
86 loco::DataType visit(const luci::CircleConst *node) final { return node->dtype(); }
88 loco::DataType visit(const luci::CircleConv2D *node) final
90 return loco::dtype_get(node->input());
93 loco::DataType visit(const luci::CircleCos *node) final { return loco::dtype_get(node->x()); }
95 loco::DataType visit(const luci::CircleCustom *node) final
97 if (node->custom_code() == "BatchMatMulV2")
99 return loco::dtype_get(node->inputs(0));
101 return node->dtype();
104 loco::DataType visit(const luci::CircleDepthToSpace *node) final
106 return loco::dtype_get(node->input());
109 loco::DataType visit(const luci::CircleDepthwiseConv2D *node) final
111 return loco::dtype_get(node->input());
114 loco::DataType visit(const luci::CircleDiv *node) final { return loco::dtype_get(node->x()); }
116 loco::DataType visit(const luci::CircleElu *node) final
118 return loco::dtype_get(node->features());
121 loco::DataType visit(const luci::CircleEqual *) final { return loco::DataType::BOOL; }
123 loco::DataType visit(const luci::CircleExp *node) final { return loco::dtype_get(node->x()); }
125 loco::DataType visit(const luci::CircleExpandDims *node) final
127 return loco::dtype_get(node->input());
130 loco::DataType visit(const luci::CircleFill *node) final
132 return loco::dtype_get(node->value());
135 loco::DataType visit(const luci::CircleFloor *node) final { return loco::dtype_get(node->x()); }
137 loco::DataType visit(const luci::CircleFloorDiv *node) final
139 return loco::dtype_get(node->x());
142 loco::DataType visit(const luci::CircleFloorMod *node) final
144 return loco::dtype_get(node->x());
147 loco::DataType visit(const luci::CircleFullyConnected *node) final
149 return loco::dtype_get(node->input());
152 loco::DataType visit(const luci::CircleGather *node) final
154 return loco::dtype_get(node->params());
157 loco::DataType visit(const luci::CircleGatherNd *node) final
159 return loco::dtype_get(node->params());
162 loco::DataType visit(const luci::CircleGreater *) final { return loco::DataType::BOOL; }
164 loco::DataType visit(const luci::CircleGreaterEqual *) final { return loco::DataType::BOOL; }
166 loco::DataType visit(const luci::CircleIf *node) final
168 // Type of If is not used. Just use input 0
169 assert(node->input_count() > 0);
170 return loco::dtype_get(node->input(0));
173 loco::DataType visit(const luci::CircleL2Normalize *node) final
175 return loco::dtype_get(node->x());
178 loco::DataType visit(const luci::CircleL2Pool2D *node) final
180 return loco::dtype_get(node->value());
183 loco::DataType visit(const luci::CircleLeakyRelu *node) final
185 return loco::dtype_get(node->features());
188 loco::DataType visit(const luci::CircleLess *) final { return loco::DataType::BOOL; }
190 loco::DataType visit(const luci::CircleLessEqual *) final { return loco::DataType::BOOL; }
192 loco::DataType visit(const luci::CircleLocalResponseNormalization *node) final
194 return loco::dtype_get(node->input());
197 loco::DataType visit(const luci::CircleLog *node) final { return loco::dtype_get(node->x()); }
199 loco::DataType visit(const luci::CircleLogicalAnd *node) final
201 return loco::dtype_get(node->x());
204 loco::DataType visit(const luci::CircleLogicalNot *node) final
206 return loco::dtype_get(node->x());
209 loco::DataType visit(const luci::CircleLogicalOr *node) final
211 return loco::dtype_get(node->x());
214 loco::DataType visit(const luci::CircleLogistic *node) final
216 return loco::dtype_get(node->x());
219 loco::DataType visit(const luci::CircleLogSoftmax *node) final
221 return loco::dtype_get(node->logits());
224 loco::DataType visit(const luci::CircleMatrixDiag *node) final
226 return loco::dtype_get(node->diagonal());
229 loco::DataType visit(const luci::CircleMatrixSetDiag *node) final
231 return loco::dtype_get(node->input());
234 loco::DataType visit(const luci::CircleMaximum *node) final { return loco::dtype_get(node->x()); }
236 loco::DataType visit(const luci::CircleMaxPool2D *node) final
238 return loco::dtype_get(node->value());
241 loco::DataType visit(const luci::CircleMean *node) final
243 return loco::dtype_get(node->input());
246 loco::DataType visit(const luci::CircleMinimum *node) final { return loco::dtype_get(node->x()); }
248 loco::DataType visit(const luci::CircleMirrorPad *node) final
250 return loco::dtype_get(node->input());
253 loco::DataType visit(const luci::CircleNeg *node) final { return loco::dtype_get(node->x()); }
255 loco::DataType visit(const luci::CircleNonMaxSuppressionV4 *node) final
257 return loco::dtype_get(node->boxes());
260 loco::DataType visit(const luci::CircleNonMaxSuppressionV5 *node) final
262 return loco::dtype_get(node->boxes());
265 loco::DataType visit(const luci::CircleNotEqual *) final { return loco::DataType::BOOL; }
267 loco::DataType visit(const luci::CirclePack *node) final
269 // Only support CirclePack with one or more inputs
270 assert(node->values_count() > 0);
272 auto first_value_type = loco::dtype_get(node->values(0));
273 for (uint32_t i = 1; i < node->values_count(); ++i)
274 assert(first_value_type == loco::dtype_get(node->values(i)));
276 return first_value_type;
279 loco::DataType visit(const luci::CirclePad *node) final { return loco::dtype_get(node->input()); }
281 loco::DataType visit(const luci::CirclePadV2 *node) final
283 return loco::dtype_get(node->input());
286 loco::DataType visit(const luci::CirclePow *node) final
288 // TODO make sure types cannot differ
289 auto x_type = loco::dtype_get(node->x());
290 auto y_type = loco::dtype_get(node->y());
292 if (x_type != y_type)
293 INTERNAL_EXN("Different datatype for x and y are not supported");
298 loco::DataType visit(const luci::CirclePRelu *node) final
300 auto input_type = loco::dtype_get(node->input());
301 auto alpha_type = loco::dtype_get(node->alpha());
303 if (input_type != alpha_type)
304 INTERNAL_EXN("Different datatype for input and alpha are not supported");
309 loco::DataType visit(const luci::CircleRange *node) final
311 return loco::dtype_get(node->start());
314 loco::DataType visit(const luci::CircleRank *) final { return loco::DataType::S32; }
316 loco::DataType visit(const luci::CircleMul *node) final { return loco::dtype_get(node->x()); }
318 loco::DataType visit(const luci::CircleOneHot *node) final
320 return loco::dtype_get(node->on_value());
323 loco::DataType visit(const luci::CircleReduceAny *node) final
325 return loco::dtype_get(node->input());
328 loco::DataType visit(const luci::CircleReduceMax *node) final
330 return loco::dtype_get(node->input());
333 loco::DataType visit(const luci::CircleReduceMin *node) final
335 return loco::dtype_get(node->input());
338 loco::DataType visit(const luci::CircleReduceProd *node) final
340 return loco::dtype_get(node->input());
343 loco::DataType visit(const luci::CircleRelu *node) final
345 return loco::dtype_get(node->features());
348 loco::DataType visit(const luci::CircleRelu6 *node) final
350 return loco::dtype_get(node->features());
353 loco::DataType visit(const luci::CircleReluN1To1 *node) final
355 return loco::dtype_get(node->features());
358 loco::DataType visit(const luci::CircleReshape *node) final
360 return loco::dtype_get(node->tensor());
363 loco::DataType visit(const luci::CircleResizeBilinear *node) final
365 return loco::dtype_get(node->input());
368 loco::DataType visit(const luci::CircleResizeNearestNeighbor *node) final
370 return loco::dtype_get(node->input());
373 loco::DataType visit(const luci::CircleReverseSequence *node) final
375 return loco::dtype_get(node->input());
378 loco::DataType visit(const luci::CircleReverseV2 *node) final
380 return loco::dtype_get(node->tensor());
383 loco::DataType visit(const luci::CircleRound *node) final { return loco::dtype_get(node->x()); }
385 loco::DataType visit(const luci::CircleRsqrt *node) final { return loco::dtype_get(node->x()); }
387 loco::DataType visit(const luci::CircleScatterNd *node) final
389 return loco::dtype_get(node->updates());
392 loco::DataType visit(const luci::CircleSegmentSum *node) final
394 return loco::dtype_get(node->input());
397 loco::DataType visit(const luci::CircleSelect *node) final
399 assert(loco::dtype_get(node->t()) == loco::dtype_get(node->e()));
400 return loco::dtype_get(node->t());
403 loco::DataType visit(const luci::CircleSelectV2 *node) final
405 assert(loco::dtype_get(node->t()) == loco::dtype_get(node->e()));
406 return loco::dtype_get(node->t());
409 loco::DataType visit(const luci::CircleShape *node) final { return node->out_type(); }
411 loco::DataType visit(const luci::CircleSin *node) final { return loco::dtype_get(node->x()); }
413 loco::DataType visit(const luci::CircleSlice *node) final
415 return loco::dtype_get(node->input());
418 loco::DataType visit(const luci::CircleSoftmax *node) final
420 return loco::dtype_get(node->logits());
423 loco::DataType visit(const luci::CircleSpaceToBatchND *node) final
425 return loco::dtype_get(node->input());
428 loco::DataType visit(const luci::CircleSpaceToDepth *node) final
430 return loco::dtype_get(node->input());
433 loco::DataType visit(const luci::CircleSparseToDense *node) final
435 return loco::dtype_get(node->values());
438 loco::DataType visit(const luci::CircleSplit *node) final
440 return loco::dtype_get(node->input());
443 loco::DataType visit(const luci::CircleSplitV *node) final
445 return loco::dtype_get(node->input());
448 loco::DataType visit(const luci::CircleSqrt *node) final { return loco::dtype_get(node->x()); }
450 loco::DataType visit(const luci::CircleSquare *node) final { return loco::dtype_get(node->x()); }
452 loco::DataType visit(const luci::CircleSquaredDifference *node) final
454 return loco::dtype_get(node->x());
457 loco::DataType visit(const luci::CircleSqueeze *node) final
459 return loco::dtype_get(node->input());
462 loco::DataType visit(const luci::CircleStridedSlice *node) final
464 return loco::dtype_get(node->input());
467 loco::DataType visit(const luci::CircleSub *node) final { return loco::dtype_get(node->x()); }
469 loco::DataType visit(const luci::CircleSum *node) final { return loco::dtype_get(node->input()); }
471 loco::DataType visit(const luci::CircleTanh *node) final { return loco::dtype_get(node->x()); }
473 loco::DataType visit(const luci::CircleTile *node) final
475 return loco::dtype_get(node->input());
478 loco::DataType visit(const luci::CircleTopKV2 *node) final
480 return loco::dtype_get(node->input());
483 loco::DataType visit(const luci::CircleTranspose *node) final
485 return loco::dtype_get(node->a());
488 loco::DataType visit(const luci::CircleTransposeConv *node) final
490 return loco::dtype_get(node->outBackprop());
493 loco::DataType visit(const luci::CircleUnique *node) final
495 return loco::dtype_get(node->input());
498 loco::DataType visit(const luci::CircleUnpack *node) final
500 return loco::dtype_get(node->value());
503 loco::DataType visit(const luci::CircleWhere *) final { return loco::DataType::S64; }
505 loco::DataType visit(const luci::CircleWhile *node) final
507 // Type of While is not used. Just use input 0
508 assert(node->input_count() > 0);
509 return loco::dtype_get(node->input(0));
512 loco::DataType visit(const luci::CircleZerosLike *node) final
514 return loco::dtype_get(node->input());
518 loco::DataType visit(const luci::CircleBCQFullyConnected *) final
520 return loco::DataType::FLOAT32;
523 loco::DataType visit(const luci::CircleBCQGather *) final { return loco::DataType::FLOAT32; }
525 loco::DataType visit(const luci::CircleInstanceNorm *node) final
527 return loco::dtype_get(node->input());
531 loco::DataType visit(const luci::CircleInput *node) final { return node->dtype(); }
533 loco::DataType visit(const luci::CircleOutput *node) final
535 auto graph_outputs = node->graph()->outputs();
536 auto graph_output = graph_outputs->at(node->index());
537 auto output_dtype = graph_output->dtype();
539 if (dynamic_cast<luci::CircleOutputDummy *>(node->from()) == nullptr &&
540 dynamic_cast<luci::CircleOutputExclude *>(node->from()) == nullptr)
542 // We don't care for the type if from() is CircleOutputDummy or CircleOutputExclude
543 // from() type should match that of CircleOutput
544 assert(output_dtype == loco::dtype_get(node->from()));
549 loco::DataType visit(const luci::CircleOutputDummy *node) final { return node->dtype(); }
551 loco::DataType visit(const luci::CircleOutputExclude *node) final { return node->dtype(); }
553 loco::DataType visit(const luci::CircleCustomOut *node) final { return node->dtype(); }
555 loco::DataType visit(const luci::CircleIfOut *node) final
558 * @note IF operator type and shape are that of the "then" and "else"
561 auto circle_if = dynamic_cast<const luci::CircleIf *>(node->input());
562 if (circle_if == nullptr)
564 INTERNAL_EXN("CircleIf IR is not configured correctly");
567 auto index = node->index();
568 auto then_graph = circle_if->then_graph();
569 auto else_graph = circle_if->else_graph();
570 assert(then_graph != nullptr);
571 assert(else_graph != nullptr);
573 // shape and type are assumed to be same
574 // these are checked at post_import_graph() in Import
575 auto then_outputs = loco::output_nodes(then_graph);
576 auto else_outputs = loco::output_nodes(else_graph);
577 assert(then_outputs.size() == else_outputs.size());
578 assert(index < static_cast<int32_t>(then_outputs.size()));
580 auto then_out = loco::must_cast<luci::CircleOutput *>(then_outputs.at(index));
581 auto else_out = loco::must_cast<luci::CircleOutput *>(else_outputs.at(index));
583 auto then_graph_outputs = then_graph->outputs(); // loco::GraphOutput items
584 auto else_graph_outputs = else_graph->outputs();
585 assert(then_graph_outputs->size() == else_graph_outputs->size());
587 auto then_graph_output = then_graph_outputs->at(then_out->index());
588 auto else_graph_output = else_graph_outputs->at(else_out->index());
589 (void)else_graph_output; // make compiler happy for unused variable warnings
590 assert(then_graph_output->dtype() == else_graph_output->dtype());
592 return then_graph_output->dtype();
595 loco::DataType visit(const luci::CircleNonMaxSuppressionV4Out *node) final
598 assert(node->index() == 0 || node->index() == 1);
599 return loco::DataType::S32;
602 loco::DataType visit(const luci::CircleNonMaxSuppressionV5Out *node) final
605 if (node->index() == 0 || node->index() == 2)
607 return loco::DataType::S32;
609 assert(node->index() == 1);
610 return loco::DataType::FLOAT32;
613 loco::DataType visit(const luci::CircleSplitOut *node) final
615 return loco::dtype_get(node->input());
618 loco::DataType visit(const luci::CircleSplitVOut *node) final
620 return loco::dtype_get(node->input());
623 loco::DataType visit(const luci::CircleTopKV2Out *node) final
625 // First output is same as input
626 if (node->index() == 0)
627 return loco::dtype_get(node->input());
628 // Second outout is always S32
629 assert(node->index() == 1);
630 return loco::DataType::S32;
633 loco::DataType visit(const luci::CircleUniqueOut *node) final
635 if (node->index() == 0)
637 return loco::dtype_get(node->input());
639 assert(node->index() == 1);
640 auto unique = loco::must_cast<luci::CircleUnique *>(node->input());
641 return unique->idx_out_type();
644 loco::DataType visit(const luci::CircleUnpackOut *node) final
646 return loco::dtype_get(node->input());
649 loco::DataType visit(const luci::CircleWhileOut *node) final
652 * @note WHILE operator's type is the same with the "cond"
655 auto circle_while = dynamic_cast<const luci::CircleWhile *>(node->input());
656 if (circle_while == nullptr)
658 INTERNAL_EXN("CircleWhile IR is not configured correctly");
661 auto index = node->index();
662 auto cond_graph = circle_while->cond_graph();
663 assert(cond_graph != nullptr);
665 // Assumption: the index of CircleWhileOut matches with the index of input nodes returned by
667 auto cond_inputs = loco::input_nodes(cond_graph);
668 auto cond_in = loco::must_cast<luci::CircleInput *>(cond_inputs.at(index));
670 auto cond_graph_inputs = cond_graph->inputs();
671 auto cond_graph_input = cond_graph_inputs->at(cond_in->index());
673 return cond_graph_input->dtype();
682 bool CircleTypeInferenceRule::recognize(const loco::Dialect *d) const
684 return CircleDialect::get() == d;
687 bool CircleTypeInferenceRule::infer(const loco::Node *node, loco::DataType &dtype) const
689 assert(node->dialect() == CircleDialect::get());
691 TypeInferenceAlgorithm alg;
693 auto circle_node = loco::must_cast<const CircleNode *>(node);
694 dtype = circle_node->accept(&alg);
695 assert(dtype != loco::DataType::Unknown);