void visit(const BlobShape&) override;
Graph* getGraph();
+ void setGraphOutputs();
+ void setIrNodeNames();
private:
Graph* graph = nullptr;
std::vector<Shape> inputShapes;
std::map<std::string, INode::Ref> opsForBlobsTheyOutput;
+ std::vector<INode::Ref> graphOutputs;
std::shared_ptr<IrTensor> createTensor(const BlobProto&);
std::vector<INode::Ref> createOpInputs(const LayerParameter&);
ModelWalker caffeWalker(&irCreator);
caffeWalker.walkNetParameter(*net);
+ irCreator.setIrNodeNames();
+ irCreator.setGraphOutputs();
return irCreator.getGraph();
}
{
opsForBlobsTheyOutput[lp.top(0)] = item;
}
+ graphOutputs.assign(outputs.begin(), outputs.end());
}
void ModelVisitor::visit(const BlobProto&) {}
return params;
}
+void ModelVisitor::setGraphOutputs() {
+ // Marking nodes as output nodes.
+ for (auto &outputIdx : graphOutputs)
+ {
+ graph->markOutput(outputIdx);
+ }
+}
+
+void ModelVisitor::setIrNodeNames() {
+ for (auto &item : opsForBlobsTheyOutput)
+ {
+ item.second->setName(item.first);
+ }
+}
+
} // namespace caffe
} // namespace frontend
} // namespace contrib