#ifndef __LOCO_IR_NODE_MIXINS_H__
#define __LOCO_IR_NODE_MIXINS_H__
+#include "loco/IR/Node.h"
#include "loco/IR/DataType.h"
#include "loco/IR/Dimension.h"
std::vector<Dimension> _dims;
};
+template <unsigned N> struct FixedArity
+{
+ template <typename Base> class Mixin : public virtual Base
+ {
+ public:
+ Mixin()
+ {
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ _args[n] = std::unique_ptr<Use>{new Use{this}};
+ }
+ }
+
+ virtual ~Mixin() = default;
+
+ public:
+ unsigned arity(void) const final { return N; }
+
+ Node *arg(uint32_t n) const final { return _args.at(n)->node(); }
+
+ void drop(void) final
+ {
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ _args.at(n)->node(nullptr);
+ }
+ }
+
+ protected:
+ // This API allows inherited classes to access "_args" field.
+ Use *at(unsigned n) const { return _args.at(n).get(); }
+
+ private:
+ std::array<std::unique_ptr<Use>, N> _args;
+ };
+};
+
+template <NodeTrait Trait> struct With
+{
+ template <typename Base> struct Mixin : public virtual Base, public NodeMixin<Trait>
+ {
+ // DO NOTHING
+ };
+};
+
} // namespace loco
#endif // __LOCO_IR_NODE_MIXINS_H__
/**
* @brief Make a value visible to user
*/
-class Push /* to user */ final : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::Push>>,
- public NodeMixin<NodeTrait::TensorShape>
+class Push /* to user */ final
+ : public CanonicalNodeDef<CanonicalOpcode::Push, FixedArity<1>::Mixin,
+ With<NodeTrait::TensorShape>::Mixin>
{
public:
Push() = default;
* @brief Create a value from user data
*/
class Pull /* from user */ final
- : public FixedArityNode<0, CanonicalNodeImpl<CanonicalOpcode::Pull>>,
- public NodeMixin<NodeTrait::DataType>,
- public NodeMixin<NodeTrait::TensorShape>
+ : public CanonicalNodeDef<CanonicalOpcode::Pull, FixedArity<0>::Mixin,
+ With<NodeTrait::DataType>::Mixin, With<NodeTrait::TensorShape>::Mixin>
{
public:
Pull() = default;
*
* This node may encode memory transfer (such as CPU -> GPU or GPU -> CPU)
*/
-class Forward final : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::Forward>>
+class Forward final : public CanonicalNodeDef<CanonicalOpcode::Forward, FixedArity<1>::Mixin>
{
public:
Forward() = default;
/**
* @brief Create a new value that rectifies its input
*/
-class ReLU final : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::ReLU>>
+class ReLU final : public CanonicalNodeDef<CanonicalOpcode::ReLU, FixedArity<1>::Mixin>
{
public:
ReLU() = default;
/**
* @brief Create a new value that rectifies its input capping the units at 6.
*/
-class ReLU6 final : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::ReLU6>>
+class ReLU6 final : public CanonicalNodeDef<CanonicalOpcode::ReLU6, FixedArity<1>::Mixin>
{
public:
ReLU6() = default;
* return res;
* }
*/
-class ConstGen final : public FixedArityNode<0, CanonicalNodeImpl<CanonicalOpcode::ConstGen>>,
- public NodeMixin<NodeTrait::DataType>,
- public NodeMixin<NodeTrait::TensorShape>
+class ConstGen final
+ : public CanonicalNodeDef<CanonicalOpcode::ConstGen, FixedArity<0>::Mixin,
+ With<NodeTrait::DataType>::Mixin, With<NodeTrait::TensorShape>::Mixin>
{
public:
ConstGen() = default;
* before/after MaxPool2D node according to the semantics of the corresponding NN framework.
* ---
*/
-class MaxPool2D final : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::MaxPool2D>>
+class MaxPool2D final : public CanonicalNodeDef<CanonicalOpcode::MaxPool2D, FixedArity<1>::Mixin>
{
public:
Node *ifm(void) const { return at(0)->node(); }
*
* @note Follows MaxPool2D (TODO: describe difference)
*/
-class AvgPool2D final : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::AvgPool2D>>
+class AvgPool2D final : public CanonicalNodeDef<CanonicalOpcode::AvgPool2D, FixedArity<1>::Mixin>
{
public:
enum class Convention
* @brief Create a feature map from a tensor
*/
class FeatureEncode final
- : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::FeatureEncode>>
+ : public CanonicalNodeDef<CanonicalOpcode::FeatureEncode, FixedArity<1>::Mixin>
{
public:
Node *input(void) const { return at(0)->node(); }
* @brief Create a tensor from a feature map
*/
class FeatureDecode final
- : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::FeatureDecode>>
+ : public CanonicalNodeDef<CanonicalOpcode::FeatureDecode, FixedArity<1>::Mixin>
{
public:
Node *input(void) const { return at(0)->node(); }
* @brief Create a filter from a tensor
*/
class FilterEncode final
- : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::FilterEncode>>
+ : public CanonicalNodeDef<CanonicalOpcode::FilterEncode, FixedArity<1>::Mixin>
{
public:
Node *input(void) const { return at(0)->node(); }
* @brief Create a depthwise filter from a tensor
*/
class DepthwiseFilterEncode final
- : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::DepthwiseFilterEncode>>
+ : public CanonicalNodeDef<CanonicalOpcode::DepthwiseFilterEncode, FixedArity<1>::Mixin>
{
public:
Node *input(void) const { return at(0)->node(); }
*/
template <>
class Reshape<ReshapeType::Fixed> final
- : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::FixedReshape>>,
- public NodeMixin<NodeTrait::TensorShape>
+ : public CanonicalNodeDef<CanonicalOpcode::FixedReshape, FixedArity<1>::Mixin,
+ With<NodeTrait::TensorShape>::Mixin>
{
public:
Node *input(void) const { return at(0)->node(); }
* concatenated along the given axis.
*/
class TensorConcat final
- : public FixedArityNode<2, CanonicalNodeImpl<CanonicalOpcode::TensorConcat>>
+ : public CanonicalNodeDef<CanonicalOpcode::TensorConcat, FixedArity<2>::Mixin>
{
public:
Node *lhs(void) const { return at(0)->node(); }
/**
* @brief 2D Spatial Convolution
*/
-class Conv2D final : public FixedArityNode<2, CanonicalNodeImpl<CanonicalOpcode::Conv2D>>
+class Conv2D final : public CanonicalNodeDef<CanonicalOpcode::Conv2D, FixedArity<2>::Mixin>
{
public:
Node *ifm(void) const { return at(0)->node(); }
*
* BiasEncode currently requires a rank-1 tensor as its input.
*/
-class BiasEncode final : public FixedArityNode<1, CanonicalNodeImpl<CanonicalOpcode::BiasEncode>>
+class BiasEncode final : public CanonicalNodeDef<CanonicalOpcode::BiasEncode, FixedArity<1>::Mixin>
{
public:
BiasEncode() = default;
*/
template <>
class BiasAdd<Domain::Tensor> final
- : public FixedArityNode<2, CanonicalNodeImpl<CanonicalOpcode::TensorBiasAdd>>
+ : public CanonicalNodeDef<CanonicalOpcode::TensorBiasAdd, FixedArity<2>::Mixin>
{
public:
BiasAdd() = default;
*/
template <>
class BiasAdd<Domain::Feature> final
- : public FixedArityNode<2, CanonicalNodeImpl<CanonicalOpcode::FeatureBiasAdd>>
+ : public CanonicalNodeDef<CanonicalOpcode::FeatureBiasAdd, FixedArity<2>::Mixin>
{
public:
BiasAdd() = default;