From 9bf5a83505d2be78d0f89bc7b732998c4e6ed3e1 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=EB=B0=95=EC=A2=85=ED=98=84/On-Device=20Lab=28SR=29/Staff?= =?utf8?q?=20Engineer/=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 10 Jul 2019 12:59:31 +0900 Subject: [PATCH] [locop] Extensible Node Formatting (#4156) This commit introduces NodeSummaryBuilder & NodeSummaryBuilderFactory interfaces. These interfaces allow users to inject their own Node Formatting implementation. Signed-off-by: Jonghyun Park --- contrib/locop/include/locop/FormattedGraph.h | 31 +++++++++++++++ contrib/locop/src/FormattedGraph.cpp | 56 +++++++++++++++++++++++++--- contrib/locop/src/FormattedGraph.test.cpp | 41 ++++++++++++++++++++ 3 files changed, 123 insertions(+), 5 deletions(-) diff --git a/contrib/locop/include/locop/FormattedGraph.h b/contrib/locop/include/locop/FormattedGraph.h index dc29bfa..253759b 100644 --- a/contrib/locop/include/locop/FormattedGraph.h +++ b/contrib/locop/include/locop/FormattedGraph.h @@ -83,6 +83,25 @@ private: State _state = State::Invalid; }; +using NodeSummary = NodeDesc; + +/** + * @brief Build a summary from loco Node + */ +struct NodeSummaryBuilder +{ + virtual ~NodeSummaryBuilder() = default; + + virtual bool build(const loco::Node *, NodeSummary &) const = 0; +}; + +struct NodeSummaryBuilderFactory +{ + virtual ~NodeSummaryBuilderFactory() = default; + + virtual std::unique_ptr create(const SymbolTable *) const = 0; +}; + struct FormattedGraph { virtual ~FormattedGraph() = default; @@ -108,8 +127,20 @@ public: public: void dump(std::ostream &os) const final; +public: + FormattedGraphImpl &with(std::unique_ptr &&f) + { + _factory = std::move(f); + return (*this); + } + private: loco::Graph *_graph; + + /** + * @brief User-provided NodeSummaryBuilderFactory + */ + std::unique_ptr _factory = nullptr; }; template FormattedGraphImpl fmt(loco::Graph *g) diff --git a/contrib/locop/src/FormattedGraph.cpp b/contrib/locop/src/FormattedGraph.cpp index 1cceb5f..aa11b6c 100644 --- a/contrib/locop/src/FormattedGraph.cpp +++ b/contrib/locop/src/FormattedGraph.cpp @@ -147,7 +147,8 @@ void NodeDesc::opname(const std::string &v) { _name = stdex::make_unique values; - for (uint32_t n = 0; n < d.arg_size(); ++n) + for (uint32_t n = 0; n < d.args().count(); ++n) { - values.emplace_back(d.arg(n).first + ": " + d.arg(n).second); + values.emplace_back(d.args().at(n).first + ": " + d.args().at(n).second); } if (d.state() == NodeDesc::State::PartiallyKnown) @@ -166,7 +167,7 @@ std::ostream &operator<<(std::ostream &os, const NodeDesc &d) values.emplace_back("..."); } - os << d.name(); + os << d.opname(); os << "("; if (values.size() > 0) { @@ -181,6 +182,11 @@ std::ostream &operator<<(std::ostream &os, const NodeDesc &d) return os; } +} // namespace locop + +namespace +{ + NodeDesc default_node_desc(const SymbolTable &tbl, const loco::Node *node) { NodeDesc res{opname(node)}; @@ -280,6 +286,25 @@ NodeDesc node_desc(const SymbolTable &tbl, const loco::Node *node) return default_node_desc(tbl, node); } +struct BuiltinNodeSummaryBuilder final : public locop::NodeSummaryBuilder +{ +public: + BuiltinNodeSummaryBuilder(const locop::SymbolTable *symtbl) : _symtbl{symtbl} + { + // DO NOTHING + } + +public: + bool build(const loco::Node *node, locop::NodeSummary &summary) const final + { + summary = node_desc(*_symtbl, node); + return true; + } + +private: + const locop::SymbolTable *_symtbl; +}; + } // namespace namespace locop @@ -374,6 +399,19 @@ void FormattedGraphImpl::dump(std::ostream &os) const clusters.at(find(node)).insert(node); } + std::unique_ptr node_summary_builder; + + if (_factory) + { + // Use User-defined NodeSummaryBuilder if NodeSummaryBuilderFactory is present + node_summary_builder = _factory->create(&symbols); + } + else + { + // Use Built-in NodeSummaryBuilder otherwise + node_summary_builder = stdex::make_unique(&symbols); + } + for (auto it = clusters.begin(); it != clusters.end(); ++it) { std::vector cluster_outputs; @@ -389,7 +427,15 @@ void FormattedGraphImpl::dump(std::ostream &os) const for (auto node : loco::postorder_traversal(cluster_outputs)) { - os << symbol(node) << " = " << node_desc(symbols, node) << std::endl; + locop::NodeSummary node_summary; + + // Build a node summary + if (!node_summary_builder->build(node, node_summary)) + { + throw std::runtime_error{"Fail to build a node summary"}; + } + + os << symbol(node) << " = " << node_summary << std::endl; } os << std::endl; } diff --git a/contrib/locop/src/FormattedGraph.test.cpp b/contrib/locop/src/FormattedGraph.test.cpp index 2f48052..0f8c595 100644 --- a/contrib/locop/src/FormattedGraph.test.cpp +++ b/contrib/locop/src/FormattedGraph.test.cpp @@ -16,6 +16,8 @@ #include "locop/FormattedGraph.h" +#include + #include TEST(LinearV1FormatterTest, simple) @@ -36,3 +38,42 @@ TEST(LinearV1FormatterTest, simple) // TODO Validate the output (when the implementation becomes stable) std::cout << locop::fmt(g) << std::endl; } + +TEST(LinearV1FormatterTest, user_defined_node_summary_builder) +{ + auto g = loco::make_graph(); + { + auto pull = g->nodes()->create(); + + pull->rank(2); + pull->dim(0) = loco::make_dimension(); // Mark dim 0 as unknown + pull->dim(1) = 4; + + auto push = g->nodes()->create(); + + push->from(pull); + } + + struct MyBuilder final : public locop::NodeSummaryBuilder + { + bool build(const loco::Node *, locop::NodeSummary &s) const final + { + s.opname("my.op"); + s.state(locop::NodeSummary::State::PartiallyKnown); + return true; + } + }; + + struct MyFactory final : public locop::NodeSummaryBuilderFactory + { + std::unique_ptr create(const locop::SymbolTable *) const final + { + return stdex::make_unique(); + } + }; + + std::cout << locop::fmt(g).with(stdex::make_unique()) << std::endl; + + // TODO Check whether MyBuilder actually sees all the nodes in a graph + SUCCEED(); +} -- 2.7.4