From 911039e6300eebd517b6aba93b40969508cca193 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Dmitry=20Mozolev/AI=20Tools=20Lab=20/SRR/Engineer/=EC=82=BC?= =?utf8?q?=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Fri, 13 Jul 2018 16:26:40 +0300 Subject: [PATCH] Add Caffe model input processing (#596) Process Caffe model Input layer and set up Model IR graph inputs. Signed-off-by: Dmitry Mozolev --- .../frontend/caffe/src/caffe_model_visitor.cpp | 57 ++++++++++++++++++---- 1 file changed, 48 insertions(+), 9 deletions(-) diff --git a/contrib/nnc/libs/frontend/caffe/src/caffe_model_visitor.cpp b/contrib/nnc/libs/frontend/caffe/src/caffe_model_visitor.cpp index f3f40f6..2b4c6a0 100644 --- a/contrib/nnc/libs/frontend/caffe/src/caffe_model_visitor.cpp +++ b/contrib/nnc/libs/frontend/caffe/src/caffe_model_visitor.cpp @@ -4,6 +4,7 @@ #include "nncc/core/ADT/tensor/Shape.h" #include "nnc/core/IR/model/operations/variable_op.h" +#include "PluginException.h" #include "shape_helper.h" #include "caffe_model_visitor.h" @@ -22,7 +23,7 @@ using nncc::core::ADT::tensor::Shape; void ModelVisitor::visit(const NetParameter& np) { - processDeprecatedInput(np); + processDeprecatedInput(np); } void ModelVisitor::visit(const LayerParameter& lp) @@ -36,27 +37,65 @@ void ModelVisitor::visit(const BlobShape&) { throw std::runtime_error("Not yet i Graph *ModelVisitor::getGraph() { - return graph; + return graph; } void ModelVisitor::createGraphInputs(const std::vector &names, const std::vector &shapes) { - (void)names; - (void)shapes; - throw std::runtime_error("Not yet implemented"); + assert(names.size() == shapes.size()); + + for (size_t i = 0; i < names.size(); ++i) + { + auto node = graph->create(names[i]); + opsForBlobsTheyOutput[names[i]] = node; + + Shape inputShape = shapes[i]; + // WARNING! Temporary solution! Assuming that every 4D input will be used for a convolution, + // so we change every 4D input from Caffe NCHW to Model IR HWC (batch is cut off earlier). + // TODO: Implement a more consistent way of handling shapes within the model. + if (shapes[i].rank() == 3) + { + const Shape &sh = shapes[i]; + inputShape = Shape{sh.dim(1), sh.dim(2), sh.dim(0)}; + } + // WARNING! Temporary solution! + + node->getOperation()->setOutputShape(0, inputShape); + } } void ModelVisitor::processDeprecatedInput(const NetParameter& np) { - (void)np; - throw std::runtime_error("Not yet implemented"); + if (np.input_dim_size() != 0 || np.input_shape_size() != 0) + { + throw PluginException("Deprecated Caffe input types are not supported"); + } } void ModelVisitor::processInputLayer(const LayerParameter& lp) { - (void)lp; - throw std::runtime_error("Not yet implemented"); + if (!inputShapes.empty()) + { + throw std::runtime_error("Model contains both Input layer and deprecated input methods"); + } + + std::vector inputNames; + for (const auto &name : lp.top()) + { + inputNames.push_back(name); + } + + for (const auto &shape : lp.input_param().shape()) + { + Shape sh = common::ShapeHelper::createShape(shape.dim(), shape.dim_size()); + inputShapes.push_back(common::ShapeHelper::cutOffBatchDim(sh)); + } + + if (!inputShapes.empty()) + { + createGraphInputs(inputNames, inputShapes); + } } std::shared_ptr ModelVisitor::createTensor(const BlobProto &bp) -- 2.7.4