2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
7 #include "armnnTfParser/ITfParser.hpp"
9 #include "armnn/Types.hpp"
10 #include "armnn/Tensor.hpp"
11 #include "armnn/INetwork.hpp"
15 #include <unordered_map>
29 namespace armnnTfParser
32 using BindingPointInfo = std::pair<armnn::LayerBindingId, armnn::TensorInfo>;
34 class ParsedTfOperation;
35 using ParsedTfOperationPtr = std::unique_ptr<ParsedTfOperation>;
38 /// WithOutputTensorIndex wraps a value and an index. The purpose of
39 /// this template is to signify that, in Tensorflow, the input name of
40 /// a layer has the convention of 'inputTensorName:#index', where the
41 /// #index can be omitted and it implicitly means the 0 output of
42 /// the referenced layer. By supporting this notation we can handle
43 /// layers with multiple outputs, such as Split.
46 struct WithOutputTensorIndex
51 WithOutputTensorIndex(const T & value, unsigned int index)
52 : m_IndexedValue{value}
55 WithOutputTensorIndex(T && value, unsigned int index)
56 : m_IndexedValue{value}
60 using OutputOfParsedTfOperation = WithOutputTensorIndex<ParsedTfOperation *>;
61 using OutputOfConstNodeDef = WithOutputTensorIndex<const tensorflow::NodeDef*>;
62 using OutputId = WithOutputTensorIndex<std::string>;
64 class TfParser : public ITfParser
67 /// Creates the network from a protobuf text file on the disk.
68 virtual armnn::INetworkPtr CreateNetworkFromTextFile(
69 const char* graphFile,
70 const std::map<std::string, armnn::TensorShape>& inputShapes,
71 const std::vector<std::string>& requestedOutputs) override;
73 /// Creates the network from a protobuf binary file on the disk.
74 virtual armnn::INetworkPtr CreateNetworkFromBinaryFile(
75 const char* graphFile,
76 const std::map<std::string, armnn::TensorShape>& inputShapes,
77 const std::vector<std::string>& requestedOutputs) override;
79 /// Creates the network directly from protobuf text in a string. Useful for debugging/testing.
80 virtual armnn::INetworkPtr CreateNetworkFromString(
81 const char* protoText,
82 const std::map<std::string, armnn::TensorShape>& inputShapes,
83 const std::vector<std::string>& requestedOutputs) override;
85 /// Retrieves binding info (layer id and tensor info) for the network input identified by the given layer name.
86 virtual BindingPointInfo GetNetworkInputBindingInfo(const std::string& name) const override;
88 /// Retrieves binding info (layer id and tensor info) for the network output identified by the given layer name.
89 virtual BindingPointInfo GetNetworkOutputBindingInfo(const std::string& name) const override;
96 friend class ParsedConstTfOperation;
97 friend class ParsedMatMulTfOperation;
98 friend class ParsedMulTfOperation;
100 /// Parses a GraphDef loaded into memory from one of the other CreateNetwork*.
101 armnn::INetworkPtr CreateNetworkFromGraphDef(const tensorflow::GraphDef& graphDef,
102 const std::map<std::string, armnn::TensorShape>& inputShapes,
103 const std::vector<std::string>& requestedOutputs);
105 /// Sets up variables and then performs BFS to parse all nodes.
106 void LoadGraphDef(const tensorflow::GraphDef& graphDef);
108 /// Parses a given node, assuming nodes before it in the graph have been done.
109 void LoadNodeDef(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
111 /// Handling identity layers as the input for Conv2D layer.
112 const tensorflow::NodeDef* ResolveIdentityNode(const tensorflow::NodeDef* nodeDef);
113 /// Finds the nodes connected as inputs of the given node in the graph.
114 std::vector<OutputOfConstNodeDef> GetTfInputNodes(const tensorflow::NodeDef& nodeDef) const;
115 /// Finds the IParsedTfOperations for the nodes connected as inputs of the given node in the graph,
116 /// and throws an exception if the number of inputs does not match the expected one.
117 /// This will automatically resolve any identity nodes. The result vector contains the parsed operation
118 /// together with the output tensor index to make the connection unambiguous.
119 std::vector<OutputOfParsedTfOperation> GetInputParsedTfOperationsChecked(const tensorflow::NodeDef& nodeDef,
120 std::size_t expectedNumInputs);
122 ParsedTfOperationPtr ParseConst(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
124 /// Checks if there is a pre-parsed const tensor available with the given name and Type.
125 template<typename Type>
126 bool HasParsedConstTensor(const std::string & nodeName) const;
128 ParsedTfOperationPtr ParseAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
129 ParsedTfOperationPtr ParseBiasAdd(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
130 ParsedTfOperationPtr ParseConv2D(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
131 ParsedTfOperationPtr ParseDepthwiseConv2D(const tensorflow::NodeDef& nodeDef,const tensorflow::GraphDef& graphDef);
132 ParsedTfOperationPtr ParseFusedBatchNorm(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
133 ParsedTfOperationPtr ParseConcat(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
134 ParsedTfOperationPtr ParseIdentity(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
135 ParsedTfOperationPtr ParseLrn(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
136 ParsedTfOperationPtr ParseMatMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
137 ParsedTfOperationPtr ParseMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
138 ParsedTfOperationPtr ParsePlaceholder(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
139 ParsedTfOperationPtr ParseRelu(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
140 ParsedTfOperationPtr ParseRelu6(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
141 ParsedTfOperationPtr ParseReshape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
142 ParsedTfOperationPtr ParseResizeBilinear(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
143 ParsedTfOperationPtr ParseShape(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
144 ParsedTfOperationPtr ParseSqueeze(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
145 ParsedTfOperationPtr ParseSigmoid(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
146 ParsedTfOperationPtr ParseSoftmax(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
147 ParsedTfOperationPtr ParseSoftplus(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
148 ParsedTfOperationPtr ParseTanh(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
149 ParsedTfOperationPtr ParseMaxPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
150 ParsedTfOperationPtr ParseAvgPool(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
151 ParsedTfOperationPtr ParsePooling2d(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef,
152 armnn::PoolingAlgorithm pooltype);
153 ParsedTfOperationPtr ParseMaximum(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef);
154 ParsedTfOperationPtr AddActivationLayer(const tensorflow::NodeDef& nodeDef, armnn::ActivationDescriptor& desc);
155 ParsedTfOperationPtr AddAdditionLayer(const tensorflow::NodeDef& nodeDef, bool isBiasAdd = false);
158 armnn::IConnectableLayer* AddMultiplicationLayer(const tensorflow::NodeDef& nodeDef);
160 armnn::IConnectableLayer* AddFullyConnectedLayer(const tensorflow::NodeDef& matMulNodeDef,
161 const tensorflow::NodeDef* addNodeDef, const char* armnnLayerName);
163 bool IsSupportedLeakyReluPattern(const tensorflow::NodeDef& mulNodeDef,
164 size_t alphaLayerIndex,
165 const OutputOfParsedTfOperation& otherOp,
166 armnn::IOutputSlot** outputOfLeakyRelu,
167 armnn::ActivationDescriptor & desc);
169 static std::pair<armnn::LayerBindingId, armnn::TensorInfo> GetBindingInfo(const std::string& layerName,
170 const char* bindingPointDesc,
171 const std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo);
173 void TrackInputBinding(armnn::IConnectableLayer* layer,
174 armnn::LayerBindingId id,
175 const armnn::TensorInfo& tensorInfo);
177 void TrackOutputBinding(armnn::IConnectableLayer* layer,
178 armnn::LayerBindingId id,
179 const armnn::TensorInfo& tensorInfo);
181 static void TrackBindingPoint(armnn::IConnectableLayer* layer, armnn::LayerBindingId id,
182 const armnn::TensorInfo& tensorInfo,
183 const char* bindingPointDesc,
184 std::unordered_map<std::string, BindingPointInfo>& nameToBindingInfo);
188 /// The network we're building. Gets cleared after it is passed to the user.
189 armnn::INetworkPtr m_Network;
191 using OperationParsingFunction = ParsedTfOperationPtr(TfParser::*)(const tensorflow::NodeDef& nodeDef,
192 const tensorflow::GraphDef& graphDef);
194 /// Map of TensorFlow operation names to parsing member functions.
195 static const std::map<std::string, OperationParsingFunction> ms_OperationNameToParsingFunctions;
197 std::map<std::string, armnn::TensorShape> m_InputShapes;
198 std::vector<std::string> m_RequestedOutputs;
200 /// Map of nodes extracted from the GraphDef to speed up parsing.
201 std::unordered_map<std::string, const tensorflow::NodeDef*> m_NodesByName;
203 std::unordered_map<std::string, ParsedTfOperationPtr> m_ParsedTfOperations;
205 /// Maps input layer names to their corresponding ids and tensor info.
206 std::unordered_map<std::string, BindingPointInfo> m_NetworkInputsBindingInfo;
208 /// Maps output layer names to their corresponding ids and tensor info.
209 std::unordered_map<std::string, BindingPointInfo> m_NetworkOutputsBindingInfo;