[locop] Extract CanonicalNodeSummaryBuilder (#6169)
author박종현/On-Device Lab(SR)/Staff Engineer/삼성전자 <jh1302.park@samsung.com>
Mon, 5 Aug 2019 00:15:43 +0000 (09:15 +0900)
committerGitHub Enterprise <noreply-CODE@samsung.com>
Mon, 5 Aug 2019 00:15:43 +0000 (09:15 +0900)
This commit extracts CanonicalNodeSummaryBuilder from FormattedGraph
module, and put them in a header/source.

Signed-off-by: Jonghyun Park <jh1302.park@samsung.com>
compiler/locop/include/locop/CanonicalNodeSummaryBuilder.h [new file with mode: 0644]
compiler/locop/include/locop/FormattedGraph.h
compiler/locop/src/CanonicalNodeSummaryBuilder.cpp [new file with mode: 0644]
compiler/locop/src/FormattedGraph.cpp

diff --git a/compiler/locop/include/locop/CanonicalNodeSummaryBuilder.h b/compiler/locop/include/locop/CanonicalNodeSummaryBuilder.h
new file mode 100644 (file)
index 0000000..e9ced3f
--- /dev/null
@@ -0,0 +1,45 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef __LOCOP_CANONICAL_NODE_SUMMARY_BUILDER_H__
+#define __LOCOP_CANONICAL_NODE_SUMMARY_BUILDER_H__
+
+#include "locop/NodeSummaryBuilder.h"
+
+namespace locop
+{
+
+/**
+ * @brief Built-in Node Summary Builder for Canonical Dialect
+ */
+class CanonicalNodeSummaryBuilder final : public NodeSummaryBuilder
+{
+public:
+  CanonicalNodeSummaryBuilder(const SymbolTable *tbl) : _tbl{tbl}
+  {
+    // DO NOTHING
+  }
+
+public:
+  bool build(const loco::Node *node, locop::NodeSummary &out) const final;
+
+private:
+  const SymbolTable *_tbl;
+};
+
+} // namespace locop
+
+#endif // __LOCOP_CANONICAL_NODE_SUMMARY_BUILDER_H__
index 3de2936..0805c0e 100644 (file)
@@ -20,6 +20,8 @@
 #include "locop/SymbolTable.h"
 #include "locop/NodeSummary.h"
 #include "locop/NodeSummaryBuilder.h"
+// TODO Remove this redundant include
+#include "locop/CanonicalNodeSummaryBuilder.h"
 
 #include <loco.h>
 
 namespace locop
 {
 
-/**
- * @brief Built-in Node Summary Builder for Canonical Dialect
- */
-class CanonicalNodeSummaryBuilder final : public NodeSummaryBuilder
-{
-public:
-  CanonicalNodeSummaryBuilder(const SymbolTable *tbl) : _tbl{tbl}
-  {
-    // DO NOTHING
-  }
-
-public:
-  bool build(const loco::Node *node, locop::NodeSummary &out) const final;
-
-private:
-  const SymbolTable *_tbl;
-};
-
 struct FormattedGraph
 {
   virtual ~FormattedGraph() = default;
diff --git a/compiler/locop/src/CanonicalNodeSummaryBuilder.cpp b/compiler/locop/src/CanonicalNodeSummaryBuilder.cpp
new file mode 100644 (file)
index 0000000..c11e4b6
--- /dev/null
@@ -0,0 +1,301 @@
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "locop/CanonicalNodeSummaryBuilder.h"
+
+#include <loco/IR/CanonicalOpcode.h>
+#include <loco/IR/CanonicalDialect.h>
+#include <loco/IR/CanonicalNodeVisitor.h>
+#include <loco/IR/CanonicalNodeImpl.h>
+
+#include <pp/Format.h>
+
+#include <stdex/Memory.h>
+
+#include <map>
+#include <set>
+
+#include <cassert>
+
+using locop::SymbolTable;
+
+namespace
+{
+
+std::string symbol_lookup(const SymbolTable &tbl, const loco::Node *node)
+{
+  // This helper is now redundant but left to reduce code diffs
+  // TODO Remove this helper.
+  return tbl.lookup(node);
+};
+
+} // namespace
+
+namespace
+{
+
+// TODO Move this into loco
+loco::TensorShape tensor_shape(const loco::NodeMixin<loco::NodeTrait::TensorShape> *m)
+{
+  loco::TensorShape res;
+
+  res.rank(m->rank());
+
+  for (uint32_t axis = 0; axis < m->rank(); ++axis)
+  {
+    res.dim(axis) = m->dim(axis);
+  }
+
+  return res;
+}
+
+std::ostream &operator<<(std::ostream &os, const loco::Dimension &d)
+{
+  os << (d.known() ? std::to_string(d.value()) : std::string{"?"});
+  return os;
+}
+
+class FormattedTensorShape
+{
+public:
+  FormattedTensorShape(const loco::TensorShape *ptr) : _ptr{ptr}
+  {
+    // DO NOTHING
+  }
+
+public:
+  const loco::TensorShape &get(void) const { return *_ptr; }
+
+private:
+  const loco::TensorShape *_ptr;
+};
+
+inline FormattedTensorShape pretty(const loco::TensorShape &shape)
+{
+  return FormattedTensorShape{&shape};
+}
+
+std::ostream &operator<<(std::ostream &os, const FormattedTensorShape &f)
+{
+  const auto &shape = f.get();
+
+  os << "[";
+
+  if (shape.rank() > 0)
+  {
+    os << " " << shape.dim(0);
+
+    for (uint32_t axis = 1; axis < shape.rank(); ++axis)
+    {
+      os << " x " << shape.dim(axis);
+    }
+  }
+
+  os << " ]";
+  return os;
+}
+
+} // namespace
+
+namespace
+{
+
+/**
+ * @brief Return the opname as "<dialect>.<op>"
+ */
+std::string opname(const loco::Node *node)
+{
+  if (node->dialect() == loco::CanonicalDialect::get())
+  {
+    auto canonical_node = dynamic_cast<const loco::CanonicalNode *>(node);
+
+    assert(canonical_node != nullptr);
+
+    switch (canonical_node->opcode())
+    {
+#define CANONICAL_NODE(OPCODE, CLASS) \
+  case loco::CanonicalOpcode::OPCODE: \
+    return "canonical." #OPCODE;
+#include "loco/IR/CanonicalNodes.lst"
+#undef CANONICAL_NODE
+      default:
+        break;
+    };
+
+    return "canonical."
+           "Invalid";
+  }
+
+  return "unknown."
+         "Unknown";
+}
+
+struct NodeDesc : public locop::NodeDesc
+{
+public:
+  NodeDesc() = default;
+  NodeDesc(const locop::OpName &opname) : locop::NodeDesc{opname}
+  {
+    // DO NOTHING
+  }
+
+public:
+  // DEPRECATED
+  const locop::OpName &name(void) const { return opname(); }
+
+  // DEPRECATED
+  uint32_t arg_size(void) const { return args().count(); }
+  // DEPRECATED
+  const locop::ArgElem &arg(uint32_t n) const { return args().at(n); }
+  // DEPRECATED
+  void arg(const locop::ArgName &name, const locop::ArgValue &value) { args().append(name, value); }
+};
+
+NodeDesc default_node_desc(const SymbolTable &tbl, const loco::Node *node)
+{
+  NodeDesc res{opname(node)};
+
+  for (uint32_t n = 0; n < node->arity(); ++n)
+  {
+    res.arg(std::string{"arg"} + std::to_string(n), symbol_lookup(tbl, node->arg(n)));
+  }
+  res.state(NodeDesc::State::PartiallyKnown);
+
+  return res;
+}
+
+class CanonicalNodeDescBuilder final : public loco::CanonicalNodeVisitor<NodeDesc>
+{
+public:
+  CanonicalNodeDescBuilder(const SymbolTable *symtbl) : _symtbl{symtbl}
+  {
+    // DO NOTHING
+  }
+
+public:
+  // TODO Build a node description for each canonical node
+  NodeDesc visit(const loco::Push *node) final
+  {
+    NodeDesc res{opname(node)};
+
+    res.arg("index", node->indexed() ? pp::fmt(node->index()) : pp::fmt('?'));
+    res.arg("from", symbol_lookup(*_symtbl, node->from()));
+    res.arg("shape", pp::fmt(pretty(tensor_shape(node))));
+    res.state(NodeDesc::State::Complete);
+
+    return res;
+  }
+
+  NodeDesc visit(const loco::Pull *node) final
+  {
+    NodeDesc res{opname(node)};
+
+    res.arg("index", node->indexed() ? pp::fmt(node->index()) : pp::fmt('?'));
+    // TODO Print dtype
+    res.arg("shape", pp::fmt(pretty(tensor_shape(node))));
+    res.state(NodeDesc::State::PartiallyKnown);
+
+    return res;
+  }
+
+  NodeDesc visit(const loco::Forward *node) final
+  {
+    NodeDesc res{opname(node)};
+
+    res.arg("input", symbol_lookup(*_symtbl, node->input()));
+    res.state(NodeDesc::State::Complete);
+
+    return res;
+  }
+
+  NodeDesc visit(const loco::ConstGen *node) final
+  {
+    NodeDesc res{opname(node)};
+
+    // TODO Print data type
+    res.arg("shape", pp::fmt(pretty(tensor_shape(node))));
+    res.state(NodeDesc::State::PartiallyKnown);
+
+    return res;
+  }
+
+  NodeDesc visit(const loco::TensorConcat *node) final
+  {
+    NodeDesc res{opname(node)};
+
+    res.arg("lhs", symbol_lookup(*_symtbl, node->lhs()));
+    res.arg("rhs", symbol_lookup(*_symtbl, node->rhs()));
+    res.arg("axis", pp::fmt(node->axis()));
+    res.state(NodeDesc::State::Complete);
+
+    return res;
+  }
+
+  NodeDesc visit(const loco::EltwiseAdd *node) final
+  {
+    NodeDesc res{opname(node)};
+
+    res.arg("lhs", symbol_lookup(*_symtbl, node->lhs()));
+    res.arg("rhs", symbol_lookup(*_symtbl, node->rhs()));
+    res.state(NodeDesc::State::Complete);
+
+    return res;
+  }
+
+  NodeDesc visit(const loco::EltwiseMul *node) final
+  {
+    NodeDesc res{opname(node)};
+
+    res.arg("lhs", symbol_lookup(*_symtbl, node->lhs()));
+    res.arg("rhs", symbol_lookup(*_symtbl, node->rhs()));
+    res.state(NodeDesc::State::Complete);
+
+    return res;
+  }
+
+public:
+  NodeDesc visit(const loco::Node *node) final { return default_node_desc(*_symtbl, node); }
+
+private:
+  const SymbolTable *_symtbl;
+};
+
+NodeDesc canonical_node_desc(const SymbolTable &tbl, const loco::CanonicalNode *canonical_node)
+{
+  CanonicalNodeDescBuilder builder{&tbl};
+  return canonical_node->accept(&builder);
+}
+
+} // namespace
+
+namespace locop
+{
+
+bool CanonicalNodeSummaryBuilder::build(const loco::Node *node, locop::NodeSummary &out) const
+{
+  // Skip if a given node does not belong to loco.canonical
+  if (node->dialect() != loco::CanonicalDialect::get())
+  {
+    return false;
+  }
+
+  auto canonical_node = dynamic_cast<const loco::CanonicalNode *>(node);
+  assert(canonical_node != nullptr);
+  out = canonical_node_desc(*_tbl, canonical_node);
+  return true;
+}
+
+} // namespace locop
index 9b09a85..aa75bd1 100644 (file)
@@ -48,26 +48,6 @@ namespace
 {
 
 // TODO Move this into loco
-loco::TensorShape tensor_shape(const loco::NodeMixin<loco::NodeTrait::TensorShape> *m)
-{
-  loco::TensorShape res;
-
-  res.rank(m->rank());
-
-  for (uint32_t axis = 0; axis < m->rank(); ++axis)
-  {
-    res.dim(axis) = m->dim(axis);
-  }
-
-  return res;
-}
-
-std::ostream &operator<<(std::ostream &os, const loco::Dimension &d)
-{
-  os << (d.known() ? std::to_string(d.value()) : std::string{"?"});
-  return os;
-}
-
 class FormattedTensorShape
 {
 public:
@@ -88,26 +68,6 @@ inline FormattedTensorShape pretty(const loco::TensorShape &shape)
   return FormattedTensorShape{&shape};
 }
 
-std::ostream &operator<<(std::ostream &os, const FormattedTensorShape &f)
-{
-  const auto &shape = f.get();
-
-  os << "[";
-
-  if (shape.rank() > 0)
-  {
-    os << " " << shape.dim(0);
-
-    for (uint32_t axis = 1; axis < shape.rank(); ++axis)
-    {
-      os << " x " << shape.dim(axis);
-    }
-  }
-
-  os << " ]";
-  return os;
-}
-
 } // namespace
 
 namespace
@@ -224,108 +184,6 @@ NodeDesc default_node_desc(const SymbolTable &tbl, const loco::Node *node)
   return res;
 }
 
-class CanonicalNodeDescBuilder final : public loco::CanonicalNodeVisitor<NodeDesc>
-{
-public:
-  CanonicalNodeDescBuilder(const SymbolTable *symtbl) : _symtbl{symtbl}
-  {
-    // DO NOTHING
-  }
-
-public:
-  // TODO Build a node description for each canonical node
-  NodeDesc visit(const loco::Push *node) final
-  {
-    NodeDesc res{opname(node)};
-
-    res.arg("index", node->indexed() ? pp::fmt(node->index()) : pp::fmt('?'));
-    res.arg("from", symbol_lookup(*_symtbl, node->from()));
-    res.arg("shape", pp::fmt(pretty(tensor_shape(node))));
-    res.state(NodeDesc::State::Complete);
-
-    return res;
-  }
-
-  NodeDesc visit(const loco::Pull *node) final
-  {
-    NodeDesc res{opname(node)};
-
-    res.arg("index", node->indexed() ? pp::fmt(node->index()) : pp::fmt('?'));
-    // TODO Print dtype
-    res.arg("shape", pp::fmt(pretty(tensor_shape(node))));
-    res.state(NodeDesc::State::PartiallyKnown);
-
-    return res;
-  }
-
-  NodeDesc visit(const loco::Forward *node) final
-  {
-    NodeDesc res{opname(node)};
-
-    res.arg("input", symbol_lookup(*_symtbl, node->input()));
-    res.state(NodeDesc::State::Complete);
-
-    return res;
-  }
-
-  NodeDesc visit(const loco::ConstGen *node) final
-  {
-    NodeDesc res{opname(node)};
-
-    // TODO Print data type
-    res.arg("shape", pp::fmt(pretty(tensor_shape(node))));
-    res.state(NodeDesc::State::PartiallyKnown);
-
-    return res;
-  }
-
-  NodeDesc visit(const loco::TensorConcat *node) final
-  {
-    NodeDesc res{opname(node)};
-
-    res.arg("lhs", symbol_lookup(*_symtbl, node->lhs()));
-    res.arg("rhs", symbol_lookup(*_symtbl, node->rhs()));
-    res.arg("axis", pp::fmt(node->axis()));
-    res.state(NodeDesc::State::Complete);
-
-    return res;
-  }
-
-  NodeDesc visit(const loco::EltwiseAdd *node) final
-  {
-    NodeDesc res{opname(node)};
-
-    res.arg("lhs", symbol_lookup(*_symtbl, node->lhs()));
-    res.arg("rhs", symbol_lookup(*_symtbl, node->rhs()));
-    res.state(NodeDesc::State::Complete);
-
-    return res;
-  }
-
-  NodeDesc visit(const loco::EltwiseMul *node) final
-  {
-    NodeDesc res{opname(node)};
-
-    res.arg("lhs", symbol_lookup(*_symtbl, node->lhs()));
-    res.arg("rhs", symbol_lookup(*_symtbl, node->rhs()));
-    res.state(NodeDesc::State::Complete);
-
-    return res;
-  }
-
-public:
-  NodeDesc visit(const loco::Node *node) final { return default_node_desc(*_symtbl, node); }
-
-private:
-  const SymbolTable *_symtbl;
-};
-
-NodeDesc canonical_node_desc(const SymbolTable &tbl, const loco::CanonicalNode *canonical_node)
-{
-  CanonicalNodeDescBuilder builder{&tbl};
-  return canonical_node->accept(&builder);
-}
-
 struct BuiltinNodeSummaryBuilder final : public locop::NodeSummaryBuilder
 {
 public:
@@ -350,25 +208,6 @@ private:
 namespace locop
 {
 
-bool CanonicalNodeSummaryBuilder::build(const loco::Node *node, locop::NodeSummary &out) const
-{
-  // Skip if a given node does not belong to loco.canonical
-  if (node->dialect() != loco::CanonicalDialect::get())
-  {
-    return false;
-  }
-
-  auto canonical_node = dynamic_cast<const loco::CanonicalNode *>(node);
-  assert(canonical_node != nullptr);
-  out = canonical_node_desc(*_tbl, canonical_node);
-  return true;
-}
-
-} // namespace locop
-
-namespace locop
-{
-
 std::ostream &operator<<(std::ostream &os, const FormattedGraph &fmt)
 {
   fmt.dump(os);