return it->second;
}
-void Graph::accept(Visitor *visitor) {
+void Graph::accept(IVisitor *visitor) {
std::deque<INode::Ref> q;
std::set<INode::Ref> known_nodes;
--- /dev/null
+#include "core/modelIR/visitor.h"
+
+namespace nncc {
+namespace contrib {
+namespace core {
+namespace IR {
+namespace model {
+
+void Visitor::visit(ADT::INode *node, ops::ConcatOp &op) {(void)node; (void)op;};
+void Visitor::visit(ADT::INode *node, ops::Conv2DOp &op) {(void)node; (void)op;};
+void Visitor::visit(ADT::INode *node, ops::DepthwiseConv2DOp &op) {(void)node; (void)op;};
+void Visitor::visit(ADT::INode *node, ops::SoftmaxOp &op) {(void)node; (void)op;};
+void Visitor::visit(ADT::INode *node, ops::PoolOp &op) {(void)node; (void)op;};
+void Visitor::visit(ADT::INode *node, ops::FullyConnectedOp &op) {(void)node; (void)op;};
+void Visitor::visit(ADT::INode *node, ops::CappedReluOp &op) {(void)node; (void)op;};
+void Visitor::visit(ADT::INode *node, ops::BiasAddOp &op) {(void)node; (void)op;};
+void Visitor::visit(ADT::INode *node, ops::VariableOp &op) {(void)node; (void)op;};
+void Visitor::visit(ADT::INode *node, ops::ReluOp &op) {(void)node; (void)op;};
+void Visitor::visit(ADT::INode *node, ops::ReshapeOp &op) {(void)node; (void)op;};
+
+} // namespace model
+} // namespace IR
+} // namespace core
+} // namespace contrib
+} // namespace nncc
using namespace nncc::contrib::core::IR::model;
-class ShapeInference : public Visitor {
+class ShapeInference : public IVisitor {
public:
static const auto AUTO_DIM = std::numeric_limits<uint32_t>::max();
namespace model {
using ADT::INode;
-class Visitor;
+class IVisitor;
class Graph {
public:
return node;
}
- void accept(Visitor *visitor);
+ void accept(IVisitor *visitor);
void markOutput(INode::Ref node);
INode::Ref getInput(const std::string &name);
* @breif Model IR visitor that can be used to output Model IR as a .dot graph.
* @usage Run on a Model IR graph as a visitor, and then call writeDot passing it a stream
*/
-class IrDotDumper : public Visitor
+class IrDotDumper : public IVisitor
{
public:
void visit(INode *node, ops::ConcatOp &op) override;
virtual const std::string &getName() const = 0;
virtual void setName(const std::string &name) = 0;
- virtual void accept(Visitor *v) = 0;
+ virtual void accept(IVisitor *v) = 0;
virtual const IODescriptor getOutput(const size_t index) = 0;
virtual void connectInputTo(const int inputIndex, const IODescriptor &descriptor) = 0;
void setName(const std::string &name) override { _props.name = name; }
- void accept(Visitor *v) override
+ void accept(IVisitor *v) override
{
v->visit(this, *static_cast<OpType*>(_props.op));
}
class ReshapeOp;
}
-class Visitor {
+/**
+ * @brief Visitor Interface declaration
+ */
+class IVisitor {
public:
virtual void visit(ADT::INode *node, ops::ConcatOp &op) = 0;
virtual void visit(ADT::INode *node, ops::Conv2DOp &op) = 0;
virtual void visit(ADT::INode *node, ops::BiasAddOp &op) = 0;
virtual void visit(ADT::INode *node, ops::VariableOp &op) = 0;
virtual void visit(ADT::INode *node, ops::ReluOp &op) = 0;
- virtual void visit(ADT::INode* node, ops::ReshapeOp &op) = 0;
+ virtual void visit(ADT::INode *node, ops::ReshapeOp &op) = 0;
- virtual ~Visitor() = default;
+ virtual ~IVisitor() = default;
+};
+
+/**
+ * @brief Non Pure Virtual implementation of IVisitor
+ * It is used to facilitate adding new operations,
+ * so that we don't have to add more declarations and
+ * only need to define an implementation of `visit` for a subset of operations in the graph,
+ * while not doing anything for all others.
+ */
+class Visitor: public IVisitor{
+public:
+ void visit(ADT::INode *node, ops::ConcatOp &op) override;
+ void visit(ADT::INode *node, ops::Conv2DOp &op) override;
+ void visit(ADT::INode *node, ops::DepthwiseConv2DOp &op) override;
+ void visit(ADT::INode *node, ops::SoftmaxOp &op) override;
+ void visit(ADT::INode *node, ops::PoolOp &op) override;
+ void visit(ADT::INode *node, ops::FullyConnectedOp &op) override;
+ void visit(ADT::INode *node, ops::CappedReluOp &op) override;
+ void visit(ADT::INode *node, ops::BiasAddOp &op) override;
+ void visit(ADT::INode *node, ops::VariableOp &op) override;
+ void visit(ADT::INode *node, ops::ReluOp &op) override;
+ void visit(ADT::INode *node, ops::ReshapeOp &op) override;
};
} // namespace model
using nncc::contrib::core::data::Index;
using nncc::contrib::core::data::Tensor;
-class NNInterpreter : public Visitor
+class NNInterpreter : public IVisitor
{
public:
explicit NNInterpreter() = default;
const size_t INVALID_TENSOR_ID = std::numeric_limits<size_t>::max();
-class ModelAnalyzer: public model::Visitor
+class ModelAnalyzer: public model::IVisitor
{
public:
void visit(ADT::INode *node, ops::ConcatOp &op) override;
namespace ADT = model::ADT;
namespace ops = model::ops;
-class Serializer: public model::Visitor
+class Serializer: public model::IVisitor
{
public: