[neurun] DotDumper shows model outputs (#3590)
author이한종/동작제어Lab(SR)/Engineer/삼성전자 <hanjoung.lee@samsung.com>
Thu, 15 Nov 2018 04:17:01 +0000 (13:17 +0900)
committer이춘석/동작제어Lab(SR)/Staff Engineer/삼성전자 <chunseok.lee@samsung.com>
Thu, 15 Nov 2018 04:17:01 +0000 (13:17 +0900)
Just like model inputs, this commit shows model outputs as a different
shape.

Signed-off-by: Hanjoung Lee <hanjoung.lee@samsung.com>
runtimes/neurun/src/dumper/dot/DotDumper.cc
runtimes/neurun/src/dumper/dot/DotOperandInfo.cc
runtimes/neurun/src/dumper/dot/DotOperandInfo.h

index dc431da..6186d0e 100644 (file)
@@ -40,8 +40,9 @@ void DotDumper::dump(const std::string &tag, uint32_t option)
 
     for (auto output : node.getOutputs())
     {
-      auto child =
-          std::make_shared<neurun::dumper::dot::DotOperandInfo>(output, operands.at(output));
+      using neurun::dumper::dot::DotOperandInfo;
+      auto child = std::make_shared<DotOperandInfo>(output, operands.at(output),
+                                                    DotOperandInfo::Type::MODEL_OUTPUT);
       node_info.appendChild(child);
     }
 
@@ -66,7 +67,17 @@ void DotDumper::dump(const std::string &tag, uint32_t option)
     }
     if (showing_cond)
     {
-      neurun::dumper::dot::DotOperandInfo operand_info(index, object);
+      auto type = [&]() {
+        using neurun::dumper::dot::DotOperandInfo;
+        if (_graph.getInputs().contains(index))
+          return DotOperandInfo::Type::MODEL_INPUT;
+        if (_graph.getOutputs().contains(index))
+          return DotOperandInfo::Type::MODEL_OUTPUT;
+        return DotOperandInfo::Type::INTERNAL;
+      }();
+
+      neurun::dumper::dot::DotOperandInfo operand_info(index, object, type);
+
       for (auto operation_index : object.getUses().list())
       {
         auto &node = operations.at(operation_index);
index b138173..334e3b3 100644 (file)
@@ -29,14 +29,15 @@ namespace dot
 {
 
 const std::string DotOperandInfo::INPUT_SHAPE = "Mdiamond";
+const std::string DotOperandInfo::OUTPUT_SHAPE = "Mdiamond";
 const std::string DotOperandInfo::OPERAND_SHAPE = "ellipse";
 const std::string DotOperandInfo::BG_COLOR_SCHEME = "set38";
 // RED BLUE ORANGE YELLOW GREEN PUPLE CYAN PINK
 const std::string DotOperandInfo::BG_COLORS[8] = {"4", "5", "6", "2", "7", "3", "1", "8"};
 
 DotOperandInfo::DotOperandInfo(const neurun::graph::operand::Index &index,
-                               const neurun::graph::operand::Object &object)
-    : _index(index), _object(object)
+                               const neurun::graph::operand::Object &object, Type type)
+    : _index(index), _object(object), _type(type)
 {
   const auto &lower_info = object.lower_info();
   if (lower_info)
@@ -67,12 +68,19 @@ std::string DotOperandInfo::label() const
 
 std::string DotOperandInfo::dot_shape() const
 {
-  if (_object.isModelInput())
+  switch (_type)
   {
-    return INPUT_SHAPE;
-  }
+    case Type::MODEL_INPUT:
+      return INPUT_SHAPE;
+
+    case Type::MODEL_OUTPUT:
+      return OUTPUT_SHAPE;
 
-  return OPERAND_SHAPE;
+    case Type::UNDEFINED:
+    case Type::INTERNAL:
+    default:
+      return OPERAND_SHAPE;
+  }
 }
 
 std::string DotOperandInfo::bg_color_scheme() const { return BG_COLOR_SCHEME; }
index bb937c3..18f633d 100644 (file)
@@ -33,14 +33,24 @@ namespace dot
 class DotOperandInfo : public IDotInfo
 {
 public:
+  enum class Type
+  {
+    UNDEFINED,
+    MODEL_INPUT,
+    MODEL_OUTPUT,
+    INTERNAL
+  };
+
+public:
   static const std::string INPUT_SHAPE;
+  static const std::string OUTPUT_SHAPE;
   static const std::string OPERAND_SHAPE;
   static const std::string BG_COLOR_SCHEME;
   static const std::string BG_COLORS[8];
 
 public:
   DotOperandInfo(const neurun::graph::operand::Index &index,
-                 const neurun::graph::operand::Object &object);
+                 const neurun::graph::operand::Object &object, Type type);
 
 public:
   virtual std::string index_str() const override;
@@ -55,6 +65,7 @@ private:
 private:
   const neurun::graph::operand::Index &_index;
   const neurun::graph::operand::Object &_object;
+  Type _type;
 
   std::vector<std::string> _labels;
 };