IVGCVSW-5593 Implement Pimpl Idiom for serialization classes
[platform/upstream/armnn.git] / src / armnnDeserializer / Deserializer.hpp
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <armnn/INetwork.hpp>
9 #include <armnnDeserializer/IDeserializer.hpp>
10 #include <ArmnnSchema_generated.h>
11
12 #include <unordered_map>
13
14 namespace armnnDeserializer
15 {
16
17 // Shorthands for deserializer types
18 using ConstTensorRawPtr = const armnnSerializer::ConstTensor *;
19 using GraphPtr = const armnnSerializer::SerializedGraph *;
20 using TensorRawPtr = const armnnSerializer::TensorInfo *;
21 using PoolingDescriptor = const armnnSerializer::Pooling2dDescriptor *;
22 using NormalizationDescriptorPtr = const armnnSerializer::NormalizationDescriptor *;
23 using LstmDescriptorPtr = const armnnSerializer::LstmDescriptor *;
24 using LstmInputParamsPtr = const armnnSerializer::LstmInputParams *;
25 using QLstmDescriptorPtr = const armnnSerializer::QLstmDescriptor *;
26 using QunatizedLstmInputParamsPtr = const armnnSerializer::QuantizedLstmInputParams *;
27 using TensorRawPtrVector = std::vector<TensorRawPtr>;
28 using LayerRawPtr = const armnnSerializer::LayerBase *;
29 using LayerBaseRawPtr = const armnnSerializer::LayerBase *;
30 using LayerBaseRawPtrVector = std::vector<LayerBaseRawPtr>;
31
32 class IDeserializer::DeserializerImpl
33 {
34 public:
35
36     /// Create an input network from binary file contents
37     armnn::INetworkPtr CreateNetworkFromBinary(const std::vector<uint8_t>& binaryContent);
38
39     /// Create an input network from a binary input stream
40     armnn::INetworkPtr CreateNetworkFromBinary(std::istream& binaryContent);
41
42     /// Retrieve binding info (layer id and tensor info) for the network input identified by the given layer name
43     BindingPointInfo GetNetworkInputBindingInfo(unsigned int layerId, const std::string& name) const;
44
45     /// Retrieve binding info (layer id and tensor info) for the network output identified by the given layer name
46     BindingPointInfo GetNetworkOutputBindingInfo(unsigned int layerId, const std::string& name) const;
47
48     DeserializerImpl();
49     ~DeserializerImpl() = default;
50
51     // No copying allowed until it is wanted and properly implemented
52     DeserializerImpl(const DeserializerImpl&) = delete;
53     DeserializerImpl& operator=(const DeserializerImpl&) = delete;
54
55     // testable helpers
56     static GraphPtr LoadGraphFromBinary(const uint8_t* binaryContent, size_t len);
57     static TensorRawPtrVector GetInputs(const GraphPtr& graph, unsigned int layerIndex);
58     static TensorRawPtrVector GetOutputs(const GraphPtr& graph, unsigned int layerIndex);
59     static LayerBaseRawPtr GetBaseLayer(const GraphPtr& graphPtr, unsigned int layerIndex);
60     static int32_t GetBindingLayerInfo(const GraphPtr& graphPtr, unsigned int layerIndex);
61     static std::string GetLayerName(const GraphPtr& graph, unsigned int index);
62     static armnn::Pooling2dDescriptor GetPoolingDescriptor(PoolingDescriptor pooling2dDescriptor,
63                                                            unsigned int layerIndex);
64     static armnn::NormalizationDescriptor GetNormalizationDescriptor(
65         NormalizationDescriptorPtr normalizationDescriptor, unsigned int layerIndex);
66     static armnn::LstmDescriptor GetLstmDescriptor(LstmDescriptorPtr lstmDescriptor);
67     static armnn::LstmInputParams GetLstmInputParams(LstmDescriptorPtr lstmDescriptor,
68                                                      LstmInputParamsPtr lstmInputParams);
69     static armnn::QLstmDescriptor GetQLstmDescriptor(QLstmDescriptorPtr qLstmDescriptorPtr);
70     static armnn::TensorInfo OutputShapeOfReshape(const armnn::TensorInfo & inputTensorInfo,
71                                                   const std::vector<uint32_t> & targetDimsIn);
72
73 private:
74     /// Create the network from an already loaded flatbuffers graph
75     armnn::INetworkPtr CreateNetworkFromGraph(GraphPtr graph);
76
77     // signature for the parser functions
78     using LayerParsingFunction = void(DeserializerImpl::*)(GraphPtr graph, unsigned int layerIndex);
79
80     void ParseUnsupportedLayer(GraphPtr graph, unsigned int layerIndex);
81     void ParseAbs(GraphPtr graph, unsigned int layerIndex);
82     void ParseActivation(GraphPtr graph, unsigned int layerIndex);
83     void ParseAdd(GraphPtr graph, unsigned int layerIndex);
84     void ParseArgMinMax(GraphPtr graph, unsigned int layerIndex);
85     void ParseBatchToSpaceNd(GraphPtr graph, unsigned int layerIndex);
86     void ParseBatchNormalization(GraphPtr graph, unsigned int layerIndex);
87     void ParseComparison(GraphPtr graph, unsigned int layerIndex);
88     void ParseConcat(GraphPtr graph, unsigned int layerIndex);
89     void ParseConstant(GraphPtr graph, unsigned int layerIndex);
90     void ParseConvolution2d(GraphPtr graph, unsigned int layerIndex);
91     void ParseDepthToSpace(GraphPtr graph, unsigned int layerIndex);
92     void ParseDepthwiseConvolution2d(GraphPtr graph, unsigned int layerIndex);
93     void ParseDequantize(GraphPtr graph, unsigned int layerIndex);
94     void ParseDetectionPostProcess(GraphPtr graph, unsigned int layerIndex);
95     void ParseDivision(GraphPtr graph, unsigned int layerIndex);
96     void ParseElementwiseUnary(GraphPtr graph, unsigned int layerIndex);
97     void ParseEqual(GraphPtr graph, unsigned int layerIndex);
98     void ParseFill(GraphPtr graph, unsigned int layerIndex);
99     void ParseFloor(GraphPtr graph, unsigned int layerIndex);
100     void ParseFullyConnected(GraphPtr graph, unsigned int layerIndex);
101     void ParseGather(GraphPtr graph, unsigned int layerIndex);
102     void ParseGreater(GraphPtr graph, unsigned int layerIndex);
103     void ParseInstanceNormalization(GraphPtr graph, unsigned int layerIndex);
104     void ParseL2Normalization(GraphPtr graph, unsigned int layerIndex);
105     void ParseLogicalBinary(GraphPtr graph, unsigned int layerIndex);
106     void ParseLogSoftmax(GraphPtr graph, unsigned int layerIndex);
107     void ParseMaximum(GraphPtr graph, unsigned int layerIndex);
108     void ParseMean(GraphPtr graph, unsigned int layerIndex);
109     void ParseMinimum(GraphPtr graph, unsigned int layerIndex);
110     void ParseMerge(GraphPtr graph, unsigned int layerIndex);
111     void ParseMultiplication(GraphPtr graph, unsigned int layerIndex);
112     void ParseNormalization(GraphPtr graph, unsigned int layerIndex);
113     void ParseLstm(GraphPtr graph, unsigned int layerIndex);
114     void ParseQuantizedLstm(GraphPtr graph, unsigned int layerIndex);
115     void ParsePad(GraphPtr graph, unsigned int layerIndex);
116     void ParsePermute(GraphPtr graph, unsigned int layerIndex);
117     void ParsePooling2d(GraphPtr graph, unsigned int layerIndex);
118     void ParsePrelu(GraphPtr graph, unsigned int layerIndex);
119     void ParseQLstm(GraphPtr graph, unsigned int layerIndex);
120     void ParseQuantize(GraphPtr graph, unsigned int layerIndex);
121     void ParseRank(GraphPtr graph, unsigned int layerIndex);
122     void ParseReshape(GraphPtr graph, unsigned int layerIndex);
123     void ParseResize(GraphPtr graph, unsigned int layerIndex);
124     void ParseResizeBilinear(GraphPtr graph, unsigned int layerIndex);
125     void ParseRsqrt(GraphPtr graph, unsigned int layerIndex);
126     void ParseSlice(GraphPtr graph, unsigned int layerIndex);
127     void ParseSoftmax(GraphPtr graph, unsigned int layerIndex);
128     void ParseSpaceToBatchNd(GraphPtr graph, unsigned int layerIndex);
129     void ParseSpaceToDepth(GraphPtr graph, unsigned int layerIndex);
130     void ParseSplitter(GraphPtr graph, unsigned int layerIndex);
131     void ParseStack(GraphPtr graph, unsigned int layerIndex);
132     void ParseStandIn(GraphPtr graph, unsigned int layerIndex);
133     void ParseStridedSlice(GraphPtr graph, unsigned int layerIndex);
134     void ParseSubtraction(GraphPtr graph, unsigned int layerIndex);
135     void ParseSwitch(GraphPtr graph, unsigned int layerIndex);
136     void ParseTranspose(GraphPtr graph, unsigned int layerIndex);
137     void ParseTransposeConvolution2d(GraphPtr graph, unsigned int layerIndex);
138
139     void RegisterInputSlots(GraphPtr graph, uint32_t layerIndex,
140                             armnn::IConnectableLayer* layer);
141     void RegisterOutputSlots(GraphPtr graph, uint32_t layerIndex,
142                              armnn::IConnectableLayer* layer);
143
144     // NOTE index here must be from flatbuffer object index property
145     void RegisterOutputSlotOfConnection(uint32_t sourceLayerIndex, uint32_t outputSlotIndex, armnn::IOutputSlot* slot);
146     void RegisterInputSlotOfConnection(uint32_t sourceLayerIndex, uint32_t outputSlotIndex, armnn::IInputSlot* slot);
147
148     void ResetParser();
149
150     void SetupInputLayers(GraphPtr graphPtr);
151     void SetupOutputLayers(GraphPtr graphPtr);
152
153     /// Helper to get the index of the layer in the flatbuffer vector from its bindingId property
154     unsigned int GetInputLayerInVector(GraphPtr graph, int targetId);
155     unsigned int GetOutputLayerInVector(GraphPtr graph, int targetId);
156
157     /// Helper to get the index of the layer in the flatbuffer vector from its index property
158     unsigned int GetLayerIndexInVector(GraphPtr graph, unsigned int index);
159
160     struct FeatureVersions
161     {
162         // Default values to zero for backward compatibility
163         unsigned int m_BindingIdScheme = 0;
164     };
165
166     FeatureVersions GetFeatureVersions(GraphPtr graph);
167
168     /// The network we're building. Gets cleared after it is passed to the user
169     armnn::INetworkPtr                    m_Network;
170     std::vector<LayerParsingFunction>     m_ParserFunctions;
171
172     using NameToBindingInfo = std::pair<std::string, BindingPointInfo >;
173     std::vector<NameToBindingInfo>    m_InputBindings;
174     std::vector<NameToBindingInfo>    m_OutputBindings;
175
176     /// This struct describe connections for each layer
177     struct Connections
178     {
179         // Maps output slot index (property in flatbuffer object) to IOutputSlot pointer
180         std::unordered_map<unsigned int, armnn::IOutputSlot*> outputSlots;
181
182         // Maps output slot index to IInputSlot pointer the output slot should be connected to
183         std::unordered_map<unsigned int, std::vector<armnn::IInputSlot*>> inputSlots;
184     };
185
186     /// Maps layer index (index property in flatbuffer object) to Connections for each layer
187     std::unordered_map<unsigned int, Connections> m_GraphConnections;
188 };
189
190 } // namespace armnnDeserializer