From: 박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 Date: Tue, 11 Jun 2019 00:55:35 +0000 (+0900) Subject: [loco] Rewrite FixedArityNode as a mix-in (#3699) X-Git-Tag: nncc_backup~447 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=16647499568e9e49f7a861d132aa2c8fb8d86e6b;p=platform%2Fcore%2Fml%2Fnnfw.git [loco] Rewrite FixedArityNode as a mix-in (#3699) * [loco] Rewrite FixedArityNode as a mix-in FixedArityNode now serves as a mix-in rather than a concrete implementation. Signed-off-by: Jonghyun Park * Inherit Base --- diff --git a/contrib/loco/include/loco/IR/Node.h b/contrib/loco/include/loco/IR/Node.h index a6901ce..e8a4def 100644 --- a/contrib/loco/include/loco/IR/Node.h +++ b/contrib/loco/include/loco/IR/Node.h @@ -84,7 +84,7 @@ std::set succs(const Node *node); /** * @brief Nodes with the fixed number of inputs */ -template class FixedArityNode : public Node +template class FixedArityNode : public Base { public: FixedArityNode() diff --git a/contrib/loco/include/loco/IR/Nodes.h b/contrib/loco/include/loco/IR/Nodes.h index 9ad132e..efad495 100644 --- a/contrib/loco/include/loco/IR/Nodes.h +++ b/contrib/loco/include/loco/IR/Nodes.h @@ -35,7 +35,8 @@ namespace loco /** * @brief Make a value visible to user */ -class Push /* to user */ final : public FixedArityNode<1>, public NodeMixin +class Push /* to user */ final : public FixedArityNode<1, Node>, + public NodeMixin { public: Push() = default; @@ -48,7 +49,7 @@ public: /** * @brief Create a value from user data */ -class Pull /* from user */ final : public FixedArityNode<0>, +class Pull /* from user */ final : public FixedArityNode<0, Node>, public NodeMixin, public NodeMixin { @@ -61,7 +62,7 @@ public: * * This node may encode memory transfer (such as CPU -> GPU or GPU -> CPU) */ -class Forward final : public FixedArityNode<1> +class Forward final : public FixedArityNode<1, Node> { public: Forward() = default; @@ -74,7 +75,7 @@ public: /** * @brief Create a new value that rectifies its input */ -class ReLU final : public FixedArityNode<1> +class ReLU final : public FixedArityNode<1, Node> { public: ReLU() = default; @@ -105,7 +106,7 @@ public: * return res; * } */ -class ConstGen final : public FixedArityNode<0>, +class ConstGen final : public FixedArityNode<0, Node>, public NodeMixin, public NodeMixin { @@ -196,7 +197,7 @@ private: * before/after MaxPool2D node according to the semantics of the corresponding NN framework. * --- */ -struct MaxPool2D final : public FixedArityNode<1> +struct MaxPool2D final : public FixedArityNode<1, Node> { public: Node *ifm(void) const { return at(0)->node(); } @@ -227,7 +228,7 @@ private: /** * @brief Create a feature map from a tensor */ -class FeatureEncode final : public FixedArityNode<1> +class FeatureEncode final : public FixedArityNode<1, Node> { public: Node *input(void) const { return at(0)->node(); } @@ -245,7 +246,7 @@ private: /** * @brief Create a tensor from a feature map */ -class FeatureDecode final : public FixedArityNode<1> +class FeatureDecode final : public FixedArityNode<1, Node> { public: Node *input(void) const { return at(0)->node(); } @@ -263,7 +264,7 @@ private: /** * @brief Create a filter from a tensor */ -class FilterEncode final : public FixedArityNode<1> +class FilterEncode final : public FixedArityNode<1, Node> { public: Node *input(void) const { return at(0)->node(); } @@ -305,7 +306,7 @@ template class Reshape; * input[1, 1, 1, 1] => output [3, 3] */ template <> -class Reshape final : public FixedArityNode<1>, +class Reshape final : public FixedArityNode<1, Node>, public NodeMixin { public: @@ -319,7 +320,7 @@ public: * Given an axis, TensorConcat takes as input two tensors and produces a tensor * concatenated along the given axis. */ -class TensorConcat final : public FixedArityNode<2> +class TensorConcat final : public FixedArityNode<2, Node> { public: Node *lhs(void) const { return at(0)->node(); } @@ -340,7 +341,7 @@ private: /** * @brief 2D Spatial Convolution */ -class Conv2D final : public FixedArityNode<2> +class Conv2D final : public FixedArityNode<2, Node> { public: Node *ifm(void) const { return at(0)->node(); } diff --git a/contrib/loco/src/IR/Node.test.cpp b/contrib/loco/src/IR/Node.test.cpp index 4d98552..57a72e2 100644 --- a/contrib/loco/src/IR/Node.test.cpp +++ b/contrib/loco/src/IR/Node.test.cpp @@ -51,7 +51,7 @@ TEST(NodeTest, succs) TEST(FixedArityNodeTest, constructor) { - loco::FixedArityNode<1> node; + loco::FixedArityNode<1, loco::Node> node; ASSERT_EQ(node.arity(), 1); ASSERT_EQ(node.arg(0), nullptr);