2 * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "luci/Service/CircleTypeInferenceRule.h"
19 #include <luci/IR/CircleDialect.h>
20 #include <luci/IR/CircleNodeVisitor.h>
21 #include <luci/IR/CircleNodes.h>
28 struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataType>
30 // TODO Given a tensor x of complex numbers, Abs operation returns a tensor of type float32 or
32 loco::DataType visit(const luci::CircleAbs *node) final { return loco::dtype_get(node->x()); }
34 loco::DataType visit(const luci::CircleAdd *node) final { return loco::dtype_get(node->x()); }
36 loco::DataType visit(const luci::CircleAddN *node) final
38 auto dtype = loco::dtype_get(node->inputs(0));
40 for (uint32_t idx = 1; idx < node->arity(); ++idx)
42 auto dtype_idx = loco::dtype_get(node->inputs(idx));
43 if (dtype != dtype_idx)
45 INTERNAL_EXN_V("ADD_N dtype not same as the first input: ", idx);
49 return loco::dtype_get(node->inputs(0));
52 loco::DataType visit(const luci::CircleArgMax *node) final { return node->output_type(); }
54 loco::DataType visit(const luci::CircleArgMin *node) final { return node->output_type(); }
56 loco::DataType visit(const luci::CircleAveragePool2D *node) final
58 return loco::dtype_get(node->value());
61 loco::DataType visit(const luci::CircleBatchMatMul *node) final
63 return loco::dtype_get(node->x());
66 loco::DataType visit(const luci::CircleBatchToSpaceND *node) final
68 return loco::dtype_get(node->input());
71 loco::DataType visit(const luci::CircleCast *node) final { return node->dtype(); }
73 loco::DataType visit(const luci::CircleCeil *node) final { return loco::dtype_get(node->x()); }
75 loco::DataType visit(const luci::CircleConcatenation *node) final
77 // TODO Support when CircleConcatenation has 0 input
78 assert(node->numValues() > 0);
80 for (uint32_t i = 1; i < node->numValues(); ++i)
81 assert(loco::dtype_get(node->values(i - 1)) == loco::dtype_get(node->values(i)));
83 return loco::dtype_get(node->values(0));
86 loco::DataType visit(const luci::CircleConst *node) final { return node->dtype(); }
88 loco::DataType visit(const luci::CircleConv2D *node) final
90 return loco::dtype_get(node->input());
93 loco::DataType visit(const luci::CircleCos *node) final { return loco::dtype_get(node->x()); }
95 loco::DataType visit(const luci::CircleCustom *node) final
97 if (node->custom_code() == "BatchMatMulV2")
99 return loco::dtype_get(node->inputs(0));
101 return node->dtype();
104 loco::DataType visit(const luci::CircleDepthToSpace *node) final
106 return loco::dtype_get(node->input());
109 loco::DataType visit(const luci::CircleDepthwiseConv2D *node) final
111 return loco::dtype_get(node->input());
114 loco::DataType visit(const luci::CircleDiv *node) final { return loco::dtype_get(node->x()); }
116 loco::DataType visit(const luci::CircleElu *node) final
118 return loco::dtype_get(node->features());
121 loco::DataType visit(const luci::CircleEqual *) final { return loco::DataType::BOOL; }
123 loco::DataType visit(const luci::CircleExp *node) final { return loco::dtype_get(node->x()); }
125 loco::DataType visit(const luci::CircleExpandDims *node) final
127 return loco::dtype_get(node->input());
130 loco::DataType visit(const luci::CircleFill *node) final
132 return loco::dtype_get(node->value());
135 loco::DataType visit(const luci::CircleFloor *node) final { return loco::dtype_get(node->x()); }
137 loco::DataType visit(const luci::CircleFloorDiv *node) final
139 return loco::dtype_get(node->x());
142 loco::DataType visit(const luci::CircleFloorMod *node) final
144 return loco::dtype_get(node->x());
147 loco::DataType visit(const luci::CircleFullyConnected *node) final
149 return loco::dtype_get(node->input());
152 loco::DataType visit(const luci::CircleGather *node) final
154 return loco::dtype_get(node->params());
157 loco::DataType visit(const luci::CircleGatherNd *node) final
159 return loco::dtype_get(node->params());
162 loco::DataType visit(const luci::CircleGreater *) final { return loco::DataType::BOOL; }
164 loco::DataType visit(const luci::CircleGreaterEqual *) final { return loco::DataType::BOOL; }
166 loco::DataType visit(const luci::CircleIf *node) final
168 // Type of If is not used. Just use input 0
169 assert(node->input_count() > 0);
170 return loco::dtype_get(node->input(0));
173 loco::DataType visit(const luci::CircleL2Normalize *node) final
175 return loco::dtype_get(node->x());
178 loco::DataType visit(const luci::CircleL2Pool2D *node) final
180 return loco::dtype_get(node->value());
183 loco::DataType visit(const luci::CircleLeakyRelu *node) final
185 return loco::dtype_get(node->features());
188 loco::DataType visit(const luci::CircleLess *) final { return loco::DataType::BOOL; }
190 loco::DataType visit(const luci::CircleLessEqual *) final { return loco::DataType::BOOL; }
192 loco::DataType visit(const luci::CircleLocalResponseNormalization *node) final
194 return loco::dtype_get(node->input());
197 loco::DataType visit(const luci::CircleLog *node) final { return loco::dtype_get(node->x()); }
199 loco::DataType visit(const luci::CircleLogicalAnd *node) final
201 return loco::dtype_get(node->x());
204 loco::DataType visit(const luci::CircleLogicalNot *node) final
206 return loco::dtype_get(node->x());
209 loco::DataType visit(const luci::CircleLogicalOr *node) final
211 return loco::dtype_get(node->x());
214 loco::DataType visit(const luci::CircleLogistic *node) final
216 return loco::dtype_get(node->x());
219 loco::DataType visit(const luci::CircleLogSoftmax *node) final
221 return loco::dtype_get(node->logits());
224 loco::DataType visit(const luci::CircleMatrixDiag *node) final
226 return loco::dtype_get(node->diagonal());
229 loco::DataType visit(const luci::CircleMatrixSetDiag *node) final
231 return loco::dtype_get(node->input());
234 loco::DataType visit(const luci::CircleMaximum *node) final { return loco::dtype_get(node->x()); }
236 loco::DataType visit(const luci::CircleMaxPool2D *node) final
238 return loco::dtype_get(node->value());
241 loco::DataType visit(const luci::CircleMean *node) final
243 return loco::dtype_get(node->input());
246 loco::DataType visit(const luci::CircleMinimum *node) final { return loco::dtype_get(node->x()); }
248 loco::DataType visit(const luci::CircleMirrorPad *node) final
250 return loco::dtype_get(node->input());
253 loco::DataType visit(const luci::CircleNeg *node) final { return loco::dtype_get(node->x()); }
255 loco::DataType visit(const luci::CircleNonMaxSuppressionV4 *node) final
257 return loco::dtype_get(node->boxes());
260 loco::DataType visit(const luci::CircleNotEqual *) final { return loco::DataType::BOOL; }
262 loco::DataType visit(const luci::CirclePack *node) final
264 // Only support CirclePack with one or more inputs
265 assert(node->values_count() > 0);
267 auto first_value_type = loco::dtype_get(node->values(0));
268 for (uint32_t i = 1; i < node->values_count(); ++i)
269 assert(first_value_type == loco::dtype_get(node->values(i)));
271 return first_value_type;
274 loco::DataType visit(const luci::CirclePad *node) final { return loco::dtype_get(node->input()); }
276 loco::DataType visit(const luci::CirclePow *node) final
278 // TODO make sure types cannot differ
279 auto x_type = loco::dtype_get(node->x());
280 auto y_type = loco::dtype_get(node->y());
282 if (x_type != y_type)
283 INTERNAL_EXN("Different datatype for x and y are not supported");
288 loco::DataType visit(const luci::CirclePRelu *node) final
290 auto input_type = loco::dtype_get(node->input());
291 auto alpha_type = loco::dtype_get(node->alpha());
293 if (input_type != alpha_type)
294 INTERNAL_EXN("Different datatype for input and alpha are not supported");
299 loco::DataType visit(const luci::CircleRange *node) final
301 return loco::dtype_get(node->start());
304 loco::DataType visit(const luci::CircleRank *) final { return loco::DataType::S32; }
306 loco::DataType visit(const luci::CircleMul *node) final { return loco::dtype_get(node->x()); }
308 loco::DataType visit(const luci::CircleOneHot *node) final
310 return loco::dtype_get(node->on_value());
313 loco::DataType visit(const luci::CircleReduceAny *node) final
315 return loco::dtype_get(node->input());
318 loco::DataType visit(const luci::CircleReduceMax *node) final
320 return loco::dtype_get(node->input());
323 loco::DataType visit(const luci::CircleReduceMin *node) final
325 return loco::dtype_get(node->input());
328 loco::DataType visit(const luci::CircleReduceProd *node) final
330 return loco::dtype_get(node->input());
333 loco::DataType visit(const luci::CircleRelu *node) final
335 return loco::dtype_get(node->features());
338 loco::DataType visit(const luci::CircleRelu6 *node) final
340 return loco::dtype_get(node->features());
343 loco::DataType visit(const luci::CircleReluN1To1 *node) final
345 return loco::dtype_get(node->features());
348 loco::DataType visit(const luci::CircleReshape *node) final
350 return loco::dtype_get(node->tensor());
353 loco::DataType visit(const luci::CircleResizeBilinear *node) final
355 return loco::dtype_get(node->input());
358 loco::DataType visit(const luci::CircleResizeNearestNeighbor *node) final
360 return loco::dtype_get(node->input());
363 loco::DataType visit(const luci::CircleReverseSequence *node) final
365 return loco::dtype_get(node->input());
368 loco::DataType visit(const luci::CircleReverseV2 *node) final
370 return loco::dtype_get(node->tensor());
373 loco::DataType visit(const luci::CircleRound *node) final { return loco::dtype_get(node->x()); }
375 loco::DataType visit(const luci::CircleRsqrt *node) final { return loco::dtype_get(node->x()); }
377 loco::DataType visit(const luci::CircleScatterNd *node) final
379 return loco::dtype_get(node->updates());
382 loco::DataType visit(const luci::CircleSegmentSum *node) final
384 return loco::dtype_get(node->input());
387 loco::DataType visit(const luci::CircleSelect *node) final
389 assert(loco::dtype_get(node->t()) == loco::dtype_get(node->e()));
390 return loco::dtype_get(node->t());
393 loco::DataType visit(const luci::CircleSelectV2 *node) final
395 assert(loco::dtype_get(node->t()) == loco::dtype_get(node->e()));
396 return loco::dtype_get(node->t());
399 loco::DataType visit(const luci::CircleShape *node) final { return node->out_type(); }
401 loco::DataType visit(const luci::CircleSin *node) final { return loco::dtype_get(node->x()); }
403 loco::DataType visit(const luci::CircleSlice *node) final
405 return loco::dtype_get(node->input());
408 loco::DataType visit(const luci::CircleSoftmax *node) final
410 return loco::dtype_get(node->logits());
413 loco::DataType visit(const luci::CircleSpaceToBatchND *node) final
415 return loco::dtype_get(node->input());
418 loco::DataType visit(const luci::CircleSpaceToDepth *node) final
420 return loco::dtype_get(node->input());
423 loco::DataType visit(const luci::CircleSparseToDense *node) final
425 return loco::dtype_get(node->values());
428 loco::DataType visit(const luci::CircleSplit *node) final
430 return loco::dtype_get(node->input());
433 loco::DataType visit(const luci::CircleSplitV *node) final
435 return loco::dtype_get(node->input());
438 loco::DataType visit(const luci::CircleSqrt *node) final { return loco::dtype_get(node->x()); }
440 loco::DataType visit(const luci::CircleSquare *node) final { return loco::dtype_get(node->x()); }
442 loco::DataType visit(const luci::CircleSquaredDifference *node) final
444 return loco::dtype_get(node->x());
447 loco::DataType visit(const luci::CircleSqueeze *node) final
449 return loco::dtype_get(node->input());
452 loco::DataType visit(const luci::CircleStridedSlice *node) final
454 return loco::dtype_get(node->input());
457 loco::DataType visit(const luci::CircleSub *node) final { return loco::dtype_get(node->x()); }
459 loco::DataType visit(const luci::CircleSum *node) final { return loco::dtype_get(node->input()); }
461 loco::DataType visit(const luci::CircleTanh *node) final { return loco::dtype_get(node->x()); }
463 loco::DataType visit(const luci::CircleTile *node) final
465 return loco::dtype_get(node->input());
468 loco::DataType visit(const luci::CircleTopKV2 *node) final
470 return loco::dtype_get(node->input());
473 loco::DataType visit(const luci::CircleTranspose *node) final
475 return loco::dtype_get(node->a());
478 loco::DataType visit(const luci::CircleTransposeConv *node) final
480 return loco::dtype_get(node->outBackprop());
483 loco::DataType visit(const luci::CircleUnique *node) final
485 return loco::dtype_get(node->input());
488 loco::DataType visit(const luci::CircleUnpack *node) final
490 return loco::dtype_get(node->value());
493 loco::DataType visit(const luci::CircleWhere *) final { return loco::DataType::S64; }
495 loco::DataType visit(const luci::CircleWhile *node) final
497 // Type of While is not used. Just use input 0
498 assert(node->input_count() > 0);
499 return loco::dtype_get(node->input(0));
502 loco::DataType visit(const luci::CircleZerosLike *node) final
504 return loco::dtype_get(node->input());
508 loco::DataType visit(const luci::CircleBCQFullyConnected *) final
510 return loco::DataType::FLOAT32;
513 loco::DataType visit(const luci::CircleBCQGather *) final { return loco::DataType::FLOAT32; }
515 loco::DataType visit(const luci::CircleInstanceNorm *node) final
517 return loco::dtype_get(node->input());
521 loco::DataType visit(const luci::CircleInput *node) final { return node->dtype(); }
523 loco::DataType visit(const luci::CircleOutput *node) final
525 auto graph_outputs = node->graph()->outputs();
526 auto graph_output = graph_outputs->at(node->index());
527 auto output_dtype = graph_output->dtype();
529 if (dynamic_cast<luci::CircleOutputDummy *>(node->from()) == nullptr &&
530 dynamic_cast<luci::CircleOutputExclude *>(node->from()) == nullptr)
532 // We don't care for the type if from() is CircleOutputDummy or CircleOutputExclude
533 // from() type should match that of CircleOutput
534 assert(output_dtype == loco::dtype_get(node->from()));
539 loco::DataType visit(const luci::CircleOutputDummy *node) final { return node->dtype(); }
541 loco::DataType visit(const luci::CircleOutputExclude *node) final { return node->dtype(); }
543 loco::DataType visit(const luci::CircleCustomOut *node) final { return node->dtype(); }
545 loco::DataType visit(const luci::CircleIfOut *node) final
548 * @note IF operator type and shape are that of the "then" and "else"
551 auto circle_if = dynamic_cast<const luci::CircleIf *>(node->input());
552 if (circle_if == nullptr)
554 INTERNAL_EXN("CircleIf IR is not configured correctly");
557 auto index = node->index();
558 auto then_graph = circle_if->then_graph();
559 auto else_graph = circle_if->else_graph();
560 assert(then_graph != nullptr);
561 assert(else_graph != nullptr);
563 // shape and type are assumed to be same
564 // these are checked at post_import_graph() in Import
565 auto then_outputs = loco::output_nodes(then_graph);
566 auto else_outputs = loco::output_nodes(else_graph);
567 assert(then_outputs.size() == else_outputs.size());
568 assert(index < static_cast<int32_t>(then_outputs.size()));
570 auto then_out = loco::must_cast<luci::CircleOutput *>(then_outputs.at(index));
571 auto else_out = loco::must_cast<luci::CircleOutput *>(else_outputs.at(index));
573 auto then_graph_outputs = then_graph->outputs(); // loco::GraphOutput items
574 auto else_graph_outputs = else_graph->outputs();
575 assert(then_graph_outputs->size() == else_graph_outputs->size());
577 auto then_graph_output = then_graph_outputs->at(then_out->index());
578 auto else_graph_output = else_graph_outputs->at(else_out->index());
579 (void)else_graph_output; // make compiler happy for unused variable warnings
580 assert(then_graph_output->dtype() == else_graph_output->dtype());
582 return then_graph_output->dtype();
585 loco::DataType visit(const luci::CircleNonMaxSuppressionV4Out *node) final
588 assert(node->index() == 0 || node->index() == 1);
589 return loco::DataType::S32;
592 loco::DataType visit(const luci::CircleSplitOut *node) final
594 return loco::dtype_get(node->input());
597 loco::DataType visit(const luci::CircleSplitVOut *node) final
599 return loco::dtype_get(node->input());
602 loco::DataType visit(const luci::CircleTopKV2Out *node) final
604 // First output is same as input
605 if (node->index() == 0)
606 return loco::dtype_get(node->input());
607 // Second outout is always S32
608 assert(node->index() == 1);
609 return loco::DataType::S32;
612 loco::DataType visit(const luci::CircleUniqueOut *node) final
614 if (node->index() == 0)
616 return loco::dtype_get(node->input());
618 assert(node->index() == 1);
619 auto unique = loco::must_cast<luci::CircleUnique *>(node->input());
620 return unique->idx_out_type();
623 loco::DataType visit(const luci::CircleUnpackOut *node) final
625 return loco::dtype_get(node->input());
628 loco::DataType visit(const luci::CircleWhileOut *node) final
631 * @note WHILE operator's type is the same with the "cond"
634 auto circle_while = dynamic_cast<const luci::CircleWhile *>(node->input());
635 if (circle_while == nullptr)
637 INTERNAL_EXN("CircleWhile IR is not configured correctly");
640 auto index = node->index();
641 auto cond_graph = circle_while->cond_graph();
642 assert(cond_graph != nullptr);
644 // Assumption: the index of CircleWhileOut matches with the index of input nodes returned by
646 auto cond_inputs = loco::input_nodes(cond_graph);
647 auto cond_in = loco::must_cast<luci::CircleInput *>(cond_inputs.at(index));
649 auto cond_graph_inputs = cond_graph->inputs();
650 auto cond_graph_input = cond_graph_inputs->at(cond_in->index());
652 return cond_graph_input->dtype();
661 bool CircleTypeInferenceRule::recognize(const loco::Dialect *d) const
663 return CircleDialect::get() == d;
666 bool CircleTypeInferenceRule::infer(const loco::Node *node, loco::DataType &dtype) const
668 assert(node->dialect() == CircleDialect::get());
670 TypeInferenceAlgorithm alg;
672 auto circle_node = loco::must_cast<const CircleNode *>(node);
673 dtype = circle_node->accept(&alg);
674 assert(dtype != loco::DataType::Unknown);