Release 18.08
[platform/upstream/armnn.git] / src / armnnOnnxParser / OnnxParser.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #pragma once
6
7 #include "armnnOnnxParser/IOnnxParser.hpp"
8 #include "google/protobuf/repeated_field.h"
9 #include <unordered_map>
10
11 #include <onnx/onnx.pb.h>
12
13
14 namespace armnn
15 {
16 class TensorInfo;
17 }
18
19 namespace armnnOnnxParser
20 {
21
22 using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
23 using ModelPtr = std::unique_ptr<onnx::ModelProto>;
24
25 class OnnxParser : public IOnnxParser
26 {
27
28 using OperationParsingFunction = void(OnnxParser::*)(const onnx::NodeProto& NodeProto);
29
30 public:
31
32     using GraphPtr = std::unique_ptr<onnx::GraphProto>;
33
34     /// Create the network from a protobuf binary file on disk
35     virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(const char* graphFile) override;
36
37     /// Create the network from a protobuf text file on disk
38     virtual armnn::INetworkPtr CreateNetworkFromTextFile(const char* graphFile) override;
39
40     /// Create the network directly from protobuf text in a string. Useful for debugging/testing
41     virtual armnn::INetworkPtr CreateNetworkFromString(const std::string& protoText) override;
42
43     /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name
44     virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const override;
45
46     /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name
47     virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const override;
48
49 public:
50
51     OnnxParser();
52
53     static ModelPtr LoadModelFromBinaryFile(const char * fileName);
54     static ModelPtr LoadModelFromTextFile(const char * fileName);
55     static ModelPtr LoadModelFromString(const std::string& inputString);
56
57     ///Retrieve inputs names
58     static std::vector<std::string> GetInputs(ModelPtr& model);
59
60     ///Retrieve outputs names
61     static std::vector<std::string> GetOutputs(ModelPtr& model);
62
63 private:
64
65     /// Parses a ModelProto loaded into memory from one of the other CreateNetwork*
66     armnn::INetworkPtr CreateNetworkFromModel(onnx::ModelProto& model);
67
68     ///Parse every node and make the connection between the resulting tensors
69     void LoadGraph();
70
71     void SetupInfo(const google::protobuf::RepeatedPtrField<onnx::ValueInfoProto >* list);
72
73     std::vector<armnn::TensorInfo> ComputeOutputInfo(std::vector<std::string> outNames,
74                                                      const armnn::IConnectableLayer* layer,
75                                                      std::vector<armnn::TensorShape> inputShapes);
76
77     void DetectFullyConnected();
78
79     template <typename Location>
80     void GetInputAndParam(const onnx::NodeProto& node,
81                           std::string* inputName,
82                           std::string* constName,
83                           const Location& location);
84
85     template <typename Location>
86     void To1DTensor(const std::string &name, const Location& location);
87
88     //Broadcast Preparation functions
89     std::pair<std::string, std::string> AddPrepareBroadcast(const std::string& input0, const std::string& input1);
90     void PrependForBroadcast(const std::string& outputName, const std::string& input0, const std::string& input1);
91
92     void CreateConstantLayer(const std::string& tensorName, const std::string& layerName);
93     void CreateReshapeLayer(const std::string& inputName,
94                             const std::string& outputName,
95                             const std::string& layerName);
96
97     void ParseBatchNormalization(const onnx::NodeProto& node);
98     void ParseConstant(const onnx::NodeProto& nodeProto);
99
100     void ParseMaxPool(const onnx::NodeProto& nodeProto);
101     void ParseAveragePool(const onnx::NodeProto& nodeProto);
102     void ParseGlobalAveragePool(const onnx::NodeProto& node);
103
104     void AddPoolingLayer(const onnx::NodeProto& nodeProto, armnn::Pooling2dDescriptor& desc);
105
106     void ParseReshape(const onnx::NodeProto& nodeProto);
107     void ParseRelu(const onnx::NodeProto& nodeProto);
108
109     void AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node, const armnn::Convolution2dDescriptor& convDesc);
110     void ParseConv(const onnx::NodeProto& nodeProto);
111
112     void ParseAdd(const onnx::NodeProto& nodeProto);
113     void AddFullyConnected(const onnx::NodeProto& matmulNode, const onnx::NodeProto* addNode = nullptr);
114
115     void RegisterInputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes);
116     void RegisterOutputSlots(armnn::IConnectableLayer* layer, const std::vector<std::string>& tensorIndexes);
117
118     void SetupInputLayers();
119     void SetupOutputLayers();
120
121     void ResetParser();
122     void Cleanup();
123
124     std::pair<armnn::ConstTensor, std::unique_ptr<float[]>> CreateConstTensor(const std::string name);
125
126     template <typename TypeList, typename Location>
127     void ValidateInputs(const onnx::NodeProto& node,
128                         TypeList validInputs,
129                         const Location& location);
130
131     /// The network we're building. Gets cleared after it is passed to the user
132     armnn::INetworkPtr m_Network;
133
134     ///Ptr to the graph we're building the network from
135     GraphPtr m_Graph;
136
137     ///Map of the information for every tensor
138     struct OnnxTensor
139     {
140         std::unique_ptr<armnn::TensorInfo>          m_info;
141         std::unique_ptr<const onnx::TensorProto>    m_tensor;
142         onnx::TensorProto::DataType                 m_dtype;
143
144         OnnxTensor() : m_info(nullptr), m_tensor(nullptr), m_dtype(onnx::TensorProto::FLOAT) { }
145         bool isConstant() { return m_tensor != nullptr; }
146
147     };
148
149     std::unordered_map<std::string, OnnxTensor> m_TensorsInfo;
150
151     /// map of onnx operation names to parsing member functions
152     static const std::map<std::string, OperationParsingFunction> m_ParserFunctions;
153
154     /// A mapping of an output slot to each of the input slots it should be connected to
155     /// The outputSlot is from the layer that creates this tensor as one of its ouputs
156     /// The inputSlots are from the layers that use this tensor as one of their inputs
157     struct TensorSlots
158     {
159         armnn::IOutputSlot* outputSlot;
160         std::vector<armnn::IInputSlot*> inputSlots;
161
162         TensorSlots() : outputSlot(nullptr) { }
163     };
164     ///Map of the tensor names to their connections for the connections of the layers of the graph
165     std::unordered_map<std::string, TensorSlots> m_TensorConnections;
166
167     //Map of the tensor names to their node and index in graph.node()
168     std::unordered_map<std::string, std::pair<const onnx::NodeProto*, int>> m_OutputsMap;
169
170     /// Number of times a specific node (identified by his index number) was used as input
171     /// and list of the nodes it was fused with
172     struct UsageSummary
173     {
174         std::vector<size_t> fusedWithNodes;
175         size_t inputForNodes;
176
177         UsageSummary() : fusedWithNodes({}), inputForNodes(0) { }
178
179     };
180
181     std::vector<UsageSummary> m_OutputsFusedAndUsed;
182 };
183 }