7400885ab110277010e07c29b11da86e1ef6d078
[platform/upstream/armnn.git] / src / armnnSerializer / Serializer.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <armnn/ILayerVisitor.hpp>
8 #include <armnn/LayerVisitorBase.hpp>
9
10 #include <armnnSerializer/ISerializer.hpp>
11
12 #include <unordered_map>
13
14 #include <ArmnnSchema_generated.h>
15
16 namespace armnnSerializer
17 {
18
19 class SerializerVisitor : public armnn::ILayerVisitor
20 {
21 public:
22     SerializerVisitor() : m_layerId(0) {}
23     ~SerializerVisitor() {}
24
25     flatbuffers::FlatBufferBuilder& GetFlatBufferBuilder()
26     {
27         return m_flatBufferBuilder;
28     }
29
30     std::vector<unsigned int>& GetInputIds()
31     {
32         return m_inputIds;
33     }
34
35     std::vector<unsigned int>& GetOutputIds()
36     {
37         return m_outputIds;
38     }
39
40     std::vector<flatbuffers::Offset<armnnSerializer::AnyLayer>>& GetSerializedLayers()
41     {
42         return m_serializedLayers;
43     }
44
45     void VisitAbsLayer(const armnn::IConnectableLayer* layer,
46                        const char* name = nullptr) override;
47
48     void VisitActivationLayer(const armnn::IConnectableLayer* layer,
49                               const armnn::ActivationDescriptor& descriptor,
50                               const char* name = nullptr) override;
51
52     void VisitAdditionLayer(const armnn::IConnectableLayer* layer,
53                             const char* name = nullptr) override;
54
55     void VisitBatchToSpaceNdLayer(const armnn::IConnectableLayer* layer,
56                                   const armnn::BatchToSpaceNdDescriptor& descriptor,
57                                   const char* name = nullptr) override;
58
59     void VisitBatchNormalizationLayer(const armnn::IConnectableLayer* layer,
60                                       const armnn::BatchNormalizationDescriptor& BatchNormalizationDescriptor,
61                                       const armnn::ConstTensor& mean,
62                                       const armnn::ConstTensor& variance,
63                                       const armnn::ConstTensor& beta,
64                                       const armnn::ConstTensor& gamma,
65                                       const char* name = nullptr) override;
66
67     void VisitConcatLayer(const armnn::IConnectableLayer* layer,
68                           const armnn::ConcatDescriptor& concatDescriptor,
69                           const char* name = nullptr) override;
70
71     void VisitConstantLayer(const armnn::IConnectableLayer* layer,
72                             const armnn::ConstTensor& input,
73                             const char* = nullptr) override;
74
75     void VisitConvolution2dLayer(const armnn::IConnectableLayer* layer,
76                                  const armnn::Convolution2dDescriptor& descriptor,
77                                  const armnn::ConstTensor& weights,
78                                  const armnn::Optional<armnn::ConstTensor>& biases,
79                                  const char* = nullptr) override;
80
81     void VisitDepthwiseConvolution2dLayer(const armnn::IConnectableLayer* layer,
82                                           const armnn::DepthwiseConvolution2dDescriptor& descriptor,
83                                           const armnn::ConstTensor& weights,
84                                           const armnn::Optional<armnn::ConstTensor>& biases,
85                                           const char* name = nullptr) override;
86
87     void VisitDequantizeLayer(const armnn::IConnectableLayer* layer,
88                               const char* name = nullptr) override;
89
90     void VisitDetectionPostProcessLayer(const armnn::IConnectableLayer* layer,
91                                         const armnn::DetectionPostProcessDescriptor& descriptor,
92                                         const armnn::ConstTensor& anchors,
93                                         const char* name = nullptr) override;
94
95     void VisitDivisionLayer(const armnn::IConnectableLayer* layer,
96                             const char* name = nullptr) override;
97
98     void VisitEqualLayer(const armnn::IConnectableLayer* layer,
99                          const char* name = nullptr) override;
100
101     void VisitFloorLayer(const armnn::IConnectableLayer *layer,
102                          const char *name = nullptr) override;
103
104     void VisitFullyConnectedLayer(const armnn::IConnectableLayer* layer,
105                                   const armnn::FullyConnectedDescriptor& fullyConnectedDescriptor,
106                                   const armnn::ConstTensor& weights,
107                                   const armnn::Optional<armnn::ConstTensor>& biases,
108                                   const char* name = nullptr) override;
109
110     void VisitGatherLayer(const armnn::IConnectableLayer* layer,
111                           const char* name = nullptr) override;
112
113     void VisitGreaterLayer(const armnn::IConnectableLayer* layer,
114                            const char* name = nullptr) override;
115
116     void VisitInputLayer(const armnn::IConnectableLayer* layer,
117                          armnn::LayerBindingId id,
118                          const char* name = nullptr) override;
119
120     void VisitL2NormalizationLayer(const armnn::IConnectableLayer* layer,
121                                    const armnn::L2NormalizationDescriptor& l2NormalizationDescriptor,
122                                    const char* name = nullptr) override;
123
124     void VisitLstmLayer(const armnn::IConnectableLayer* layer,
125                         const armnn::LstmDescriptor& descriptor,
126                         const armnn::LstmInputParams& params,
127                         const char* name = nullptr) override;
128
129     void VisitMeanLayer(const armnn::IConnectableLayer* layer,
130                         const armnn::MeanDescriptor& descriptor,
131                         const char* name) override;
132
133     void VisitMinimumLayer(const armnn::IConnectableLayer* layer,
134                            const char* name = nullptr) override;
135
136     void VisitMaximumLayer(const armnn::IConnectableLayer* layer,
137                            const char* name = nullptr) override;
138
139     void VisitMergeLayer(const armnn::IConnectableLayer* layer,
140                          const char* name = nullptr) override;
141
142     ARMNN_DEPRECATED_MSG("Use VisitConcatLayer instead")
143     void VisitMergerLayer(const armnn::IConnectableLayer* layer,
144                           const armnn::MergerDescriptor& mergerDescriptor,
145                           const char* name = nullptr) override;
146
147     void VisitMultiplicationLayer(const armnn::IConnectableLayer* layer,
148                                   const char* name = nullptr) override;
149
150     void VisitOutputLayer(const armnn::IConnectableLayer* layer,
151                           armnn::LayerBindingId id,
152                           const char* name = nullptr) override;
153
154     void VisitPadLayer(const armnn::IConnectableLayer* layer,
155                        const armnn::PadDescriptor& PadDescriptor,
156                        const char* name = nullptr) override;
157
158     void VisitPermuteLayer(const armnn::IConnectableLayer* layer,
159                            const armnn::PermuteDescriptor& PermuteDescriptor,
160                            const char* name = nullptr) override;
161
162     void VisitPooling2dLayer(const armnn::IConnectableLayer* layer,
163                              const armnn::Pooling2dDescriptor& pooling2dDescriptor,
164                              const char* name = nullptr) override;
165
166     void VisitPreluLayer(const armnn::IConnectableLayer* layer,
167                          const char* name = nullptr) override;
168
169     void VisitQuantizeLayer(const armnn::IConnectableLayer* layer,
170                             const char* name = nullptr) override;
171
172     void VisitQuantizedLstmLayer(const armnn::IConnectableLayer* layer,
173                                  const armnn::QuantizedLstmInputParams& params,
174                                  const char* name = nullptr) override;
175
176     void VisitReshapeLayer(const armnn::IConnectableLayer* layer,
177                            const armnn::ReshapeDescriptor& reshapeDescriptor,
178                            const char* name = nullptr) override;
179
180     void VisitResizeLayer(const armnn::IConnectableLayer* layer,
181                           const armnn::ResizeDescriptor& resizeDescriptor,
182                           const char* name = nullptr) override;
183
184     ARMNN_DEPRECATED_MSG("Use VisitResizeLayer instead")
185     void VisitResizeBilinearLayer(const armnn::IConnectableLayer* layer,
186                                   const armnn::ResizeBilinearDescriptor& resizeDescriptor,
187                                   const char* name = nullptr) override;
188
189     void VisitRsqrtLayer(const armnn::IConnectableLayer* layer,
190                          const char* name = nullptr) override;
191
192     void VisitSoftmaxLayer(const armnn::IConnectableLayer* layer,
193                            const armnn::SoftmaxDescriptor& softmaxDescriptor,
194                            const char* name = nullptr) override;
195
196     void VisitSpaceToBatchNdLayer(const armnn::IConnectableLayer* layer,
197                                   const armnn::SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
198                                   const char* name = nullptr) override;
199
200     void VisitSpaceToDepthLayer(const armnn::IConnectableLayer* layer,
201                                 const armnn::SpaceToDepthDescriptor& spaceToDepthDescriptor,
202                                 const char* name = nullptr) override;
203
204     void VisitNormalizationLayer(const armnn::IConnectableLayer* layer,
205                                  const armnn::NormalizationDescriptor& normalizationDescriptor,
206                                  const char* name = nullptr) override;
207
208     void VisitSplitterLayer(const armnn::IConnectableLayer* layer,
209                             const armnn::ViewsDescriptor& viewsDescriptor,
210                             const char* name = nullptr) override;
211
212     void VisitStackLayer(const armnn::IConnectableLayer* layer,
213                          const armnn::StackDescriptor& stackDescriptor,
214                          const char* name = nullptr) override;
215
216     void VisitStridedSliceLayer(const armnn::IConnectableLayer* layer,
217                                 const armnn::StridedSliceDescriptor& stridedSliceDescriptor,
218                                 const char* name = nullptr) override;
219
220     void VisitSubtractionLayer(const armnn::IConnectableLayer* layer,
221                                const char* name = nullptr) override;
222
223     void VisitSwitchLayer(const armnn::IConnectableLayer* layer,
224                           const char* name = nullptr) override;
225
226     void VisitTransposeConvolution2dLayer(const armnn::IConnectableLayer* layer,
227                                           const armnn::TransposeConvolution2dDescriptor& descriptor,
228                                           const armnn::ConstTensor& weights,
229                                           const armnn::Optional<armnn::ConstTensor>& biases,
230                                           const char* = nullptr) override;
231
232 private:
233
234     /// Creates the Input Slots and Output Slots and LayerBase for the layer.
235     flatbuffers::Offset<armnnSerializer::LayerBase> CreateLayerBase(
236             const armnn::IConnectableLayer* layer,
237             const armnnSerializer::LayerType layerType);
238
239     /// Creates the serializer AnyLayer for the layer and adds it to m_serializedLayers.
240     void CreateAnyLayer(const flatbuffers::Offset<void>& layer, const armnnSerializer::Layer serializerLayer);
241
242     /// Creates the serializer ConstTensor for the armnn ConstTensor.
243     flatbuffers::Offset<armnnSerializer::ConstTensor> CreateConstTensorInfo(
244             const armnn::ConstTensor& constTensor);
245
246     template <typename T>
247     flatbuffers::Offset<flatbuffers::Vector<T>> CreateDataVector(const void* memory, unsigned int size);
248
249     ///Function which maps Guid to an index
250     uint32_t GetSerializedId(unsigned int guid);
251
252     /// Creates the serializer InputSlots for the layer.
253     std::vector<flatbuffers::Offset<armnnSerializer::InputSlot>> CreateInputSlots(
254             const armnn::IConnectableLayer* layer);
255
256     /// Creates the serializer OutputSlots for the layer.
257     std::vector<flatbuffers::Offset<armnnSerializer::OutputSlot>> CreateOutputSlots(
258             const armnn::IConnectableLayer* layer);
259
260     /// FlatBufferBuilder to create our layers' FlatBuffers.
261     flatbuffers::FlatBufferBuilder m_flatBufferBuilder;
262
263     /// AnyLayers required by the SerializedGraph.
264     std::vector<flatbuffers::Offset<armnnSerializer::AnyLayer>> m_serializedLayers;
265
266     /// Guids of all Input Layers required by the SerializedGraph.
267     std::vector<unsigned int> m_inputIds;
268
269     /// Guids of all Output Layers required by the SerializedGraph.
270     std::vector<unsigned int> m_outputIds;
271
272     /// Mapped Guids of all Layers to match our index.
273     std::unordered_map<unsigned int, uint32_t > m_guidMap;
274
275     /// layer within our FlatBuffer index.
276     uint32_t m_layerId;
277 };
278
279 class Serializer : public ISerializer
280 {
281 public:
282     Serializer() {}
283     ~Serializer() {}
284
285     /// Serializes the network to ArmNN SerializedGraph.
286     /// @param [in] inNetwork The network to be serialized.
287     void Serialize(const armnn::INetwork& inNetwork) override;
288
289     /// Serializes the SerializedGraph to the stream.
290     /// @param [stream] the stream to save to
291     /// @return true if graph is Serialized to the Stream, false otherwise
292     bool SaveSerializedToStream(std::ostream& stream) override;
293
294 private:
295
296     /// Visitor to contruct serialized network
297     SerializerVisitor m_SerializerVisitor;
298 };
299
300 } //namespace armnnSerializer