56d313f97b4568ac75eb614c484ec0eda3986fd3
[platform/upstream/armnn.git] / src / armnnSerializer / Serializer.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "Serializer.hpp"
7
8 #include "SerializerUtils.hpp"
9
10 #include <armnn/ArmNN.hpp>
11
12 #include <iostream>
13
14 #include <boost/numeric/conversion/cast.hpp>
15
16 #include <flatbuffers/util.h>
17
18 using namespace armnn;
19 namespace fb = flatbuffers;
20 namespace serializer = armnnSerializer;
21
22 namespace armnnSerializer
23 {
24
25 serializer::ActivationFunction GetFlatBufferActivationFunction(armnn::ActivationFunction function)
26 {
27     switch (function)
28     {
29         case armnn::ActivationFunction::Sigmoid:
30             return serializer::ActivationFunction::ActivationFunction_Sigmoid;
31         case armnn::ActivationFunction::TanH:
32             return serializer::ActivationFunction::ActivationFunction_TanH;
33         case armnn::ActivationFunction::Linear:
34             return serializer::ActivationFunction::ActivationFunction_Linear;
35         case armnn::ActivationFunction::ReLu:
36             return serializer::ActivationFunction::ActivationFunction_ReLu;
37         case armnn::ActivationFunction::BoundedReLu:
38             return serializer::ActivationFunction::ActivationFunction_BoundedReLu;
39         case armnn::ActivationFunction::LeakyReLu:
40             return serializer::ActivationFunction::ActivationFunction_LeakyReLu;
41         case armnn::ActivationFunction::Abs:
42             return serializer::ActivationFunction::ActivationFunction_Abs;
43         case armnn::ActivationFunction::Sqrt:
44             return serializer::ActivationFunction::ActivationFunction_Sqrt;
45         case armnn::ActivationFunction::Square:
46             return serializer::ActivationFunction::ActivationFunction_Square;
47         default:
48             return serializer::ActivationFunction::ActivationFunction_Sigmoid;
49     }
50 }
51
52 uint32_t SerializerVisitor::GetSerializedId(unsigned int guid)
53 {
54     std::pair<unsigned int, uint32_t> guidPair(guid, m_layerId);
55
56     if (m_guidMap.empty())
57     {
58         m_guidMap.insert(guidPair);
59     }
60     else if (m_guidMap.find(guid) == m_guidMap.end())
61     {
62         guidPair.second = ++m_layerId;
63         m_guidMap.insert(guidPair);
64         return m_layerId;
65     }
66     return m_guidMap[guid];
67 }
68
69 // Build FlatBuffer for Input Layer
70 void SerializerVisitor::VisitInputLayer(const armnn::IConnectableLayer* layer, LayerBindingId id, const char* name)
71 {
72     // Create FlatBuffer BaseLayer
73     auto flatBufferInputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Input);
74
75     // Create FlatBuffer BindableBaseLayer
76     auto flatBufferInputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
77                                                                                 flatBufferInputBaseLayer,
78                                                                                 id);
79     // Push layer Guid to outputIds.
80     m_inputIds.push_back(GetSerializedId(layer->GetGuid()));
81
82     // Create the FlatBuffer InputLayer
83     auto flatBufferInputLayer = serializer::CreateInputLayer(m_flatBufferBuilder, flatBufferInputBindableBaseLayer);
84
85     // Add the AnyLayer to the FlatBufferLayers
86     CreateAnyLayer(flatBufferInputLayer.o, serializer::Layer::Layer_InputLayer);
87 }
88
89 // Build FlatBuffer for Output Layer
90 void SerializerVisitor::VisitOutputLayer(const armnn::IConnectableLayer* layer, LayerBindingId id, const char* name)
91 {
92     // Create FlatBuffer BaseLayer
93     auto flatBufferOutputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Output);
94
95     // Create FlatBuffer BindableBaseLayer
96     auto flatBufferOutputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
97                                                                                  flatBufferOutputBaseLayer,
98                                                                                  id);
99     // Push layer Guid to outputIds.
100     m_outputIds.push_back(GetSerializedId(layer->GetGuid()));
101
102     // Create the FlatBuffer OutputLayer
103     auto flatBufferOutputLayer = serializer::CreateOutputLayer(m_flatBufferBuilder, flatBufferOutputBindableBaseLayer);
104     // Add the AnyLayer to the FlatBufferLayers
105     CreateAnyLayer(flatBufferOutputLayer.o, serializer::Layer::Layer_OutputLayer);
106 }
107
108 void SerializerVisitor::VisitAbsLayer(const armnn::IConnectableLayer* layer, const char* name)
109 {
110     throw UnimplementedException("SerializerVisitor::VisitAbsLayer is not implemented");
111 }
112
113 // Build FlatBuffer for Activation Layer
114 void SerializerVisitor::VisitActivationLayer(const armnn::IConnectableLayer* layer,
115                                              const armnn::ActivationDescriptor& descriptor,
116                                              const char* name)
117 {
118     // Create FlatBuffer BaseLayer
119     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Activation);
120
121     // Create the FlatBuffer ActivationDescriptor
122     auto flatBufferDescriptor = CreateActivationDescriptor(m_flatBufferBuilder,
123                                                            GetFlatBufferActivationFunction(descriptor.m_Function),
124                                                            descriptor.m_A,
125                                                            descriptor.m_B);
126
127     // Create the FlatBuffer ActivationLayer
128     auto flatBufferAdditionLayer = CreateActivationLayer(m_flatBufferBuilder,
129                                                          flatBufferBaseLayer,
130                                                          flatBufferDescriptor);
131
132     // Add the AnyLayer to the FlatBufferLayers
133     CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_ActivationLayer);
134 }
135
136 // Build FlatBuffer for Addition Layer
137 void SerializerVisitor::VisitAdditionLayer(const armnn::IConnectableLayer* layer, const char* name)
138 {
139     // Create FlatBuffer BaseLayer
140     auto flatBufferAdditionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Addition);
141
142     // Create the FlatBuffer AdditionLayer
143     auto flatBufferAdditionLayer = serializer::CreateAdditionLayer(m_flatBufferBuilder, flatBufferAdditionBaseLayer);
144
145     // Add the AnyLayer to the FlatBufferLayers
146     CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_AdditionLayer);
147 }
148
149 // Build FlatBuffer for BatchToSpaceNd Layer
150 void SerializerVisitor::VisitBatchToSpaceNdLayer(const armnn::IConnectableLayer* layer,
151                                                  const armnn::BatchToSpaceNdDescriptor& descriptor,
152                                                  const char* name)
153 {
154     // Create FlatBuffer BaseLayer
155     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_BatchToSpaceNd);
156
157     std::vector<unsigned int> crops;
158     crops.reserve(descriptor.m_Crops.size() * 2);
159     for (auto& crop : descriptor.m_Crops)
160     {
161         crops.push_back(crop.first);
162         crops.push_back(crop.second);
163     }
164
165     auto flatBufferDescriptor =
166         CreateBatchToSpaceNdDescriptor(m_flatBufferBuilder,
167                                        m_flatBufferBuilder.CreateVector(descriptor.m_BlockShape),
168                                        m_flatBufferBuilder.CreateVector(crops),
169                                        GetFlatBufferDataLayout(descriptor.m_DataLayout));
170
171     auto flatBufferLayer = serializer::CreateBatchToSpaceNdLayer(m_flatBufferBuilder,
172                                                                  flatBufferBaseLayer,
173                                                                  flatBufferDescriptor);
174
175     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_BatchToSpaceNdLayer);
176 }
177
178 void SerializerVisitor::VisitBatchNormalizationLayer(const armnn::IConnectableLayer* layer,
179                                                      const armnn::BatchNormalizationDescriptor& batchNormDescriptor,
180                                                      const armnn::ConstTensor& mean,
181                                                      const armnn::ConstTensor& variance,
182                                                      const armnn::ConstTensor& beta,
183                                                      const armnn::ConstTensor& gamma,
184                                                      const char* name)
185 {
186     auto fbBatchNormalizationBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_BatchNormalization);
187     auto fbBatchNormalizationDescriptor = serializer::CreateBatchNormalizationDescriptor(
188                                                   m_flatBufferBuilder,
189                                                   batchNormDescriptor.m_Eps,
190                                                   GetFlatBufferDataLayout(batchNormDescriptor.m_DataLayout));
191
192     auto fbMeanConstTensorInfo     = CreateConstTensorInfo(mean);
193     auto fbVarianceConstTensorInfo = CreateConstTensorInfo(variance);
194     auto fbBetaConstTensorInfo     = CreateConstTensorInfo(beta);
195     auto fbGammaConstTensorInfo    = CreateConstTensorInfo(gamma);
196     auto fbBatchNormalizationLayer = serializer::CreateBatchNormalizationLayer(m_flatBufferBuilder,
197                                                                                fbBatchNormalizationBaseLayer,
198                                                                                fbBatchNormalizationDescriptor,
199                                                                                fbMeanConstTensorInfo,
200                                                                                fbVarianceConstTensorInfo,
201                                                                                fbBetaConstTensorInfo,
202                                                                                fbGammaConstTensorInfo);
203
204     CreateAnyLayer(fbBatchNormalizationLayer.o, serializer::Layer::Layer_BatchNormalizationLayer);
205 }
206
207 // Build FlatBuffer for Constant Layer
208 void SerializerVisitor::VisitConstantLayer(const armnn::IConnectableLayer* layer,
209                                            const armnn::ConstTensor& input,
210                                            const char* name)
211 {
212     // Create FlatBuffer BaseLayer
213     auto flatBufferConstantBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Constant);
214
215     auto flatBufferConstTensorInfo = CreateConstTensorInfo(input);
216
217     // Create the FlatBuffer ConstantLayer
218     auto flatBufferLayer = CreateConstantLayer(m_flatBufferBuilder,
219                                                flatBufferConstantBaseLayer,
220                                                flatBufferConstTensorInfo);
221
222     // Add the AnyLayer to the FlatBufferLayers
223     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ConstantLayer);
224 }
225
226 // Build FlatBuffer for Convolution2dLayer
227 void SerializerVisitor::VisitConvolution2dLayer(const armnn::IConnectableLayer* layer,
228                                                 const armnn::Convolution2dDescriptor& descriptor,
229                                                 const armnn::ConstTensor& weights,
230                                                 const armnn::Optional<armnn::ConstTensor>& biases,
231                                                 const char* name)
232 {
233     // Create FlatBuffer BaseLayer
234     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Convolution2d);
235
236     auto flatBufferDescriptor = CreateConvolution2dDescriptor(m_flatBufferBuilder,
237                                                               descriptor.m_PadLeft,
238                                                               descriptor.m_PadRight,
239                                                               descriptor.m_PadTop,
240                                                               descriptor.m_PadBottom,
241                                                               descriptor.m_StrideX,
242                                                               descriptor.m_StrideY,
243                                                               descriptor.m_DilationX,
244                                                               descriptor.m_DilationY,
245                                                               descriptor.m_BiasEnabled,
246                                                               GetFlatBufferDataLayout(descriptor.m_DataLayout));
247     auto flatBufferWeightsConstTensorInfo = CreateConstTensorInfo(weights);
248     flatbuffers::Offset<serializer::ConstTensor> flatBufferBiasesConstTensorInfo;
249
250     if (biases.has_value())
251     {
252         flatBufferBiasesConstTensorInfo = CreateConstTensorInfo(biases.value());
253     }
254
255     // Create the FlatBuffer Convolution2dLayer
256     auto flatBufferLayer = CreateConvolution2dLayer(m_flatBufferBuilder,
257                                                     flatBufferBaseLayer,
258                                                     flatBufferDescriptor,
259                                                     flatBufferWeightsConstTensorInfo,
260                                                     flatBufferBiasesConstTensorInfo);
261
262     // Add the AnyLayer to the FlatBufferLayers
263     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_Convolution2dLayer);
264 }
265
266 void SerializerVisitor::VisitDepthwiseConvolution2dLayer(const armnn::IConnectableLayer* layer,
267                                                          const armnn::DepthwiseConvolution2dDescriptor& descriptor,
268                                                          const armnn::ConstTensor& weights,
269                                                          const armnn::Optional<armnn::ConstTensor>& biases,
270                                                          const char* name)
271 {
272     auto fbBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_DepthwiseConvolution2d);
273     auto fbDescriptor = CreateDepthwiseConvolution2dDescriptor(m_flatBufferBuilder,
274                                                                descriptor.m_PadLeft,
275                                                                descriptor.m_PadRight,
276                                                                descriptor.m_PadTop,
277                                                                descriptor.m_PadBottom,
278                                                                descriptor.m_StrideX,
279                                                                descriptor.m_StrideY,
280                                                                descriptor.m_DilationX,
281                                                                descriptor.m_DilationY,
282                                                                descriptor.m_BiasEnabled,
283                                                                GetFlatBufferDataLayout(descriptor.m_DataLayout));
284
285     flatbuffers::Offset<serializer::ConstTensor> fbWeightsConstTensorInfo = CreateConstTensorInfo(weights);
286     flatbuffers::Offset<serializer::ConstTensor> fbBiasesConstTensorInfo;
287     if (biases.has_value())
288     {
289         fbBiasesConstTensorInfo = CreateConstTensorInfo(biases.value());
290     }
291
292     auto flatBufferLayer = CreateDepthwiseConvolution2dLayer(m_flatBufferBuilder,
293                                                              fbBaseLayer,
294                                                              fbDescriptor,
295                                                              fbWeightsConstTensorInfo,
296                                                              fbBiasesConstTensorInfo);
297
298     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_DepthwiseConvolution2dLayer);
299 }
300
301 void SerializerVisitor::VisitDequantizeLayer(const armnn::IConnectableLayer* layer,
302                                              const char* name)
303 {
304     auto fbDequantizeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Dequantize);
305     auto fbDequantizeLayer     = serializer::CreateDequantizeLayer(m_flatBufferBuilder, fbDequantizeBaseLayer);
306
307     CreateAnyLayer(fbDequantizeLayer.o, serializer::Layer::Layer_DequantizeLayer);
308 }
309
310 void SerializerVisitor::VisitDetectionPostProcessLayer(const armnn::IConnectableLayer* layer,
311                                                        const armnn::DetectionPostProcessDescriptor& descriptor,
312                                                        const armnn::ConstTensor& anchors,
313                                                        const char* name)
314 {
315     auto fbBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_DetectionPostProcess);
316     auto fbDescriptor = CreateDetectionPostProcessDescriptor(m_flatBufferBuilder,
317                                                              descriptor.m_MaxDetections,
318                                                              descriptor.m_MaxClassesPerDetection,
319                                                              descriptor.m_DetectionsPerClass,
320                                                              descriptor.m_NmsScoreThreshold,
321                                                              descriptor.m_NmsIouThreshold,
322                                                              descriptor.m_NumClasses,
323                                                              descriptor.m_UseRegularNms,
324                                                              descriptor.m_ScaleX,
325                                                              descriptor.m_ScaleY,
326                                                              descriptor.m_ScaleW,
327                                                              descriptor.m_ScaleH);
328
329     flatbuffers::Offset<serializer::ConstTensor> fbAnchorsConstTensorInfo = CreateConstTensorInfo(anchors);
330
331     auto flatBufferLayer = CreateDetectionPostProcessLayer(m_flatBufferBuilder,
332                                                            fbBaseLayer,
333                                                            fbDescriptor,
334                                                            fbAnchorsConstTensorInfo);
335
336     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_DetectionPostProcessLayer);
337 }
338
339 void SerializerVisitor::VisitDivisionLayer(const armnn::IConnectableLayer* layer, const char* name)
340 {
341     auto fbDivisionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Division);
342     auto fbDivisionLayer     = serializer::CreateDivisionLayer(m_flatBufferBuilder, fbDivisionBaseLayer);
343
344     CreateAnyLayer(fbDivisionLayer.o, serializer::Layer::Layer_DivisionLayer);
345 }
346
347 void SerializerVisitor::VisitEqualLayer(const armnn::IConnectableLayer* layer, const char* name)
348 {
349     auto fbBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_Equal);
350     auto fbEqualLayer = serializer::CreateEqualLayer(m_flatBufferBuilder, fbBaseLayer);
351
352     CreateAnyLayer(fbEqualLayer.o, serializer::Layer::Layer_EqualLayer);
353 }
354
355 void SerializerVisitor::VisitFloorLayer(const armnn::IConnectableLayer *layer, const char *name)
356 {
357     auto flatBufferFloorBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Floor);
358     auto flatBufferFloorLayer = serializer::CreateFloorLayer(m_flatBufferBuilder, flatBufferFloorBaseLayer);
359
360     CreateAnyLayer(flatBufferFloorLayer.o, serializer::Layer::Layer_FloorLayer);
361 }
362
363 void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer, const char* name)
364 {
365     auto fbGatherBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Gather);
366     auto flatBufferLayer   = serializer::CreateGatherLayer(m_flatBufferBuilder, fbGatherBaseLayer);
367
368     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_GatherLayer);
369 }
370
371 void SerializerVisitor::VisitGreaterLayer(const armnn::IConnectableLayer* layer, const char* name)
372 {
373     auto fbGreaterBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Greater);
374     auto fbGreaterLayer = serializer::CreateGreaterLayer(m_flatBufferBuilder, fbGreaterBaseLayer);
375
376     CreateAnyLayer(fbGreaterLayer.o, serializer::Layer::Layer_GreaterLayer);
377 }
378
379 void SerializerVisitor::VisitL2NormalizationLayer(const armnn::IConnectableLayer* layer,
380                                                   const armnn::L2NormalizationDescriptor& l2NormalizationDescriptor,
381                                                   const char* name)
382 {
383     // Create FlatBuffer BaseLayer
384     auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_L2Normalization);
385
386     // Create the FlatBuffer L2Normalization Descriptor
387     auto fbDescriptor = serializer::CreateL2NormalizationDescriptor(
388             m_flatBufferBuilder,
389             GetFlatBufferDataLayout(l2NormalizationDescriptor.m_DataLayout),
390             l2NormalizationDescriptor.m_Eps);
391
392     // Create FlatBuffer layer
393     auto fbLayer = serializer::CreateL2NormalizationLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
394
395     CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_L2NormalizationLayer);
396 }
397
398 void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, const armnn::LstmDescriptor& descriptor,
399                                        const armnn::LstmInputParams& params, const char* name)
400 {
401     auto fbLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Lstm);
402
403     auto fbLstmDescriptor = serializer::CreateLstmDescriptor(
404         m_flatBufferBuilder,
405         descriptor.m_ActivationFunc,
406         descriptor.m_ClippingThresCell,
407         descriptor.m_ClippingThresProj,
408         descriptor.m_CifgEnabled,
409         descriptor.m_PeepholeEnabled,
410         descriptor.m_ProjectionEnabled,
411         descriptor.m_LayerNormEnabled);
412
413     // Get mandatory input parameters
414     auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights);
415     auto inputToCellWeights = CreateConstTensorInfo(*params.m_InputToCellWeights);
416     auto inputToOutputWeights = CreateConstTensorInfo(*params.m_InputToOutputWeights);
417     auto recurrentToForgetWeights = CreateConstTensorInfo(*params.m_RecurrentToForgetWeights);
418     auto recurrentToCellWeights = CreateConstTensorInfo(*params.m_RecurrentToCellWeights);
419     auto recurrentToOutputWeights = CreateConstTensorInfo(*params.m_RecurrentToOutputWeights);
420     auto forgetGateBias = CreateConstTensorInfo(*params.m_ForgetGateBias);
421     auto cellBias = CreateConstTensorInfo(*params.m_CellBias);
422     auto outputGateBias = CreateConstTensorInfo(*params.m_OutputGateBias);
423
424     //Define optional parameters, these will be set depending on configuration in Lstm descriptor
425     flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights;
426     flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights;
427     flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights;
428     flatbuffers::Offset<serializer::ConstTensor> inputGateBias;
429     flatbuffers::Offset<serializer::ConstTensor> projectionWeights;
430     flatbuffers::Offset<serializer::ConstTensor> projectionBias;
431     flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights;
432     flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights;
433     flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights;
434     flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights;
435     flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights;
436     flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights;
437
438     if (!descriptor.m_CifgEnabled)
439     {
440         inputToInputWeights = CreateConstTensorInfo(*params.m_InputToInputWeights);
441         recurrentToInputWeights = CreateConstTensorInfo(*params.m_RecurrentToInputWeights);
442         cellToInputWeights = CreateConstTensorInfo(*params.m_CellToInputWeights);
443         inputGateBias = CreateConstTensorInfo(*params.m_InputGateBias);
444     }
445
446     if (descriptor.m_ProjectionEnabled)
447     {
448         projectionWeights = CreateConstTensorInfo(*params.m_ProjectionWeights);
449         projectionBias = CreateConstTensorInfo(*params.m_ProjectionBias);
450     }
451
452     if (descriptor.m_PeepholeEnabled)
453     {
454         cellToForgetWeights = CreateConstTensorInfo(*params.m_CellToForgetWeights);
455         cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights);
456     }
457
458     if (descriptor.m_LayerNormEnabled)
459     {
460         if (!descriptor.m_CifgEnabled)
461         {
462             inputLayerNormWeights = CreateConstTensorInfo((*params.m_InputLayerNormWeights));
463         }
464         forgetLayerNormWeights = CreateConstTensorInfo(*params.m_ForgetLayerNormWeights);
465         cellLayerNormWeights   = CreateConstTensorInfo(*params.m_CellLayerNormWeights);
466         outputLayerNormWeights = CreateConstTensorInfo(*params.m_OutputLayerNormWeights);
467     }
468
469     auto fbLstmParams = serializer::CreateLstmInputParams(
470         m_flatBufferBuilder,
471         inputToForgetWeights,
472         inputToCellWeights,
473         inputToOutputWeights,
474         recurrentToForgetWeights,
475         recurrentToCellWeights,
476         recurrentToOutputWeights,
477         forgetGateBias,
478         cellBias,
479         outputGateBias,
480         inputToInputWeights,
481         recurrentToInputWeights,
482         cellToInputWeights,
483         inputGateBias,
484         projectionWeights,
485         projectionBias,
486         cellToForgetWeights,
487         cellToOutputWeights,
488         inputLayerNormWeights,
489         forgetLayerNormWeights,
490         cellLayerNormWeights,
491         outputLayerNormWeights);
492
493     auto fbLstmLayer = serializer::CreateLstmLayer(
494         m_flatBufferBuilder,
495         fbLstmBaseLayer,
496         fbLstmDescriptor,
497         fbLstmParams);
498
499     CreateAnyLayer(fbLstmLayer.o, serializer::Layer::Layer_LstmLayer);
500 }
501
502 void SerializerVisitor::VisitMaximumLayer(const armnn::IConnectableLayer* layer, const char* name)
503 {
504     auto fbMaximumBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Maximum);
505     auto fbMaximumLayer     = serializer::CreateMaximumLayer(m_flatBufferBuilder, fbMaximumBaseLayer);
506
507     CreateAnyLayer(fbMaximumLayer.o, serializer::Layer::Layer_MaximumLayer);
508 }
509
510 void SerializerVisitor::VisitMeanLayer(const armnn::IConnectableLayer* layer,
511                                        const armnn::MeanDescriptor& descriptor,
512                                        const char* name)
513 {
514     auto fbMeanBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_Mean);
515     auto fbMeanDescriptor = serializer::CreateMeanDescriptor(m_flatBufferBuilder,
516                                                              m_flatBufferBuilder.CreateVector(descriptor.m_Axis),
517                                                              descriptor.m_KeepDims);
518
519     auto fbMeanLayer = serializer::CreateMeanLayer(m_flatBufferBuilder,
520                                                    fbMeanBaseLayer,
521                                                    fbMeanDescriptor);
522
523     CreateAnyLayer(fbMeanLayer.o, serializer::Layer::Layer_MeanLayer);
524 }
525
526 void SerializerVisitor::VisitMinimumLayer(const armnn::IConnectableLayer* layer, const char* name)
527 {
528     auto fbMinimumBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Minimum);
529     auto fbMinimumLayer     = serializer::CreateMinimumLayer(m_flatBufferBuilder, fbMinimumBaseLayer);
530
531     CreateAnyLayer(fbMinimumLayer.o, serializer::Layer::Layer_MinimumLayer);
532 }
533
534 void SerializerVisitor::VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name)
535 {
536     auto fbMergeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Merge);
537     auto fbMergeLayer     = serializer::CreateMergeLayer(m_flatBufferBuilder, fbMergeBaseLayer);
538
539     CreateAnyLayer(fbMergeLayer.o, serializer::Layer::Layer_MergeLayer);
540 }
541
542 void SerializerVisitor::VisitMergerLayer(const armnn::IConnectableLayer* layer,
543                                          const armnn::MergerDescriptor& mergerDescriptor,
544                                          const char* name)
545 {
546     VisitConcatLayer(layer, mergerDescriptor, name);
547 }
548
549 void SerializerVisitor::VisitConcatLayer(const armnn::IConnectableLayer* layer,
550                                          const armnn::ConcatDescriptor& concatDescriptor,
551                                          const char* name)
552 {
553     auto flatBufferConcatBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Concat);
554
555     std::vector<flatbuffers::Offset<UintVector>> views;
556     for (unsigned int v = 0; v < concatDescriptor.GetNumViews(); ++v)
557     {
558         const uint32_t* origin = concatDescriptor.GetViewOrigin(v);
559         std::vector<uint32_t> origins;
560         for (unsigned int d = 0; d < concatDescriptor.GetNumDimensions(); ++d)
561         {
562             origins.push_back(origin[d]);
563         }
564         auto view = m_flatBufferBuilder.CreateVector(origins);
565         auto uintVector = CreateUintVector(m_flatBufferBuilder, view);
566         views.push_back(uintVector);
567     }
568
569     auto flatBufferConcatDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder,
570                                                               concatDescriptor.GetConcatAxis(),
571                                                               concatDescriptor.GetNumViews(),
572                                                               concatDescriptor.GetNumDimensions(),
573                                                               m_flatBufferBuilder.CreateVector(views));
574
575     auto flatBufferLayer = CreateConcatLayer(m_flatBufferBuilder,
576                                              flatBufferConcatBaseLayer,
577                                              flatBufferConcatDescriptor);
578
579     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ConcatLayer);
580 }
581
582 void SerializerVisitor::VisitMultiplicationLayer(const armnn::IConnectableLayer* layer, const char* name)
583 {
584     auto fbMultiplicationBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Multiplication);
585     auto fbMultiplicationLayer     = serializer::CreateMultiplicationLayer(m_flatBufferBuilder,
586                                                                            fbMultiplicationBaseLayer);
587
588     CreateAnyLayer(fbMultiplicationLayer.o, serializer::Layer::Layer_MultiplicationLayer);
589 }
590
591 void SerializerVisitor::VisitPadLayer(const armnn::IConnectableLayer* layer,
592                                       const armnn::PadDescriptor& padDescriptor,
593                                       const char* name)
594 {
595     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Pad);
596
597     std::vector<unsigned int> padList;
598     for (auto& p: padDescriptor.m_PadList)
599     {
600         padList.push_back(p.first);
601         padList.push_back(p.second);
602     }
603
604     auto flatBufferPadDesc = serializer::CreatePadDescriptor(m_flatBufferBuilder,
605                                                              m_flatBufferBuilder.CreateVector(padList),
606                                                              padDescriptor.m_PadValue);
607
608     auto flatBufferPadLayer = serializer::CreatePadLayer(m_flatBufferBuilder,
609                                                          flatBufferBaseLayer,
610                                                          flatBufferPadDesc);
611
612     CreateAnyLayer(flatBufferPadLayer.o, serializer::Layer::Layer_PadLayer);
613 }
614
615 void SerializerVisitor::VisitPermuteLayer(const armnn::IConnectableLayer* layer,
616                                           const armnn::PermuteDescriptor& permuteDescriptor,
617                                           const char* name)
618 {
619     // Create FlatBuffer BaseLayer
620     auto flatBufferPermuteBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Permute);
621
622     std::vector<unsigned int> dimMappings;
623     for (unsigned int i=0; i<permuteDescriptor.m_DimMappings.GetSize(); ++i)
624     {
625         dimMappings.push_back(permuteDescriptor.m_DimMappings[i]);
626     }
627
628     auto flatBufferPermuteDesc = serializer::CreatePermuteDescriptor(m_flatBufferBuilder,
629                                                                      m_flatBufferBuilder.CreateVector(dimMappings));
630
631     // Create the FlatBuffer PermuteLayer
632     auto flatBufferPermuteLayer = serializer::CreatePermuteLayer(m_flatBufferBuilder,
633                                                                  flatBufferPermuteBaseLayer,
634                                                                  flatBufferPermuteDesc);
635
636     // Add the AnyLayer to the FlatBufferLayers
637     CreateAnyLayer(flatBufferPermuteLayer.o, serializer::Layer::Layer_PermuteLayer);
638 }
639
640 // Build FlatBuffer for Reshape Layer
641 void SerializerVisitor::VisitReshapeLayer(const armnn::IConnectableLayer* layer,
642                                           const armnn::ReshapeDescriptor& reshapeDescriptor,
643                                           const char* name)
644 {
645     // Create FlatBuffer BaseLayer
646     auto flatBufferReshapeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Reshape);
647
648     std::vector<unsigned int> targetShape;
649     for (unsigned int i =0; i < reshapeDescriptor.m_TargetShape.GetNumDimensions(); i++)
650     {
651         targetShape.push_back(reshapeDescriptor.m_TargetShape[i]);
652     }
653
654     auto flatBufferReshapeDesc = serializer::CreateReshapeDescriptor(m_flatBufferBuilder,
655                                                                      m_flatBufferBuilder.CreateVector(targetShape));
656
657     // Create the FlatBuffer ReshapeLayer
658     auto flatBufferReshapeLayer = serializer::CreateReshapeLayer(m_flatBufferBuilder, flatBufferReshapeBaseLayer,
659                                                                  flatBufferReshapeDesc);
660
661     // Add the AnyLayer to the FlatBufferLayers
662     CreateAnyLayer(flatBufferReshapeLayer.o, serializer::Layer::Layer_ReshapeLayer);
663 }
664
665 void SerializerVisitor::VisitResizeBilinearLayer(const armnn::IConnectableLayer* layer,
666                                                  const armnn::ResizeBilinearDescriptor& resizeDescriptor,
667                                                  const char* name)
668 {
669     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_ResizeBilinear);
670
671     auto flatBufferDescriptor =
672         CreateResizeBilinearDescriptor(m_flatBufferBuilder,
673                                        resizeDescriptor.m_TargetWidth,
674                                        resizeDescriptor.m_TargetHeight,
675                                        GetFlatBufferDataLayout(resizeDescriptor.m_DataLayout));
676
677     auto flatBufferLayer = serializer::CreateResizeBilinearLayer(m_flatBufferBuilder,
678                                                                  flatBufferBaseLayer,
679                                                                  flatBufferDescriptor);
680
681     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ResizeBilinearLayer);
682 }
683
684 void SerializerVisitor::VisitResizeLayer(const armnn::IConnectableLayer* layer,
685                                          const armnn::ResizeDescriptor& resizeDescriptor,
686                                          const char* name)
687 {
688     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Resize);
689
690     auto flatBufferDescriptor =
691             CreateResizeDescriptor(m_flatBufferBuilder,
692                                    resizeDescriptor.m_TargetHeight,
693                                    resizeDescriptor.m_TargetWidth,
694                                    GetFlatBufferResizeMethod(resizeDescriptor.m_Method),
695                                    GetFlatBufferDataLayout(resizeDescriptor.m_DataLayout));
696
697     auto flatBufferLayer = serializer::CreateResizeLayer(m_flatBufferBuilder,
698                                                          flatBufferBaseLayer,
699                                                          flatBufferDescriptor);
700
701     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ResizeLayer);
702 }
703
704 void SerializerVisitor::VisitRsqrtLayer(const armnn::IConnectableLayer* layer, const char* name)
705 {
706     auto fbRsqrtBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Rsqrt);
707     auto fbRsqrtLayer     = serializer::CreateRsqrtLayer(m_flatBufferBuilder, fbRsqrtBaseLayer);
708
709     CreateAnyLayer(fbRsqrtLayer.o, serializer::Layer::Layer_RsqrtLayer);
710 }
711
712 // Build FlatBuffer for Softmax Layer
713 void SerializerVisitor::VisitSoftmaxLayer(const armnn::IConnectableLayer* layer,
714                                           const armnn::SoftmaxDescriptor& softmaxDescriptor,
715                                           const char* name)
716 {
717     // Create FlatBuffer BaseLayer
718     auto flatBufferSoftmaxBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Softmax);
719
720     // Create the FlatBuffer SoftmaxDescriptor
721     auto flatBufferSoftmaxDesc =
722         serializer::CreateSoftmaxDescriptor(m_flatBufferBuilder, softmaxDescriptor.m_Beta);
723
724     // Create the FlatBuffer SoftmaxLayer
725     auto flatBufferSoftmaxLayer =
726         serializer::CreateSoftmaxLayer(m_flatBufferBuilder,
727                                        flatBufferSoftmaxBaseLayer,
728                                        flatBufferSoftmaxDesc);
729
730     CreateAnyLayer(flatBufferSoftmaxLayer.o, serializer::Layer::Layer_SoftmaxLayer);
731 }
732
733 void SerializerVisitor::VisitPooling2dLayer(const armnn::IConnectableLayer* layer,
734                                             const armnn::Pooling2dDescriptor& pooling2dDescriptor,
735                                             const char* name)
736 {
737     auto fbPooling2dBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_Pooling2d);
738     auto fbPooling2dDescriptor = serializer::CreatePooling2dDescriptor(
739         m_flatBufferBuilder,
740         GetFlatBufferPoolingAlgorithm(pooling2dDescriptor.m_PoolType),
741         pooling2dDescriptor.m_PadLeft,
742         pooling2dDescriptor.m_PadRight,
743         pooling2dDescriptor.m_PadTop,
744         pooling2dDescriptor.m_PadBottom,
745         pooling2dDescriptor.m_PoolWidth,
746         pooling2dDescriptor.m_PoolHeight,
747         pooling2dDescriptor.m_StrideX,
748         pooling2dDescriptor.m_StrideY,
749         GetFlatBufferOutputShapeRounding(pooling2dDescriptor.m_OutputShapeRounding),
750         GetFlatBufferPaddingMethod(pooling2dDescriptor.m_PaddingMethod),
751         GetFlatBufferDataLayout(pooling2dDescriptor.m_DataLayout));
752
753     auto fbPooling2dLayer = serializer::CreatePooling2dLayer(m_flatBufferBuilder,
754                                                              fbPooling2dBaseLayer,
755                                                              fbPooling2dDescriptor);
756
757     CreateAnyLayer(fbPooling2dLayer.o, serializer::Layer::Layer_Pooling2dLayer);
758 }
759
760 void SerializerVisitor::VisitPreluLayer(const armnn::IConnectableLayer* layer,
761                                         const char* name)
762 {
763     // Create FlatBuffer BaseLayer
764     auto flatBufferPreluBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Prelu);
765
766     // Create the FlatBuffer AdditionLayer
767     auto flatBufferPreluLayer = serializer::CreatePreluLayer(m_flatBufferBuilder, flatBufferPreluBaseLayer);
768
769     // Add the AnyLayer to the FlatBufferLayers
770     CreateAnyLayer(flatBufferPreluLayer.o, serializer::Layer::Layer_PreluLayer);
771 }
772
773 void SerializerVisitor::VisitQuantizeLayer(const armnn::IConnectableLayer *layer, const char *name)
774 {
775     auto fbQuantizeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Quantize);
776     auto fbQuantizeLayer = serializer::CreateQuantizeLayer(m_flatBufferBuilder,
777                                                            fbQuantizeBaseLayer);
778     CreateAnyLayer(fbQuantizeLayer.o, serializer::Layer::Layer_QuantizeLayer);
779 }
780
781 // Build FlatBuffer for FullyConnected Layer
782 void SerializerVisitor::VisitFullyConnectedLayer(const armnn::IConnectableLayer* layer,
783                                                  const armnn::FullyConnectedDescriptor& fullyConnectedDescriptor,
784                                                  const armnn::ConstTensor& weights,
785                                                  const armnn::Optional<armnn::ConstTensor>& biases,
786                                                  const char* name)
787 {
788     // Create FlatBuffer BaseLayer
789     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_FullyConnected);
790
791     // Create FlatBuffer FullyConnectedDescriptor
792     auto flatBufferDescriptor =
793         serializer::CreateFullyConnectedDescriptor(m_flatBufferBuilder,
794                                                    fullyConnectedDescriptor.m_BiasEnabled,
795                                                    fullyConnectedDescriptor.m_TransposeWeightMatrix);
796
797     // Create FlatBuffer weights data
798     auto flatBufferWeights = CreateConstTensorInfo(weights);
799
800     // Create FlatBuffer bias data
801     flatbuffers::Offset<serializer::ConstTensor> flatBufferBiases;
802     if (fullyConnectedDescriptor.m_BiasEnabled)
803     {
804         flatBufferBiases = CreateConstTensorInfo(biases.value());
805     }
806
807     // Create FlatBuffer FullyConnectedLayer
808     auto flatBufferLayer = serializer::CreateFullyConnectedLayer(m_flatBufferBuilder,
809                                                                  flatBufferBaseLayer,
810                                                                  flatBufferDescriptor,
811                                                                  flatBufferWeights,
812                                                                  flatBufferBiases);
813
814     // Add created FullyConnectedLayer to the FlatBufferLayers
815     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_FullyConnectedLayer);
816 }
817
818 // Build FlatBuffer for SpaceToBatchNd Layer
819 void SerializerVisitor::VisitSpaceToBatchNdLayer(const armnn::IConnectableLayer* layer,
820                                                  const armnn::SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
821                                                  const char* name)
822 {
823     // Create FlatBuffer BaseLayer
824     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_SpaceToBatchNd);
825
826     std::vector<unsigned int> padList;
827     padList.reserve(spaceToBatchNdDescriptor.m_PadList.size()*2);
828     for (auto& pad : spaceToBatchNdDescriptor.m_PadList)
829     {
830         padList.push_back(pad.first);
831         padList.push_back(pad.second);
832     }
833
834     auto flatBufferDescriptor =
835         CreateSpaceToBatchNdDescriptor(m_flatBufferBuilder,
836                                        m_flatBufferBuilder.CreateVector(spaceToBatchNdDescriptor.m_BlockShape),
837                                        m_flatBufferBuilder.CreateVector(padList),
838                                        GetFlatBufferDataLayout(spaceToBatchNdDescriptor.m_DataLayout));
839
840     auto flatBufferLayer = serializer::CreateSpaceToBatchNdLayer(m_flatBufferBuilder,
841                                                                  flatBufferBaseLayer,
842                                                                  flatBufferDescriptor);
843
844     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_SpaceToBatchNdLayer);
845 }
846
847 // Build FlatBuffer for SpaceToDepthLayer
848 void SerializerVisitor::VisitSpaceToDepthLayer(const armnn::IConnectableLayer* layer,
849                                                const armnn::SpaceToDepthDescriptor& spaceToDepthDescriptor,
850                                                const char* name)
851 {
852     auto flatBufferBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_SpaceToDepth);
853     auto flatBufferDescriptor =
854         CreateSpaceToDepthDescriptor(m_flatBufferBuilder,
855                                      spaceToDepthDescriptor.m_BlockSize,
856                                      GetFlatBufferDataLayout(spaceToDepthDescriptor.m_DataLayout));
857
858     auto flatBufferLayer = serializer::CreateSpaceToDepthLayer(m_flatBufferBuilder,
859                                                                flatBufferBaseLayer,
860                                                                flatBufferDescriptor);
861
862     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_SpaceToDepthLayer);
863 }
864
865 // Build FlatBuffer for Splitter Layer
866 void SerializerVisitor::VisitSplitterLayer(const armnn::IConnectableLayer* layer,
867                                            const armnn::ViewsDescriptor& viewsDescriptor,
868                                            const char* name)
869 {
870     // Create FlatBuffer ViewOrigins
871     std::vector<flatbuffers::Offset<UintVector>> flatBufferViewOrigins;
872     flatBufferViewOrigins.reserve(viewsDescriptor.GetNumViews());
873
874     for(unsigned int vIdx = 0; vIdx < viewsDescriptor.GetNumViews(); ++vIdx)
875     {
876         std::vector<uint32_t> viewOrigin;
877         viewOrigin.reserve(viewsDescriptor.GetNumDimensions());
878
879         // Copy vector
880         for(unsigned int dIdx = 0; dIdx < viewsDescriptor.GetNumDimensions(); ++dIdx)
881         {
882             viewOrigin.push_back(viewsDescriptor.GetViewOrigin(vIdx)[dIdx]);
883         }
884
885         flatBufferViewOrigins.push_back(CreateUintVector(m_flatBufferBuilder,
886                                                          m_flatBufferBuilder.CreateVector(viewOrigin)));
887     }
888
889     // Create FlatBuffer OriginsDescriptor
890     auto flatBufferOriginDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder,
891                                                               viewsDescriptor.GetOrigins().GetConcatAxis(),
892                                                               viewsDescriptor.GetOrigins().GetNumViews(),
893                                                               viewsDescriptor.GetOrigins().GetNumDimensions(),
894                                                               m_flatBufferBuilder.CreateVector(flatBufferViewOrigins));
895
896     // Create FlatBuffer ViewOrigins
897     std::vector<flatbuffers::Offset<UintVector>> flatBufferViewSizes;
898     flatBufferViewSizes.reserve(viewsDescriptor.GetNumViews());
899
900     for(unsigned int vIdx = 0; vIdx < viewsDescriptor.GetNumViews(); ++vIdx)
901     {
902         std::vector<uint32_t> viewSize;
903         viewSize.reserve(viewsDescriptor.GetNumDimensions());
904
905         // Copy vector
906         for(unsigned int dIdx = 0; dIdx < viewsDescriptor.GetNumDimensions(); ++dIdx)
907         {
908             viewSize.push_back(viewsDescriptor.GetViewSizes(vIdx)[dIdx]);
909         }
910
911         flatBufferViewSizes.push_back(CreateUintVector(m_flatBufferBuilder,
912                                                        m_flatBufferBuilder.CreateVector(viewSize)));
913     }
914
915     // Create FlatBuffer ViewsDescriptor
916     auto flatBufferViewsDescriptor = CreateViewsDescriptor(m_flatBufferBuilder,
917                                                            flatBufferOriginDescriptor,
918                                                            m_flatBufferBuilder.CreateVector(flatBufferViewSizes));
919
920     // Create FlatBuffer BaseLayer
921     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Splitter);
922
923     auto flatBufferSplitterLayer = serializer::CreateSplitterLayer(m_flatBufferBuilder,
924                                                                    flatBufferBaseLayer,
925                                                                    flatBufferViewsDescriptor);
926
927     CreateAnyLayer(flatBufferSplitterLayer.o, serializer::Layer::Layer_SplitterLayer);
928 }
929
930 void SerializerVisitor::VisitNormalizationLayer(const armnn::IConnectableLayer* layer,
931                                                 const armnn::NormalizationDescriptor& descriptor,
932                                                 const char* name)
933 {
934     auto fbNormalizationBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_Normalization);
935
936     auto fbNormalizationDescriptor = serializer::CreateNormalizationDescriptor(
937         m_flatBufferBuilder,
938         GetFlatBufferNormalizationAlgorithmChannel(descriptor.m_NormChannelType),
939         GetFlatBufferNormalizationAlgorithmMethod(descriptor.m_NormMethodType),
940         descriptor.m_NormSize,
941         descriptor.m_Alpha,
942         descriptor.m_Beta,
943         descriptor.m_K,
944         GetFlatBufferDataLayout(descriptor.m_DataLayout));
945
946     auto flatBufferLayer = serializer::CreateNormalizationLayer(m_flatBufferBuilder,
947                                                                 fbNormalizationBaseLayer,
948                                                                 fbNormalizationDescriptor);
949
950     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_NormalizationLayer);
951 }
952
953 void SerializerVisitor::VisitStackLayer(const armnn::IConnectableLayer* layer,
954                                         const armnn::StackDescriptor& stackDescriptor,
955                                         const char* name)
956 {
957     auto stackBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Stack);
958
959     std::vector<unsigned int> inputShape;
960     for (unsigned int i =0; i < stackDescriptor.m_InputShape.GetNumDimensions(); i++)
961     {
962         inputShape.push_back(stackDescriptor.m_InputShape[i]);
963     }
964
965     auto flatBufferStackDescriptor = CreateStackDescriptor(m_flatBufferBuilder,
966                                                            stackDescriptor.m_Axis,
967                                                            stackDescriptor.m_NumInputs,
968                                                            m_flatBufferBuilder.CreateVector(inputShape));
969
970     auto stackLayer = serializer::CreateStackLayer(m_flatBufferBuilder, stackBaseLayer, flatBufferStackDescriptor);
971     CreateAnyLayer(stackLayer.o, serializer::Layer::Layer_StackLayer);
972 }
973
974 void SerializerVisitor::VisitStridedSliceLayer(const armnn::IConnectableLayer* layer,
975                                                const armnn::StridedSliceDescriptor& stridedSliceDescriptor,
976                                                const char* name)
977 {
978     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_StridedSlice);
979
980     auto flatBufferDescriptor =
981         CreateStridedSliceDescriptor(m_flatBufferBuilder,
982                                      m_flatBufferBuilder.CreateVector(stridedSliceDescriptor.m_Begin),
983                                      m_flatBufferBuilder.CreateVector(stridedSliceDescriptor.m_End),
984                                      m_flatBufferBuilder.CreateVector(stridedSliceDescriptor.m_Stride),
985                                      stridedSliceDescriptor.m_BeginMask,
986                                      stridedSliceDescriptor.m_EndMask,
987                                      stridedSliceDescriptor.m_ShrinkAxisMask,
988                                      stridedSliceDescriptor.m_EllipsisMask,
989                                      stridedSliceDescriptor.m_NewAxisMask,
990                                      GetFlatBufferDataLayout(stridedSliceDescriptor.m_DataLayout));
991
992     auto flatBufferLayer = serializer::CreateStridedSliceLayer(m_flatBufferBuilder,
993                                                                flatBufferBaseLayer,
994                                                                flatBufferDescriptor);
995
996     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_StridedSliceLayer);
997 }
998
999 void SerializerVisitor::VisitSubtractionLayer(const armnn::IConnectableLayer* layer, const char* name)
1000 {
1001     auto fbSubtractionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Subtraction);
1002     auto fbSubtractionLayer = serializer::CreateSubtractionLayer(m_flatBufferBuilder, fbSubtractionBaseLayer);
1003
1004     CreateAnyLayer(fbSubtractionLayer.o, serializer::Layer::Layer_SubtractionLayer);
1005 }
1006
1007 void SerializerVisitor::VisitSwitchLayer(const armnn::IConnectableLayer* layer, const char* name)
1008 {
1009     auto fbSwitchBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Switch);
1010     auto fbSwitchLayer = serializer::CreateSwitchLayer(m_flatBufferBuilder, fbSwitchBaseLayer);
1011
1012     CreateAnyLayer(fbSwitchLayer.o, serializer::Layer::Layer_SwitchLayer);
1013 }
1014
1015 void SerializerVisitor::VisitTransposeConvolution2dLayer(
1016     const armnn::IConnectableLayer* layer,
1017     const armnn::TransposeConvolution2dDescriptor& descriptor,
1018     const armnn::ConstTensor& weights,
1019     const armnn::Optional<armnn::ConstTensor>& biases,
1020     const char* name)
1021 {
1022     auto fbBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_Convolution2d);
1023     auto fbDescriptor = CreateTransposeConvolution2dDescriptor(m_flatBufferBuilder,
1024                                                                descriptor.m_PadLeft,
1025                                                                descriptor.m_PadRight,
1026                                                                descriptor.m_PadTop,
1027                                                                descriptor.m_PadBottom,
1028                                                                descriptor.m_StrideX,
1029                                                                descriptor.m_StrideY,
1030                                                                descriptor.m_BiasEnabled,
1031                                                                GetFlatBufferDataLayout(descriptor.m_DataLayout));
1032
1033     // weights & biases
1034     auto fbWeightsConstTensorInfo = CreateConstTensorInfo(weights);
1035     flatbuffers::Offset<serializer::ConstTensor> fbBiasesConstTensorInfo;
1036     if (biases.has_value())
1037     {
1038         fbBiasesConstTensorInfo = CreateConstTensorInfo(biases.value());
1039     }
1040
1041     auto fbLayer = CreateTransposeConvolution2dLayer(m_flatBufferBuilder,
1042                                                      fbBaseLayer,
1043                                                      fbDescriptor,
1044                                                      fbWeightsConstTensorInfo,
1045                                                      fbBiasesConstTensorInfo);
1046
1047     CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_TransposeConvolution2dLayer);
1048 }
1049
1050 void SerializerVisitor::VisitQuantizedLstmLayer(const armnn::IConnectableLayer* layer,
1051                                                 const armnn::QuantizedLstmInputParams& params,
1052                                                 const char* name)
1053 {
1054     auto fbQuantizedLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QuantizedLstm);
1055
1056     // Get input parameters
1057     auto inputToInputWeights = CreateConstTensorInfo(params.GetInputToInputWeights());
1058     auto inputToForgetWeights = CreateConstTensorInfo(params.GetInputToForgetWeights());
1059     auto inputToCellWeights = CreateConstTensorInfo(params.GetInputToCellWeights());
1060     auto inputToOutputWeights = CreateConstTensorInfo(params.GetInputToOutputWeights());
1061
1062     auto recurrentToInputWeights = CreateConstTensorInfo(params.GetRecurrentToInputWeights());
1063     auto recurrentToForgetWeights = CreateConstTensorInfo(params.GetRecurrentToForgetWeights());
1064     auto recurrentToCellWeights = CreateConstTensorInfo(params.GetRecurrentToCellWeights());
1065     auto recurrentToOutputWeights = CreateConstTensorInfo(params.GetRecurrentToOutputWeights());
1066
1067     auto inputGateBias = CreateConstTensorInfo(params.GetInputGateBias());
1068     auto forgetGateBias = CreateConstTensorInfo(params.GetForgetGateBias());
1069     auto cellBias = CreateConstTensorInfo(params.GetCellBias());
1070     auto outputGateBias = CreateConstTensorInfo(params.GetOutputGateBias());
1071
1072     auto fbQuantizedLstmParams = serializer::CreateQuantizedLstmInputParams(
1073         m_flatBufferBuilder,
1074         inputToInputWeights,
1075         inputToForgetWeights,
1076         inputToCellWeights,
1077         inputToOutputWeights,
1078         recurrentToInputWeights,
1079         recurrentToForgetWeights,
1080         recurrentToCellWeights,
1081         recurrentToOutputWeights,
1082         inputGateBias,
1083         forgetGateBias,
1084         cellBias,
1085         outputGateBias);
1086
1087     auto fbQuantizedLstmLayer = serializer::CreateQuantizedLstmLayer(
1088         m_flatBufferBuilder,
1089         fbQuantizedLstmBaseLayer,
1090         fbQuantizedLstmParams);
1091
1092     CreateAnyLayer(fbQuantizedLstmLayer.o, serializer::Layer::Layer_QuantizedLstmLayer);
1093 }
1094
1095 fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer,
1096                                                                      const serializer::LayerType layerType)
1097 {
1098     uint32_t fbIndex = GetSerializedId(layer->GetGuid());
1099
1100     std::vector<fb::Offset<serializer::InputSlot>> inputSlots = CreateInputSlots(layer);
1101     std::vector<fb::Offset<serializer::OutputSlot>> outputSlots = CreateOutputSlots(layer);
1102
1103     return serializer::CreateLayerBase(m_flatBufferBuilder,
1104                                        fbIndex,
1105                                        m_flatBufferBuilder.CreateString(layer->GetName()),
1106                                        layerType,
1107                                        m_flatBufferBuilder.CreateVector(inputSlots),
1108                                        m_flatBufferBuilder.CreateVector(outputSlots));
1109 }
1110
1111 void SerializerVisitor::CreateAnyLayer(const flatbuffers::Offset<void>& layer, const serializer::Layer serializerLayer)
1112 {
1113     auto anyLayer = armnnSerializer::CreateAnyLayer(m_flatBufferBuilder, serializerLayer, layer);
1114     m_serializedLayers.push_back(anyLayer);
1115 }
1116
1117 template <typename T>
1118 flatbuffers::Offset<flatbuffers::Vector<T>> SerializerVisitor::CreateDataVector(const void* memory, unsigned int size)
1119 {
1120     const T* buffer = reinterpret_cast<const T*>(memory);
1121     std::vector<T> vector(buffer, buffer + (size / sizeof(T)));
1122     auto fbVector = m_flatBufferBuilder.CreateVector(vector);
1123     return fbVector;
1124 }
1125
1126 flatbuffers::Offset<serializer::ConstTensor>
1127     SerializerVisitor::CreateConstTensorInfo(const armnn::ConstTensor& constTensor)
1128 {
1129     armnn::TensorInfo tensorInfo = constTensor.GetInfo();
1130
1131     // Get the dimensions
1132     std::vector<unsigned int> shape;
1133
1134     for(unsigned int dim = 0; dim < tensorInfo.GetShape().GetNumDimensions(); ++dim)
1135     {
1136         shape.push_back(tensorInfo.GetShape()[dim]);
1137     }
1138
1139     // Create FlatBuffer TensorInfo
1140     auto flatBufferTensorInfo = serializer::CreateTensorInfo(m_flatBufferBuilder,
1141                                                              m_flatBufferBuilder.CreateVector(shape),
1142                                                              GetFlatBufferDataType(tensorInfo.GetDataType()),
1143                                                              tensorInfo.GetQuantizationScale(),
1144                                                              tensorInfo.GetQuantizationOffset());
1145     flatbuffers::Offset<void> fbPayload;
1146
1147     switch (tensorInfo.GetDataType())
1148     {
1149         case armnn::DataType::Float32:
1150         case armnn::DataType::Signed32:
1151         {
1152             auto fbVector = CreateDataVector<int32_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1153             flatbuffers::Offset<serializer::IntData> flatBuffersData = serializer::CreateIntData(
1154                     m_flatBufferBuilder,
1155                     fbVector);
1156             fbPayload = flatBuffersData.o;
1157             break;
1158         }
1159         case armnn::DataType::Float16:
1160         {
1161             auto fbVector = CreateDataVector<int16_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1162             flatbuffers::Offset<serializer::ShortData> flatBuffersData = serializer::CreateShortData(
1163                     m_flatBufferBuilder,
1164                     fbVector);
1165             fbPayload = flatBuffersData.o;
1166             break;
1167         }
1168         case armnn::DataType::QuantisedSymm16:
1169         {
1170             auto fbVector = CreateDataVector<int16_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1171             flatbuffers::Offset<serializer::ShortData> flatBuffersData = serializer::CreateShortData(
1172                     m_flatBufferBuilder,
1173                     fbVector);
1174             fbPayload = flatBuffersData.o;
1175             break;
1176         }
1177         case armnn::DataType::QuantisedAsymm8:
1178         case armnn::DataType::Boolean:
1179         default:
1180         {
1181             auto fbVector = CreateDataVector<int8_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1182             flatbuffers::Offset<serializer::ByteData> flatBuffersData = serializer::CreateByteData(
1183                     m_flatBufferBuilder,
1184                     fbVector);
1185             fbPayload = flatBuffersData.o;
1186         }
1187     }
1188     flatbuffers::Offset<serializer::ConstTensor> flatBufferConstTensor = serializer::CreateConstTensor(
1189             m_flatBufferBuilder,
1190             flatBufferTensorInfo,
1191             GetFlatBufferConstTensorData(tensorInfo.GetDataType()),
1192             fbPayload);
1193     return flatBufferConstTensor;
1194 }
1195
1196 std::vector<fb::Offset<serializer::InputSlot>>
1197     SerializerVisitor::CreateInputSlots(const armnn::IConnectableLayer* layer)
1198 {
1199     std::vector<fb::Offset<serializer::InputSlot>> inputSlots;
1200
1201     // Get the InputSlots
1202     for (unsigned int slotIndex = 0; slotIndex<layer->GetNumInputSlots(); ++slotIndex)
1203     {
1204         const IInputSlot& inputSlot = layer->GetInputSlot(slotIndex);
1205
1206         // Get the Connection for the InputSlot
1207         const IOutputSlot* connection = inputSlot.GetConnection();
1208
1209         // Create FlatBuffer Connection
1210         serializer::Connection conn(GetSerializedId(inputSlot.GetConnection()->GetOwningLayerGuid()),
1211                                     connection->CalculateIndexOnOwner());
1212         // Create FlatBuffer InputSlot
1213         inputSlots.push_back(serializer::CreateInputSlot(m_flatBufferBuilder, slotIndex, &conn));
1214     }
1215     return inputSlots;
1216 }
1217
1218 std::vector<fb::Offset<serializer::OutputSlot>>
1219     SerializerVisitor::CreateOutputSlots(const armnn::IConnectableLayer* layer)
1220 {
1221     std::vector<fb::Offset<serializer::OutputSlot>> outputSlots;
1222
1223     // Get the OutputSlots
1224     for (unsigned int slotIndex = 0; slotIndex < layer->GetNumOutputSlots(); ++slotIndex)
1225     {
1226         const IOutputSlot& outputSlot = layer->GetOutputSlot(slotIndex);
1227         const armnn::TensorInfo& tensorInfo = outputSlot.GetTensorInfo();
1228
1229         // Get the dimensions
1230         std::vector<unsigned int> shape;
1231         for(unsigned int dim = 0; dim < tensorInfo.GetShape().GetNumDimensions(); ++dim)
1232         {
1233             shape.push_back(tensorInfo.GetShape()[dim]);
1234         }
1235
1236         // Create FlatBuffer TensorInfo
1237         auto flatBufferTensorInfo = serializer::CreateTensorInfo(m_flatBufferBuilder,
1238                                                                  m_flatBufferBuilder.CreateVector(shape),
1239                                                                  GetFlatBufferDataType(tensorInfo.GetDataType()),
1240                                                                  tensorInfo.GetQuantizationScale(),
1241                                                                  tensorInfo.GetQuantizationOffset());
1242
1243         // Create FlatBuffer Outputslot
1244         outputSlots.push_back(serializer::CreateOutputSlot(m_flatBufferBuilder,
1245                                                            slotIndex,
1246                                                            flatBufferTensorInfo));
1247     }
1248     return outputSlots;
1249 }
1250
1251
1252 ISerializer* ISerializer::CreateRaw()
1253 {
1254     return new Serializer();
1255 }
1256
1257 ISerializerPtr ISerializer::Create()
1258 {
1259     return ISerializerPtr(CreateRaw(), &ISerializer::Destroy);
1260 }
1261
1262 void ISerializer::Destroy(ISerializer* serializer)
1263 {
1264     delete serializer;
1265 }
1266
1267 void Serializer::Serialize(const INetwork& inNetwork)
1268 {
1269     // Iterate through to network
1270     inNetwork.Accept(m_SerializerVisitor);
1271     flatbuffers::FlatBufferBuilder& fbBuilder = m_SerializerVisitor.GetFlatBufferBuilder();
1272
1273     // Create FlatBuffer SerializedGraph
1274     auto serializedGraph = serializer::CreateSerializedGraph(
1275         fbBuilder,
1276         fbBuilder.CreateVector(m_SerializerVisitor.GetSerializedLayers()),
1277         fbBuilder.CreateVector(m_SerializerVisitor.GetInputIds()),
1278         fbBuilder.CreateVector(m_SerializerVisitor.GetOutputIds()));
1279
1280     // Serialize the graph
1281     fbBuilder.Finish(serializedGraph);
1282 }
1283
1284 bool Serializer::SaveSerializedToStream(std::ostream& stream)
1285 {
1286     flatbuffers::FlatBufferBuilder& fbBuilder = m_SerializerVisitor.GetFlatBufferBuilder();
1287
1288     auto bytesToWrite = boost::numeric_cast<std::streamsize>(fbBuilder.GetSize());
1289     stream.write(reinterpret_cast<const char*>(fbBuilder.GetBufferPointer()), bytesToWrite);
1290     return !stream.bad();
1291 }
1292
1293 } // namespace armnnSerializer