[locop] Extensible Node Formatting (#4156)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Wed, 10 Jul 2019 03:59:31 +0000 (12:59 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Wed, 10 Jul 2019 03:59:31 +0000 (12:59 +0900)
This commit introduces NodeSummaryBuilder & NodeSummaryBuilderFactory
interfaces.

These interfaces allow users to inject their own Node Formatting
implementation.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
contrib/locop/include/locop/FormattedGraph.h
contrib/locop/src/FormattedGraph.cpp
contrib/locop/src/FormattedGraph.test.cpp

index dc29bfa..253759b 100644 (file)
@@ -83,6 +83,25 @@ private:
   State _state = State::Invalid;
 };
 
+using NodeSummary = NodeDesc;
+
+/**
+ * @brief Build a summary from loco Node
+ */
+struct NodeSummaryBuilder
+{
+  virtual ~NodeSummaryBuilder() = default;
+
+  virtual bool build(const loco::Node *, NodeSummary &) const = 0;
+};
+
+struct NodeSummaryBuilderFactory
+{
+  virtual ~NodeSummaryBuilderFactory() = default;
+
+  virtual std::unique_ptr<NodeSummaryBuilder> create(const SymbolTable *) const = 0;
+};
+
 struct FormattedGraph
 {
   virtual ~FormattedGraph() = default;
@@ -108,8 +127,20 @@ public:
 public:
   void dump(std::ostream &os) const final;
 
+public:
+  FormattedGraphImpl<Formatter::LinearV1> &with(std::unique_ptr<NodeSummaryBuilderFactory> &&f)
+  {
+    _factory = std::move(f);
+    return (*this);
+  }
+
 private:
   loco::Graph *_graph;
+
+  /**
+   * @brief User-provided NodeSummaryBuilderFactory
+   */
+  std::unique_ptr<NodeSummaryBuilderFactory> _factory = nullptr;
 };
 
 template <Formatter F> FormattedGraphImpl<F> fmt(loco::Graph *g)
index 1cceb5f..aa11b6c 100644 (file)
@@ -147,7 +147,8 @@ void NodeDesc::opname(const std::string &v) { _name = stdex::make_unique<std::st
 
 } // namespace locop
 
-namespace
+// TODO Remove this workaround
+namespace locop
 {
 
 std::ostream &operator<<(std::ostream &os, const NodeDesc &d)
@@ -156,9 +157,9 @@ std::ostream &operator<<(std::ostream &os, const NodeDesc &d)
 
   std::vector<std::string> values;
 
-  for (uint32_t n = 0; n < d.arg_size(); ++n)
+  for (uint32_t n = 0; n < d.args().count(); ++n)
   {
-    values.emplace_back(d.arg(n).first + ": " + d.arg(n).second);
+    values.emplace_back(d.args().at(n).first + ": " + d.args().at(n).second);
   }
 
   if (d.state() == NodeDesc::State::PartiallyKnown)
@@ -166,7 +167,7 @@ std::ostream &operator<<(std::ostream &os, const NodeDesc &d)
     values.emplace_back("...");
   }
 
-  os << d.name();
+  os << d.opname();
   os << "(";
   if (values.size() > 0)
   {
@@ -181,6 +182,11 @@ std::ostream &operator<<(std::ostream &os, const NodeDesc &d)
   return os;
 }
 
+} // namespace locop
+
+namespace
+{
+
 NodeDesc default_node_desc(const SymbolTable &tbl, const loco::Node *node)
 {
   NodeDesc res{opname(node)};
@@ -280,6 +286,25 @@ NodeDesc node_desc(const SymbolTable &tbl, const loco::Node *node)
   return default_node_desc(tbl, node);
 }
 
+struct BuiltinNodeSummaryBuilder final : public locop::NodeSummaryBuilder
+{
+public:
+  BuiltinNodeSummaryBuilder(const locop::SymbolTable *symtbl) : _symtbl{symtbl}
+  {
+    // DO NOTHING
+  }
+
+public:
+  bool build(const loco::Node *node, locop::NodeSummary &summary) const final
+  {
+    summary = node_desc(*_symtbl, node);
+    return true;
+  }
+
+private:
+  const locop::SymbolTable *_symtbl;
+};
+
 } // namespace
 
 namespace locop
@@ -374,6 +399,19 @@ void FormattedGraphImpl<Formatter::LinearV1>::dump(std::ostream &os) const
     clusters.at(find(node)).insert(node);
   }
 
+  std::unique_ptr<locop::NodeSummaryBuilder> node_summary_builder;
+
+  if (_factory)
+  {
+    // Use User-defined NodeSummaryBuilder if NodeSummaryBuilderFactory is present
+    node_summary_builder = _factory->create(&symbols);
+  }
+  else
+  {
+    // Use Built-in NodeSummaryBuilder otherwise
+    node_summary_builder = stdex::make_unique<BuiltinNodeSummaryBuilder>(&symbols);
+  }
+
   for (auto it = clusters.begin(); it != clusters.end(); ++it)
   {
     std::vector<loco::Node *> cluster_outputs;
@@ -389,7 +427,15 @@ void FormattedGraphImpl<Formatter::LinearV1>::dump(std::ostream &os) const
 
     for (auto node : loco::postorder_traversal(cluster_outputs))
     {
-      os << symbol(node) << " = " << node_desc(symbols, node) << std::endl;
+      locop::NodeSummary node_summary;
+
+      // Build a node summary
+      if (!node_summary_builder->build(node, node_summary))
+      {
+        throw std::runtime_error{"Fail to build a node summary"};
+      }
+
+      os << symbol(node) << " = " << node_summary << std::endl;
     }
     os << std::endl;
   }
index 2f48052..0f8c595 100644 (file)
@@ -16,6 +16,8 @@
 
 #include "locop/FormattedGraph.h"
 
+#include <stdex/Memory.h>
+
 #include <gtest/gtest.h>
 
 TEST(LinearV1FormatterTest, simple)
@@ -36,3 +38,42 @@ TEST(LinearV1FormatterTest, simple)
   // TODO Validate the output (when the implementation becomes stable)
   std::cout << locop::fmt<locop::LinearV1>(g) << std::endl;
 }
+
+TEST(LinearV1FormatterTest, user_defined_node_summary_builder)
+{
+  auto g = loco::make_graph();
+  {
+    auto pull = g->nodes()->create<loco::Pull>();
+
+    pull->rank(2);
+    pull->dim(0) = loco::make_dimension(); // Mark dim 0 as unknown
+    pull->dim(1) = 4;
+
+    auto push = g->nodes()->create<loco::Push>();
+
+    push->from(pull);
+  }
+
+  struct MyBuilder final : public locop::NodeSummaryBuilder
+  {
+    bool build(const loco::Node *, locop::NodeSummary &s) const final
+    {
+      s.opname("my.op");
+      s.state(locop::NodeSummary::State::PartiallyKnown);
+      return true;
+    }
+  };
+
+  struct MyFactory final : public locop::NodeSummaryBuilderFactory
+  {
+    std::unique_ptr<locop::NodeSummaryBuilder> create(const locop::SymbolTable *) const final
+    {
+      return stdex::make_unique<MyBuilder>();
+    }
+  };
+
+  std::cout << locop::fmt<locop::LinearV1>(g).with(stdex::make_unique<MyFactory>()) << std::endl;
+
+  // TODO Check whether MyBuilder actually sees all the nodes in a graph
+  SUCCEED();
+}