IVGCVSW-4777 Add QLstm serialization support
[platform/upstream/armnn.git] / src / armnnSerializer / Serializer.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "Serializer.hpp"
7
8 #include <armnn/Descriptors.hpp>
9 #include <armnn/LstmParams.hpp>
10 #include <armnn/QuantizedLstmParams.hpp>
11 #include <armnn/utility/IgnoreUnused.hpp>
12
13 #include <iostream>
14
15 #include <boost/numeric/conversion/cast.hpp>
16 #include <flatbuffers/util.h>
17
18 #include "SerializerUtils.hpp"
19
20 using namespace armnn;
21 namespace fb = flatbuffers;
22 namespace serializer = armnnSerializer;
23
24 namespace armnnSerializer
25 {
26
27 serializer::ActivationFunction GetFlatBufferActivationFunction(armnn::ActivationFunction function)
28 {
29     switch (function)
30     {
31         case armnn::ActivationFunction::Sigmoid:
32             return serializer::ActivationFunction::ActivationFunction_Sigmoid;
33         case armnn::ActivationFunction::TanH:
34             return serializer::ActivationFunction::ActivationFunction_TanH;
35         case armnn::ActivationFunction::Linear:
36             return serializer::ActivationFunction::ActivationFunction_Linear;
37         case armnn::ActivationFunction::ReLu:
38             return serializer::ActivationFunction::ActivationFunction_ReLu;
39         case armnn::ActivationFunction::BoundedReLu:
40             return serializer::ActivationFunction::ActivationFunction_BoundedReLu;
41         case armnn::ActivationFunction::LeakyReLu:
42             return serializer::ActivationFunction::ActivationFunction_LeakyReLu;
43         case armnn::ActivationFunction::Abs:
44             return serializer::ActivationFunction::ActivationFunction_Abs;
45         case armnn::ActivationFunction::Sqrt:
46             return serializer::ActivationFunction::ActivationFunction_Sqrt;
47         case armnn::ActivationFunction::Square:
48             return serializer::ActivationFunction::ActivationFunction_Square;
49         case armnn::ActivationFunction::Elu:
50             return serializer::ActivationFunction::ActivationFunction_Elu;
51         case armnn::ActivationFunction::HardSwish:
52             return serializer::ActivationFunction::ActivationFunction_HardSwish;
53         default:
54             return serializer::ActivationFunction::ActivationFunction_Sigmoid;
55     }
56 }
57
58 serializer::ArgMinMaxFunction GetFlatBufferArgMinMaxFunction(armnn::ArgMinMaxFunction function)
59 {
60     switch (function)
61     {
62         case armnn::ArgMinMaxFunction::Max:
63             return serializer::ArgMinMaxFunction::ArgMinMaxFunction_Max;
64         case armnn::ArgMinMaxFunction::Min:
65         default:
66             return serializer::ArgMinMaxFunction::ArgMinMaxFunction_Min;
67     }
68 }
69
70 uint32_t SerializerVisitor::GetSerializedId(armnn::LayerGuid guid)
71 {
72     if (m_guidMap.empty())
73     {
74         m_guidMap.insert(std::make_pair(guid, m_layerId));
75     }
76     else if (m_guidMap.find(guid) == m_guidMap.end())
77     {
78         ++m_layerId;
79         m_guidMap.insert(std::make_pair(guid, m_layerId));
80
81         return m_layerId;
82     }
83     return m_guidMap[guid];
84 }
85
86 // Build FlatBuffer for Input Layer
87 void SerializerVisitor::VisitInputLayer(const armnn::IConnectableLayer* layer, LayerBindingId id, const char* name)
88 {
89     IgnoreUnused(name);
90
91     // Create FlatBuffer BaseLayer
92     auto flatBufferInputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Input);
93
94     // Create FlatBuffer BindableBaseLayer
95     auto flatBufferInputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
96                                                                                 flatBufferInputBaseLayer,
97                                                                                 id);
98     // Push layer binding id to outputIds.
99     m_inputIds.push_back(id);
100
101     // Create the FlatBuffer InputLayer
102     auto flatBufferInputLayer = serializer::CreateInputLayer(m_flatBufferBuilder, flatBufferInputBindableBaseLayer);
103
104     // Add the AnyLayer to the FlatBufferLayers
105     CreateAnyLayer(flatBufferInputLayer.o, serializer::Layer::Layer_InputLayer);
106 }
107
108 // Build FlatBuffer for Output Layer
109 void SerializerVisitor::VisitOutputLayer(const armnn::IConnectableLayer* layer, LayerBindingId id, const char* name)
110 {
111     IgnoreUnused(name);
112
113     // Create FlatBuffer BaseLayer
114     auto flatBufferOutputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Output);
115
116     // Create FlatBuffer BindableBaseLayer
117     auto flatBufferOutputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
118                                                                                  flatBufferOutputBaseLayer,
119                                                                                  id);
120     // Push layer binding id to outputIds.
121     m_outputIds.push_back(id);
122
123     // Create the FlatBuffer OutputLayer
124     auto flatBufferOutputLayer = serializer::CreateOutputLayer(m_flatBufferBuilder, flatBufferOutputBindableBaseLayer);
125     // Add the AnyLayer to the FlatBufferLayers
126     CreateAnyLayer(flatBufferOutputLayer.o, serializer::Layer::Layer_OutputLayer);
127 }
128
129 void SerializerVisitor::VisitAbsLayer(const armnn::IConnectableLayer* layer, const char* name)
130 {
131     IgnoreUnused(name);
132     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Abs);
133     auto flatBufferAbsLayer  = serializer::CreateAbsLayer(m_flatBufferBuilder, flatBufferBaseLayer);
134
135     CreateAnyLayer(flatBufferAbsLayer.o, serializer::Layer::Layer_AbsLayer);
136 }
137
138 // Build FlatBuffer for Activation Layer
139 void SerializerVisitor::VisitActivationLayer(const armnn::IConnectableLayer* layer,
140                                              const armnn::ActivationDescriptor& descriptor,
141                                              const char* name)
142 {
143     IgnoreUnused(name);
144
145     // Create FlatBuffer BaseLayer
146     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Activation);
147
148     // Create the FlatBuffer ActivationDescriptor
149     auto flatBufferDescriptor = CreateActivationDescriptor(m_flatBufferBuilder,
150                                                            GetFlatBufferActivationFunction(descriptor.m_Function),
151                                                            descriptor.m_A,
152                                                            descriptor.m_B);
153
154     // Create the FlatBuffer ActivationLayer
155     auto flatBufferAdditionLayer = CreateActivationLayer(m_flatBufferBuilder,
156                                                          flatBufferBaseLayer,
157                                                          flatBufferDescriptor);
158
159     // Add the AnyLayer to the FlatBufferLayers
160     CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_ActivationLayer);
161 }
162
163 // Build FlatBuffer for Addition Layer
164 void SerializerVisitor::VisitAdditionLayer(const armnn::IConnectableLayer* layer, const char* name)
165 {
166     IgnoreUnused(name);
167
168     // Create FlatBuffer BaseLayer
169     auto flatBufferAdditionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Addition);
170
171     // Create the FlatBuffer AdditionLayer
172     auto flatBufferAdditionLayer = serializer::CreateAdditionLayer(m_flatBufferBuilder, flatBufferAdditionBaseLayer);
173
174     // Add the AnyLayer to the FlatBufferLayers
175     CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_AdditionLayer);
176 }
177
178 // Build FlatBuffer for ArgMinMax Layer
179 void SerializerVisitor::VisitArgMinMaxLayer(const armnn::IConnectableLayer *layer,
180                                             const armnn::ArgMinMaxDescriptor& descriptor,
181                                             const char *name)
182 {
183     IgnoreUnused(name);
184
185     // Create FlatBuffer BaseLayer
186     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_ArgMinMax);
187
188     // Create FlatBuffer Descriptor
189     auto flatBufferDescriptor = CreateArgMinMaxDescriptor(m_flatBufferBuilder,
190                                                           GetFlatBufferArgMinMaxFunction(descriptor.m_Function),
191                                                           descriptor.m_Axis);
192
193     // Create FlatBuffer ArgMinMaxLayer
194     auto flatBufferLayer = CreateArgMinMaxLayer(m_flatBufferBuilder,
195                                                 flatBufferBaseLayer,
196                                                 flatBufferDescriptor);
197
198     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ArgMinMaxLayer);
199 }
200
201 // Build FlatBuffer for BatchToSpaceNd Layer
202 void SerializerVisitor::VisitBatchToSpaceNdLayer(const armnn::IConnectableLayer* layer,
203                                                  const armnn::BatchToSpaceNdDescriptor& descriptor,
204                                                  const char* name)
205 {
206     IgnoreUnused(name);
207
208     // Create FlatBuffer BaseLayer
209     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_BatchToSpaceNd);
210
211     std::vector<unsigned int> crops;
212     crops.reserve(descriptor.m_Crops.size() * 2);
213     for (auto& crop : descriptor.m_Crops)
214     {
215         crops.push_back(crop.first);
216         crops.push_back(crop.second);
217     }
218
219     auto flatBufferDescriptor =
220         CreateBatchToSpaceNdDescriptor(m_flatBufferBuilder,
221                                        m_flatBufferBuilder.CreateVector(descriptor.m_BlockShape),
222                                        m_flatBufferBuilder.CreateVector(crops),
223                                        GetFlatBufferDataLayout(descriptor.m_DataLayout));
224
225     auto flatBufferLayer = serializer::CreateBatchToSpaceNdLayer(m_flatBufferBuilder,
226                                                                  flatBufferBaseLayer,
227                                                                  flatBufferDescriptor);
228
229     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_BatchToSpaceNdLayer);
230 }
231
232 void SerializerVisitor::VisitBatchNormalizationLayer(const armnn::IConnectableLayer* layer,
233                                                      const armnn::BatchNormalizationDescriptor& batchNormDescriptor,
234                                                      const armnn::ConstTensor& mean,
235                                                      const armnn::ConstTensor& variance,
236                                                      const armnn::ConstTensor& beta,
237                                                      const armnn::ConstTensor& gamma,
238                                                      const char* name)
239 {
240     IgnoreUnused(name);
241
242     auto fbBatchNormalizationBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_BatchNormalization);
243     auto fbBatchNormalizationDescriptor = serializer::CreateBatchNormalizationDescriptor(
244                                                   m_flatBufferBuilder,
245                                                   batchNormDescriptor.m_Eps,
246                                                   GetFlatBufferDataLayout(batchNormDescriptor.m_DataLayout));
247
248     auto fbMeanConstTensorInfo     = CreateConstTensorInfo(mean);
249     auto fbVarianceConstTensorInfo = CreateConstTensorInfo(variance);
250     auto fbBetaConstTensorInfo     = CreateConstTensorInfo(beta);
251     auto fbGammaConstTensorInfo    = CreateConstTensorInfo(gamma);
252     auto fbBatchNormalizationLayer = serializer::CreateBatchNormalizationLayer(m_flatBufferBuilder,
253                                                                                fbBatchNormalizationBaseLayer,
254                                                                                fbBatchNormalizationDescriptor,
255                                                                                fbMeanConstTensorInfo,
256                                                                                fbVarianceConstTensorInfo,
257                                                                                fbBetaConstTensorInfo,
258                                                                                fbGammaConstTensorInfo);
259
260     CreateAnyLayer(fbBatchNormalizationLayer.o, serializer::Layer::Layer_BatchNormalizationLayer);
261 }
262
263 void SerializerVisitor::VisitComparisonLayer(const armnn::IConnectableLayer* layer,
264                                              const armnn::ComparisonDescriptor& descriptor,
265                                              const char* name)
266 {
267     IgnoreUnused(name);
268
269     auto fbBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_Comparison);
270     auto fbDescriptor = serializer::CreateComparisonDescriptor(
271         m_flatBufferBuilder,
272         GetFlatBufferComparisonOperation(descriptor.m_Operation));
273
274     auto fbLayer = serializer::CreateComparisonLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
275     CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_ComparisonLayer);
276 }
277
278 // Build FlatBuffer for Constant Layer
279 void SerializerVisitor::VisitConstantLayer(const armnn::IConnectableLayer* layer,
280                                            const armnn::ConstTensor& input,
281                                            const char* name)
282 {
283     IgnoreUnused(name);
284
285     // Create FlatBuffer BaseLayer
286     auto flatBufferConstantBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Constant);
287
288     auto flatBufferConstTensorInfo = CreateConstTensorInfo(input);
289
290     // Create the FlatBuffer ConstantLayer
291     auto flatBufferLayer = CreateConstantLayer(m_flatBufferBuilder,
292                                                flatBufferConstantBaseLayer,
293                                                flatBufferConstTensorInfo);
294
295     // Add the AnyLayer to the FlatBufferLayers
296     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ConstantLayer);
297 }
298
299 // Build FlatBuffer for Convolution2dLayer
300 void SerializerVisitor::VisitConvolution2dLayer(const armnn::IConnectableLayer* layer,
301                                                 const armnn::Convolution2dDescriptor& descriptor,
302                                                 const armnn::ConstTensor& weights,
303                                                 const armnn::Optional<armnn::ConstTensor>& biases,
304                                                 const char* name)
305 {
306     IgnoreUnused(name);
307
308     // Create FlatBuffer BaseLayer
309     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Convolution2d);
310
311     auto flatBufferDescriptor = CreateConvolution2dDescriptor(m_flatBufferBuilder,
312                                                               descriptor.m_PadLeft,
313                                                               descriptor.m_PadRight,
314                                                               descriptor.m_PadTop,
315                                                               descriptor.m_PadBottom,
316                                                               descriptor.m_StrideX,
317                                                               descriptor.m_StrideY,
318                                                               descriptor.m_DilationX,
319                                                               descriptor.m_DilationY,
320                                                               descriptor.m_BiasEnabled,
321                                                               GetFlatBufferDataLayout(descriptor.m_DataLayout));
322     auto flatBufferWeightsConstTensorInfo = CreateConstTensorInfo(weights);
323     flatbuffers::Offset<serializer::ConstTensor> flatBufferBiasesConstTensorInfo;
324
325     if (biases.has_value())
326     {
327         flatBufferBiasesConstTensorInfo = CreateConstTensorInfo(biases.value());
328     }
329
330     // Create the FlatBuffer Convolution2dLayer
331     auto flatBufferLayer = CreateConvolution2dLayer(m_flatBufferBuilder,
332                                                     flatBufferBaseLayer,
333                                                     flatBufferDescriptor,
334                                                     flatBufferWeightsConstTensorInfo,
335                                                     flatBufferBiasesConstTensorInfo);
336
337     // Add the AnyLayer to the FlatBufferLayers
338     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_Convolution2dLayer);
339 }
340
341 void SerializerVisitor::VisitDepthToSpaceLayer(const armnn::IConnectableLayer* layer,
342                                                const armnn::DepthToSpaceDescriptor& descriptor,
343                                                const char* name)
344 {
345     IgnoreUnused(name);
346
347     auto fbBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_DepthToSpace);
348     auto fbDescriptor = CreateDepthToSpaceDescriptor(m_flatBufferBuilder,
349                                                      descriptor.m_BlockSize,
350                                                      GetFlatBufferDataLayout(descriptor.m_DataLayout));
351
352     auto fbLayer = serializer::CreateDepthToSpaceLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
353
354     CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_DepthToSpaceLayer);
355 }
356
357 void SerializerVisitor::VisitDepthwiseConvolution2dLayer(const armnn::IConnectableLayer* layer,
358                                                          const armnn::DepthwiseConvolution2dDescriptor& descriptor,
359                                                          const armnn::ConstTensor& weights,
360                                                          const armnn::Optional<armnn::ConstTensor>& biases,
361                                                          const char* name)
362 {
363     IgnoreUnused(name);
364
365     auto fbBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_DepthwiseConvolution2d);
366     auto fbDescriptor = CreateDepthwiseConvolution2dDescriptor(m_flatBufferBuilder,
367                                                                descriptor.m_PadLeft,
368                                                                descriptor.m_PadRight,
369                                                                descriptor.m_PadTop,
370                                                                descriptor.m_PadBottom,
371                                                                descriptor.m_StrideX,
372                                                                descriptor.m_StrideY,
373                                                                descriptor.m_DilationX,
374                                                                descriptor.m_DilationY,
375                                                                descriptor.m_BiasEnabled,
376                                                                GetFlatBufferDataLayout(descriptor.m_DataLayout));
377
378     flatbuffers::Offset<serializer::ConstTensor> fbWeightsConstTensorInfo = CreateConstTensorInfo(weights);
379     flatbuffers::Offset<serializer::ConstTensor> fbBiasesConstTensorInfo;
380     if (biases.has_value())
381     {
382         fbBiasesConstTensorInfo = CreateConstTensorInfo(biases.value());
383     }
384
385     auto flatBufferLayer = CreateDepthwiseConvolution2dLayer(m_flatBufferBuilder,
386                                                              fbBaseLayer,
387                                                              fbDescriptor,
388                                                              fbWeightsConstTensorInfo,
389                                                              fbBiasesConstTensorInfo);
390
391     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_DepthwiseConvolution2dLayer);
392 }
393
394 void SerializerVisitor::VisitDequantizeLayer(const armnn::IConnectableLayer* layer,
395                                              const char* name)
396 {
397     IgnoreUnused(name);
398
399     auto fbDequantizeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Dequantize);
400     auto fbDequantizeLayer     = serializer::CreateDequantizeLayer(m_flatBufferBuilder, fbDequantizeBaseLayer);
401
402     CreateAnyLayer(fbDequantizeLayer.o, serializer::Layer::Layer_DequantizeLayer);
403 }
404
405 void SerializerVisitor::VisitDetectionPostProcessLayer(const armnn::IConnectableLayer* layer,
406                                                        const armnn::DetectionPostProcessDescriptor& descriptor,
407                                                        const armnn::ConstTensor& anchors,
408                                                        const char* name)
409 {
410     IgnoreUnused(name);
411
412     auto fbBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_DetectionPostProcess);
413     auto fbDescriptor = CreateDetectionPostProcessDescriptor(m_flatBufferBuilder,
414                                                              descriptor.m_MaxDetections,
415                                                              descriptor.m_MaxClassesPerDetection,
416                                                              descriptor.m_DetectionsPerClass,
417                                                              descriptor.m_NmsScoreThreshold,
418                                                              descriptor.m_NmsIouThreshold,
419                                                              descriptor.m_NumClasses,
420                                                              descriptor.m_UseRegularNms,
421                                                              descriptor.m_ScaleX,
422                                                              descriptor.m_ScaleY,
423                                                              descriptor.m_ScaleW,
424                                                              descriptor.m_ScaleH);
425
426     flatbuffers::Offset<serializer::ConstTensor> fbAnchorsConstTensorInfo = CreateConstTensorInfo(anchors);
427
428     auto flatBufferLayer = CreateDetectionPostProcessLayer(m_flatBufferBuilder,
429                                                            fbBaseLayer,
430                                                            fbDescriptor,
431                                                            fbAnchorsConstTensorInfo);
432
433     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_DetectionPostProcessLayer);
434 }
435
436 void SerializerVisitor::VisitDivisionLayer(const armnn::IConnectableLayer* layer, const char* name)
437 {
438     IgnoreUnused(name);
439
440     auto fbDivisionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Division);
441     auto fbDivisionLayer     = serializer::CreateDivisionLayer(m_flatBufferBuilder, fbDivisionBaseLayer);
442
443     CreateAnyLayer(fbDivisionLayer.o, serializer::Layer::Layer_DivisionLayer);
444 }
445
446 void SerializerVisitor::VisitElementwiseUnaryLayer(const armnn::IConnectableLayer* layer,
447                                                    const armnn::ElementwiseUnaryDescriptor& descriptor,
448                                                    const char* name)
449 {
450     IgnoreUnused(name);
451
452     auto fbBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_ElementwiseUnary);
453     auto fbDescriptor = serializer::CreateElementwiseUnaryDescriptor(
454         m_flatBufferBuilder,
455         GetFlatBufferUnaryOperation(descriptor.m_Operation));
456
457     auto fbLayer = serializer::CreateElementwiseUnaryLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
458     CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_ElementwiseUnaryLayer);
459 }
460
461 void SerializerVisitor::VisitEqualLayer(const armnn::IConnectableLayer* layer, const char* name)
462 {
463     IgnoreUnused(name);
464
465     auto fbBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_Equal);
466     auto fbEqualLayer = serializer::CreateEqualLayer(m_flatBufferBuilder, fbBaseLayer);
467
468     CreateAnyLayer(fbEqualLayer.o, serializer::Layer::Layer_EqualLayer);
469 }
470
471 void SerializerVisitor::VisitFloorLayer(const armnn::IConnectableLayer *layer, const char *name)
472 {
473     IgnoreUnused(name);
474
475     auto flatBufferFloorBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Floor);
476     auto flatBufferFloorLayer = serializer::CreateFloorLayer(m_flatBufferBuilder, flatBufferFloorBaseLayer);
477
478     CreateAnyLayer(flatBufferFloorLayer.o, serializer::Layer::Layer_FloorLayer);
479 }
480
481 void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer, const char* name)
482 {
483     IgnoreUnused(name);
484
485     auto fbGatherBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Gather);
486     auto flatBufferLayer   = serializer::CreateGatherLayer(m_flatBufferBuilder, fbGatherBaseLayer);
487
488     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_GatherLayer);
489 }
490
491 void SerializerVisitor::VisitGreaterLayer(const armnn::IConnectableLayer* layer, const char* name)
492 {
493     IgnoreUnused(name);
494
495     auto fbGreaterBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Greater);
496     auto fbGreaterLayer = serializer::CreateGreaterLayer(m_flatBufferBuilder, fbGreaterBaseLayer);
497
498     CreateAnyLayer(fbGreaterLayer.o, serializer::Layer::Layer_GreaterLayer);
499 }
500
501 void SerializerVisitor::VisitInstanceNormalizationLayer(
502     const armnn::IConnectableLayer* layer,
503     const armnn::InstanceNormalizationDescriptor& instanceNormalizationDescriptor,
504     const char* name)
505 {
506     IgnoreUnused(name);
507
508     auto fbDescriptor = serializer::CreateInstanceNormalizationDescriptor(
509             m_flatBufferBuilder,
510             instanceNormalizationDescriptor.m_Gamma,
511             instanceNormalizationDescriptor.m_Beta,
512             instanceNormalizationDescriptor.m_Eps,
513             GetFlatBufferDataLayout(instanceNormalizationDescriptor.m_DataLayout));
514
515     auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_InstanceNormalization);
516     auto fbLayer     = serializer::CreateInstanceNormalizationLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
517
518     CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_InstanceNormalizationLayer);
519 }
520
521 void SerializerVisitor::VisitL2NormalizationLayer(const armnn::IConnectableLayer* layer,
522                                                   const armnn::L2NormalizationDescriptor& l2NormalizationDescriptor,
523                                                   const char* name)
524 {
525     IgnoreUnused(name);
526
527     // Create FlatBuffer BaseLayer
528     auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_L2Normalization);
529
530     // Create the FlatBuffer L2Normalization Descriptor
531     auto fbDescriptor = serializer::CreateL2NormalizationDescriptor(
532             m_flatBufferBuilder,
533             GetFlatBufferDataLayout(l2NormalizationDescriptor.m_DataLayout),
534             l2NormalizationDescriptor.m_Eps);
535
536     // Create FlatBuffer layer
537     auto fbLayer = serializer::CreateL2NormalizationLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
538
539     CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_L2NormalizationLayer);
540 }
541
542 void SerializerVisitor::VisitLogSoftmaxLayer(const armnn::IConnectableLayer* layer,
543                                              const armnn::LogSoftmaxDescriptor& logSoftmaxDescriptor,
544                                              const char* name)
545 {
546     IgnoreUnused(name);
547
548     // Create FlatBuffer BaseLayer
549     auto flatBufferLogSoftmaxBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_LogSoftmax);
550
551     // Create the FlatBuffer LogSoftmaxDescriptor
552     auto flatBufferLogSoftmaxDesc =
553         serializer::CreateLogSoftmaxDescriptor(m_flatBufferBuilder,
554                                                logSoftmaxDescriptor.m_Beta,
555                                                logSoftmaxDescriptor.m_Axis);
556
557     // Create the FlatBuffer LogSoftmaxLayer
558     auto flatBufferLogSoftmaxLayer =
559         serializer::CreateLogSoftmaxLayer(m_flatBufferBuilder,
560                                           flatBufferLogSoftmaxBaseLayer,
561                                           flatBufferLogSoftmaxDesc);
562
563     CreateAnyLayer(flatBufferLogSoftmaxLayer.o, serializer::Layer::Layer_LogSoftmaxLayer);
564 }
565
566 void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer,
567                                        const armnn::LstmDescriptor& descriptor,
568                                        const armnn::LstmInputParams& params,
569                                        const char* name)
570 {
571     IgnoreUnused(name);
572
573     auto fbLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Lstm);
574
575     auto fbLstmDescriptor = serializer::CreateLstmDescriptor(
576         m_flatBufferBuilder,
577         descriptor.m_ActivationFunc,
578         descriptor.m_ClippingThresCell,
579         descriptor.m_ClippingThresProj,
580         descriptor.m_CifgEnabled,
581         descriptor.m_PeepholeEnabled,
582         descriptor.m_ProjectionEnabled,
583         descriptor.m_LayerNormEnabled);
584
585     // Get mandatory input parameters
586     auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights);
587     auto inputToCellWeights = CreateConstTensorInfo(*params.m_InputToCellWeights);
588     auto inputToOutputWeights = CreateConstTensorInfo(*params.m_InputToOutputWeights);
589     auto recurrentToForgetWeights = CreateConstTensorInfo(*params.m_RecurrentToForgetWeights);
590     auto recurrentToCellWeights = CreateConstTensorInfo(*params.m_RecurrentToCellWeights);
591     auto recurrentToOutputWeights = CreateConstTensorInfo(*params.m_RecurrentToOutputWeights);
592     auto forgetGateBias = CreateConstTensorInfo(*params.m_ForgetGateBias);
593     auto cellBias = CreateConstTensorInfo(*params.m_CellBias);
594     auto outputGateBias = CreateConstTensorInfo(*params.m_OutputGateBias);
595
596     //Define optional parameters, these will be set depending on configuration in Lstm descriptor
597     flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights;
598     flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights;
599     flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights;
600     flatbuffers::Offset<serializer::ConstTensor> inputGateBias;
601     flatbuffers::Offset<serializer::ConstTensor> projectionWeights;
602     flatbuffers::Offset<serializer::ConstTensor> projectionBias;
603     flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights;
604     flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights;
605     flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights;
606     flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights;
607     flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights;
608     flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights;
609
610     if (!descriptor.m_CifgEnabled)
611     {
612         inputToInputWeights = CreateConstTensorInfo(*params.m_InputToInputWeights);
613         recurrentToInputWeights = CreateConstTensorInfo(*params.m_RecurrentToInputWeights);
614         cellToInputWeights = CreateConstTensorInfo(*params.m_CellToInputWeights);
615         inputGateBias = CreateConstTensorInfo(*params.m_InputGateBias);
616     }
617
618     if (descriptor.m_ProjectionEnabled)
619     {
620         projectionWeights = CreateConstTensorInfo(*params.m_ProjectionWeights);
621         projectionBias = CreateConstTensorInfo(*params.m_ProjectionBias);
622     }
623
624     if (descriptor.m_PeepholeEnabled)
625     {
626         cellToForgetWeights = CreateConstTensorInfo(*params.m_CellToForgetWeights);
627         cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights);
628     }
629
630     if (descriptor.m_LayerNormEnabled)
631     {
632         if (!descriptor.m_CifgEnabled)
633         {
634             inputLayerNormWeights = CreateConstTensorInfo((*params.m_InputLayerNormWeights));
635         }
636         forgetLayerNormWeights = CreateConstTensorInfo(*params.m_ForgetLayerNormWeights);
637         cellLayerNormWeights   = CreateConstTensorInfo(*params.m_CellLayerNormWeights);
638         outputLayerNormWeights = CreateConstTensorInfo(*params.m_OutputLayerNormWeights);
639     }
640
641     auto fbLstmParams = serializer::CreateLstmInputParams(
642         m_flatBufferBuilder,
643         inputToForgetWeights,
644         inputToCellWeights,
645         inputToOutputWeights,
646         recurrentToForgetWeights,
647         recurrentToCellWeights,
648         recurrentToOutputWeights,
649         forgetGateBias,
650         cellBias,
651         outputGateBias,
652         inputToInputWeights,
653         recurrentToInputWeights,
654         cellToInputWeights,
655         inputGateBias,
656         projectionWeights,
657         projectionBias,
658         cellToForgetWeights,
659         cellToOutputWeights,
660         inputLayerNormWeights,
661         forgetLayerNormWeights,
662         cellLayerNormWeights,
663         outputLayerNormWeights);
664
665     auto fbLstmLayer = serializer::CreateLstmLayer(
666         m_flatBufferBuilder,
667         fbLstmBaseLayer,
668         fbLstmDescriptor,
669         fbLstmParams);
670
671     CreateAnyLayer(fbLstmLayer.o, serializer::Layer::Layer_LstmLayer);
672 }
673
674 void SerializerVisitor::VisitMaximumLayer(const armnn::IConnectableLayer* layer, const char* name)
675 {
676     IgnoreUnused(name);
677
678     auto fbMaximumBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Maximum);
679     auto fbMaximumLayer     = serializer::CreateMaximumLayer(m_flatBufferBuilder, fbMaximumBaseLayer);
680
681     CreateAnyLayer(fbMaximumLayer.o, serializer::Layer::Layer_MaximumLayer);
682 }
683
684 void SerializerVisitor::VisitMeanLayer(const armnn::IConnectableLayer* layer,
685                                        const armnn::MeanDescriptor& descriptor,
686                                        const char* name)
687 {
688     IgnoreUnused(name);
689
690     auto fbMeanBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_Mean);
691     auto fbMeanDescriptor = serializer::CreateMeanDescriptor(m_flatBufferBuilder,
692                                                              m_flatBufferBuilder.CreateVector(descriptor.m_Axis),
693                                                              descriptor.m_KeepDims);
694
695     auto fbMeanLayer = serializer::CreateMeanLayer(m_flatBufferBuilder,
696                                                    fbMeanBaseLayer,
697                                                    fbMeanDescriptor);
698
699     CreateAnyLayer(fbMeanLayer.o, serializer::Layer::Layer_MeanLayer);
700 }
701
702 void SerializerVisitor::VisitMinimumLayer(const armnn::IConnectableLayer* layer, const char* name)
703 {
704     IgnoreUnused(name);
705
706     auto fbMinimumBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Minimum);
707     auto fbMinimumLayer     = serializer::CreateMinimumLayer(m_flatBufferBuilder, fbMinimumBaseLayer);
708
709     CreateAnyLayer(fbMinimumLayer.o, serializer::Layer::Layer_MinimumLayer);
710 }
711
712 void SerializerVisitor::VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name)
713 {
714     IgnoreUnused(name);
715
716     auto fbMergeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Merge);
717     auto fbMergeLayer     = serializer::CreateMergeLayer(m_flatBufferBuilder, fbMergeBaseLayer);
718
719     CreateAnyLayer(fbMergeLayer.o, serializer::Layer::Layer_MergeLayer);
720 }
721
722 void SerializerVisitor::VisitMergerLayer(const armnn::IConnectableLayer* layer,
723                                          const armnn::MergerDescriptor& mergerDescriptor,
724                                          const char* name)
725 {
726     VisitConcatLayer(layer, mergerDescriptor, name);
727 }
728
729 void SerializerVisitor::VisitConcatLayer(const armnn::IConnectableLayer* layer,
730                                          const armnn::ConcatDescriptor& concatDescriptor,
731                                          const char* name)
732 {
733     IgnoreUnused(name);
734
735     auto flatBufferConcatBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Concat);
736
737     std::vector<flatbuffers::Offset<UintVector>> views;
738     for (unsigned int v = 0; v < concatDescriptor.GetNumViews(); ++v)
739     {
740         const uint32_t* origin = concatDescriptor.GetViewOrigin(v);
741         std::vector<uint32_t> origins;
742         for (unsigned int d = 0; d < concatDescriptor.GetNumDimensions(); ++d)
743         {
744             origins.push_back(origin[d]);
745         }
746         auto view = m_flatBufferBuilder.CreateVector(origins);
747         auto uintVector = CreateUintVector(m_flatBufferBuilder, view);
748         views.push_back(uintVector);
749     }
750
751     auto flatBufferConcatDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder,
752                                                               concatDescriptor.GetConcatAxis(),
753                                                               concatDescriptor.GetNumViews(),
754                                                               concatDescriptor.GetNumDimensions(),
755                                                               m_flatBufferBuilder.CreateVector(views));
756
757     auto flatBufferLayer = CreateConcatLayer(m_flatBufferBuilder,
758                                              flatBufferConcatBaseLayer,
759                                              flatBufferConcatDescriptor);
760
761     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ConcatLayer);
762 }
763
764 void SerializerVisitor::VisitMultiplicationLayer(const armnn::IConnectableLayer* layer, const char* name)
765 {
766     IgnoreUnused(name);
767
768     auto fbMultiplicationBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Multiplication);
769     auto fbMultiplicationLayer     = serializer::CreateMultiplicationLayer(m_flatBufferBuilder,
770                                                                            fbMultiplicationBaseLayer);
771
772     CreateAnyLayer(fbMultiplicationLayer.o, serializer::Layer::Layer_MultiplicationLayer);
773 }
774
775 void SerializerVisitor::VisitPadLayer(const armnn::IConnectableLayer* layer,
776                                       const armnn::PadDescriptor& padDescriptor,
777                                       const char* name)
778 {
779     IgnoreUnused(name);
780
781     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Pad);
782
783     std::vector<unsigned int> padList;
784     for (auto& p: padDescriptor.m_PadList)
785     {
786         padList.push_back(p.first);
787         padList.push_back(p.second);
788     }
789
790     auto flatBufferPadDesc = serializer::CreatePadDescriptor(m_flatBufferBuilder,
791                                                              m_flatBufferBuilder.CreateVector(padList),
792                                                              padDescriptor.m_PadValue);
793
794     auto flatBufferPadLayer = serializer::CreatePadLayer(m_flatBufferBuilder,
795                                                          flatBufferBaseLayer,
796                                                          flatBufferPadDesc);
797
798     CreateAnyLayer(flatBufferPadLayer.o, serializer::Layer::Layer_PadLayer);
799 }
800
801 void SerializerVisitor::VisitPermuteLayer(const armnn::IConnectableLayer* layer,
802                                           const armnn::PermuteDescriptor& permuteDescriptor,
803                                           const char* name)
804 {
805     IgnoreUnused(name);
806
807     // Create FlatBuffer BaseLayer
808     auto flatBufferPermuteBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Permute);
809
810     std::vector<unsigned int> dimMappings;
811     for (unsigned int i=0; i<permuteDescriptor.m_DimMappings.GetSize(); ++i)
812     {
813         dimMappings.push_back(permuteDescriptor.m_DimMappings[i]);
814     }
815
816     auto flatBufferPermuteDesc = serializer::CreatePermuteDescriptor(m_flatBufferBuilder,
817                                                                      m_flatBufferBuilder.CreateVector(dimMappings));
818
819     // Create the FlatBuffer PermuteLayer
820     auto flatBufferPermuteLayer = serializer::CreatePermuteLayer(m_flatBufferBuilder,
821                                                                  flatBufferPermuteBaseLayer,
822                                                                  flatBufferPermuteDesc);
823
824     // Add the AnyLayer to the FlatBufferLayers
825     CreateAnyLayer(flatBufferPermuteLayer.o, serializer::Layer::Layer_PermuteLayer);
826 }
827
828 // Build FlatBuffer for Reshape Layer
829 void SerializerVisitor::VisitReshapeLayer(const armnn::IConnectableLayer* layer,
830                                           const armnn::ReshapeDescriptor& reshapeDescriptor,
831                                           const char* name)
832 {
833     IgnoreUnused(name);
834
835     // Create FlatBuffer BaseLayer
836     auto flatBufferReshapeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Reshape);
837
838     std::vector<unsigned int> targetShape;
839     for (unsigned int i =0; i < reshapeDescriptor.m_TargetShape.GetNumDimensions(); i++)
840     {
841         targetShape.push_back(reshapeDescriptor.m_TargetShape[i]);
842     }
843
844     auto flatBufferReshapeDesc = serializer::CreateReshapeDescriptor(m_flatBufferBuilder,
845                                                                      m_flatBufferBuilder.CreateVector(targetShape));
846
847     // Create the FlatBuffer ReshapeLayer
848     auto flatBufferReshapeLayer = serializer::CreateReshapeLayer(m_flatBufferBuilder, flatBufferReshapeBaseLayer,
849                                                                  flatBufferReshapeDesc);
850
851     // Add the AnyLayer to the FlatBufferLayers
852     CreateAnyLayer(flatBufferReshapeLayer.o, serializer::Layer::Layer_ReshapeLayer);
853 }
854
855 void SerializerVisitor::VisitResizeBilinearLayer(const armnn::IConnectableLayer* layer,
856                                                  const armnn::ResizeBilinearDescriptor& resizeDescriptor,
857                                                  const char* name)
858 {
859     IgnoreUnused(name);
860
861     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_ResizeBilinear);
862
863     auto flatBufferDescriptor =
864         CreateResizeBilinearDescriptor(m_flatBufferBuilder,
865                                        resizeDescriptor.m_TargetWidth,
866                                        resizeDescriptor.m_TargetHeight,
867                                        GetFlatBufferDataLayout(resizeDescriptor.m_DataLayout));
868
869     auto flatBufferLayer = serializer::CreateResizeBilinearLayer(m_flatBufferBuilder,
870                                                                  flatBufferBaseLayer,
871                                                                  flatBufferDescriptor);
872
873     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ResizeBilinearLayer);
874 }
875
876 void SerializerVisitor::VisitResizeLayer(const armnn::IConnectableLayer* layer,
877                                          const armnn::ResizeDescriptor& resizeDescriptor,
878                                          const char* name)
879 {
880     IgnoreUnused(name);
881
882     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Resize);
883
884     auto flatBufferDescriptor =
885             CreateResizeDescriptor(m_flatBufferBuilder,
886                                    resizeDescriptor.m_TargetHeight,
887                                    resizeDescriptor.m_TargetWidth,
888                                    GetFlatBufferResizeMethod(resizeDescriptor.m_Method),
889                                    GetFlatBufferDataLayout(resizeDescriptor.m_DataLayout));
890
891     auto flatBufferLayer = serializer::CreateResizeLayer(m_flatBufferBuilder,
892                                                          flatBufferBaseLayer,
893                                                          flatBufferDescriptor);
894
895     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ResizeLayer);
896 }
897
898 void SerializerVisitor::VisitRsqrtLayer(const armnn::IConnectableLayer* layer, const char* name)
899 {
900     IgnoreUnused(name);
901
902     auto fbRsqrtBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Rsqrt);
903     auto fbRsqrtLayer     = serializer::CreateRsqrtLayer(m_flatBufferBuilder, fbRsqrtBaseLayer);
904
905     CreateAnyLayer(fbRsqrtLayer.o, serializer::Layer::Layer_RsqrtLayer);
906 }
907
908 void SerializerVisitor::VisitSliceLayer(const armnn::IConnectableLayer* layer,
909                                         const armnn::SliceDescriptor& sliceDescriptor,
910                                         const char* name)
911 {
912     IgnoreUnused(name);
913
914     auto fbSliceBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_Slice);
915     auto fbSliceDescriptor = CreateSliceDescriptor(m_flatBufferBuilder,
916                                                    m_flatBufferBuilder.CreateVector(sliceDescriptor.m_Begin),
917                                                    m_flatBufferBuilder.CreateVector(sliceDescriptor.m_Size));
918
919     auto fbSliceLayer = serializer::CreateSliceLayer(m_flatBufferBuilder, fbSliceBaseLayer, fbSliceDescriptor);
920
921     CreateAnyLayer(fbSliceLayer.o, serializer::Layer::Layer_SliceLayer);
922 }
923
924 // Build FlatBuffer for Softmax Layer
925 void SerializerVisitor::VisitSoftmaxLayer(const armnn::IConnectableLayer* layer,
926                                           const armnn::SoftmaxDescriptor& softmaxDescriptor,
927                                           const char* name)
928 {
929     IgnoreUnused(name);
930
931     // Create FlatBuffer BaseLayer
932     auto flatBufferSoftmaxBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Softmax);
933
934     // Create the FlatBuffer SoftmaxDescriptor
935     auto flatBufferSoftmaxDesc =
936         serializer::CreateSoftmaxDescriptor(m_flatBufferBuilder, softmaxDescriptor.m_Beta);
937
938     // Create the FlatBuffer SoftmaxLayer
939     auto flatBufferSoftmaxLayer =
940         serializer::CreateSoftmaxLayer(m_flatBufferBuilder,
941                                        flatBufferSoftmaxBaseLayer,
942                                        flatBufferSoftmaxDesc);
943
944     CreateAnyLayer(flatBufferSoftmaxLayer.o, serializer::Layer::Layer_SoftmaxLayer);
945 }
946
947 void SerializerVisitor::VisitPooling2dLayer(const armnn::IConnectableLayer* layer,
948                                             const armnn::Pooling2dDescriptor& pooling2dDescriptor,
949                                             const char* name)
950 {
951     IgnoreUnused(name);
952
953     auto fbPooling2dBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_Pooling2d);
954     auto fbPooling2dDescriptor = serializer::CreatePooling2dDescriptor(
955         m_flatBufferBuilder,
956         GetFlatBufferPoolingAlgorithm(pooling2dDescriptor.m_PoolType),
957         pooling2dDescriptor.m_PadLeft,
958         pooling2dDescriptor.m_PadRight,
959         pooling2dDescriptor.m_PadTop,
960         pooling2dDescriptor.m_PadBottom,
961         pooling2dDescriptor.m_PoolWidth,
962         pooling2dDescriptor.m_PoolHeight,
963         pooling2dDescriptor.m_StrideX,
964         pooling2dDescriptor.m_StrideY,
965         GetFlatBufferOutputShapeRounding(pooling2dDescriptor.m_OutputShapeRounding),
966         GetFlatBufferPaddingMethod(pooling2dDescriptor.m_PaddingMethod),
967         GetFlatBufferDataLayout(pooling2dDescriptor.m_DataLayout));
968
969     auto fbPooling2dLayer = serializer::CreatePooling2dLayer(m_flatBufferBuilder,
970                                                              fbPooling2dBaseLayer,
971                                                              fbPooling2dDescriptor);
972
973     CreateAnyLayer(fbPooling2dLayer.o, serializer::Layer::Layer_Pooling2dLayer);
974 }
975
976 void SerializerVisitor::VisitPreluLayer(const armnn::IConnectableLayer* layer,
977                                         const char* name)
978 {
979     IgnoreUnused(name);
980
981     // Create FlatBuffer BaseLayer
982     auto flatBufferPreluBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Prelu);
983
984     // Create the FlatBuffer AdditionLayer
985     auto flatBufferPreluLayer = serializer::CreatePreluLayer(m_flatBufferBuilder, flatBufferPreluBaseLayer);
986
987     // Add the AnyLayer to the FlatBufferLayers
988     CreateAnyLayer(flatBufferPreluLayer.o, serializer::Layer::Layer_PreluLayer);
989 }
990
991 void SerializerVisitor::VisitQuantizeLayer(const armnn::IConnectableLayer *layer, const char *name)
992 {
993     IgnoreUnused(name);
994
995     auto fbQuantizeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Quantize);
996     auto fbQuantizeLayer = serializer::CreateQuantizeLayer(m_flatBufferBuilder,
997                                                            fbQuantizeBaseLayer);
998     CreateAnyLayer(fbQuantizeLayer.o, serializer::Layer::Layer_QuantizeLayer);
999 }
1000
1001 // Build FlatBuffer for FullyConnected Layer
1002 void SerializerVisitor::VisitFullyConnectedLayer(const armnn::IConnectableLayer* layer,
1003                                                  const armnn::FullyConnectedDescriptor& fullyConnectedDescriptor,
1004                                                  const armnn::ConstTensor& weights,
1005                                                  const armnn::Optional<armnn::ConstTensor>& biases,
1006                                                  const char* name)
1007 {
1008     IgnoreUnused(name);
1009
1010     // Create FlatBuffer BaseLayer
1011     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_FullyConnected);
1012
1013     // Create FlatBuffer FullyConnectedDescriptor
1014     auto flatBufferDescriptor =
1015         serializer::CreateFullyConnectedDescriptor(m_flatBufferBuilder,
1016                                                    fullyConnectedDescriptor.m_BiasEnabled,
1017                                                    fullyConnectedDescriptor.m_TransposeWeightMatrix);
1018
1019     // Create FlatBuffer weights data
1020     auto flatBufferWeights = CreateConstTensorInfo(weights);
1021
1022     // Create FlatBuffer bias data
1023     flatbuffers::Offset<serializer::ConstTensor> flatBufferBiases;
1024     if (fullyConnectedDescriptor.m_BiasEnabled)
1025     {
1026         flatBufferBiases = CreateConstTensorInfo(biases.value());
1027     }
1028
1029     // Create FlatBuffer FullyConnectedLayer
1030     auto flatBufferLayer = serializer::CreateFullyConnectedLayer(m_flatBufferBuilder,
1031                                                                  flatBufferBaseLayer,
1032                                                                  flatBufferDescriptor,
1033                                                                  flatBufferWeights,
1034                                                                  flatBufferBiases);
1035
1036     // Add created FullyConnectedLayer to the FlatBufferLayers
1037     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_FullyConnectedLayer);
1038 }
1039
1040 // Build FlatBuffer for SpaceToBatchNd Layer
1041 void SerializerVisitor::VisitSpaceToBatchNdLayer(const armnn::IConnectableLayer* layer,
1042                                                  const armnn::SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
1043                                                  const char* name)
1044 {
1045     IgnoreUnused(name);
1046
1047     // Create FlatBuffer BaseLayer
1048     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_SpaceToBatchNd);
1049
1050     std::vector<unsigned int> padList;
1051     padList.reserve(spaceToBatchNdDescriptor.m_PadList.size()*2);
1052     for (auto& pad : spaceToBatchNdDescriptor.m_PadList)
1053     {
1054         padList.push_back(pad.first);
1055         padList.push_back(pad.second);
1056     }
1057
1058     auto flatBufferDescriptor =
1059         CreateSpaceToBatchNdDescriptor(m_flatBufferBuilder,
1060                                        m_flatBufferBuilder.CreateVector(spaceToBatchNdDescriptor.m_BlockShape),
1061                                        m_flatBufferBuilder.CreateVector(padList),
1062                                        GetFlatBufferDataLayout(spaceToBatchNdDescriptor.m_DataLayout));
1063
1064     auto flatBufferLayer = serializer::CreateSpaceToBatchNdLayer(m_flatBufferBuilder,
1065                                                                  flatBufferBaseLayer,
1066                                                                  flatBufferDescriptor);
1067
1068     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_SpaceToBatchNdLayer);
1069 }
1070
1071 // Build FlatBuffer for SpaceToDepthLayer
1072 void SerializerVisitor::VisitSpaceToDepthLayer(const armnn::IConnectableLayer* layer,
1073                                                const armnn::SpaceToDepthDescriptor& spaceToDepthDescriptor,
1074                                                const char* name)
1075 {
1076     IgnoreUnused(name);
1077
1078     auto flatBufferBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_SpaceToDepth);
1079     auto flatBufferDescriptor =
1080         CreateSpaceToDepthDescriptor(m_flatBufferBuilder,
1081                                      spaceToDepthDescriptor.m_BlockSize,
1082                                      GetFlatBufferDataLayout(spaceToDepthDescriptor.m_DataLayout));
1083
1084     auto flatBufferLayer = serializer::CreateSpaceToDepthLayer(m_flatBufferBuilder,
1085                                                                flatBufferBaseLayer,
1086                                                                flatBufferDescriptor);
1087
1088     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_SpaceToDepthLayer);
1089 }
1090
1091 // Build FlatBuffer for Splitter Layer
1092 void SerializerVisitor::VisitSplitterLayer(const armnn::IConnectableLayer* layer,
1093                                            const armnn::ViewsDescriptor& viewsDescriptor,
1094                                            const char* name)
1095 {
1096     IgnoreUnused(name);
1097
1098     // Create FlatBuffer ViewOrigins
1099     std::vector<flatbuffers::Offset<UintVector>> flatBufferViewOrigins;
1100     flatBufferViewOrigins.reserve(viewsDescriptor.GetNumViews());
1101
1102     for(unsigned int vIdx = 0; vIdx < viewsDescriptor.GetNumViews(); ++vIdx)
1103     {
1104         std::vector<uint32_t> viewOrigin;
1105         viewOrigin.reserve(viewsDescriptor.GetNumDimensions());
1106
1107         // Copy vector
1108         for(unsigned int dIdx = 0; dIdx < viewsDescriptor.GetNumDimensions(); ++dIdx)
1109         {
1110             viewOrigin.push_back(viewsDescriptor.GetViewOrigin(vIdx)[dIdx]);
1111         }
1112
1113         flatBufferViewOrigins.push_back(CreateUintVector(m_flatBufferBuilder,
1114                                                          m_flatBufferBuilder.CreateVector(viewOrigin)));
1115     }
1116
1117     // Create FlatBuffer OriginsDescriptor
1118     auto flatBufferOriginDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder,
1119                                                               viewsDescriptor.GetOrigins().GetConcatAxis(),
1120                                                               viewsDescriptor.GetOrigins().GetNumViews(),
1121                                                               viewsDescriptor.GetOrigins().GetNumDimensions(),
1122                                                               m_flatBufferBuilder.CreateVector(flatBufferViewOrigins));
1123
1124     // Create FlatBuffer ViewOrigins
1125     std::vector<flatbuffers::Offset<UintVector>> flatBufferViewSizes;
1126     flatBufferViewSizes.reserve(viewsDescriptor.GetNumViews());
1127
1128     for(unsigned int vIdx = 0; vIdx < viewsDescriptor.GetNumViews(); ++vIdx)
1129     {
1130         std::vector<uint32_t> viewSize;
1131         viewSize.reserve(viewsDescriptor.GetNumDimensions());
1132
1133         // Copy vector
1134         for(unsigned int dIdx = 0; dIdx < viewsDescriptor.GetNumDimensions(); ++dIdx)
1135         {
1136             viewSize.push_back(viewsDescriptor.GetViewSizes(vIdx)[dIdx]);
1137         }
1138
1139         flatBufferViewSizes.push_back(CreateUintVector(m_flatBufferBuilder,
1140                                                        m_flatBufferBuilder.CreateVector(viewSize)));
1141     }
1142
1143     // Create FlatBuffer ViewsDescriptor
1144     auto flatBufferViewsDescriptor = CreateViewsDescriptor(m_flatBufferBuilder,
1145                                                            flatBufferOriginDescriptor,
1146                                                            m_flatBufferBuilder.CreateVector(flatBufferViewSizes));
1147
1148     // Create FlatBuffer BaseLayer
1149     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Splitter);
1150
1151     auto flatBufferSplitterLayer = serializer::CreateSplitterLayer(m_flatBufferBuilder,
1152                                                                    flatBufferBaseLayer,
1153                                                                    flatBufferViewsDescriptor);
1154
1155     CreateAnyLayer(flatBufferSplitterLayer.o, serializer::Layer::Layer_SplitterLayer);
1156 }
1157
1158 void SerializerVisitor::VisitNormalizationLayer(const armnn::IConnectableLayer* layer,
1159                                                 const armnn::NormalizationDescriptor& descriptor,
1160                                                 const char* name)
1161 {
1162     IgnoreUnused(name);
1163
1164     auto fbNormalizationBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_Normalization);
1165
1166     auto fbNormalizationDescriptor = serializer::CreateNormalizationDescriptor(
1167         m_flatBufferBuilder,
1168         GetFlatBufferNormalizationAlgorithmChannel(descriptor.m_NormChannelType),
1169         GetFlatBufferNormalizationAlgorithmMethod(descriptor.m_NormMethodType),
1170         descriptor.m_NormSize,
1171         descriptor.m_Alpha,
1172         descriptor.m_Beta,
1173         descriptor.m_K,
1174         GetFlatBufferDataLayout(descriptor.m_DataLayout));
1175
1176     auto flatBufferLayer = serializer::CreateNormalizationLayer(m_flatBufferBuilder,
1177                                                                 fbNormalizationBaseLayer,
1178                                                                 fbNormalizationDescriptor);
1179
1180     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_NormalizationLayer);
1181 }
1182
1183 void SerializerVisitor::VisitStackLayer(const armnn::IConnectableLayer* layer,
1184                                         const armnn::StackDescriptor& stackDescriptor,
1185                                         const char* name)
1186 {
1187     IgnoreUnused(name);
1188
1189     auto stackBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Stack);
1190
1191     std::vector<unsigned int> inputShape;
1192     for (unsigned int i =0; i < stackDescriptor.m_InputShape.GetNumDimensions(); i++)
1193     {
1194         inputShape.push_back(stackDescriptor.m_InputShape[i]);
1195     }
1196
1197     auto flatBufferStackDescriptor = CreateStackDescriptor(m_flatBufferBuilder,
1198                                                            stackDescriptor.m_Axis,
1199                                                            stackDescriptor.m_NumInputs,
1200                                                            m_flatBufferBuilder.CreateVector(inputShape));
1201
1202     auto stackLayer = serializer::CreateStackLayer(m_flatBufferBuilder, stackBaseLayer, flatBufferStackDescriptor);
1203     CreateAnyLayer(stackLayer.o, serializer::Layer::Layer_StackLayer);
1204 }
1205
1206 void SerializerVisitor::VisitStandInLayer(const armnn::IConnectableLayer *layer,
1207                                           const armnn::StandInDescriptor& standInDescriptor,
1208                                           const char *name)
1209 {
1210     IgnoreUnused(name);
1211
1212     auto fbDescriptor = serializer::CreateStandInDescriptor(m_flatBufferBuilder,
1213                                                             standInDescriptor.m_NumInputs,
1214                                                             standInDescriptor.m_NumOutputs);
1215
1216     auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_StandIn);
1217     auto fbLayer     = serializer::CreateStandInLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
1218
1219     CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_StandInLayer);
1220 }
1221
1222 void SerializerVisitor::VisitStridedSliceLayer(const armnn::IConnectableLayer* layer,
1223                                                const armnn::StridedSliceDescriptor& stridedSliceDescriptor,
1224                                                const char* name)
1225 {
1226     IgnoreUnused(name);
1227
1228     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_StridedSlice);
1229
1230     auto flatBufferDescriptor =
1231         CreateStridedSliceDescriptor(m_flatBufferBuilder,
1232                                      m_flatBufferBuilder.CreateVector(stridedSliceDescriptor.m_Begin),
1233                                      m_flatBufferBuilder.CreateVector(stridedSliceDescriptor.m_End),
1234                                      m_flatBufferBuilder.CreateVector(stridedSliceDescriptor.m_Stride),
1235                                      stridedSliceDescriptor.m_BeginMask,
1236                                      stridedSliceDescriptor.m_EndMask,
1237                                      stridedSliceDescriptor.m_ShrinkAxisMask,
1238                                      stridedSliceDescriptor.m_EllipsisMask,
1239                                      stridedSliceDescriptor.m_NewAxisMask,
1240                                      GetFlatBufferDataLayout(stridedSliceDescriptor.m_DataLayout));
1241
1242     auto flatBufferLayer = serializer::CreateStridedSliceLayer(m_flatBufferBuilder,
1243                                                                flatBufferBaseLayer,
1244                                                                flatBufferDescriptor);
1245
1246     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_StridedSliceLayer);
1247 }
1248
1249 void SerializerVisitor::VisitSubtractionLayer(const armnn::IConnectableLayer* layer, const char* name)
1250 {
1251     IgnoreUnused(name);
1252
1253     auto fbSubtractionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Subtraction);
1254     auto fbSubtractionLayer = serializer::CreateSubtractionLayer(m_flatBufferBuilder, fbSubtractionBaseLayer);
1255
1256     CreateAnyLayer(fbSubtractionLayer.o, serializer::Layer::Layer_SubtractionLayer);
1257 }
1258
1259 void SerializerVisitor::VisitSwitchLayer(const armnn::IConnectableLayer* layer, const char* name)
1260 {
1261     IgnoreUnused(name);
1262
1263     auto fbSwitchBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Switch);
1264     auto fbSwitchLayer = serializer::CreateSwitchLayer(m_flatBufferBuilder, fbSwitchBaseLayer);
1265
1266     CreateAnyLayer(fbSwitchLayer.o, serializer::Layer::Layer_SwitchLayer);
1267 }
1268
1269 void SerializerVisitor::VisitTransposeConvolution2dLayer(
1270     const armnn::IConnectableLayer* layer,
1271     const armnn::TransposeConvolution2dDescriptor& descriptor,
1272     const armnn::ConstTensor& weights,
1273     const armnn::Optional<armnn::ConstTensor>& biases,
1274     const char* name)
1275 {
1276     IgnoreUnused(name);
1277
1278     auto fbBaseLayer  = CreateLayerBase(layer, serializer::LayerType::LayerType_Convolution2d);
1279     auto fbDescriptor = CreateTransposeConvolution2dDescriptor(m_flatBufferBuilder,
1280                                                                descriptor.m_PadLeft,
1281                                                                descriptor.m_PadRight,
1282                                                                descriptor.m_PadTop,
1283                                                                descriptor.m_PadBottom,
1284                                                                descriptor.m_StrideX,
1285                                                                descriptor.m_StrideY,
1286                                                                descriptor.m_BiasEnabled,
1287                                                                GetFlatBufferDataLayout(descriptor.m_DataLayout));
1288
1289     // weights & biases
1290     auto fbWeightsConstTensorInfo = CreateConstTensorInfo(weights);
1291     flatbuffers::Offset<serializer::ConstTensor> fbBiasesConstTensorInfo;
1292     if (biases.has_value())
1293     {
1294         fbBiasesConstTensorInfo = CreateConstTensorInfo(biases.value());
1295     }
1296
1297     auto fbLayer = CreateTransposeConvolution2dLayer(m_flatBufferBuilder,
1298                                                      fbBaseLayer,
1299                                                      fbDescriptor,
1300                                                      fbWeightsConstTensorInfo,
1301                                                      fbBiasesConstTensorInfo);
1302
1303     CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_TransposeConvolution2dLayer);
1304 }
1305
1306 void SerializerVisitor::VisitTransposeLayer(const armnn::IConnectableLayer* layer,
1307                                             const armnn::TransposeDescriptor& descriptor,
1308                                             const char* name)
1309 {
1310     IgnoreUnused(name);
1311
1312     // Create FlatBuffer BaseLayer
1313     auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Transpose);
1314
1315     std::vector<unsigned int> dimMappings;
1316     for (unsigned int i=0; i<descriptor.m_DimMappings.GetSize(); ++i)
1317     {
1318         dimMappings.push_back(descriptor.m_DimMappings[i]);
1319     }
1320
1321     auto flatBufferDesc = serializer::CreateTransposeDescriptor(m_flatBufferBuilder,
1322                                                                 m_flatBufferBuilder.CreateVector(dimMappings));
1323
1324     // Create the FlatBuffer TransposeLayer
1325     auto flatBufferLayer = serializer::CreateTransposeLayer(m_flatBufferBuilder,
1326                                                             flatBufferBaseLayer,
1327                                                             flatBufferDesc);
1328
1329     // Add the AnyLayer to the FlatBufferLayers
1330     CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_TransposeLayer);
1331 }
1332
1333 void SerializerVisitor::VisitQLstmLayer(const armnn::IConnectableLayer* layer,
1334                                         const armnn::QLstmDescriptor& descriptor,
1335                                         const armnn::LstmInputParams& params,
1336                                         const char* name)
1337 {
1338     IgnoreUnused(name);
1339
1340     auto fbQLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QLstm);
1341
1342     auto fbQLstmDescriptor = serializer::CreateQLstmDescriptor(
1343             m_flatBufferBuilder,
1344             descriptor.m_CifgEnabled,
1345             descriptor.m_PeepholeEnabled,
1346             descriptor.m_ProjectionEnabled,
1347             descriptor.m_LayerNormEnabled,
1348             descriptor.m_CellClip,
1349             descriptor.m_ProjectionClip,
1350             descriptor.m_InputIntermediateScale,
1351             descriptor.m_ForgetIntermediateScale,
1352             descriptor.m_CellIntermediateScale,
1353             descriptor.m_OutputIntermediateScale,
1354             descriptor.m_HiddenStateZeroPoint,
1355             descriptor.m_HiddenStateScale
1356             );
1357
1358     // Mandatory params
1359     auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights);
1360     auto inputToCellWeights = CreateConstTensorInfo(*params.m_InputToCellWeights);
1361     auto inputToOutputWeights = CreateConstTensorInfo(*params.m_InputToOutputWeights);
1362     auto recurrentToForgetWeights = CreateConstTensorInfo(*params.m_RecurrentToForgetWeights);
1363     auto recurrentToCellWeights = CreateConstTensorInfo(*params.m_RecurrentToCellWeights);
1364     auto recurrentToOutputWeights = CreateConstTensorInfo(*params.m_RecurrentToOutputWeights);
1365     auto forgetGateBias = CreateConstTensorInfo(*params.m_ForgetGateBias);
1366     auto cellBias = CreateConstTensorInfo(*params.m_CellBias);
1367     auto outputGateBias = CreateConstTensorInfo(*params.m_OutputGateBias);
1368
1369     // CIFG
1370     flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights;
1371     flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights;
1372     flatbuffers::Offset<serializer::ConstTensor> inputGateBias;
1373
1374     if (!descriptor.m_CifgEnabled)
1375     {
1376         inputToInputWeights = CreateConstTensorInfo(*params.m_InputToInputWeights);
1377         recurrentToInputWeights = CreateConstTensorInfo(*params.m_RecurrentToInputWeights);
1378         inputGateBias = CreateConstTensorInfo(*params.m_InputGateBias);
1379     }
1380
1381     // Projectiom
1382     flatbuffers::Offset<serializer::ConstTensor> projectionWeights;
1383     flatbuffers::Offset<serializer::ConstTensor> projectionBias;
1384
1385     if (descriptor.m_ProjectionEnabled)
1386     {
1387         projectionWeights = CreateConstTensorInfo(*params.m_ProjectionWeights);
1388         projectionBias = CreateConstTensorInfo(*params.m_ProjectionBias);
1389     }
1390
1391     // Peephole
1392     flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights;
1393     flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights;
1394     flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights;
1395
1396     if (descriptor.m_PeepholeEnabled)
1397     {
1398         if (!descriptor.m_CifgEnabled)
1399         {
1400             cellToInputWeights  = CreateConstTensorInfo(*params.m_CellToInputWeights);
1401         }
1402
1403         cellToForgetWeights = CreateConstTensorInfo(*params.m_CellToForgetWeights);
1404         cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights);
1405     }
1406
1407     // Layer norm
1408     flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights;
1409     flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights;
1410     flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights;
1411     flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights;
1412
1413     if (descriptor.m_LayerNormEnabled)
1414     {
1415         if (!descriptor.m_CifgEnabled)
1416         {
1417             inputLayerNormWeights = CreateConstTensorInfo((*params.m_InputLayerNormWeights));
1418         }
1419
1420         forgetLayerNormWeights = CreateConstTensorInfo(*params.m_ForgetLayerNormWeights);
1421         cellLayerNormWeights   = CreateConstTensorInfo(*params.m_CellLayerNormWeights);
1422         outputLayerNormWeights = CreateConstTensorInfo(*params.m_OutputLayerNormWeights);
1423     }
1424
1425     auto fbQLstmParams = serializer::CreateQLstmInputParams(
1426             m_flatBufferBuilder,
1427             inputToForgetWeights,
1428             inputToCellWeights,
1429             inputToOutputWeights,
1430             recurrentToForgetWeights,
1431             recurrentToCellWeights,
1432             recurrentToOutputWeights,
1433             forgetGateBias,
1434             cellBias,
1435             outputGateBias,
1436             inputToInputWeights,
1437             recurrentToInputWeights,
1438             inputGateBias,
1439             projectionWeights,
1440             projectionBias,
1441             cellToInputWeights,
1442             cellToForgetWeights,
1443             cellToOutputWeights,
1444             inputLayerNormWeights,
1445             forgetLayerNormWeights,
1446             cellLayerNormWeights,
1447             outputLayerNormWeights);
1448
1449     auto fbQLstmLayer = serializer::CreateQLstmLayer(
1450             m_flatBufferBuilder,
1451             fbQLstmBaseLayer,
1452             fbQLstmDescriptor,
1453             fbQLstmParams);
1454
1455     CreateAnyLayer(fbQLstmLayer.o, serializer::Layer::Layer_QLstmLayer);
1456 }
1457
1458 void SerializerVisitor::VisitQuantizedLstmLayer(const armnn::IConnectableLayer* layer,
1459                                                 const armnn::QuantizedLstmInputParams& params,
1460                                                 const char* name)
1461 {
1462     IgnoreUnused(name);
1463
1464     auto fbQuantizedLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QuantizedLstm);
1465
1466     // Get input parameters
1467     auto inputToInputWeights = CreateConstTensorInfo(params.GetInputToInputWeights());
1468     auto inputToForgetWeights = CreateConstTensorInfo(params.GetInputToForgetWeights());
1469     auto inputToCellWeights = CreateConstTensorInfo(params.GetInputToCellWeights());
1470     auto inputToOutputWeights = CreateConstTensorInfo(params.GetInputToOutputWeights());
1471
1472     auto recurrentToInputWeights = CreateConstTensorInfo(params.GetRecurrentToInputWeights());
1473     auto recurrentToForgetWeights = CreateConstTensorInfo(params.GetRecurrentToForgetWeights());
1474     auto recurrentToCellWeights = CreateConstTensorInfo(params.GetRecurrentToCellWeights());
1475     auto recurrentToOutputWeights = CreateConstTensorInfo(params.GetRecurrentToOutputWeights());
1476
1477     auto inputGateBias = CreateConstTensorInfo(params.GetInputGateBias());
1478     auto forgetGateBias = CreateConstTensorInfo(params.GetForgetGateBias());
1479     auto cellBias = CreateConstTensorInfo(params.GetCellBias());
1480     auto outputGateBias = CreateConstTensorInfo(params.GetOutputGateBias());
1481
1482     auto fbQuantizedLstmParams = serializer::CreateQuantizedLstmInputParams(
1483         m_flatBufferBuilder,
1484         inputToInputWeights,
1485         inputToForgetWeights,
1486         inputToCellWeights,
1487         inputToOutputWeights,
1488         recurrentToInputWeights,
1489         recurrentToForgetWeights,
1490         recurrentToCellWeights,
1491         recurrentToOutputWeights,
1492         inputGateBias,
1493         forgetGateBias,
1494         cellBias,
1495         outputGateBias);
1496
1497     auto fbQuantizedLstmLayer = serializer::CreateQuantizedLstmLayer(
1498         m_flatBufferBuilder,
1499         fbQuantizedLstmBaseLayer,
1500         fbQuantizedLstmParams);
1501
1502     CreateAnyLayer(fbQuantizedLstmLayer.o, serializer::Layer::Layer_QuantizedLstmLayer);
1503 }
1504
1505 fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer,
1506                                                                      const serializer::LayerType layerType)
1507 {
1508
1509     uint32_t fbIndex = GetSerializedId(layer->GetGuid());
1510
1511     std::vector<fb::Offset<serializer::InputSlot>> inputSlots = CreateInputSlots(layer);
1512     std::vector<fb::Offset<serializer::OutputSlot>> outputSlots = CreateOutputSlots(layer);
1513
1514     return serializer::CreateLayerBase(m_flatBufferBuilder,
1515                                        fbIndex,
1516                                        m_flatBufferBuilder.CreateString(layer->GetName()),
1517                                        layerType,
1518                                        m_flatBufferBuilder.CreateVector(inputSlots),
1519                                        m_flatBufferBuilder.CreateVector(outputSlots));
1520 }
1521
1522 void SerializerVisitor::CreateAnyLayer(const flatbuffers::Offset<void>& layer, const serializer::Layer serializerLayer)
1523 {
1524
1525     auto anyLayer = armnnSerializer::CreateAnyLayer(m_flatBufferBuilder, serializerLayer, layer);
1526     m_serializedLayers.push_back(anyLayer);
1527 }
1528
1529 template <typename T>
1530 flatbuffers::Offset<flatbuffers::Vector<T>> SerializerVisitor::CreateDataVector(const void* memory, unsigned int size)
1531 {
1532     const T* buffer = reinterpret_cast<const T*>(memory);
1533     std::vector<T> vector(buffer, buffer + (size / sizeof(T)));
1534     auto fbVector = m_flatBufferBuilder.CreateVector(vector);
1535     return fbVector;
1536 }
1537
1538 flatbuffers::Offset<TensorInfo>  SerializerVisitor::CreateTensorInfo(const armnn::TensorInfo& tensorInfo)
1539 {
1540     // Get the dimensions
1541     std::vector<unsigned int> shape;
1542     for(unsigned int dim = 0; dim < tensorInfo.GetShape().GetNumDimensions(); ++dim)
1543     {
1544         shape.push_back(tensorInfo.GetShape()[dim]);
1545     }
1546
1547     if (tensorInfo.HasPerAxisQuantization())
1548     {
1549         // Create FlatBuffer TensorInfo
1550         auto flatBufferTensorInfo =
1551             serializer::CreateTensorInfo(m_flatBufferBuilder,
1552                                          m_flatBufferBuilder.CreateVector(shape),
1553                                          GetFlatBufferDataType(tensorInfo.GetDataType()),
1554                                          tensorInfo.GetQuantizationScales()[0],
1555                                          tensorInfo.GetQuantizationOffset(),
1556                                          m_flatBufferBuilder.CreateVector(tensorInfo.GetQuantizationScales()),
1557                                          tensorInfo.GetQuantizationDim().value());
1558         return flatBufferTensorInfo;
1559     }
1560
1561     // Create FlatBuffer TensorInfo
1562     auto flatBufferTensorInfo = serializer::CreateTensorInfo(m_flatBufferBuilder,
1563                                                              m_flatBufferBuilder.CreateVector(shape),
1564                                                              GetFlatBufferDataType(tensorInfo.GetDataType()),
1565                                                              tensorInfo.GetQuantizationScale(),
1566                                                              tensorInfo.GetQuantizationOffset());
1567     return flatBufferTensorInfo;
1568 }
1569
1570 flatbuffers::Offset<serializer::ConstTensor>
1571     SerializerVisitor::CreateConstTensorInfo(const armnn::ConstTensor& constTensor)
1572 {
1573     armnn::TensorInfo tensorInfo = constTensor.GetInfo();
1574
1575     flatbuffers::Offset<void> fbPayload;
1576
1577     switch (tensorInfo.GetDataType())
1578     {
1579         case armnn::DataType::Float32:
1580         case armnn::DataType::Signed32:
1581         {
1582             auto fbVector = CreateDataVector<int32_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1583             flatbuffers::Offset<serializer::IntData> flatBuffersData = serializer::CreateIntData(
1584                     m_flatBufferBuilder,
1585                     fbVector);
1586             fbPayload = flatBuffersData.o;
1587             break;
1588         }
1589         case armnn::DataType::Float16:
1590         case armnn::DataType::BFloat16:
1591         case armnn::DataType::QSymmS16:
1592         {
1593             auto fbVector = CreateDataVector<int16_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1594             flatbuffers::Offset<serializer::ShortData> flatBuffersData = serializer::CreateShortData(
1595                     m_flatBufferBuilder,
1596                     fbVector);
1597             fbPayload = flatBuffersData.o;
1598             break;
1599         }
1600         case armnn::DataType::QSymmS8:
1601         case armnn::DataType::QAsymmS8:
1602         case armnn::DataType::QAsymmU8:
1603         case armnn::DataType::Boolean:
1604         default:
1605         {
1606             auto fbVector = CreateDataVector<int8_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1607             flatbuffers::Offset<serializer::ByteData> flatBuffersData = serializer::CreateByteData(
1608                     m_flatBufferBuilder,
1609                     fbVector);
1610             fbPayload = flatBuffersData.o;
1611         }
1612     }
1613     flatbuffers::Offset<serializer::ConstTensor> flatBufferConstTensor = serializer::CreateConstTensor(
1614             m_flatBufferBuilder,
1615             CreateTensorInfo(tensorInfo),
1616             GetFlatBufferConstTensorData(tensorInfo.GetDataType()),
1617             fbPayload);
1618     return flatBufferConstTensor;
1619 }
1620
1621 flatbuffers::Offset<armnnSerializer::FeatureCompatibilityVersions> SerializerVisitor::GetVersionTable()
1622 {
1623     flatbuffers::Offset<armnnSerializer::FeatureCompatibilityVersions> versionsTable =
1624         serializer::CreateFeatureCompatibilityVersions(
1625                 m_flatBufferBuilder,
1626                 1 // Binding ids scheme version
1627             );
1628     return versionsTable;
1629 }
1630
1631 std::vector<fb::Offset<serializer::InputSlot>>
1632     SerializerVisitor::CreateInputSlots(const armnn::IConnectableLayer* layer)
1633 {
1634     std::vector<fb::Offset<serializer::InputSlot>> inputSlots;
1635
1636     // Get the InputSlots
1637     for (unsigned int slotIndex = 0; slotIndex<layer->GetNumInputSlots(); ++slotIndex)
1638     {
1639         const IInputSlot& inputSlot = layer->GetInputSlot(slotIndex);
1640
1641         // Get the Connection for the InputSlot
1642         const IOutputSlot* connection = inputSlot.GetConnection();
1643
1644         // Create FlatBuffer Connection
1645         serializer::Connection conn(GetSerializedId(inputSlot.GetConnection()->GetOwningLayerGuid()),
1646                                     connection->CalculateIndexOnOwner());
1647         // Create FlatBuffer InputSlot
1648         inputSlots.push_back(serializer::CreateInputSlot(m_flatBufferBuilder, slotIndex, &conn));
1649     }
1650     return inputSlots;
1651 }
1652
1653 std::vector<fb::Offset<serializer::OutputSlot>>
1654     SerializerVisitor::CreateOutputSlots(const armnn::IConnectableLayer* layer)
1655 {
1656     std::vector<fb::Offset<serializer::OutputSlot>> outputSlots;
1657
1658     // Get the OutputSlots
1659     for (unsigned int slotIndex = 0; slotIndex < layer->GetNumOutputSlots(); ++slotIndex)
1660     {
1661         const IOutputSlot& outputSlot = layer->GetOutputSlot(slotIndex);
1662         const armnn::TensorInfo& tensorInfo = outputSlot.GetTensorInfo();
1663
1664         // Create FlatBuffer Outputslot
1665         outputSlots.push_back(serializer::CreateOutputSlot(m_flatBufferBuilder,
1666                                                            slotIndex,
1667                                                            CreateTensorInfo(tensorInfo)));
1668     }
1669     return outputSlots;
1670 }
1671
1672
1673 ISerializer* ISerializer::CreateRaw()
1674 {
1675     return new Serializer();
1676 }
1677
1678 ISerializerPtr ISerializer::Create()
1679 {
1680     return ISerializerPtr(CreateRaw(), &ISerializer::Destroy);
1681 }
1682
1683 void ISerializer::Destroy(ISerializer* serializer)
1684 {
1685     delete serializer;
1686 }
1687
1688 void Serializer::Serialize(const INetwork& inNetwork)
1689 {
1690     // Iterate through to network
1691     inNetwork.Accept(m_SerializerVisitor);
1692     flatbuffers::FlatBufferBuilder& fbBuilder = m_SerializerVisitor.GetFlatBufferBuilder();
1693
1694     // Create FlatBuffer SerializedGraph
1695     auto serializedGraph = serializer::CreateSerializedGraph(
1696         fbBuilder,
1697         fbBuilder.CreateVector(m_SerializerVisitor.GetSerializedLayers()),
1698         fbBuilder.CreateVector(m_SerializerVisitor.GetInputIds()),
1699         fbBuilder.CreateVector(m_SerializerVisitor.GetOutputIds()),
1700         m_SerializerVisitor.GetVersionTable());
1701
1702     // Serialize the graph
1703     fbBuilder.Finish(serializedGraph);
1704 }
1705
1706 bool Serializer::SaveSerializedToStream(std::ostream& stream)
1707 {
1708     flatbuffers::FlatBufferBuilder& fbBuilder = m_SerializerVisitor.GetFlatBufferBuilder();
1709
1710     auto bytesToWrite = boost::numeric_cast<std::streamsize>(fbBuilder.GetSize());
1711     stream.write(reinterpret_cast<const char*>(fbBuilder.GetBufferPointer()), bytesToWrite);
1712     return !stream.bad();
1713 }
1714
1715 } // namespace armnnSerializer