Imported Upstream version 1.9.0
[platform/core/ml/nnfw.git] / compiler / luci / service / src / CircleTypeInferenceRule.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Service/CircleTypeInferenceRule.h"
18
19 #include <luci/IR/CircleDialect.h>
20 #include <luci/IR/CircleNodeVisitor.h>
21 #include <luci/IR/CircleNodes.h>
22
23 #include <cassert>
24
25 namespace
26 {
27
28 struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataType>
29 {
30   // TODO Given a tensor x of complex numbers, Abs operation returns a tensor of type float32 or
31   // float64.
32   loco::DataType visit(const luci::CircleAbs *node) final { return loco::dtype_get(node->x()); }
33
34   loco::DataType visit(const luci::CircleAdd *node) final { return loco::dtype_get(node->x()); }
35
36   loco::DataType visit(const luci::CircleAddN *node) final
37   {
38     auto dtype = loco::dtype_get(node->inputs(0));
39
40     for (uint32_t idx = 1; idx < node->arity(); ++idx)
41     {
42       auto dtype_idx = loco::dtype_get(node->inputs(idx));
43       if (dtype != dtype_idx)
44       {
45         INTERNAL_EXN_V("ADD_N dtype not same as the first input: ", idx);
46       }
47     }
48
49     return loco::dtype_get(node->inputs(0));
50   }
51
52   loco::DataType visit(const luci::CircleArgMax *node) final { return node->output_type(); }
53
54   loco::DataType visit(const luci::CircleArgMin *node) final { return node->output_type(); }
55
56   loco::DataType visit(const luci::CircleAveragePool2D *node) final
57   {
58     return loco::dtype_get(node->value());
59   }
60
61   loco::DataType visit(const luci::CircleBatchMatMul *node) final
62   {
63     return loco::dtype_get(node->x());
64   }
65
66   loco::DataType visit(const luci::CircleBatchToSpaceND *node) final
67   {
68     return loco::dtype_get(node->input());
69   }
70
71   loco::DataType visit(const luci::CircleCast *node) final { return node->dtype(); }
72
73   loco::DataType visit(const luci::CircleCeil *node) final { return loco::dtype_get(node->x()); }
74
75   loco::DataType visit(const luci::CircleConcatenation *node) final
76   {
77     // TODO Support when CircleConcatenation has 0 input
78     assert(node->numValues() > 0);
79
80     for (uint32_t i = 1; i < node->numValues(); ++i)
81       assert(loco::dtype_get(node->values(i - 1)) == loco::dtype_get(node->values(i)));
82
83     return loco::dtype_get(node->values(0));
84   }
85
86   loco::DataType visit(const luci::CircleConst *node) final { return node->dtype(); }
87
88   loco::DataType visit(const luci::CircleConv2D *node) final
89   {
90     return loco::dtype_get(node->input());
91   }
92
93   loco::DataType visit(const luci::CircleCos *node) final { return loco::dtype_get(node->x()); }
94
95   loco::DataType visit(const luci::CircleCustom *node) final
96   {
97     if (node->custom_code() == "BatchMatMulV2")
98     {
99       return loco::dtype_get(node->inputs(0));
100     }
101     return node->dtype();
102   }
103
104   loco::DataType visit(const luci::CircleDepthToSpace *node) final
105   {
106     return loco::dtype_get(node->input());
107   }
108
109   loco::DataType visit(const luci::CircleDepthwiseConv2D *node) final
110   {
111     return loco::dtype_get(node->input());
112   }
113
114   loco::DataType visit(const luci::CircleDiv *node) final { return loco::dtype_get(node->x()); }
115
116   loco::DataType visit(const luci::CircleElu *node) final
117   {
118     return loco::dtype_get(node->features());
119   }
120
121   loco::DataType visit(const luci::CircleEqual *) final { return loco::DataType::BOOL; }
122
123   loco::DataType visit(const luci::CircleExp *node) final { return loco::dtype_get(node->x()); }
124
125   loco::DataType visit(const luci::CircleExpandDims *node) final
126   {
127     return loco::dtype_get(node->input());
128   }
129
130   loco::DataType visit(const luci::CircleFill *node) final
131   {
132     return loco::dtype_get(node->value());
133   }
134
135   loco::DataType visit(const luci::CircleFloor *node) final { return loco::dtype_get(node->x()); }
136
137   loco::DataType visit(const luci::CircleFloorDiv *node) final
138   {
139     return loco::dtype_get(node->x());
140   }
141
142   loco::DataType visit(const luci::CircleFloorMod *node) final
143   {
144     return loco::dtype_get(node->x());
145   }
146
147   loco::DataType visit(const luci::CircleFullyConnected *node) final
148   {
149     return loco::dtype_get(node->input());
150   }
151
152   loco::DataType visit(const luci::CircleGather *node) final
153   {
154     return loco::dtype_get(node->params());
155   }
156
157   loco::DataType visit(const luci::CircleGatherNd *node) final
158   {
159     return loco::dtype_get(node->params());
160   }
161
162   loco::DataType visit(const luci::CircleGreater *) final { return loco::DataType::BOOL; }
163
164   loco::DataType visit(const luci::CircleGreaterEqual *) final { return loco::DataType::BOOL; }
165
166   loco::DataType visit(const luci::CircleIf *node) final
167   {
168     // Type of If is not used. Just use input 0
169     assert(node->input_count() > 0);
170     return loco::dtype_get(node->input(0));
171   }
172
173   loco::DataType visit(const luci::CircleL2Normalize *node) final
174   {
175     return loco::dtype_get(node->x());
176   }
177
178   loco::DataType visit(const luci::CircleL2Pool2D *node) final
179   {
180     return loco::dtype_get(node->value());
181   }
182
183   loco::DataType visit(const luci::CircleLeakyRelu *node) final
184   {
185     return loco::dtype_get(node->features());
186   }
187
188   loco::DataType visit(const luci::CircleLess *) final { return loco::DataType::BOOL; }
189
190   loco::DataType visit(const luci::CircleLessEqual *) final { return loco::DataType::BOOL; }
191
192   loco::DataType visit(const luci::CircleLocalResponseNormalization *node) final
193   {
194     return loco::dtype_get(node->input());
195   }
196
197   loco::DataType visit(const luci::CircleLog *node) final { return loco::dtype_get(node->x()); }
198
199   loco::DataType visit(const luci::CircleLogicalAnd *node) final
200   {
201     return loco::dtype_get(node->x());
202   }
203
204   loco::DataType visit(const luci::CircleLogicalNot *node) final
205   {
206     return loco::dtype_get(node->x());
207   }
208
209   loco::DataType visit(const luci::CircleLogicalOr *node) final
210   {
211     return loco::dtype_get(node->x());
212   }
213
214   loco::DataType visit(const luci::CircleLogistic *node) final
215   {
216     return loco::dtype_get(node->x());
217   }
218
219   loco::DataType visit(const luci::CircleLogSoftmax *node) final
220   {
221     return loco::dtype_get(node->logits());
222   }
223
224   loco::DataType visit(const luci::CircleMatrixDiag *node) final
225   {
226     return loco::dtype_get(node->diagonal());
227   }
228
229   loco::DataType visit(const luci::CircleMatrixSetDiag *node) final
230   {
231     return loco::dtype_get(node->input());
232   }
233
234   loco::DataType visit(const luci::CircleMaximum *node) final { return loco::dtype_get(node->x()); }
235
236   loco::DataType visit(const luci::CircleMaxPool2D *node) final
237   {
238     return loco::dtype_get(node->value());
239   }
240
241   loco::DataType visit(const luci::CircleMean *node) final
242   {
243     return loco::dtype_get(node->input());
244   }
245
246   loco::DataType visit(const luci::CircleMinimum *node) final { return loco::dtype_get(node->x()); }
247
248   loco::DataType visit(const luci::CircleMirrorPad *node) final
249   {
250     return loco::dtype_get(node->input());
251   }
252
253   loco::DataType visit(const luci::CircleNeg *node) final { return loco::dtype_get(node->x()); }
254
255   loco::DataType visit(const luci::CircleNonMaxSuppressionV4 *node) final
256   {
257     return loco::dtype_get(node->boxes());
258   }
259
260   loco::DataType visit(const luci::CircleNonMaxSuppressionV5 *node) final
261   {
262     return loco::dtype_get(node->boxes());
263   }
264
265   loco::DataType visit(const luci::CircleNotEqual *) final { return loco::DataType::BOOL; }
266
267   loco::DataType visit(const luci::CirclePack *node) final
268   {
269     // Only support CirclePack with one or more inputs
270     assert(node->values_count() > 0);
271
272     auto first_value_type = loco::dtype_get(node->values(0));
273     for (uint32_t i = 1; i < node->values_count(); ++i)
274       assert(first_value_type == loco::dtype_get(node->values(i)));
275
276     return first_value_type;
277   }
278
279   loco::DataType visit(const luci::CirclePad *node) final { return loco::dtype_get(node->input()); }
280
281   loco::DataType visit(const luci::CirclePadV2 *node) final
282   {
283     return loco::dtype_get(node->input());
284   }
285
286   loco::DataType visit(const luci::CirclePow *node) final
287   {
288     // TODO make sure types cannot differ
289     auto x_type = loco::dtype_get(node->x());
290     auto y_type = loco::dtype_get(node->y());
291
292     if (x_type != y_type)
293       INTERNAL_EXN("Different datatype for x and y are not supported");
294
295     return x_type;
296   }
297
298   loco::DataType visit(const luci::CirclePRelu *node) final
299   {
300     auto input_type = loco::dtype_get(node->input());
301     auto alpha_type = loco::dtype_get(node->alpha());
302
303     if (input_type != alpha_type)
304       INTERNAL_EXN("Different datatype for input and alpha are not supported");
305
306     return input_type;
307   }
308
309   loco::DataType visit(const luci::CircleRange *node) final
310   {
311     return loco::dtype_get(node->start());
312   }
313
314   loco::DataType visit(const luci::CircleRank *) final { return loco::DataType::S32; }
315
316   loco::DataType visit(const luci::CircleMul *node) final { return loco::dtype_get(node->x()); }
317
318   loco::DataType visit(const luci::CircleOneHot *node) final
319   {
320     return loco::dtype_get(node->on_value());
321   }
322
323   loco::DataType visit(const luci::CircleReduceAny *node) final
324   {
325     return loco::dtype_get(node->input());
326   }
327
328   loco::DataType visit(const luci::CircleReduceMax *node) final
329   {
330     return loco::dtype_get(node->input());
331   }
332
333   loco::DataType visit(const luci::CircleReduceMin *node) final
334   {
335     return loco::dtype_get(node->input());
336   }
337
338   loco::DataType visit(const luci::CircleReduceProd *node) final
339   {
340     return loco::dtype_get(node->input());
341   }
342
343   loco::DataType visit(const luci::CircleRelu *node) final
344   {
345     return loco::dtype_get(node->features());
346   }
347
348   loco::DataType visit(const luci::CircleRelu6 *node) final
349   {
350     return loco::dtype_get(node->features());
351   }
352
353   loco::DataType visit(const luci::CircleReluN1To1 *node) final
354   {
355     return loco::dtype_get(node->features());
356   }
357
358   loco::DataType visit(const luci::CircleReshape *node) final
359   {
360     return loco::dtype_get(node->tensor());
361   }
362
363   loco::DataType visit(const luci::CircleResizeBilinear *node) final
364   {
365     return loco::dtype_get(node->input());
366   }
367
368   loco::DataType visit(const luci::CircleResizeNearestNeighbor *node) final
369   {
370     return loco::dtype_get(node->input());
371   }
372
373   loco::DataType visit(const luci::CircleReverseSequence *node) final
374   {
375     return loco::dtype_get(node->input());
376   }
377
378   loco::DataType visit(const luci::CircleReverseV2 *node) final
379   {
380     return loco::dtype_get(node->tensor());
381   }
382
383   loco::DataType visit(const luci::CircleRound *node) final { return loco::dtype_get(node->x()); }
384
385   loco::DataType visit(const luci::CircleRsqrt *node) final { return loco::dtype_get(node->x()); }
386
387   loco::DataType visit(const luci::CircleScatterNd *node) final
388   {
389     return loco::dtype_get(node->updates());
390   }
391
392   loco::DataType visit(const luci::CircleSegmentSum *node) final
393   {
394     return loco::dtype_get(node->input());
395   }
396
397   loco::DataType visit(const luci::CircleSelect *node) final
398   {
399     assert(loco::dtype_get(node->t()) == loco::dtype_get(node->e()));
400     return loco::dtype_get(node->t());
401   }
402
403   loco::DataType visit(const luci::CircleSelectV2 *node) final
404   {
405     assert(loco::dtype_get(node->t()) == loco::dtype_get(node->e()));
406     return loco::dtype_get(node->t());
407   }
408
409   loco::DataType visit(const luci::CircleShape *node) final { return node->out_type(); }
410
411   loco::DataType visit(const luci::CircleSin *node) final { return loco::dtype_get(node->x()); }
412
413   loco::DataType visit(const luci::CircleSlice *node) final
414   {
415     return loco::dtype_get(node->input());
416   }
417
418   loco::DataType visit(const luci::CircleSoftmax *node) final
419   {
420     return loco::dtype_get(node->logits());
421   }
422
423   loco::DataType visit(const luci::CircleSpaceToBatchND *node) final
424   {
425     return loco::dtype_get(node->input());
426   }
427
428   loco::DataType visit(const luci::CircleSpaceToDepth *node) final
429   {
430     return loco::dtype_get(node->input());
431   }
432
433   loco::DataType visit(const luci::CircleSparseToDense *node) final
434   {
435     return loco::dtype_get(node->values());
436   }
437
438   loco::DataType visit(const luci::CircleSplit *node) final
439   {
440     return loco::dtype_get(node->input());
441   }
442
443   loco::DataType visit(const luci::CircleSplitV *node) final
444   {
445     return loco::dtype_get(node->input());
446   }
447
448   loco::DataType visit(const luci::CircleSqrt *node) final { return loco::dtype_get(node->x()); }
449
450   loco::DataType visit(const luci::CircleSquare *node) final { return loco::dtype_get(node->x()); }
451
452   loco::DataType visit(const luci::CircleSquaredDifference *node) final
453   {
454     return loco::dtype_get(node->x());
455   }
456
457   loco::DataType visit(const luci::CircleSqueeze *node) final
458   {
459     return loco::dtype_get(node->input());
460   }
461
462   loco::DataType visit(const luci::CircleStridedSlice *node) final
463   {
464     return loco::dtype_get(node->input());
465   }
466
467   loco::DataType visit(const luci::CircleSub *node) final { return loco::dtype_get(node->x()); }
468
469   loco::DataType visit(const luci::CircleSum *node) final { return loco::dtype_get(node->input()); }
470
471   loco::DataType visit(const luci::CircleTanh *node) final { return loco::dtype_get(node->x()); }
472
473   loco::DataType visit(const luci::CircleTile *node) final
474   {
475     return loco::dtype_get(node->input());
476   }
477
478   loco::DataType visit(const luci::CircleTopKV2 *node) final
479   {
480     return loco::dtype_get(node->input());
481   }
482
483   loco::DataType visit(const luci::CircleTranspose *node) final
484   {
485     return loco::dtype_get(node->a());
486   }
487
488   loco::DataType visit(const luci::CircleTransposeConv *node) final
489   {
490     return loco::dtype_get(node->outBackprop());
491   }
492
493   loco::DataType visit(const luci::CircleUnique *node) final
494   {
495     return loco::dtype_get(node->input());
496   }
497
498   loco::DataType visit(const luci::CircleUnpack *node) final
499   {
500     return loco::dtype_get(node->value());
501   }
502
503   loco::DataType visit(const luci::CircleWhere *) final { return loco::DataType::S64; }
504
505   loco::DataType visit(const luci::CircleWhile *node) final
506   {
507     // Type of While is not used. Just use input 0
508     assert(node->input_count() > 0);
509     return loco::dtype_get(node->input(0));
510   }
511
512   loco::DataType visit(const luci::CircleZerosLike *node) final
513   {
514     return loco::dtype_get(node->input());
515   }
516
517   // Circle Only
518   loco::DataType visit(const luci::CircleBCQFullyConnected *) final
519   {
520     return loco::DataType::FLOAT32;
521   }
522
523   loco::DataType visit(const luci::CircleBCQGather *) final { return loco::DataType::FLOAT32; }
524
525   loco::DataType visit(const luci::CircleInstanceNorm *node) final
526   {
527     return loco::dtype_get(node->input());
528   }
529
530   // Virtual
531   loco::DataType visit(const luci::CircleInput *node) final { return node->dtype(); }
532
533   loco::DataType visit(const luci::CircleOutput *node) final
534   {
535     auto graph_outputs = node->graph()->outputs();
536     auto graph_output = graph_outputs->at(node->index());
537     auto output_dtype = graph_output->dtype();
538
539     if (dynamic_cast<luci::CircleOutputDummy *>(node->from()) == nullptr &&
540         dynamic_cast<luci::CircleOutputExclude *>(node->from()) == nullptr)
541     {
542       // We don't care for the type if from() is CircleOutputDummy or CircleOutputExclude
543       // from() type should match that of CircleOutput
544       assert(output_dtype == loco::dtype_get(node->from()));
545     }
546     return output_dtype;
547   }
548
549   loco::DataType visit(const luci::CircleOutputDummy *node) final { return node->dtype(); }
550
551   loco::DataType visit(const luci::CircleOutputExclude *node) final { return node->dtype(); }
552
553   loco::DataType visit(const luci::CircleCustomOut *node) final { return node->dtype(); }
554
555   loco::DataType visit(const luci::CircleIfOut *node) final
556   {
557     /**
558      * @note  IF operator type and shape are that of the "then" and "else"
559      *        Graph Outputs.
560      */
561     auto circle_if = dynamic_cast<const luci::CircleIf *>(node->input());
562     if (circle_if == nullptr)
563     {
564       INTERNAL_EXN("CircleIf IR is not configured correctly");
565     }
566
567     auto index = node->index();
568     auto then_graph = circle_if->then_graph();
569     auto else_graph = circle_if->else_graph();
570     assert(then_graph != nullptr);
571     assert(else_graph != nullptr);
572
573     // shape and type are assumed to be same
574     // these are checked at post_import_graph() in Import
575     auto then_outputs = loco::output_nodes(then_graph);
576     auto else_outputs = loco::output_nodes(else_graph);
577     assert(then_outputs.size() == else_outputs.size());
578     assert(index < static_cast<int32_t>(then_outputs.size()));
579
580     auto then_out = loco::must_cast<luci::CircleOutput *>(then_outputs.at(index));
581     auto else_out = loco::must_cast<luci::CircleOutput *>(else_outputs.at(index));
582
583     auto then_graph_outputs = then_graph->outputs(); // loco::GraphOutput items
584     auto else_graph_outputs = else_graph->outputs();
585     assert(then_graph_outputs->size() == else_graph_outputs->size());
586
587     auto then_graph_output = then_graph_outputs->at(then_out->index());
588     auto else_graph_output = else_graph_outputs->at(else_out->index());
589     (void)else_graph_output; // make compiler happy for unused variable warnings
590     assert(then_graph_output->dtype() == else_graph_output->dtype());
591
592     return then_graph_output->dtype();
593   }
594
595   loco::DataType visit(const luci::CircleNonMaxSuppressionV4Out *node) final
596   {
597     (void)node;
598     assert(node->index() == 0 || node->index() == 1);
599     return loco::DataType::S32;
600   }
601
602   loco::DataType visit(const luci::CircleNonMaxSuppressionV5Out *node) final
603   {
604     (void)node;
605     if (node->index() == 0 || node->index() == 2)
606     {
607       return loco::DataType::S32;
608     }
609     assert(node->index() == 1);
610     return loco::DataType::FLOAT32;
611   }
612
613   loco::DataType visit(const luci::CircleSplitOut *node) final
614   {
615     return loco::dtype_get(node->input());
616   }
617
618   loco::DataType visit(const luci::CircleSplitVOut *node) final
619   {
620     return loco::dtype_get(node->input());
621   }
622
623   loco::DataType visit(const luci::CircleTopKV2Out *node) final
624   {
625     // First output is same as input
626     if (node->index() == 0)
627       return loco::dtype_get(node->input());
628     // Second outout is always S32
629     assert(node->index() == 1);
630     return loco::DataType::S32;
631   }
632
633   loco::DataType visit(const luci::CircleUniqueOut *node) final
634   {
635     if (node->index() == 0)
636     {
637       return loco::dtype_get(node->input());
638     }
639     assert(node->index() == 1);
640     auto unique = loco::must_cast<luci::CircleUnique *>(node->input());
641     return unique->idx_out_type();
642   }
643
644   loco::DataType visit(const luci::CircleUnpackOut *node) final
645   {
646     return loco::dtype_get(node->input());
647   }
648
649   loco::DataType visit(const luci::CircleWhileOut *node) final
650   {
651     /**
652      * @note  WHILE operator's type is the same with the "cond"
653      *        Graph Input.
654      */
655     auto circle_while = dynamic_cast<const luci::CircleWhile *>(node->input());
656     if (circle_while == nullptr)
657     {
658       INTERNAL_EXN("CircleWhile IR is not configured correctly");
659     }
660
661     auto index = node->index();
662     auto cond_graph = circle_while->cond_graph();
663     assert(cond_graph != nullptr);
664
665     // Assumption: the index of CircleWhileOut matches with the index of input nodes returned by
666     // loco::input_nodes
667     auto cond_inputs = loco::input_nodes(cond_graph);
668     auto cond_in = loco::must_cast<luci::CircleInput *>(cond_inputs.at(index));
669
670     auto cond_graph_inputs = cond_graph->inputs();
671     auto cond_graph_input = cond_graph_inputs->at(cond_in->index());
672
673     return cond_graph_input->dtype();
674   }
675 };
676
677 } // namespace
678
679 namespace luci
680 {
681
682 bool CircleTypeInferenceRule::recognize(const loco::Dialect *d) const
683 {
684   return CircleDialect::get() == d;
685 }
686
687 bool CircleTypeInferenceRule::infer(const loco::Node *node, loco::DataType &dtype) const
688 {
689   assert(node->dialect() == CircleDialect::get());
690
691   TypeInferenceAlgorithm alg;
692
693   auto circle_node = loco::must_cast<const CircleNode *>(node);
694   dtype = circle_node->accept(&alg);
695   assert(dtype != loco::DataType::Unknown);
696
697   return true;
698 }
699
700 } // namespace luci