From: Vitaliy Cherepanov/AI Tools Lab /SRR/Engineer/삼성전자 Date: Wed, 15 Aug 2018 12:53:46 +0000 (+0300) Subject: nnc: Fix Caffe output node names (#1024) X-Git-Tag: nncc_backup~2158 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=0d712db2c957b6c11289bca197d1a840b8187420;p=platform%2Fcore%2Fml%2Fnnfw.git nnc: Fix Caffe output node names (#1024) This commit fix interpreting problem based on wrong IR output nodes names Signed-off-by: Vitaliy Cherepanov --- diff --git a/contrib/nnc/libs/frontend/caffe/include/caffe_model_visitor.h b/contrib/nnc/libs/frontend/caffe/include/caffe_model_visitor.h index 2088201..d594f9a 100644 --- a/contrib/nnc/libs/frontend/caffe/include/caffe_model_visitor.h +++ b/contrib/nnc/libs/frontend/caffe/include/caffe_model_visitor.h @@ -38,6 +38,8 @@ public: void visit(const BlobShape&) override; Graph* getGraph(); + void setGraphOutputs(); + void setIrNodeNames(); private: Graph* graph = nullptr; @@ -45,6 +47,7 @@ private: std::vector inputShapes; std::map opsForBlobsTheyOutput; + std::vector graphOutputs; std::shared_ptr createTensor(const BlobProto&); std::vector createOpInputs(const LayerParameter&); diff --git a/contrib/nnc/libs/frontend/caffe/src/caffe_importer.cpp b/contrib/nnc/libs/frontend/caffe/src/caffe_importer.cpp index e126d3a..8b69f18 100644 --- a/contrib/nnc/libs/frontend/caffe/src/caffe_importer.cpp +++ b/contrib/nnc/libs/frontend/caffe/src/caffe_importer.cpp @@ -32,6 +32,8 @@ void* CaffeImporter::createIR() ModelWalker caffeWalker(&irCreator); caffeWalker.walkNetParameter(*net); + irCreator.setIrNodeNames(); + irCreator.setGraphOutputs(); return irCreator.getGraph(); } diff --git a/contrib/nnc/libs/frontend/caffe/src/caffe_model_visitor.cpp b/contrib/nnc/libs/frontend/caffe/src/caffe_model_visitor.cpp index 68a2fa0..878e09d 100644 --- a/contrib/nnc/libs/frontend/caffe/src/caffe_model_visitor.cpp +++ b/contrib/nnc/libs/frontend/caffe/src/caffe_model_visitor.cpp @@ -78,6 +78,7 @@ void ModelVisitor::visit(const LayerParameter& lp) { opsForBlobsTheyOutput[lp.top(0)] = item; } + graphOutputs.assign(outputs.begin(), outputs.end()); } void ModelVisitor::visit(const BlobProto&) {} @@ -220,6 +221,21 @@ std::vector> ModelVisitor::createOpParams(const LayerP return params; } +void ModelVisitor::setGraphOutputs() { + // Marking nodes as output nodes. + for (auto &outputIdx : graphOutputs) + { + graph->markOutput(outputIdx); + } +} + +void ModelVisitor::setIrNodeNames() { + for (auto &item : opsForBlobsTheyOutput) + { + item.second->setName(item.first); + } +} + } // namespace caffe } // namespace frontend } // namespace contrib