Imported Upstream version 1.8.0
[platform/core/ml/nnfw.git] / compiler / luci / service / src / CircleTypeInferenceRule.cpp
1 /*
2  * Copyright (c) 2020 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #include "luci/Service/CircleTypeInferenceRule.h"
18
19 #include <luci/IR/CircleDialect.h>
20 #include <luci/IR/CircleNodeVisitor.h>
21 #include <luci/IR/CircleNodes.h>
22
23 #include <cassert>
24
25 namespace
26 {
27
28 struct TypeInferenceAlgorithm final : public luci::CircleNodeVisitor<loco::DataType>
29 {
30   // TODO Given a tensor x of complex numbers, Abs operation returns a tensor of type float32 or
31   // float64.
32   loco::DataType visit(const luci::CircleAbs *node) final { return loco::dtype_get(node->x()); }
33
34   loco::DataType visit(const luci::CircleAdd *node) final { return loco::dtype_get(node->x()); }
35
36   loco::DataType visit(const luci::CircleAddN *node) final
37   {
38     auto dtype = loco::dtype_get(node->inputs(0));
39
40     for (uint32_t idx = 1; idx < node->arity(); ++idx)
41     {
42       auto dtype_idx = loco::dtype_get(node->inputs(idx));
43       if (dtype != dtype_idx)
44       {
45         INTERNAL_EXN_V("ADD_N dtype not same as the first input: ", idx);
46       }
47     }
48
49     return loco::dtype_get(node->inputs(0));
50   }
51
52   loco::DataType visit(const luci::CircleArgMax *node) final { return node->output_type(); }
53
54   loco::DataType visit(const luci::CircleArgMin *node) final { return node->output_type(); }
55
56   loco::DataType visit(const luci::CircleAveragePool2D *node) final
57   {
58     return loco::dtype_get(node->value());
59   }
60
61   loco::DataType visit(const luci::CircleBatchMatMul *node) final
62   {
63     return loco::dtype_get(node->x());
64   }
65
66   loco::DataType visit(const luci::CircleBatchToSpaceND *node) final
67   {
68     return loco::dtype_get(node->input());
69   }
70
71   loco::DataType visit(const luci::CircleCast *node) final { return node->dtype(); }
72
73   loco::DataType visit(const luci::CircleCeil *node) final { return loco::dtype_get(node->x()); }
74
75   loco::DataType visit(const luci::CircleConcatenation *node) final
76   {
77     // TODO Support when CircleConcatenation has 0 input
78     assert(node->numValues() > 0);
79
80     for (uint32_t i = 1; i < node->numValues(); ++i)
81       assert(loco::dtype_get(node->values(i - 1)) == loco::dtype_get(node->values(i)));
82
83     return loco::dtype_get(node->values(0));
84   }
85
86   loco::DataType visit(const luci::CircleConst *node) final { return node->dtype(); }
87
88   loco::DataType visit(const luci::CircleConv2D *node) final
89   {
90     return loco::dtype_get(node->input());
91   }
92
93   loco::DataType visit(const luci::CircleCos *node) final { return loco::dtype_get(node->x()); }
94
95   loco::DataType visit(const luci::CircleCustom *node) final
96   {
97     if (node->custom_code() == "BatchMatMulV2")
98     {
99       return loco::dtype_get(node->inputs(0));
100     }
101     return node->dtype();
102   }
103
104   loco::DataType visit(const luci::CircleDepthToSpace *node) final
105   {
106     return loco::dtype_get(node->input());
107   }
108
109   loco::DataType visit(const luci::CircleDepthwiseConv2D *node) final
110   {
111     return loco::dtype_get(node->input());
112   }
113
114   loco::DataType visit(const luci::CircleDiv *node) final { return loco::dtype_get(node->x()); }
115
116   loco::DataType visit(const luci::CircleElu *node) final
117   {
118     return loco::dtype_get(node->features());
119   }
120
121   loco::DataType visit(const luci::CircleEqual *) final { return loco::DataType::BOOL; }
122
123   loco::DataType visit(const luci::CircleExp *node) final { return loco::dtype_get(node->x()); }
124
125   loco::DataType visit(const luci::CircleExpandDims *node) final
126   {
127     return loco::dtype_get(node->input());
128   }
129
130   loco::DataType visit(const luci::CircleFill *node) final
131   {
132     return loco::dtype_get(node->value());
133   }
134
135   loco::DataType visit(const luci::CircleFloor *node) final { return loco::dtype_get(node->x()); }
136
137   loco::DataType visit(const luci::CircleFloorDiv *node) final
138   {
139     return loco::dtype_get(node->x());
140   }
141
142   loco::DataType visit(const luci::CircleFloorMod *node) final
143   {
144     return loco::dtype_get(node->x());
145   }
146
147   loco::DataType visit(const luci::CircleFullyConnected *node) final
148   {
149     return loco::dtype_get(node->input());
150   }
151
152   loco::DataType visit(const luci::CircleGather *node) final
153   {
154     return loco::dtype_get(node->params());
155   }
156
157   loco::DataType visit(const luci::CircleGatherNd *node) final
158   {
159     return loco::dtype_get(node->params());
160   }
161
162   loco::DataType visit(const luci::CircleGreater *) final { return loco::DataType::BOOL; }
163
164   loco::DataType visit(const luci::CircleGreaterEqual *) final { return loco::DataType::BOOL; }
165
166   loco::DataType visit(const luci::CircleIf *node) final
167   {
168     // Type of If is not used. Just use input 0
169     assert(node->input_count() > 0);
170     return loco::dtype_get(node->input(0));
171   }
172
173   loco::DataType visit(const luci::CircleL2Normalize *node) final
174   {
175     return loco::dtype_get(node->x());
176   }
177
178   loco::DataType visit(const luci::CircleL2Pool2D *node) final
179   {
180     return loco::dtype_get(node->value());
181   }
182
183   loco::DataType visit(const luci::CircleLeakyRelu *node) final
184   {
185     return loco::dtype_get(node->features());
186   }
187
188   loco::DataType visit(const luci::CircleLess *) final { return loco::DataType::BOOL; }
189
190   loco::DataType visit(const luci::CircleLessEqual *) final { return loco::DataType::BOOL; }
191
192   loco::DataType visit(const luci::CircleLocalResponseNormalization *node) final
193   {
194     return loco::dtype_get(node->input());
195   }
196
197   loco::DataType visit(const luci::CircleLog *node) final { return loco::dtype_get(node->x()); }
198
199   loco::DataType visit(const luci::CircleLogicalAnd *node) final
200   {
201     return loco::dtype_get(node->x());
202   }
203
204   loco::DataType visit(const luci::CircleLogicalNot *node) final
205   {
206     return loco::dtype_get(node->x());
207   }
208
209   loco::DataType visit(const luci::CircleLogicalOr *node) final
210   {
211     return loco::dtype_get(node->x());
212   }
213
214   loco::DataType visit(const luci::CircleLogistic *node) final
215   {
216     return loco::dtype_get(node->x());
217   }
218
219   loco::DataType visit(const luci::CircleLogSoftmax *node) final
220   {
221     return loco::dtype_get(node->logits());
222   }
223
224   loco::DataType visit(const luci::CircleMatrixDiag *node) final
225   {
226     return loco::dtype_get(node->diagonal());
227   }
228
229   loco::DataType visit(const luci::CircleMatrixSetDiag *node) final
230   {
231     return loco::dtype_get(node->input());
232   }
233
234   loco::DataType visit(const luci::CircleMaximum *node) final { return loco::dtype_get(node->x()); }
235
236   loco::DataType visit(const luci::CircleMaxPool2D *node) final
237   {
238     return loco::dtype_get(node->value());
239   }
240
241   loco::DataType visit(const luci::CircleMean *node) final
242   {
243     return loco::dtype_get(node->input());
244   }
245
246   loco::DataType visit(const luci::CircleMinimum *node) final { return loco::dtype_get(node->x()); }
247
248   loco::DataType visit(const luci::CircleMirrorPad *node) final
249   {
250     return loco::dtype_get(node->input());
251   }
252
253   loco::DataType visit(const luci::CircleNeg *node) final { return loco::dtype_get(node->x()); }
254
255   loco::DataType visit(const luci::CircleNonMaxSuppressionV4 *node) final
256   {
257     return loco::dtype_get(node->boxes());
258   }
259
260   loco::DataType visit(const luci::CircleNotEqual *) final { return loco::DataType::BOOL; }
261
262   loco::DataType visit(const luci::CirclePack *node) final
263   {
264     // Only support CirclePack with one or more inputs
265     assert(node->values_count() > 0);
266
267     auto first_value_type = loco::dtype_get(node->values(0));
268     for (uint32_t i = 1; i < node->values_count(); ++i)
269       assert(first_value_type == loco::dtype_get(node->values(i)));
270
271     return first_value_type;
272   }
273
274   loco::DataType visit(const luci::CirclePad *node) final { return loco::dtype_get(node->input()); }
275
276   loco::DataType visit(const luci::CirclePow *node) final
277   {
278     // TODO make sure types cannot differ
279     auto x_type = loco::dtype_get(node->x());
280     auto y_type = loco::dtype_get(node->y());
281
282     if (x_type != y_type)
283       INTERNAL_EXN("Different datatype for x and y are not supported");
284
285     return x_type;
286   }
287
288   loco::DataType visit(const luci::CirclePRelu *node) final
289   {
290     auto input_type = loco::dtype_get(node->input());
291     auto alpha_type = loco::dtype_get(node->alpha());
292
293     if (input_type != alpha_type)
294       INTERNAL_EXN("Different datatype for input and alpha are not supported");
295
296     return input_type;
297   }
298
299   loco::DataType visit(const luci::CircleRange *node) final
300   {
301     return loco::dtype_get(node->start());
302   }
303
304   loco::DataType visit(const luci::CircleRank *) final { return loco::DataType::S32; }
305
306   loco::DataType visit(const luci::CircleMul *node) final { return loco::dtype_get(node->x()); }
307
308   loco::DataType visit(const luci::CircleOneHot *node) final
309   {
310     return loco::dtype_get(node->on_value());
311   }
312
313   loco::DataType visit(const luci::CircleReduceAny *node) final
314   {
315     return loco::dtype_get(node->input());
316   }
317
318   loco::DataType visit(const luci::CircleReduceMax *node) final
319   {
320     return loco::dtype_get(node->input());
321   }
322
323   loco::DataType visit(const luci::CircleReduceMin *node) final
324   {
325     return loco::dtype_get(node->input());
326   }
327
328   loco::DataType visit(const luci::CircleReduceProd *node) final
329   {
330     return loco::dtype_get(node->input());
331   }
332
333   loco::DataType visit(const luci::CircleRelu *node) final
334   {
335     return loco::dtype_get(node->features());
336   }
337
338   loco::DataType visit(const luci::CircleRelu6 *node) final
339   {
340     return loco::dtype_get(node->features());
341   }
342
343   loco::DataType visit(const luci::CircleReluN1To1 *node) final
344   {
345     return loco::dtype_get(node->features());
346   }
347
348   loco::DataType visit(const luci::CircleReshape *node) final
349   {
350     return loco::dtype_get(node->tensor());
351   }
352
353   loco::DataType visit(const luci::CircleResizeBilinear *node) final
354   {
355     return loco::dtype_get(node->input());
356   }
357
358   loco::DataType visit(const luci::CircleResizeNearestNeighbor *node) final
359   {
360     return loco::dtype_get(node->input());
361   }
362
363   loco::DataType visit(const luci::CircleReverseSequence *node) final
364   {
365     return loco::dtype_get(node->input());
366   }
367
368   loco::DataType visit(const luci::CircleReverseV2 *node) final
369   {
370     return loco::dtype_get(node->tensor());
371   }
372
373   loco::DataType visit(const luci::CircleRound *node) final { return loco::dtype_get(node->x()); }
374
375   loco::DataType visit(const luci::CircleRsqrt *node) final { return loco::dtype_get(node->x()); }
376
377   loco::DataType visit(const luci::CircleScatterNd *node) final
378   {
379     return loco::dtype_get(node->updates());
380   }
381
382   loco::DataType visit(const luci::CircleSegmentSum *node) final
383   {
384     return loco::dtype_get(node->input());
385   }
386
387   loco::DataType visit(const luci::CircleSelect *node) final
388   {
389     assert(loco::dtype_get(node->t()) == loco::dtype_get(node->e()));
390     return loco::dtype_get(node->t());
391   }
392
393   loco::DataType visit(const luci::CircleSelectV2 *node) final
394   {
395     assert(loco::dtype_get(node->t()) == loco::dtype_get(node->e()));
396     return loco::dtype_get(node->t());
397   }
398
399   loco::DataType visit(const luci::CircleShape *node) final { return node->out_type(); }
400
401   loco::DataType visit(const luci::CircleSin *node) final { return loco::dtype_get(node->x()); }
402
403   loco::DataType visit(const luci::CircleSlice *node) final
404   {
405     return loco::dtype_get(node->input());
406   }
407
408   loco::DataType visit(const luci::CircleSoftmax *node) final
409   {
410     return loco::dtype_get(node->logits());
411   }
412
413   loco::DataType visit(const luci::CircleSpaceToBatchND *node) final
414   {
415     return loco::dtype_get(node->input());
416   }
417
418   loco::DataType visit(const luci::CircleSpaceToDepth *node) final
419   {
420     return loco::dtype_get(node->input());
421   }
422
423   loco::DataType visit(const luci::CircleSparseToDense *node) final
424   {
425     return loco::dtype_get(node->values());
426   }
427
428   loco::DataType visit(const luci::CircleSplit *node) final
429   {
430     return loco::dtype_get(node->input());
431   }
432
433   loco::DataType visit(const luci::CircleSplitV *node) final
434   {
435     return loco::dtype_get(node->input());
436   }
437
438   loco::DataType visit(const luci::CircleSqrt *node) final { return loco::dtype_get(node->x()); }
439
440   loco::DataType visit(const luci::CircleSquare *node) final { return loco::dtype_get(node->x()); }
441
442   loco::DataType visit(const luci::CircleSquaredDifference *node) final
443   {
444     return loco::dtype_get(node->x());
445   }
446
447   loco::DataType visit(const luci::CircleSqueeze *node) final
448   {
449     return loco::dtype_get(node->input());
450   }
451
452   loco::DataType visit(const luci::CircleStridedSlice *node) final
453   {
454     return loco::dtype_get(node->input());
455   }
456
457   loco::DataType visit(const luci::CircleSub *node) final { return loco::dtype_get(node->x()); }
458
459   loco::DataType visit(const luci::CircleSum *node) final { return loco::dtype_get(node->input()); }
460
461   loco::DataType visit(const luci::CircleTanh *node) final { return loco::dtype_get(node->x()); }
462
463   loco::DataType visit(const luci::CircleTile *node) final
464   {
465     return loco::dtype_get(node->input());
466   }
467
468   loco::DataType visit(const luci::CircleTopKV2 *node) final
469   {
470     return loco::dtype_get(node->input());
471   }
472
473   loco::DataType visit(const luci::CircleTranspose *node) final
474   {
475     return loco::dtype_get(node->a());
476   }
477
478   loco::DataType visit(const luci::CircleTransposeConv *node) final
479   {
480     return loco::dtype_get(node->outBackprop());
481   }
482
483   loco::DataType visit(const luci::CircleUnique *node) final
484   {
485     return loco::dtype_get(node->input());
486   }
487
488   loco::DataType visit(const luci::CircleUnpack *node) final
489   {
490     return loco::dtype_get(node->value());
491   }
492
493   loco::DataType visit(const luci::CircleWhere *) final { return loco::DataType::S64; }
494
495   loco::DataType visit(const luci::CircleWhile *node) final
496   {
497     // Type of While is not used. Just use input 0
498     assert(node->input_count() > 0);
499     return loco::dtype_get(node->input(0));
500   }
501
502   loco::DataType visit(const luci::CircleZerosLike *node) final
503   {
504     return loco::dtype_get(node->input());
505   }
506
507   // Circle Only
508   loco::DataType visit(const luci::CircleBCQFullyConnected *) final
509   {
510     return loco::DataType::FLOAT32;
511   }
512
513   loco::DataType visit(const luci::CircleBCQGather *) final { return loco::DataType::FLOAT32; }
514
515   loco::DataType visit(const luci::CircleInstanceNorm *node) final
516   {
517     return loco::dtype_get(node->input());
518   }
519
520   // Virtual
521   loco::DataType visit(const luci::CircleInput *node) final { return node->dtype(); }
522
523   loco::DataType visit(const luci::CircleOutput *node) final
524   {
525     auto graph_outputs = node->graph()->outputs();
526     auto graph_output = graph_outputs->at(node->index());
527     auto output_dtype = graph_output->dtype();
528
529     if (dynamic_cast<luci::CircleOutputDummy *>(node->from()) == nullptr &&
530         dynamic_cast<luci::CircleOutputExclude *>(node->from()) == nullptr)
531     {
532       // We don't care for the type if from() is CircleOutputDummy or CircleOutputExclude
533       // from() type should match that of CircleOutput
534       assert(output_dtype == loco::dtype_get(node->from()));
535     }
536     return output_dtype;
537   }
538
539   loco::DataType visit(const luci::CircleOutputDummy *node) final { return node->dtype(); }
540
541   loco::DataType visit(const luci::CircleOutputExclude *node) final { return node->dtype(); }
542
543   loco::DataType visit(const luci::CircleCustomOut *node) final { return node->dtype(); }
544
545   loco::DataType visit(const luci::CircleIfOut *node) final
546   {
547     /**
548      * @note  IF operator type and shape are that of the "then" and "else"
549      *        Graph Outputs.
550      */
551     auto circle_if = dynamic_cast<const luci::CircleIf *>(node->input());
552     if (circle_if == nullptr)
553     {
554       INTERNAL_EXN("CircleIf IR is not configured correctly");
555     }
556
557     auto index = node->index();
558     auto then_graph = circle_if->then_graph();
559     auto else_graph = circle_if->else_graph();
560     assert(then_graph != nullptr);
561     assert(else_graph != nullptr);
562
563     // shape and type are assumed to be same
564     // these are checked at post_import_graph() in Import
565     auto then_outputs = loco::output_nodes(then_graph);
566     auto else_outputs = loco::output_nodes(else_graph);
567     assert(then_outputs.size() == else_outputs.size());
568     assert(index < static_cast<int32_t>(then_outputs.size()));
569
570     auto then_out = loco::must_cast<luci::CircleOutput *>(then_outputs.at(index));
571     auto else_out = loco::must_cast<luci::CircleOutput *>(else_outputs.at(index));
572
573     auto then_graph_outputs = then_graph->outputs(); // loco::GraphOutput items
574     auto else_graph_outputs = else_graph->outputs();
575     assert(then_graph_outputs->size() == else_graph_outputs->size());
576
577     auto then_graph_output = then_graph_outputs->at(then_out->index());
578     auto else_graph_output = else_graph_outputs->at(else_out->index());
579     (void)else_graph_output; // make compiler happy for unused variable warnings
580     assert(then_graph_output->dtype() == else_graph_output->dtype());
581
582     return then_graph_output->dtype();
583   }
584
585   loco::DataType visit(const luci::CircleNonMaxSuppressionV4Out *node) final
586   {
587     (void)node;
588     assert(node->index() == 0 || node->index() == 1);
589     return loco::DataType::S32;
590   }
591
592   loco::DataType visit(const luci::CircleSplitOut *node) final
593   {
594     return loco::dtype_get(node->input());
595   }
596
597   loco::DataType visit(const luci::CircleSplitVOut *node) final
598   {
599     return loco::dtype_get(node->input());
600   }
601
602   loco::DataType visit(const luci::CircleTopKV2Out *node) final
603   {
604     // First output is same as input
605     if (node->index() == 0)
606       return loco::dtype_get(node->input());
607     // Second outout is always S32
608     assert(node->index() == 1);
609     return loco::DataType::S32;
610   }
611
612   loco::DataType visit(const luci::CircleUniqueOut *node) final
613   {
614     if (node->index() == 0)
615     {
616       return loco::dtype_get(node->input());
617     }
618     assert(node->index() == 1);
619     auto unique = loco::must_cast<luci::CircleUnique *>(node->input());
620     return unique->idx_out_type();
621   }
622
623   loco::DataType visit(const luci::CircleUnpackOut *node) final
624   {
625     return loco::dtype_get(node->input());
626   }
627
628   loco::DataType visit(const luci::CircleWhileOut *node) final
629   {
630     /**
631      * @note  WHILE operator's type is the same with the "cond"
632      *        Graph Input.
633      */
634     auto circle_while = dynamic_cast<const luci::CircleWhile *>(node->input());
635     if (circle_while == nullptr)
636     {
637       INTERNAL_EXN("CircleWhile IR is not configured correctly");
638     }
639
640     auto index = node->index();
641     auto cond_graph = circle_while->cond_graph();
642     assert(cond_graph != nullptr);
643
644     // Assumption: the index of CircleWhileOut matches with the index of input nodes returned by
645     // loco::input_nodes
646     auto cond_inputs = loco::input_nodes(cond_graph);
647     auto cond_in = loco::must_cast<luci::CircleInput *>(cond_inputs.at(index));
648
649     auto cond_graph_inputs = cond_graph->inputs();
650     auto cond_graph_input = cond_graph_inputs->at(cond_in->index());
651
652     return cond_graph_input->dtype();
653   }
654 };
655
656 } // namespace
657
658 namespace luci
659 {
660
661 bool CircleTypeInferenceRule::recognize(const loco::Dialect *d) const
662 {
663   return CircleDialect::get() == d;
664 }
665
666 bool CircleTypeInferenceRule::infer(const loco::Node *node, loco::DataType &dtype) const
667 {
668   assert(node->dialect() == CircleDialect::get());
669
670   TypeInferenceAlgorithm alg;
671
672   auto circle_node = loco::must_cast<const CircleNode *>(node);
673   dtype = circle_node->accept(&alg);
674   assert(dtype != loco::DataType::Unknown);
675
676   return true;
677 }
678
679 } // namespace luci