auto nodeInfo = DotIrNodeInfo().withType("CappedRelu", node->getName())
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op))
- .withAxis(op.getCap());
+ .withMisc("Cap", op.getCap());
dotBuilder.updateWithNode(node, nodeInfo);
auto nodeInfo = DotIrNodeInfo().withType("Concat", node->getName())
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op))
- .withAxis(op.getAxis());
+ .withMisc("Axis", op.getAxis());
dotBuilder.updateWithNode(node, nodeInfo);
}
auto nodeInfo = DotIrNodeInfo().withType("Softmax", node->getName())
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op))
- .withAxis(op.getAxis());
+ .withMisc("Axis", op.getAxis());
dotBuilder.updateWithNode(node, nodeInfo);
}
return *this;
}
+/**
+ * @brief Allows dumping arbitrary scalar parameters from layers as floats.
+ * The values that are actually integers get dumped as integers.
+ */
+DotIrNodeInfo &DotIrNodeInfo::withMisc(const std::string &miscName, const float miscVal)
+{
+ this->miscVals.emplace_back(miscName, miscVal);
+ return *this;
+}
+
+
DotIrNodeInfo &DotIrNodeInfo::withShape(const std::string &shapeName, const Shape &shape)
{
this->shapes.emplace_back(shapeName, shape);
return *this;
}
-DotIrNodeInfo &DotIrNodeInfo::withAxis(float axis)
-{
- this->axis = axis;
- return *this;
-}
-
std::string DotIrNodeInfo::getLabel() const
{
std::stringstream ss;
ss << "Stride: " << strideShape;
}
- if (axis != -1)
- {
- addPipeIfNeeded(ss, true, needPipe);
- ss << "Axis: " << axis;
+ //misc scalar parameters (Cap, dropRate, etc..)
+ addPipeIfNeeded(ss, !miscVals.empty(), needPipe);
+ for (Shapes::size_type i = 0; i < miscVals.size(); ++i) {
+ if (i != 0)
+ ss << " | ";
+ ss << miscVals[i].first << ": "<< miscVals[i].second;
}
return ss.str();
* @brief Can collect information about a NN operator, and then use it to output
* this info as a node label in .dot format.
* @usage Provides a typical builder interface for collecting NN operator info, for example:
- * auto info = DotIrNodeInfo().withType("Softmax").withInShapes({{1, 2, 3}, {2, 3, 4}}).withAxis(0);
+ * auto info = DotIrNodeInfo().withType("Softmax").withInShapes({{1, 2, 3}, {2, 3, 4}}).withMisc("Axis", 0));
* Then resulting .dot node label is accessed with info.getLabel();
*/
class DotIrNodeInfo
public:
using Shapes = std::vector<Shape>;
using NamedShape = std::pair<std::string, Shape>;
+ using MiscVal = std::pair<std::string, float>;
using PadType = ops::PaddingType;
using PoolType = ops::PoolOp::PoolingType;
DotIrNodeInfo &withShape(const std::string &shapeName, const Shape &shape);
DotIrNodeInfo &withPadType(PadType padType);
DotIrNodeInfo &withPoolType(PoolType poolType);
- DotIrNodeInfo &withAxis(float axis);
+ DotIrNodeInfo &withMisc(const std::string &miscName, const float miscVal);
/**
* Create a label in dot format for the Model IR node.
Shape strideShape;
std::vector<NamedShape> shapes;
+ std::vector<MiscVal> miscVals;
bool hasPad = false;
PadType padType = PadType::Valid;