2a8639f7fc8bd88436f591b7990d7cc9ae3f3dcc
[platform/upstream/armnn.git] / src / armnnDeserializer / Deserializer.hpp
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include "armnn/INetwork.hpp"
9 #include "armnnDeserializer/IDeserializer.hpp"
10 #include <ArmnnSchema_generated.h>
11
12 #include <unordered_map>
13
14 namespace armnnDeserializer
15 {
16 class Deserializer : public IDeserializer
17 {
18 public:
19     // Shorthands for deserializer types
20     using ConstTensorRawPtr = const armnnSerializer::ConstTensor *;
21     using GraphPtr = const armnnSerializer::SerializedGraph *;
22     using TensorRawPtr = const armnnSerializer::TensorInfo *;
23     using PoolingDescriptor = const armnnSerializer::Pooling2dDescriptor *;
24     using NormalizationDescriptorPtr = const armnnSerializer::NormalizationDescriptor *;
25     using LstmDescriptorPtr = const armnnSerializer::LstmDescriptor *;
26     using LstmInputParamsPtr = const armnnSerializer::LstmInputParams *;
27     using QLstmDescriptorPtr = const armnnSerializer::QLstmDescriptor *;
28     using QunatizedLstmInputParamsPtr = const armnnSerializer::QuantizedLstmInputParams *;
29     using TensorRawPtrVector = std::vector<TensorRawPtr>;
30     using LayerRawPtr = const armnnSerializer::LayerBase *;
31     using LayerBaseRawPtr = const armnnSerializer::LayerBase *;
32     using LayerBaseRawPtrVector = std::vector<LayerBaseRawPtr>;
33
34 public:
35
36     /// Create an input network from binary file contents
37     armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent) override;
38
39     /// Create an input network from a binary input stream
40     armnn::INetworkPtr CreateNetworkFromBinary(std::istream& binaryContent) override;
41
42     /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name
43     BindingPointInfo GetNetworkInputBindingInfo(unsigned int layerId, const std::string& name) const override;
44
45     /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name
46     BindingPointInfo GetNetworkOutputBindingInfo(unsigned int layerId, const std::string& name) const override;
47
48     Deserializer();
49     ~Deserializer() {}
50
51 public:
52     // testable helpers
53     static GraphPtr LoadGraphFromBinary(const uint8_t* binaryContent, size_t len);
54     static TensorRawPtrVector GetInputs(const GraphPtr& graph, unsigned int layerIndex);
55     static TensorRawPtrVector GetOutputs(const GraphPtr& graph, unsigned int layerIndex);
56     static LayerBaseRawPtr GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex);
57     static int32_t GetBindingLayerInfo(const GraphPtr& graphPtr, unsigned int layerIndex);
58     static std::string GetLayerName(const GraphPtr& graph, unsigned int index);
59     static armnn::Pooling2dDescriptor GetPoolingDescriptor(PoolingDescriptor pooling2dDescriptor,
60                                                            unsigned int layerIndex);
61     static armnn::NormalizationDescriptor GetNormalizationDescriptor(
62         NormalizationDescriptorPtr normalizationDescriptor, unsigned int layerIndex);
63     static armnn::LstmDescriptor GetLstmDescriptor(LstmDescriptorPtr lstmDescriptor);
64     static armnn::LstmInputParams GetLstmInputParams(LstmDescriptorPtr lstmDescriptor,
65                                                      LstmInputParamsPtr lstmInputParams);
66     static armnn::QLstmDescriptor GetQLstmDescriptor(QLstmDescriptorPtr qLstmDescriptorPtr);
67     static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
68                                                   const std::vector<uint32_t> & targetDimsIn);
69
70 private:
71     // No copying allowed until it is wanted and properly implemented
72     Deserializer(const Deserializer&) = delete;
73     Deserializer& operator=(const Deserializer&) = delete;
74
75     /// Create the network from an already loaded flatbuffers graph
76     armnn::INetworkPtr CreateNetworkFromGraph(GraphPtr graph);
77
78     // signature for the parser functions
79     using LayerParsingFunction = void(Deserializer::*)(GraphPtr graph, unsigned int layerIndex);
80
81     void ParseUnsupportedLayer(GraphPtr graph, unsigned int layerIndex);
82     void ParseAbs(GraphPtr graph, unsigned int layerIndex);
83     void ParseActivation(GraphPtr graph, unsigned int layerIndex);
84     void ParseAdd(GraphPtr graph, unsigned int layerIndex);
85     void ParseArgMinMax(GraphPtr graph, unsigned int layerIndex);
86     void ParseBatchToSpaceNd(GraphPtr graph, unsigned int layerIndex);
87     void ParseBatchNormalization(GraphPtr graph, unsigned int layerIndex);
88     void ParseComparison(GraphPtr graph, unsigned int layerIndex);
89     void ParseConcat(GraphPtr graph, unsigned int layerIndex);
90     void ParseConstant(GraphPtr graph, unsigned int layerIndex);
91     void ParseConvolution2d(GraphPtr graph, unsigned int layerIndex);
92     void ParseDepthToSpace(GraphPtr graph, unsigned int layerIndex);
93     void ParseDepthwiseConvolution2d(GraphPtr graph, unsigned int layerIndex);
94     void ParseDequantize(GraphPtr graph, unsigned int layerIndex);
95     void ParseDetectionPostProcess(GraphPtr graph, unsigned int layerIndex);
96     void ParseDivision(GraphPtr graph, unsigned int layerIndex);
97     void ParseElementwiseUnary(GraphPtr graph, unsigned int layerIndex);
98     void ParseEqual(GraphPtr graph, unsigned int layerIndex);
99     void ParseFill(GraphPtr graph, unsigned int layerIndex);
100     void ParseFloor(GraphPtr graph, unsigned int layerIndex);
101     void ParseFullyConnected(GraphPtr graph, unsigned int layerIndex);
102     void ParseGather(GraphPtr graph, unsigned int layerIndex);
103     void ParseGreater(GraphPtr graph, unsigned int layerIndex);
104     void ParseInstanceNormalization(GraphPtr graph, unsigned int layerIndex);
105     void ParseL2Normalization(GraphPtr graph, unsigned int layerIndex);
106     void ParseLogicalBinary(GraphPtr graph, unsigned int layerIndex);
107     void ParseLogSoftmax(GraphPtr graph, unsigned int layerIndex);
108     void ParseMaximum(GraphPtr graph, unsigned int layerIndex);
109     void ParseMean(GraphPtr graph, unsigned int layerIndex);
110     void ParseMinimum(GraphPtr graph, unsigned int layerIndex);
111     void ParseMerge(GraphPtr graph, unsigned int layerIndex);
112     void ParseMultiplication(GraphPtr graph, unsigned int layerIndex);
113     void ParseNormalization(GraphPtr graph, unsigned int layerIndex);
114     void ParseLstm(GraphPtr graph, unsigned int layerIndex);
115     void ParseQuantizedLstm(GraphPtr graph, unsigned int layerIndex);
116     void ParsePad(GraphPtr graph, unsigned int layerIndex);
117     void ParsePermute(GraphPtr graph, unsigned int layerIndex);
118     void ParsePooling2d(GraphPtr graph, unsigned int layerIndex);
119     void ParsePrelu(GraphPtr graph, unsigned int layerIndex);
120     void ParseQLstm(GraphPtr graph, unsigned int layerIndex);
121     void ParseQuantize(GraphPtr graph, unsigned int layerIndex);
122     void ParseRank(GraphPtr graph, unsigned int layerIndex);
123     void ParseReshape(GraphPtr graph, unsigned int layerIndex);
124     void ParseResize(GraphPtr graph, unsigned int layerIndex);
125     void ParseResizeBilinear(GraphPtr graph, unsigned int layerIndex);
126     void ParseRsqrt(GraphPtr graph, unsigned int layerIndex);
127     void ParseSlice(GraphPtr graph, unsigned int layerIndex);
128     void ParseSoftmax(GraphPtr graph, unsigned int layerIndex);
129     void ParseSpaceToBatchNd(GraphPtr graph, unsigned int layerIndex);
130     void ParseSpaceToDepth(GraphPtr graph, unsigned int layerIndex);
131     void ParseSplitter(GraphPtr graph, unsigned int layerIndex);
132     void ParseStack(GraphPtr graph, unsigned int layerIndex);
133     void ParseStandIn(GraphPtr graph, unsigned int layerIndex);
134     void ParseStridedSlice(GraphPtr graph, unsigned int layerIndex);
135     void ParseSubtraction(GraphPtr graph, unsigned int layerIndex);
136     void ParseSwitch(GraphPtr graph, unsigned int layerIndex);
137     void ParseTranspose(GraphPtr graph, unsigned int layerIndex);
138     void ParseTransposeConvolution2d(GraphPtr graph, unsigned int layerIndex);
139
140     void RegisterInputSlots(GraphPtr graph, uint32_t layerIndex,
141                             armnn::IConnectableLayer* layer);
142     void RegisterOutputSlots(GraphPtr graph, uint32_t layerIndex,
143                              armnn::IConnectableLayer* layer);
144
145     // NOTE index here must be from flatbuffer object index property
146     void RegisterOutputSlotOfConnection(uint32_t sourceLayerIndex, uint32_t outputSlotIndex, armnn::IOutputSlot* slot);
147     void RegisterInputSlotOfConnection(uint32_t sourceLayerIndex, uint32_t outputSlotIndex, armnn::IInputSlot* slot);
148
149     void ResetParser();
150
151     void SetupInputLayers(GraphPtr graphPtr);
152     void SetupOutputLayers(GraphPtr graphPtr);
153
154     /// Helper to get the index of the layer in the flatbuffer vector from its bindingId property
155     unsigned int GetInputLayerInVector(GraphPtr graph, int targetId);
156     unsigned int GetOutputLayerInVector(GraphPtr graph, int targetId);
157
158     /// Helper to get the index of the layer in the flatbuffer vector from its index property
159     unsigned int GetLayerIndexInVector(GraphPtr graph, unsigned int index);
160
161     struct FeatureVersions
162     {
163         // Default values to zero for backward compatibility
164         unsigned int m_BindingIdScheme = 0;
165     };
166
167     FeatureVersions GetFeatureVersions(GraphPtr graph);
168
169     /// The network we're building. Gets cleared after it is passed to the user
170     armnn::INetworkPtr                    m_Network;
171     std::vector<LayerParsingFunction>     m_ParserFunctions;
172
173     using NameToBindingInfo = std::pair<std::string, BindingPointInfo >;
174     std::vector<NameToBindingInfo>    m_InputBindings;
175     std::vector<NameToBindingInfo>    m_OutputBindings;
176
177     /// This struct describe connections for each layer
178     struct Connections
179     {
180         // Maps output slot index (property in flatbuffer object) to IOutputSlot pointer
181         std::unordered_map<unsigned int, armnn::IOutputSlot*> outputSlots;
182
183         // Maps output slot index to IInputSlot pointer the output slot should be connected to
184         std::unordered_map<unsigned int, std::vector<armnn::IInputSlot*>> inputSlots;
185     };
186
187     /// Maps layer index (index property in flatbuffer object) to Connections for each layer
188     std::unordered_map<unsigned int, Connections> m_GraphConnections;
189 };
190
191 } // namespace armnnDeserializer