[nnc] Scalar param support in IR dumper (#1397)
authorАндрей Шедько/AI Tools Lab /SRR/Assistant Engineer/삼성전자 <a.shedko@partner.samsung.com>
Fri, 7 Sep 2018 10:37:26 +0000 (13:37 +0300)
committerРоман Михайлович Русяев/AI Tools Lab /SRR/Staff Engineer/삼성전자 <r.rusyaev@samsung.com>
Fri, 7 Sep 2018 10:37:26 +0000 (13:37 +0300)
Added scalar parameter support in IR Dot Dumper
The params and their names are stored in a `vector<pair<string,float>>`

Signed-off-by: Andrei Shedko <a.shedko@partner.samsung.com>
contrib/nnc/core/modelIR/ir_dot_dumper.cpp
contrib/nnc/core/modelIR/ir_dot_node_info.cpp
contrib/nnc/include/core/modelIR/ir_dot_node_info.h

index 2083537..72fbc68 100644 (file)
@@ -52,7 +52,7 @@ void IrDotDumper::visit(INode *node, ops::CappedReluOp &op)
   auto nodeInfo = DotIrNodeInfo().withType("CappedRelu", node->getName())
                                  .withInShapes(getInputShapes(op))
                                  .withOutShapes(getOutputShapes(op))
-                                 .withAxis(op.getCap());
+                                 .withMisc("Cap", op.getCap());
 
   dotBuilder.updateWithNode(node, nodeInfo);
 
@@ -63,7 +63,7 @@ void IrDotDumper::visit(INode *node, ops::ConcatOp &op)
   auto nodeInfo = DotIrNodeInfo().withType("Concat", node->getName())
                                  .withInShapes(getInputShapes(op))
                                  .withOutShapes(getOutputShapes(op))
-                                 .withAxis(op.getAxis());
+                                 .withMisc("Axis", op.getAxis());
 
   dotBuilder.updateWithNode(node, nodeInfo);
 }
@@ -107,7 +107,7 @@ void IrDotDumper::visit(INode *node, ops::SoftmaxOp &op)
   auto nodeInfo = DotIrNodeInfo().withType("Softmax", node->getName())
                                  .withInShapes(getInputShapes(op))
                                  .withOutShapes(getOutputShapes(op))
-                                 .withAxis(op.getAxis());
+                                 .withMisc("Axis", op.getAxis());
 
   dotBuilder.updateWithNode(node, nodeInfo);
 }
index de251f1..20ab89b 100644 (file)
@@ -43,6 +43,17 @@ DotIrNodeInfo &DotIrNodeInfo::withStride(const Shape &strideShape)
   return *this;
 }
 
+/**
+ * @brief Allows dumping arbitrary scalar parameters from layers as floats.
+ * The values that are actually integers get dumped as integers.
+ */
+DotIrNodeInfo &DotIrNodeInfo::withMisc(const std::string &miscName, const float miscVal)
+{
+  this->miscVals.emplace_back(miscName, miscVal);
+  return *this;
+}
+
+
 DotIrNodeInfo &DotIrNodeInfo::withShape(const std::string &shapeName, const Shape &shape)
 {
   this->shapes.emplace_back(shapeName, shape);
@@ -63,12 +74,6 @@ DotIrNodeInfo &DotIrNodeInfo::withPoolType(DotIrNodeInfo::PoolType poolType)
   return *this;
 }
 
-DotIrNodeInfo &DotIrNodeInfo::withAxis(float axis)
-{
-  this->axis = axis;
-  return *this;
-}
-
 std::string DotIrNodeInfo::getLabel() const
 {
   std::stringstream ss;
@@ -202,10 +207,12 @@ std::string DotIrNodeInfo::labelForNodeParams() const
     ss << "Stride: " << strideShape;
   }
 
-  if (axis != -1)
-  {
-    addPipeIfNeeded(ss, true, needPipe);
-    ss << "Axis: " << axis;
+  //misc scalar parameters (Cap, dropRate, etc..)
+  addPipeIfNeeded(ss, !miscVals.empty(), needPipe);
+  for (Shapes::size_type i = 0; i < miscVals.size(); ++i) {
+    if (i != 0)
+      ss << " | ";
+    ss << miscVals[i].first << ": "<< miscVals[i].second;
   }
 
   return ss.str();
index 1fdd32b..4c7fe72 100644 (file)
@@ -21,7 +21,7 @@ using namespace nncc::contrib::core::data;
  * @brief Can collect information about a NN operator, and then use it to output
  * this info as a node label in .dot format.
  * @usage Provides a typical builder interface for collecting NN operator info, for example:
- * auto info = DotIrNodeInfo().withType("Softmax").withInShapes({{1, 2, 3}, {2, 3, 4}}).withAxis(0);
+ * auto info = DotIrNodeInfo().withType("Softmax").withInShapes({{1, 2, 3}, {2, 3, 4}}).withMisc("Axis", 0));
  * Then resulting .dot node label is accessed with info.getLabel();
  */
 class DotIrNodeInfo
@@ -29,6 +29,7 @@ class DotIrNodeInfo
 public:
   using Shapes = std::vector<Shape>;
   using NamedShape = std::pair<std::string, Shape>;
+  using MiscVal = std::pair<std::string, float>;
   using PadType = ops::PaddingType;
   using PoolType = ops::PoolOp::PoolingType;
 
@@ -42,7 +43,7 @@ public:
   DotIrNodeInfo &withShape(const std::string &shapeName, const Shape &shape);
   DotIrNodeInfo &withPadType(PadType padType);
   DotIrNodeInfo &withPoolType(PoolType poolType);
-  DotIrNodeInfo &withAxis(float axis);
+  DotIrNodeInfo &withMisc(const std::string &miscName, const float miscVal);
 
 /**
  * Create a label in dot format for the Model IR node.
@@ -74,6 +75,7 @@ private:
 
   Shape strideShape;
   std::vector<NamedShape> shapes;
+  std::vector<MiscVal> miscVals;
 
   bool hasPad = false;
   PadType padType = PadType::Valid;