nnc: Fix Caffe output node names (#1024)
authorVitaliy Cherepanov/AI Tools Lab /SRR/Engineer/삼성전자 <v.cherepanov@samsung.com>
Wed, 15 Aug 2018 12:53:46 +0000 (15:53 +0300)
committerSergey Vostokov/AI Tools Lab /SRR/Staff Engineer/삼성전자 <s.vostokov@samsung.com>
Wed, 15 Aug 2018 12:53:46 +0000 (15:53 +0300)
This commit fix interpreting problem
based on wrong IR output nodes names

Signed-off-by: Vitaliy Cherepanov <v.cherepanov@samsung.com>
contrib/nnc/libs/frontend/caffe/include/caffe_model_visitor.h
contrib/nnc/libs/frontend/caffe/src/caffe_importer.cpp
contrib/nnc/libs/frontend/caffe/src/caffe_model_visitor.cpp

index 2088201..d594f9a 100644 (file)
@@ -38,6 +38,8 @@ public:
     void visit(const BlobShape&) override;
 
     Graph* getGraph();
+    void setGraphOutputs();
+    void setIrNodeNames();
 
 private:
     Graph* graph = nullptr;
@@ -45,6 +47,7 @@ private:
 
     std::vector<Shape> inputShapes;
     std::map<std::string, INode::Ref> opsForBlobsTheyOutput;
+    std::vector<INode::Ref> graphOutputs;
 
     std::shared_ptr<IrTensor> createTensor(const BlobProto&);
     std::vector<INode::Ref> createOpInputs(const LayerParameter&);
index e126d3a..8b69f18 100644 (file)
@@ -32,6 +32,8 @@ void* CaffeImporter::createIR()
     ModelWalker caffeWalker(&irCreator);
 
     caffeWalker.walkNetParameter(*net);
+    irCreator.setIrNodeNames();
+    irCreator.setGraphOutputs();
 
     return irCreator.getGraph();
 }
index 68a2fa0..878e09d 100644 (file)
@@ -78,6 +78,7 @@ void ModelVisitor::visit(const LayerParameter& lp)
   {
     opsForBlobsTheyOutput[lp.top(0)] = item;
   }
+  graphOutputs.assign(outputs.begin(), outputs.end());
 }
 
 void ModelVisitor::visit(const BlobProto&) {}
@@ -220,6 +221,21 @@ std::vector<std::shared_ptr<IrTensor>> ModelVisitor::createOpParams(const LayerP
   return params;
 }
 
+void ModelVisitor::setGraphOutputs() {
+  // Marking nodes as output nodes.
+  for (auto &outputIdx : graphOutputs)
+  {
+    graph->markOutput(outputIdx);
+  }
+}
+
+void ModelVisitor::setIrNodeNames() {
+  for (auto &item : opsForBlobsTheyOutput)
+  {
+    item.second->setName(item.first);
+  }
+}
+
 } // namespace caffe
 } // namespace frontend
 } // namespace contrib