[exo-tflite] Adding FixedArityNode (#6917)
author윤현식/On-Device Lab(SR)/Principal Engineer/삼성전자 <hyunsik.yoon@samsung.com>
Mon, 26 Aug 2019 08:19:46 +0000 (17:19 +0900)
committer박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 26 Aug 2019 08:19:46 +0000 (17:19 +0900)
FixedArityNode is copied from loco since loco::FixedArityNode's destiny is doomed.

Signed-off-by: Hyun Sik Yoon <hyunsik.yoon@samsung.com>
compiler/exo-tflite/src/Dialect/IR/TFLNodes.h

index 5feee48..4b648ac 100644 (file)
 #include <loco/IR/Node.h>
 #include <loco/IR/NodeMixins.h>
 
+#include <array>
+
 namespace locoex
 {
 
+/**
+ * @brief Nodes with the fixed number of inputs
+ */
+template <unsigned N, typename Base> class FixedArityNode : public Base
+{
+public:
+  FixedArityNode()
+  {
+    for (uint32_t n = 0; n < N; ++n)
+    {
+      _args[n] = std::unique_ptr<loco::Use>(new loco::Use{this});
+    }
+  }
+
+  virtual ~FixedArityNode() = default;
+
+public:
+  unsigned arity(void) const final { return N; }
+
+  loco::Node *arg(uint32_t n) const final { return _args.at(n)->node(); }
+
+  void drop(void) final
+  {
+    for (uint32_t n = 0; n < N; ++n)
+    {
+      _args.at(n)->node(nullptr);
+    }
+  }
+
+protected:
+  // This API allows inherited classes to access "_args" field.
+  loco::Use *at(unsigned n) const { return _args.at(n).get(); }
+
+private:
+  std::array<std::unique_ptr<loco::Use>, N> _args;
+};
+
 class TFLRelu final : public loco::FixedArityNode<1, TFLNodeImpl<TFLOpcode::RELU>>
 {
 public: