From 00a7322c98d4a54c00be25029c1072fbb04c3a18 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 17 Jul 2019 16:49:38 +0900 Subject: [PATCH] [loco] Introduce CanonicalNodeDef interface (#4302) This commit introduces CanonicalNodeDef interfaces which takes multiple Mixins all at once. Compared with existing CanonicalNodeImpl, CanonicalNodeDef emits uniform declarations that begin with "CanonicalNodeDef", which makes it easy to read the code. Signed-off-by: Jonghyun Park --- contrib/loco/include/loco/IR/CanonicalNodeDecl.h | 9 +++++ contrib/loco/include/loco/IR/NodeMixins.h | 46 +++++++++++++++++++++++ contrib/loco/include/loco/IR/Nodes.h | 48 ++++++++++++------------ 3 files changed, 79 insertions(+), 24 deletions(-) diff --git a/contrib/loco/include/loco/IR/CanonicalNodeDecl.h b/contrib/loco/include/loco/IR/CanonicalNodeDecl.h index 6dadf88..a226d51 100644 --- a/contrib/loco/include/loco/IR/CanonicalNodeDecl.h +++ b/contrib/loco/include/loco/IR/CanonicalNodeDecl.h @@ -44,6 +44,15 @@ template struct CanonicalNodeImpl : public CanonicalNode CanonicalOpcode opcode(void) const final { return Code; } }; +template class... Mixins> +struct CanonicalNodeDef : public virtual CanonicalNode, public Mixins... +{ + virtual ~CanonicalNodeDef() = default; + + uint32_t opnum(void) const final { return static_cast(Code); } + CanonicalOpcode opcode(void) const final { return Code; } +}; + } // namespace loco #endif // __LOCO_IR_CANONICAL_NODE_H__ diff --git a/contrib/loco/include/loco/IR/NodeMixins.h b/contrib/loco/include/loco/IR/NodeMixins.h index 1da4fee..89e0920 100644 --- a/contrib/loco/include/loco/IR/NodeMixins.h +++ b/contrib/loco/include/loco/IR/NodeMixins.h @@ -17,6 +17,7 @@ #ifndef __LOCO_IR_NODE_MIXINS_H__ #define __LOCO_IR_NODE_MIXINS_H__ +#include "loco/IR/Node.h" #include "loco/IR/DataType.h" #include "loco/IR/Dimension.h" @@ -82,6 +83,51 @@ private: std::vector _dims; }; +template struct FixedArity +{ + template class Mixin : public virtual Base + { + public: + Mixin() + { + for (uint32_t n = 0; n < N; ++n) + { + _args[n] = std::unique_ptr{new Use{this}}; + } + } + + virtual ~Mixin() = default; + + public: + unsigned arity(void) const final { return N; } + + Node *arg(uint32_t n) const final { return _args.at(n)->node(); } + + void drop(void) final + { + for (uint32_t n = 0; n < N; ++n) + { + _args.at(n)->node(nullptr); + } + } + + protected: + // This API allows inherited classes to access "_args" field. + Use *at(unsigned n) const { return _args.at(n).get(); } + + private: + std::array, N> _args; + }; +}; + +template struct With +{ + template struct Mixin : public virtual Base, public NodeMixin + { + // DO NOTHING + }; +}; + } // namespace loco #endif // __LOCO_IR_NODE_MIXINS_H__ diff --git a/contrib/loco/include/loco/IR/Nodes.h b/contrib/loco/include/loco/IR/Nodes.h index 932e6c1..073933e 100644 --- a/contrib/loco/include/loco/IR/Nodes.h +++ b/contrib/loco/include/loco/IR/Nodes.h @@ -38,8 +38,9 @@ namespace loco /** * @brief Make a value visible to user */ -class Push /* to user */ final : public FixedArityNode<1, CanonicalNodeImpl>, - public NodeMixin +class Push /* to user */ final + : public CanonicalNodeDef::Mixin, + With::Mixin> { public: Push() = default; @@ -53,9 +54,8 @@ public: * @brief Create a value from user data */ class Pull /* from user */ final - : public FixedArityNode<0, CanonicalNodeImpl>, - public NodeMixin, - public NodeMixin + : public CanonicalNodeDef::Mixin, + With::Mixin, With::Mixin> { public: Pull() = default; @@ -66,7 +66,7 @@ public: * * This node may encode memory transfer (such as CPU -> GPU or GPU -> CPU) */ -class Forward final : public FixedArityNode<1, CanonicalNodeImpl> +class Forward final : public CanonicalNodeDef::Mixin> { public: Forward() = default; @@ -79,7 +79,7 @@ public: /** * @brief Create a new value that rectifies its input */ -class ReLU final : public FixedArityNode<1, CanonicalNodeImpl> +class ReLU final : public CanonicalNodeDef::Mixin> { public: ReLU() = default; @@ -92,7 +92,7 @@ public: /** * @brief Create a new value that rectifies its input capping the units at 6. */ -class ReLU6 final : public FixedArityNode<1, CanonicalNodeImpl> +class ReLU6 final : public CanonicalNodeDef::Mixin> { public: ReLU6() = default; @@ -123,9 +123,9 @@ public: * return res; * } */ -class ConstGen final : public FixedArityNode<0, CanonicalNodeImpl>, - public NodeMixin, - public NodeMixin +class ConstGen final + : public CanonicalNodeDef::Mixin, + With::Mixin, With::Mixin> { public: ConstGen() = default; @@ -212,7 +212,7 @@ private: * before/after MaxPool2D node according to the semantics of the corresponding NN framework. * --- */ -class MaxPool2D final : public FixedArityNode<1, CanonicalNodeImpl> +class MaxPool2D final : public CanonicalNodeDef::Mixin> { public: Node *ifm(void) const { return at(0)->node(); } @@ -244,7 +244,7 @@ private: * * @note Follows MaxPool2D (TODO: describe difference) */ -class AvgPool2D final : public FixedArityNode<1, CanonicalNodeImpl> +class AvgPool2D final : public CanonicalNodeDef::Mixin> { public: enum class Convention @@ -287,7 +287,7 @@ private: * @brief Create a feature map from a tensor */ class FeatureEncode final - : public FixedArityNode<1, CanonicalNodeImpl> + : public CanonicalNodeDef::Mixin> { public: Node *input(void) const { return at(0)->node(); } @@ -306,7 +306,7 @@ private: * @brief Create a tensor from a feature map */ class FeatureDecode final - : public FixedArityNode<1, CanonicalNodeImpl> + : public CanonicalNodeDef::Mixin> { public: Node *input(void) const { return at(0)->node(); } @@ -325,7 +325,7 @@ private: * @brief Create a filter from a tensor */ class FilterEncode final - : public FixedArityNode<1, CanonicalNodeImpl> + : public CanonicalNodeDef::Mixin> { public: Node *input(void) const { return at(0)->node(); } @@ -346,7 +346,7 @@ private: * @brief Create a depthwise filter from a tensor */ class DepthwiseFilterEncode final - : public FixedArityNode<1, CanonicalNodeImpl> + : public CanonicalNodeDef::Mixin> { public: Node *input(void) const { return at(0)->node(); } @@ -389,8 +389,8 @@ template class Reshape; */ template <> class Reshape final - : public FixedArityNode<1, CanonicalNodeImpl>, - public NodeMixin + : public CanonicalNodeDef::Mixin, + With::Mixin> { public: Node *input(void) const { return at(0)->node(); } @@ -404,7 +404,7 @@ public: * concatenated along the given axis. */ class TensorConcat final - : public FixedArityNode<2, CanonicalNodeImpl> + : public CanonicalNodeDef::Mixin> { public: Node *lhs(void) const { return at(0)->node(); } @@ -425,7 +425,7 @@ private: /** * @brief 2D Spatial Convolution */ -class Conv2D final : public FixedArityNode<2, CanonicalNodeImpl> +class Conv2D final : public CanonicalNodeDef::Mixin> { public: Node *ifm(void) const { return at(0)->node(); } @@ -454,7 +454,7 @@ private: * * BiasEncode currently requires a rank-1 tensor as its input. */ -class BiasEncode final : public FixedArityNode<1, CanonicalNodeImpl> +class BiasEncode final : public CanonicalNodeDef::Mixin> { public: BiasEncode() = default; @@ -477,7 +477,7 @@ template class BiasAdd; */ template <> class BiasAdd final - : public FixedArityNode<2, CanonicalNodeImpl> + : public CanonicalNodeDef::Mixin> { public: BiasAdd() = default; @@ -513,7 +513,7 @@ using TensorBiasAdd = BiasAdd; */ template <> class BiasAdd final - : public FixedArityNode<2, CanonicalNodeImpl> + : public CanonicalNodeDef::Mixin> { public: BiasAdd() = default; -- 2.7.4