}
void mir::IrDotDumper::visit(ops::ElementwiseOp& op) {
- auto nodeInfo = DotIrNodeInfo().withType("TanhOp", op.getName())
+ static const std::map<ops::ElementwiseOp::OpType, const char*> op_types{
+ {ops::ElementwiseOp::OpType::mul, "mul"},
+ {ops::ElementwiseOp::OpType::add, "add"},
+ {ops::ElementwiseOp::OpType::max, "max"},
+ {ops::ElementwiseOp::OpType::div, "div"}
+ };
+
+ auto node_info = DotIrNodeInfo().withType("ElementwiseOp", op.getName())
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op))
- .withMisc("Operation", ( int ) op.getOpType());
+ .withMisc("Operation", op_types.at(op.getOpType()));
- dotBuilder.updateWithOp(&op, nodeInfo);
+ dotBuilder.updateWithOp(&op, node_info);
}
void IrDotDumper::visit(ops::SqueezeOp& op) {
}
void IrDotDumper::visit(ops::ReduceFOp& op) {
+ static const std::map<ops::ReduceFOp::FuncType, const char*> types{
+ {ops::ReduceFOp::FuncType::mean, "mean"}
+ };
+
auto node_info = DotIrNodeInfo().withType("ReduceFOp", op.getName())
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op))
.withShape("Reduction dims", Shape(op.getReductionDims())) // appropriated shape to dims
.withMisc("Keep Dims", op.getKeepDims())
- .withMisc("OPType", (float) op.getFuncType());
+ .withMisc("OPType", types.at(op.getFuncType()));
dotBuilder.updateWithOp(&op, node_info);
}
void IrDotDumper::visit(ops::ResizeOp& op) {
+ static const std::map<ops::ResizeOp::ResizeMethod, const char*> modes{
+ {ops::ResizeOp::ResizeMethod::nearestNeighbor, "nearestNeighbor"}
+ };
+
auto node_info = DotIrNodeInfo().withType("Resize", op.getName())
.withInShapes(getInputShapes(op))
.withOutShapes(getOutputShapes(op))
- .withMisc("Mode", (int) op.getMode());
- // scale and resShape are only needed in Shape Inference
+ .withMisc("Mode", modes.at(op.getMode()));
+ // scale is only needed in Shape Inference
dotBuilder.updateWithOp(&op, node_info);
}
namespace mir
{
+template <>
+DotIrNodeInfo::Stringable::Stringable(std::string val) : _val(std::move(val)) {}
+
+template <>
+DotIrNodeInfo::Stringable::Stringable(const char* val) : _val(val) {}
+
DotIrNodeInfo &DotIrNodeInfo::withType(const std::string &typeName, const std::string &nodeName)
{
this->typeName = typeName;
}
/**
- * @brief Allows dumping arbitrary scalar parameters from layers as floats.
+ * @brief Allows dumping arbitrary parameters from layer
* The values that are actually integers get dumped as integers.
*/
-DotIrNodeInfo &DotIrNodeInfo::withMisc(const std::string &miscName, const float miscVal)
-{
- this->miscVals.emplace_back(miscName, miscVal);
+DotIrNodeInfo& DotIrNodeInfo::withMisc(const std::string& name, Stringable val) {
+ this->miscVals.emplace_back(name, std::move(val));
return *this;
}
}
} // namespace mir
-} // namespace nnc
\ No newline at end of file
+} // namespace nnc
public:
using Shapes = std::vector<Shape>;
using NamedShape = std::pair<std::string, Shape>;
- using MiscVal = std::pair<std::string, float>;
+ using MiscVal = std::pair<std::string, std::string>;
using PadType = ops::PaddingType;
using PoolType = ops::PoolOp::PoolingType;
+ class Stringable {
+ public:
+ template <typename T>
+ /*implicit*/ Stringable(T val);
+
+
+ operator std::string&&() {
+ return std::move(_val);
+ }
+
+ private:
+ std::string _val;
+ };
+
DotIrNodeInfo() = default;
DotIrNodeInfo &withType(const std::string &typeName, const std::string &nodeName);
DotIrNodeInfo &withShape(const std::string &shapeName, const Shape &shape);
DotIrNodeInfo &withPadType(PadType padType);
DotIrNodeInfo &withPoolType(PoolType poolType);
- DotIrNodeInfo &withMisc(const std::string &miscName, const float miscVal);
+ DotIrNodeInfo& withMisc(const std::string&, Stringable);
/**
* Create a label in dot format for the Model IR node.
PoolType poolType = PoolType::MAX;
};
+template <typename T>
+DotIrNodeInfo::Stringable::Stringable(T val) : _val(std::to_string(val)) {}
+
+template <>
+DotIrNodeInfo::Stringable::Stringable(std::string val);
+
+template <>
+DotIrNodeInfo::Stringable::Stringable(const char* val);
+
} // namespace mir
} // namespace nnc