Release 18.03
[platform/upstream/armnn.git] / src / armnnTfParser / TfParser.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5 #pragma once
6
7 #include "armnnTfParser/ITfParser.hpp"
8
9 #include "armnn/Types.hpp"
10 #include "armnn/Tensor.hpp"
11 #include "armnn/INetwork.hpp"
12
13 #include <map>
14 #include <memory>
15 #include <unordered_map>
16 #include <vector>
17
18 namespace armnn
19 {
20 class TensorInfo;
21 }
22
23 namespace tensorflow
24 {
25 class GraphDef;
26 class NodeDef;
27 }
28
29 namespace armnnTfParser
30 {
31
32 using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
33
34 class ParsedTfOperation;
35 using ParsedTfOperationPtr = std::unique_ptr<ParsedTfOperation>;
36
37 ///
38 /// WithOutputTensorIndex wraps a value and an index. The purpose of
39 /// this template is to signify that in Tensorflow the input name of
40 /// a layer has the convention of 'inputTensorName:#index' where the
41 /// #index can be omitted and it implicitly means the 0. output of
42 /// the referenced layer. By supporting this notation we can handle
43 /// layers with multiple outputs, such as Split.
44 ///
45 template <typename T>
46 struct WithOutputTensorIndex
47 {
48     T                m_IndexedValue;
49     unsigned int     m_Index;
50
51     WithOutputTensorIndex(const T & value, unsigned int index)
52     : m_IndexedValue{value}
53     , m_Index{index} {}
54
55     WithOutputTensorIndex(T && value, unsigned int index)
56     : m_IndexedValue{value}
57     , m_Index{index} {}
58 };
59
60 using OutputOfParsedTfOperation = WithOutputTensorIndex<ParsedTfOperation *>;
61 using OutputOfConstNodeDef = WithOutputTensorIndex<const tensorflow::NodeDef*>;
62 using OutputId = WithOutputTensorIndex<std::string>;
63
64 class TfParser : public ITfParser
65 {
66 public:
67     /// Create the network from a protobuf text file on disk
68     virtual armnn::INetworkPtr CreateNetworkFromTextFile(
69         const char* graphFile,
70         const std::map<std::string, armnn::TensorShape>& inputShapes,
71         const std::vector<std::string>& requestedOutputs) override;
72
73     /// Create the network from a protobuf binary file on disk
74     virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(
75         const char* graphFile,
76         const std::map<std::string, armnn::TensorShape>& inputShapes,
77         const std::vector<std::string>& requestedOutputs) override;
78
79     /// Create the network directly from protobuf text in a string. Useful for debugging/testing
80     virtual armnn::INetworkPtr CreateNetworkFromString(
81         const char* protoText,
82         const std::map<std::string, armnn::TensorShape>& inputShapes,
83         const std::vector<std::string>& requestedOutputs) override;
84
85     /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name
86     virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const override;
87
88     /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name
89     virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const override;
90
91 public:
92     TfParser();
93
94 private:
95     template <typename T>
96     friend class ParsedConstTfOperation;
97     friend class ParsedMatMulTfOperation;
98
99     /// Parses a GraphDef loaded into memory from one of the other CreateNetwork*
100     armnn::INetworkPtr CreateNetworkFromGraphDef(const tensorflow::GraphDef& graphDef,
101         const std::map<std::string, armnn::TensorShape>& inputShapes,
102         const std::vector<std::string>& requestedOutputs);
103
104     /// sets up variables and then performs BFS to parse all nodes
105     void LoadGraphDef(const tensorflow::GraphDef& graphDef);
106
107     /// parses a given node, assuming nodes before it in graph have been done
108     void LoadNodeDef(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
109
110     /// Handling identity layers as the input for Conv2D layer
111     const tensorflow::NodeDef* ResolveIdentityNode(const tensorflow::NodeDef* nodeDef);
112     /// Finds the nodes connected as inputs of the given node in the graph.
113     std::vector<OutputOfConstNodeDef> GetTfInputNodes(const tensorflow::NodeDef& nodeDef) const;
114     /// Finds the IParsedTfOperations for the nodes connected as inputs of the given node in the graph,
115     /// and throws an exception if the number of inputs does not match the expected one.
116     /// This will automatically resolve any identity nodes. The result vector contains the parsed operation
117     /// together with the output tensor index to make the connection unambiguous.
118     std::vector<OutputOfParsedTfOperation> GetInputParsedTfOperationsChecked(const tensorflow::NodeDef& nodeDef,
119                                                                              std::size_t expectedNumInputs);
120
121     ParsedTfOperationPtr ParseConst(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
122
123     /// Checks if there is a pre-parsed const tensor is available with the given name and Type
124     template<typename Type>
125     bool HasParsedConstTensor(const std::string & nodeName) const;
126
127     ParsedTfOperationPtr ParseAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
128     ParsedTfOperationPtr ParseBiasAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
129     ParsedTfOperationPtr ParseConv2D(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
130     ParsedTfOperationPtr ParseDepthwiseConv2D(const tensorflow::NodeDef& nodeDef,const tensorflow::GraphDef& graphDef);
131     ParsedTfOperationPtr ParseFusedBatchNorm(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
132     ParsedTfOperationPtr ParseConcat(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
133     ParsedTfOperationPtr ParseIdentity(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
134     ParsedTfOperationPtr ParseLrn(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
135     ParsedTfOperationPtr ParseMatMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
136     ParsedTfOperationPtr ParseMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
137     ParsedTfOperationPtr ParsePlaceholder(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
138     ParsedTfOperationPtr ParseRelu(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
139     ParsedTfOperationPtr ParseRelu6(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
140     ParsedTfOperationPtr ParseReshape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
141     ParsedTfOperationPtr ParseResizeBilinear(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
142     ParsedTfOperationPtr ParseShape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
143     ParsedTfOperationPtr ParseSqueeze(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
144     ParsedTfOperationPtr ParseSigmoid(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
145     ParsedTfOperationPtr ParseSoftmax(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
146     ParsedTfOperationPtr ParseSoftplus(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
147     ParsedTfOperationPtr ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
148     ParsedTfOperationPtr ParseMaxPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
149     ParsedTfOperationPtr ParseAvgPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
150     ParsedTfOperationPtr ParsePooling2d(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef,
151         armnn::PoolingAlgorithm pooltype);
152     ParsedTfOperationPtr AddActivationLayer(const tensorflow::NodeDef& nodeDef, armnn::ActivationDescriptor& desc);
153     ParsedTfOperationPtr AddAdditionLayer(const tensorflow::NodeDef& nodeDef, bool isBiasAdd = false);
154     armnn::IConnectableLayer* AddFullyConnectedLayer(const tensorflow::NodeDef& matMulNodeDef,
155         const tensorflow::NodeDef* addNodeDef, const char* armnnLayerName);
156
157     static std::pair<armnn::LayerBindingId, armnn::TensorInfo> GetBindingInfo(const std::string& layerName,
158         const char* bindingPointDesc,
159         const std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo);
160
161     void TrackInputBinding(armnn::IConnectableLayer* layer,
162         armnn::LayerBindingId id,
163         const armnn::TensorInfo& tensorInfo);
164
165     void TrackOutputBinding(armnn::IConnectableLayer* layer,
166         armnn::LayerBindingId id,
167         const armnn::TensorInfo& tensorInfo);
168
169     static void TrackBindingPoint(armnn::IConnectableLayer* layer, armnn::LayerBindingId id,
170         const armnn::TensorInfo& tensorInfo,
171         const char* bindingPointDesc,
172         std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo);
173
174     void Cleanup();
175
176     /// The network we're building. Gets cleared after it is passed to the user
177     armnn::INetworkPtr m_Network;
178
179     using OperationParsingFunction = ParsedTfOperationPtr(TfParser::*)(const tensorflow::NodeDef& nodeDef,
180                                                                  const tensorflow::GraphDef& graphDef);
181
182     /// map of TensorFlow operation names to parsing member functions
183     static const std::map<std::string, OperationParsingFunction> ms_OperationNameToParsingFunctions;
184
185     std::map<std::string, armnn::TensorShape> m_InputShapes;
186     std::vector<std::string> m_RequestedOutputs;
187
188     /// map of nodes extracted from the GraphDef to speed up parsing
189     std::unordered_map<std::string, const tensorflow::NodeDef*> m_NodesByName;
190
191     std::unordered_map<std::string, ParsedTfOperationPtr> m_ParsedTfOperations;
192
193     /// maps input layer names to their corresponding ids and tensor infos
194     std::unordered_map<std::string, BindingPointInfo> m_NetworkInputsBindingInfo;
195
196     /// maps output layer names to their corresponding ids and tensor infos
197     std::unordered_map<std::string, BindingPointInfo> m_NetworkOutputsBindingInfo;
198 };
199 }