From 8d56ef78f771faa5f30098913f818f232eec7960 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=84=B8=ED=9D=AC/On-Device=20Lab=28SR=29/Princip?= =?utf8?q?al=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Mon, 4 Nov 2019 15:48:17 +0900 Subject: [PATCH] [loco] GraphInputIndexQueryService and input_nodes() (#8713) This will introduce GraphInputIndexQueryService interface and input_nodes() method to gather Graph Input nodes Signed-off-by: SaeHie Park --- compiler/loco/include/loco/IR/Graph.h | 17 +++++++++++++++++ compiler/loco/src/IR/Graph.cpp | 30 ++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+) diff --git a/compiler/loco/include/loco/IR/Graph.h b/compiler/loco/include/loco/IR/Graph.h index 42ec68b..46b08e5 100644 --- a/compiler/loco/include/loco/IR/Graph.h +++ b/compiler/loco/include/loco/IR/Graph.h @@ -228,6 +228,23 @@ private: OutputContext _output_ctx; }; +struct GraphInputIndexQueryService : public DialectService +{ + virtual ~GraphInputIndexQueryService() = default; + + /** + * @brief Check whether a given node is associated with any Graph-level input + */ + virtual bool associated(const Node *node) const = 0; + + /** + * WARNING! CALLER SHOULD GUARANTEE that associated(node) is true before invoking this API. + */ + virtual GraphInputIndex index(const Node *node) const = 0; +}; + +std::vector input_nodes(const Graph *); + struct GraphOutputIndexQueryService : public DialectService { virtual ~GraphOutputIndexQueryService() = default; diff --git a/compiler/loco/src/IR/Graph.cpp b/compiler/loco/src/IR/Graph.cpp index 345525b..1d87522 100644 --- a/compiler/loco/src/IR/Graph.cpp +++ b/compiler/loco/src/IR/Graph.cpp @@ -72,6 +72,36 @@ std::set all_nodes(loco::Graph *g) return res; } +std::vector input_nodes(const Graph *g) +{ + std::map table; + + for (uint32_t n = 0; n < g->nodes()->size(); ++n) + { + auto node = g->nodes()->at(n); + + if (auto service = node->dialect()->service()) + { + if (service->associated(node)) + { + auto input_index = service->index(node); + assert(table.find(input_index) == table.end()); + table[input_index] = node; + } + } + } + + std::vector res; + + for (uint32_t n = 0; n < g->inputs()->size(); ++n) + { + auto it = table.find(n); + res.emplace_back(it == table.end() ? nullptr : it->second); + } + + return res; +} + std::vector output_nodes(loco::Graph *g) { std::map table; -- 2.7.4