[neurun] Introduce OperandConstraint (#3182)
author이한종/동작제어Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Wed, 17 Oct 2018 01:15:24 +0000 (10:15 +0900)
committer오형석/동작제어Lab(SR)/Staff Engineer/삼성전자 <hseok82.oh@samsung.com>
Wed, 17 Oct 2018 01:15:24 +0000 (10:15 +0900)
Introduce OperandConstraint to simplify implementation of
setInputs/Outputs. They were virtual methods but this commit makes it
non-virtual to make the code simpler. Instead `Node` class now uses
`OperandConstraint` which limits the number of input operands(and
possibly output operands).

Signed-off-by: Hanjoung Lee <hanjoung.lee@samsung.com>
23 files changed:
runtimes/neurun/src/graph/operation/AvgPool2D.cc
runtimes/neurun/src/graph/operation/AvgPool2D.h
runtimes/neurun/src/graph/operation/Concat.cc
runtimes/neurun/src/graph/operation/Concat.h
runtimes/neurun/src/graph/operation/Conv2D.cc
runtimes/neurun/src/graph/operation/Conv2D.h
runtimes/neurun/src/graph/operation/FullyConnected.cc
runtimes/neurun/src/graph/operation/FullyConnected.h
runtimes/neurun/src/graph/operation/MaxPool2D.cc
runtimes/neurun/src/graph/operation/MaxPool2D.h
runtimes/neurun/src/graph/operation/NOP.cc
runtimes/neurun/src/graph/operation/NOP.h
runtimes/neurun/src/graph/operation/Node.cc
runtimes/neurun/src/graph/operation/Node.h
runtimes/neurun/src/graph/operation/OperandConstraint.cc [new file with mode: 0644]
runtimes/neurun/src/graph/operation/OperandConstraint.h [new file with mode: 0644]
runtimes/neurun/src/graph/operation/Permute.cc
runtimes/neurun/src/graph/operation/Permute.h
runtimes/neurun/src/graph/operation/Reshape.cc
runtimes/neurun/src/graph/operation/Reshape.h
runtimes/neurun/src/graph/operation/Softmax.cc
runtimes/neurun/src/graph/operation/Softmax.h
runtimes/neurun/test/graph/MockNode.h

index adcdd3c..1ab06b6 100644 (file)
@@ -35,6 +35,7 @@ namespace Implicit
 void Node::accept(NodeVisitor &&v) const { v.visit(*this); }
 
 Node::Node(const graph::operation::Node::InitParam &init_param)
+    : operation::Node{OperandConstraint::createEQ(1u)}
 {
   assert(init_param.input_count == 7);
   assert(init_param.output_count == 1);
@@ -61,20 +62,6 @@ Node::Node(const graph::operation::Node::InitParam &init_param)
   _param.activation_index = operand::Index{init_param.inputs[6]};
 }
 
-void Node::setInputs(const operand::IndexSet &indexes)
-{
-  assert(indexes.size() == 1);
-
-  graph::operation::Node::setInputs(indexes);
-}
-
-void Node::setOutputs(const operand::IndexSet &indexes)
-{
-  assert(indexes.size() == 1);
-
-  graph::operation::Node::setOutputs(indexes);
-}
-
 } // namespace Implicit
 } // namespace AvgPool2D
 } // namespace operation
index 41ec0db..8f4c455 100644 (file)
@@ -59,10 +59,6 @@ public:
   virtual std::string getName() const override { return "AvgPool2D"; }
 
 public:
-  virtual void setInputs(const operand::IndexSet &indexes) override;
-  virtual void setOutputs(const operand::IndexSet &indexes) override;
-
-public:
   const Param &param() const { return _param; }
 
 private:
index ae311d7..2e130da 100644 (file)
@@ -33,6 +33,7 @@ namespace Concat
 void Node::accept(NodeVisitor &&v) const { v.visit(*this); }
 
 Node::Node(const graph::operation::Node::InitParam &init_param)
+    : operation::Node{OperandConstraint::createGE(2u)}
 {
   assert(init_param.input_count >= 2); // At least one one input tensor and axis
   assert(init_param.output_count == 1);
@@ -56,13 +57,6 @@ Node::Node(const graph::operation::Node::InitParam &init_param)
   _param.axis_index = operand::Index{init_param.inputs[init_param.input_count - 1]};
 }
 
-void Node::setOutputs(const operand::IndexSet &indexes)
-{
-  assert(indexes.size() == 1);
-
-  graph::operation::Node::setOutputs(indexes);
-}
-
 } // namespace Concat
 } // namespace operation
 } // namespace graph
index 69aabcc..7710d46 100644 (file)
@@ -45,9 +45,6 @@ public:
   virtual std::string getName() const override { return "Concat"; }
 
 public:
-  virtual void setOutputs(const operand::IndexSet &indexes) override;
-
-public:
   const Param &param() const { return _param; }
 
 private:
index 61b1d79..47e469d 100644 (file)
@@ -35,6 +35,7 @@ namespace Implicit
 void Node::accept(NodeVisitor &&v) const { v.visit(*this); }
 
 Node::Node(const graph::operation::Node::InitParam &init_param)
+    : operation::Node{OperandConstraint::createEQ(3u)}
 {
   assert(init_param.input_count == 7 && init_param.output_count == 1);
 
@@ -58,20 +59,6 @@ Node::Node(const graph::operation::Node::InitParam &init_param)
   _param.activation_index = operand::Index{init_param.inputs[6]};
 }
 
-void Node::setInputs(const operand::IndexSet &indexes)
-{
-  assert(indexes.size() == 3);
-
-  graph::operation::Node::setInputs(indexes);
-}
-
-void Node::setOutputs(const operand::IndexSet &indexes)
-{
-  assert(indexes.size() == 1);
-
-  graph::operation::Node::setOutputs(indexes);
-}
-
 } // namespace Implicit
 } // namespace Conv2D
 } // namespace operation
index 6221eab..b0eec01 100644 (file)
@@ -58,10 +58,6 @@ public:
   virtual std::string getName() const override { return "Conv2D"; }
 
 public:
-  virtual void setInputs(const operand::IndexSet &indexes) override;
-  virtual void setOutputs(const operand::IndexSet &indexes) override;
-
-public:
   const Param &param() const { return _param; }
 
 private:
index db3cf05..086db66 100644 (file)
@@ -33,6 +33,7 @@ namespace FullyConnected
 void Node::accept(NodeVisitor &&v) const { v.visit(*this); }
 
 Node::Node(const graph::operation::Node::InitParam &init_param)
+    : operation::Node{OperandConstraint::createEQ(3u)}
 {
   assert(init_param.input_count == 4 && init_param.output_count == 1);
 
@@ -49,20 +50,6 @@ Node::Node(const graph::operation::Node::InitParam &init_param)
   _param.activation_index = operand::Index{init_param.inputs[3]};
 }
 
-void Node::setInputs(const operand::IndexSet &indexes)
-{
-  assert(indexes.size() == 3);
-
-  graph::operation::Node::setInputs(indexes);
-}
-
-void Node::setOutputs(const operand::IndexSet &indexes)
-{
-  assert(indexes.size() == 1);
-
-  graph::operation::Node::setOutputs(indexes);
-}
-
 } // namespace FullyConnected
 } // namespace operation
 } // namespace graph
index b935f7b..2c28452 100644 (file)
@@ -52,10 +52,6 @@ public:
   virtual std::string getName() const override { return "FullyConnected"; }
 
 public:
-  virtual void setInputs(const operand::IndexSet &indexes) override;
-  virtual void setOutputs(const operand::IndexSet &indexes) override;
-
-public:
   const Param &param() const { return _param; }
 
 private:
index eec0972..b08f6a6 100644 (file)
@@ -35,6 +35,7 @@ namespace Implicit
 void Node::accept(NodeVisitor &&v) const { v.visit(*this); }
 
 Node::Node(const graph::operation::Node::InitParam &init_param)
+    : operation::Node{OperandConstraint::createEQ(1u)}
 {
   assert(init_param.input_count == 7);
   assert(init_param.output_count == 1);
@@ -61,20 +62,6 @@ Node::Node(const graph::operation::Node::InitParam &init_param)
   _param.activation_index = operand::Index{init_param.inputs[6]};
 }
 
-void Node::setInputs(const operand::IndexSet &indexes)
-{
-  assert(indexes.size() == 1);
-
-  graph::operation::Node::setInputs(indexes);
-}
-
-void Node::setOutputs(const operand::IndexSet &indexes)
-{
-  assert(indexes.size() == 1);
-
-  graph::operation::Node::setOutputs(indexes);
-}
-
 } // namespace Implicit
 } // namespace MaxPool2D
 } // namespace operation
index 09557f0..d626ec5 100644 (file)
@@ -59,10 +59,6 @@ public:
   Node(const graph::operation::Node::InitParam &init_param);
 
 public:
-  virtual void setInputs(const operand::IndexSet &indexes) override;
-  virtual void setOutputs(const operand::IndexSet &indexes) override;
-
-public:
   const Param &param() const { return _param; }
 
 private:
index 18c3246..ea98c6e 100644 (file)
@@ -30,6 +30,11 @@ namespace NOP
 
 void Node::accept(NodeVisitor &&v) const { v.visit(*this); }
 
+Node::Node(const graph::operation::Node::InitParam &)
+    : operation::Node{OperandConstraint::createEQ(1u)}
+{
+}
+
 } // namespace NOP
 } // namespace operation
 } // namespace graph
index 8da8c8b..193c5d4 100644 (file)
@@ -33,7 +33,7 @@ namespace NOP
 class Node : public graph::operation::Node
 {
 public:
-  Node(const graph::operation::Node::InitParam &) {}
+  Node(const graph::operation::Node::InitParam &);
 
 public:
   virtual void accept(NodeVisitor &&) const override;
index eeb3ff4..74b37c4 100644 (file)
@@ -16,6 +16,8 @@
 
 #include "Node.h"
 
+#include <cassert>
+
 #include "LowerInfo.h"
 
 namespace neurun
@@ -25,10 +27,18 @@ namespace graph
 namespace operation
 {
 
-Node::Node() = default;
+Node::Node(OperandConstraint input_constr) : _input_constr{input_constr} {}
 
 Node::~Node() = default;
 
+void Node::setInputs(const operand::IndexSet &indexes)
+{
+  assert(_input_constr.check(indexes.size()));
+  _inputs = indexes;
+}
+
+void Node::setOutputs(const operand::IndexSet &indexes) { _outputs = indexes; }
+
 void Node::replaceInput(const operand::Index &from, const operand::Index &to)
 {
   _inputs.replace(from, to);
index bab51b3..a1c5be8 100644 (file)
@@ -20,6 +20,7 @@
 #include <memory>
 
 #include "graph/operand/IndexSet.h"
+#include "OperandConstraint.h"
 
 namespace neurun
 {
@@ -43,7 +44,7 @@ public:
   };
 
 public:
-  Node();
+  Node(OperandConstraint input_constr);
   virtual ~Node();
 
 public:
@@ -51,13 +52,13 @@ public:
   virtual std::string getName() const = 0;
 
 public:
-  virtual void replaceInput(const operand::Index &from, const operand::Index &to);
-  virtual void replaceOutput(const operand::Index &from, const operand::Index &to);
-  virtual const operand::IndexSet &getInputs() const { return _inputs; }
-  virtual const operand::IndexSet &getOutputs() const { return _outputs; }
+  void replaceInput(const operand::Index &from, const operand::Index &to);
+  void replaceOutput(const operand::Index &from, const operand::Index &to);
+  const operand::IndexSet &getInputs() const { return _inputs; }
+  const operand::IndexSet &getOutputs() const { return _outputs; }
   // It's for only input/output tensors but const data.
-  virtual void setInputs(const operand::IndexSet &indexes) { _inputs = indexes; }
-  virtual void setOutputs(const operand::IndexSet &indexes) { _outputs = indexes; }
+  void setInputs(const operand::IndexSet &indexes);
+  void setOutputs(const operand::IndexSet &indexes);
 
 public:
   void lower_info(std::unique_ptr<LowerInfo> &&lower_info);
@@ -66,6 +67,7 @@ public:
 private:
   operand::IndexSet _inputs;
   operand::IndexSet _outputs;
+  OperandConstraint _input_constr;
   std::unique_ptr<LowerInfo> _lower_info;
 };
 
diff --git a/runtimes/neurun/src/graph/operation/OperandConstraint.cc b/runtimes/neurun/src/graph/operation/OperandConstraint.cc
new file mode 100644 (file)
index 0000000..f434cb8
--- /dev/null
@@ -0,0 +1,12 @@
+#include "OperandConstraint.h"
+
+namespace neurun
+{
+namespace graph
+{
+namespace operation
+{
+
+} // namespace operation
+} // namespace graph
+} // namespace neurun
diff --git a/runtimes/neurun/src/graph/operation/OperandConstraint.h b/runtimes/neurun/src/graph/operation/OperandConstraint.h
new file mode 100644 (file)
index 0000000..1f656c4
--- /dev/null
@@ -0,0 +1,45 @@
+#ifndef __NEURUN_GRAPH_OPERATION_OPERAND_CONSTRAINT_H__
+#define __NEURUN_GRAPH_OPERATION_OPERAND_CONSTRAINT_H__
+
+#include <stdint.h>
+#include <limits>
+#include <set>
+
+namespace neurun
+{
+namespace graph
+{
+namespace operation
+{
+
+class OperandConstraint
+{
+private:
+  static const uint32_t INF = std::numeric_limits<uint32_t>::max();
+
+public:
+  static OperandConstraint createAny() { return OperandConstraint{0u, INF}; }
+  static OperandConstraint createEQ(uint32_t exact) { return OperandConstraint{exact, exact}; }
+  static OperandConstraint createLE(uint32_t end) { return OperandConstraint{0u, end}; }
+  static OperandConstraint createGE(uint32_t begin) { return OperandConstraint{begin, INF}; }
+  static OperandConstraint createRange(uint32_t begin, uint32_t end)
+  {
+    return OperandConstraint{begin, end};
+  }
+
+private:
+  OperandConstraint(uint32_t begin, uint32_t end) : _begin{begin}, _end{end} {}
+
+public:
+  bool check(uint32_t ind) const { return _begin <= ind && ind <= _end; }
+
+private:
+  uint32_t _begin;
+  uint32_t _end;
+};
+
+} // namespace operation
+} // namespace graph
+} // namespace neurun
+
+#endif // __NEURUN_GRAPH_OPERATION_OPERAND_CONSTRAINT_H__
index 2688e5e..57a40f7 100644 (file)
@@ -16,25 +16,12 @@ namespace Permute
 void Node::accept(NodeVisitor &&v) const { v.visit(*this); }
 
 Node::Node(const operand::Index &input, const operand::Index &output)
+    : operation::Node{OperandConstraint::createEQ(1u)}
 {
   setInputs({input});
   setOutputs({output});
 }
 
-void Node::setInputs(const operand::IndexSet &indexes)
-{
-  assert(indexes.size() == 1);
-
-  graph::operation::Node::setInputs(indexes);
-}
-
-void Node::setOutputs(const operand::IndexSet &indexes)
-{
-  assert(indexes.size() == 1);
-
-  graph::operation::Node::setOutputs(indexes);
-}
-
 } // namespace Permute
 } // namespace operation
 } // namespace graph
index d0b8e9a..5a64ca8 100644 (file)
@@ -20,10 +20,6 @@ public:
 
 public:
   Node(const operand::Index &input, const operand::Index &output);
-
-public:
-  virtual void setInputs(const operand::IndexSet &indexes) override;
-  virtual void setOutputs(const operand::IndexSet &indexes) override;
 };
 
 } // namespace Permute
index e6bc211..184f907 100644 (file)
@@ -33,6 +33,7 @@ namespace Reshape
 void Node::accept(NodeVisitor &&v) const { v.visit(*this); }
 
 Node::Node(const graph::operation::Node::InitParam &init_param)
+    : operation::Node{OperandConstraint::createEQ(1u)}
 {
   assert(init_param.input_count == 2 && init_param.output_count == 1);
 
@@ -47,20 +48,6 @@ Node::Node(const graph::operation::Node::InitParam &init_param)
   setOutputs({init_param.outputs[0]});
 }
 
-void Node::setInputs(const operand::IndexSet &indexes)
-{
-  assert(indexes.size() == 1); // TODO Should be 2 (See also the constructor)
-
-  graph::operation::Node::setInputs(indexes);
-}
-
-void Node::setOutputs(const operand::IndexSet &indexes)
-{
-  assert(indexes.size() == 1);
-
-  graph::operation::Node::setOutputs(indexes);
-}
-
 } // namespace Reshape
 } // namespace operation
 } // namespace graph
index ede243a..3959007 100644 (file)
@@ -43,10 +43,6 @@ public:
 
 public:
   Node(const graph::operation::Node::InitParam &init_param);
-
-public:
-  virtual void setInputs(const operand::IndexSet &indexes) override;
-  virtual void setOutputs(const operand::IndexSet &indexes) override;
 };
 
 } // namespace Reshape
index a6385a2..307258a 100644 (file)
@@ -33,6 +33,7 @@ namespace Softmax
 void Node::accept(NodeVisitor &&v) const { v.visit(*this); }
 
 Node::Node(const graph::operation::Node::InitParam &init_param)
+    : operation::Node{OperandConstraint::createEQ(1u)}
 {
   assert(init_param.input_count == 2 && init_param.output_count == 1);
 
@@ -47,20 +48,6 @@ Node::Node(const graph::operation::Node::InitParam &init_param)
   _param.scale_index = operand::Index{init_param.inputs[1]};
 }
 
-void Node::setInputs(const operand::IndexSet &indexes)
-{
-  assert(indexes.size() == 1);
-
-  graph::operation::Node::setInputs(indexes);
-}
-
-void Node::setOutputs(const operand::IndexSet &indexes)
-{
-  assert(indexes.size() == 1);
-
-  graph::operation::Node::setOutputs(indexes);
-}
-
 } // namespace Softmax
 } // namespace operation
 } // namespace graph
index a1575bc..8fc8ae8 100644 (file)
@@ -50,10 +50,6 @@ public:
   Node(const graph::operation::Node::InitParam &init_param);
 
 public:
-  virtual void setInputs(const operand::IndexSet &indexes) override;
-  virtual void setOutputs(const operand::IndexSet &indexes) override;
-
-public:
   const Param &param() const { return _param; }
 
 private:
index e0baf6e..4f4ea4d 100644 (file)
@@ -30,6 +30,7 @@ class SimpleMockNode : public neurun::graph::operation::Node
 public:
   SimpleMockNode(const neurun::graph::operand::IndexSet &inputs,
                  const neurun::graph::operand::IndexSet &outputs)
+      : neurun::graph::operation::Node{neurun::graph::operation::OperandConstraint::createAny()}
   {
     setInputs(inputs);
     setOutputs(outputs);