[loco] Rewrite FixedArityNode as a mix-in (#3699)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Tue, 11 Jun 2019 00:55:35 +0000 (09:55 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Tue, 11 Jun 2019 00:55:35 +0000 (09:55 +0900)
* [loco] Rewrite FixedArityNode as a mix-in

FixedArityNode now serves as a mix-in rather than a concrete
implementation.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
* Inherit Base

contrib/loco/include/loco/IR/Node.h
contrib/loco/include/loco/IR/Nodes.h
contrib/loco/src/IR/Node.test.cpp

index a6901ce..e8a4def 100644 (file)
@@ -84,7 +84,7 @@ std::set<Node *> succs(const Node *node);
 /**
  * @brief Nodes with the fixed number of inputs
  */
-template <unsigned N> class FixedArityNode : public Node
+template <unsigned N, typename Base> class FixedArityNode : public Base
 {
 public:
   FixedArityNode()
index 9ad132e..efad495 100644 (file)
@@ -35,7 +35,8 @@ namespace loco
 /**
  * @brief Make a value visible to user
  */
-class Push /* to user */ final : public FixedArityNode<1>, public NodeMixin<NodeTrait::TensorShape>
+class Push /* to user */ final : public FixedArityNode<1, Node>,
+                                 public NodeMixin<NodeTrait::TensorShape>
 {
 public:
   Push() = default;
@@ -48,7 +49,7 @@ public:
 /**
  * @brief Create a value from user data
  */
-class Pull /* from user */ final : public FixedArityNode<0>,
+class Pull /* from user */ final : public FixedArityNode<0, Node>,
                                    public NodeMixin<NodeTrait::DataType>,
                                    public NodeMixin<NodeTrait::TensorShape>
 {
@@ -61,7 +62,7 @@ public:
  *
  * This node may encode memory transfer (such as CPU -> GPU or GPU -> CPU)
  */
-class Forward final : public FixedArityNode<1>
+class Forward final : public FixedArityNode<1, Node>
 {
 public:
   Forward() = default;
@@ -74,7 +75,7 @@ public:
 /**
  * @brief Create a new value that rectifies its input
  */
-class ReLU final : public FixedArityNode<1>
+class ReLU final : public FixedArityNode<1, Node>
 {
 public:
   ReLU() = default;
@@ -105,7 +106,7 @@ public:
  *   return res;
  * }
  */
-class ConstGen final : public FixedArityNode<0>,
+class ConstGen final : public FixedArityNode<0, Node>,
                        public NodeMixin<NodeTrait::DataType>,
                        public NodeMixin<NodeTrait::TensorShape>
 {
@@ -196,7 +197,7 @@ private:
  *   before/after MaxPool2D node according to the semantics of the corresponding NN framework.
  * ---
  */
-struct MaxPool2D final : public FixedArityNode<1>
+struct MaxPool2D final : public FixedArityNode<1, Node>
 {
 public:
   Node *ifm(void) const { return at(0)->node(); }
@@ -227,7 +228,7 @@ private:
 /**
  * @brief Create a feature map from a tensor
  */
-class FeatureEncode final : public FixedArityNode<1>
+class FeatureEncode final : public FixedArityNode<1, Node>
 {
 public:
   Node *input(void) const { return at(0)->node(); }
@@ -245,7 +246,7 @@ private:
 /**
  * @brief Create a tensor from a feature map
  */
-class FeatureDecode final : public FixedArityNode<1>
+class FeatureDecode final : public FixedArityNode<1, Node>
 {
 public:
   Node *input(void) const { return at(0)->node(); }
@@ -263,7 +264,7 @@ private:
 /**
  * @brief Create a filter from a tensor
  */
-class FilterEncode final : public FixedArityNode<1>
+class FilterEncode final : public FixedArityNode<1, Node>
 {
 public:
   Node *input(void) const { return at(0)->node(); }
@@ -305,7 +306,7 @@ template <ReshapeType RT> class Reshape;
  *         input[1, 1, 1, 1] => output [3, 3]
  */
 template <>
-class Reshape<ReshapeType::Fixed> final : public FixedArityNode<1>,
+class Reshape<ReshapeType::Fixed> final : public FixedArityNode<1, Node>,
                                           public NodeMixin<NodeTrait::TensorShape>
 {
 public:
@@ -319,7 +320,7 @@ public:
  * Given an axis, TensorConcat takes as input two tensors and produces a tensor
  * concatenated along the given axis.
  */
-class TensorConcat final : public FixedArityNode<2>
+class TensorConcat final : public FixedArityNode<2, Node>
 {
 public:
   Node *lhs(void) const { return at(0)->node(); }
@@ -340,7 +341,7 @@ private:
 /**
  * @brief 2D Spatial Convolution
  */
-class Conv2D final : public FixedArityNode<2>
+class Conv2D final : public FixedArityNode<2, Node>
 {
 public:
   Node *ifm(void) const { return at(0)->node(); }
index 4d98552..57a72e2 100644 (file)
@@ -51,7 +51,7 @@ TEST(NodeTest, succs)
 
 TEST(FixedArityNodeTest, constructor)
 {
-  loco::FixedArityNode<1> node;
+  loco::FixedArityNode<1, loco::Node> node;
 
   ASSERT_EQ(node.arity(), 1);
   ASSERT_EQ(node.arg(0), nullptr);