From: 이한종/동작제어Lab(SR)/Engineer/삼성전자 Date: Wed, 17 Oct 2018 01:15:24 +0000 (+0900) Subject: [neurun] Introduce OperandConstraint (#3182) X-Git-Tag: 0.3~605 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=6ade4151fe14bb9f004529bafcd9ea8aabc1d7b3;p=platform%2Fcore%2Fml%2Fnnfw.git [neurun] Introduce OperandConstraint (#3182) Introduce OperandConstraint to simplify implementation of setInputs/Outputs. They were virtual methods but this commit makes it non-virtual to make the code simpler. Instead `Node` class now uses `OperandConstraint` which limits the number of input operands(and possibly output operands). Signed-off-by: Hanjoung Lee --- diff --git a/runtimes/neurun/src/graph/operation/AvgPool2D.cc b/runtimes/neurun/src/graph/operation/AvgPool2D.cc index adcdd3c..1ab06b6 100644 --- a/runtimes/neurun/src/graph/operation/AvgPool2D.cc +++ b/runtimes/neurun/src/graph/operation/AvgPool2D.cc @@ -35,6 +35,7 @@ namespace Implicit void Node::accept(NodeVisitor &&v) const { v.visit(*this); } Node::Node(const graph::operation::Node::InitParam &init_param) + : operation::Node{OperandConstraint::createEQ(1u)} { assert(init_param.input_count == 7); assert(init_param.output_count == 1); @@ -61,20 +62,6 @@ Node::Node(const graph::operation::Node::InitParam &init_param) _param.activation_index = operand::Index{init_param.inputs[6]}; } -void Node::setInputs(const operand::IndexSet &indexes) -{ - assert(indexes.size() == 1); - - graph::operation::Node::setInputs(indexes); -} - -void Node::setOutputs(const operand::IndexSet &indexes) -{ - assert(indexes.size() == 1); - - graph::operation::Node::setOutputs(indexes); -} - } // namespace Implicit } // namespace AvgPool2D } // namespace operation diff --git a/runtimes/neurun/src/graph/operation/AvgPool2D.h b/runtimes/neurun/src/graph/operation/AvgPool2D.h index 41ec0db..8f4c455 100644 --- a/runtimes/neurun/src/graph/operation/AvgPool2D.h +++ b/runtimes/neurun/src/graph/operation/AvgPool2D.h @@ -59,10 +59,6 @@ public: virtual std::string getName() const override { return "AvgPool2D"; } public: - virtual void setInputs(const operand::IndexSet &indexes) override; - virtual void setOutputs(const operand::IndexSet &indexes) override; - -public: const Param ¶m() const { return _param; } private: diff --git a/runtimes/neurun/src/graph/operation/Concat.cc b/runtimes/neurun/src/graph/operation/Concat.cc index ae311d7..2e130da 100644 --- a/runtimes/neurun/src/graph/operation/Concat.cc +++ b/runtimes/neurun/src/graph/operation/Concat.cc @@ -33,6 +33,7 @@ namespace Concat void Node::accept(NodeVisitor &&v) const { v.visit(*this); } Node::Node(const graph::operation::Node::InitParam &init_param) + : operation::Node{OperandConstraint::createGE(2u)} { assert(init_param.input_count >= 2); // At least one one input tensor and axis assert(init_param.output_count == 1); @@ -56,13 +57,6 @@ Node::Node(const graph::operation::Node::InitParam &init_param) _param.axis_index = operand::Index{init_param.inputs[init_param.input_count - 1]}; } -void Node::setOutputs(const operand::IndexSet &indexes) -{ - assert(indexes.size() == 1); - - graph::operation::Node::setOutputs(indexes); -} - } // namespace Concat } // namespace operation } // namespace graph diff --git a/runtimes/neurun/src/graph/operation/Concat.h b/runtimes/neurun/src/graph/operation/Concat.h index 69aabcc..7710d46 100644 --- a/runtimes/neurun/src/graph/operation/Concat.h +++ b/runtimes/neurun/src/graph/operation/Concat.h @@ -45,9 +45,6 @@ public: virtual std::string getName() const override { return "Concat"; } public: - virtual void setOutputs(const operand::IndexSet &indexes) override; - -public: const Param ¶m() const { return _param; } private: diff --git a/runtimes/neurun/src/graph/operation/Conv2D.cc b/runtimes/neurun/src/graph/operation/Conv2D.cc index 61b1d79..47e469d 100644 --- a/runtimes/neurun/src/graph/operation/Conv2D.cc +++ b/runtimes/neurun/src/graph/operation/Conv2D.cc @@ -35,6 +35,7 @@ namespace Implicit void Node::accept(NodeVisitor &&v) const { v.visit(*this); } Node::Node(const graph::operation::Node::InitParam &init_param) + : operation::Node{OperandConstraint::createEQ(3u)} { assert(init_param.input_count == 7 && init_param.output_count == 1); @@ -58,20 +59,6 @@ Node::Node(const graph::operation::Node::InitParam &init_param) _param.activation_index = operand::Index{init_param.inputs[6]}; } -void Node::setInputs(const operand::IndexSet &indexes) -{ - assert(indexes.size() == 3); - - graph::operation::Node::setInputs(indexes); -} - -void Node::setOutputs(const operand::IndexSet &indexes) -{ - assert(indexes.size() == 1); - - graph::operation::Node::setOutputs(indexes); -} - } // namespace Implicit } // namespace Conv2D } // namespace operation diff --git a/runtimes/neurun/src/graph/operation/Conv2D.h b/runtimes/neurun/src/graph/operation/Conv2D.h index 6221eab..b0eec01 100644 --- a/runtimes/neurun/src/graph/operation/Conv2D.h +++ b/runtimes/neurun/src/graph/operation/Conv2D.h @@ -58,10 +58,6 @@ public: virtual std::string getName() const override { return "Conv2D"; } public: - virtual void setInputs(const operand::IndexSet &indexes) override; - virtual void setOutputs(const operand::IndexSet &indexes) override; - -public: const Param ¶m() const { return _param; } private: diff --git a/runtimes/neurun/src/graph/operation/FullyConnected.cc b/runtimes/neurun/src/graph/operation/FullyConnected.cc index db3cf05..086db66 100644 --- a/runtimes/neurun/src/graph/operation/FullyConnected.cc +++ b/runtimes/neurun/src/graph/operation/FullyConnected.cc @@ -33,6 +33,7 @@ namespace FullyConnected void Node::accept(NodeVisitor &&v) const { v.visit(*this); } Node::Node(const graph::operation::Node::InitParam &init_param) + : operation::Node{OperandConstraint::createEQ(3u)} { assert(init_param.input_count == 4 && init_param.output_count == 1); @@ -49,20 +50,6 @@ Node::Node(const graph::operation::Node::InitParam &init_param) _param.activation_index = operand::Index{init_param.inputs[3]}; } -void Node::setInputs(const operand::IndexSet &indexes) -{ - assert(indexes.size() == 3); - - graph::operation::Node::setInputs(indexes); -} - -void Node::setOutputs(const operand::IndexSet &indexes) -{ - assert(indexes.size() == 1); - - graph::operation::Node::setOutputs(indexes); -} - } // namespace FullyConnected } // namespace operation } // namespace graph diff --git a/runtimes/neurun/src/graph/operation/FullyConnected.h b/runtimes/neurun/src/graph/operation/FullyConnected.h index b935f7b..2c28452 100644 --- a/runtimes/neurun/src/graph/operation/FullyConnected.h +++ b/runtimes/neurun/src/graph/operation/FullyConnected.h @@ -52,10 +52,6 @@ public: virtual std::string getName() const override { return "FullyConnected"; } public: - virtual void setInputs(const operand::IndexSet &indexes) override; - virtual void setOutputs(const operand::IndexSet &indexes) override; - -public: const Param ¶m() const { return _param; } private: diff --git a/runtimes/neurun/src/graph/operation/MaxPool2D.cc b/runtimes/neurun/src/graph/operation/MaxPool2D.cc index eec0972..b08f6a6 100644 --- a/runtimes/neurun/src/graph/operation/MaxPool2D.cc +++ b/runtimes/neurun/src/graph/operation/MaxPool2D.cc @@ -35,6 +35,7 @@ namespace Implicit void Node::accept(NodeVisitor &&v) const { v.visit(*this); } Node::Node(const graph::operation::Node::InitParam &init_param) + : operation::Node{OperandConstraint::createEQ(1u)} { assert(init_param.input_count == 7); assert(init_param.output_count == 1); @@ -61,20 +62,6 @@ Node::Node(const graph::operation::Node::InitParam &init_param) _param.activation_index = operand::Index{init_param.inputs[6]}; } -void Node::setInputs(const operand::IndexSet &indexes) -{ - assert(indexes.size() == 1); - - graph::operation::Node::setInputs(indexes); -} - -void Node::setOutputs(const operand::IndexSet &indexes) -{ - assert(indexes.size() == 1); - - graph::operation::Node::setOutputs(indexes); -} - } // namespace Implicit } // namespace MaxPool2D } // namespace operation diff --git a/runtimes/neurun/src/graph/operation/MaxPool2D.h b/runtimes/neurun/src/graph/operation/MaxPool2D.h index 09557f0..d626ec5 100644 --- a/runtimes/neurun/src/graph/operation/MaxPool2D.h +++ b/runtimes/neurun/src/graph/operation/MaxPool2D.h @@ -59,10 +59,6 @@ public: Node(const graph::operation::Node::InitParam &init_param); public: - virtual void setInputs(const operand::IndexSet &indexes) override; - virtual void setOutputs(const operand::IndexSet &indexes) override; - -public: const Param ¶m() const { return _param; } private: diff --git a/runtimes/neurun/src/graph/operation/NOP.cc b/runtimes/neurun/src/graph/operation/NOP.cc index 18c3246..ea98c6e 100644 --- a/runtimes/neurun/src/graph/operation/NOP.cc +++ b/runtimes/neurun/src/graph/operation/NOP.cc @@ -30,6 +30,11 @@ namespace NOP void Node::accept(NodeVisitor &&v) const { v.visit(*this); } +Node::Node(const graph::operation::Node::InitParam &) + : operation::Node{OperandConstraint::createEQ(1u)} +{ +} + } // namespace NOP } // namespace operation } // namespace graph diff --git a/runtimes/neurun/src/graph/operation/NOP.h b/runtimes/neurun/src/graph/operation/NOP.h index 8da8c8b..193c5d4 100644 --- a/runtimes/neurun/src/graph/operation/NOP.h +++ b/runtimes/neurun/src/graph/operation/NOP.h @@ -33,7 +33,7 @@ namespace NOP class Node : public graph::operation::Node { public: - Node(const graph::operation::Node::InitParam &) {} + Node(const graph::operation::Node::InitParam &); public: virtual void accept(NodeVisitor &&) const override; diff --git a/runtimes/neurun/src/graph/operation/Node.cc b/runtimes/neurun/src/graph/operation/Node.cc index eeb3ff4..74b37c4 100644 --- a/runtimes/neurun/src/graph/operation/Node.cc +++ b/runtimes/neurun/src/graph/operation/Node.cc @@ -16,6 +16,8 @@ #include "Node.h" +#include + #include "LowerInfo.h" namespace neurun @@ -25,10 +27,18 @@ namespace graph namespace operation { -Node::Node() = default; +Node::Node(OperandConstraint input_constr) : _input_constr{input_constr} {} Node::~Node() = default; +void Node::setInputs(const operand::IndexSet &indexes) +{ + assert(_input_constr.check(indexes.size())); + _inputs = indexes; +} + +void Node::setOutputs(const operand::IndexSet &indexes) { _outputs = indexes; } + void Node::replaceInput(const operand::Index &from, const operand::Index &to) { _inputs.replace(from, to); diff --git a/runtimes/neurun/src/graph/operation/Node.h b/runtimes/neurun/src/graph/operation/Node.h index bab51b3..a1c5be8 100644 --- a/runtimes/neurun/src/graph/operation/Node.h +++ b/runtimes/neurun/src/graph/operation/Node.h @@ -20,6 +20,7 @@ #include #include "graph/operand/IndexSet.h" +#include "OperandConstraint.h" namespace neurun { @@ -43,7 +44,7 @@ public: }; public: - Node(); + Node(OperandConstraint input_constr); virtual ~Node(); public: @@ -51,13 +52,13 @@ public: virtual std::string getName() const = 0; public: - virtual void replaceInput(const operand::Index &from, const operand::Index &to); - virtual void replaceOutput(const operand::Index &from, const operand::Index &to); - virtual const operand::IndexSet &getInputs() const { return _inputs; } - virtual const operand::IndexSet &getOutputs() const { return _outputs; } + void replaceInput(const operand::Index &from, const operand::Index &to); + void replaceOutput(const operand::Index &from, const operand::Index &to); + const operand::IndexSet &getInputs() const { return _inputs; } + const operand::IndexSet &getOutputs() const { return _outputs; } // It's for only input/output tensors but const data. - virtual void setInputs(const operand::IndexSet &indexes) { _inputs = indexes; } - virtual void setOutputs(const operand::IndexSet &indexes) { _outputs = indexes; } + void setInputs(const operand::IndexSet &indexes); + void setOutputs(const operand::IndexSet &indexes); public: void lower_info(std::unique_ptr &&lower_info); @@ -66,6 +67,7 @@ public: private: operand::IndexSet _inputs; operand::IndexSet _outputs; + OperandConstraint _input_constr; std::unique_ptr _lower_info; }; diff --git a/runtimes/neurun/src/graph/operation/OperandConstraint.cc b/runtimes/neurun/src/graph/operation/OperandConstraint.cc new file mode 100644 index 0000000..f434cb8 --- /dev/null +++ b/runtimes/neurun/src/graph/operation/OperandConstraint.cc @@ -0,0 +1,12 @@ +#include "OperandConstraint.h" + +namespace neurun +{ +namespace graph +{ +namespace operation +{ + +} // namespace operation +} // namespace graph +} // namespace neurun diff --git a/runtimes/neurun/src/graph/operation/OperandConstraint.h b/runtimes/neurun/src/graph/operation/OperandConstraint.h new file mode 100644 index 0000000..1f656c4 --- /dev/null +++ b/runtimes/neurun/src/graph/operation/OperandConstraint.h @@ -0,0 +1,45 @@ +#ifndef __NEURUN_GRAPH_OPERATION_OPERAND_CONSTRAINT_H__ +#define __NEURUN_GRAPH_OPERATION_OPERAND_CONSTRAINT_H__ + +#include +#include +#include + +namespace neurun +{ +namespace graph +{ +namespace operation +{ + +class OperandConstraint +{ +private: + static const uint32_t INF = std::numeric_limits::max(); + +public: + static OperandConstraint createAny() { return OperandConstraint{0u, INF}; } + static OperandConstraint createEQ(uint32_t exact) { return OperandConstraint{exact, exact}; } + static OperandConstraint createLE(uint32_t end) { return OperandConstraint{0u, end}; } + static OperandConstraint createGE(uint32_t begin) { return OperandConstraint{begin, INF}; } + static OperandConstraint createRange(uint32_t begin, uint32_t end) + { + return OperandConstraint{begin, end}; + } + +private: + OperandConstraint(uint32_t begin, uint32_t end) : _begin{begin}, _end{end} {} + +public: + bool check(uint32_t ind) const { return _begin <= ind && ind <= _end; } + +private: + uint32_t _begin; + uint32_t _end; +}; + +} // namespace operation +} // namespace graph +} // namespace neurun + +#endif // __NEURUN_GRAPH_OPERATION_OPERAND_CONSTRAINT_H__ diff --git a/runtimes/neurun/src/graph/operation/Permute.cc b/runtimes/neurun/src/graph/operation/Permute.cc index 2688e5e..57a40f7 100644 --- a/runtimes/neurun/src/graph/operation/Permute.cc +++ b/runtimes/neurun/src/graph/operation/Permute.cc @@ -16,25 +16,12 @@ namespace Permute void Node::accept(NodeVisitor &&v) const { v.visit(*this); } Node::Node(const operand::Index &input, const operand::Index &output) + : operation::Node{OperandConstraint::createEQ(1u)} { setInputs({input}); setOutputs({output}); } -void Node::setInputs(const operand::IndexSet &indexes) -{ - assert(indexes.size() == 1); - - graph::operation::Node::setInputs(indexes); -} - -void Node::setOutputs(const operand::IndexSet &indexes) -{ - assert(indexes.size() == 1); - - graph::operation::Node::setOutputs(indexes); -} - } // namespace Permute } // namespace operation } // namespace graph diff --git a/runtimes/neurun/src/graph/operation/Permute.h b/runtimes/neurun/src/graph/operation/Permute.h index d0b8e9a..5a64ca8 100644 --- a/runtimes/neurun/src/graph/operation/Permute.h +++ b/runtimes/neurun/src/graph/operation/Permute.h @@ -20,10 +20,6 @@ public: public: Node(const operand::Index &input, const operand::Index &output); - -public: - virtual void setInputs(const operand::IndexSet &indexes) override; - virtual void setOutputs(const operand::IndexSet &indexes) override; }; } // namespace Permute diff --git a/runtimes/neurun/src/graph/operation/Reshape.cc b/runtimes/neurun/src/graph/operation/Reshape.cc index e6bc211..184f907 100644 --- a/runtimes/neurun/src/graph/operation/Reshape.cc +++ b/runtimes/neurun/src/graph/operation/Reshape.cc @@ -33,6 +33,7 @@ namespace Reshape void Node::accept(NodeVisitor &&v) const { v.visit(*this); } Node::Node(const graph::operation::Node::InitParam &init_param) + : operation::Node{OperandConstraint::createEQ(1u)} { assert(init_param.input_count == 2 && init_param.output_count == 1); @@ -47,20 +48,6 @@ Node::Node(const graph::operation::Node::InitParam &init_param) setOutputs({init_param.outputs[0]}); } -void Node::setInputs(const operand::IndexSet &indexes) -{ - assert(indexes.size() == 1); // TODO Should be 2 (See also the constructor) - - graph::operation::Node::setInputs(indexes); -} - -void Node::setOutputs(const operand::IndexSet &indexes) -{ - assert(indexes.size() == 1); - - graph::operation::Node::setOutputs(indexes); -} - } // namespace Reshape } // namespace operation } // namespace graph diff --git a/runtimes/neurun/src/graph/operation/Reshape.h b/runtimes/neurun/src/graph/operation/Reshape.h index ede243a..3959007 100644 --- a/runtimes/neurun/src/graph/operation/Reshape.h +++ b/runtimes/neurun/src/graph/operation/Reshape.h @@ -43,10 +43,6 @@ public: public: Node(const graph::operation::Node::InitParam &init_param); - -public: - virtual void setInputs(const operand::IndexSet &indexes) override; - virtual void setOutputs(const operand::IndexSet &indexes) override; }; } // namespace Reshape diff --git a/runtimes/neurun/src/graph/operation/Softmax.cc b/runtimes/neurun/src/graph/operation/Softmax.cc index a6385a2..307258a 100644 --- a/runtimes/neurun/src/graph/operation/Softmax.cc +++ b/runtimes/neurun/src/graph/operation/Softmax.cc @@ -33,6 +33,7 @@ namespace Softmax void Node::accept(NodeVisitor &&v) const { v.visit(*this); } Node::Node(const graph::operation::Node::InitParam &init_param) + : operation::Node{OperandConstraint::createEQ(1u)} { assert(init_param.input_count == 2 && init_param.output_count == 1); @@ -47,20 +48,6 @@ Node::Node(const graph::operation::Node::InitParam &init_param) _param.scale_index = operand::Index{init_param.inputs[1]}; } -void Node::setInputs(const operand::IndexSet &indexes) -{ - assert(indexes.size() == 1); - - graph::operation::Node::setInputs(indexes); -} - -void Node::setOutputs(const operand::IndexSet &indexes) -{ - assert(indexes.size() == 1); - - graph::operation::Node::setOutputs(indexes); -} - } // namespace Softmax } // namespace operation } // namespace graph diff --git a/runtimes/neurun/src/graph/operation/Softmax.h b/runtimes/neurun/src/graph/operation/Softmax.h index a1575bc..8fc8ae8 100644 --- a/runtimes/neurun/src/graph/operation/Softmax.h +++ b/runtimes/neurun/src/graph/operation/Softmax.h @@ -50,10 +50,6 @@ public: Node(const graph::operation::Node::InitParam &init_param); public: - virtual void setInputs(const operand::IndexSet &indexes) override; - virtual void setOutputs(const operand::IndexSet &indexes) override; - -public: const Param ¶m() const { return _param; } private: diff --git a/runtimes/neurun/test/graph/MockNode.h b/runtimes/neurun/test/graph/MockNode.h index e0baf6e..4f4ea4d 100644 --- a/runtimes/neurun/test/graph/MockNode.h +++ b/runtimes/neurun/test/graph/MockNode.h @@ -30,6 +30,7 @@ class SimpleMockNode : public neurun::graph::operation::Node public: SimpleMockNode(const neurun::graph::operand::IndexSet &inputs, const neurun::graph::operand::IndexSet &outputs) + : neurun::graph::operation::Node{neurun::graph::operation::OperandConstraint::createAny()} { setInputs(inputs); setOutputs(outputs);