Release 18.08
[platform/upstream/armnn.git] / src / armnnCaffeParser / RecordByRecordCaffeParser.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #pragma once
7
8 #include <string>
9 #include <vector>
10 #include <iostream>
11
12 #include "caffe/proto/caffe.pb.h"
13
14 #include "CaffeParser.hpp"
15
16
17
18 namespace armnnCaffeParser
19 {
20
21 class NetParameterInfo;
22 class LayerParameterInfo;
23
24
25 class RecordByRecordCaffeParser : public CaffeParserBase
26 {
27 public:
28
29     /// Create the network from a protobuf binary file on disk
30     virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(
31         const char* graphFile,
32         const std::map<std::string, armnn::TensorShape>& inputShapes,
33         const std::vector<std::string>& requestedOutputs) override;
34
35     RecordByRecordCaffeParser();
36
37 private:
38     void ProcessLayers(const NetParameterInfo& netParameterInfo,
39                        std::vector<LayerParameterInfo>& layerInfo,
40                        const std::vector<std::string>& m_RequestedOutputs,
41                        std::vector<const LayerParameterInfo*>& sortedNodes);
42     armnn::INetworkPtr LoadLayers(std::ifstream& ifs,
43                                   std::vector<const LayerParameterInfo *>& sortedNodes,
44                                   const NetParameterInfo& netParameterInfo);
45     std::vector<const LayerParameterInfo*> GetInputs(
46         const LayerParameterInfo& layerParam);
47
48     std::map<std::string, const LayerParameterInfo*> m_CaffeLayersByTopName;
49     std::vector<std::string> m_RequestedOutputs;
50 };
51
52 } // namespace armnnCaffeParser
53