From 650c954397e45f469ce803ebf25e339db3e11af2 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Vladimir=20Plazun/AI=20Tools=20Lab=20/SRR/Engineer/?= =?utf8?q?=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 19 Oct 2018 15:31:23 +0300 Subject: [PATCH] [nnc] Add ability to get inputs/outputs from model graph (#1905) This commits adds collectInputs/collectOutputs methods to graph Add _lastNodeId field to separate node indexing from nodes list size Fix code style Signed-off-by: Vladimir Plazun --- contrib/nnc/core/modelIR/graph.cpp | 34 +++++++++++++++++++++----------- contrib/nnc/include/core/modelIR/graph.h | 32 +++++++++++++++++++----------- 2 files changed, 44 insertions(+), 22 deletions(-) diff --git a/contrib/nnc/core/modelIR/graph.cpp b/contrib/nnc/core/modelIR/graph.cpp index 43613fa..6d9ee69 100644 --- a/contrib/nnc/core/modelIR/graph.cpp +++ b/contrib/nnc/core/modelIR/graph.cpp @@ -18,15 +18,11 @@ #include #include "core/modelIR/graph.h" -#include "core/modelIR/ir_node.h" -#include "core/modelIR/operations/operation.h" -namespace nnc -{ -namespace mir -{ +namespace nnc { +namespace mir { -INode::Ref Graph::getInput(const std::string &name) { +INode::Ref Graph::getInput(const std::string& name) { auto it = _inputs.find(name); if (it == _inputs.end()) return nullptr; @@ -34,7 +30,7 @@ INode::Ref Graph::getInput(const std::string &name) { return it->second; } -INode::Ref Graph::getOutput(const std::string &name) { +INode::Ref Graph::getOutput(const std::string& name) { auto it = _outputs.find(name); if (it == _outputs.end()) return nullptr; @@ -42,11 +38,11 @@ INode::Ref Graph::getOutput(const std::string &name) { return it->second; } -void Graph::accept(IVisitor *visitor) { +void Graph::accept(IVisitor* visitor) { std::deque q; std::set known_nodes; - for (const auto &e : _inputs) { + for (const auto& e : _inputs) { q.push_back(e.second); known_nodes.insert(e.second); //Consider all input _nodes resolved by default } @@ -75,7 +71,7 @@ void Graph::accept(IVisitor *visitor) { } Graph::~Graph() { - for (auto &node : _nodes) { + for (auto& node : _nodes) { delete node; } } @@ -89,5 +85,21 @@ void Graph::markOutput(INode::Ref node) { _outputs[node->getName()] = node; } +std::vector Graph::collectInputs() { + std::vector res; + for (auto& e : _inputs) { + res.emplace_back(e.second); + } + return res; +} + +std::vector Graph::collectOutputs() { + std::vector res; + for (auto& e : _outputs) { + res.emplace_back(e.second); + } + return res; +} + } // namespace mir } // namespace nnc diff --git a/contrib/nnc/include/core/modelIR/graph.h b/contrib/nnc/include/core/modelIR/graph.h index d02e257..1dc143f 100644 --- a/contrib/nnc/include/core/modelIR/graph.h +++ b/contrib/nnc/include/core/modelIR/graph.h @@ -26,10 +26,8 @@ #include "core/modelIR/operations/variable_op.h" #include "core/modelIR/ir_node.h" -namespace nnc -{ -namespace mir -{ +namespace nnc { +namespace mir { class IVisitor; @@ -39,28 +37,39 @@ class Graph { virtual ~Graph(); - template + template //make this method callable only with OpDescription subclasses typename std::enable_if::value, INode::Ref>::type - create(const std::string &name, Args &&...args) { - auto node = Node::createNode(name, _nodes.size(), std::forward(args)...); + create(const std::string& name, Args&&...args) { + auto node = Node::createNode(name, _lastNodeId++, std::forward(args)...); registerNode(node); return node; } - void accept(IVisitor *visitor); + void accept(IVisitor* visitor); void markOutput(INode::Ref node); - INode::Ref getInput(const std::string &name); - INode::Ref getOutput(const std::string &name); + INode::Ref getInput(const std::string& name); + INode::Ref getOutput(const std::string& name); + /** + * @brief Returns all inputs from graph + * @returns vector containing all graph input nodes + */ + std::vector collectInputs(); + + /** + * @brief Returns all outputs from graph + * @returns vector containing all graph outputs nodes + */ + std::vector collectOutputs(); private: void registerNode(INode::Ref node) { _nodes.push_back(node); } //TODO: maybe make user to mark input _nodes in a more obvious way - void registerNode(Node *node) { + void registerNode(Node* node) { auto it = _inputs.find(node->getName()); if( it != _inputs.end()) { throw std::runtime_error("Input name collision"); @@ -70,6 +79,7 @@ class Graph { } std::vector _nodes; + size_t _lastNodeId = 0; std::unordered_map _inputs; std::unordered_map _outputs; }; -- 2.7.4