#include <loco/IR/Node.h>
#include <loco/IR/NodeMixins.h>
+#include <array>
+
namespace locoex
{
+/**
+ * @brief Nodes with the fixed number of inputs
+ */
+template <unsigned N, typename Base> class FixedArityNode : public Base
+{
+public:
+ FixedArityNode()
+ {
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ _args[n] = std::unique_ptr<loco::Use>(new loco::Use{this});
+ }
+ }
+
+ virtual ~FixedArityNode() = default;
+
+public:
+ unsigned arity(void) const final { return N; }
+
+ loco::Node *arg(uint32_t n) const final { return _args.at(n)->node(); }
+
+ void drop(void) final
+ {
+ for (uint32_t n = 0; n < N; ++n)
+ {
+ _args.at(n)->node(nullptr);
+ }
+ }
+
+protected:
+ // This API allows inherited classes to access "_args" field.
+ loco::Use *at(unsigned n) const { return _args.at(n).get(); }
+
+private:
+ std::array<std::unique_ptr<loco::Use>, N> _args;
+};
+
class TFLRelu final : public loco::FixedArityNode<1, TFLNodeImpl<TFLOpcode::RELU>>
{
public: