[moco-tf] Introduce FixedArityNode (#6907)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 26 Aug 2019 04:48:04 +0000 (13:48 +0900)
committer오형석/On-Device Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Mon, 26 Aug 2019 04:48:04 +0000 (13:48 +0900)
This commit introduces moco-tf internal FixedArityNode class as the
first step to remove FixedArityNode from loco.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
24 files changed:
compiler/moco-tf/src/Dialect/TFNodeDecl.h
compiler/moco-tf/src/IR/TFAdd.h
compiler/moco-tf/src/IR/TFAvgPool.h
compiler/moco-tf/src/IR/TFBiasAdd.h
compiler/moco-tf/src/IR/TFConst.h
compiler/moco-tf/src/IR/TFConv2D.h
compiler/moco-tf/src/IR/TFDepthwiseConv2dNative.h
compiler/moco-tf/src/IR/TFFusedBatchNorm.h
compiler/moco-tf/src/IR/TFIdentity.h
compiler/moco-tf/src/IR/TFMaxPool.h
compiler/moco-tf/src/IR/TFMul.h
compiler/moco-tf/src/IR/TFRealDiv.h
compiler/moco-tf/src/IR/TFRelu.h
compiler/moco-tf/src/IR/TFRelu6.h
compiler/moco-tf/src/IR/TFReshape.h
compiler/moco-tf/src/IR/TFRsqrt.h
compiler/moco-tf/src/IR/TFShape.h
compiler/moco-tf/src/IR/TFSoftmax.h
compiler/moco-tf/src/IR/TFSqrt.h
compiler/moco-tf/src/IR/TFSquaredDifference.h
compiler/moco-tf/src/IR/TFSqueeze.h
compiler/moco-tf/src/IR/TFStopGradient.h
compiler/moco-tf/src/IR/TFSub.h
compiler/moco-tf/src/IR/TFTanh.h

index 0e9ba0c..922165b 100644 (file)
@@ -23,6 +23,8 @@
 #include "TFOpcode.h"
 #include "TFNodeVisitor.forward.h"
 
+#include <array>
+
 namespace moco
 {
 namespace tf
@@ -51,6 +53,43 @@ template <TFOpcode Code> struct TFNodeImpl : public TFNode
   TFOpcode opcode(void) const final { return Code; }
 };
 
+/**
+ * @brief Nodes with the fixed number of inputs
+ */
+template <unsigned N, typename Base> class FixedArityNode : public Base
+{
+public:
+  FixedArityNode()
+  {
+    for (uint32_t n = 0; n < N; ++n)
+    {
+      _args[n] = std::unique_ptr<loco::Use>{new loco::Use{this}};
+    }
+  }
+
+  virtual ~FixedArityNode() = default;
+
+public:
+  unsigned arity(void) const final { return N; }
+
+  loco::Node *arg(uint32_t n) const final { return _args.at(n)->node(); }
+
+  void drop(void) final
+  {
+    for (uint32_t n = 0; n < N; ++n)
+    {
+      _args.at(n)->node(nullptr);
+    }
+  }
+
+protected:
+  // This API allows inherited classes to access "_args" field.
+  loco::Use *at(unsigned n) const { return _args.at(n).get(); }
+
+private:
+  std::array<std::unique_ptr<loco::Use>, N> _args;
+};
+
 } // namespace tf
 } // namespace moco
 
index 7042630..d2489fb 100644 (file)
@@ -40,7 +40,7 @@ node {
 }
 */
 
-class TFAdd final : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::Add>>
+class TFAdd final : public FixedArityNode<2, TFNodeImpl<TFOpcode::Add>>
 {
 public:
   TFAdd() = default;
index a562594..93a72bb 100644 (file)
@@ -69,7 +69,7 @@ node {
 }
 */
 
-class TFAvgPool final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::AvgPool>>
+class TFAvgPool final : public FixedArityNode<1, TFNodeImpl<TFOpcode::AvgPool>>
 {
 public:
   TFAvgPool() = default;
index 58f47e2..468e02d 100644 (file)
@@ -46,7 +46,7 @@ node {
 }
 */
 
-class TFBiasAdd final : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::BiasAdd>>
+class TFBiasAdd final : public FixedArityNode<2, TFNodeImpl<TFOpcode::BiasAdd>>
 {
 public:
   TFBiasAdd() = default;
index a44ecad..b63d37d 100644 (file)
@@ -62,7 +62,7 @@ node {
  * @note   Implementation for this class came from Canonical ConstGen
  *         Read comments in loco::ConstGen for details
  */
-class TFConst final : public loco::FixedArityNode<0, TFNodeImpl<TFOpcode::Const>>,
+class TFConst final : public FixedArityNode<0, TFNodeImpl<TFOpcode::Const>>,
                       public loco::NodeMixin<loco::NodeTrait::DataType>,
                       public loco::NodeMixin<loco::NodeTrait::TensorShape>
 {
index a433ff6..e061d2b 100644 (file)
@@ -30,7 +30,7 @@ namespace moco
 namespace tf
 {
 
-class TFConv2D final : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::Conv2D>>
+class TFConv2D final : public FixedArityNode<2, TFNodeImpl<TFOpcode::Conv2D>>
 {
 public:
   loco::Node *ifm(void) const { return at(0)->node(); }
index 6eadaa8..2e9c7c9 100644 (file)
@@ -33,7 +33,7 @@ namespace tf
 {
 
 class TFDepthwiseConv2dNative final
-    : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::DepthwiseConv2dNative>>
+    : public FixedArityNode<2, TFNodeImpl<TFOpcode::DepthwiseConv2dNative>>
 {
 public:
   loco::Node *ifm(void) const { return at(0)->node(); }
index 6a9e0c5..297f439 100644 (file)
@@ -24,7 +24,7 @@ namespace moco
 namespace tf
 {
 
-class TFFusedBatchNorm final : public loco::FixedArityNode<5, TFNodeImpl<TFOpcode::FusedBatchNorm>>
+class TFFusedBatchNorm final : public FixedArityNode<5, TFNodeImpl<TFOpcode::FusedBatchNorm>>
 {
 public:
   TFFusedBatchNorm() = default;
index a1ef1c8..9eeab8d 100644 (file)
@@ -39,7 +39,7 @@ node {
 }
 */
 
-class TFIdentity final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::Identity>>
+class TFIdentity final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Identity>>
 {
 public:
   TFIdentity() = default;
index 847e85d..14dae70 100644 (file)
@@ -69,7 +69,7 @@ node {
 }
 */
 
-class TFMaxPool final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::MaxPool>>
+class TFMaxPool final : public FixedArityNode<1, TFNodeImpl<TFOpcode::MaxPool>>
 {
 public:
   TFMaxPool() = default;
index 5390612..95826f0 100644 (file)
@@ -40,7 +40,7 @@ node {
 }
 */
 
-class TFMul final : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::Mul>>
+class TFMul final : public FixedArityNode<2, TFNodeImpl<TFOpcode::Mul>>
 {
 public:
   TFMul() = default;
index 16bbb10..8ef3786 100644 (file)
@@ -40,7 +40,7 @@ node {
 }
 */
 
-class TFRealDiv final : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::RealDiv>>
+class TFRealDiv final : public FixedArityNode<2, TFNodeImpl<TFOpcode::RealDiv>>
 {
 public:
   TFRealDiv() = default;
index cbcc227..7df958b 100644 (file)
@@ -24,7 +24,7 @@ namespace moco
 namespace tf
 {
 
-class TFRelu final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::Relu>>
+class TFRelu final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Relu>>
 {
 public:
   TFRelu() = default;
index 360c2de..eba83a9 100644 (file)
@@ -24,7 +24,7 @@ namespace moco
 namespace tf
 {
 
-class TFRelu6 final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::Relu6>>
+class TFRelu6 final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Relu6>>
 {
 public:
   TFRelu6() = default;
index 7b8fb4a..4359a49 100644 (file)
@@ -38,7 +38,7 @@ node {
 }
 */
 
-class TFReshape final : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::Reshape>>
+class TFReshape final : public FixedArityNode<2, TFNodeImpl<TFOpcode::Reshape>>
 {
 public:
   TFReshape() = default;
index f03f36d..f371e39 100644 (file)
@@ -39,7 +39,7 @@ node {
 }
 */
 
-class TFRsqrt final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::Rsqrt>>
+class TFRsqrt final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Rsqrt>>
 {
 public:
   TFRsqrt() = default;
index a932cd2..d50cabf 100644 (file)
@@ -46,7 +46,7 @@ node {
 */
 
 /// @note  Mixed in dtype() is for 'out_type' attribute
-class TFShape final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::Shape>>,
+class TFShape final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Shape>>,
                       public loco::NodeMixin<loco::NodeTrait::DataType>
 {
 public:
index 15e3ef0..22b7b9e 100644 (file)
@@ -24,7 +24,7 @@ namespace moco
 namespace tf
 {
 
-class TFSoftmax final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::Softmax>>
+class TFSoftmax final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Softmax>>
 {
 public:
   TFSoftmax() = default;
index a95ff16..fda032e 100644 (file)
@@ -39,7 +39,7 @@ node {
 }
 */
 
-class TFSqrt final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::Sqrt>>
+class TFSqrt final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Sqrt>>
 {
 public:
   TFSqrt() = default;
index b0d8051..83ecdb8 100644 (file)
@@ -40,8 +40,7 @@ node {
 }
 */
 
-class TFSquaredDifference final
-    : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::SquaredDifference>>
+class TFSquaredDifference final : public FixedArityNode<2, TFNodeImpl<TFOpcode::SquaredDifference>>
 {
 public:
   TFSquaredDifference() = default;
index 5f30ea5..e986441 100644 (file)
@@ -51,7 +51,7 @@ node {
 }
 */
 
-class TFSqueeze final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::Squeeze>>
+class TFSqueeze final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Squeeze>>
 {
 public:
   TFSqueeze() = default;
index d39f14e..4b8f1b8 100644 (file)
@@ -39,7 +39,7 @@ node {
 }
 */
 
-class TFStopGradient final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::StopGradient>>
+class TFStopGradient final : public FixedArityNode<1, TFNodeImpl<TFOpcode::StopGradient>>
 {
 public:
   TFStopGradient() = default;
index 00ebf44..5f4e48b 100644 (file)
@@ -40,7 +40,7 @@ node {
 }
 */
 
-class TFSub final : public loco::FixedArityNode<2, TFNodeImpl<TFOpcode::Sub>>
+class TFSub final : public FixedArityNode<2, TFNodeImpl<TFOpcode::Sub>>
 {
 public:
   TFSub() = default;
index eebee4b..c85663e 100644 (file)
@@ -24,7 +24,7 @@ namespace moco
 namespace tf
 {
 
-class TFTanh final : public loco::FixedArityNode<1, TFNodeImpl<TFOpcode::Tanh>>
+class TFTanh final : public FixedArityNode<1, TFNodeImpl<TFOpcode::Tanh>>
 {
 public:
   TFTanh() = default;