2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Service/CircleTypeInferenceRule.h"
18 #include "CircleTypeInferenceHelper.h"
20 #include <luci/IR/CircleDialect.h>
21 #include <luci/IR/CircleNodeVisitor.h>
22 #include <luci/IR/CircleNodes.h>
29 struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataType>
31 // TODO Given a tensor x of complex numbers, Abs operation returns a tensor of type float32 or
33 loco::DataType visit(const luci::CircleAbs *node) final { return luci::dtype_get(node->x()); }
35 loco::DataType visit(const luci::CircleAdd *node) final { return luci::dtype_get(node->x()); }
37 loco::DataType visit(const luci::CircleAddN *node) final
39 auto dtype = luci::dtype_get(node->inputs(0));
41 for (uint32_t idx = 1; idx < node->arity(); ++idx)
43 auto dtype_idx = luci::dtype_get(node->inputs(idx));
44 if (dtype != dtype_idx)
46 INTERNAL_EXN_V("ADD_N dtype not same as the first input: ", idx);
50 return luci::dtype_get(node->inputs(0));
53 loco::DataType visit(const luci::CircleArgMax *node) final { return node->output_type(); }
55 loco::DataType visit(const luci::CircleArgMin *node) final { return node->output_type(); }
57 loco::DataType visit(const luci::CircleAveragePool2D *node) final
59 return luci::dtype_get(node->value());
62 loco::DataType visit(const luci::CircleBatchMatMul *node) final
64 return luci::dtype_get(node->x());
67 loco::DataType visit(const luci::CircleBatchToSpaceND *node) final
69 return luci::dtype_get(node->input());
72 loco::DataType visit(const luci::CircleCast *node) final { return node->dtype(); }
74 loco::DataType visit(const luci::CircleCeil *node) final { return luci::dtype_get(node->x()); }
76 loco::DataType visit(const luci::CircleConcatenation *node) final
78 // TODO Support when CircleConcatenation has 0 input
79 assert(node->numValues() > 0);
81 for (uint32_t i = 1; i < node->numValues(); ++i)
82 assert(luci::dtype_get(node->values(i - 1)) == luci::dtype_get(node->values(i)));
84 return luci::dtype_get(node->values(0));
87 loco::DataType visit(const luci::CircleConst *node) final { return node->dtype(); }
89 loco::DataType visit(const luci::CircleConv2D *node) final
91 return luci::dtype_get(node->input());
94 loco::DataType visit(const luci::CircleCos *node) final { return luci::dtype_get(node->x()); }
96 loco::DataType visit(const luci::CircleCustom *node) final
98 if (node->custom_code() == "BatchMatMulV2")
100 return luci::dtype_get(node->inputs(0));
102 return node->dtype();
105 loco::DataType visit(const luci::CircleDensify *node) final
107 return luci::dtype_get(node->input());
110 loco::DataType visit(const luci::CircleDepthToSpace *node) final
112 return luci::dtype_get(node->input());
115 loco::DataType visit(const luci::CircleDepthwiseConv2D *node) final
117 return luci::dtype_get(node->input());
120 loco::DataType visit(const luci::CircleDequantize *) final { return loco::DataType::FLOAT32; }
122 loco::DataType visit(const luci::CircleDiv *node) final { return luci::dtype_get(node->x()); }
124 loco::DataType visit(const luci::CircleElu *node) final
126 return luci::dtype_get(node->features());
129 loco::DataType visit(const luci::CircleEqual *) final { return loco::DataType::BOOL; }
131 loco::DataType visit(const luci::CircleExp *node) final { return luci::dtype_get(node->x()); }
133 loco::DataType visit(const luci::CircleExpandDims *node) final
135 return luci::dtype_get(node->input());
138 loco::DataType visit(const luci::CircleFakeQuant *node) final
140 return luci::dtype_get(node->inputs());
143 loco::DataType visit(const luci::CircleFill *node) final
145 return luci::dtype_get(node->value());
148 loco::DataType visit(const luci::CircleFloor *node) final { return luci::dtype_get(node->x()); }
150 loco::DataType visit(const luci::CircleFloorDiv *node) final
152 return luci::dtype_get(node->x());
155 loco::DataType visit(const luci::CircleFloorMod *node) final
157 return luci::dtype_get(node->x());
160 loco::DataType visit(const luci::CircleFullyConnected *node) final
162 return luci::dtype_get(node->input());
165 loco::DataType visit(const luci::CircleGather *node) final
167 return luci::dtype_get(node->params());
170 loco::DataType visit(const luci::CircleGatherNd *node) final
172 return luci::dtype_get(node->params());
175 loco::DataType visit(const luci::CircleGelu *node) final
177 return luci::dtype_get(node->features());
180 loco::DataType visit(const luci::CircleGreater *) final { return loco::DataType::BOOL; }
182 loco::DataType visit(const luci::CircleGreaterEqual *) final { return loco::DataType::BOOL; }
184 loco::DataType visit(const luci::CircleHardSwish *node) final
186 return luci::dtype_get(node->features());
189 loco::DataType visit(const luci::CircleIf *node) final
191 // Type of If is not used. Just use input 0
192 assert(node->input_count() > 0);
193 return luci::dtype_get(node->input(0));
196 loco::DataType visit(const luci::CircleL2Normalize *node) final
198 return luci::dtype_get(node->x());
201 loco::DataType visit(const luci::CircleL2Pool2D *node) final
203 return luci::dtype_get(node->value());
206 loco::DataType visit(const luci::CircleLeakyRelu *node) final
208 return luci::dtype_get(node->features());
211 loco::DataType visit(const luci::CircleLess *) final { return loco::DataType::BOOL; }
213 loco::DataType visit(const luci::CircleLessEqual *) final { return loco::DataType::BOOL; }
215 loco::DataType visit(const luci::CircleLocalResponseNormalization *node) final
217 return luci::dtype_get(node->input());
220 loco::DataType visit(const luci::CircleLog *node) final { return luci::dtype_get(node->x()); }
222 loco::DataType visit(const luci::CircleLogicalAnd *node) final
224 return luci::dtype_get(node->x());
227 loco::DataType visit(const luci::CircleLogicalNot *node) final
229 return luci::dtype_get(node->x());
232 loco::DataType visit(const luci::CircleLogicalOr *node) final
234 return luci::dtype_get(node->x());
237 loco::DataType visit(const luci::CircleLogistic *node) final
239 return luci::dtype_get(node->x());
242 loco::DataType visit(const luci::CircleLogSoftmax *node) final
244 return luci::dtype_get(node->logits());
247 loco::DataType visit(const luci::CircleMatrixDiag *node) final
249 return luci::dtype_get(node->diagonal());
252 loco::DataType visit(const luci::CircleMatrixSetDiag *node) final
254 return luci::dtype_get(node->input());
257 loco::DataType visit(const luci::CircleMaximum *node) final { return luci::dtype_get(node->x()); }
259 loco::DataType visit(const luci::CircleMaxPool2D *node) final
261 return luci::dtype_get(node->value());
264 loco::DataType visit(const luci::CircleMean *node) final
266 return luci::dtype_get(node->input());
269 loco::DataType visit(const luci::CircleMinimum *node) final { return luci::dtype_get(node->x()); }
271 loco::DataType visit(const luci::CircleMirrorPad *node) final
273 return luci::dtype_get(node->input());
276 loco::DataType visit(const luci::CircleNeg *node) final { return luci::dtype_get(node->x()); }
278 loco::DataType visit(const luci::CircleNonMaxSuppressionV4 *node) final
280 return luci::dtype_get(node->boxes());
283 loco::DataType visit(const luci::CircleNonMaxSuppressionV5 *node) final
285 return luci::dtype_get(node->boxes());
288 loco::DataType visit(const luci::CircleNotEqual *) final { return loco::DataType::BOOL; }
290 loco::DataType visit(const luci::CirclePack *node) final
292 // Only support CirclePack with one or more inputs
293 assert(node->values_count() > 0);
295 auto first_value_type = luci::dtype_get(node->values(0));
296 for (uint32_t i = 1; i < node->values_count(); ++i)
297 assert(first_value_type == luci::dtype_get(node->values(i)));
299 return first_value_type;
302 loco::DataType visit(const luci::CirclePad *node) final { return luci::dtype_get(node->input()); }
304 loco::DataType visit(const luci::CirclePadV2 *node) final
306 return luci::dtype_get(node->input());
309 loco::DataType visit(const luci::CirclePow *node) final
311 // TODO make sure types cannot differ
312 auto x_type = luci::dtype_get(node->x());
313 auto y_type = luci::dtype_get(node->y());
315 if (x_type != y_type)
316 INTERNAL_EXN("Different datatype for x and y are not supported");
321 loco::DataType visit(const luci::CirclePRelu *node) final
323 auto input_type = luci::dtype_get(node->input());
324 auto alpha_type = luci::dtype_get(node->alpha());
326 if (input_type != alpha_type)
327 INTERNAL_EXN("Different datatype for input and alpha are not supported");
332 loco::DataType visit(const luci::CircleQuantize *node) final { return luci::dtype_get(node); }
334 loco::DataType visit(const luci::CircleRange *node) final
336 return luci::dtype_get(node->start());
339 loco::DataType visit(const luci::CircleRank *) final { return loco::DataType::S32; }
341 loco::DataType visit(const luci::CircleMul *node) final { return luci::dtype_get(node->x()); }
343 loco::DataType visit(const luci::CircleOneHot *node) final
345 return luci::dtype_get(node->on_value());
348 loco::DataType visit(const luci::CircleReduceAny *node) final
350 return luci::dtype_get(node->input());
353 loco::DataType visit(const luci::CircleReduceMax *node) final
355 return luci::dtype_get(node->input());
358 loco::DataType visit(const luci::CircleReduceMin *node) final
360 return luci::dtype_get(node->input());
363 loco::DataType visit(const luci::CircleReduceProd *node) final
365 return luci::dtype_get(node->input());
368 loco::DataType visit(const luci::CircleRelu *node) final
370 return luci::dtype_get(node->features());
373 loco::DataType visit(const luci::CircleRelu6 *node) final
375 return luci::dtype_get(node->features());
378 loco::DataType visit(const luci::CircleReluN1To1 *node) final
380 return luci::dtype_get(node->features());
383 loco::DataType visit(const luci::CircleReshape *node) final
385 return luci::dtype_get(node->tensor());
388 loco::DataType visit(const luci::CircleResizeBilinear *node) final
390 return luci::dtype_get(node->input());
393 loco::DataType visit(const luci::CircleResizeNearestNeighbor *node) final
395 return luci::dtype_get(node->input());
398 loco::DataType visit(const luci::CircleReverseSequence *node) final
400 return luci::dtype_get(node->input());
403 loco::DataType visit(const luci::CircleReverseV2 *node) final
405 return luci::dtype_get(node->tensor());
408 loco::DataType visit(const luci::CircleRound *node) final { return luci::dtype_get(node->x()); }
410 loco::DataType visit(const luci::CircleRsqrt *node) final { return luci::dtype_get(node->x()); }
412 loco::DataType visit(const luci::CircleScatterNd *node) final
414 return luci::dtype_get(node->updates());
417 loco::DataType visit(const luci::CircleSegmentSum *node) final
419 return luci::dtype_get(node->input());
422 loco::DataType visit(const luci::CircleSelect *node) final
424 assert(luci::dtype_get(node->t()) == luci::dtype_get(node->e()));
425 return luci::dtype_get(node->t());
428 loco::DataType visit(const luci::CircleSelectV2 *node) final
430 assert(luci::dtype_get(node->t()) == luci::dtype_get(node->e()));
431 return luci::dtype_get(node->t());
434 loco::DataType visit(const luci::CircleShape *node) final { return node->out_type(); }
436 loco::DataType visit(const luci::CircleSin *node) final { return luci::dtype_get(node->x()); }
438 loco::DataType visit(const luci::CircleSlice *node) final
440 return luci::dtype_get(node->input());
443 loco::DataType visit(const luci::CircleSoftmax *node) final
445 return luci::dtype_get(node->logits());
448 loco::DataType visit(const luci::CircleSpaceToBatchND *node) final
450 return luci::dtype_get(node->input());
453 loco::DataType visit(const luci::CircleSpaceToDepth *node) final
455 return luci::dtype_get(node->input());
458 loco::DataType visit(const luci::CircleSparseToDense *node) final
460 return luci::dtype_get(node->values());
463 loco::DataType visit(const luci::CircleSplit *node) final
465 return luci::dtype_get(node->input());
468 loco::DataType visit(const luci::CircleSplitV *node) final
470 return luci::dtype_get(node->input());
473 loco::DataType visit(const luci::CircleSqrt *node) final { return luci::dtype_get(node->x()); }
475 loco::DataType visit(const luci::CircleSquare *node) final { return luci::dtype_get(node->x()); }
477 loco::DataType visit(const luci::CircleSquaredDifference *node) final
479 return luci::dtype_get(node->x());
482 loco::DataType visit(const luci::CircleSqueeze *node) final
484 return luci::dtype_get(node->input());
487 loco::DataType visit(const luci::CircleStridedSlice *node) final
489 return luci::dtype_get(node->input());
492 loco::DataType visit(const luci::CircleSub *node) final { return luci::dtype_get(node->x()); }
494 loco::DataType visit(const luci::CircleSum *node) final { return luci::dtype_get(node->input()); }
496 loco::DataType visit(const luci::CircleSVDF *node) final
498 return luci::dtype_get(node->input());
501 loco::DataType visit(const luci::CircleTanh *node) final { return luci::dtype_get(node->x()); }
503 loco::DataType visit(const luci::CircleTile *node) final
505 return luci::dtype_get(node->input());
508 loco::DataType visit(const luci::CircleTopKV2 *node) final
510 return luci::dtype_get(node->input());
513 loco::DataType visit(const luci::CircleTranspose *node) final
515 return luci::dtype_get(node->a());
518 loco::DataType visit(const luci::CircleTransposeConv *node) final
520 return luci::dtype_get(node->outBackprop());
523 loco::DataType visit(const luci::CircleUnidirectionalSequenceLSTM *node) final
525 return luci::dtype_get(node->input());
528 loco::DataType visit(const luci::CircleUnique *node) final
530 return luci::dtype_get(node->input());
533 loco::DataType visit(const luci::CircleUnpack *node) final
535 return luci::dtype_get(node->value());
538 loco::DataType visit(const luci::CircleWhere *) final { return loco::DataType::S64; }
540 loco::DataType visit(const luci::CircleWhile *node) final
542 // Type of While is not used. Just use input 0
543 assert(node->input_count() > 0);
544 return luci::dtype_get(node->input(0));
547 loco::DataType visit(const luci::CircleZerosLike *node) final
549 return luci::dtype_get(node->input());
553 loco::DataType visit(const luci::CircleBCQFullyConnected *) final
555 return loco::DataType::FLOAT32;
558 loco::DataType visit(const luci::CircleBCQGather *) final { return loco::DataType::FLOAT32; }
560 loco::DataType visit(const luci::CircleInstanceNorm *node) final
562 return luci::dtype_get(node->input());
566 loco::DataType visit(const luci::CircleInput *node) final { return node->dtype(); }
568 loco::DataType visit(const luci::CircleOutput *node) final
570 auto graph_outputs = node->graph()->outputs();
571 auto graph_output = graph_outputs->at(node->index());
572 auto output_dtype = graph_output->dtype();
574 if (dynamic_cast<luci::CircleOutputDummy *>(node->from()) == nullptr &&
575 dynamic_cast<luci::CircleOutputExclude *>(node->from()) == nullptr)
577 // We don't care for the type if from() is CircleOutputDummy or CircleOutputExclude
578 // from() type should match that of CircleOutput
579 assert(output_dtype == luci::dtype_get(node->from()));
584 loco::DataType visit(const luci::CircleOutputDummy *node) final { return node->dtype(); }
586 loco::DataType visit(const luci::CircleOutputExclude *node) final
588 // NOTE We don't care CircleOutputExclude dtype, but set to FLOAT32
589 // if it's Unknown to make type inference happy.
590 if (node->dtype() == loco::DataType::Unknown)
591 return loco::DataType::FLOAT32;
592 return node->dtype();
595 loco::DataType visit(const luci::CircleCustomOut *node) final { return node->dtype(); }
597 loco::DataType visit(const luci::CircleNonMaxSuppressionV4Out *node) final
600 assert(node->index() == 0 || node->index() == 1);
601 return loco::DataType::S32;
604 loco::DataType visit(const luci::CircleNonMaxSuppressionV5Out *node) final
607 if (node->index() == 0 || node->index() == 2)
609 return loco::DataType::S32;
611 assert(node->index() == 1);
612 return loco::DataType::FLOAT32;
615 loco::DataType visit(const luci::CircleSplitOut *node) final
617 return luci::dtype_get(node->input());
620 loco::DataType visit(const luci::CircleSplitVOut *node) final
622 return luci::dtype_get(node->input());
625 loco::DataType visit(const luci::CircleTopKV2Out *node) final
627 // First output is same as input
628 if (node->index() == 0)
629 return luci::dtype_get(node->input());
630 // Second outout is always S32
631 assert(node->index() == 1);
632 return loco::DataType::S32;
635 loco::DataType visit(const luci::CircleVariable *node) final { return node->dtype(); }
637 loco::DataType visit(const luci::CircleUniqueOut *node) final
639 if (node->index() == 0)
641 return luci::dtype_get(node->input());
643 assert(node->index() == 1);
644 auto unique = loco::must_cast<luci::CircleUnique *>(node->input());
645 return unique->idx_out_type();
648 loco::DataType visit(const luci::CircleUnpackOut *node) final
650 return luci::dtype_get(node->input());
653 loco::DataType visit(const luci::CircleWhileOut *node) final
656 * @note WHILE operator's type is the same with the "cond"
659 auto circle_while = dynamic_cast<const luci::CircleWhile *>(node->input());
660 if (circle_while == nullptr)
662 INTERNAL_EXN("CircleWhile IR is not configured correctly");
665 auto index = node->index();
666 auto cond_graph = circle_while->cond_graph();
667 assert(cond_graph != nullptr);
669 // Assumption: the index of CircleWhileOut matches with the index of input nodes returned by
671 auto cond_inputs = loco::input_nodes(cond_graph);
672 auto cond_in = loco::must_cast<luci::CircleInput *>(cond_inputs.at(index));
674 auto cond_graph_inputs = cond_graph->inputs();
675 auto cond_graph_input = cond_graph_inputs->at(cond_in->index());
677 return cond_graph_input->dtype();
686 bool CircleTypeInferenceRule::recognize(const loco::Dialect *d) const
688 return CircleDialect::get() == d;
691 bool CircleTypeInferenceRule::infer(const loco::Node *node, loco::DataType &dtype) const
693 assert(node->dialect() == CircleDialect::get());
695 TypeInferenceAlgorithm alg;
697 auto circle_node = loco::must_cast<const CircleNode *>(node);
698 dtype = circle_node->accept(&alg);
699 assert(dtype != loco::DataType::Unknown);