#include "nnc/core/linalg/TensorVariant.h"
#include "caffe_visitor.h"
+#include "caffe_op_creator.h"
namespace nncc
{
class ModelVisitor : public Visitor
{
public:
- ModelVisitor() : graph(new Graph()) {};
+ ModelVisitor() : graph(new Graph()), opCreator(graph) {};
void visit(const NetParameter&) override;
void visit(const LayerParameter&) override;
private:
Graph* graph = nullptr;
+ OpCreator opCreator;
std::vector<Shape> inputShapes;
std::map<std::string, INode::Ref> opsForBlobsTheyOutput;
#include <vector>
#include <cassert>
-#include <stdexcept>
#include "nncc/core/ADT/tensor/Shape.h"
#include "nnc/core/IR/model/operations/variable_op.h"
void ModelVisitor::visit(const LayerParameter& lp)
{
- (void)lp;
- throw std::runtime_error("Not yet implemented");
+ auto inputs = createOpInputs(lp);
+ auto params = createOpParams(lp);
+
+ std::vector<INode::Ref> outputs;
+
+ // TODO: support other layer types
+
+ // This is the Input layer
+ if (lp.has_input_param())
+ {
+ processInputLayer(lp);
+ }
+ else if (lp.has_convolution_param())
+ {
+ outputs = opCreator.createConv2D(inputs, params, lp.convolution_param());
+ }
+ else if (lp.has_inner_product_param())
+ {
+ outputs = opCreator.createFullyConnected(inputs, params, lp.inner_product_param());
+ }
+ else if (lp.has_pooling_param())
+ {
+ outputs = opCreator.createPool(inputs, params, lp.pooling_param());
+ }
+ else if (lp.has_concat_param())
+ {
+ outputs = opCreator.createConcat(inputs, params, lp.concat_param());
+ }
+ else if (lp.has_reshape_param())
+ {
+ outputs = opCreator.createReshape(inputs, params, lp.reshape_param());
+ }
+ else if (lp.has_relu_param() || lp.type() == "ReLU")
+ {
+ outputs = opCreator.createRelu(inputs, params, lp.relu_param());
+ }
+ else if (lp.has_softmax_param() || lp.type() == "Softmax")
+ {
+ outputs = opCreator.createSoftmax(inputs, params, lp.softmax_param());
+ }
+ else
+ {
+ throw PluginException("Encountered unsupported Caffe layer type");
+ }
+
+ for (auto item : outputs)
+ {
+ opsForBlobsTheyOutput[lp.top(0)] = item;
+ }
}
-void ModelVisitor::visit(const BlobProto&) { throw std::runtime_error("Not yet implemented"); }
-void ModelVisitor::visit(const BlobShape&) { throw std::runtime_error("Not yet implemented"); }
+void ModelVisitor::visit(const BlobProto&) {}
+void ModelVisitor::visit(const BlobShape&) {}
Graph *ModelVisitor::getGraph()
{