[mir] Refactor graph dot dumper (#7273)
authorСергей Баранников/AI Tools Lab /SRR/Engineer/삼성전자 <s.barannikov@samsung.com>
Fri, 6 Sep 2019 21:11:14 +0000 (00:11 +0300)
committerAlexander Efimov/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Fri, 6 Sep 2019 21:11:14 +0000 (00:11 +0300)
Refactor graph dot dumper to loosen dependencies.

Signed-off-by: Sergei Barannikov <s.barannikov@samsung.com>
12 files changed:
compiler/mir/CMakeLists.txt
compiler/mir/include/mir/DataFormat.h
compiler/mir/include/mir/IrDotDumper.h
compiler/mir/include/mir/Shape.h
compiler/mir/include/mir/ir_dot_node_info.h [deleted file]
compiler/mir/src/DotGraph.cpp [moved from compiler/mir/src/ir_dot_builder.cpp with 50% similarity]
compiler/mir/src/DotGraph.h [moved from compiler/mir/include/mir/ir_dot_builder.h with 54% similarity]
compiler/mir/src/DotNodeBuilder.cpp [new file with mode: 0644]
compiler/mir/src/DotNodeBuilder.h [new file with mode: 0644]
compiler/mir/src/IrDotDumper.cpp
compiler/mir/src/Shape.cpp
compiler/mir/src/ir_dot_node_info.cpp [deleted file]

index 56a8037..9f6db34 100644 (file)
@@ -1,29 +1,29 @@
 set(MIR_SOURCES
-    src/ops/AvgPool2DOp.cpp
-    src/ops/BinaryElementwiseOp.cpp
-    src/ops/ConcatOp.cpp
-    src/ops/Conv2DOp.cpp
-    src/ops/DeConv2DOp.cpp
-    src/ops/DepthwiseConv2DOp.cpp
-    src/ops/FullyConnectedOp.cpp
-    src/ops/GatherOp.cpp
-    src/ops/MaxPool2DOp.cpp
-    src/ops/PadOp.cpp
-    src/ops/ReduceOp.cpp
-    src/ops/SqueezeOp.cpp
-    src/ops/SliceOp.cpp
-    src/ops/TransposeOp.cpp
-    src/Graph.cpp
-    src/Index.cpp
-    src/ir_dot_builder.cpp
-    src/IrDotDumper.cpp
-    src/GraphPatternMatcher.cpp
-    src/ir_dot_node_info.cpp
-    src/Operation.cpp
-    src/Shape.cpp
-    src/Tensor.cpp
-    src/TensorVariant.cpp
-    src/Visitor.cpp)
+        src/ops/AvgPool2DOp.cpp
+        src/ops/BinaryElementwiseOp.cpp
+        src/ops/ConcatOp.cpp
+        src/ops/Conv2DOp.cpp
+        src/ops/DeConv2DOp.cpp
+        src/ops/DepthwiseConv2DOp.cpp
+        src/ops/FullyConnectedOp.cpp
+        src/ops/GatherOp.cpp
+        src/ops/MaxPool2DOp.cpp
+        src/ops/PadOp.cpp
+        src/ops/ReduceOp.cpp
+        src/ops/SqueezeOp.cpp
+        src/ops/SliceOp.cpp
+        src/ops/TransposeOp.cpp
+        src/DotGraph.cpp
+        src/DotNodeBuilder.cpp
+        src/Graph.cpp
+        src/GraphPatternMatcher.cpp
+        src/Index.cpp
+        src/IrDotDumper.cpp
+        src/Operation.cpp
+        src/Shape.cpp
+        src/Tensor.cpp
+        src/TensorVariant.cpp
+        src/Visitor.cpp)
 
 add_library(mir STATIC ${MIR_SOURCES})
 target_include_directories(mir PUBLIC include)
index 5979d08..44edfa8 100644 (file)
@@ -17,6 +17,9 @@
 #ifndef _MIR_DATA_FORMAT_H_
 #define _MIR_DATA_FORMAT_H_
 
+#include <cassert>
+#include <string>
+
 namespace mir
 {
 
@@ -68,6 +71,20 @@ inline int getDataSpatialDimIndex(DataFormat data_format, int dim)
   }
 }
 
+inline std::string toString(DataFormat data_format)
+{
+  switch (data_format)
+  {
+    case DataFormat::NCHW:
+      return "NCHW";
+    case DataFormat::NHWC:
+      return "NHWC";
+    default:
+      assert(false);
+      return ""; // Dummy value to silence compiler warning.
+  }
+}
+
 } // namespace mir
 
 #endif //_MIR_DATA_FORMAT_H_
index ae86c16..e6c295c 100644 (file)
 #ifndef _MIR_IR_DOT_DUMPER_H_
 #define _MIR_IR_DOT_DUMPER_H_
 
-#include "mir/Visitor.h"
-#include "mir/ir_dot_builder.h"
+#include <ostream>
 
 namespace mir
 {
 
-/**
- * @brief Model IR visitor that can be used to output Model IR as a .dot graph.
- * @usage Run on a Model IR graph as a visitor, and then call writeDot passing it a stream
- */
-class IrDotDumper : public IVisitor
-{
-public:
-  void visit(ops::AddOp &op) override;
-  void visit(ops::AvgPool2DOp &op) override;
-  void visit(ops::CappedReluOp &op) override;
-  void visit(ops::ConcatOp &op) override;
-  void visit(ops::ConstantOp &op) override;
-  void visit(ops::Conv2DOp &op) override;
-  void visit(ops::DeConv2DOp &op) override;
-  void visit(ops::DepthwiseConv2DOp &op) override;
-  void visit(ops::DivOp &op) override;
-  void visit(ops::EluOp &op) override;
-  void visit(ops::FullyConnectedOp &op) override;
-  void visit(ops::GatherOp &op) override;
-  void visit(ops::InputOp &op) override;
-  void visit(ops::LeakyReluOp &op) override;
-  void visit(ops::MaxOp &op) override;
-  void visit(ops::MaxPool2DOp &op) override;
-  void visit(ops::MulOp &op) override;
-  void visit(ops::OutputOp &op) override;
-  void visit(ops::PadOp &op) override;
-  void visit(ops::ReduceMeanOp &op) override;
-  void visit(ops::ReluOp &op) override;
-  void visit(ops::ReshapeOp &op) override;
-  void visit(ops::ResizeOp &op) override;
-  void visit(ops::SigmoidOp &op) override;
-  void visit(ops::SliceOp &op) override;
-  void visit(ops::SoftmaxOp &op) override;
-  void visit(ops::SqrtOp &op) override;
-  void visit(ops::SqueezeOp &op) override;
-  void visit(ops::SubOp &op) override;
-  void visit(ops::TanhOp &op) override;
-  void visit(ops::TransposeOp &op) override;
-
-  void writeDot(std::ostream &os) { _dot_builder.writeDot(os); };
-
-private:
-  void addSimpleOperation(const Operation &op);
-
-  IrDotBuilder _dot_builder;
-};
+class Graph;
 
 void dumpGraph(const Graph *graph, std::ostream &stream);
 
index f3082e8..fbd4951 100644 (file)
@@ -20,8 +20,6 @@
 #include <initializer_list>
 #include <vector>
 #include <cstdint>
-#include <ostream>
-#include <cassert>
 
 #include "adtidas/SmallVector.h"
 #include "mir/Common.h"
@@ -68,7 +66,7 @@ private:
   adt::small_vector<int32_t, MAX_DIMENSION_COUNT> _dims;
 };
 
-std::ostream &operator<<(std::ostream &s, const Shape &sh);
+std::string toString(const Shape &shape);
 
 } // namespace mir
 
diff --git a/compiler/mir/include/mir/ir_dot_node_info.h b/compiler/mir/include/mir/ir_dot_node_info.h
deleted file mode 100644 (file)
index 8852dd2..0000000
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef _MIR_IR_DOT_NODE_INFO_H_
-#define _MIR_IR_DOT_NODE_INFO_H_
-
-#include "mir/Operation.h"
-#include "mir/Shape.h"
-#include "mir/ops/CommonProps.h"
-
-namespace mir
-{
-
-/**
- * @brief Can collect information about a NN operator, and then use it to output
- * this info as a node label in .dot format.
- * @usage Provides a typical builder interface for collecting NN operator info, for example:
- * auto info = DotIrNodeInfo().withType("Softmax").withInShapes({{1, 2, 3}, {2, 3,
- * 4}}).withMisc("Axis", 0));
- * Then resulting .dot node label is accessed with info.getLabel();
- */
-class DotIrNodeInfo
-{
-public:
-  using Shapes = std::vector<Shape>;
-  using NamedShape = std::pair<std::string, Shape>;
-  using MiscVal = std::pair<std::string, std::string>;
-  using PadType = ops::PaddingType;
-
-  class Stringable
-  {
-  public:
-    template <typename T>
-    /*implicit*/ Stringable(
-        T val); // NOLINT(google-explicit-constructor, hicpp-explicit-conversions)
-
-    /*implicit*/ operator std::string &&()
-    { // NOLINT(google-explicit-constructor, hicpp-explicit-conversions)
-      return std::move(_val);
-    }
-
-  private:
-    std::string _val;
-  };
-
-  DotIrNodeInfo() = default;
-
-  DotIrNodeInfo &withType(const std::string &type_name);
-  DotIrNodeInfo &withInShapes(Shapes &&in_shapes);
-  DotIrNodeInfo &withOutShapes(Shapes &&out_shapes);
-
-  DotIrNodeInfo &withStride(const Shape &stride_shape);
-  DotIrNodeInfo &withShape(const std::string &shape_name, const Shape &shape);
-  DotIrNodeInfo &withPadType(PadType pad_type);
-  DotIrNodeInfo &withMisc(const std::string &name, Stringable val);
-
-  /**
-   * Create a label in dot format for the Model IR node.
-   * Label is created in the form of the table with node name and type on top,
-   * then in and out shapes in the left column, and other parameters on the right column.
-   *
-   * Dot syntax for tables:
-   * - pipe ("|") symbol adds another line/column; by default it adds columns
-   * - when something gets inside of "{}", it changes what the pipe adds
-   * Example: label="leftCol | middleCol | {firstRowInRightCol | secondRowInRightCol }"
-   */
-  std::string getLabel() const;
-
-private:
-  void writeInShapesLabel(std::stringstream &ss) const;
-  void writeOutShapesLabel(std::stringstream &ss) const;
-
-  std::string labelForPadAndPool() const;
-  std::string labelForNodeParams() const;
-  void addPipeIfNeeded(std::stringstream &ss, bool needed, bool &need_pipe) const;
-
-  std::string _type_name;
-
-  Shapes _in_shapes;
-  Shapes _out_shapes;
-
-  Shape _kernel_shape;
-
-  Shape _stride_shape;
-  std::vector<NamedShape> _shapes;
-  std::vector<MiscVal> _misc_vals;
-
-  bool _has_pad = false;
-  PadType _pad_type = PadType::Valid;
-
-  bool _has_pool = false;
-};
-
-template <typename T> DotIrNodeInfo::Stringable::Stringable(T val) : _val(std::to_string(val)) {}
-
-template <> DotIrNodeInfo::Stringable::Stringable(std::string val);
-
-template <> DotIrNodeInfo::Stringable::Stringable(const char *val);
-
-} // namespace mir
-
-#endif //_MIR_IR_DOT_NODE_INFO_H_
similarity index 50%
rename from compiler/mir/src/ir_dot_builder.cpp
rename to compiler/mir/src/DotGraph.cpp
index e085275..b0e92e1 100644 (file)
  * limitations under the License.
  */
 
-#include "mir/ir_dot_builder.h"
+#include "DotGraph.h"
 
 namespace mir
 {
 
-void IrDotBuilder::updateWithOp(const Operation *op, const DotIrNodeInfo &ir_node_info)
-{
-  addNode(op, ir_node_info);
-  for (auto &prev : op->getInputs())
-  {
-    addEdge(prev.getProducer()->getNode(), op);
-  }
-}
+void DotGraph::addNode(DotNode node) { _nodes.emplace_back(std::move(node)); }
 
-void IrDotBuilder::writeDot(std::ostream &os)
-{
-  os << "digraph D {" << std::endl << dot.str() << std::endl << "}" << std::endl;
-}
+void DotGraph::addEdge(DotEdge edge) { _edges.emplace_back(edge); }
 
-void IrDotBuilder::addNode(const Operation *op, const DotIrNodeInfo &ir_node)
+std::ostream &operator<<(std::ostream &stream, const DotGraph &graph)
 {
-  dot << op->getId() << " [shape=record label=\"" << ir_node.getLabel() << "\"];" << std::endl;
-}
-
-void IrDotBuilder::addEdge(const Operation *op1, const Operation *op2)
-{
-  dot << op1->getId() << " -> " << op2->getId() << ";" << std::endl;
+  stream << "digraph D {" << std::endl;
+  for (const auto &node : graph._nodes)
+  {
+    stream << node.id << " [shape=record label=\"" << node.label << "\"];" << std::endl;
+  }
+  for (const auto &edge : graph._edges)
+  {
+    stream << edge.src_id << " -> " << edge.dst_id << ";" << std::endl;
+  }
+  stream << "}" << std::endl;
+  return stream;
 }
 
 } // namespace mir
similarity index 54%
rename from compiler/mir/include/mir/ir_dot_builder.h
rename to compiler/mir/src/DotGraph.h
index 311ec71..29698bb 100644 (file)
  * limitations under the License.
  */
 
-#ifndef _MIR_IR_DOT_BUILDER_H_
-#define _MIR_IR_DOT_BUILDER_H_
+#ifndef _MIR_DOT_GRAPH_
+#define _MIR_DOT_GRAPH_
 
+#include <cstddef>
 #include <sstream>
-
-#include "mir/ir_dot_node_info.h"
+#include <string>
+#include <vector>
 
 namespace mir
 {
 
-/**
- * @brief Provides an API to add nodes and edges to the .dot Model IR representation
- * and then write the whole graph to a provided stream.
- */
-class IrDotBuilder
+struct DotNode
+{
+  std::size_t id;
+  std::string label;
+};
+
+struct DotEdge
+{
+  std::size_t src_id;
+  std::size_t dst_id;
+};
+
+class DotGraph
 {
 public:
-  explicit IrDotBuilder() = default;
+  void addNode(DotNode node);
+  void addEdge(DotEdge edge);
 
-  void updateWithOp(const Operation *op, const DotIrNodeInfo &ir_node_info);
-  void writeDot(std::ostream &os);
+  friend std::ostream &operator<<(std::ostream &stream, const DotGraph &graph);
 
 private:
-  void addNode(const Operation *op, const DotIrNodeInfo &ir_node);
-  void addEdge(const Operation *op1, const Operation *op2);
-
-  std::stringstream dot;
+  std::vector<DotNode> _nodes;
+  std::vector<DotEdge> _edges;
 };
 
 } // namespace mir
 
-#endif //_MIR_IR_DOT_BUILDER_H_
+#endif //_MIR_DOT_GRAPH_
diff --git a/compiler/mir/src/DotNodeBuilder.cpp b/compiler/mir/src/DotNodeBuilder.cpp
new file mode 100644 (file)
index 0000000..0afaa37
--- /dev/null
@@ -0,0 +1,203 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "DotNodeBuilder.h"
+#include "mir/OpDefs.h"
+
+#include <sstream>
+
+namespace mir
+{
+
+template <typename T> static std::string toString(const std::vector<T> &v)
+{
+  std::stringstream ss;
+  ss << "[";
+  for (std::size_t i = 0; i < v.size(); ++i)
+  {
+    if (i != 0)
+      ss << ", ";
+    ss << v[i];
+  }
+  return ss.str();
+}
+
+DotNodeBuilder::DotNodeBuilder(const Operation &op)
+{
+  _type_name = getTypeName(op.getType());
+  _id = op.getId();
+
+  for (std::size_t i = 0; i < op.getNumInputs(); ++i)
+  {
+    _in_shapes.push_back(toString(op.getInputShape(i)));
+  }
+
+  for (std::size_t i = 0; i < op.getNumOutputs(); ++i)
+  {
+    _out_shapes.push_back(toString(op.getOutputShape(i)));
+  }
+
+  // Get attributes.
+  const_cast<Operation &>(op).accept(this);
+}
+
+void DotNodeBuilder::visit(ops::AvgPool2DOp &op)
+{
+  addAttribute("Window size", toString(op.getWindowSize()));
+  addAttribute("Strides", toString(op.getStrides()));
+  addAttribute("Padding before", toString(op.getPaddingBefore()));
+  addAttribute("Padding after", toString(op.getPaddingAfter()));
+  addAttribute("Include pad", std::to_string(op.getIncludePad()));
+}
+
+void DotNodeBuilder::visit(ops::CappedReluOp &op)
+{
+  addAttribute("Cap", std::to_string(op.getCap()));
+}
+
+void DotNodeBuilder::visit(ops::ConcatOp &op)
+{
+  addAttribute("Axis", std::to_string(op.getAxis()));
+}
+
+void DotNodeBuilder::visit(ops::Conv2DOp &op)
+{
+  addAttribute("Strides", toString(op.getStrides()));
+  addAttribute("Padding before", toString(op.getPaddingBefore()));
+  addAttribute("Padding after", toString(op.getPaddingAfter()));
+  addAttribute("Data format", toString(op.getDataFormat()));
+}
+
+void DotNodeBuilder::visit(ops::DepthwiseConv2DOp &op)
+{
+  addAttribute("Strides", toString(op.getStrides()));
+  addAttribute("Padding before", toString(op.getPaddingBefore()));
+  addAttribute("Padding after", toString(op.getPaddingAfter()));
+  addAttribute("Data format", toString(op.getDataFormat()));
+}
+
+void DotNodeBuilder::visit(ops::MaxPool2DOp &op)
+{
+  addAttribute("Window size", toString(op.getWindowSize()));
+  addAttribute("Strides", toString(op.getStrides()));
+  addAttribute("Padding before", toString(op.getPaddingBefore()));
+  addAttribute("Padding after", toString(op.getPaddingAfter()));
+  addAttribute("Data format", toString(op.getDataFormat()));
+}
+
+void DotNodeBuilder::visit(ops::SoftmaxOp &op)
+{
+  addAttribute("Axis", std::to_string(op.getAxis()));
+}
+
+void DotNodeBuilder::visit(ops::SliceOp &op)
+{
+  addAttribute("Starts", toString(op.getStarts()));
+  addAttribute("Sizes", toString(op.getSizes()));
+}
+
+void DotNodeBuilder::visit(ops::DeConv2DOp &op)
+{
+  addAttribute("Padding before", toString(op.getPaddingBefore()));
+  addAttribute("Padding after", toString(op.getPaddingAfter()));
+  addAttribute("Strides", toString(op.getStrides()));
+  addAttribute("Data format", toString(op.getDataFormat()));
+}
+
+void DotNodeBuilder::visit(ops::EluOp &op) { addAttribute("Alpha", std::to_string(op.getAlpha())); }
+
+void DotNodeBuilder::visit(ops::SqueezeOp &op)
+{
+  addAttribute("Dims to squeeze", toString(op.getDimsToSqueeze()));
+}
+
+void mir::DotNodeBuilder::visit(ops::PadOp &op)
+{
+  addAttribute("Padding before", toString(op.getPaddingBefore()));
+  addAttribute("Padding after", toString(op.getPaddingAfter()));
+  addAttribute("Padding value", std::to_string(op.getPaddingValue()));
+}
+
+void DotNodeBuilder::visit(ops::ReduceMeanOp &op)
+{
+  addAttribute("Reduction dims", toString(op.getReductionDims()));
+  addAttribute("Keep dims", std::to_string(op.getKeepDims()));
+}
+
+void DotNodeBuilder::visit(ops::ResizeOp &op)
+{
+  assert(op.getMode() == ops::ResizeOp::ResizeMethod::nearestNeighbor);
+
+  addAttribute("Interpolation mode", "nearestNeighbor");
+}
+
+void DotNodeBuilder::visit(ops::TransposeOp &op)
+{
+  addAttribute("Axis order", toString(op.getAxisOrder()));
+}
+
+void DotNodeBuilder::visit(ops::GatherOp &op)
+{
+  addAttribute("Axis", std::to_string(op.getAxis()));
+}
+
+void DotNodeBuilder::visit(mir::ops::LeakyReluOp &op)
+{
+  addAttribute("Alpha", std::to_string(op.getAlpha()));
+}
+
+void DotNodeBuilder::addAttribute(std::string name, std::string val)
+{
+  this->_attributes.emplace_back(std::move(name), std::move(val));
+}
+
+std::string DotNodeBuilder::getLabel() const
+{
+  std::stringstream ss;
+
+  ss << "{" << _type_name << " | {{";
+
+  for (std::size_t i = 0; i < _in_shapes.size(); ++i)
+  {
+    if (i != 0)
+      ss << " | ";
+    ss << "in" << i << ": " << _in_shapes[i];
+  }
+
+  ss << " | ";
+
+  for (std::size_t i = 0; i < _out_shapes.size(); ++i)
+  {
+    if (i != 0)
+      ss << " | ";
+    ss << "out" << i << ": " << _out_shapes[i];
+  }
+
+  ss << "} | {";
+
+  for (std::size_t i = 0; i < _attributes.size(); ++i)
+  {
+    if (i != 0)
+      ss << " | ";
+    ss << _attributes[i].first << ": " << _attributes[i].second;
+  }
+
+  ss << "}}}";
+
+  return ss.str();
+}
+
+} // namespace mir
diff --git a/compiler/mir/src/DotNodeBuilder.h b/compiler/mir/src/DotNodeBuilder.h
new file mode 100644 (file)
index 0000000..09eba0c
--- /dev/null
@@ -0,0 +1,70 @@
+/*
+ * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef _MIR_DOT_NODE_BUILDER_H_
+#define _MIR_DOT_NODE_BUILDER_H_
+
+#include "DotGraph.h"
+
+#include "mir/Visitor.h"
+
+#include <cstddef>
+#include <string>
+#include <vector>
+
+namespace mir
+{
+
+class DotNodeBuilder : public Visitor
+{
+public:
+  explicit DotNodeBuilder(const Operation &op);
+
+  void visit(ops::AvgPool2DOp &op) override;
+  void visit(ops::CappedReluOp &op) override;
+  void visit(ops::ConcatOp &op) override;
+  void visit(ops::Conv2DOp &op) override;
+  void visit(ops::DeConv2DOp &op) override;
+  void visit(ops::DepthwiseConv2DOp &op) override;
+  void visit(ops::EluOp &op) override;
+  void visit(ops::GatherOp &op) override;
+  void visit(ops::LeakyReluOp &op) override;
+  void visit(ops::MaxPool2DOp &op) override;
+  void visit(ops::PadOp &op) override;
+  void visit(ops::ReduceMeanOp &op) override;
+  void visit(ops::ResizeOp &op) override;
+  void visit(ops::SliceOp &op) override;
+  void visit(ops::SoftmaxOp &op) override;
+  void visit(ops::SqueezeOp &op) override;
+  void visit(ops::TransposeOp &op) override;
+
+  void addAttribute(std::string name, std::string val);
+
+  DotNode getDotNode() const { return {_id, getLabel()}; }
+
+private:
+  std::string getLabel() const;
+
+  std::size_t _id;
+  std::string _type_name;
+  std::vector<std::string> _in_shapes;
+  std::vector<std::string> _out_shapes;
+  std::vector<std::pair<std::string, std::string>> _attributes;
+};
+
+} // namespace mir
+
+#endif //_MIR_DOT_NODE_BUILDER_H_
index 829b53a..e40255c 100644 (file)
  * limitations under the License.
  */
 
-#include "mir/Graph.h"
 #include "mir/IrDotDumper.h"
-#include "mir/OpDefs.h"
-
-#include <map>
+#include "mir/Graph.h"
+#include "DotGraph.h"
+#include "DotNodeBuilder.h"
 
 namespace mir
 {
 
-static std::vector<Shape> getInputShapes(const Operation &op)
-{
-  std::vector<Shape> shapes;
-  for (std::size_t i = 0; i < op.getNumInputs(); ++i)
-  {
-    shapes.push_back(op.getInputShape(i));
-  }
-  return shapes;
-}
-
-static std::vector<Shape> getOutputShapes(const Operation &op)
-{
-  std::vector<Shape> shapes;
-  for (std::size_t i = 0; i < op.getNumOutputs(); ++i)
-  {
-    shapes.push_back(op.getOutputShape(i));
-  }
-  return shapes;
-}
-
-void IrDotDumper::visit(ops::AddOp &op) { addSimpleOperation(op); }
-
-void IrDotDumper::visit(ops::AvgPool2DOp &op)
-{
-  auto nodeInfo = DotIrNodeInfo()
-                      .withType("AvgPool2D")
-                      .withInShapes(getInputShapes(op))
-                      .withOutShapes(getOutputShapes(op))
-                      .withShape("Window size", Shape(op.getWindowSize()))
-                      .withStride(Shape(op.getStrides()))
-                      .withShape("Padding before", Shape(op.getPaddingBefore()))
-                      .withShape("Padding after", Shape(op.getPaddingAfter()));
-
-  _dot_builder.updateWithOp(&op, nodeInfo);
-}
-
-void IrDotDumper::visit(ops::CappedReluOp &op)
-{
-  auto nodeInfo = DotIrNodeInfo()
-                      .withType("CappedRelu")
-                      .withInShapes(getInputShapes(op))
-                      .withOutShapes(getOutputShapes(op))
-                      .withMisc("Cap", op.getCap());
-
-  _dot_builder.updateWithOp(&op, nodeInfo);
-}
-
-void IrDotDumper::visit(ops::ConcatOp &op)
-{
-  auto nodeInfo = DotIrNodeInfo()
-                      .withType("Concat")
-                      .withInShapes(getInputShapes(op))
-                      .withOutShapes(getOutputShapes(op))
-                      .withMisc("Axis", op.getAxis());
-
-  _dot_builder.updateWithOp(&op, nodeInfo);
-}
-
-void IrDotDumper::visit(ops::Conv2DOp &op)
-{
-  auto nodeInfo = DotIrNodeInfo()
-                      .withType("Conv2D")
-                      .withInShapes(getInputShapes(op))
-                      .withOutShapes(getOutputShapes(op))
-                      .withStride(Shape(op.getStrides()))
-                      .withShape("Padding before", Shape(op.getPaddingBefore()))
-                      .withShape("Padding after", Shape(op.getPaddingAfter()));
-
-  _dot_builder.updateWithOp(&op, nodeInfo);
-}
-
-void IrDotDumper::visit(ops::DepthwiseConv2DOp &op)
-{
-  auto nodeInfo = DotIrNodeInfo()
-                      .withType("DepthwiseConv2D")
-                      .withInShapes(getInputShapes(op))
-                      .withOutShapes(getOutputShapes(op))
-                      .withStride(Shape(op.getStrides()))
-                      .withShape("Padding before", Shape(op.getPaddingBefore()))
-                      .withShape("Padding after", Shape(op.getPaddingAfter()));
-
-  _dot_builder.updateWithOp(&op, nodeInfo);
-}
-
-void IrDotDumper::visit(ops::DivOp &op) { addSimpleOperation(op); }
-
-void IrDotDumper::visit(ops::FullyConnectedOp &op)
-{
-  auto nodeInfo = DotIrNodeInfo()
-                      .withType("FullyConnected")
-                      .withInShapes(getInputShapes(op))
-                      .withOutShapes(getOutputShapes(op));
-
-  _dot_builder.updateWithOp(&op, nodeInfo);
-}
-
-void IrDotDumper::visit(ops::MaxPool2DOp &op)
-{
-  auto nodeInfo = DotIrNodeInfo()
-                      .withType("MaxPool2D")
-                      .withInShapes(getInputShapes(op))
-                      .withOutShapes(getOutputShapes(op))
-                      .withShape("Window size", Shape(op.getWindowSize()))
-                      .withStride(Shape(op.getStrides()))
-                      .withShape("Padding before", Shape(op.getPaddingBefore()))
-                      .withShape("Padding after", Shape(op.getPaddingAfter()));
-
-  _dot_builder.updateWithOp(&op, nodeInfo);
-}
-
-void IrDotDumper::visit(ops::SoftmaxOp &op)
-{
-  auto nodeInfo = DotIrNodeInfo()
-                      .withType("Softmax")
-                      .withInShapes(getInputShapes(op))
-                      .withOutShapes(getOutputShapes(op))
-                      .withMisc("Axis", op.getAxis());
-
-  _dot_builder.updateWithOp(&op, nodeInfo);
-}
-
-void IrDotDumper::visit(ops::ReluOp &op)
-{
-  auto nodeInfo = DotIrNodeInfo()
-                      .withType("ReLU")
-                      .withInShapes(getInputShapes(op))
-                      .withOutShapes(getOutputShapes(op));
-
-  _dot_builder.updateWithOp(&op, nodeInfo);
-}
-
-void IrDotDumper::visit(ops::ReshapeOp &op)
-{
-  auto nodeInfo = DotIrNodeInfo()
-                      .withType("Reshape")
-                      .withInShapes(getInputShapes(op))
-                      .withOutShapes(getOutputShapes(op));
-
-  _dot_builder.updateWithOp(&op, nodeInfo);
-}
-
-void IrDotDumper::visit(ops::InputOp &op)
-{
-  auto nodeInfo = DotIrNodeInfo()
-                      .withType("Input")
-                      .withInShapes(getInputShapes(op))
-                      .withOutShapes(getOutputShapes(op));
-
-  _dot_builder.updateWithOp(&op, nodeInfo);
-}
-
-void IrDotDumper::visit(ops::ConstantOp &op)
-{
-  auto node_info = DotIrNodeInfo()
-                       .withType("Constant")
-                       .withInShapes(getInputShapes(op))
-                       .withOutShapes(getOutputShapes(op));
-
-  _dot_builder.updateWithOp(&op, node_info);
-}
-
-void IrDotDumper::visit(ops::SliceOp &op)
-{
-  auto node_info = DotIrNodeInfo()
-                       .withType("SliceOp")
-                       .withInShapes(getInputShapes(op))
-                       .withShape("Starts", op.getStarts())
-                       .withShape("Sizes", op.getSizes())
-                       .withOutShapes(getOutputShapes(op));
-
-  _dot_builder.updateWithOp(&op, node_info);
-}
-
-void IrDotDumper::visit(ops::DeConv2DOp &op)
-{
-  auto node_info = DotIrNodeInfo()
-                       .withType("DeConv2D")
-                       .withInShapes(getInputShapes(op))
-                       .withOutShapes(getOutputShapes(op))
-                       .withPadType(op.getPaddingType())
-                       .withStride(Shape(op.getStrides()));
-
-  _dot_builder.updateWithOp(&op, node_info);
-}
-
-void IrDotDumper::visit(ops::EluOp &op)
-{
-  auto nodeInfo = DotIrNodeInfo()
-                      .withType("EluOp")
-                      .withInShapes(getInputShapes(op))
-                      .withOutShapes(getOutputShapes(op))
-                      .withMisc("Alpha", op.getAlpha());
-}
-
-void IrDotDumper::visit(ops::TanhOp &op)
-{
-  auto nodeInfo = DotIrNodeInfo()
-                      .withType("TanhOp")
-                      .withInShapes(getInputShapes(op))
-                      .withOutShapes(getOutputShapes(op));
-
-  _dot_builder.updateWithOp(&op, nodeInfo);
-}
-
-void IrDotDumper::visit(ops::SqueezeOp &op)
+void dumpGraph(const Graph *graph, std::ostream &stream)
 {
-  auto node_info = DotIrNodeInfo()
-                       .withType("SqueezeOp")
-                       .withInShapes(getInputShapes(op))
-                       .withOutShapes(getOutputShapes(op));
+  DotGraph dot_graph;
 
-  for (auto dim : op.getDimsToSqueeze())
+  for (const auto *node : graph->getNodes())
   {
-    node_info.withMisc("SqueezeDim", dim);
+    dot_graph.addNode(DotNodeBuilder(*node).getDotNode());
+    for (const auto &input : node->getInputs())
+    {
+      dot_graph.addEdge({input.getProducer()->getNode()->getId(), node->getId()});
+    }
   }
 
-  _dot_builder.updateWithOp(&op, node_info);
-}
-
-void IrDotDumper::visit(ops::SubOp &op) { addSimpleOperation(op); }
-
-void mir::IrDotDumper::visit(ops::PadOp &op)
-{
-  auto node_info = DotIrNodeInfo()
-                       .withType("PadOp")
-                       .withInShapes(getInputShapes(op))
-                       .withOutShapes(getOutputShapes(op));
-
-  _dot_builder.updateWithOp(&op, node_info);
-}
-
-void IrDotDumper::visit(ops::SqrtOp &op)
-{
-  auto node_info = DotIrNodeInfo()
-                       .withType("Sqrt")
-                       .withInShapes(getInputShapes(op))
-                       .withOutShapes(getOutputShapes(op));
-
-  _dot_builder.updateWithOp(&op, node_info);
-}
-
-void IrDotDumper::visit(ops::ReduceMeanOp &op)
-{
-  auto node_info =
-      DotIrNodeInfo()
-          .withType("ReduceMeanOp")
-          .withInShapes(getInputShapes(op))
-          .withOutShapes(getOutputShapes(op))
-          .withShape("Reduction dims", Shape(op.getReductionDims())) // appropriated shape to dims
-          .withMisc("Keep dims", op.getKeepDims());
-
-  _dot_builder.updateWithOp(&op, node_info);
-}
-
-void IrDotDumper::visit(ops::ResizeOp &op)
-{
-  static const std::map<ops::ResizeOp::ResizeMethod, const char *> modes{
-      {ops::ResizeOp::ResizeMethod::nearestNeighbor, "nearestNeighbor"}};
-
-  auto node_info = DotIrNodeInfo()
-                       .withType("Resize")
-                       .withInShapes(getInputShapes(op))
-                       .withOutShapes(getOutputShapes(op))
-                       .withMisc("Mode", modes.at(op.getMode()));
-  // scale is only needed in Shape Inference
-
-  _dot_builder.updateWithOp(&op, node_info);
-}
-
-void IrDotDumper::visit(ops::TransposeOp &op)
-{
-  auto node_info = DotIrNodeInfo()
-                       .withType("TransposeOp")
-                       .withInShapes(getInputShapes(op))
-                       .withOutShapes(getOutputShapes(op));
-
-  _dot_builder.updateWithOp(&op, node_info);
-}
-
-void IrDotDumper::visit(ops::GatherOp &op)
-{
-  auto node_info = DotIrNodeInfo()
-                       .withType("GatherOp")
-                       .withInShapes(getInputShapes(op))
-                       .withOutShapes(getOutputShapes(op));
-
-  _dot_builder.updateWithOp(&op, node_info);
-}
-
-void IrDotDumper::visit(ops::SigmoidOp &op)
-{
-  auto node_info = DotIrNodeInfo()
-                       .withType("SigmoidOp")
-                       .withInShapes(getInputShapes(op))
-                       .withOutShapes(getOutputShapes(op));
-
-  _dot_builder.updateWithOp(&op, node_info);
-}
-
-void IrDotDumper::visit(mir::ops::LeakyReluOp &op)
-{
-  auto node_info = DotIrNodeInfo()
-                       .withType("LeakyReluOp")
-                       .withInShapes(getInputShapes(op))
-                       .withOutShapes(getOutputShapes(op))
-                       .withMisc("alpha", op.getAlpha());
-
-  _dot_builder.updateWithOp(&op, node_info);
-}
-
-void IrDotDumper::visit(ops::MaxOp &op) { addSimpleOperation(op); }
-
-void IrDotDumper::visit(ops::MulOp &op) { addSimpleOperation(op); }
-
-void IrDotDumper::visit(ops::OutputOp &op)
-{
-  auto node_info = DotIrNodeInfo().withType("OutputOp").withInShapes(getInputShapes(op));
-
-  _dot_builder.updateWithOp(&op, node_info);
-}
-
-void IrDotDumper::addSimpleOperation(const Operation &op)
-{
-  auto node_info = DotIrNodeInfo()
-                       .withType(getTypeName(op.getType()))
-                       .withInShapes(getInputShapes(op))
-                       .withOutShapes(getOutputShapes(op));
-
-  _dot_builder.updateWithOp(&op, node_info);
-}
-
-void dumpGraph(const Graph *graph, std::ostream &stream)
-{
-  IrDotDumper dumper;
-  const_cast<Graph *>(graph)->accept(&dumper);
-  dumper.writeDot(stream);
+  stream << dot_graph;
 }
 
 } // namespace mir
index 0ef79b5..d6250e2 100644 (file)
@@ -18,6 +18,7 @@
 
 #include <algorithm>
 #include <cassert>
+#include <sstream>
 
 namespace mir
 {
@@ -39,22 +40,23 @@ int32_t Shape::numElements() const
   return res;
 }
 
-std::ostream &operator<<(std::ostream &s, const Shape &sh)
+std::string toString(const Shape &shape)
 {
-  int32_t rank = sh.rank();
-  s << "[";
+  std::stringstream ss;
 
-  for (int32_t axis = 0; axis < rank; ++axis)
+  ss << "[";
+  for (int32_t axis = 0; axis < shape.rank(); ++axis)
   {
     if (axis != 0)
-      s << ", ";
-    if (sh.dim(axis) == Shape::autoDim)
-      s << "AUTO";
+      ss << ", ";
+    if (shape.dim(axis) == Shape::autoDim)
+      ss << "AUTO";
     else
-      s << sh.dim(axis);
+      ss << shape.dim(axis);
   }
-  s << "]";
-  return s;
+  ss << "]";
+
+  return ss.str();
 }
 
 } // namespace mir
diff --git a/compiler/mir/src/ir_dot_node_info.cpp b/compiler/mir/src/ir_dot_node_info.cpp
deleted file mode 100644 (file)
index e0968ff..0000000
+++ /dev/null
@@ -1,210 +0,0 @@
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#include <sstream>
-#include <iostream>
-
-#include "mir/ir_dot_node_info.h"
-
-namespace mir
-{
-
-template <> DotIrNodeInfo::Stringable::Stringable(std::string val) : _val(std::move(val)) {}
-
-template <> DotIrNodeInfo::Stringable::Stringable(const char *val) : _val(val) {}
-
-DotIrNodeInfo &DotIrNodeInfo::withType(const std::string &type_name)
-{
-  this->_type_name = type_name;
-  return *this;
-}
-
-DotIrNodeInfo &DotIrNodeInfo::withInShapes(DotIrNodeInfo::Shapes &&in_shapes)
-{
-  this->_in_shapes = in_shapes;
-  return *this;
-}
-
-DotIrNodeInfo &DotIrNodeInfo::withOutShapes(DotIrNodeInfo::Shapes &&out_shapes)
-{
-  this->_out_shapes = out_shapes;
-  return *this;
-}
-
-DotIrNodeInfo &DotIrNodeInfo::withStride(const Shape &stride_shape)
-{
-  this->_stride_shape = stride_shape;
-  return *this;
-}
-
-/**
- * @brief Allows dumping arbitrary parameters from layer
- * The values that are actually integers get dumped as integers.
- */
-DotIrNodeInfo &DotIrNodeInfo::withMisc(const std::string &name, Stringable val)
-{
-  this->_misc_vals.emplace_back(name, std::move(val));
-  return *this;
-}
-
-DotIrNodeInfo &DotIrNodeInfo::withShape(const std::string &shape_name, const Shape &shape)
-{
-  this->_shapes.emplace_back(shape_name, shape);
-  return *this;
-}
-
-DotIrNodeInfo &DotIrNodeInfo::withPadType(DotIrNodeInfo::PadType pad_type)
-{
-  this->_pad_type = pad_type;
-  this->_has_pad = true;
-  return *this;
-}
-
-std::string DotIrNodeInfo::getLabel() const
-{
-  std::stringstream ss;
-
-  ss << "{" << _type_name << " | ";
-
-  // Note inputs and output shapes
-  ss << "{{";
-  writeInShapesLabel(ss);
-  ss << " | ";
-  writeOutShapesLabel(ss);
-  ss << "}";
-
-  // Other node parameters - kernel shape, stride, padding type etc
-  std::string label = labelForNodeParams();
-  if (!label.empty())
-    ss << " | {" << label << "}";
-
-  ss << "}}";
-
-  return ss.str();
-}
-
-std::string DotIrNodeInfo::labelForPadAndPool() const
-{
-  std::stringstream ss;
-
-  if (_has_pad)
-  {
-    ss << "{PadType: ";
-    switch (_pad_type)
-    {
-      case PadType::Valid:
-        ss << "VALID";
-        break;
-      case PadType::Same:
-        ss << "SAME";
-        break;
-      case PadType::Custom:
-        ss << "CUSTOM";
-        break;
-      default:
-        assert(false && "Unknown Padding type");
-        break;
-    }
-
-    ss << "}";
-  }
-
-  return ss.str();
-}
-
-void DotIrNodeInfo::writeInShapesLabel(std::stringstream &ss) const
-{
-  if (_in_shapes.empty())
-    ss << "IN_SHAPES_NOT_SET";
-  else
-  {
-    for (Shapes::size_type i = 0; i < _in_shapes.size(); ++i)
-    {
-      if (i != 0)
-        ss << " | ";
-      ss << "in" << i << ": " << _in_shapes[i];
-    }
-  }
-}
-
-void DotIrNodeInfo::writeOutShapesLabel(std::stringstream &ss) const
-{
-  if (_out_shapes.empty())
-    ss << "OUT_SHAPES_NOT_SET";
-  else
-  {
-    for (Shapes::size_type i = 0; i < _out_shapes.size(); ++i)
-    {
-      if (i != 0)
-        ss << "| ";
-      ss << "out" << i << ": " << _out_shapes[i];
-    }
-  }
-}
-
-std::string DotIrNodeInfo::labelForNodeParams() const
-{
-  std::stringstream ss;
-
-  bool need_pipe = false;
-  if (_kernel_shape.rank() != 0)
-  {
-    ss << "Kernel: " << _kernel_shape;
-    need_pipe = true;
-  }
-
-  std::string label = labelForPadAndPool();
-  addPipeIfNeeded(ss, !label.empty(), need_pipe);
-  ss << label;
-
-  addPipeIfNeeded(ss, !_shapes.empty(), need_pipe);
-  for (Shapes::size_type i = 0; i < _shapes.size(); ++i)
-  {
-    if (i != 0)
-      ss << " | ";
-    ss << _shapes[i].first << ": " << _shapes[i].second;
-  }
-
-  if (_stride_shape.rank() != 0)
-  {
-    addPipeIfNeeded(ss, true, need_pipe);
-    ss << "Stride: " << _stride_shape;
-  }
-
-  // misc scalar parameters (Cap, dropRate, etc..)
-  addPipeIfNeeded(ss, !_misc_vals.empty(), need_pipe);
-  for (Shapes::size_type i = 0; i < _misc_vals.size(); ++i)
-  {
-    if (i != 0)
-      ss << " | ";
-    ss << _misc_vals[i].first << ": " << _misc_vals[i].second;
-  }
-
-  return ss.str();
-}
-
-void DotIrNodeInfo::addPipeIfNeeded(std::stringstream &ss, bool needed, bool &need_pipe) const
-{
-  if (needed)
-  {
-    if (need_pipe)
-      ss << " | ";
-    else
-      need_pipe = true;
-  }
-}
-
-} // namespace mir