From 40d3f27fb4cd353a4edcc6c65d003b16242a991a Mon Sep 17 00:00:00 2001 From: =?utf8?q?Dmitry=20Mozolev/AI=20Tools=20Lab=20/SRR/Engineer/=EC=82=BC?= =?utf8?q?=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 13 Jul 2018 16:48:37 +0300 Subject: [PATCH] Create Model IR graph from Caffe model visitor (#635) Implemented Caffe LayerParameter visiting function to call Caffe to Model IR operation converter. Signed-off-by: Dmitry Mozolev --- .../frontend/caffe/include/caffe_model_visitor.h | 4 +- .../libs/frontend/caffe/include/caffe_op_creator.h | 1 - .../frontend/caffe/src/caffe_model_visitor.cpp | 56 ++++++++++++++++++++-- 3 files changed, 54 insertions(+), 7 deletions(-) diff --git a/contrib/nnc/libs/frontend/caffe/include/caffe_model_visitor.h b/contrib/nnc/libs/frontend/caffe/include/caffe_model_visitor.h index 87510be..2088201 100644 --- a/contrib/nnc/libs/frontend/caffe/include/caffe_model_visitor.h +++ b/contrib/nnc/libs/frontend/caffe/include/caffe_model_visitor.h @@ -9,6 +9,7 @@ #include "nnc/core/linalg/TensorVariant.h" #include "caffe_visitor.h" +#include "caffe_op_creator.h" namespace nncc { @@ -29,7 +30,7 @@ using nncc::core::ADT::tensor::Shape; class ModelVisitor : public Visitor { public: - ModelVisitor() : graph(new Graph()) {}; + ModelVisitor() : graph(new Graph()), opCreator(graph) {}; void visit(const NetParameter&) override; void visit(const LayerParameter&) override; @@ -40,6 +41,7 @@ public: private: Graph* graph = nullptr; + OpCreator opCreator; std::vector inputShapes; std::map opsForBlobsTheyOutput; diff --git a/contrib/nnc/libs/frontend/caffe/include/caffe_op_creator.h b/contrib/nnc/libs/frontend/caffe/include/caffe_op_creator.h index 583ec46..61d9a3a 100644 --- a/contrib/nnc/libs/frontend/caffe/include/caffe_op_creator.h +++ b/contrib/nnc/libs/frontend/caffe/include/caffe_op_creator.h @@ -13,7 +13,6 @@ #include "nncc/core/ADT/tensor/Shape.h" #include "caffe/proto/caffe.pb.h" -#include "caffe_model_visitor.h" namespace nncc { diff --git a/contrib/nnc/libs/frontend/caffe/src/caffe_model_visitor.cpp b/contrib/nnc/libs/frontend/caffe/src/caffe_model_visitor.cpp index c5e1198..141c6d5 100644 --- a/contrib/nnc/libs/frontend/caffe/src/caffe_model_visitor.cpp +++ b/contrib/nnc/libs/frontend/caffe/src/caffe_model_visitor.cpp @@ -1,6 +1,5 @@ #include #include -#include #include "nncc/core/ADT/tensor/Shape.h" #include "nnc/core/IR/model/operations/variable_op.h" @@ -28,12 +27,59 @@ void ModelVisitor::visit(const NetParameter& np) void ModelVisitor::visit(const LayerParameter& lp) { - (void)lp; - throw std::runtime_error("Not yet implemented"); + auto inputs = createOpInputs(lp); + auto params = createOpParams(lp); + + std::vector outputs; + + // TODO: support other layer types + + // This is the Input layer + if (lp.has_input_param()) + { + processInputLayer(lp); + } + else if (lp.has_convolution_param()) + { + outputs = opCreator.createConv2D(inputs, params, lp.convolution_param()); + } + else if (lp.has_inner_product_param()) + { + outputs = opCreator.createFullyConnected(inputs, params, lp.inner_product_param()); + } + else if (lp.has_pooling_param()) + { + outputs = opCreator.createPool(inputs, params, lp.pooling_param()); + } + else if (lp.has_concat_param()) + { + outputs = opCreator.createConcat(inputs, params, lp.concat_param()); + } + else if (lp.has_reshape_param()) + { + outputs = opCreator.createReshape(inputs, params, lp.reshape_param()); + } + else if (lp.has_relu_param() || lp.type() == "ReLU") + { + outputs = opCreator.createRelu(inputs, params, lp.relu_param()); + } + else if (lp.has_softmax_param() || lp.type() == "Softmax") + { + outputs = opCreator.createSoftmax(inputs, params, lp.softmax_param()); + } + else + { + throw PluginException("Encountered unsupported Caffe layer type"); + } + + for (auto item : outputs) + { + opsForBlobsTheyOutput[lp.top(0)] = item; + } } -void ModelVisitor::visit(const BlobProto&) { throw std::runtime_error("Not yet implemented"); } -void ModelVisitor::visit(const BlobShape&) { throw std::runtime_error("Not yet implemented"); } +void ModelVisitor::visit(const BlobProto&) {} +void ModelVisitor::visit(const BlobShape&) {} Graph *ModelVisitor::getGraph() { -- 2.7.4