From c0bda1855ab023decb2e636cc5397d66b3de8d06 Mon Sep 17 00:00:00 2001 From: =?utf8?q?=D0=90=D0=BD=D0=B4=D1=80=D0=B5=D0=B9=20=D0=A8=D0=B5=D0=B4?= =?utf8?q?=D1=8C=D0=BA=D0=BE/AI=20Tools=20Lab=20/SRR/Assistant=20Engineer/?= =?utf8?q?=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 7 Sep 2018 13:37:26 +0300 Subject: [PATCH] [nnc] Scalar param support in IR dumper (#1397) Added scalar parameter support in IR Dot Dumper The params and their names are stored in a `vector>` Signed-off-by: Andrei Shedko --- contrib/nnc/core/modelIR/ir_dot_dumper.cpp | 6 ++--- contrib/nnc/core/modelIR/ir_dot_node_info.cpp | 27 ++++++++++++++-------- .../nnc/include/core/modelIR/ir_dot_node_info.h | 6 +++-- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/contrib/nnc/core/modelIR/ir_dot_dumper.cpp b/contrib/nnc/core/modelIR/ir_dot_dumper.cpp index 2083537..72fbc68 100644 --- a/contrib/nnc/core/modelIR/ir_dot_dumper.cpp +++ b/contrib/nnc/core/modelIR/ir_dot_dumper.cpp @@ -52,7 +52,7 @@ void IrDotDumper::visit(INode *node, ops::CappedReluOp &op) auto nodeInfo = DotIrNodeInfo().withType("CappedRelu", node->getName()) .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)) - .withAxis(op.getCap()); + .withMisc("Cap", op.getCap()); dotBuilder.updateWithNode(node, nodeInfo); @@ -63,7 +63,7 @@ void IrDotDumper::visit(INode *node, ops::ConcatOp &op) auto nodeInfo = DotIrNodeInfo().withType("Concat", node->getName()) .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)) - .withAxis(op.getAxis()); + .withMisc("Axis", op.getAxis()); dotBuilder.updateWithNode(node, nodeInfo); } @@ -107,7 +107,7 @@ void IrDotDumper::visit(INode *node, ops::SoftmaxOp &op) auto nodeInfo = DotIrNodeInfo().withType("Softmax", node->getName()) .withInShapes(getInputShapes(op)) .withOutShapes(getOutputShapes(op)) - .withAxis(op.getAxis()); + .withMisc("Axis", op.getAxis()); dotBuilder.updateWithNode(node, nodeInfo); } diff --git a/contrib/nnc/core/modelIR/ir_dot_node_info.cpp b/contrib/nnc/core/modelIR/ir_dot_node_info.cpp index de251f1..20ab89b 100644 --- a/contrib/nnc/core/modelIR/ir_dot_node_info.cpp +++ b/contrib/nnc/core/modelIR/ir_dot_node_info.cpp @@ -43,6 +43,17 @@ DotIrNodeInfo &DotIrNodeInfo::withStride(const Shape &strideShape) return *this; } +/** + * @brief Allows dumping arbitrary scalar parameters from layers as floats. + * The values that are actually integers get dumped as integers. + */ +DotIrNodeInfo &DotIrNodeInfo::withMisc(const std::string &miscName, const float miscVal) +{ + this->miscVals.emplace_back(miscName, miscVal); + return *this; +} + + DotIrNodeInfo &DotIrNodeInfo::withShape(const std::string &shapeName, const Shape &shape) { this->shapes.emplace_back(shapeName, shape); @@ -63,12 +74,6 @@ DotIrNodeInfo &DotIrNodeInfo::withPoolType(DotIrNodeInfo::PoolType poolType) return *this; } -DotIrNodeInfo &DotIrNodeInfo::withAxis(float axis) -{ - this->axis = axis; - return *this; -} - std::string DotIrNodeInfo::getLabel() const { std::stringstream ss; @@ -202,10 +207,12 @@ std::string DotIrNodeInfo::labelForNodeParams() const ss << "Stride: " << strideShape; } - if (axis != -1) - { - addPipeIfNeeded(ss, true, needPipe); - ss << "Axis: " << axis; + //misc scalar parameters (Cap, dropRate, etc..) + addPipeIfNeeded(ss, !miscVals.empty(), needPipe); + for (Shapes::size_type i = 0; i < miscVals.size(); ++i) { + if (i != 0) + ss << " | "; + ss << miscVals[i].first << ": "<< miscVals[i].second; } return ss.str(); diff --git a/contrib/nnc/include/core/modelIR/ir_dot_node_info.h b/contrib/nnc/include/core/modelIR/ir_dot_node_info.h index 1fdd32b..4c7fe72 100644 --- a/contrib/nnc/include/core/modelIR/ir_dot_node_info.h +++ b/contrib/nnc/include/core/modelIR/ir_dot_node_info.h @@ -21,7 +21,7 @@ using namespace nncc::contrib::core::data; * @brief Can collect information about a NN operator, and then use it to output * this info as a node label in .dot format. * @usage Provides a typical builder interface for collecting NN operator info, for example: - * auto info = DotIrNodeInfo().withType("Softmax").withInShapes({{1, 2, 3}, {2, 3, 4}}).withAxis(0); + * auto info = DotIrNodeInfo().withType("Softmax").withInShapes({{1, 2, 3}, {2, 3, 4}}).withMisc("Axis", 0)); * Then resulting .dot node label is accessed with info.getLabel(); */ class DotIrNodeInfo @@ -29,6 +29,7 @@ class DotIrNodeInfo public: using Shapes = std::vector; using NamedShape = std::pair; + using MiscVal = std::pair; using PadType = ops::PaddingType; using PoolType = ops::PoolOp::PoolingType; @@ -42,7 +43,7 @@ public: DotIrNodeInfo &withShape(const std::string &shapeName, const Shape &shape); DotIrNodeInfo &withPadType(PadType padType); DotIrNodeInfo &withPoolType(PoolType poolType); - DotIrNodeInfo &withAxis(float axis); + DotIrNodeInfo &withMisc(const std::string &miscName, const float miscVal); /** * Create a label in dot format for the Model IR node. @@ -74,6 +75,7 @@ private: Shape strideShape; std::vector shapes; + std::vector miscVals; bool hasPad = false; PadType padType = PadType::Valid; -- 2.7.4