IVGCVSW-3722 Add front end support for ArgMinMax
[platform/upstream/armnn.git] / src / armnnSerializer / Serializer.hpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <armnn/ILayerVisitor.hpp>
8 #include <armnn/LayerVisitorBase.hpp>
9
10 #include <armnnSerializer/ISerializer.hpp>
11
12 #include <unordered_map>
13
14 #include <ArmnnSchema_generated.h>
15
16 namespace armnnSerializer
17 {
18
19 class SerializerVisitor : public armnn::ILayerVisitor
20 {
21 public:
22     SerializerVisitor() : m_layerId(0) {}
23     ~SerializerVisitor() {}
24
25     flatbuffers::FlatBufferBuilder& GetFlatBufferBuilder()
26     {
27         return m_flatBufferBuilder;
28     }
29
30     std::vector<unsigned int>& GetInputIds()
31     {
32         return m_inputIds;
33     }
34
35     std::vector<unsigned int>& GetOutputIds()
36     {
37         return m_outputIds;
38     }
39
40     std::vector<flatbuffers::Offset<armnnSerializer::AnyLayer>>& GetSerializedLayers()
41     {
42         return m_serializedLayers;
43     }
44
45     void VisitAbsLayer(const armnn::IConnectableLayer* layer,
46                        const char* name = nullptr) override;
47
48     void VisitActivationLayer(const armnn::IConnectableLayer* layer,
49                               const armnn::ActivationDescriptor& descriptor,
50                               const char* name = nullptr) override;
51
52     void VisitAdditionLayer(const armnn::IConnectableLayer* layer,
53                             const char* name = nullptr) override;
54
55     void VisitArgMinMaxLayer(const armnn::IConnectableLayer* layer,
56                              const armnn::ArgMinMaxDescriptor& argMinMaxDescriptor,
57                              const char* name = nullptr) override;
58
59     void VisitBatchToSpaceNdLayer(const armnn::IConnectableLayer* layer,
60                                   const armnn::BatchToSpaceNdDescriptor& descriptor,
61                                   const char* name = nullptr) override;
62
63     void VisitBatchNormalizationLayer(const armnn::IConnectableLayer* layer,
64                                       const armnn::BatchNormalizationDescriptor& BatchNormalizationDescriptor,
65                                       const armnn::ConstTensor& mean,
66                                       const armnn::ConstTensor& variance,
67                                       const armnn::ConstTensor& beta,
68                                       const armnn::ConstTensor& gamma,
69                                       const char* name = nullptr) override;
70
71     void VisitConcatLayer(const armnn::IConnectableLayer* layer,
72                           const armnn::ConcatDescriptor& concatDescriptor,
73                           const char* name = nullptr) override;
74
75     void VisitConstantLayer(const armnn::IConnectableLayer* layer,
76                             const armnn::ConstTensor& input,
77                             const char* = nullptr) override;
78
79     void VisitConvolution2dLayer(const armnn::IConnectableLayer* layer,
80                                  const armnn::Convolution2dDescriptor& descriptor,
81                                  const armnn::ConstTensor& weights,
82                                  const armnn::Optional<armnn::ConstTensor>& biases,
83                                  const char* = nullptr) override;
84
85     void VisitDepthwiseConvolution2dLayer(const armnn::IConnectableLayer* layer,
86                                           const armnn::DepthwiseConvolution2dDescriptor& descriptor,
87                                           const armnn::ConstTensor& weights,
88                                           const armnn::Optional<armnn::ConstTensor>& biases,
89                                           const char* name = nullptr) override;
90
91     void VisitDequantizeLayer(const armnn::IConnectableLayer* layer,
92                               const char* name = nullptr) override;
93
94     void VisitDetectionPostProcessLayer(const armnn::IConnectableLayer* layer,
95                                         const armnn::DetectionPostProcessDescriptor& descriptor,
96                                         const armnn::ConstTensor& anchors,
97                                         const char* name = nullptr) override;
98
99     void VisitDivisionLayer(const armnn::IConnectableLayer* layer,
100                             const char* name = nullptr) override;
101
102     void VisitEqualLayer(const armnn::IConnectableLayer* layer,
103                          const char* name = nullptr) override;
104
105     void VisitFloorLayer(const armnn::IConnectableLayer *layer,
106                          const char *name = nullptr) override;
107
108     void VisitFullyConnectedLayer(const armnn::IConnectableLayer* layer,
109                                   const armnn::FullyConnectedDescriptor& fullyConnectedDescriptor,
110                                   const armnn::ConstTensor& weights,
111                                   const armnn::Optional<armnn::ConstTensor>& biases,
112                                   const char* name = nullptr) override;
113
114     void VisitGatherLayer(const armnn::IConnectableLayer* layer,
115                           const char* name = nullptr) override;
116
117     void VisitGreaterLayer(const armnn::IConnectableLayer* layer,
118                            const char* name = nullptr) override;
119
120     void VisitInputLayer(const armnn::IConnectableLayer* layer,
121                          armnn::LayerBindingId id,
122                          const char* name = nullptr) override;
123
124     void VisitL2NormalizationLayer(const armnn::IConnectableLayer* layer,
125                                    const armnn::L2NormalizationDescriptor& l2NormalizationDescriptor,
126                                    const char* name = nullptr) override;
127
128     void VisitLstmLayer(const armnn::IConnectableLayer* layer,
129                         const armnn::LstmDescriptor& descriptor,
130                         const armnn::LstmInputParams& params,
131                         const char* name = nullptr) override;
132
133     void VisitMeanLayer(const armnn::IConnectableLayer* layer,
134                         const armnn::MeanDescriptor& descriptor,
135                         const char* name) override;
136
137     void VisitMinimumLayer(const armnn::IConnectableLayer* layer,
138                            const char* name = nullptr) override;
139
140     void VisitMaximumLayer(const armnn::IConnectableLayer* layer,
141                            const char* name = nullptr) override;
142
143     void VisitMergeLayer(const armnn::IConnectableLayer* layer,
144                          const char* name = nullptr) override;
145
146     ARMNN_DEPRECATED_MSG("Use VisitConcatLayer instead")
147     void VisitMergerLayer(const armnn::IConnectableLayer* layer,
148                           const armnn::MergerDescriptor& mergerDescriptor,
149                           const char* name = nullptr) override;
150
151     void VisitMultiplicationLayer(const armnn::IConnectableLayer* layer,
152                                   const char* name = nullptr) override;
153
154     void VisitOutputLayer(const armnn::IConnectableLayer* layer,
155                           armnn::LayerBindingId id,
156                           const char* name = nullptr) override;
157
158     void VisitPadLayer(const armnn::IConnectableLayer* layer,
159                        const armnn::PadDescriptor& PadDescriptor,
160                        const char* name = nullptr) override;
161
162     void VisitPermuteLayer(const armnn::IConnectableLayer* layer,
163                            const armnn::PermuteDescriptor& PermuteDescriptor,
164                            const char* name = nullptr) override;
165
166     void VisitPooling2dLayer(const armnn::IConnectableLayer* layer,
167                              const armnn::Pooling2dDescriptor& pooling2dDescriptor,
168                              const char* name = nullptr) override;
169
170     void VisitPreluLayer(const armnn::IConnectableLayer* layer,
171                          const char* name = nullptr) override;
172
173     void VisitQuantizeLayer(const armnn::IConnectableLayer* layer,
174                             const char* name = nullptr) override;
175
176     void VisitQuantizedLstmLayer(const armnn::IConnectableLayer* layer,
177                                  const armnn::QuantizedLstmInputParams& params,
178                                  const char* name = nullptr) override;
179
180     void VisitReshapeLayer(const armnn::IConnectableLayer* layer,
181                            const armnn::ReshapeDescriptor& reshapeDescriptor,
182                            const char* name = nullptr) override;
183
184     void VisitResizeLayer(const armnn::IConnectableLayer* layer,
185                           const armnn::ResizeDescriptor& resizeDescriptor,
186                           const char* name = nullptr) override;
187
188     ARMNN_DEPRECATED_MSG("Use VisitResizeLayer instead")
189     void VisitResizeBilinearLayer(const armnn::IConnectableLayer* layer,
190                                   const armnn::ResizeBilinearDescriptor& resizeDescriptor,
191                                   const char* name = nullptr) override;
192
193     void VisitRsqrtLayer(const armnn::IConnectableLayer* layer,
194                          const char* name = nullptr) override;
195
196     void VisitSoftmaxLayer(const armnn::IConnectableLayer* layer,
197                            const armnn::SoftmaxDescriptor& softmaxDescriptor,
198                            const char* name = nullptr) override;
199
200     void VisitSpaceToBatchNdLayer(const armnn::IConnectableLayer* layer,
201                                   const armnn::SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
202                                   const char* name = nullptr) override;
203
204     void VisitSpaceToDepthLayer(const armnn::IConnectableLayer* layer,
205                                 const armnn::SpaceToDepthDescriptor& spaceToDepthDescriptor,
206                                 const char* name = nullptr) override;
207
208     void VisitNormalizationLayer(const armnn::IConnectableLayer* layer,
209                                  const armnn::NormalizationDescriptor& normalizationDescriptor,
210                                  const char* name = nullptr) override;
211
212     void VisitSplitterLayer(const armnn::IConnectableLayer* layer,
213                             const armnn::ViewsDescriptor& viewsDescriptor,
214                             const char* name = nullptr) override;
215
216     void VisitStackLayer(const armnn::IConnectableLayer* layer,
217                          const armnn::StackDescriptor& stackDescriptor,
218                          const char* name = nullptr) override;
219
220     void VisitStridedSliceLayer(const armnn::IConnectableLayer* layer,
221                                 const armnn::StridedSliceDescriptor& stridedSliceDescriptor,
222                                 const char* name = nullptr) override;
223
224     void VisitSubtractionLayer(const armnn::IConnectableLayer* layer,
225                                const char* name = nullptr) override;
226
227     void VisitSwitchLayer(const armnn::IConnectableLayer* layer,
228                           const char* name = nullptr) override;
229
230     void VisitTransposeConvolution2dLayer(const armnn::IConnectableLayer* layer,
231                                           const armnn::TransposeConvolution2dDescriptor& descriptor,
232                                           const armnn::ConstTensor& weights,
233                                           const armnn::Optional<armnn::ConstTensor>& biases,
234                                           const char* = nullptr) override;
235
236 private:
237
238     /// Creates the Input Slots and Output Slots and LayerBase for the layer.
239     flatbuffers::Offset<armnnSerializer::LayerBase> CreateLayerBase(
240             const armnn::IConnectableLayer* layer,
241             const armnnSerializer::LayerType layerType);
242
243     /// Creates the serializer AnyLayer for the layer and adds it to m_serializedLayers.
244     void CreateAnyLayer(const flatbuffers::Offset<void>& layer, const armnnSerializer::Layer serializerLayer);
245
246     /// Creates the serializer ConstTensor for the armnn ConstTensor.
247     flatbuffers::Offset<armnnSerializer::ConstTensor> CreateConstTensorInfo(
248             const armnn::ConstTensor& constTensor);
249
250     template <typename T>
251     flatbuffers::Offset<flatbuffers::Vector<T>> CreateDataVector(const void* memory, unsigned int size);
252
253     ///Function which maps Guid to an index
254     uint32_t GetSerializedId(unsigned int guid);
255
256     /// Creates the serializer InputSlots for the layer.
257     std::vector<flatbuffers::Offset<armnnSerializer::InputSlot>> CreateInputSlots(
258             const armnn::IConnectableLayer* layer);
259
260     /// Creates the serializer OutputSlots for the layer.
261     std::vector<flatbuffers::Offset<armnnSerializer::OutputSlot>> CreateOutputSlots(
262             const armnn::IConnectableLayer* layer);
263
264     /// FlatBufferBuilder to create our layers' FlatBuffers.
265     flatbuffers::FlatBufferBuilder m_flatBufferBuilder;
266
267     /// AnyLayers required by the SerializedGraph.
268     std::vector<flatbuffers::Offset<armnnSerializer::AnyLayer>> m_serializedLayers;
269
270     /// Guids of all Input Layers required by the SerializedGraph.
271     std::vector<unsigned int> m_inputIds;
272
273     /// Guids of all Output Layers required by the SerializedGraph.
274     std::vector<unsigned int> m_outputIds;
275
276     /// Mapped Guids of all Layers to match our index.
277     std::unordered_map<unsigned int, uint32_t > m_guidMap;
278
279     /// layer within our FlatBuffer index.
280     uint32_t m_layerId;
281 };
282
283 class Serializer : public ISerializer
284 {
285 public:
286     Serializer() {}
287     ~Serializer() {}
288
289     /// Serializes the network to ArmNN SerializedGraph.
290     /// @param [in] inNetwork The network to be serialized.
291     void Serialize(const armnn::INetwork& inNetwork) override;
292
293     /// Serializes the SerializedGraph to the stream.
294     /// @param [stream] the stream to save to
295     /// @return true if graph is Serialized to the Stream, false otherwise
296     bool SaveSerializedToStream(std::ostream& stream) override;
297
298 private:
299
300     /// Visitor to contruct serialized network
301     SerializerVisitor m_SerializerVisitor;
302 };
303
304 } //namespace armnnSerializer