Imported Upstream version 1.25.0
[platform/core/ml/nnfw.git] / compiler / luci / service / src / CircleTypeInferenceRule.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Service/CircleTypeInferenceRule.h"
18 #include "CircleTypeInferenceHelper.h"
19
20 #include <luci/IR/CircleDialect.h>
21 #include <luci/IR/CircleNodeVisitor.h>
22 #include <luci/IR/CircleNodes.h>
23
24 #include <cassert>
25
26 namespace
27 {
28
29 struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataType>
30 {
31   // TODO Given a tensor x of complex numbers, Abs operation returns a tensor of type float32 or
32   // float64.
33   loco::DataType visit(const luci::CircleAbs *node) final { return luci::dtype_get(node->x()); }
34
35   loco::DataType visit(const luci::CircleAdd *node) final { return luci::dtype_get(node->x()); }
36
37   loco::DataType visit(const luci::CircleAddN *node) final
38   {
39     auto dtype = luci::dtype_get(node->inputs(0));
40
41     for (uint32_t idx = 1; idx < node->arity(); ++idx)
42     {
43       auto dtype_idx = luci::dtype_get(node->inputs(idx));
44       if (dtype != dtype_idx)
45       {
46         INTERNAL_EXN_V("ADD_N dtype not same as the first input: ", idx);
47       }
48     }
49
50     return luci::dtype_get(node->inputs(0));
51   }
52
53   loco::DataType visit(const luci::CircleArgMax *node) final { return node->output_type(); }
54
55   loco::DataType visit(const luci::CircleArgMin *node) final { return node->output_type(); }
56
57   loco::DataType visit(const luci::CircleAveragePool2D *node) final
58   {
59     return luci::dtype_get(node->value());
60   }
61
62   loco::DataType visit(const luci::CircleBatchMatMul *node) final
63   {
64     return luci::dtype_get(node->x());
65   }
66
67   loco::DataType visit(const luci::CircleBatchToSpaceND *node) final
68   {
69     return luci::dtype_get(node->input());
70   }
71
72   loco::DataType visit(const luci::CircleCast *node) final { return node->dtype(); }
73
74   loco::DataType visit(const luci::CircleCeil *node) final { return luci::dtype_get(node->x()); }
75
76   loco::DataType visit(const luci::CircleConcatenation *node) final
77   {
78     // TODO Support when CircleConcatenation has 0 input
79     assert(node->numValues() > 0);
80
81     for (uint32_t i = 1; i < node->numValues(); ++i)
82       assert(luci::dtype_get(node->values(i - 1)) == luci::dtype_get(node->values(i)));
83
84     return luci::dtype_get(node->values(0));
85   }
86
87   loco::DataType visit(const luci::CircleConst *node) final { return node->dtype(); }
88
89   loco::DataType visit(const luci::CircleConv2D *node) final
90   {
91     return luci::dtype_get(node->input());
92   }
93
94   loco::DataType visit(const luci::CircleCos *node) final { return luci::dtype_get(node->x()); }
95
96   loco::DataType visit(const luci::CircleCustom *node) final
97   {
98     if (node->custom_code() == "BatchMatMulV2")
99     {
100       return luci::dtype_get(node->inputs(0));
101     }
102     return node->dtype();
103   }
104
105   loco::DataType visit(const luci::CircleDensify *node) final
106   {
107     return luci::dtype_get(node->input());
108   }
109
110   loco::DataType visit(const luci::CircleDepthToSpace *node) final
111   {
112     return luci::dtype_get(node->input());
113   }
114
115   loco::DataType visit(const luci::CircleDepthwiseConv2D *node) final
116   {
117     return luci::dtype_get(node->input());
118   }
119
120   loco::DataType visit(const luci::CircleDequantize *) final { return loco::DataType::FLOAT32; }
121
122   loco::DataType visit(const luci::CircleDiv *node) final { return luci::dtype_get(node->x()); }
123
124   loco::DataType visit(const luci::CircleElu *node) final
125   {
126     return luci::dtype_get(node->features());
127   }
128
129   loco::DataType visit(const luci::CircleEqual *) final { return loco::DataType::BOOL; }
130
131   loco::DataType visit(const luci::CircleExp *node) final { return luci::dtype_get(node->x()); }
132
133   loco::DataType visit(const luci::CircleExpandDims *node) final
134   {
135     return luci::dtype_get(node->input());
136   }
137
138   loco::DataType visit(const luci::CircleFakeQuant *node) final
139   {
140     return luci::dtype_get(node->inputs());
141   }
142
143   loco::DataType visit(const luci::CircleFill *node) final
144   {
145     return luci::dtype_get(node->value());
146   }
147
148   loco::DataType visit(const luci::CircleFloor *node) final { return luci::dtype_get(node->x()); }
149
150   loco::DataType visit(const luci::CircleFloorDiv *node) final
151   {
152     return luci::dtype_get(node->x());
153   }
154
155   loco::DataType visit(const luci::CircleFloorMod *node) final
156   {
157     return luci::dtype_get(node->x());
158   }
159
160   loco::DataType visit(const luci::CircleFullyConnected *node) final
161   {
162     return luci::dtype_get(node->input());
163   }
164
165   loco::DataType visit(const luci::CircleGather *node) final
166   {
167     return luci::dtype_get(node->params());
168   }
169
170   loco::DataType visit(const luci::CircleGatherNd *node) final
171   {
172     return luci::dtype_get(node->params());
173   }
174
175   loco::DataType visit(const luci::CircleGelu *node) final
176   {
177     return luci::dtype_get(node->features());
178   }
179
180   loco::DataType visit(const luci::CircleGreater *) final { return loco::DataType::BOOL; }
181
182   loco::DataType visit(const luci::CircleGreaterEqual *) final { return loco::DataType::BOOL; }
183
184   loco::DataType visit(const luci::CircleHardSwish *node) final
185   {
186     return luci::dtype_get(node->features());
187   }
188
189   loco::DataType visit(const luci::CircleIf *node) final
190   {
191     // Type of If is not used. Just use input 0
192     assert(node->input_count() > 0);
193     return luci::dtype_get(node->input(0));
194   }
195
196   loco::DataType visit(const luci::CircleL2Normalize *node) final
197   {
198     return luci::dtype_get(node->x());
199   }
200
201   loco::DataType visit(const luci::CircleL2Pool2D *node) final
202   {
203     return luci::dtype_get(node->value());
204   }
205
206   loco::DataType visit(const luci::CircleLeakyRelu *node) final
207   {
208     return luci::dtype_get(node->features());
209   }
210
211   loco::DataType visit(const luci::CircleLess *) final { return loco::DataType::BOOL; }
212
213   loco::DataType visit(const luci::CircleLessEqual *) final { return loco::DataType::BOOL; }
214
215   loco::DataType visit(const luci::CircleLocalResponseNormalization *node) final
216   {
217     return luci::dtype_get(node->input());
218   }
219
220   loco::DataType visit(const luci::CircleLog *node) final { return luci::dtype_get(node->x()); }
221
222   loco::DataType visit(const luci::CircleLogicalAnd *node) final
223   {
224     return luci::dtype_get(node->x());
225   }
226
227   loco::DataType visit(const luci::CircleLogicalNot *node) final
228   {
229     return luci::dtype_get(node->x());
230   }
231
232   loco::DataType visit(const luci::CircleLogicalOr *node) final
233   {
234     return luci::dtype_get(node->x());
235   }
236
237   loco::DataType visit(const luci::CircleLogistic *node) final
238   {
239     return luci::dtype_get(node->x());
240   }
241
242   loco::DataType visit(const luci::CircleLogSoftmax *node) final
243   {
244     return luci::dtype_get(node->logits());
245   }
246
247   loco::DataType visit(const luci::CircleMatrixDiag *node) final
248   {
249     return luci::dtype_get(node->diagonal());
250   }
251
252   loco::DataType visit(const luci::CircleMatrixSetDiag *node) final
253   {
254     return luci::dtype_get(node->input());
255   }
256
257   loco::DataType visit(const luci::CircleMaximum *node) final { return luci::dtype_get(node->x()); }
258
259   loco::DataType visit(const luci::CircleMaxPool2D *node) final
260   {
261     return luci::dtype_get(node->value());
262   }
263
264   loco::DataType visit(const luci::CircleMean *node) final
265   {
266     return luci::dtype_get(node->input());
267   }
268
269   loco::DataType visit(const luci::CircleMinimum *node) final { return luci::dtype_get(node->x()); }
270
271   loco::DataType visit(const luci::CircleMirrorPad *node) final
272   {
273     return luci::dtype_get(node->input());
274   }
275
276   loco::DataType visit(const luci::CircleNeg *node) final { return luci::dtype_get(node->x()); }
277
278   loco::DataType visit(const luci::CircleNonMaxSuppressionV4 *node) final
279   {
280     return luci::dtype_get(node->boxes());
281   }
282
283   loco::DataType visit(const luci::CircleNonMaxSuppressionV5 *node) final
284   {
285     return luci::dtype_get(node->boxes());
286   }
287
288   loco::DataType visit(const luci::CircleNotEqual *) final { return loco::DataType::BOOL; }
289
290   loco::DataType visit(const luci::CirclePack *node) final
291   {
292     // Only support CirclePack with one or more inputs
293     assert(node->values_count() > 0);
294
295     auto first_value_type = luci::dtype_get(node->values(0));
296     for (uint32_t i = 1; i < node->values_count(); ++i)
297       assert(first_value_type == luci::dtype_get(node->values(i)));
298
299     return first_value_type;
300   }
301
302   loco::DataType visit(const luci::CirclePad *node) final { return luci::dtype_get(node->input()); }
303
304   loco::DataType visit(const luci::CirclePadV2 *node) final
305   {
306     return luci::dtype_get(node->input());
307   }
308
309   loco::DataType visit(const luci::CirclePow *node) final
310   {
311     // TODO make sure types cannot differ
312     auto x_type = luci::dtype_get(node->x());
313     auto y_type = luci::dtype_get(node->y());
314
315     if (x_type != y_type)
316       INTERNAL_EXN("Different datatype for x and y are not supported");
317
318     return x_type;
319   }
320
321   loco::DataType visit(const luci::CirclePRelu *node) final
322   {
323     auto input_type = luci::dtype_get(node->input());
324     auto alpha_type = luci::dtype_get(node->alpha());
325
326     if (input_type != alpha_type)
327       INTERNAL_EXN("Different datatype for input and alpha are not supported");
328
329     return input_type;
330   }
331
332   loco::DataType visit(const luci::CircleQuantize *node) final { return luci::dtype_get(node); }
333
334   loco::DataType visit(const luci::CircleRange *node) final
335   {
336     return luci::dtype_get(node->start());
337   }
338
339   loco::DataType visit(const luci::CircleRank *) final { return loco::DataType::S32; }
340
341   loco::DataType visit(const luci::CircleMul *node) final { return luci::dtype_get(node->x()); }
342
343   loco::DataType visit(const luci::CircleOneHot *node) final
344   {
345     return luci::dtype_get(node->on_value());
346   }
347
348   loco::DataType visit(const luci::CircleReduceAny *node) final
349   {
350     return luci::dtype_get(node->input());
351   }
352
353   loco::DataType visit(const luci::CircleReduceMax *node) final
354   {
355     return luci::dtype_get(node->input());
356   }
357
358   loco::DataType visit(const luci::CircleReduceMin *node) final
359   {
360     return luci::dtype_get(node->input());
361   }
362
363   loco::DataType visit(const luci::CircleReduceProd *node) final
364   {
365     return luci::dtype_get(node->input());
366   }
367
368   loco::DataType visit(const luci::CircleRelu *node) final
369   {
370     return luci::dtype_get(node->features());
371   }
372
373   loco::DataType visit(const luci::CircleRelu6 *node) final
374   {
375     return luci::dtype_get(node->features());
376   }
377
378   loco::DataType visit(const luci::CircleReluN1To1 *node) final
379   {
380     return luci::dtype_get(node->features());
381   }
382
383   loco::DataType visit(const luci::CircleReshape *node) final
384   {
385     return luci::dtype_get(node->tensor());
386   }
387
388   loco::DataType visit(const luci::CircleResizeBilinear *node) final
389   {
390     return luci::dtype_get(node->input());
391   }
392
393   loco::DataType visit(const luci::CircleResizeNearestNeighbor *node) final
394   {
395     return luci::dtype_get(node->input());
396   }
397
398   loco::DataType visit(const luci::CircleReverseSequence *node) final
399   {
400     return luci::dtype_get(node->input());
401   }
402
403   loco::DataType visit(const luci::CircleReverseV2 *node) final
404   {
405     return luci::dtype_get(node->tensor());
406   }
407
408   loco::DataType visit(const luci::CircleRound *node) final { return luci::dtype_get(node->x()); }
409
410   loco::DataType visit(const luci::CircleRsqrt *node) final { return luci::dtype_get(node->x()); }
411
412   loco::DataType visit(const luci::CircleScatterNd *node) final
413   {
414     return luci::dtype_get(node->updates());
415   }
416
417   loco::DataType visit(const luci::CircleSegmentSum *node) final
418   {
419     return luci::dtype_get(node->input());
420   }
421
422   loco::DataType visit(const luci::CircleSelect *node) final
423   {
424     assert(luci::dtype_get(node->t()) == luci::dtype_get(node->e()));
425     return luci::dtype_get(node->t());
426   }
427
428   loco::DataType visit(const luci::CircleSelectV2 *node) final
429   {
430     assert(luci::dtype_get(node->t()) == luci::dtype_get(node->e()));
431     return luci::dtype_get(node->t());
432   }
433
434   loco::DataType visit(const luci::CircleShape *node) final { return node->out_type(); }
435
436   loco::DataType visit(const luci::CircleSin *node) final { return luci::dtype_get(node->x()); }
437
438   loco::DataType visit(const luci::CircleSlice *node) final
439   {
440     return luci::dtype_get(node->input());
441   }
442
443   loco::DataType visit(const luci::CircleSoftmax *node) final
444   {
445     return luci::dtype_get(node->logits());
446   }
447
448   loco::DataType visit(const luci::CircleSpaceToBatchND *node) final
449   {
450     return luci::dtype_get(node->input());
451   }
452
453   loco::DataType visit(const luci::CircleSpaceToDepth *node) final
454   {
455     return luci::dtype_get(node->input());
456   }
457
458   loco::DataType visit(const luci::CircleSparseToDense *node) final
459   {
460     return luci::dtype_get(node->values());
461   }
462
463   loco::DataType visit(const luci::CircleSplit *node) final
464   {
465     return luci::dtype_get(node->input());
466   }
467
468   loco::DataType visit(const luci::CircleSplitV *node) final
469   {
470     return luci::dtype_get(node->input());
471   }
472
473   loco::DataType visit(const luci::CircleSqrt *node) final { return luci::dtype_get(node->x()); }
474
475   loco::DataType visit(const luci::CircleSquare *node) final { return luci::dtype_get(node->x()); }
476
477   loco::DataType visit(const luci::CircleSquaredDifference *node) final
478   {
479     return luci::dtype_get(node->x());
480   }
481
482   loco::DataType visit(const luci::CircleSqueeze *node) final
483   {
484     return luci::dtype_get(node->input());
485   }
486
487   loco::DataType visit(const luci::CircleStridedSlice *node) final
488   {
489     return luci::dtype_get(node->input());
490   }
491
492   loco::DataType visit(const luci::CircleSub *node) final { return luci::dtype_get(node->x()); }
493
494   loco::DataType visit(const luci::CircleSum *node) final { return luci::dtype_get(node->input()); }
495
496   loco::DataType visit(const luci::CircleSVDF *node) final
497   {
498     return luci::dtype_get(node->input());
499   }
500
501   loco::DataType visit(const luci::CircleTanh *node) final { return luci::dtype_get(node->x()); }
502
503   loco::DataType visit(const luci::CircleTile *node) final
504   {
505     return luci::dtype_get(node->input());
506   }
507
508   loco::DataType visit(const luci::CircleTopKV2 *node) final
509   {
510     return luci::dtype_get(node->input());
511   }
512
513   loco::DataType visit(const luci::CircleTranspose *node) final
514   {
515     return luci::dtype_get(node->a());
516   }
517
518   loco::DataType visit(const luci::CircleTransposeConv *node) final
519   {
520     return luci::dtype_get(node->outBackprop());
521   }
522
523   loco::DataType visit(const luci::CircleUnidirectionalSequenceLSTM *node) final
524   {
525     return luci::dtype_get(node->input());
526   }
527
528   loco::DataType visit(const luci::CircleUnique *node) final
529   {
530     return luci::dtype_get(node->input());
531   }
532
533   loco::DataType visit(const luci::CircleUnpack *node) final
534   {
535     return luci::dtype_get(node->value());
536   }
537
538   loco::DataType visit(const luci::CircleWhere *) final { return loco::DataType::S64; }
539
540   loco::DataType visit(const luci::CircleWhile *node) final
541   {
542     // Type of While is not used. Just use input 0
543     assert(node->input_count() > 0);
544     return luci::dtype_get(node->input(0));
545   }
546
547   loco::DataType visit(const luci::CircleZerosLike *node) final
548   {
549     return luci::dtype_get(node->input());
550   }
551
552   // Circle Only
553   loco::DataType visit(const luci::CircleBCQFullyConnected *) final
554   {
555     return loco::DataType::FLOAT32;
556   }
557
558   loco::DataType visit(const luci::CircleBCQGather *) final { return loco::DataType::FLOAT32; }
559
560   loco::DataType visit(const luci::CircleInstanceNorm *node) final
561   {
562     return luci::dtype_get(node->input());
563   }
564
565   // Virtual
566   loco::DataType visit(const luci::CircleInput *node) final { return node->dtype(); }
567
568   loco::DataType visit(const luci::CircleOutput *node) final
569   {
570     auto graph_outputs = node->graph()->outputs();
571     auto graph_output = graph_outputs->at(node->index());
572     auto output_dtype = graph_output->dtype();
573
574     if (dynamic_cast<luci::CircleOutputDummy *>(node->from()) == nullptr &&
575         dynamic_cast<luci::CircleOutputExclude *>(node->from()) == nullptr)
576     {
577       // We don't care for the type if from() is CircleOutputDummy or CircleOutputExclude
578       // from() type should match that of CircleOutput
579       assert(output_dtype == luci::dtype_get(node->from()));
580     }
581     return output_dtype;
582   }
583
584   loco::DataType visit(const luci::CircleOutputDummy *node) final { return node->dtype(); }
585
586   loco::DataType visit(const luci::CircleOutputExclude *node) final
587   {
588     // NOTE We don't care CircleOutputExclude dtype, but set to FLOAT32
589     //      if it's Unknown to make type inference happy.
590     if (node->dtype() == loco::DataType::Unknown)
591       return loco::DataType::FLOAT32;
592     return node->dtype();
593   }
594
595   loco::DataType visit(const luci::CircleCustomOut *node) final { return node->dtype(); }
596
597   loco::DataType visit(const luci::CircleNonMaxSuppressionV4Out *node) final
598   {
599     (void)node;
600     assert(node->index() == 0 || node->index() == 1);
601     return loco::DataType::S32;
602   }
603
604   loco::DataType visit(const luci::CircleNonMaxSuppressionV5Out *node) final
605   {
606     (void)node;
607     if (node->index() == 0 || node->index() == 2)
608     {
609       return loco::DataType::S32;
610     }
611     assert(node->index() == 1);
612     return loco::DataType::FLOAT32;
613   }
614
615   loco::DataType visit(const luci::CircleSplitOut *node) final
616   {
617     return luci::dtype_get(node->input());
618   }
619
620   loco::DataType visit(const luci::CircleSplitVOut *node) final
621   {
622     return luci::dtype_get(node->input());
623   }
624
625   loco::DataType visit(const luci::CircleTopKV2Out *node) final
626   {
627     // First output is same as input
628     if (node->index() == 0)
629       return luci::dtype_get(node->input());
630     // Second outout is always S32
631     assert(node->index() == 1);
632     return loco::DataType::S32;
633   }
634
635   loco::DataType visit(const luci::CircleVariable *node) final { return node->dtype(); }
636
637   loco::DataType visit(const luci::CircleUniqueOut *node) final
638   {
639     if (node->index() == 0)
640     {
641       return luci::dtype_get(node->input());
642     }
643     assert(node->index() == 1);
644     auto unique = loco::must_cast<luci::CircleUnique *>(node->input());
645     return unique->idx_out_type();
646   }
647
648   loco::DataType visit(const luci::CircleUnpackOut *node) final
649   {
650     return luci::dtype_get(node->input());
651   }
652
653   loco::DataType visit(const luci::CircleWhileOut *node) final
654   {
655     /**
656      * @note  WHILE operator's type is the same with the "cond"
657      *        Graph Input.
658      */
659     auto circle_while = dynamic_cast<const luci::CircleWhile *>(node->input());
660     if (circle_while == nullptr)
661     {
662       INTERNAL_EXN("CircleWhile IR is not configured correctly");
663     }
664
665     auto index = node->index();
666     auto cond_graph = circle_while->cond_graph();
667     assert(cond_graph != nullptr);
668
669     // Assumption: the index of CircleWhileOut matches with the index of input nodes returned by
670     // loco::input_nodes
671     auto cond_inputs = loco::input_nodes(cond_graph);
672     auto cond_in = loco::must_cast<luci::CircleInput *>(cond_inputs.at(index));
673
674     auto cond_graph_inputs = cond_graph->inputs();
675     auto cond_graph_input = cond_graph_inputs->at(cond_in->index());
676
677     return cond_graph_input->dtype();
678   }
679 };
680
681 } // namespace
682
683 namespace luci
684 {
685
686 bool CircleTypeInferenceRule::recognize(const loco::Dialect *d) const
687 {
688   return CircleDialect::get() == d;
689 }
690
691 bool CircleTypeInferenceRule::infer(const loco::Node *node, loco::DataType &dtype) const
692 {
693   assert(node->dialect() == CircleDialect::get());
694
695   TypeInferenceAlgorithm alg;
696
697   auto circle_node = loco::must_cast<const CircleNode *>(node);
698   dtype = circle_node->accept(&alg);
699   assert(dtype != loco::DataType::Unknown);
700
701   return true;
702 }
703
704 } // namespace luci