Add .dot Model IR dumper (#1027)
authorDmitry Mozolev/AI Tools Lab /SRR/Engineer/삼성전자 <d.mozolev@samsung.com>
Thu, 23 Aug 2018 15:32:45 +0000 (18:32 +0300)
committerРоман Михайлович Русяев/AI Tools Lab /SRR/Staff Engineer/삼성전자 <r.rusyaev@samsung.com>
Thu, 23 Aug 2018 15:32:45 +0000 (18:32 +0300)
Allows to present Model IR graph as a graph in .dot format.

contrib/nnc/libs/core/include/nnc/core/IR/dumpers/ir_dot_builder.h [new file with mode: 0644]
contrib/nnc/libs/core/include/nnc/core/IR/dumpers/ir_dot_dumper.h [new file with mode: 0644]
contrib/nnc/libs/core/include/nnc/core/IR/dumpers/ir_dot_node_info.h [new file with mode: 0644]
contrib/nnc/libs/core/src/core/IR/dumpers/ir_dot_builder.cpp [new file with mode: 0644]
contrib/nnc/libs/core/src/core/IR/dumpers/ir_dot_dumper.cpp [new file with mode: 0644]
contrib/nnc/libs/core/src/core/IR/dumpers/ir_dot_node_info.cpp [new file with mode: 0644]

diff --git a/contrib/nnc/libs/core/include/nnc/core/IR/dumpers/ir_dot_builder.h b/contrib/nnc/libs/core/include/nnc/core/IR/dumpers/ir_dot_builder.h
new file mode 100644 (file)
index 0000000..abf23dc
--- /dev/null
@@ -0,0 +1,44 @@
+#ifndef NNCC_IR_DOT_BUILDER_H
+#define NNCC_IR_DOT_BUILDER_H
+
+#include <sstream>
+
+#include "nnc/core/IR/model/graph/ir_node.h"
+#include "nnc/core/IR/dumpers/ir_dot_node_info.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace core
+{
+namespace dumper
+{
+
+using nncc::contrib::core::IR::model::ADT::INode;
+
+/**
+ * @brief Provides an API to add nodes and edges to the .dot Model IR representation
+ * and then write the whole graph to a provided stream.
+ */
+class IrDotBuilder
+{
+public:
+  explicit IrDotBuilder() = default;
+
+  void updateWithNode(INode *node, const DotIrNodeInfo &irNodeInfo);
+  void writeDot(std::ostream &os);
+
+private:
+  void addNode(INode *node, const DotIrNodeInfo &irNode);
+  void addEdge(INode *node1, INode *node2);
+
+  std::stringstream dot;
+};
+
+} // namespace dumper
+} // namespace core
+} // namespace contrib
+} // namespace nncc
+
+#endif //NNCC_IR_DOT_BUILDER_H
diff --git a/contrib/nnc/libs/core/include/nnc/core/IR/dumpers/ir_dot_dumper.h b/contrib/nnc/libs/core/include/nnc/core/IR/dumpers/ir_dot_dumper.h
new file mode 100644 (file)
index 0000000..e285884
--- /dev/null
@@ -0,0 +1,62 @@
+#ifndef _NNC_BACKEND_INTERPRETER_CORE_DOTDUMPER_
+#define _NNC_BACKEND_INTERPRETER_CORE_DOTDUMPER_
+
+#include "nnc/core/IR/model/visitor/visitor.h"
+#include "nnc/core/IR/model/operations/fully_connected_op.h"
+#include "nnc/core/IR/model/operations/softmax_op.h"
+#include "nnc/core/IR/model/operations/capped_relu_op.h"
+#include "nnc/core/IR/model/operations/depthwise_conv2d_op.h"
+#include "nnc/core/IR/model/operations/conv_2d_op.h"
+#include "nnc/core/IR/model/operations/pool_op.h"
+#include "nnc/core/IR/model/operations/variable_op.h"
+#include "nnc/core/IR/model/operations/relu_op.h"
+#include "nnc/core/IR/model/operations/operation.h"
+#include "nnc/core/IR/model/operations/concat_op.h"
+#include "nnc/core/IR/model/operations/bias_add_op.h"
+#include "nnc/core/IR/model/operations/reshape_op.h"
+
+#include "nnc/core/IR/dumpers/ir_dot_builder.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace core
+{
+namespace dumper
+{
+
+using nncc::contrib::core::IR::model::ADT::INode;
+using namespace nncc::contrib::core::IR::model;
+
+/**
+ * @breif Model IR visitor that can be used to output Model IR as a .dot graph.
+ * @usage Run on a Model IR graph as a visitor, and then call writeDot passing it a stream
+ */
+class IrDotDumper : public Visitor
+{
+public:
+  void visit(INode *node, ops::ConcatOp &op) override;
+  void visit(INode *node, ops::ReluOp &op) override;
+  void visit(INode *node, ops::Conv2DOp &op) override;
+  void visit(INode *node, ops::DepthwiseConv2DOp &op) override;
+  void visit(INode *node, ops::SoftmaxOp &op) override;
+  void visit(INode *node, ops::PoolOp &op) override;
+  void visit(INode *node, ops::FullyConnectedOp &op) override;
+  void visit(INode *node, ops::CappedReluOp &op) override;
+  void visit(INode *node, ops::BiasAddOp &op) override;
+  void visit(INode *node, ops::VariableOp &op) override;
+  void visit(INode *node, ops::ReshapeOp &op) override;
+
+  void writeDot(std::ostream &os) { dotBuilder.writeDot(os); };
+
+private:
+  IrDotBuilder dotBuilder;
+};
+
+} // namespace dumper
+} // namespace core
+} // namespace contrib
+} // namespace nncc
+
+#endif // _NNC_BACKEND_INTERPRETER_CORE_DOTDUMPER_
diff --git a/contrib/nnc/libs/core/include/nnc/core/IR/dumpers/ir_dot_node_info.h b/contrib/nnc/libs/core/include/nnc/core/IR/dumpers/ir_dot_node_info.h
new file mode 100644 (file)
index 0000000..cbb7e94
--- /dev/null
@@ -0,0 +1,92 @@
+#ifndef NNCC_IR_NODE_DOT_BUILDER_H
+#define NNCC_IR_NODE_DOT_BUILDER_H
+
+#include "nncc/core/ADT/tensor/Shape.h"
+#include "nnc/core/IR/model/operations/common.h"
+#include "nnc/core/IR/model/operations/pool_op.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace core
+{
+namespace dumper
+{
+
+using namespace nncc::contrib::core::IR::model;
+using nncc::core::ADT::tensor::Shape;
+
+/**
+ * @brief Can collect information about a NN operator, and then use it to output
+ * this info as a node label in .dot format.
+ * @usage Provides a typical builder interface for collecting NN operator info, for example:
+ * auto info = DotIrNodeInfo().withType("Softmax").withInShapes({{1, 2, 3}, {2, 3, 4}}).withAxis(0);
+ * Then resulting .dot node label is accessed with info.getLabel();
+ */
+class DotIrNodeInfo
+{
+public:
+  using Shapes = std::vector<Shape>;
+  using NamedShape = std::pair<std::string, Shape>;
+  using PadType = ops::PaddingType;
+  using PoolType = ops::PoolOp::PoolingType;
+
+  DotIrNodeInfo() = default;
+
+  DotIrNodeInfo &withType(const std::string &typeName, const std::string &nodeName);
+  DotIrNodeInfo &withInShapes(Shapes &&inShapes);
+  DotIrNodeInfo &withOutShapes(Shapes &&outShapes);
+  DotIrNodeInfo &withKernelShape(const Shape &kernelShape);
+  DotIrNodeInfo &withStride(const Shape &strideShape);
+  DotIrNodeInfo &withShape(const std::string &shapeName, const Shape &shape);
+  DotIrNodeInfo &withPadType(PadType padType);
+  DotIrNodeInfo &withPoolType(PoolType poolType);
+  DotIrNodeInfo &withAxis(float axis);
+
+/**
+ * Create a label in dot format for the Model IR node.
+ * Label is created in the form of the table with node name and type on top,
+ * then in and out shapes in the left column, and other parameters on the right column.
+ *
+ * Dot syntax for tables:
+ * - pipe ("|") symbol adds another line/column; by default it adds columns
+ * - when something gets inside of "{}", it changes what the pipe adds
+ * Example: label="leftCol | middleCol | {firstRowInRightCol | secondRowInRightCol }"
+ */
+  std::string getLabel() const;
+
+private:
+  void writeInShapesLabel(std::stringstream &ss) const;
+  void writeOutShapesLabel(std::stringstream &ss) const;
+
+  std::string labelForPadAndPool() const;
+  std::string labelForNodeParams() const;
+  void addPipeIfNeeded(std::stringstream &ss, bool needed, bool &needPipe) const;
+
+  std::string typeName;
+  std::string nodeName;
+
+  Shapes inShapes;
+  Shapes outShapes;
+
+  Shape kernelShape;
+
+  Shape strideShape;
+  std::vector<NamedShape> shapes;
+
+  bool hasPad = false;
+  PadType padType = PadType::Valid;
+
+  bool hasPool = false;
+  PoolType poolType = PoolType::MAX;
+
+  float axis = -1;
+};
+
+} // namespace dumper
+} // namespace core
+} // namespace contrib
+} // namespace nncc
+
+#endif // NNCC_IR_NODE_DOT_BUILDER_H
diff --git a/contrib/nnc/libs/core/src/core/IR/dumpers/ir_dot_builder.cpp b/contrib/nnc/libs/core/src/core/IR/dumpers/ir_dot_builder.cpp
new file mode 100644 (file)
index 0000000..14b4fad
--- /dev/null
@@ -0,0 +1,39 @@
+#include "nnc/core/IR/dumpers/ir_dot_builder.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace core
+{
+namespace dumper
+{
+
+void IrDotBuilder::updateWithNode(INode *node, const DotIrNodeInfo &irNodeInfo)
+{
+  addNode(node, irNodeInfo);
+  for (auto &prev : node->getPrevNodes())
+  {
+    addEdge(prev.node, node);
+  }
+}
+
+void IrDotBuilder::writeDot(std::ostream &os)
+{
+  os << "digraph D {" << std::endl << dot.str() << std::endl << "}" << std::endl;
+}
+
+void IrDotBuilder::addNode(INode *node, const DotIrNodeInfo &irNode)
+{
+  dot << node->getId() << " [shape=record label=\"" << irNode.getLabel() << "\"];" << std::endl;
+}
+
+void IrDotBuilder::addEdge(INode *node1, INode *node2)
+{
+  dot << node1->getId() << " -> " << node2->getId() << ";" << std::endl;
+}
+
+} // namespace dumper
+} // namespace core
+} // namespace contrib
+} // namespace nncc
diff --git a/contrib/nnc/libs/core/src/core/IR/dumpers/ir_dot_dumper.cpp b/contrib/nnc/libs/core/src/core/IR/dumpers/ir_dot_dumper.cpp
new file mode 100644 (file)
index 0000000..35516d8
--- /dev/null
@@ -0,0 +1,158 @@
+#include <iostream>
+
+#include "nncc/core/ADT/tensor/Shape.h"
+#include "nnc/core/IR/model/graph/ir_node.h"
+
+#include "nnc/core/IR/dumpers/ir_dot_node_info.h"
+#include "nnc/core/IR/dumpers/ir_dot_dumper.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace core
+{
+namespace dumper
+{
+
+using nncc::core::ADT::tensor::Shape;
+
+static std::vector<Shape> getInputShapes(OpDescription &op)
+{
+  std::vector<Shape> shapes;
+  for (std::size_t i = 0; i < op.getNumInputs(); ++i)
+  {
+    shapes.push_back(op.getInputShape(i));
+  }
+  return shapes;
+}
+
+static std::vector<Shape> getOutputShapes(const OpDescription &op)
+{
+  std::vector<Shape> shapes;
+  for (std::size_t i = 0; i < op.getNumOutputs(); ++i)
+  {
+    shapes.push_back(op.getOutputShape(i));
+  }
+  return shapes;
+}
+
+void IrDotDumper::visit(INode *node, ops::BiasAddOp &op)
+{
+  auto nodeInfo = DotIrNodeInfo().withType("BiasAdd", node->getName())
+                                 .withInShapes(getInputShapes(op))
+                                 .withOutShapes(getOutputShapes(op))
+                                 .withKernelShape(op.getWeights().getShape());
+
+  dotBuilder.updateWithNode(node, nodeInfo);
+}
+
+void IrDotDumper::visit(INode *node, ops::CappedReluOp &op)
+{
+  auto nodeInfo = DotIrNodeInfo().withType("CappedRelu", node->getName())
+                                 .withInShapes(getInputShapes(op))
+                                 .withOutShapes(getOutputShapes(op))
+                                 .withAxis(op.getCap());
+
+  dotBuilder.updateWithNode(node, nodeInfo);
+
+}
+
+void IrDotDumper::visit(INode *node, ops::ConcatOp &op)
+{
+  auto nodeInfo = DotIrNodeInfo().withType("Concat", node->getName())
+                                 .withInShapes(getInputShapes(op))
+                                 .withOutShapes(getOutputShapes(op))
+                                 .withAxis(op.getAxis());
+
+  dotBuilder.updateWithNode(node, nodeInfo);
+}
+
+void IrDotDumper::visit(INode *node, ops::Conv2DOp &op)
+{
+  auto nodeInfo = DotIrNodeInfo().withType("Conv2D", node->getName())
+                                 .withInShapes(getInputShapes(op))
+                                 .withOutShapes(getOutputShapes(op))
+                                 .withKernelShape(op.getKernel().getShape())
+                                 .withPadType(op.getPaddingType())
+                                 .withStride(op.getStrides());
+
+  dotBuilder.updateWithNode(node, nodeInfo);
+}
+
+void IrDotDumper::visit(INode *node, ops::DepthwiseConv2DOp &op)
+{
+  auto nodeInfo = DotIrNodeInfo().withType("DepthwiseConv2D", node->getName())
+                                 .withInShapes(getInputShapes(op))
+                                 .withOutShapes(getOutputShapes(op))
+                                 .withKernelShape(op.getKernel().getShape())
+                                 .withPadType(op.getPaddingType())
+                                 .withStride(op.getStrides());
+
+  dotBuilder.updateWithNode(node, nodeInfo);
+}
+
+void IrDotDumper::visit(INode *node, ops::FullyConnectedOp &op)
+{
+  auto nodeInfo = DotIrNodeInfo().withType("FullyConnected", node->getName())
+                                 .withInShapes(getInputShapes(op))
+                                 .withOutShapes(getOutputShapes(op))
+                                 .withKernelShape(op.getWeights().getShape());
+
+  dotBuilder.updateWithNode(node, nodeInfo);
+}
+
+void IrDotDumper::visit(INode *node, ops::SoftmaxOp &op)
+{
+  auto nodeInfo = DotIrNodeInfo().withType("Softmax", node->getName())
+                                 .withInShapes(getInputShapes(op))
+                                 .withOutShapes(getOutputShapes(op))
+                                 .withAxis(op.getAxis());
+
+  dotBuilder.updateWithNode(node, nodeInfo);
+}
+
+void IrDotDumper::visit(INode *node, ops::PoolOp &op)
+{
+  auto nodeInfo = DotIrNodeInfo().withType("Pool2D", node->getName())
+                                 .withInShapes(getInputShapes(op))
+                                 .withOutShapes(getOutputShapes(op))
+                                 .withShape("PoolWindow", op.getWindowShape())
+                                 .withPadType(op.getPaddingType())
+                                 .withPoolType(op.getPoolingType())
+                                 .withStride(op.getStrides());
+
+  dotBuilder.updateWithNode(node, nodeInfo);
+}
+
+void IrDotDumper::visit(INode *node, ops::ReluOp &op)
+{
+  auto nodeInfo = DotIrNodeInfo().withType("ReLU", node->getName())
+                                 .withInShapes(getInputShapes(op))
+                                 .withOutShapes(getOutputShapes(op));
+
+  dotBuilder.updateWithNode(node, nodeInfo);
+}
+
+void IrDotDumper::visit(INode *node, ops::ReshapeOp &op)
+{
+  auto nodeInfo = DotIrNodeInfo().withType("Reshape", node->getName())
+                                 .withInShapes(getInputShapes(op))
+                                 .withOutShapes(getOutputShapes(op));
+
+  dotBuilder.updateWithNode(node, nodeInfo);
+}
+
+void IrDotDumper::visit(INode *node, ops::VariableOp &op)
+{
+  auto nodeInfo = DotIrNodeInfo().withType("Input", node->getName())
+                                 .withInShapes(getInputShapes(op))
+                                 .withOutShapes(getOutputShapes(op));
+
+  dotBuilder.updateWithNode(node, nodeInfo);
+}
+
+} // namespace dumper
+} // namespace core
+} // namespace contrib
+} // namespace nncc
diff --git a/contrib/nnc/libs/core/src/core/IR/dumpers/ir_dot_node_info.cpp b/contrib/nnc/libs/core/src/core/IR/dumpers/ir_dot_node_info.cpp
new file mode 100644 (file)
index 0000000..14fafeb
--- /dev/null
@@ -0,0 +1,239 @@
+#include <sstream>
+#include <iostream>
+
+#include "nnc/core/IR/dumpers/ir_dot_node_info.h"
+
+namespace nncc
+{
+namespace contrib
+{
+namespace core
+{
+namespace dumper
+{
+
+std::ostream &operator<<(std::ostream &s, const Shape &shape) noexcept
+{
+  s << "[";
+  for (uint32_t d = 0; d < shape.rank(); ++d)
+  {
+    if (d != 0)
+      s << ", ";
+    s << shape.dim(d);
+  }
+  s << "]";
+  return s;
+}
+
+DotIrNodeInfo &DotIrNodeInfo::withType(const std::string &typeName, const std::string &nodeName)
+{
+  this->typeName = typeName;
+  this->nodeName = nodeName;
+  return *this;
+}
+
+DotIrNodeInfo &DotIrNodeInfo::withInShapes(DotIrNodeInfo::Shapes &&inShapes)
+{
+  this->inShapes = inShapes;
+  return *this;
+}
+
+DotIrNodeInfo &DotIrNodeInfo::withOutShapes(DotIrNodeInfo::Shapes &&outShapes)
+{
+  this->outShapes = outShapes;
+  return *this;
+}
+
+DotIrNodeInfo &DotIrNodeInfo::withKernelShape(const Shape &kernelShape)
+{
+  this->kernelShape = kernelShape;
+  return *this;
+}
+
+DotIrNodeInfo &DotIrNodeInfo::withStride(const Shape &strideShape)
+{
+  this->strideShape = strideShape;
+  return *this;
+}
+
+DotIrNodeInfo &DotIrNodeInfo::withShape(const std::string &shapeName, const Shape &shape)
+{
+  this->shapes.emplace_back(shapeName, shape);
+  return *this;
+}
+
+DotIrNodeInfo &DotIrNodeInfo::withPadType(DotIrNodeInfo::PadType padType)
+{
+  this->padType = padType;
+  this->hasPad = true;
+  return *this;
+}
+
+DotIrNodeInfo &DotIrNodeInfo::withPoolType(DotIrNodeInfo::PoolType poolType)
+{
+  this->poolType = poolType;
+  this->hasPool = true;
+  return *this;
+}
+
+DotIrNodeInfo &DotIrNodeInfo::withAxis(float axis)
+{
+  this->axis = axis;
+  return *this;
+}
+
+std::string DotIrNodeInfo::getLabel() const
+{
+  std::stringstream ss;
+
+  ss << "{";
+
+  // Node type and name
+  ss << (!typeName.empty() ? typeName : "TYPE_NOT_SET")
+     << ": "
+     << (!nodeName.empty() ? nodeName : "NAME_NOT_SET");
+
+  if (typeName.empty())
+  {
+    std::cout << "Warning: Node type is not set for "
+         << (nodeName.empty() ? "one of the nodes" : "node " + nodeName) << std::endl;
+  }
+
+  ss << " | ";
+
+  // Note inputs and output shapes
+  ss << "{{";
+  writeInShapesLabel(ss);
+  ss << " | ";
+  writeOutShapesLabel(ss);
+  ss << "}";
+
+  // Other node parameters - kernel shape, stride, padding type etc
+  std::string label = labelForNodeParams();
+  if (!label.empty())
+    ss << " | {" << label << "}";
+
+  ss << "}}";
+
+  return ss.str();
+}
+
+std::string DotIrNodeInfo::labelForPadAndPool() const
+{
+  std::stringstream ss;
+
+  if (hasPad)
+  {
+    ss << "{";
+    ss << "PadType: " << (padType == PadType::Valid ? "VALID" : "SAME");
+    if (hasPool) ss << " | ";
+    else ss << "}";
+  }
+
+  if (hasPool)
+  {
+    if (!hasPad) ss << "{";
+
+    std::string poolTypeStr;
+    switch (poolType)
+    {
+      case PoolType::MAX:
+        poolTypeStr = "MAX";
+        break;
+      case PoolType::AVG:
+        poolTypeStr = "AVG";
+        break;
+      case PoolType::MIN:
+        poolTypeStr = "MIN";
+        break;
+      default:
+        assert(false && "Unknown PoolType");
+    }
+    ss << "PoolType: " << poolTypeStr;
+    ss << "}";
+  }
+
+  return ss.str();
+}
+
+void DotIrNodeInfo::writeInShapesLabel(std::stringstream &ss) const
+{
+  if (inShapes.empty())
+    ss << "IN_SHAPES_NOT_SET";
+  else
+  {
+    for (Shapes::size_type i = 0; i < inShapes.size(); ++i)
+    {
+      if (i != 0)
+        ss << " | ";
+      ss << "in" << i << ": " << inShapes[i];
+    }
+  }
+}
+
+void DotIrNodeInfo::writeOutShapesLabel(std::stringstream &ss) const
+{
+  if (outShapes.empty())
+    ss << "OUT_SHAPES_NOT_SET";
+  else
+  {
+    for (Shapes::size_type i = 0; i < outShapes.size(); ++i)
+    {
+      if (i != 0)
+        ss << "| ";
+      ss << "out" << i << ": " << outShapes[i];
+    }
+  }
+}
+
+std::string DotIrNodeInfo::labelForNodeParams() const
+{
+  std::stringstream ss;
+
+  bool needPipe = false;
+  if (kernelShape.rank() != 0)
+  {
+    ss << "Kernel: " << kernelShape;
+    needPipe = true;
+  }
+
+  std::string label = labelForPadAndPool();
+  addPipeIfNeeded(ss, !label.empty(), needPipe);
+  ss << label;
+
+  addPipeIfNeeded(ss, !shapes.empty(), needPipe);
+  for (Shapes::size_type i = 0; i < shapes.size(); ++i)
+  {
+    if (i != 0)
+      ss << " | ";
+    ss << shapes[i].first << ": " << shapes[i].second;
+  }
+
+  if (strideShape.rank() != 0)
+  {
+    addPipeIfNeeded(ss, true, needPipe);
+    ss << "Stride: " << strideShape;
+  }
+
+  if (axis != -1)
+  {
+    addPipeIfNeeded(ss, true, needPipe);
+    ss << "Axis: " << axis;
+  }
+
+  return ss.str();
+}
+
+void DotIrNodeInfo::addPipeIfNeeded(std::stringstream &ss, bool needed, bool &needPipe) const
+{
+  if (needed)
+  {
+    if (needPipe) ss << " | ";
+    else needPipe = true;
+  }
+}
+
+} // namespace dumper
+} // namespace core
+} // namespace contrib
+} // namespace nncc
\ No newline at end of file