2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include "Serializer.hpp"
8 #include <armnn/Descriptors.hpp>
9 #include <armnn/LstmParams.hpp>
10 #include <armnn/QuantizedLstmParams.hpp>
11 #include <armnn/utility/IgnoreUnused.hpp>
15 #include <boost/numeric/conversion/cast.hpp>
16 #include <flatbuffers/util.h>
18 #include "SerializerUtils.hpp"
20 using namespace armnn;
21 namespace fb = flatbuffers;
22 namespace serializer = armnnSerializer;
24 namespace armnnSerializer
27 serializer::ActivationFunction GetFlatBufferActivationFunction(armnn::ActivationFunction function)
31 case armnn::ActivationFunction::Sigmoid:
32 return serializer::ActivationFunction::ActivationFunction_Sigmoid;
33 case armnn::ActivationFunction::TanH:
34 return serializer::ActivationFunction::ActivationFunction_TanH;
35 case armnn::ActivationFunction::Linear:
36 return serializer::ActivationFunction::ActivationFunction_Linear;
37 case armnn::ActivationFunction::ReLu:
38 return serializer::ActivationFunction::ActivationFunction_ReLu;
39 case armnn::ActivationFunction::BoundedReLu:
40 return serializer::ActivationFunction::ActivationFunction_BoundedReLu;
41 case armnn::ActivationFunction::LeakyReLu:
42 return serializer::ActivationFunction::ActivationFunction_LeakyReLu;
43 case armnn::ActivationFunction::Abs:
44 return serializer::ActivationFunction::ActivationFunction_Abs;
45 case armnn::ActivationFunction::Sqrt:
46 return serializer::ActivationFunction::ActivationFunction_Sqrt;
47 case armnn::ActivationFunction::Square:
48 return serializer::ActivationFunction::ActivationFunction_Square;
49 case armnn::ActivationFunction::Elu:
50 return serializer::ActivationFunction::ActivationFunction_Elu;
51 case armnn::ActivationFunction::HardSwish:
52 return serializer::ActivationFunction::ActivationFunction_HardSwish;
54 return serializer::ActivationFunction::ActivationFunction_Sigmoid;
58 serializer::ArgMinMaxFunction GetFlatBufferArgMinMaxFunction(armnn::ArgMinMaxFunction function)
62 case armnn::ArgMinMaxFunction::Max:
63 return serializer::ArgMinMaxFunction::ArgMinMaxFunction_Max;
64 case armnn::ArgMinMaxFunction::Min:
66 return serializer::ArgMinMaxFunction::ArgMinMaxFunction_Min;
70 uint32_t SerializerVisitor::GetSerializedId(armnn::LayerGuid guid)
72 if (m_guidMap.empty())
74 m_guidMap.insert(std::make_pair(guid, m_layerId));
76 else if (m_guidMap.find(guid) == m_guidMap.end())
79 m_guidMap.insert(std::make_pair(guid, m_layerId));
83 return m_guidMap[guid];
86 // Build FlatBuffer for Input Layer
87 void SerializerVisitor::VisitInputLayer(const armnn::IConnectableLayer* layer, LayerBindingId id, const char* name)
91 // Create FlatBuffer BaseLayer
92 auto flatBufferInputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Input);
94 // Create FlatBuffer BindableBaseLayer
95 auto flatBufferInputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
96 flatBufferInputBaseLayer,
98 // Push layer binding id to outputIds.
99 m_inputIds.push_back(id);
101 // Create the FlatBuffer InputLayer
102 auto flatBufferInputLayer = serializer::CreateInputLayer(m_flatBufferBuilder, flatBufferInputBindableBaseLayer);
104 // Add the AnyLayer to the FlatBufferLayers
105 CreateAnyLayer(flatBufferInputLayer.o, serializer::Layer::Layer_InputLayer);
108 // Build FlatBuffer for Output Layer
109 void SerializerVisitor::VisitOutputLayer(const armnn::IConnectableLayer* layer, LayerBindingId id, const char* name)
113 // Create FlatBuffer BaseLayer
114 auto flatBufferOutputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Output);
116 // Create FlatBuffer BindableBaseLayer
117 auto flatBufferOutputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
118 flatBufferOutputBaseLayer,
120 // Push layer binding id to outputIds.
121 m_outputIds.push_back(id);
123 // Create the FlatBuffer OutputLayer
124 auto flatBufferOutputLayer = serializer::CreateOutputLayer(m_flatBufferBuilder, flatBufferOutputBindableBaseLayer);
125 // Add the AnyLayer to the FlatBufferLayers
126 CreateAnyLayer(flatBufferOutputLayer.o, serializer::Layer::Layer_OutputLayer);
129 void SerializerVisitor::VisitAbsLayer(const armnn::IConnectableLayer* layer, const char* name)
132 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Abs);
133 auto flatBufferAbsLayer = serializer::CreateAbsLayer(m_flatBufferBuilder, flatBufferBaseLayer);
135 CreateAnyLayer(flatBufferAbsLayer.o, serializer::Layer::Layer_AbsLayer);
138 // Build FlatBuffer for Activation Layer
139 void SerializerVisitor::VisitActivationLayer(const armnn::IConnectableLayer* layer,
140 const armnn::ActivationDescriptor& descriptor,
145 // Create FlatBuffer BaseLayer
146 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Activation);
148 // Create the FlatBuffer ActivationDescriptor
149 auto flatBufferDescriptor = CreateActivationDescriptor(m_flatBufferBuilder,
150 GetFlatBufferActivationFunction(descriptor.m_Function),
154 // Create the FlatBuffer ActivationLayer
155 auto flatBufferAdditionLayer = CreateActivationLayer(m_flatBufferBuilder,
157 flatBufferDescriptor);
159 // Add the AnyLayer to the FlatBufferLayers
160 CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_ActivationLayer);
163 // Build FlatBuffer for Addition Layer
164 void SerializerVisitor::VisitAdditionLayer(const armnn::IConnectableLayer* layer, const char* name)
168 // Create FlatBuffer BaseLayer
169 auto flatBufferAdditionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Addition);
171 // Create the FlatBuffer AdditionLayer
172 auto flatBufferAdditionLayer = serializer::CreateAdditionLayer(m_flatBufferBuilder, flatBufferAdditionBaseLayer);
174 // Add the AnyLayer to the FlatBufferLayers
175 CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_AdditionLayer);
178 // Build FlatBuffer for ArgMinMax Layer
179 void SerializerVisitor::VisitArgMinMaxLayer(const armnn::IConnectableLayer *layer,
180 const armnn::ArgMinMaxDescriptor& descriptor,
185 // Create FlatBuffer BaseLayer
186 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_ArgMinMax);
188 // Create FlatBuffer Descriptor
189 auto flatBufferDescriptor = CreateArgMinMaxDescriptor(m_flatBufferBuilder,
190 GetFlatBufferArgMinMaxFunction(descriptor.m_Function),
193 // Create FlatBuffer ArgMinMaxLayer
194 auto flatBufferLayer = CreateArgMinMaxLayer(m_flatBufferBuilder,
196 flatBufferDescriptor);
198 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ArgMinMaxLayer);
201 // Build FlatBuffer for BatchToSpaceNd Layer
202 void SerializerVisitor::VisitBatchToSpaceNdLayer(const armnn::IConnectableLayer* layer,
203 const armnn::BatchToSpaceNdDescriptor& descriptor,
208 // Create FlatBuffer BaseLayer
209 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_BatchToSpaceNd);
211 std::vector<unsigned int> crops;
212 crops.reserve(descriptor.m_Crops.size() * 2);
213 for (auto& crop : descriptor.m_Crops)
215 crops.push_back(crop.first);
216 crops.push_back(crop.second);
219 auto flatBufferDescriptor =
220 CreateBatchToSpaceNdDescriptor(m_flatBufferBuilder,
221 m_flatBufferBuilder.CreateVector(descriptor.m_BlockShape),
222 m_flatBufferBuilder.CreateVector(crops),
223 GetFlatBufferDataLayout(descriptor.m_DataLayout));
225 auto flatBufferLayer = serializer::CreateBatchToSpaceNdLayer(m_flatBufferBuilder,
227 flatBufferDescriptor);
229 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_BatchToSpaceNdLayer);
232 void SerializerVisitor::VisitBatchNormalizationLayer(const armnn::IConnectableLayer* layer,
233 const armnn::BatchNormalizationDescriptor& batchNormDescriptor,
234 const armnn::ConstTensor& mean,
235 const armnn::ConstTensor& variance,
236 const armnn::ConstTensor& beta,
237 const armnn::ConstTensor& gamma,
242 auto fbBatchNormalizationBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_BatchNormalization);
243 auto fbBatchNormalizationDescriptor = serializer::CreateBatchNormalizationDescriptor(
245 batchNormDescriptor.m_Eps,
246 GetFlatBufferDataLayout(batchNormDescriptor.m_DataLayout));
248 auto fbMeanConstTensorInfo = CreateConstTensorInfo(mean);
249 auto fbVarianceConstTensorInfo = CreateConstTensorInfo(variance);
250 auto fbBetaConstTensorInfo = CreateConstTensorInfo(beta);
251 auto fbGammaConstTensorInfo = CreateConstTensorInfo(gamma);
252 auto fbBatchNormalizationLayer = serializer::CreateBatchNormalizationLayer(m_flatBufferBuilder,
253 fbBatchNormalizationBaseLayer,
254 fbBatchNormalizationDescriptor,
255 fbMeanConstTensorInfo,
256 fbVarianceConstTensorInfo,
257 fbBetaConstTensorInfo,
258 fbGammaConstTensorInfo);
260 CreateAnyLayer(fbBatchNormalizationLayer.o, serializer::Layer::Layer_BatchNormalizationLayer);
263 void SerializerVisitor::VisitComparisonLayer(const armnn::IConnectableLayer* layer,
264 const armnn::ComparisonDescriptor& descriptor,
269 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Comparison);
270 auto fbDescriptor = serializer::CreateComparisonDescriptor(
272 GetFlatBufferComparisonOperation(descriptor.m_Operation));
274 auto fbLayer = serializer::CreateComparisonLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
275 CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_ComparisonLayer);
278 // Build FlatBuffer for Constant Layer
279 void SerializerVisitor::VisitConstantLayer(const armnn::IConnectableLayer* layer,
280 const armnn::ConstTensor& input,
285 // Create FlatBuffer BaseLayer
286 auto flatBufferConstantBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Constant);
288 auto flatBufferConstTensorInfo = CreateConstTensorInfo(input);
290 // Create the FlatBuffer ConstantLayer
291 auto flatBufferLayer = CreateConstantLayer(m_flatBufferBuilder,
292 flatBufferConstantBaseLayer,
293 flatBufferConstTensorInfo);
295 // Add the AnyLayer to the FlatBufferLayers
296 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ConstantLayer);
299 // Build FlatBuffer for Convolution2dLayer
300 void SerializerVisitor::VisitConvolution2dLayer(const armnn::IConnectableLayer* layer,
301 const armnn::Convolution2dDescriptor& descriptor,
302 const armnn::ConstTensor& weights,
303 const armnn::Optional<armnn::ConstTensor>& biases,
308 // Create FlatBuffer BaseLayer
309 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Convolution2d);
311 auto flatBufferDescriptor = CreateConvolution2dDescriptor(m_flatBufferBuilder,
312 descriptor.m_PadLeft,
313 descriptor.m_PadRight,
315 descriptor.m_PadBottom,
316 descriptor.m_StrideX,
317 descriptor.m_StrideY,
318 descriptor.m_DilationX,
319 descriptor.m_DilationY,
320 descriptor.m_BiasEnabled,
321 GetFlatBufferDataLayout(descriptor.m_DataLayout));
322 auto flatBufferWeightsConstTensorInfo = CreateConstTensorInfo(weights);
323 flatbuffers::Offset<serializer::ConstTensor> flatBufferBiasesConstTensorInfo;
325 if (biases.has_value())
327 flatBufferBiasesConstTensorInfo = CreateConstTensorInfo(biases.value());
330 // Create the FlatBuffer Convolution2dLayer
331 auto flatBufferLayer = CreateConvolution2dLayer(m_flatBufferBuilder,
333 flatBufferDescriptor,
334 flatBufferWeightsConstTensorInfo,
335 flatBufferBiasesConstTensorInfo);
337 // Add the AnyLayer to the FlatBufferLayers
338 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_Convolution2dLayer);
341 void SerializerVisitor::VisitDepthToSpaceLayer(const armnn::IConnectableLayer* layer,
342 const armnn::DepthToSpaceDescriptor& descriptor,
347 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_DepthToSpace);
348 auto fbDescriptor = CreateDepthToSpaceDescriptor(m_flatBufferBuilder,
349 descriptor.m_BlockSize,
350 GetFlatBufferDataLayout(descriptor.m_DataLayout));
352 auto fbLayer = serializer::CreateDepthToSpaceLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
354 CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_DepthToSpaceLayer);
357 void SerializerVisitor::VisitDepthwiseConvolution2dLayer(const armnn::IConnectableLayer* layer,
358 const armnn::DepthwiseConvolution2dDescriptor& descriptor,
359 const armnn::ConstTensor& weights,
360 const armnn::Optional<armnn::ConstTensor>& biases,
365 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_DepthwiseConvolution2d);
366 auto fbDescriptor = CreateDepthwiseConvolution2dDescriptor(m_flatBufferBuilder,
367 descriptor.m_PadLeft,
368 descriptor.m_PadRight,
370 descriptor.m_PadBottom,
371 descriptor.m_StrideX,
372 descriptor.m_StrideY,
373 descriptor.m_DilationX,
374 descriptor.m_DilationY,
375 descriptor.m_BiasEnabled,
376 GetFlatBufferDataLayout(descriptor.m_DataLayout));
378 flatbuffers::Offset<serializer::ConstTensor> fbWeightsConstTensorInfo = CreateConstTensorInfo(weights);
379 flatbuffers::Offset<serializer::ConstTensor> fbBiasesConstTensorInfo;
380 if (biases.has_value())
382 fbBiasesConstTensorInfo = CreateConstTensorInfo(biases.value());
385 auto flatBufferLayer = CreateDepthwiseConvolution2dLayer(m_flatBufferBuilder,
388 fbWeightsConstTensorInfo,
389 fbBiasesConstTensorInfo);
391 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_DepthwiseConvolution2dLayer);
394 void SerializerVisitor::VisitDequantizeLayer(const armnn::IConnectableLayer* layer,
399 auto fbDequantizeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Dequantize);
400 auto fbDequantizeLayer = serializer::CreateDequantizeLayer(m_flatBufferBuilder, fbDequantizeBaseLayer);
402 CreateAnyLayer(fbDequantizeLayer.o, serializer::Layer::Layer_DequantizeLayer);
405 void SerializerVisitor::VisitDetectionPostProcessLayer(const armnn::IConnectableLayer* layer,
406 const armnn::DetectionPostProcessDescriptor& descriptor,
407 const armnn::ConstTensor& anchors,
412 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_DetectionPostProcess);
413 auto fbDescriptor = CreateDetectionPostProcessDescriptor(m_flatBufferBuilder,
414 descriptor.m_MaxDetections,
415 descriptor.m_MaxClassesPerDetection,
416 descriptor.m_DetectionsPerClass,
417 descriptor.m_NmsScoreThreshold,
418 descriptor.m_NmsIouThreshold,
419 descriptor.m_NumClasses,
420 descriptor.m_UseRegularNms,
424 descriptor.m_ScaleH);
426 flatbuffers::Offset<serializer::ConstTensor> fbAnchorsConstTensorInfo = CreateConstTensorInfo(anchors);
428 auto flatBufferLayer = CreateDetectionPostProcessLayer(m_flatBufferBuilder,
431 fbAnchorsConstTensorInfo);
433 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_DetectionPostProcessLayer);
436 void SerializerVisitor::VisitDivisionLayer(const armnn::IConnectableLayer* layer, const char* name)
440 auto fbDivisionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Division);
441 auto fbDivisionLayer = serializer::CreateDivisionLayer(m_flatBufferBuilder, fbDivisionBaseLayer);
443 CreateAnyLayer(fbDivisionLayer.o, serializer::Layer::Layer_DivisionLayer);
446 void SerializerVisitor::VisitElementwiseUnaryLayer(const armnn::IConnectableLayer* layer,
447 const armnn::ElementwiseUnaryDescriptor& descriptor,
452 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_ElementwiseUnary);
453 auto fbDescriptor = serializer::CreateElementwiseUnaryDescriptor(
455 GetFlatBufferUnaryOperation(descriptor.m_Operation));
457 auto fbLayer = serializer::CreateElementwiseUnaryLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
458 CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_ElementwiseUnaryLayer);
461 void SerializerVisitor::VisitEqualLayer(const armnn::IConnectableLayer* layer, const char* name)
465 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Equal);
466 auto fbEqualLayer = serializer::CreateEqualLayer(m_flatBufferBuilder, fbBaseLayer);
468 CreateAnyLayer(fbEqualLayer.o, serializer::Layer::Layer_EqualLayer);
471 void SerializerVisitor::VisitFloorLayer(const armnn::IConnectableLayer *layer, const char *name)
475 auto flatBufferFloorBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Floor);
476 auto flatBufferFloorLayer = serializer::CreateFloorLayer(m_flatBufferBuilder, flatBufferFloorBaseLayer);
478 CreateAnyLayer(flatBufferFloorLayer.o, serializer::Layer::Layer_FloorLayer);
481 void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer, const char* name)
485 auto fbGatherBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Gather);
486 auto flatBufferLayer = serializer::CreateGatherLayer(m_flatBufferBuilder, fbGatherBaseLayer);
488 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_GatherLayer);
491 void SerializerVisitor::VisitGreaterLayer(const armnn::IConnectableLayer* layer, const char* name)
495 auto fbGreaterBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Greater);
496 auto fbGreaterLayer = serializer::CreateGreaterLayer(m_flatBufferBuilder, fbGreaterBaseLayer);
498 CreateAnyLayer(fbGreaterLayer.o, serializer::Layer::Layer_GreaterLayer);
501 void SerializerVisitor::VisitInstanceNormalizationLayer(
502 const armnn::IConnectableLayer* layer,
503 const armnn::InstanceNormalizationDescriptor& instanceNormalizationDescriptor,
508 auto fbDescriptor = serializer::CreateInstanceNormalizationDescriptor(
510 instanceNormalizationDescriptor.m_Gamma,
511 instanceNormalizationDescriptor.m_Beta,
512 instanceNormalizationDescriptor.m_Eps,
513 GetFlatBufferDataLayout(instanceNormalizationDescriptor.m_DataLayout));
515 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_InstanceNormalization);
516 auto fbLayer = serializer::CreateInstanceNormalizationLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
518 CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_InstanceNormalizationLayer);
521 void SerializerVisitor::VisitL2NormalizationLayer(const armnn::IConnectableLayer* layer,
522 const armnn::L2NormalizationDescriptor& l2NormalizationDescriptor,
527 // Create FlatBuffer BaseLayer
528 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_L2Normalization);
530 // Create the FlatBuffer L2Normalization Descriptor
531 auto fbDescriptor = serializer::CreateL2NormalizationDescriptor(
533 GetFlatBufferDataLayout(l2NormalizationDescriptor.m_DataLayout),
534 l2NormalizationDescriptor.m_Eps);
536 // Create FlatBuffer layer
537 auto fbLayer = serializer::CreateL2NormalizationLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
539 CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_L2NormalizationLayer);
542 void SerializerVisitor::VisitLogSoftmaxLayer(const armnn::IConnectableLayer* layer,
543 const armnn::LogSoftmaxDescriptor& logSoftmaxDescriptor,
548 // Create FlatBuffer BaseLayer
549 auto flatBufferLogSoftmaxBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_LogSoftmax);
551 // Create the FlatBuffer LogSoftmaxDescriptor
552 auto flatBufferLogSoftmaxDesc =
553 serializer::CreateLogSoftmaxDescriptor(m_flatBufferBuilder,
554 logSoftmaxDescriptor.m_Beta,
555 logSoftmaxDescriptor.m_Axis);
557 // Create the FlatBuffer LogSoftmaxLayer
558 auto flatBufferLogSoftmaxLayer =
559 serializer::CreateLogSoftmaxLayer(m_flatBufferBuilder,
560 flatBufferLogSoftmaxBaseLayer,
561 flatBufferLogSoftmaxDesc);
563 CreateAnyLayer(flatBufferLogSoftmaxLayer.o, serializer::Layer::Layer_LogSoftmaxLayer);
566 void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer,
567 const armnn::LstmDescriptor& descriptor,
568 const armnn::LstmInputParams& params,
573 auto fbLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Lstm);
575 auto fbLstmDescriptor = serializer::CreateLstmDescriptor(
577 descriptor.m_ActivationFunc,
578 descriptor.m_ClippingThresCell,
579 descriptor.m_ClippingThresProj,
580 descriptor.m_CifgEnabled,
581 descriptor.m_PeepholeEnabled,
582 descriptor.m_ProjectionEnabled,
583 descriptor.m_LayerNormEnabled);
585 // Get mandatory input parameters
586 auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights);
587 auto inputToCellWeights = CreateConstTensorInfo(*params.m_InputToCellWeights);
588 auto inputToOutputWeights = CreateConstTensorInfo(*params.m_InputToOutputWeights);
589 auto recurrentToForgetWeights = CreateConstTensorInfo(*params.m_RecurrentToForgetWeights);
590 auto recurrentToCellWeights = CreateConstTensorInfo(*params.m_RecurrentToCellWeights);
591 auto recurrentToOutputWeights = CreateConstTensorInfo(*params.m_RecurrentToOutputWeights);
592 auto forgetGateBias = CreateConstTensorInfo(*params.m_ForgetGateBias);
593 auto cellBias = CreateConstTensorInfo(*params.m_CellBias);
594 auto outputGateBias = CreateConstTensorInfo(*params.m_OutputGateBias);
596 //Define optional parameters, these will be set depending on configuration in Lstm descriptor
597 flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights;
598 flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights;
599 flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights;
600 flatbuffers::Offset<serializer::ConstTensor> inputGateBias;
601 flatbuffers::Offset<serializer::ConstTensor> projectionWeights;
602 flatbuffers::Offset<serializer::ConstTensor> projectionBias;
603 flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights;
604 flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights;
605 flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights;
606 flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights;
607 flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights;
608 flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights;
610 if (!descriptor.m_CifgEnabled)
612 inputToInputWeights = CreateConstTensorInfo(*params.m_InputToInputWeights);
613 recurrentToInputWeights = CreateConstTensorInfo(*params.m_RecurrentToInputWeights);
614 cellToInputWeights = CreateConstTensorInfo(*params.m_CellToInputWeights);
615 inputGateBias = CreateConstTensorInfo(*params.m_InputGateBias);
618 if (descriptor.m_ProjectionEnabled)
620 projectionWeights = CreateConstTensorInfo(*params.m_ProjectionWeights);
621 projectionBias = CreateConstTensorInfo(*params.m_ProjectionBias);
624 if (descriptor.m_PeepholeEnabled)
626 cellToForgetWeights = CreateConstTensorInfo(*params.m_CellToForgetWeights);
627 cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights);
630 if (descriptor.m_LayerNormEnabled)
632 if (!descriptor.m_CifgEnabled)
634 inputLayerNormWeights = CreateConstTensorInfo((*params.m_InputLayerNormWeights));
636 forgetLayerNormWeights = CreateConstTensorInfo(*params.m_ForgetLayerNormWeights);
637 cellLayerNormWeights = CreateConstTensorInfo(*params.m_CellLayerNormWeights);
638 outputLayerNormWeights = CreateConstTensorInfo(*params.m_OutputLayerNormWeights);
641 auto fbLstmParams = serializer::CreateLstmInputParams(
643 inputToForgetWeights,
645 inputToOutputWeights,
646 recurrentToForgetWeights,
647 recurrentToCellWeights,
648 recurrentToOutputWeights,
653 recurrentToInputWeights,
660 inputLayerNormWeights,
661 forgetLayerNormWeights,
662 cellLayerNormWeights,
663 outputLayerNormWeights);
665 auto fbLstmLayer = serializer::CreateLstmLayer(
671 CreateAnyLayer(fbLstmLayer.o, serializer::Layer::Layer_LstmLayer);
674 void SerializerVisitor::VisitMaximumLayer(const armnn::IConnectableLayer* layer, const char* name)
678 auto fbMaximumBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Maximum);
679 auto fbMaximumLayer = serializer::CreateMaximumLayer(m_flatBufferBuilder, fbMaximumBaseLayer);
681 CreateAnyLayer(fbMaximumLayer.o, serializer::Layer::Layer_MaximumLayer);
684 void SerializerVisitor::VisitMeanLayer(const armnn::IConnectableLayer* layer,
685 const armnn::MeanDescriptor& descriptor,
690 auto fbMeanBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Mean);
691 auto fbMeanDescriptor = serializer::CreateMeanDescriptor(m_flatBufferBuilder,
692 m_flatBufferBuilder.CreateVector(descriptor.m_Axis),
693 descriptor.m_KeepDims);
695 auto fbMeanLayer = serializer::CreateMeanLayer(m_flatBufferBuilder,
699 CreateAnyLayer(fbMeanLayer.o, serializer::Layer::Layer_MeanLayer);
702 void SerializerVisitor::VisitMinimumLayer(const armnn::IConnectableLayer* layer, const char* name)
706 auto fbMinimumBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Minimum);
707 auto fbMinimumLayer = serializer::CreateMinimumLayer(m_flatBufferBuilder, fbMinimumBaseLayer);
709 CreateAnyLayer(fbMinimumLayer.o, serializer::Layer::Layer_MinimumLayer);
712 void SerializerVisitor::VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name)
716 auto fbMergeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Merge);
717 auto fbMergeLayer = serializer::CreateMergeLayer(m_flatBufferBuilder, fbMergeBaseLayer);
719 CreateAnyLayer(fbMergeLayer.o, serializer::Layer::Layer_MergeLayer);
722 void SerializerVisitor::VisitMergerLayer(const armnn::IConnectableLayer* layer,
723 const armnn::MergerDescriptor& mergerDescriptor,
726 VisitConcatLayer(layer, mergerDescriptor, name);
729 void SerializerVisitor::VisitConcatLayer(const armnn::IConnectableLayer* layer,
730 const armnn::ConcatDescriptor& concatDescriptor,
735 auto flatBufferConcatBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Concat);
737 std::vector<flatbuffers::Offset<UintVector>> views;
738 for (unsigned int v = 0; v < concatDescriptor.GetNumViews(); ++v)
740 const uint32_t* origin = concatDescriptor.GetViewOrigin(v);
741 std::vector<uint32_t> origins;
742 for (unsigned int d = 0; d < concatDescriptor.GetNumDimensions(); ++d)
744 origins.push_back(origin[d]);
746 auto view = m_flatBufferBuilder.CreateVector(origins);
747 auto uintVector = CreateUintVector(m_flatBufferBuilder, view);
748 views.push_back(uintVector);
751 auto flatBufferConcatDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder,
752 concatDescriptor.GetConcatAxis(),
753 concatDescriptor.GetNumViews(),
754 concatDescriptor.GetNumDimensions(),
755 m_flatBufferBuilder.CreateVector(views));
757 auto flatBufferLayer = CreateConcatLayer(m_flatBufferBuilder,
758 flatBufferConcatBaseLayer,
759 flatBufferConcatDescriptor);
761 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ConcatLayer);
764 void SerializerVisitor::VisitMultiplicationLayer(const armnn::IConnectableLayer* layer, const char* name)
768 auto fbMultiplicationBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Multiplication);
769 auto fbMultiplicationLayer = serializer::CreateMultiplicationLayer(m_flatBufferBuilder,
770 fbMultiplicationBaseLayer);
772 CreateAnyLayer(fbMultiplicationLayer.o, serializer::Layer::Layer_MultiplicationLayer);
775 void SerializerVisitor::VisitPadLayer(const armnn::IConnectableLayer* layer,
776 const armnn::PadDescriptor& padDescriptor,
781 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Pad);
783 std::vector<unsigned int> padList;
784 for (auto& p: padDescriptor.m_PadList)
786 padList.push_back(p.first);
787 padList.push_back(p.second);
790 auto flatBufferPadDesc = serializer::CreatePadDescriptor(m_flatBufferBuilder,
791 m_flatBufferBuilder.CreateVector(padList),
792 padDescriptor.m_PadValue);
794 auto flatBufferPadLayer = serializer::CreatePadLayer(m_flatBufferBuilder,
798 CreateAnyLayer(flatBufferPadLayer.o, serializer::Layer::Layer_PadLayer);
801 void SerializerVisitor::VisitPermuteLayer(const armnn::IConnectableLayer* layer,
802 const armnn::PermuteDescriptor& permuteDescriptor,
807 // Create FlatBuffer BaseLayer
808 auto flatBufferPermuteBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Permute);
810 std::vector<unsigned int> dimMappings;
811 for (unsigned int i=0; i<permuteDescriptor.m_DimMappings.GetSize(); ++i)
813 dimMappings.push_back(permuteDescriptor.m_DimMappings[i]);
816 auto flatBufferPermuteDesc = serializer::CreatePermuteDescriptor(m_flatBufferBuilder,
817 m_flatBufferBuilder.CreateVector(dimMappings));
819 // Create the FlatBuffer PermuteLayer
820 auto flatBufferPermuteLayer = serializer::CreatePermuteLayer(m_flatBufferBuilder,
821 flatBufferPermuteBaseLayer,
822 flatBufferPermuteDesc);
824 // Add the AnyLayer to the FlatBufferLayers
825 CreateAnyLayer(flatBufferPermuteLayer.o, serializer::Layer::Layer_PermuteLayer);
828 // Build FlatBuffer for Reshape Layer
829 void SerializerVisitor::VisitReshapeLayer(const armnn::IConnectableLayer* layer,
830 const armnn::ReshapeDescriptor& reshapeDescriptor,
835 // Create FlatBuffer BaseLayer
836 auto flatBufferReshapeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Reshape);
838 std::vector<unsigned int> targetShape;
839 for (unsigned int i =0; i < reshapeDescriptor.m_TargetShape.GetNumDimensions(); i++)
841 targetShape.push_back(reshapeDescriptor.m_TargetShape[i]);
844 auto flatBufferReshapeDesc = serializer::CreateReshapeDescriptor(m_flatBufferBuilder,
845 m_flatBufferBuilder.CreateVector(targetShape));
847 // Create the FlatBuffer ReshapeLayer
848 auto flatBufferReshapeLayer = serializer::CreateReshapeLayer(m_flatBufferBuilder, flatBufferReshapeBaseLayer,
849 flatBufferReshapeDesc);
851 // Add the AnyLayer to the FlatBufferLayers
852 CreateAnyLayer(flatBufferReshapeLayer.o, serializer::Layer::Layer_ReshapeLayer);
855 void SerializerVisitor::VisitResizeBilinearLayer(const armnn::IConnectableLayer* layer,
856 const armnn::ResizeBilinearDescriptor& resizeDescriptor,
861 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_ResizeBilinear);
863 auto flatBufferDescriptor =
864 CreateResizeBilinearDescriptor(m_flatBufferBuilder,
865 resizeDescriptor.m_TargetWidth,
866 resizeDescriptor.m_TargetHeight,
867 GetFlatBufferDataLayout(resizeDescriptor.m_DataLayout));
869 auto flatBufferLayer = serializer::CreateResizeBilinearLayer(m_flatBufferBuilder,
871 flatBufferDescriptor);
873 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ResizeBilinearLayer);
876 void SerializerVisitor::VisitResizeLayer(const armnn::IConnectableLayer* layer,
877 const armnn::ResizeDescriptor& resizeDescriptor,
882 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Resize);
884 auto flatBufferDescriptor =
885 CreateResizeDescriptor(m_flatBufferBuilder,
886 resizeDescriptor.m_TargetHeight,
887 resizeDescriptor.m_TargetWidth,
888 GetFlatBufferResizeMethod(resizeDescriptor.m_Method),
889 GetFlatBufferDataLayout(resizeDescriptor.m_DataLayout));
891 auto flatBufferLayer = serializer::CreateResizeLayer(m_flatBufferBuilder,
893 flatBufferDescriptor);
895 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ResizeLayer);
898 void SerializerVisitor::VisitRsqrtLayer(const armnn::IConnectableLayer* layer, const char* name)
902 auto fbRsqrtBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Rsqrt);
903 auto fbRsqrtLayer = serializer::CreateRsqrtLayer(m_flatBufferBuilder, fbRsqrtBaseLayer);
905 CreateAnyLayer(fbRsqrtLayer.o, serializer::Layer::Layer_RsqrtLayer);
908 void SerializerVisitor::VisitSliceLayer(const armnn::IConnectableLayer* layer,
909 const armnn::SliceDescriptor& sliceDescriptor,
914 auto fbSliceBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Slice);
915 auto fbSliceDescriptor = CreateSliceDescriptor(m_flatBufferBuilder,
916 m_flatBufferBuilder.CreateVector(sliceDescriptor.m_Begin),
917 m_flatBufferBuilder.CreateVector(sliceDescriptor.m_Size));
919 auto fbSliceLayer = serializer::CreateSliceLayer(m_flatBufferBuilder, fbSliceBaseLayer, fbSliceDescriptor);
921 CreateAnyLayer(fbSliceLayer.o, serializer::Layer::Layer_SliceLayer);
924 // Build FlatBuffer for Softmax Layer
925 void SerializerVisitor::VisitSoftmaxLayer(const armnn::IConnectableLayer* layer,
926 const armnn::SoftmaxDescriptor& softmaxDescriptor,
931 // Create FlatBuffer BaseLayer
932 auto flatBufferSoftmaxBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Softmax);
934 // Create the FlatBuffer SoftmaxDescriptor
935 auto flatBufferSoftmaxDesc =
936 serializer::CreateSoftmaxDescriptor(m_flatBufferBuilder, softmaxDescriptor.m_Beta);
938 // Create the FlatBuffer SoftmaxLayer
939 auto flatBufferSoftmaxLayer =
940 serializer::CreateSoftmaxLayer(m_flatBufferBuilder,
941 flatBufferSoftmaxBaseLayer,
942 flatBufferSoftmaxDesc);
944 CreateAnyLayer(flatBufferSoftmaxLayer.o, serializer::Layer::Layer_SoftmaxLayer);
947 void SerializerVisitor::VisitPooling2dLayer(const armnn::IConnectableLayer* layer,
948 const armnn::Pooling2dDescriptor& pooling2dDescriptor,
953 auto fbPooling2dBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Pooling2d);
954 auto fbPooling2dDescriptor = serializer::CreatePooling2dDescriptor(
956 GetFlatBufferPoolingAlgorithm(pooling2dDescriptor.m_PoolType),
957 pooling2dDescriptor.m_PadLeft,
958 pooling2dDescriptor.m_PadRight,
959 pooling2dDescriptor.m_PadTop,
960 pooling2dDescriptor.m_PadBottom,
961 pooling2dDescriptor.m_PoolWidth,
962 pooling2dDescriptor.m_PoolHeight,
963 pooling2dDescriptor.m_StrideX,
964 pooling2dDescriptor.m_StrideY,
965 GetFlatBufferOutputShapeRounding(pooling2dDescriptor.m_OutputShapeRounding),
966 GetFlatBufferPaddingMethod(pooling2dDescriptor.m_PaddingMethod),
967 GetFlatBufferDataLayout(pooling2dDescriptor.m_DataLayout));
969 auto fbPooling2dLayer = serializer::CreatePooling2dLayer(m_flatBufferBuilder,
970 fbPooling2dBaseLayer,
971 fbPooling2dDescriptor);
973 CreateAnyLayer(fbPooling2dLayer.o, serializer::Layer::Layer_Pooling2dLayer);
976 void SerializerVisitor::VisitPreluLayer(const armnn::IConnectableLayer* layer,
981 // Create FlatBuffer BaseLayer
982 auto flatBufferPreluBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Prelu);
984 // Create the FlatBuffer AdditionLayer
985 auto flatBufferPreluLayer = serializer::CreatePreluLayer(m_flatBufferBuilder, flatBufferPreluBaseLayer);
987 // Add the AnyLayer to the FlatBufferLayers
988 CreateAnyLayer(flatBufferPreluLayer.o, serializer::Layer::Layer_PreluLayer);
991 void SerializerVisitor::VisitQuantizeLayer(const armnn::IConnectableLayer *layer, const char *name)
995 auto fbQuantizeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Quantize);
996 auto fbQuantizeLayer = serializer::CreateQuantizeLayer(m_flatBufferBuilder,
997 fbQuantizeBaseLayer);
998 CreateAnyLayer(fbQuantizeLayer.o, serializer::Layer::Layer_QuantizeLayer);
1001 // Build FlatBuffer for FullyConnected Layer
1002 void SerializerVisitor::VisitFullyConnectedLayer(const armnn::IConnectableLayer* layer,
1003 const armnn::FullyConnectedDescriptor& fullyConnectedDescriptor,
1004 const armnn::ConstTensor& weights,
1005 const armnn::Optional<armnn::ConstTensor>& biases,
1010 // Create FlatBuffer BaseLayer
1011 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_FullyConnected);
1013 // Create FlatBuffer FullyConnectedDescriptor
1014 auto flatBufferDescriptor =
1015 serializer::CreateFullyConnectedDescriptor(m_flatBufferBuilder,
1016 fullyConnectedDescriptor.m_BiasEnabled,
1017 fullyConnectedDescriptor.m_TransposeWeightMatrix);
1019 // Create FlatBuffer weights data
1020 auto flatBufferWeights = CreateConstTensorInfo(weights);
1022 // Create FlatBuffer bias data
1023 flatbuffers::Offset<serializer::ConstTensor> flatBufferBiases;
1024 if (fullyConnectedDescriptor.m_BiasEnabled)
1026 flatBufferBiases = CreateConstTensorInfo(biases.value());
1029 // Create FlatBuffer FullyConnectedLayer
1030 auto flatBufferLayer = serializer::CreateFullyConnectedLayer(m_flatBufferBuilder,
1031 flatBufferBaseLayer,
1032 flatBufferDescriptor,
1036 // Add created FullyConnectedLayer to the FlatBufferLayers
1037 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_FullyConnectedLayer);
1040 // Build FlatBuffer for SpaceToBatchNd Layer
1041 void SerializerVisitor::VisitSpaceToBatchNdLayer(const armnn::IConnectableLayer* layer,
1042 const armnn::SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
1047 // Create FlatBuffer BaseLayer
1048 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_SpaceToBatchNd);
1050 std::vector<unsigned int> padList;
1051 padList.reserve(spaceToBatchNdDescriptor.m_PadList.size()*2);
1052 for (auto& pad : spaceToBatchNdDescriptor.m_PadList)
1054 padList.push_back(pad.first);
1055 padList.push_back(pad.second);
1058 auto flatBufferDescriptor =
1059 CreateSpaceToBatchNdDescriptor(m_flatBufferBuilder,
1060 m_flatBufferBuilder.CreateVector(spaceToBatchNdDescriptor.m_BlockShape),
1061 m_flatBufferBuilder.CreateVector(padList),
1062 GetFlatBufferDataLayout(spaceToBatchNdDescriptor.m_DataLayout));
1064 auto flatBufferLayer = serializer::CreateSpaceToBatchNdLayer(m_flatBufferBuilder,
1065 flatBufferBaseLayer,
1066 flatBufferDescriptor);
1068 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_SpaceToBatchNdLayer);
1071 // Build FlatBuffer for SpaceToDepthLayer
1072 void SerializerVisitor::VisitSpaceToDepthLayer(const armnn::IConnectableLayer* layer,
1073 const armnn::SpaceToDepthDescriptor& spaceToDepthDescriptor,
1078 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_SpaceToDepth);
1079 auto flatBufferDescriptor =
1080 CreateSpaceToDepthDescriptor(m_flatBufferBuilder,
1081 spaceToDepthDescriptor.m_BlockSize,
1082 GetFlatBufferDataLayout(spaceToDepthDescriptor.m_DataLayout));
1084 auto flatBufferLayer = serializer::CreateSpaceToDepthLayer(m_flatBufferBuilder,
1085 flatBufferBaseLayer,
1086 flatBufferDescriptor);
1088 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_SpaceToDepthLayer);
1091 // Build FlatBuffer for Splitter Layer
1092 void SerializerVisitor::VisitSplitterLayer(const armnn::IConnectableLayer* layer,
1093 const armnn::ViewsDescriptor& viewsDescriptor,
1098 // Create FlatBuffer ViewOrigins
1099 std::vector<flatbuffers::Offset<UintVector>> flatBufferViewOrigins;
1100 flatBufferViewOrigins.reserve(viewsDescriptor.GetNumViews());
1102 for(unsigned int vIdx = 0; vIdx < viewsDescriptor.GetNumViews(); ++vIdx)
1104 std::vector<uint32_t> viewOrigin;
1105 viewOrigin.reserve(viewsDescriptor.GetNumDimensions());
1108 for(unsigned int dIdx = 0; dIdx < viewsDescriptor.GetNumDimensions(); ++dIdx)
1110 viewOrigin.push_back(viewsDescriptor.GetViewOrigin(vIdx)[dIdx]);
1113 flatBufferViewOrigins.push_back(CreateUintVector(m_flatBufferBuilder,
1114 m_flatBufferBuilder.CreateVector(viewOrigin)));
1117 // Create FlatBuffer OriginsDescriptor
1118 auto flatBufferOriginDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder,
1119 viewsDescriptor.GetOrigins().GetConcatAxis(),
1120 viewsDescriptor.GetOrigins().GetNumViews(),
1121 viewsDescriptor.GetOrigins().GetNumDimensions(),
1122 m_flatBufferBuilder.CreateVector(flatBufferViewOrigins));
1124 // Create FlatBuffer ViewOrigins
1125 std::vector<flatbuffers::Offset<UintVector>> flatBufferViewSizes;
1126 flatBufferViewSizes.reserve(viewsDescriptor.GetNumViews());
1128 for(unsigned int vIdx = 0; vIdx < viewsDescriptor.GetNumViews(); ++vIdx)
1130 std::vector<uint32_t> viewSize;
1131 viewSize.reserve(viewsDescriptor.GetNumDimensions());
1134 for(unsigned int dIdx = 0; dIdx < viewsDescriptor.GetNumDimensions(); ++dIdx)
1136 viewSize.push_back(viewsDescriptor.GetViewSizes(vIdx)[dIdx]);
1139 flatBufferViewSizes.push_back(CreateUintVector(m_flatBufferBuilder,
1140 m_flatBufferBuilder.CreateVector(viewSize)));
1143 // Create FlatBuffer ViewsDescriptor
1144 auto flatBufferViewsDescriptor = CreateViewsDescriptor(m_flatBufferBuilder,
1145 flatBufferOriginDescriptor,
1146 m_flatBufferBuilder.CreateVector(flatBufferViewSizes));
1148 // Create FlatBuffer BaseLayer
1149 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Splitter);
1151 auto flatBufferSplitterLayer = serializer::CreateSplitterLayer(m_flatBufferBuilder,
1152 flatBufferBaseLayer,
1153 flatBufferViewsDescriptor);
1155 CreateAnyLayer(flatBufferSplitterLayer.o, serializer::Layer::Layer_SplitterLayer);
1158 void SerializerVisitor::VisitNormalizationLayer(const armnn::IConnectableLayer* layer,
1159 const armnn::NormalizationDescriptor& descriptor,
1164 auto fbNormalizationBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Normalization);
1166 auto fbNormalizationDescriptor = serializer::CreateNormalizationDescriptor(
1167 m_flatBufferBuilder,
1168 GetFlatBufferNormalizationAlgorithmChannel(descriptor.m_NormChannelType),
1169 GetFlatBufferNormalizationAlgorithmMethod(descriptor.m_NormMethodType),
1170 descriptor.m_NormSize,
1174 GetFlatBufferDataLayout(descriptor.m_DataLayout));
1176 auto flatBufferLayer = serializer::CreateNormalizationLayer(m_flatBufferBuilder,
1177 fbNormalizationBaseLayer,
1178 fbNormalizationDescriptor);
1180 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_NormalizationLayer);
1183 void SerializerVisitor::VisitStackLayer(const armnn::IConnectableLayer* layer,
1184 const armnn::StackDescriptor& stackDescriptor,
1189 auto stackBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Stack);
1191 std::vector<unsigned int> inputShape;
1192 for (unsigned int i =0; i < stackDescriptor.m_InputShape.GetNumDimensions(); i++)
1194 inputShape.push_back(stackDescriptor.m_InputShape[i]);
1197 auto flatBufferStackDescriptor = CreateStackDescriptor(m_flatBufferBuilder,
1198 stackDescriptor.m_Axis,
1199 stackDescriptor.m_NumInputs,
1200 m_flatBufferBuilder.CreateVector(inputShape));
1202 auto stackLayer = serializer::CreateStackLayer(m_flatBufferBuilder, stackBaseLayer, flatBufferStackDescriptor);
1203 CreateAnyLayer(stackLayer.o, serializer::Layer::Layer_StackLayer);
1206 void SerializerVisitor::VisitStandInLayer(const armnn::IConnectableLayer *layer,
1207 const armnn::StandInDescriptor& standInDescriptor,
1212 auto fbDescriptor = serializer::CreateStandInDescriptor(m_flatBufferBuilder,
1213 standInDescriptor.m_NumInputs,
1214 standInDescriptor.m_NumOutputs);
1216 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_StandIn);
1217 auto fbLayer = serializer::CreateStandInLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
1219 CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_StandInLayer);
1222 void SerializerVisitor::VisitStridedSliceLayer(const armnn::IConnectableLayer* layer,
1223 const armnn::StridedSliceDescriptor& stridedSliceDescriptor,
1228 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_StridedSlice);
1230 auto flatBufferDescriptor =
1231 CreateStridedSliceDescriptor(m_flatBufferBuilder,
1232 m_flatBufferBuilder.CreateVector(stridedSliceDescriptor.m_Begin),
1233 m_flatBufferBuilder.CreateVector(stridedSliceDescriptor.m_End),
1234 m_flatBufferBuilder.CreateVector(stridedSliceDescriptor.m_Stride),
1235 stridedSliceDescriptor.m_BeginMask,
1236 stridedSliceDescriptor.m_EndMask,
1237 stridedSliceDescriptor.m_ShrinkAxisMask,
1238 stridedSliceDescriptor.m_EllipsisMask,
1239 stridedSliceDescriptor.m_NewAxisMask,
1240 GetFlatBufferDataLayout(stridedSliceDescriptor.m_DataLayout));
1242 auto flatBufferLayer = serializer::CreateStridedSliceLayer(m_flatBufferBuilder,
1243 flatBufferBaseLayer,
1244 flatBufferDescriptor);
1246 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_StridedSliceLayer);
1249 void SerializerVisitor::VisitSubtractionLayer(const armnn::IConnectableLayer* layer, const char* name)
1253 auto fbSubtractionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Subtraction);
1254 auto fbSubtractionLayer = serializer::CreateSubtractionLayer(m_flatBufferBuilder, fbSubtractionBaseLayer);
1256 CreateAnyLayer(fbSubtractionLayer.o, serializer::Layer::Layer_SubtractionLayer);
1259 void SerializerVisitor::VisitSwitchLayer(const armnn::IConnectableLayer* layer, const char* name)
1263 auto fbSwitchBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Switch);
1264 auto fbSwitchLayer = serializer::CreateSwitchLayer(m_flatBufferBuilder, fbSwitchBaseLayer);
1266 CreateAnyLayer(fbSwitchLayer.o, serializer::Layer::Layer_SwitchLayer);
1269 void SerializerVisitor::VisitTransposeConvolution2dLayer(
1270 const armnn::IConnectableLayer* layer,
1271 const armnn::TransposeConvolution2dDescriptor& descriptor,
1272 const armnn::ConstTensor& weights,
1273 const armnn::Optional<armnn::ConstTensor>& biases,
1278 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Convolution2d);
1279 auto fbDescriptor = CreateTransposeConvolution2dDescriptor(m_flatBufferBuilder,
1280 descriptor.m_PadLeft,
1281 descriptor.m_PadRight,
1282 descriptor.m_PadTop,
1283 descriptor.m_PadBottom,
1284 descriptor.m_StrideX,
1285 descriptor.m_StrideY,
1286 descriptor.m_BiasEnabled,
1287 GetFlatBufferDataLayout(descriptor.m_DataLayout));
1290 auto fbWeightsConstTensorInfo = CreateConstTensorInfo(weights);
1291 flatbuffers::Offset<serializer::ConstTensor> fbBiasesConstTensorInfo;
1292 if (biases.has_value())
1294 fbBiasesConstTensorInfo = CreateConstTensorInfo(biases.value());
1297 auto fbLayer = CreateTransposeConvolution2dLayer(m_flatBufferBuilder,
1300 fbWeightsConstTensorInfo,
1301 fbBiasesConstTensorInfo);
1303 CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_TransposeConvolution2dLayer);
1306 void SerializerVisitor::VisitTransposeLayer(const armnn::IConnectableLayer* layer,
1307 const armnn::TransposeDescriptor& descriptor,
1312 // Create FlatBuffer BaseLayer
1313 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Transpose);
1315 std::vector<unsigned int> dimMappings;
1316 for (unsigned int i=0; i<descriptor.m_DimMappings.GetSize(); ++i)
1318 dimMappings.push_back(descriptor.m_DimMappings[i]);
1321 auto flatBufferDesc = serializer::CreateTransposeDescriptor(m_flatBufferBuilder,
1322 m_flatBufferBuilder.CreateVector(dimMappings));
1324 // Create the FlatBuffer TransposeLayer
1325 auto flatBufferLayer = serializer::CreateTransposeLayer(m_flatBufferBuilder,
1326 flatBufferBaseLayer,
1329 // Add the AnyLayer to the FlatBufferLayers
1330 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_TransposeLayer);
1333 void SerializerVisitor::VisitQLstmLayer(const armnn::IConnectableLayer* layer,
1334 const armnn::QLstmDescriptor& descriptor,
1335 const armnn::LstmInputParams& params,
1340 auto fbQLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QLstm);
1342 auto fbQLstmDescriptor = serializer::CreateQLstmDescriptor(
1343 m_flatBufferBuilder,
1344 descriptor.m_CifgEnabled,
1345 descriptor.m_PeepholeEnabled,
1346 descriptor.m_ProjectionEnabled,
1347 descriptor.m_LayerNormEnabled,
1348 descriptor.m_CellClip,
1349 descriptor.m_ProjectionClip,
1350 descriptor.m_InputIntermediateScale,
1351 descriptor.m_ForgetIntermediateScale,
1352 descriptor.m_CellIntermediateScale,
1353 descriptor.m_OutputIntermediateScale,
1354 descriptor.m_HiddenStateZeroPoint,
1355 descriptor.m_HiddenStateScale
1359 auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights);
1360 auto inputToCellWeights = CreateConstTensorInfo(*params.m_InputToCellWeights);
1361 auto inputToOutputWeights = CreateConstTensorInfo(*params.m_InputToOutputWeights);
1362 auto recurrentToForgetWeights = CreateConstTensorInfo(*params.m_RecurrentToForgetWeights);
1363 auto recurrentToCellWeights = CreateConstTensorInfo(*params.m_RecurrentToCellWeights);
1364 auto recurrentToOutputWeights = CreateConstTensorInfo(*params.m_RecurrentToOutputWeights);
1365 auto forgetGateBias = CreateConstTensorInfo(*params.m_ForgetGateBias);
1366 auto cellBias = CreateConstTensorInfo(*params.m_CellBias);
1367 auto outputGateBias = CreateConstTensorInfo(*params.m_OutputGateBias);
1370 flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights;
1371 flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights;
1372 flatbuffers::Offset<serializer::ConstTensor> inputGateBias;
1374 if (!descriptor.m_CifgEnabled)
1376 inputToInputWeights = CreateConstTensorInfo(*params.m_InputToInputWeights);
1377 recurrentToInputWeights = CreateConstTensorInfo(*params.m_RecurrentToInputWeights);
1378 inputGateBias = CreateConstTensorInfo(*params.m_InputGateBias);
1382 flatbuffers::Offset<serializer::ConstTensor> projectionWeights;
1383 flatbuffers::Offset<serializer::ConstTensor> projectionBias;
1385 if (descriptor.m_ProjectionEnabled)
1387 projectionWeights = CreateConstTensorInfo(*params.m_ProjectionWeights);
1388 projectionBias = CreateConstTensorInfo(*params.m_ProjectionBias);
1392 flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights;
1393 flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights;
1394 flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights;
1396 if (descriptor.m_PeepholeEnabled)
1398 if (!descriptor.m_CifgEnabled)
1400 cellToInputWeights = CreateConstTensorInfo(*params.m_CellToInputWeights);
1403 cellToForgetWeights = CreateConstTensorInfo(*params.m_CellToForgetWeights);
1404 cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights);
1408 flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights;
1409 flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights;
1410 flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights;
1411 flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights;
1413 if (descriptor.m_LayerNormEnabled)
1415 if (!descriptor.m_CifgEnabled)
1417 inputLayerNormWeights = CreateConstTensorInfo((*params.m_InputLayerNormWeights));
1420 forgetLayerNormWeights = CreateConstTensorInfo(*params.m_ForgetLayerNormWeights);
1421 cellLayerNormWeights = CreateConstTensorInfo(*params.m_CellLayerNormWeights);
1422 outputLayerNormWeights = CreateConstTensorInfo(*params.m_OutputLayerNormWeights);
1425 auto fbQLstmParams = serializer::CreateQLstmInputParams(
1426 m_flatBufferBuilder,
1427 inputToForgetWeights,
1429 inputToOutputWeights,
1430 recurrentToForgetWeights,
1431 recurrentToCellWeights,
1432 recurrentToOutputWeights,
1436 inputToInputWeights,
1437 recurrentToInputWeights,
1442 cellToForgetWeights,
1443 cellToOutputWeights,
1444 inputLayerNormWeights,
1445 forgetLayerNormWeights,
1446 cellLayerNormWeights,
1447 outputLayerNormWeights);
1449 auto fbQLstmLayer = serializer::CreateQLstmLayer(
1450 m_flatBufferBuilder,
1455 CreateAnyLayer(fbQLstmLayer.o, serializer::Layer::Layer_QLstmLayer);
1458 void SerializerVisitor::VisitQuantizedLstmLayer(const armnn::IConnectableLayer* layer,
1459 const armnn::QuantizedLstmInputParams& params,
1464 auto fbQuantizedLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QuantizedLstm);
1466 // Get input parameters
1467 auto inputToInputWeights = CreateConstTensorInfo(params.GetInputToInputWeights());
1468 auto inputToForgetWeights = CreateConstTensorInfo(params.GetInputToForgetWeights());
1469 auto inputToCellWeights = CreateConstTensorInfo(params.GetInputToCellWeights());
1470 auto inputToOutputWeights = CreateConstTensorInfo(params.GetInputToOutputWeights());
1472 auto recurrentToInputWeights = CreateConstTensorInfo(params.GetRecurrentToInputWeights());
1473 auto recurrentToForgetWeights = CreateConstTensorInfo(params.GetRecurrentToForgetWeights());
1474 auto recurrentToCellWeights = CreateConstTensorInfo(params.GetRecurrentToCellWeights());
1475 auto recurrentToOutputWeights = CreateConstTensorInfo(params.GetRecurrentToOutputWeights());
1477 auto inputGateBias = CreateConstTensorInfo(params.GetInputGateBias());
1478 auto forgetGateBias = CreateConstTensorInfo(params.GetForgetGateBias());
1479 auto cellBias = CreateConstTensorInfo(params.GetCellBias());
1480 auto outputGateBias = CreateConstTensorInfo(params.GetOutputGateBias());
1482 auto fbQuantizedLstmParams = serializer::CreateQuantizedLstmInputParams(
1483 m_flatBufferBuilder,
1484 inputToInputWeights,
1485 inputToForgetWeights,
1487 inputToOutputWeights,
1488 recurrentToInputWeights,
1489 recurrentToForgetWeights,
1490 recurrentToCellWeights,
1491 recurrentToOutputWeights,
1497 auto fbQuantizedLstmLayer = serializer::CreateQuantizedLstmLayer(
1498 m_flatBufferBuilder,
1499 fbQuantizedLstmBaseLayer,
1500 fbQuantizedLstmParams);
1502 CreateAnyLayer(fbQuantizedLstmLayer.o, serializer::Layer::Layer_QuantizedLstmLayer);
1505 fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer,
1506 const serializer::LayerType layerType)
1509 uint32_t fbIndex = GetSerializedId(layer->GetGuid());
1511 std::vector<fb::Offset<serializer::InputSlot>> inputSlots = CreateInputSlots(layer);
1512 std::vector<fb::Offset<serializer::OutputSlot>> outputSlots = CreateOutputSlots(layer);
1514 return serializer::CreateLayerBase(m_flatBufferBuilder,
1516 m_flatBufferBuilder.CreateString(layer->GetName()),
1518 m_flatBufferBuilder.CreateVector(inputSlots),
1519 m_flatBufferBuilder.CreateVector(outputSlots));
1522 void SerializerVisitor::CreateAnyLayer(const flatbuffers::Offset<void>& layer, const serializer::Layer serializerLayer)
1525 auto anyLayer = armnnSerializer::CreateAnyLayer(m_flatBufferBuilder, serializerLayer, layer);
1526 m_serializedLayers.push_back(anyLayer);
1529 template <typename T>
1530 flatbuffers::Offset<flatbuffers::Vector<T>> SerializerVisitor::CreateDataVector(const void* memory, unsigned int size)
1532 const T* buffer = reinterpret_cast<const T*>(memory);
1533 std::vector<T> vector(buffer, buffer + (size / sizeof(T)));
1534 auto fbVector = m_flatBufferBuilder.CreateVector(vector);
1538 flatbuffers::Offset<TensorInfo> SerializerVisitor::CreateTensorInfo(const armnn::TensorInfo& tensorInfo)
1540 // Get the dimensions
1541 std::vector<unsigned int> shape;
1542 for(unsigned int dim = 0; dim < tensorInfo.GetShape().GetNumDimensions(); ++dim)
1544 shape.push_back(tensorInfo.GetShape()[dim]);
1547 if (tensorInfo.HasPerAxisQuantization())
1549 // Create FlatBuffer TensorInfo
1550 auto flatBufferTensorInfo =
1551 serializer::CreateTensorInfo(m_flatBufferBuilder,
1552 m_flatBufferBuilder.CreateVector(shape),
1553 GetFlatBufferDataType(tensorInfo.GetDataType()),
1554 tensorInfo.GetQuantizationScales()[0],
1555 tensorInfo.GetQuantizationOffset(),
1556 m_flatBufferBuilder.CreateVector(tensorInfo.GetQuantizationScales()),
1557 tensorInfo.GetQuantizationDim().value());
1558 return flatBufferTensorInfo;
1561 // Create FlatBuffer TensorInfo
1562 auto flatBufferTensorInfo = serializer::CreateTensorInfo(m_flatBufferBuilder,
1563 m_flatBufferBuilder.CreateVector(shape),
1564 GetFlatBufferDataType(tensorInfo.GetDataType()),
1565 tensorInfo.GetQuantizationScale(),
1566 tensorInfo.GetQuantizationOffset());
1567 return flatBufferTensorInfo;
1570 flatbuffers::Offset<serializer::ConstTensor>
1571 SerializerVisitor::CreateConstTensorInfo(const armnn::ConstTensor& constTensor)
1573 armnn::TensorInfo tensorInfo = constTensor.GetInfo();
1575 flatbuffers::Offset<void> fbPayload;
1577 switch (tensorInfo.GetDataType())
1579 case armnn::DataType::Float32:
1580 case armnn::DataType::Signed32:
1582 auto fbVector = CreateDataVector<int32_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1583 flatbuffers::Offset<serializer::IntData> flatBuffersData = serializer::CreateIntData(
1584 m_flatBufferBuilder,
1586 fbPayload = flatBuffersData.o;
1589 case armnn::DataType::Float16:
1590 case armnn::DataType::BFloat16:
1591 case armnn::DataType::QSymmS16:
1593 auto fbVector = CreateDataVector<int16_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1594 flatbuffers::Offset<serializer::ShortData> flatBuffersData = serializer::CreateShortData(
1595 m_flatBufferBuilder,
1597 fbPayload = flatBuffersData.o;
1600 case armnn::DataType::QSymmS8:
1601 case armnn::DataType::QAsymmS8:
1602 case armnn::DataType::QAsymmU8:
1603 case armnn::DataType::Boolean:
1606 auto fbVector = CreateDataVector<int8_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1607 flatbuffers::Offset<serializer::ByteData> flatBuffersData = serializer::CreateByteData(
1608 m_flatBufferBuilder,
1610 fbPayload = flatBuffersData.o;
1613 flatbuffers::Offset<serializer::ConstTensor> flatBufferConstTensor = serializer::CreateConstTensor(
1614 m_flatBufferBuilder,
1615 CreateTensorInfo(tensorInfo),
1616 GetFlatBufferConstTensorData(tensorInfo.GetDataType()),
1618 return flatBufferConstTensor;
1621 flatbuffers::Offset<armnnSerializer::FeatureCompatibilityVersions> SerializerVisitor::GetVersionTable()
1623 flatbuffers::Offset<armnnSerializer::FeatureCompatibilityVersions> versionsTable =
1624 serializer::CreateFeatureCompatibilityVersions(
1625 m_flatBufferBuilder,
1626 1 // Binding ids scheme version
1628 return versionsTable;
1631 std::vector<fb::Offset<serializer::InputSlot>>
1632 SerializerVisitor::CreateInputSlots(const armnn::IConnectableLayer* layer)
1634 std::vector<fb::Offset<serializer::InputSlot>> inputSlots;
1636 // Get the InputSlots
1637 for (unsigned int slotIndex = 0; slotIndex<layer->GetNumInputSlots(); ++slotIndex)
1639 const IInputSlot& inputSlot = layer->GetInputSlot(slotIndex);
1641 // Get the Connection for the InputSlot
1642 const IOutputSlot* connection = inputSlot.GetConnection();
1644 // Create FlatBuffer Connection
1645 serializer::Connection conn(GetSerializedId(inputSlot.GetConnection()->GetOwningLayerGuid()),
1646 connection->CalculateIndexOnOwner());
1647 // Create FlatBuffer InputSlot
1648 inputSlots.push_back(serializer::CreateInputSlot(m_flatBufferBuilder, slotIndex, &conn));
1653 std::vector<fb::Offset<serializer::OutputSlot>>
1654 SerializerVisitor::CreateOutputSlots(const armnn::IConnectableLayer* layer)
1656 std::vector<fb::Offset<serializer::OutputSlot>> outputSlots;
1658 // Get the OutputSlots
1659 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumOutputSlots(); ++slotIndex)
1661 const IOutputSlot& outputSlot = layer->GetOutputSlot(slotIndex);
1662 const armnn::TensorInfo& tensorInfo = outputSlot.GetTensorInfo();
1664 // Create FlatBuffer Outputslot
1665 outputSlots.push_back(serializer::CreateOutputSlot(m_flatBufferBuilder,
1667 CreateTensorInfo(tensorInfo)));
1673 ISerializer* ISerializer::CreateRaw()
1675 return new Serializer();
1678 ISerializerPtr ISerializer::Create()
1680 return ISerializerPtr(CreateRaw(), &ISerializer::Destroy);
1683 void ISerializer::Destroy(ISerializer* serializer)
1688 void Serializer::Serialize(const INetwork& inNetwork)
1690 // Iterate through to network
1691 inNetwork.Accept(m_SerializerVisitor);
1692 flatbuffers::FlatBufferBuilder& fbBuilder = m_SerializerVisitor.GetFlatBufferBuilder();
1694 // Create FlatBuffer SerializedGraph
1695 auto serializedGraph = serializer::CreateSerializedGraph(
1697 fbBuilder.CreateVector(m_SerializerVisitor.GetSerializedLayers()),
1698 fbBuilder.CreateVector(m_SerializerVisitor.GetInputIds()),
1699 fbBuilder.CreateVector(m_SerializerVisitor.GetOutputIds()),
1700 m_SerializerVisitor.GetVersionTable());
1702 // Serialize the graph
1703 fbBuilder.Finish(serializedGraph);
1706 bool Serializer::SaveSerializedToStream(std::ostream& stream)
1708 flatbuffers::FlatBufferBuilder& fbBuilder = m_SerializerVisitor.GetFlatBufferBuilder();
1710 auto bytesToWrite = boost::numeric_cast<std::streamsize>(fbBuilder.GetSize());
1711 stream.write(reinterpret_cast<const char*>(fbBuilder.GetBufferPointer()), bytesToWrite);
1712 return !stream.bad();
1715 } // namespace armnnSerializer