[loco] Introduce CanonicalNodeDef interface (#4302)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 17 Jul 2019 07:49:38 +0000 (16:49 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 17 Jul 2019 07:49:38 +0000 (16:49 +0900)
This commit introduces CanonicalNodeDef interfaces which takes multiple
Mixins all at once.

Compared with existing CanonicalNodeImpl, CanonicalNodeDef emits uniform
declarations that begin with "CanonicalNodeDef<OPCODE, ...>", which makes
it easy to read the code.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/loco/include/loco/IR/CanonicalNodeDecl.h
contrib/loco/include/loco/IR/NodeMixins.h
contrib/loco/include/loco/IR/Nodes.h

index 6dadf88..a226d51 100644 (file)
@@ -44,6 +44,15 @@ template <CanonicalOpcode Code> struct CanonicalNodeImpl : public CanonicalNode
   CanonicalOpcode opcode(void) const final { return Code; }
 };
 
+template <CanonicalOpcode Code, template <typename T> class... Mixins>
+struct CanonicalNodeDef : public virtual CanonicalNode, public Mixins<CanonicalNode>...
+{
+  virtual ~CanonicalNodeDef() = default;
+
+  uint32_t opnum(void) const final { return static_cast<uint32_t>(Code); }
+  CanonicalOpcode opcode(void) const final { return Code; }
+};
+
 } // namespace loco
 
 #endif // __LOCO_IR_CANONICAL_NODE_H__
index 1da4fee..89e0920 100644 (file)
@@ -17,6 +17,7 @@
 #ifndef __LOCO_IR_NODE_MIXINS_H__
 #define __LOCO_IR_NODE_MIXINS_H__
 
+#include "loco/IR/Node.h"
 #include "loco/IR/DataType.h"
 #include "loco/IR/Dimension.h"
 
@@ -82,6 +83,51 @@ private:
   std::vector<Dimension> _dims;
 };
 
+template <unsigned N> struct FixedArity
+{
+  template <typename Base> class Mixin : public virtual Base
+  {
+  public:
+    Mixin()
+    {
+      for (uint32_t n = 0; n < N; ++n)
+      {
+        _args[n] = std::unique_ptr<Use>{new Use{this}};
+      }
+    }
+
+    virtual ~Mixin() = default;
+
+  public:
+    unsigned arity(void) const final { return N; }
+
+    Node *arg(uint32_t n) const final { return _args.at(n)->node(); }
+
+    void drop(void) final
+    {
+      for (uint32_t n = 0; n < N; ++n)
+      {
+        _args.at(n)->node(nullptr);
+      }
+    }
+
+  protected:
+    // This API allows inherited classes to access "_args" field.
+    Use *at(unsigned n) const { return _args.at(n).get(); }
+
+  private:
+    std::array<std::unique_ptr<Use>, N> _args;
+  };
+};
+
+template <NodeTrait Trait> struct With
+{
+  template <typename Base> struct Mixin : public virtual Base, public NodeMixin<Trait>
+  {
+    // DO NOTHING
+  };
+};
+
 } // namespace loco
 
 #endif // __LOCO_IR_NODE_MIXINS_H__
index 932e6c1..073933e 100644 (file)
@@ -38,8 +38,9 @@ namespace loco
 /**
  * @brief Make a value visible to user
  */
-class Push /* to user */ final : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::Push>>,
-                                 public NodeMixin<NodeTrait::TensorShape>
+class Push /* to user */ final
+    : public CanonicalNodeDef<CanonicalOpcode::Push, FixedArity<1>::Mixin,
+                              With<NodeTrait::TensorShape>::Mixin>
 {
 public:
   Push() = default;
@@ -53,9 +54,8 @@ public:
  * @brief Create a value from user data
  */
 class Pull /* from user */ final
-    : public FixedArityNode<0, CanonicalNodeImpl<CanonicalOpcode::Pull>>,
-      public NodeMixin<NodeTrait::DataType>,
-      public NodeMixin<NodeTrait::TensorShape>
+    : public CanonicalNodeDef<CanonicalOpcode::Pull, FixedArity<0>::Mixin,
+                              With<NodeTrait::DataType>::Mixin, With<NodeTrait::TensorShape>::Mixin>
 {
 public:
   Pull() = default;
@@ -66,7 +66,7 @@ public:
  *
  * This node may encode memory transfer (such as CPU -> GPU or GPU -> CPU)
  */
-class Forward final : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::Forward>>
+class Forward final : public CanonicalNodeDef<CanonicalOpcode::Forward, FixedArity<1>::Mixin>
 {
 public:
   Forward() = default;
@@ -79,7 +79,7 @@ public:
 /**
  * @brief Create a new value that rectifies its input
  */
-class ReLU final : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::ReLU>>
+class ReLU final : public CanonicalNodeDef<CanonicalOpcode::ReLU, FixedArity<1>::Mixin>
 {
 public:
   ReLU() = default;
@@ -92,7 +92,7 @@ public:
 /**
  * @brief Create a new value that rectifies its input capping the units at 6.
  */
-class ReLU6 final : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::ReLU6>>
+class ReLU6 final : public CanonicalNodeDef<CanonicalOpcode::ReLU6, FixedArity<1>::Mixin>
 {
 public:
   ReLU6() = default;
@@ -123,9 +123,9 @@ public:
  *   return res;
  * }
  */
-class ConstGen final : public FixedArityNode<0, CanonicalNodeImpl<CanonicalOpcode::ConstGen>>,
-                       public NodeMixin<NodeTrait::DataType>,
-                       public NodeMixin<NodeTrait::TensorShape>
+class ConstGen final
+    : public CanonicalNodeDef<CanonicalOpcode::ConstGen, FixedArity<0>::Mixin,
+                              With<NodeTrait::DataType>::Mixin, With<NodeTrait::TensorShape>::Mixin>
 {
 public:
   ConstGen() = default;
@@ -212,7 +212,7 @@ private:
  *   before/after MaxPool2D node according to the semantics of the corresponding NN framework.
  * ---
  */
-class MaxPool2D final : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::MaxPool2D>>
+class MaxPool2D final : public CanonicalNodeDef<CanonicalOpcode::MaxPool2D, FixedArity<1>::Mixin>
 {
 public:
   Node *ifm(void) const { return at(0)->node(); }
@@ -244,7 +244,7 @@ private:
  *
  * @note Follows MaxPool2D (TODO: describe difference)
  */
-class AvgPool2D final : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::AvgPool2D>>
+class AvgPool2D final : public CanonicalNodeDef<CanonicalOpcode::AvgPool2D, FixedArity<1>::Mixin>
 {
 public:
   enum class Convention
@@ -287,7 +287,7 @@ private:
  * @brief Create a feature map from a tensor
  */
 class FeatureEncode final
-    : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::FeatureEncode>>
+    : public CanonicalNodeDef<CanonicalOpcode::FeatureEncode, FixedArity<1>::Mixin>
 {
 public:
   Node *input(void) const { return at(0)->node(); }
@@ -306,7 +306,7 @@ private:
  * @brief Create a tensor from a feature map
  */
 class FeatureDecode final
-    : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::FeatureDecode>>
+    : public CanonicalNodeDef<CanonicalOpcode::FeatureDecode, FixedArity<1>::Mixin>
 {
 public:
   Node *input(void) const { return at(0)->node(); }
@@ -325,7 +325,7 @@ private:
  * @brief Create a filter from a tensor
  */
 class FilterEncode final
-    : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::FilterEncode>>
+    : public CanonicalNodeDef<CanonicalOpcode::FilterEncode, FixedArity<1>::Mixin>
 {
 public:
   Node *input(void) const { return at(0)->node(); }
@@ -346,7 +346,7 @@ private:
  * @brief Create a depthwise filter from a tensor
  */
 class DepthwiseFilterEncode final
-    : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::DepthwiseFilterEncode>>
+    : public CanonicalNodeDef<CanonicalOpcode::DepthwiseFilterEncode, FixedArity<1>::Mixin>
 {
 public:
   Node *input(void) const { return at(0)->node(); }
@@ -389,8 +389,8 @@ template <ReshapeType RT> class Reshape;
  */
 template <>
 class Reshape<ReshapeType::Fixed> final
-    : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::FixedReshape>>,
-      public NodeMixin<NodeTrait::TensorShape>
+    : public CanonicalNodeDef<CanonicalOpcode::FixedReshape, FixedArity<1>::Mixin,
+                              With<NodeTrait::TensorShape>::Mixin>
 {
 public:
   Node *input(void) const { return at(0)->node(); }
@@ -404,7 +404,7 @@ public:
  * concatenated along the given axis.
  */
 class TensorConcat final
-    : public FixedArityNode<2, CanonicalNodeImpl<CanonicalOpcode::TensorConcat>>
+    : public CanonicalNodeDef<CanonicalOpcode::TensorConcat, FixedArity<2>::Mixin>
 {
 public:
   Node *lhs(void) const { return at(0)->node(); }
@@ -425,7 +425,7 @@ private:
 /**
  * @brief 2D Spatial Convolution
  */
-class Conv2D final : public FixedArityNode<2, CanonicalNodeImpl<CanonicalOpcode::Conv2D>>
+class Conv2D final : public CanonicalNodeDef<CanonicalOpcode::Conv2D, FixedArity<2>::Mixin>
 {
 public:
   Node *ifm(void) const { return at(0)->node(); }
@@ -454,7 +454,7 @@ private:
  *
  * BiasEncode currently requires a rank-1 tensor as its input.
  */
-class BiasEncode final : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::BiasEncode>>
+class BiasEncode final : public CanonicalNodeDef<CanonicalOpcode::BiasEncode, FixedArity<1>::Mixin>
 {
 public:
   BiasEncode() = default;
@@ -477,7 +477,7 @@ template <Domain D> class BiasAdd;
  */
 template <>
 class BiasAdd<Domain::Tensor> final
-    : public FixedArityNode<2, CanonicalNodeImpl<CanonicalOpcode::TensorBiasAdd>>
+    : public CanonicalNodeDef<CanonicalOpcode::TensorBiasAdd, FixedArity<2>::Mixin>
 {
 public:
   BiasAdd() = default;
@@ -513,7 +513,7 @@ using TensorBiasAdd = BiasAdd<Domain::Tensor>;
  */
 template <>
 class BiasAdd<Domain::Feature> final
-    : public FixedArityNode<2, CanonicalNodeImpl<CanonicalOpcode::FeatureBiasAdd>>
+    : public CanonicalNodeDef<CanonicalOpcode::FeatureBiasAdd, FixedArity<2>::Mixin>
 {
 public:
   BiasAdd() = default;