From 0d712db2c957b6c11289bca197d1a840b8187420 Mon Sep 17 00:00:00 2001 From: =?utf8?q?Vitaliy=20Cherepanov/AI=20Tools=20Lab=20/SRR/Engineer/?= =?utf8?q?=EC=82=BC=EC=84=B1=EC=A0=84=EC=9E=90?= Date: Wed, 15 Aug 2018 15:53:46 +0300 Subject: [PATCH] nnc: Fix Caffe output node names (#1024) This commit fix interpreting problem based on wrong IR output nodes names Signed-off-by: Vitaliy Cherepanov --- .../libs/frontend/caffe/include/caffe_model_visitor.h | 3 +++ contrib/nnc/libs/frontend/caffe/src/caffe_importer.cpp | 2 ++ .../nnc/libs/frontend/caffe/src/caffe_model_visitor.cpp | 16 ++++++++++++++++ 3 files changed, 21 insertions(+) diff --git a/contrib/nnc/libs/frontend/caffe/include/caffe_model_visitor.h b/contrib/nnc/libs/frontend/caffe/include/caffe_model_visitor.h index 2088201..d594f9a 100644 --- a/contrib/nnc/libs/frontend/caffe/include/caffe_model_visitor.h +++ b/contrib/nnc/libs/frontend/caffe/include/caffe_model_visitor.h @@ -38,6 +38,8 @@ public: void visit(const BlobShape&) override; Graph* getGraph(); + void setGraphOutputs(); + void setIrNodeNames(); private: Graph* graph = nullptr; @@ -45,6 +47,7 @@ private: std::vector inputShapes; std::map opsForBlobsTheyOutput; + std::vector graphOutputs; std::shared_ptr createTensor(const BlobProto&); std::vector createOpInputs(const LayerParameter&); diff --git a/contrib/nnc/libs/frontend/caffe/src/caffe_importer.cpp b/contrib/nnc/libs/frontend/caffe/src/caffe_importer.cpp index e126d3a..8b69f18 100644 --- a/contrib/nnc/libs/frontend/caffe/src/caffe_importer.cpp +++ b/contrib/nnc/libs/frontend/caffe/src/caffe_importer.cpp @@ -32,6 +32,8 @@ void* CaffeImporter::createIR() ModelWalker caffeWalker(&irCreator); caffeWalker.walkNetParameter(*net); + irCreator.setIrNodeNames(); + irCreator.setGraphOutputs(); return irCreator.getGraph(); } diff --git a/contrib/nnc/libs/frontend/caffe/src/caffe_model_visitor.cpp b/contrib/nnc/libs/frontend/caffe/src/caffe_model_visitor.cpp index 68a2fa0..878e09d 100644 --- a/contrib/nnc/libs/frontend/caffe/src/caffe_model_visitor.cpp +++ b/contrib/nnc/libs/frontend/caffe/src/caffe_model_visitor.cpp @@ -78,6 +78,7 @@ void ModelVisitor::visit(const LayerParameter& lp) { opsForBlobsTheyOutput[lp.top(0)] = item; } + graphOutputs.assign(outputs.begin(), outputs.end()); } void ModelVisitor::visit(const BlobProto&) {} @@ -220,6 +221,21 @@ std::vector> ModelVisitor::createOpParams(const LayerP return params; } +void ModelVisitor::setGraphOutputs() { + // Marking nodes as output nodes. + for (auto &outputIdx : graphOutputs) + { + graph->markOutput(outputIdx); + } +} + +void ModelVisitor::setIrNodeNames() { + for (auto &item : opsForBlobsTheyOutput) + { + item.second->setName(item.first); + } +} + } // namespace caffe } // namespace frontend } // namespace contrib -- 2.7.4