[nnc] Change IrDotDumper handling of misc values (#2542)
authorVladimir Plazun/AI Tools Lab /SRR/Engineer/삼성전자 <v.plazun@samsung.com>
Wed, 12 Dec 2018 12:50:02 +0000 (15:50 +0300)
committerEfimov Alexander/AI Tools Lab/./Samsung Electronics <a.efimov@samsung.com>
Wed, 12 Dec 2018 12:50:02 +0000 (15:50 +0300)
Makes all misc values into strings and adds automatic conversion routine

Signed-off-by: Vladimir Plazun <v.plazun@partner.samsung.com>
contrib/nnc/core/modelIR/IrDotDumper.cpp
contrib/nnc/core/modelIR/ir_dot_node_info.cpp
contrib/nnc/include/core/modelIR/ir_dot_node_info.h

index c0ad967..84d261e 100644 (file)
@@ -232,12 +232,19 @@ void IrDotDumper::visit(ops::TanhOp& op) {
 }
 
 void mir::IrDotDumper::visit(ops::ElementwiseOp& op) {
-  auto nodeInfo = DotIrNodeInfo().withType("TanhOp", op.getName())
+  static const std::map<ops::ElementwiseOp::OpType, const char*> op_types{
+    {ops::ElementwiseOp::OpType::mul, "mul"},
+    {ops::ElementwiseOp::OpType::add, "add"},
+    {ops::ElementwiseOp::OpType::max, "max"},
+    {ops::ElementwiseOp::OpType::div, "div"}
+  };
+
+  auto node_info = DotIrNodeInfo().withType("ElementwiseOp", op.getName())
     .withInShapes(getInputShapes(op))
     .withOutShapes(getOutputShapes(op))
-    .withMisc("Operation", ( int ) op.getOpType());
+    .withMisc("Operation", op_types.at(op.getOpType()));
 
-  dotBuilder.updateWithOp(&op, nodeInfo);
+  dotBuilder.updateWithOp(&op, node_info);
 }
 
 void IrDotDumper::visit(ops::SqueezeOp& op) {
@@ -261,22 +268,30 @@ void mir::IrDotDumper::visit(ops::PadOp& op) {
 }
 
 void IrDotDumper::visit(ops::ReduceFOp& op) {
+  static const std::map<ops::ReduceFOp::FuncType, const char*> types{
+    {ops::ReduceFOp::FuncType::mean, "mean"}
+  };
+
   auto node_info = DotIrNodeInfo().withType("ReduceFOp", op.getName())
     .withInShapes(getInputShapes(op))
     .withOutShapes(getOutputShapes(op))
     .withShape("Reduction dims", Shape(op.getReductionDims())) // appropriated shape to dims
     .withMisc("Keep Dims", op.getKeepDims())
-    .withMisc("OPType", (float) op.getFuncType());
+    .withMisc("OPType", types.at(op.getFuncType()));
 
   dotBuilder.updateWithOp(&op, node_info);
 }
 
 void IrDotDumper::visit(ops::ResizeOp& op) {
+  static const std::map<ops::ResizeOp::ResizeMethod, const char*> modes{
+    {ops::ResizeOp::ResizeMethod::nearestNeighbor, "nearestNeighbor"}
+  };
+
   auto node_info = DotIrNodeInfo().withType("Resize", op.getName())
     .withInShapes(getInputShapes(op))
     .withOutShapes(getOutputShapes(op))
-    .withMisc("Mode", (int) op.getMode());
-  // scale and resShape are only needed in Shape Inference
+    .withMisc("Mode", modes.at(op.getMode()));
+  // scale is only needed in Shape Inference
   
   dotBuilder.updateWithOp(&op, node_info);
 }
index ffb42e6..f49dba1 100644 (file)
@@ -24,6 +24,12 @@ namespace nnc
 namespace mir
 {
 
+template <>
+DotIrNodeInfo::Stringable::Stringable(std::string val) : _val(std::move(val)) {}
+
+template <>
+DotIrNodeInfo::Stringable::Stringable(const char* val) : _val(val) {}
+
 DotIrNodeInfo &DotIrNodeInfo::withType(const std::string &typeName, const std::string &nodeName)
 {
   this->typeName = typeName;
@@ -56,12 +62,11 @@ DotIrNodeInfo &DotIrNodeInfo::withStride(const Shape &strideShape)
 }
 
 /**
- * @brief Allows dumping arbitrary scalar parameters from layers as floats.
+ * @brief Allows dumping arbitrary parameters from layer
  * The values that are actually integers get dumped as integers.
  */
-DotIrNodeInfo &DotIrNodeInfo::withMisc(const std::string &miscName, const float miscVal)
-{
-  this->miscVals.emplace_back(miscName, miscVal);
+DotIrNodeInfo& DotIrNodeInfo::withMisc(const std::string& name, Stringable val) {
+  this->miscVals.emplace_back(name, std::move(val));
   return *this;
 }
 
@@ -254,4 +259,4 @@ void DotIrNodeInfo::addPipeIfNeeded(std::stringstream &ss, bool needed, bool &ne
 }
 
 } // namespace mir
-} // namespace nnc
\ No newline at end of file
+} // namespace nnc
index c6e97f7..ec0fc75 100644 (file)
@@ -38,10 +38,24 @@ class DotIrNodeInfo
 public:
   using Shapes = std::vector<Shape>;
   using NamedShape = std::pair<std::string, Shape>;
-  using MiscVal = std::pair<std::string, float>;
+  using MiscVal = std::pair<std::string, std::string>;
   using PadType = ops::PaddingType;
   using PoolType = ops::PoolOp::PoolingType;
 
+  class Stringable {
+  public:
+    template <typename T>
+    /*implicit*/ Stringable(T val);
+
+
+    operator std::string&&() {
+      return std::move(_val);
+    }
+
+  private:
+    std::string _val;
+  };
+
   DotIrNodeInfo() = default;
 
   DotIrNodeInfo &withType(const std::string &typeName, const std::string &nodeName);
@@ -52,7 +66,7 @@ public:
   DotIrNodeInfo &withShape(const std::string &shapeName, const Shape &shape);
   DotIrNodeInfo &withPadType(PadType padType);
   DotIrNodeInfo &withPoolType(PoolType poolType);
-  DotIrNodeInfo &withMisc(const std::string &miscName, const float miscVal);
+  DotIrNodeInfo& withMisc(const std::string&, Stringable);
 
 /**
  * Create a label in dot format for the Model IR node.
@@ -93,6 +107,15 @@ private:
   PoolType poolType = PoolType::MAX;
 };
 
+template <typename T>
+DotIrNodeInfo::Stringable::Stringable(T val) : _val(std::to_string(val)) {}
+
+template <>
+DotIrNodeInfo::Stringable::Stringable(std::string val);
+
+template <>
+DotIrNodeInfo::Stringable::Stringable(const char* val);
+
 } // namespace mir
 } // namespace nnc