[exo] TFLMixin for classes that have fused activation function (#8223)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Wed, 16 Oct 2019 07:53:35 +0000 (16:53 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 16 Oct 2019 07:53:35 +0000 (16:53 +0900)
This adds TFLMixin class for TFL nodes that have fused activation function

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
compiler/exo/src/Dialect/IR/TFLNodes.h

index 105399e..e235d4d 100644 (file)
@@ -115,10 +115,32 @@ private:
   int32_t _h;
 };
 
+/// @brief enumeration of mixin class
+enum class TFLNodeTrait
+{
+  FusedActFunc,
+};
+
+template <TFLNodeTrait T> class TFLNodeMixin;
+
+template <> class TFLNodeMixin<TFLNodeTrait::FusedActFunc>
+{
+public:
+  TFLNodeMixin() = default;
+
+public:
+  FusedActFunc fusedActivationFunction() const { return _fused_act_fun; }
+  void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; }
+
+private:
+  FusedActFunc _fused_act_fun = FusedActFunc::UNDEFINED;
+};
+
 /**
  * @brief ADD in TensorFlow Lite
  */
-class TFLAdd final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::ADD>>
+class TFLAdd final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::ADD>>,
+                     public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
 {
 public:
   loco::Node *x(void) const { return at(0)->node(); }
@@ -126,30 +148,21 @@ public:
 
   loco::Node *y(void) const { return at(1)->node(); }
   void y(loco::Node *node) { at(1)->node(node); }
-
-  FusedActFunc fusedActivationFunction() const { return _fused_act_fun; }
-  void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; }
-
-private:
-  FusedActFunc _fused_act_fun = FusedActFunc::UNDEFINED;
 };
 
 /**
  * @brief AVERAGE_POOL_2D in TensorFlow Lite
  */
-class TFLAveragePool2D final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::AVERAGE_POOL_2D>>
+class TFLAveragePool2D final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::AVERAGE_POOL_2D>>,
+                               public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
 {
 public:
-  TFLAveragePool2D()
-      : _fused_act_fun(FusedActFunc::UNDEFINED), _padding(Padding::UNDEFINED) { /* empty */}
+  TFLAveragePool2D() : _padding(Padding::UNDEFINED) { /* empty */}
 
 public:
   loco::Node *value(void) const { return at(0)->node(); }
   void value(loco::Node *node) { at(0)->node(node); }
 
-  FusedActFunc fusedActivationFunction() const { return _fused_act_fun; }
-  void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; }
-
   Padding padding() const { return _padding; }
   void padding(Padding padding) { _padding = padding; }
 
@@ -160,7 +173,6 @@ public:
   Stride *stride(void) { return &_stride; }
 
 private:
-  FusedActFunc _fused_act_fun;
   Padding _padding;
   Stride _stride;
   Filter _filter;
@@ -169,7 +181,8 @@ private:
 /**
  * @brief CONCATENATION in TensorFlow Lite
  */
-class TFLConcatenation final : public VariadicArityNode<TFLNodeImpl<TFLOpcode::CONCATENATION>>
+class TFLConcatenation final : public VariadicArityNode<TFLNodeImpl<TFLOpcode::CONCATENATION>>,
+                               public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
 {
 public:
   TFLConcatenation(uint32_t arity) : VariadicArityNode<TFLNodeImpl<TFLOpcode::CONCATENATION>>(arity)
@@ -193,16 +206,12 @@ public:
     at(index)->node(node);
   }
 
-  FusedActFunc fusedActivationFunction() const { return _fused_act_fun; }
-  void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; }
-
 public:
   uint32_t axis(void) const { return _axis; }
   void axis(uint32_t axis) { _axis = axis; }
 
 private:
   uint32_t _axis;
-  FusedActFunc _fused_act_fun = FusedActFunc::UNDEFINED;
 };
 
 /**
@@ -229,7 +238,8 @@ private:
 /**
  * @brief CONV_2D in TensorFlow Lite
  */
-class TFLConv2D final : public FixedArityNode<3, TFLNodeImpl<TFLOpcode::CONV_2D>>
+class TFLConv2D final : public FixedArityNode<3, TFLNodeImpl<TFLOpcode::CONV_2D>>,
+                        public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
 {
 public:
   loco::Node *input(void) const { return at(0)->node(); }
@@ -242,9 +252,6 @@ public:
   void bias(loco::Node *node) { at(2)->node(node); }
 
 public:
-  FusedActFunc fusedActivationFunction() const { return _fused_act_fun; }
-  void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; }
-
   Padding padding() const { return _padding; }
   void padding(Padding padding) { _padding = padding; }
 
@@ -252,7 +259,6 @@ public:
   Stride *stride(void) { return &_stride; }
 
 private:
-  FusedActFunc _fused_act_fun = FusedActFunc::UNDEFINED;
   Padding _padding = Padding::UNDEFINED;
   Stride _stride;
 };
@@ -278,19 +284,16 @@ public:
 /**
  * @brief MAX_POOL_2D in TensorFlow Lite
  */
-class TFLMaxPool2D final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::MAX_POOL_2D>>
+class TFLMaxPool2D final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::MAX_POOL_2D>>,
+                           public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
 {
 public:
-  TFLMaxPool2D()
-      : _fused_act_fun(FusedActFunc::UNDEFINED), _padding(Padding::UNDEFINED) { /* empty */}
+  TFLMaxPool2D() : _padding(Padding::UNDEFINED) { /* empty */}
 
 public:
   loco::Node *value(void) const { return at(0)->node(); }
   void value(loco::Node *node) { at(0)->node(node); }
 
-  FusedActFunc fusedActivationFunction() const { return _fused_act_fun; }
-  void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; }
-
   Padding padding() const { return _padding; }
   void padding(Padding padding) { _padding = padding; }
 
@@ -301,7 +304,6 @@ public:
   Stride *stride(void) { return &_stride; }
 
 private:
-  FusedActFunc _fused_act_fun;
   Padding _padding;
   Stride _stride;
   Filter _filter;
@@ -310,7 +312,8 @@ private:
 /**
  * @brief MUL in TensorFlow Lite
  */
-class TFLMul final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::MUL>>
+class TFLMul final : public FixedArityNode<2, TFLNodeImpl<TFLOpcode::MUL>>,
+                     public TFLNodeMixin<TFLNodeTrait::FusedActFunc>
 {
 public:
   loco::Node *x(void) const { return at(0)->node(); }
@@ -318,12 +321,6 @@ public:
 
   loco::Node *y(void) const { return at(1)->node(); }
   void y(loco::Node *node) { at(1)->node(node); }
-
-  FusedActFunc fusedActivationFunction() const { return _fused_act_fun; }
-  void fusedActivationFunction(FusedActFunc fused_act_fun) { _fused_act_fun = fused_act_fun; }
-
-private:
-  FusedActFunc _fused_act_fun = FusedActFunc::UNDEFINED;
 };
 
 class TFLRelu final : public FixedArityNode<1, TFLNodeImpl<TFLOpcode::RELU>>