From 68d8bb6ee945ea9af4a1a632f023c231d2db91ca Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EC=9D=B4=ED=95=9C=EC=A2=85/On-Device=20Lab=28SR=29/Enginee?= =?utf8?q?r/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Thu, 3 Jan 2019 15:06:36 +0900 Subject: [PATCH] [neurun] Extract Operation LowerInfo from Node (#4136) As `model::operation::Node` is pure node info so LowerInfo should not be Node class. `graph::Graph` holds them instead. Signed-off-by: Hanjoung Lee --- runtimes/neurun/src/compiler/Compiler.cc | 6 +-- runtimes/neurun/src/dumper/dot/DotDumper.cc | 5 ++- runtimes/neurun/src/dumper/dot/DotNodeInfo.cc | 26 +++++------ runtimes/neurun/src/dumper/dot/DotNodeInfo.h | 11 ++++- runtimes/neurun/src/graph/Graph.cc | 51 ++++++++++++++-------- runtimes/neurun/src/graph/Graph.h | 12 ++++- .../src/graph/pass/PermutationInsertionPass.cc | 7 +-- runtimes/neurun/src/linear/Linear.cc | 13 +++--- runtimes/neurun/src/linear/Linear.h | 15 ++++++- runtimes/neurun/src/model/operation/Node.cc | 7 --- runtimes/neurun/src/model/operation/Node.h | 5 --- 11 files changed, 97 insertions(+), 61 deletions(-) diff --git a/runtimes/neurun/src/compiler/Compiler.cc b/runtimes/neurun/src/compiler/Compiler.cc index e1d24d8..9d5382e 100644 --- a/runtimes/neurun/src/compiler/Compiler.cc +++ b/runtimes/neurun/src/compiler/Compiler.cc @@ -88,12 +88,12 @@ void Compiler::compile(void) PlanBuilder plan_builder{plan}; // Plan building - linear->iterate([&](const neurun::model::operation::Node *node) { - auto backend = node->lower_info()->backend(); + linear->iterate([&](const linear::Element &element) { + auto backend = element.lower_info->backend(); // Generate Stage auto stage_gen = backend->stage_gen(); - plan_builder.addStage(stage_gen->generate(*node)); + plan_builder.addStage(stage_gen->generate(*element.node)); }); auto tensor_builders = linear->planTensors(); diff --git a/runtimes/neurun/src/dumper/dot/DotDumper.cc b/runtimes/neurun/src/dumper/dot/DotDumper.cc index ed9590a..1e53ece 100644 --- a/runtimes/neurun/src/dumper/dot/DotDumper.cc +++ b/runtimes/neurun/src/dumper/dot/DotDumper.cc @@ -40,7 +40,7 @@ void DotDumper::dumpIfNeeded(const std::string &tag) auto &operands = _graph.operands(); operations.iterate([&](const model::operation::Index &index, const model::operation::Node &node) { - neurun::dumper::dot::DotNodeInfo node_info(index, node); + neurun::dumper::dot::DotNodeInfo node_info(_graph, index, node); for (auto output : node.getOutputs()) { @@ -85,7 +85,8 @@ void DotDumper::dumpIfNeeded(const std::string &tag) for (auto operation_index : object.getUses().list()) { auto &node = operations.at(operation_index); - auto child = std::make_shared(operation_index, node); + auto child = + std::make_shared(_graph, operation_index, node); operand_info.appendChild(child); } diff --git a/runtimes/neurun/src/dumper/dot/DotNodeInfo.cc b/runtimes/neurun/src/dumper/dot/DotNodeInfo.cc index 6e66d87..aefe12e 100644 --- a/runtimes/neurun/src/dumper/dot/DotNodeInfo.cc +++ b/runtimes/neurun/src/dumper/dot/DotNodeInfo.cc @@ -17,6 +17,7 @@ #include #include "DotNodeInfo.h" +#include "graph/Graph.h" #include "graph/operation/LowerInfo.h" #include "backend/interface/IConfig.h" @@ -32,15 +33,12 @@ const std::string DotNodeInfo::BG_COLOR_SCHEME = "pastel18"; // RED BLUE ORANGE YELLOW GREEN PUPLE CYAN PINK const std::string DotNodeInfo::BG_COLORS[8] = {"1", "2", "5", "6", "3", "4", "7", "8"}; -DotNodeInfo::DotNodeInfo(const neurun::model::operation::Index &index, +DotNodeInfo::DotNodeInfo(const neurun::graph::Graph &graph, + const neurun::model::operation::Index &index, const neurun::model::operation::Node &node) - : _index(index), _node(node) + : _index(index), _node(node), _lower_info(graph.getLowerInfo(index)) { - const auto &lower_info = _node.lower_info(); - if (lower_info) - { - addBackendLabel(); - } + addBackendLabel(); } std::string DotNodeInfo::index_str() const @@ -69,11 +67,10 @@ std::string DotNodeInfo::bg_color_scheme() const { return BG_COLOR_SCHEME; } std::string DotNodeInfo::bg_color() const { - const auto &lower_info = _node.lower_info(); - if (!lower_info) + if (!_lower_info) return DEFAULT_BG_COLOR; - assert(lower_info != nullptr); - const auto &backend = lower_info->backend(); + assert(_lower_info != nullptr); + const auto &backend = _lower_info->backend(); assert(backend != nullptr); std::string backend_id = backend->config()->id(); @@ -94,10 +91,11 @@ std::string DotNodeInfo::bg_color() const void DotNodeInfo::addBackendLabel() { + if (!_lower_info) + return; + std::string label; - const auto &lower_info = _node.lower_info(); - assert(lower_info != nullptr); - const auto &backend = lower_info->backend(); + const auto &backend = _lower_info->backend(); assert(backend != nullptr); label += "[Backend] : "; diff --git a/runtimes/neurun/src/dumper/dot/DotNodeInfo.h b/runtimes/neurun/src/dumper/dot/DotNodeInfo.h index ba50514..656a05a 100644 --- a/runtimes/neurun/src/dumper/dot/DotNodeInfo.h +++ b/runtimes/neurun/src/dumper/dot/DotNodeInfo.h @@ -23,6 +23,14 @@ namespace neurun { +namespace graph +{ +class Graph; +} // namespace graph +} // namespace neurun + +namespace neurun +{ namespace dumper { namespace dot @@ -36,7 +44,7 @@ public: static const std::string BG_COLORS[8]; public: - DotNodeInfo(const neurun::model::operation::Index &index, + DotNodeInfo(const neurun::graph::Graph &graph, const neurun::model::operation::Index &index, const neurun::model::operation::Node &node); public: @@ -52,6 +60,7 @@ private: private: neurun::model::operation::Index _index; const neurun::model::operation::Node &_node; + const neurun::graph::operation::LowerInfo *_lower_info; std::vector _labels; }; diff --git a/runtimes/neurun/src/graph/Graph.cc b/runtimes/neurun/src/graph/Graph.cc index 280125a..832e2b8 100644 --- a/runtimes/neurun/src/graph/Graph.cc +++ b/runtimes/neurun/src/graph/Graph.cc @@ -105,24 +105,25 @@ void Graph::lower(void) _backend_resolver = nnfw::cpp14::make_unique(_model->operands); - _model->operations.iterate([&](const model::operation::Index &, model::operation::Node &node) { - auto backend = _backend_resolver->getBackend(typeid(node)); + _model->operations.iterate( + [&](const model::operation::Index &index, model::operation::Node &node) { + auto backend = _backend_resolver->getBackend(typeid(node)); - // Operation LowerInfo - node.lower_info(nnfw::cpp14::make_unique(backend)); + // Operation LowerInfo + setLowerInfo(index, nnfw::cpp14::make_unique(backend)); - // LowerInfo for in/output operands - for (auto operand : node.getInputs()) - { - auto &&lower_info = operands_lower_info.at(operand); - lower_info->addUseBackend(backend); - } - for (auto operand : node.getOutputs()) - { - auto &&lower_info = operands_lower_info.at(operand); - lower_info->addDefBackend(backend); - } - }); + // LowerInfo for in/output operands + for (auto operand : node.getInputs()) + { + auto &&lower_info = operands_lower_info.at(operand); + lower_info->addUseBackend(backend); + } + for (auto operand : node.getOutputs()) + { + auto &&lower_info = operands_lower_info.at(operand); + lower_info->addDefBackend(backend); + } + }); // Add def backend to model input/output operand as default backend for (auto index : getInputs()) @@ -249,6 +250,20 @@ void Graph::initializeUseDef() }); } +const operation::LowerInfo *Graph::getLowerInfo(const model::operation::Index &index) const +{ + auto itr = _operation_lower_info.find(index); + if (itr == _operation_lower_info.end()) + return nullptr; + return itr->second.get(); +} + +void Graph::setLowerInfo(const model::operation::Index &index, + std::unique_ptr &&lower_info) +{ + _operation_lower_info.insert(std::make_pair(index, std::move(lower_info))); +} + } // namespace graph } // namespace neurun @@ -273,7 +288,7 @@ template void Graph::DefaultIterator::iterate(GraphRef graph, const IterFn &fn) const { graph.operations().iterate( - [&](const model::operation::Index &, NodeRef node) -> void { fn(node); }); + [&](const model::operation::Index &index, NodeRef node) -> void { fn(index, node); }); } // @@ -304,7 +319,7 @@ void Graph::PostDfsIterator::iterate(GraphRef graph, const IterFn &fn) } } - fn(node); + fn(index, node); }; graph.operations().iterate(dfs_recursive); diff --git a/runtimes/neurun/src/graph/Graph.h b/runtimes/neurun/src/graph/Graph.h index 7b9d639..afcfdce 100644 --- a/runtimes/neurun/src/graph/Graph.h +++ b/runtimes/neurun/src/graph/Graph.h @@ -57,9 +57,10 @@ public: { public: using GraphRef = typename std::conditional::type; + using IndexRef = const model::operation::Index &; using NodeRef = typename std::conditional::type; - using IterFn = std::function; + using IterFn = std::function; public: virtual ~Iterator() = default; @@ -70,6 +71,7 @@ public: { public: using GraphRef = typename Iterator::GraphRef; + using IndexRef = typename Iterator::IndexRef; using NodeRef = typename Iterator::NodeRef; using IterFn = typename Iterator::IterFn; @@ -82,6 +84,7 @@ public: { public: using GraphRef = typename Iterator::GraphRef; + using IndexRef = typename Iterator::IndexRef; using NodeRef = typename Iterator::NodeRef; using IterFn = typename Iterator::IterFn; @@ -132,8 +135,15 @@ private: std::unique_ptr _model{new Model}; // For LOWERED phase +public: + const operation::LowerInfo *getLowerInfo(const model::operation::Index &index) const; + void setLowerInfo(const model::operation::Index &index, + std::unique_ptr &&lower_info); + private: std::unique_ptr _backend_resolver; + std::unordered_map> + _operation_lower_info; }; } // namespace graph diff --git a/runtimes/neurun/src/graph/pass/PermutationInsertionPass.cc b/runtimes/neurun/src/graph/pass/PermutationInsertionPass.cc index a5d2752..d400ac0 100644 --- a/runtimes/neurun/src/graph/pass/PermutationInsertionPass.cc +++ b/runtimes/neurun/src/graph/pass/PermutationInsertionPass.cc @@ -86,7 +86,7 @@ void PermutationInsertionPass::callback(const model::operand::Index &index, continue; auto &operation = _graph.operations().at(use); - auto operation_li = operation.lower_info(); + auto operation_li = _graph.getLowerInfo(use); assert(operation_li); auto backend = operation_li->backend(); @@ -145,12 +145,13 @@ PermutationInsertionPass::insertPermute(const model::operand::Index &operand_ind // Insert permute operation to the graph auto insert_node = nnfw::cpp14::make_unique(operand_index, out_operand_index); - insert_node->lower_info(nnfw::cpp14::make_unique( - _graph.backend_resolver()->getDefaultBackend())); auto node_index = _graph.operations().append(std::move(insert_node)); const auto &node = _graph.operations().at(node_index); + _graph.setLowerInfo(node_index, nnfw::cpp14::make_unique( + _graph.backend_resolver()->getDefaultBackend())); + // Update Use/Def info { _graph.operands().at(operand_index).appendUse(node_index); diff --git a/runtimes/neurun/src/linear/Linear.cc b/runtimes/neurun/src/linear/Linear.cc index 109d85b..6452bbd 100644 --- a/runtimes/neurun/src/linear/Linear.cc +++ b/runtimes/neurun/src/linear/Linear.cc @@ -43,7 +43,10 @@ Linear::Linear(const graph::Graph &graph) : _graph(graph) // 3. Reverse the order of nodes graph::Graph::PostDfsConstIterator().iterate( - graph, [&](const neurun::model::operation::Node &node) { _operations.emplace_back(&node); }); + graph, [&](const model::operation::Index &index, const model::operation::Node &node) { + const auto lower_info = graph.getLowerInfo(index); + _operations.emplace_back(&node, lower_info); + }); std::reverse(std::begin(_operations), std::end(_operations)); } @@ -52,7 +55,7 @@ void Linear::accept(model::operation::NodeVisitor &&visitor) const { for (const auto op : _operations) { - op->accept(std::move(visitor)); + op.node->accept(std::move(visitor)); } } @@ -150,7 +153,7 @@ backend::TensorBuilderSet Linear::planTensors() VERBOSE(LINEAR) << "TENSORS" << std::endl; for (const auto op : _operations) { - for (const auto &ind : op->getOutputs()) + for (const auto &ind : op.node->getOutputs()) { const auto &obj = operands.at(ind); if (obj.getDef().size()) @@ -162,7 +165,7 @@ backend::TensorBuilderSet Linear::planTensors() } } - for (const auto &ind : op->getInputs()) + for (const auto &ind : op.node->getInputs()) { uses_map[ind]--; if (uses_map[ind] == 0) @@ -184,7 +187,7 @@ backend::TensorBuilderSet Linear::planTensors() return tensor_builders; } -void Linear::iterate(const std::function &fn) const +void Linear::iterate(const std::function &fn) const { for (const auto op : _operations) { diff --git a/runtimes/neurun/src/linear/Linear.h b/runtimes/neurun/src/linear/Linear.h index c65abfc..fb3f539 100644 --- a/runtimes/neurun/src/linear/Linear.h +++ b/runtimes/neurun/src/linear/Linear.h @@ -46,6 +46,17 @@ namespace neurun namespace linear { +struct Element +{ + const model::operation::Node *node; + const graph::operation::LowerInfo *lower_info; + + Element(const model::operation::Node *node, const graph::operation::LowerInfo *lower_info) + : node{node}, lower_info{lower_info} + { + } +}; + class Linear { public: @@ -60,11 +71,11 @@ public: // TODO Should not return TensorBuilderSet backend::TensorBuilderSet planTensors(); - void iterate(const std::function &fn) const; + void iterate(const std::function &fn) const; private: const graph::Graph &_graph; - std::vector _operations; + std::vector _operations; }; } // namespace linear diff --git a/runtimes/neurun/src/model/operation/Node.cc b/runtimes/neurun/src/model/operation/Node.cc index bac67ff..76397af 100644 --- a/runtimes/neurun/src/model/operation/Node.cc +++ b/runtimes/neurun/src/model/operation/Node.cc @@ -49,13 +49,6 @@ void Node::replaceOutput(const operand::Index &from, const operand::Index &to) _outputs.replace(from, to); } -void Node::lower_info(std::unique_ptr &&lower_info) -{ - _lower_info = std::move(lower_info); -} - -const graph::operation::LowerInfo *Node::lower_info() const { return _lower_info.get(); } - } // namespace operation } // namespace model } // namespace neurun diff --git a/runtimes/neurun/src/model/operation/Node.h b/runtimes/neurun/src/model/operation/Node.h index d4e3686..76f0d2d 100644 --- a/runtimes/neurun/src/model/operation/Node.h +++ b/runtimes/neurun/src/model/operation/Node.h @@ -71,15 +71,10 @@ public: void setInputs(const operand::IndexSet &indexes); void setOutputs(const operand::IndexSet &indexes); -public: - void lower_info(std::unique_ptr &&lower_info); - const graph::operation::LowerInfo *lower_info() const; - private: operand::IndexSet _inputs; operand::IndexSet _outputs; OperandConstraint _input_constr; - std::unique_ptr _lower_info; }; } // namespace operation -- 2.7.4