Imported Upstream version 1.7.0
[platform/core/ml/nnfw.git] / compiler / exo / src / Dialect / IR / TFLNodes.h
1 /*
2  * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *    http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 #ifndef __LOCOEX_IR_TFLNODES_H__
18 #define __LOCOEX_IR_TFLNODES_H__
19
20 #include "TFLNodeDecl.h"
21 #include "TFLOpcode.h"
22
23 #include "FusedActFunc.h"
24 #include "NodeMixins.h"
25
26 #include <loco/IR/Node.h>
27 #include <loco/IR/NodeMixins.h>
28 #include <loco/IR/DataTypeTraits.h>
29
30 #include <locoex/VariadicArityNode.h>
31
32 #include <array>
33
34 namespace locoex
35 {
36
37 enum class Padding
38 {
39   UNDEFINED, // This is not defined by TFLite. This was added to prevent programming error.
40   SAME,
41   VALID,
42 };
43
44 class Filter final
45 {
46 public:
47   Filter() : _w(1), _h(1) {}
48
49   int32_t w() const { return _w; }
50   void w(int32_t w) { _w = w; }
51
52   int32_t h() const { return _h; }
53   void h(int32_t h) { _h = h; }
54
55 private:
56   int32_t _w;
57   int32_t _h;
58 };
59
60 class Stride final
61 {
62 public:
63   Stride() : _w(1), _h(1) {}
64
65   int32_t w() const { return _w; }
66   void w(int32_t w) { _w = w; }
67
68   int32_t h() const { return _h; }
69   void h(int32_t h) { _h = h; }
70
71 private:
72   int32_t _w;
73   int32_t _h;
74 };
75
76 /// @brief enumeration of mixin class
77 enum class TFLNodeTrait
78 {
79   FusedActFunc,
80   Bias
81 };
82
83 template <TFLNodeTrait T> class TFLNodeMixin;
84
85 template <> class TFLNodeMixin<TFLNodeTrait::FusedActFunc>
86 {
87 public:
88   TFLNodeMixin() = default;
89
90 public:
91   FusedActFunc fusedActivationFunction() const { return _fused_act_fun; }
92   void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; }
93
94 private:
95   FusedActFunc _fused_act_fun = FusedActFunc::UNDEFINED;
96 };
97
98 /**
99  * @brief Mixin class for nodes that has a bias input
100  */
101 template <> class TFLNodeMixin<TFLNodeTrait::Bias>
102 {
103 public:
104   TFLNodeMixin() = default;
105
106 public:
107   virtual loco::Node *bias(void) const = 0; /// @brief get the input for bias.
108   virtual void bias(loco::Node *node) = 0;  /// @brief set the input for bias.
109 };
110
111 /**
112  * @brief ADD in TensorFlow Lite
113  */
114 class TFLAdd final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::ADD>>,
115                      public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
116 {
117 public:
118   loco::Node *x(void) const { return at(0)->node(); }
119   void x(loco::Node *node) { at(0)->node(node); }
120
121   loco::Node *y(void) const { return at(1)->node(); }
122   void y(loco::Node *node) { at(1)->node(node); }
123 };
124
125 /**
126  * @brief AVERAGE_POOL_2D in TensorFlow Lite
127  */
128 class TFLAveragePool2D final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::AVERAGE_POOL_2D>>,
129                                public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
130 {
131 public:
132   TFLAveragePool2D() : _padding(Padding::UNDEFINED) { /* empty */}
133
134 public:
135   loco::Node *value(void) const { return at(0)->node(); }
136   void value(loco::Node *node) { at(0)->node(node); }
137
138   Padding padding() const { return _padding; }
139   void padding(Padding padding) { _padding = padding; }
140
141   const Filter *filter(void) const { return &_filter; }
142   Filter *filter(void) { return &_filter; }
143
144   const Stride *stride(void) const { return &_stride; }
145   Stride *stride(void) { return &_stride; }
146
147 private:
148   Padding _padding;
149   Stride _stride;
150   Filter _filter;
151 };
152
153 /**
154  * @brief CONCATENATION in TensorFlow Lite
155  */
156 class TFLConcatenation final : public VariadicArityNode<TFLNodeImpl<TFLOpcode::CONCATENATION>>,
157                                public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
158 {
159 public:
160   TFLConcatenation(uint32_t arity) : VariadicArityNode<TFLNodeImpl<TFLOpcode::CONCATENATION>>(arity)
161   {
162     // TODO Support when arity is 0
163     assert(arity >= 1);
164   }
165
166 public:
167   uint32_t numValues(void) const { return arity(); }
168
169 public:
170   Node *values(uint32_t index) const
171   {
172     assert(index < numValues());
173     return at(index)->node();
174   }
175   void values(uint32_t index, Node *node)
176   {
177     assert(index < numValues());
178     at(index)->node(node);
179   }
180
181 public:
182   uint32_t axis(void) const { return _axis; }
183   void axis(uint32_t axis) { _axis = axis; }
184
185 private:
186   uint32_t _axis{0};
187 };
188
189 /**
190  * @brief Class to build tensor data
191  * @note  This will not be exported as a specific op
192  */
193 class TFLConst final : public FixedArityNode<0, TFLNodeImpl<TFLOpcode::CONST>>,
194                        public loco::NodeMixin<loco::NodeTrait::DataType>,
195                        public loco::NodeMixin<loco::NodeTrait::TensorShape>
196 {
197 public:
198   TFLConst() = default;
199
200 public:
201   template <loco::DataType DT> uint32_t size(void) const;
202   template <loco::DataType DT> void size(uint32_t size);
203   template <loco::DataType DT> const typename loco::DataTypeImpl<DT>::Type &at(uint32_t n) const;
204   template <loco::DataType DT> typename loco::DataTypeImpl<DT>::Type &at(uint32_t n);
205
206 private:
207   std::vector<uint8_t> _data;
208 };
209
210 /**
211  * @brief CONV_2D in TensorFlow Lite
212  */
213 class TFLConv2D final : public FixedArityNode<3, TFLNodeImpl<TFLOpcode::CONV_2D>>,
214                         public TFLNodeMixin<TFLNodeTrait::FusedActFunc>,
215                         public TFLNodeMixin<TFLNodeTrait::Bias>
216 {
217 public:
218   loco::Node *input(void) const { return at(0)->node(); }
219   void input(loco::Node *node) { at(0)->node(node); }
220
221   loco::Node *filter(void) const { return at(1)->node(); }
222   void filter(loco::Node *node) { at(1)->node(node); }
223
224   loco::Node *bias(void) const override { return at(2)->node(); }
225   void bias(loco::Node *node) override { at(2)->node(node); }
226
227 public:
228   Padding padding() const { return _padding; }
229   void padding(Padding padding) { _padding = padding; }
230
231   const Stride *stride(void) const { return &_stride; }
232   Stride *stride(void) { return &_stride; }
233
234 private:
235   Padding _padding = Padding::UNDEFINED;
236   Stride _stride;
237 };
238
239 /**
240  * @brief DEPTHWISE_CONV_2D in TensorFlow Lite
241  */
242 class TFLDepthwiseConv2D final
243     : public FixedArityNode<3, TFLNodeImpl<TFLOpcode::DEPTHWISE_CONV_2D>>,
244       public TFLNodeMixin<TFLNodeTrait::FusedActFunc>,
245       public TFLNodeMixin<TFLNodeTrait::Bias>
246 {
247 public:
248   loco::Node *input(void) const { return at(0)->node(); }
249   void input(loco::Node *node) { at(0)->node(node); }
250
251   loco::Node *filter(void) const { return at(1)->node(); }
252   void filter(loco::Node *node) { at(1)->node(node); }
253
254   loco::Node *bias(void) const override { return at(2)->node(); }
255   void bias(loco::Node *node) override { at(2)->node(node); }
256
257 public:
258   Padding padding() const { return _padding; }
259   void padding(Padding padding) { _padding = padding; }
260
261   const Stride *stride(void) const { return &_stride; }
262   Stride *stride(void) { return &_stride; }
263
264   int32_t depthMultiplier(void) const { return _depth_multiplier; }
265   void depthMultiplier(int32_t arg) { _depth_multiplier = arg; }
266
267 private:
268   Padding _padding = Padding::UNDEFINED;
269   Stride _stride;
270   int32_t _depth_multiplier = 0;
271 };
272
273 /**
274  * @brief DIV in TensorFlow Lite
275  */
276 class TFLDiv final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::DIV>>,
277                      public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
278 {
279 public:
280   TFLDiv() = default;
281
282 public:
283   loco::Node *x(void) const { return at(0)->node(); }
284   void x(loco::Node *node) { at(0)->node(node); }
285
286   loco::Node *y(void) const { return at(1)->node(); }
287   void y(loco::Node *node) { at(1)->node(node); }
288 };
289
290 /**
291  * @brief FULLY_CONNECTED in TensorFlow Lite
292  */
293 class TFLFullyConnected final : public FixedArityNode<3, TFLNodeImpl<TFLOpcode::FULLY_CONNECTED>>,
294                                 public TFLNodeMixin<TFLNodeTrait::FusedActFunc>,
295                                 public TFLNodeMixin<TFLNodeTrait::Bias>
296 {
297 public:
298   loco::Node *input(void) const { return at(0)->node(); }
299   void input(loco::Node *node) { at(0)->node(node); }
300
301   loco::Node *weights(void) const { return at(1)->node(); }
302   void weights(loco::Node *node) { at(1)->node(node); }
303
304   loco::Node *bias(void) const override { return at(2)->node(); }
305   void bias(loco::Node *node) override { at(2)->node(node); }
306 };
307
308 /**
309  * @brief MAXIMUM in TensorFlow Lite
310  */
311 class TFLMaximum final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::MAXIMUM>>
312 {
313 public:
314   loco::Node *x(void) const { return at(0)->node(); }
315   void x(loco::Node *node) { at(0)->node(node); }
316
317   loco::Node *y(void) const { return at(1)->node(); }
318   void y(loco::Node *node) { at(1)->node(node); }
319 };
320
321 /**
322  * @brief MAX_POOL_2D in TensorFlow Lite
323  */
324 class TFLMaxPool2D final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::MAX_POOL_2D>>,
325                            public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
326 {
327 public:
328   TFLMaxPool2D() : _padding(Padding::UNDEFINED) { /* empty */}
329
330 public:
331   loco::Node *value(void) const { return at(0)->node(); }
332   void value(loco::Node *node) { at(0)->node(node); }
333
334   Padding padding() const { return _padding; }
335   void padding(Padding padding) { _padding = padding; }
336
337   const Filter *filter(void) const { return &_filter; }
338   Filter *filter(void) { return &_filter; }
339
340   const Stride *stride(void) const { return &_stride; }
341   Stride *stride(void) { return &_stride; }
342
343 private:
344   Padding _padding;
345   Stride _stride;
346   Filter _filter;
347 };
348
349 class TFLMean final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::MEAN>>
350 {
351 public:
352   loco::Node *input(void) const { return at(0)->node(); }
353   void input(loco::Node *node) { at(0)->node(node); }
354
355   loco::Node *reduction_indices(void) const { return at(1)->node(); }
356   void reduction_indices(loco::Node *node) { at(1)->node(node); }
357
358 public:
359   bool keep_dims(void) const { return _keep_dims; }
360   void keep_dims(bool keep_dims) { _keep_dims = keep_dims; }
361
362 private:
363   bool _keep_dims = false;
364 };
365
366 /**
367  * @brief MUL in TensorFlow Lite
368  */
369 class TFLMul final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::MUL>>,
370                      public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
371 {
372 public:
373   loco::Node *x(void) const { return at(0)->node(); }
374   void x(loco::Node *node) { at(0)->node(node); }
375
376   loco::Node *y(void) const { return at(1)->node(); }
377   void y(loco::Node *node) { at(1)->node(node); }
378 };
379
380 class TFLRelu final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::RELU>>
381 {
382 public:
383   TFLRelu() = default;
384
385 public:
386   loco::Node *features(void) const { return at(0)->node(); }
387   void features(loco::Node *node) { at(0)->node(node); }
388 };
389
390 class TFLRelu6 final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::RELU6>>
391 {
392 public:
393   TFLRelu6() = default;
394
395 public:
396   loco::Node *features(void) const { return at(0)->node(); }
397   void features(loco::Node *node) { at(0)->node(node); }
398 };
399
400 class TFLReshape final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::RESHAPE>>
401 {
402 public:
403   TFLReshape() = default;
404
405 public:
406   loco::Node *tensor(void) const { return at(0)->node(); }
407   void tensor(loco::Node *node) { at(0)->node(node); }
408
409   // TODO Make this input optional. That is, loco system does not emit error
410   //      with this input being null
411   loco::Node *shape(void) const { return at(1)->node(); }
412   void shape(loco::Node *node) { at(1)->node(node); }
413
414 public:
415   class Shape
416   {
417   public:
418     uint32_t rank(void) const { return _shape.size(); }
419     void rank(uint32_t rank) { _shape.resize(rank); }
420
421     int32_t dim(uint32_t n) const { return _shape.at(n); }
422     int32_t &dim(uint32_t n) { return _shape.at(n); }
423
424   private:
425     std::vector<int32_t> _shape;
426   };
427
428   const Shape *newShape(void) const { return &_new_shape; }
429   Shape *newShape(void) { return &_new_shape; }
430
431 private:
432   Shape _new_shape;
433 };
434
435 /**
436  * @brief  Set both TFLReshape's 2nd input as TFLConst, and newShape attribute
437  *         with same value
438  * @note   Shape inference for TFLReshape forces them to be same
439  * TODO find better place for this helper
440  */
441 void set_new_shape(locoex::TFLReshape *node, int32_t *base, uint32_t size);
442
443 class TFLRsqrt final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::RSQRT>>
444 {
445 public:
446   TFLRsqrt() = default;
447
448 public:
449   loco::Node *x(void) const { return at(0)->node(); }
450   void x(loco::Node *node) { at(0)->node(node); }
451 };
452
453 // TODO TFLSoftmax
454
455 class TFLSqrt final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::SQRT>>
456 {
457 public:
458   TFLSqrt() = default;
459
460 public:
461   loco::Node *x(void) const { return at(0)->node(); }
462   void x(loco::Node *node) { at(0)->node(node); }
463 };
464
465 class TFLSquaredDifference final
466     : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::SQUARED_DIFFERENCE>>
467 {
468 public:
469   TFLSquaredDifference() = default;
470
471 public:
472   loco::Node *x(void) const { return at(0)->node(); }
473   void x(loco::Node *node) { at(0)->node(node); }
474
475   loco::Node *y(void) const { return at(1)->node(); }
476   void y(loco::Node *node) { at(1)->node(node); }
477 };
478
479 /**
480  * @brief SUB in TensorFlow Lite
481  */
482 class TFLSub final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::SUB>>,
483                      public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
484 {
485 public:
486   TFLSub() = default;
487
488 public:
489   loco::Node *x(void) const { return at(0)->node(); }
490   void x(loco::Node *node) { at(0)->node(node); }
491
492   loco::Node *y(void) const { return at(1)->node(); }
493   void y(loco::Node *node) { at(1)->node(node); }
494 };
495
496 // TODO TFLTanh
497
498 /**
499  * @brief TRANSPOSE in TensorFlow Lite
500  */
501 class TFLTranspose final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::TRANSPOSE>>
502 {
503 public:
504   TFLTranspose() = default;
505
506 public:
507   /// @brief Get the input node to transpose
508   loco::Node *a(void) const { return at(0)->node(); }
509
510   /// @brief Set the input node to transpose
511   void a(loco::Node *node) { at(0)->node(node); }
512
513   loco::Node *perm(void) const { return at(1)->node(); }
514   void perm(loco::Node *node) { at(1)->node(node); }
515 };
516
517 /**
518  * @brief TRANSPOSE_CONV in TensorFlow Lite
519  *
520  * @note  Argument node function names are from TensorFlow. So refering 'in' and
521  *        'out' acutally means 'out' and 'in' of the this node.
522  */
523 class TFLTransposeConv final : public FixedArityNode<3, TFLNodeImpl<TFLOpcode::TRANSPOSE_CONV>>
524 {
525 public:
526   loco::Node *inputSizes(void) const { return at(0)->node(); }
527   void inputSizes(Node *node) { at(0)->node(node); }
528
529   loco::Node *filter(void) const { return at(1)->node(); }
530   void filter(Node *node) { at(1)->node(node); }
531
532   loco::Node *outBackprop(void) const { return at(2)->node(); }
533   void outBackprop(Node *node) { at(2)->node(node); }
534
535 public:
536   const Padding &padding(void) const { return _padding; }
537   void padding(const Padding &padding) { _padding = padding; }
538
539   const Stride *stride(void) const { return &_stride; }
540   Stride *stride(void) { return &_stride; }
541
542 private:
543   Padding _padding{Padding::UNDEFINED};
544   Stride _stride;
545 };
546
547 // TODO define more children of TFLNode
548
549 } // namespace locoex
550
551 #endif // __LOCOEX_IR_TFLNODES_H__