2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
5 #include "Serializer.hpp"
7 #include <armnn/Descriptors.hpp>
8 #include <armnn/LstmParams.hpp>
9 #include <armnn/QuantizedLstmParams.hpp>
10 #include <armnn/utility/IgnoreUnused.hpp>
11 #include <armnn/utility/NumericCast.hpp>
15 #include "SerializerUtils.hpp"
17 using namespace armnn;
18 namespace fb = flatbuffers;
19 namespace serializer = armnnSerializer;
21 namespace armnnSerializer
24 ISerializer::ISerializer() : pSerializerImpl(new SerializerImpl())
28 ISerializer::~ISerializer() = default;
30 ISerializer* ISerializer::CreateRaw()
32 return new ISerializer();
35 ISerializerPtr ISerializer::Create()
37 return ISerializerPtr(CreateRaw(), &ISerializer::Destroy);
40 void ISerializer::Destroy(ISerializer* serializer)
45 void ISerializer::Serialize(const armnn::INetwork& inNetwork)
47 pSerializerImpl->Serialize(inNetwork);
50 bool ISerializer::SaveSerializedToStream(std::ostream& stream)
52 return pSerializerImpl->SaveSerializedToStream(stream);
55 serializer::ActivationFunction GetFlatBufferActivationFunction(armnn::ActivationFunction function)
59 case armnn::ActivationFunction::Sigmoid:
60 return serializer::ActivationFunction::ActivationFunction_Sigmoid;
61 case armnn::ActivationFunction::TanH:
62 return serializer::ActivationFunction::ActivationFunction_TanH;
63 case armnn::ActivationFunction::Linear:
64 return serializer::ActivationFunction::ActivationFunction_Linear;
65 case armnn::ActivationFunction::ReLu:
66 return serializer::ActivationFunction::ActivationFunction_ReLu;
67 case armnn::ActivationFunction::BoundedReLu:
68 return serializer::ActivationFunction::ActivationFunction_BoundedReLu;
69 case armnn::ActivationFunction::LeakyReLu:
70 return serializer::ActivationFunction::ActivationFunction_LeakyReLu;
71 case armnn::ActivationFunction::Abs:
72 return serializer::ActivationFunction::ActivationFunction_Abs;
73 case armnn::ActivationFunction::Sqrt:
74 return serializer::ActivationFunction::ActivationFunction_Sqrt;
75 case armnn::ActivationFunction::Square:
76 return serializer::ActivationFunction::ActivationFunction_Square;
77 case armnn::ActivationFunction::Elu:
78 return serializer::ActivationFunction::ActivationFunction_Elu;
79 case armnn::ActivationFunction::HardSwish:
80 return serializer::ActivationFunction::ActivationFunction_HardSwish;
82 return serializer::ActivationFunction::ActivationFunction_Sigmoid;
86 serializer::ArgMinMaxFunction GetFlatBufferArgMinMaxFunction(armnn::ArgMinMaxFunction function)
90 case armnn::ArgMinMaxFunction::Max:
91 return serializer::ArgMinMaxFunction::ArgMinMaxFunction_Max;
92 case armnn::ArgMinMaxFunction::Min:
94 return serializer::ArgMinMaxFunction::ArgMinMaxFunction_Min;
98 uint32_t SerializerVisitor::GetSerializedId(armnn::LayerGuid guid)
100 if (m_guidMap.empty())
102 m_guidMap.insert(std::make_pair(guid, m_layerId));
104 else if (m_guidMap.find(guid) == m_guidMap.end())
107 m_guidMap.insert(std::make_pair(guid, m_layerId));
111 return m_guidMap[guid];
114 // Build FlatBuffer for Input Layer
115 void SerializerVisitor::VisitInputLayer(const armnn::IConnectableLayer* layer, LayerBindingId id, const char* name)
119 // Create FlatBuffer BaseLayer
120 auto flatBufferInputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Input);
122 // Create FlatBuffer BindableBaseLayer
123 auto flatBufferInputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
124 flatBufferInputBaseLayer,
126 // Push layer binding id to outputIds.
127 m_inputIds.push_back(id);
129 // Create the FlatBuffer InputLayer
130 auto flatBufferInputLayer = serializer::CreateInputLayer(m_flatBufferBuilder, flatBufferInputBindableBaseLayer);
132 // Add the AnyLayer to the FlatBufferLayers
133 CreateAnyLayer(flatBufferInputLayer.o, serializer::Layer::Layer_InputLayer);
136 // Build FlatBuffer for Output Layer
137 void SerializerVisitor::VisitOutputLayer(const armnn::IConnectableLayer* layer, LayerBindingId id, const char* name)
141 // Create FlatBuffer BaseLayer
142 auto flatBufferOutputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Output);
144 // Create FlatBuffer BindableBaseLayer
145 auto flatBufferOutputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
146 flatBufferOutputBaseLayer,
148 // Push layer binding id to outputIds.
149 m_outputIds.push_back(id);
151 // Create the FlatBuffer OutputLayer
152 auto flatBufferOutputLayer = serializer::CreateOutputLayer(m_flatBufferBuilder, flatBufferOutputBindableBaseLayer);
153 // Add the AnyLayer to the FlatBufferLayers
154 CreateAnyLayer(flatBufferOutputLayer.o, serializer::Layer::Layer_OutputLayer);
157 void SerializerVisitor::VisitAbsLayer(const armnn::IConnectableLayer* layer, const char* name)
160 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Abs);
161 auto flatBufferAbsLayer = serializer::CreateAbsLayer(m_flatBufferBuilder, flatBufferBaseLayer);
163 CreateAnyLayer(flatBufferAbsLayer.o, serializer::Layer::Layer_AbsLayer);
166 // Build FlatBuffer for Activation Layer
167 void SerializerVisitor::VisitActivationLayer(const armnn::IConnectableLayer* layer,
168 const armnn::ActivationDescriptor& descriptor,
173 // Create FlatBuffer BaseLayer
174 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Activation);
176 // Create the FlatBuffer ActivationDescriptor
177 auto flatBufferDescriptor = CreateActivationDescriptor(m_flatBufferBuilder,
178 GetFlatBufferActivationFunction(descriptor.m_Function),
182 // Create the FlatBuffer ActivationLayer
183 auto flatBufferAdditionLayer = CreateActivationLayer(m_flatBufferBuilder,
185 flatBufferDescriptor);
187 // Add the AnyLayer to the FlatBufferLayers
188 CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_ActivationLayer);
191 // Build FlatBuffer for Addition Layer
192 void SerializerVisitor::VisitAdditionLayer(const armnn::IConnectableLayer* layer, const char* name)
196 // Create FlatBuffer BaseLayer
197 auto flatBufferAdditionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Addition);
199 // Create the FlatBuffer AdditionLayer
200 auto flatBufferAdditionLayer = serializer::CreateAdditionLayer(m_flatBufferBuilder, flatBufferAdditionBaseLayer);
202 // Add the AnyLayer to the FlatBufferLayers
203 CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_AdditionLayer);
206 // Build FlatBuffer for ArgMinMax Layer
207 void SerializerVisitor::VisitArgMinMaxLayer(const armnn::IConnectableLayer *layer,
208 const armnn::ArgMinMaxDescriptor& descriptor,
213 // Create FlatBuffer BaseLayer
214 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_ArgMinMax);
216 // Create FlatBuffer Descriptor
217 auto flatBufferDescriptor = CreateArgMinMaxDescriptor(m_flatBufferBuilder,
218 GetFlatBufferArgMinMaxFunction(descriptor.m_Function),
221 // Create FlatBuffer ArgMinMaxLayer
222 auto flatBufferLayer = CreateArgMinMaxLayer(m_flatBufferBuilder,
224 flatBufferDescriptor);
226 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ArgMinMaxLayer);
229 // Build FlatBuffer for BatchToSpaceNd Layer
230 void SerializerVisitor::VisitBatchToSpaceNdLayer(const armnn::IConnectableLayer* layer,
231 const armnn::BatchToSpaceNdDescriptor& descriptor,
236 // Create FlatBuffer BaseLayer
237 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_BatchToSpaceNd);
239 std::vector<unsigned int> crops;
240 crops.reserve(descriptor.m_Crops.size() * 2);
241 for (auto& crop : descriptor.m_Crops)
243 crops.push_back(crop.first);
244 crops.push_back(crop.second);
247 auto flatBufferDescriptor =
248 CreateBatchToSpaceNdDescriptor(m_flatBufferBuilder,
249 m_flatBufferBuilder.CreateVector(descriptor.m_BlockShape),
250 m_flatBufferBuilder.CreateVector(crops),
251 GetFlatBufferDataLayout(descriptor.m_DataLayout));
253 auto flatBufferLayer = serializer::CreateBatchToSpaceNdLayer(m_flatBufferBuilder,
255 flatBufferDescriptor);
257 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_BatchToSpaceNdLayer);
260 void SerializerVisitor::VisitBatchNormalizationLayer(const armnn::IConnectableLayer* layer,
261 const armnn::BatchNormalizationDescriptor& batchNormDescriptor,
262 const armnn::ConstTensor& mean,
263 const armnn::ConstTensor& variance,
264 const armnn::ConstTensor& beta,
265 const armnn::ConstTensor& gamma,
270 auto fbBatchNormalizationBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_BatchNormalization);
271 auto fbBatchNormalizationDescriptor = serializer::CreateBatchNormalizationDescriptor(
273 batchNormDescriptor.m_Eps,
274 GetFlatBufferDataLayout(batchNormDescriptor.m_DataLayout));
276 auto fbMeanConstTensorInfo = CreateConstTensorInfo(mean);
277 auto fbVarianceConstTensorInfo = CreateConstTensorInfo(variance);
278 auto fbBetaConstTensorInfo = CreateConstTensorInfo(beta);
279 auto fbGammaConstTensorInfo = CreateConstTensorInfo(gamma);
280 auto fbBatchNormalizationLayer = serializer::CreateBatchNormalizationLayer(m_flatBufferBuilder,
281 fbBatchNormalizationBaseLayer,
282 fbBatchNormalizationDescriptor,
283 fbMeanConstTensorInfo,
284 fbVarianceConstTensorInfo,
285 fbBetaConstTensorInfo,
286 fbGammaConstTensorInfo);
288 CreateAnyLayer(fbBatchNormalizationLayer.o, serializer::Layer::Layer_BatchNormalizationLayer);
291 void SerializerVisitor::VisitComparisonLayer(const armnn::IConnectableLayer* layer,
292 const armnn::ComparisonDescriptor& descriptor,
297 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Comparison);
298 auto fbDescriptor = serializer::CreateComparisonDescriptor(
300 GetFlatBufferComparisonOperation(descriptor.m_Operation));
302 auto fbLayer = serializer::CreateComparisonLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
303 CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_ComparisonLayer);
306 // Build FlatBuffer for Constant Layer
307 void SerializerVisitor::VisitConstantLayer(const armnn::IConnectableLayer* layer,
308 const armnn::ConstTensor& input,
313 // Create FlatBuffer BaseLayer
314 auto flatBufferConstantBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Constant);
316 auto flatBufferConstTensorInfo = CreateConstTensorInfo(input);
318 // Create the FlatBuffer ConstantLayer
319 auto flatBufferLayer = CreateConstantLayer(m_flatBufferBuilder,
320 flatBufferConstantBaseLayer,
321 flatBufferConstTensorInfo);
323 // Add the AnyLayer to the FlatBufferLayers
324 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ConstantLayer);
327 // Build FlatBuffer for Convolution2dLayer
328 void SerializerVisitor::VisitConvolution2dLayer(const armnn::IConnectableLayer* layer,
329 const armnn::Convolution2dDescriptor& descriptor,
330 const armnn::ConstTensor& weights,
331 const armnn::Optional<armnn::ConstTensor>& biases,
336 // Create FlatBuffer BaseLayer
337 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Convolution2d);
339 auto flatBufferDescriptor = CreateConvolution2dDescriptor(m_flatBufferBuilder,
340 descriptor.m_PadLeft,
341 descriptor.m_PadRight,
343 descriptor.m_PadBottom,
344 descriptor.m_StrideX,
345 descriptor.m_StrideY,
346 descriptor.m_DilationX,
347 descriptor.m_DilationY,
348 descriptor.m_BiasEnabled,
349 GetFlatBufferDataLayout(descriptor.m_DataLayout));
350 auto flatBufferWeightsConstTensorInfo = CreateConstTensorInfo(weights);
351 flatbuffers::Offset<serializer::ConstTensor> flatBufferBiasesConstTensorInfo;
353 if (biases.has_value())
355 flatBufferBiasesConstTensorInfo = CreateConstTensorInfo(biases.value());
358 // Create the FlatBuffer Convolution2dLayer
359 auto flatBufferLayer = CreateConvolution2dLayer(m_flatBufferBuilder,
361 flatBufferDescriptor,
362 flatBufferWeightsConstTensorInfo,
363 flatBufferBiasesConstTensorInfo);
365 // Add the AnyLayer to the FlatBufferLayers
366 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_Convolution2dLayer);
369 void SerializerVisitor::VisitDepthToSpaceLayer(const armnn::IConnectableLayer* layer,
370 const armnn::DepthToSpaceDescriptor& descriptor,
375 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_DepthToSpace);
376 auto fbDescriptor = CreateDepthToSpaceDescriptor(m_flatBufferBuilder,
377 descriptor.m_BlockSize,
378 GetFlatBufferDataLayout(descriptor.m_DataLayout));
380 auto fbLayer = serializer::CreateDepthToSpaceLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
382 CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_DepthToSpaceLayer);
385 void SerializerVisitor::VisitDepthwiseConvolution2dLayer(const armnn::IConnectableLayer* layer,
386 const armnn::DepthwiseConvolution2dDescriptor& descriptor,
387 const armnn::ConstTensor& weights,
388 const armnn::Optional<armnn::ConstTensor>& biases,
393 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_DepthwiseConvolution2d);
394 auto fbDescriptor = CreateDepthwiseConvolution2dDescriptor(m_flatBufferBuilder,
395 descriptor.m_PadLeft,
396 descriptor.m_PadRight,
398 descriptor.m_PadBottom,
399 descriptor.m_StrideX,
400 descriptor.m_StrideY,
401 descriptor.m_DilationX,
402 descriptor.m_DilationY,
403 descriptor.m_BiasEnabled,
404 GetFlatBufferDataLayout(descriptor.m_DataLayout));
406 flatbuffers::Offset<serializer::ConstTensor> fbWeightsConstTensorInfo = CreateConstTensorInfo(weights);
407 flatbuffers::Offset<serializer::ConstTensor> fbBiasesConstTensorInfo;
408 if (biases.has_value())
410 fbBiasesConstTensorInfo = CreateConstTensorInfo(biases.value());
413 auto flatBufferLayer = CreateDepthwiseConvolution2dLayer(m_flatBufferBuilder,
416 fbWeightsConstTensorInfo,
417 fbBiasesConstTensorInfo);
419 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_DepthwiseConvolution2dLayer);
422 void SerializerVisitor::VisitDequantizeLayer(const armnn::IConnectableLayer* layer,
427 auto fbDequantizeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Dequantize);
428 auto fbDequantizeLayer = serializer::CreateDequantizeLayer(m_flatBufferBuilder, fbDequantizeBaseLayer);
430 CreateAnyLayer(fbDequantizeLayer.o, serializer::Layer::Layer_DequantizeLayer);
433 void SerializerVisitor::VisitDetectionPostProcessLayer(const armnn::IConnectableLayer* layer,
434 const armnn::DetectionPostProcessDescriptor& descriptor,
435 const armnn::ConstTensor& anchors,
440 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_DetectionPostProcess);
441 auto fbDescriptor = CreateDetectionPostProcessDescriptor(m_flatBufferBuilder,
442 descriptor.m_MaxDetections,
443 descriptor.m_MaxClassesPerDetection,
444 descriptor.m_DetectionsPerClass,
445 descriptor.m_NmsScoreThreshold,
446 descriptor.m_NmsIouThreshold,
447 descriptor.m_NumClasses,
448 descriptor.m_UseRegularNms,
452 descriptor.m_ScaleH);
454 flatbuffers::Offset<serializer::ConstTensor> fbAnchorsConstTensorInfo = CreateConstTensorInfo(anchors);
456 auto flatBufferLayer = CreateDetectionPostProcessLayer(m_flatBufferBuilder,
459 fbAnchorsConstTensorInfo);
461 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_DetectionPostProcessLayer);
464 void SerializerVisitor::VisitDivisionLayer(const armnn::IConnectableLayer* layer, const char* name)
468 auto fbDivisionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Division);
469 auto fbDivisionLayer = serializer::CreateDivisionLayer(m_flatBufferBuilder, fbDivisionBaseLayer);
471 CreateAnyLayer(fbDivisionLayer.o, serializer::Layer::Layer_DivisionLayer);
474 void SerializerVisitor::VisitElementwiseUnaryLayer(const armnn::IConnectableLayer* layer,
475 const armnn::ElementwiseUnaryDescriptor& descriptor,
480 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_ElementwiseUnary);
481 auto fbDescriptor = serializer::CreateElementwiseUnaryDescriptor(
483 GetFlatBufferUnaryOperation(descriptor.m_Operation));
485 auto fbLayer = serializer::CreateElementwiseUnaryLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
486 CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_ElementwiseUnaryLayer);
489 void SerializerVisitor::VisitEqualLayer(const armnn::IConnectableLayer* layer, const char* name)
493 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Equal);
494 auto fbEqualLayer = serializer::CreateEqualLayer(m_flatBufferBuilder, fbBaseLayer);
496 CreateAnyLayer(fbEqualLayer.o, serializer::Layer::Layer_EqualLayer);
499 void SerializerVisitor::VisitFillLayer(const armnn::IConnectableLayer* layer,
500 const armnn::FillDescriptor& fillDescriptor,
505 auto fbFillBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Fill);
507 auto fbDescriptor = serializer::CreateFillDescriptor(m_flatBufferBuilder, fillDescriptor.m_Value);
509 auto fbFillLayer = serializer::CreateFillLayer(m_flatBufferBuilder, fbFillBaseLayer, fbDescriptor);
511 CreateAnyLayer(fbFillLayer.o, serializer::Layer::Layer_FillLayer);
514 void SerializerVisitor::VisitFloorLayer(const armnn::IConnectableLayer *layer, const char *name)
518 auto flatBufferFloorBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Floor);
519 auto flatBufferFloorLayer = serializer::CreateFloorLayer(m_flatBufferBuilder, flatBufferFloorBaseLayer);
521 CreateAnyLayer(flatBufferFloorLayer.o, serializer::Layer::Layer_FloorLayer);
524 void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer,
527 armnn::GatherDescriptor gatherDescriptor{};
528 VisitGatherLayer(layer, gatherDescriptor, name);
531 void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer,
532 const armnn::GatherDescriptor& gatherDescriptor,
537 auto fbGatherDescriptor = CreateGatherDescriptor(m_flatBufferBuilder,
538 gatherDescriptor.m_Axis);
539 auto fbGatherBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Gather);
540 auto flatBufferLayer = serializer::CreateGatherLayer(m_flatBufferBuilder, fbGatherBaseLayer, fbGatherDescriptor);
542 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_GatherLayer);
545 void SerializerVisitor::VisitGreaterLayer(const armnn::IConnectableLayer* layer, const char* name)
549 auto fbGreaterBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Greater);
550 auto fbGreaterLayer = serializer::CreateGreaterLayer(m_flatBufferBuilder, fbGreaterBaseLayer);
552 CreateAnyLayer(fbGreaterLayer.o, serializer::Layer::Layer_GreaterLayer);
555 void SerializerVisitor::VisitInstanceNormalizationLayer(
556 const armnn::IConnectableLayer* layer,
557 const armnn::InstanceNormalizationDescriptor& instanceNormalizationDescriptor,
562 auto fbDescriptor = serializer::CreateInstanceNormalizationDescriptor(
564 instanceNormalizationDescriptor.m_Gamma,
565 instanceNormalizationDescriptor.m_Beta,
566 instanceNormalizationDescriptor.m_Eps,
567 GetFlatBufferDataLayout(instanceNormalizationDescriptor.m_DataLayout));
569 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_InstanceNormalization);
570 auto fbLayer = serializer::CreateInstanceNormalizationLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
572 CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_InstanceNormalizationLayer);
575 void SerializerVisitor::VisitL2NormalizationLayer(const armnn::IConnectableLayer* layer,
576 const armnn::L2NormalizationDescriptor& l2NormalizationDescriptor,
581 // Create FlatBuffer BaseLayer
582 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_L2Normalization);
584 // Create the FlatBuffer L2Normalization Descriptor
585 auto fbDescriptor = serializer::CreateL2NormalizationDescriptor(
587 GetFlatBufferDataLayout(l2NormalizationDescriptor.m_DataLayout),
588 l2NormalizationDescriptor.m_Eps);
590 // Create FlatBuffer layer
591 auto fbLayer = serializer::CreateL2NormalizationLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
593 CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_L2NormalizationLayer);
596 void SerializerVisitor::VisitLogicalBinaryLayer(const armnn::IConnectableLayer* layer,
597 const armnn::LogicalBinaryDescriptor& descriptor,
602 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_LogicalBinary);
603 auto fbDescriptor = serializer::CreateLogicalBinaryDescriptor(
605 GetFlatBufferLogicalBinaryOperation(descriptor.m_Operation));
607 auto fbLayer = serializer::CreateLogicalBinaryLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
608 CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_LogicalBinaryLayer);
611 void SerializerVisitor::VisitLogSoftmaxLayer(const armnn::IConnectableLayer* layer,
612 const armnn::LogSoftmaxDescriptor& logSoftmaxDescriptor,
617 // Create FlatBuffer BaseLayer
618 auto flatBufferLogSoftmaxBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_LogSoftmax);
620 // Create the FlatBuffer LogSoftmaxDescriptor
621 auto flatBufferLogSoftmaxDesc =
622 serializer::CreateLogSoftmaxDescriptor(m_flatBufferBuilder,
623 logSoftmaxDescriptor.m_Beta,
624 logSoftmaxDescriptor.m_Axis);
626 // Create the FlatBuffer LogSoftmaxLayer
627 auto flatBufferLogSoftmaxLayer =
628 serializer::CreateLogSoftmaxLayer(m_flatBufferBuilder,
629 flatBufferLogSoftmaxBaseLayer,
630 flatBufferLogSoftmaxDesc);
632 CreateAnyLayer(flatBufferLogSoftmaxLayer.o, serializer::Layer::Layer_LogSoftmaxLayer);
635 void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer,
636 const armnn::LstmDescriptor& descriptor,
637 const armnn::LstmInputParams& params,
642 auto fbLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Lstm);
644 auto fbLstmDescriptor = serializer::CreateLstmDescriptor(
646 descriptor.m_ActivationFunc,
647 descriptor.m_ClippingThresCell,
648 descriptor.m_ClippingThresProj,
649 descriptor.m_CifgEnabled,
650 descriptor.m_PeepholeEnabled,
651 descriptor.m_ProjectionEnabled,
652 descriptor.m_LayerNormEnabled);
654 // Get mandatory input parameters
655 auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights);
656 auto inputToCellWeights = CreateConstTensorInfo(*params.m_InputToCellWeights);
657 auto inputToOutputWeights = CreateConstTensorInfo(*params.m_InputToOutputWeights);
658 auto recurrentToForgetWeights = CreateConstTensorInfo(*params.m_RecurrentToForgetWeights);
659 auto recurrentToCellWeights = CreateConstTensorInfo(*params.m_RecurrentToCellWeights);
660 auto recurrentToOutputWeights = CreateConstTensorInfo(*params.m_RecurrentToOutputWeights);
661 auto forgetGateBias = CreateConstTensorInfo(*params.m_ForgetGateBias);
662 auto cellBias = CreateConstTensorInfo(*params.m_CellBias);
663 auto outputGateBias = CreateConstTensorInfo(*params.m_OutputGateBias);
665 //Define optional parameters, these will be set depending on configuration in Lstm descriptor
666 flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights;
667 flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights;
668 flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights;
669 flatbuffers::Offset<serializer::ConstTensor> inputGateBias;
670 flatbuffers::Offset<serializer::ConstTensor> projectionWeights;
671 flatbuffers::Offset<serializer::ConstTensor> projectionBias;
672 flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights;
673 flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights;
674 flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights;
675 flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights;
676 flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights;
677 flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights;
679 if (!descriptor.m_CifgEnabled)
681 inputToInputWeights = CreateConstTensorInfo(*params.m_InputToInputWeights);
682 recurrentToInputWeights = CreateConstTensorInfo(*params.m_RecurrentToInputWeights);
683 cellToInputWeights = CreateConstTensorInfo(*params.m_CellToInputWeights);
684 inputGateBias = CreateConstTensorInfo(*params.m_InputGateBias);
687 if (descriptor.m_ProjectionEnabled)
689 projectionWeights = CreateConstTensorInfo(*params.m_ProjectionWeights);
690 projectionBias = CreateConstTensorInfo(*params.m_ProjectionBias);
693 if (descriptor.m_PeepholeEnabled)
695 cellToForgetWeights = CreateConstTensorInfo(*params.m_CellToForgetWeights);
696 cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights);
699 if (descriptor.m_LayerNormEnabled)
701 if (!descriptor.m_CifgEnabled)
703 inputLayerNormWeights = CreateConstTensorInfo((*params.m_InputLayerNormWeights));
705 forgetLayerNormWeights = CreateConstTensorInfo(*params.m_ForgetLayerNormWeights);
706 cellLayerNormWeights = CreateConstTensorInfo(*params.m_CellLayerNormWeights);
707 outputLayerNormWeights = CreateConstTensorInfo(*params.m_OutputLayerNormWeights);
710 auto fbLstmParams = serializer::CreateLstmInputParams(
712 inputToForgetWeights,
714 inputToOutputWeights,
715 recurrentToForgetWeights,
716 recurrentToCellWeights,
717 recurrentToOutputWeights,
722 recurrentToInputWeights,
729 inputLayerNormWeights,
730 forgetLayerNormWeights,
731 cellLayerNormWeights,
732 outputLayerNormWeights);
734 auto fbLstmLayer = serializer::CreateLstmLayer(
740 CreateAnyLayer(fbLstmLayer.o, serializer::Layer::Layer_LstmLayer);
743 void SerializerVisitor::VisitMaximumLayer(const armnn::IConnectableLayer* layer, const char* name)
747 auto fbMaximumBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Maximum);
748 auto fbMaximumLayer = serializer::CreateMaximumLayer(m_flatBufferBuilder, fbMaximumBaseLayer);
750 CreateAnyLayer(fbMaximumLayer.o, serializer::Layer::Layer_MaximumLayer);
753 void SerializerVisitor::VisitMeanLayer(const armnn::IConnectableLayer* layer,
754 const armnn::MeanDescriptor& descriptor,
759 auto fbMeanBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Mean);
760 auto fbMeanDescriptor = serializer::CreateMeanDescriptor(m_flatBufferBuilder,
761 m_flatBufferBuilder.CreateVector(descriptor.m_Axis),
762 descriptor.m_KeepDims);
764 auto fbMeanLayer = serializer::CreateMeanLayer(m_flatBufferBuilder,
768 CreateAnyLayer(fbMeanLayer.o, serializer::Layer::Layer_MeanLayer);
771 void SerializerVisitor::VisitMinimumLayer(const armnn::IConnectableLayer* layer, const char* name)
775 auto fbMinimumBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Minimum);
776 auto fbMinimumLayer = serializer::CreateMinimumLayer(m_flatBufferBuilder, fbMinimumBaseLayer);
778 CreateAnyLayer(fbMinimumLayer.o, serializer::Layer::Layer_MinimumLayer);
781 void SerializerVisitor::VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name)
785 auto fbMergeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Merge);
786 auto fbMergeLayer = serializer::CreateMergeLayer(m_flatBufferBuilder, fbMergeBaseLayer);
788 CreateAnyLayer(fbMergeLayer.o, serializer::Layer::Layer_MergeLayer);
791 void SerializerVisitor::VisitMergerLayer(const armnn::IConnectableLayer* layer,
792 const armnn::MergerDescriptor& mergerDescriptor,
795 VisitConcatLayer(layer, mergerDescriptor, name);
798 void SerializerVisitor::VisitConcatLayer(const armnn::IConnectableLayer* layer,
799 const armnn::ConcatDescriptor& concatDescriptor,
804 auto flatBufferConcatBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Concat);
806 std::vector<flatbuffers::Offset<UintVector>> views;
807 for (unsigned int v = 0; v < concatDescriptor.GetNumViews(); ++v)
809 const uint32_t* origin = concatDescriptor.GetViewOrigin(v);
810 std::vector<uint32_t> origins;
811 for (unsigned int d = 0; d < concatDescriptor.GetNumDimensions(); ++d)
813 origins.push_back(origin[d]);
815 auto view = m_flatBufferBuilder.CreateVector(origins);
816 auto uintVector = CreateUintVector(m_flatBufferBuilder, view);
817 views.push_back(uintVector);
820 auto flatBufferConcatDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder,
821 concatDescriptor.GetConcatAxis(),
822 concatDescriptor.GetNumViews(),
823 concatDescriptor.GetNumDimensions(),
824 m_flatBufferBuilder.CreateVector(views));
826 auto flatBufferLayer = CreateConcatLayer(m_flatBufferBuilder,
827 flatBufferConcatBaseLayer,
828 flatBufferConcatDescriptor);
830 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ConcatLayer);
833 void SerializerVisitor::VisitMultiplicationLayer(const armnn::IConnectableLayer* layer, const char* name)
837 auto fbMultiplicationBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Multiplication);
838 auto fbMultiplicationLayer = serializer::CreateMultiplicationLayer(m_flatBufferBuilder,
839 fbMultiplicationBaseLayer);
841 CreateAnyLayer(fbMultiplicationLayer.o, serializer::Layer::Layer_MultiplicationLayer);
844 void SerializerVisitor::VisitPadLayer(const armnn::IConnectableLayer* layer,
845 const armnn::PadDescriptor& padDescriptor,
850 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Pad);
852 std::vector<unsigned int> padList;
853 for (auto& p: padDescriptor.m_PadList)
855 padList.push_back(p.first);
856 padList.push_back(p.second);
859 auto flatBufferPadDesc = serializer::CreatePadDescriptor(m_flatBufferBuilder,
860 m_flatBufferBuilder.CreateVector(padList),
861 padDescriptor.m_PadValue);
863 auto flatBufferPadLayer = serializer::CreatePadLayer(m_flatBufferBuilder,
867 CreateAnyLayer(flatBufferPadLayer.o, serializer::Layer::Layer_PadLayer);
870 void SerializerVisitor::VisitPermuteLayer(const armnn::IConnectableLayer* layer,
871 const armnn::PermuteDescriptor& permuteDescriptor,
876 // Create FlatBuffer BaseLayer
877 auto flatBufferPermuteBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Permute);
879 std::vector<unsigned int> dimMappings;
880 for (unsigned int i=0; i<permuteDescriptor.m_DimMappings.GetSize(); ++i)
882 dimMappings.push_back(permuteDescriptor.m_DimMappings[i]);
885 auto flatBufferPermuteDesc = serializer::CreatePermuteDescriptor(m_flatBufferBuilder,
886 m_flatBufferBuilder.CreateVector(dimMappings));
888 // Create the FlatBuffer PermuteLayer
889 auto flatBufferPermuteLayer = serializer::CreatePermuteLayer(m_flatBufferBuilder,
890 flatBufferPermuteBaseLayer,
891 flatBufferPermuteDesc);
893 // Add the AnyLayer to the FlatBufferLayers
894 CreateAnyLayer(flatBufferPermuteLayer.o, serializer::Layer::Layer_PermuteLayer);
897 // Build FlatBuffer for Rank Layer
898 void SerializerVisitor::VisitRankLayer(const armnn::IConnectableLayer* layer,
902 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Rank);
903 auto flatBufferRankLayer = serializer::CreateRankLayer(m_flatBufferBuilder, flatBufferBaseLayer);
905 CreateAnyLayer(flatBufferRankLayer.o, serializer::Layer::Layer_RankLayer);
907 // Build FlatBuffer for Reshape Layer
908 void SerializerVisitor::VisitReshapeLayer(const armnn::IConnectableLayer* layer,
909 const armnn::ReshapeDescriptor& reshapeDescriptor,
914 // Create FlatBuffer BaseLayer
915 auto flatBufferReshapeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Reshape);
917 std::vector<unsigned int> targetShape;
918 for (unsigned int i =0; i < reshapeDescriptor.m_TargetShape.GetNumDimensions(); i++)
920 targetShape.push_back(reshapeDescriptor.m_TargetShape[i]);
923 auto flatBufferReshapeDesc = serializer::CreateReshapeDescriptor(m_flatBufferBuilder,
924 m_flatBufferBuilder.CreateVector(targetShape));
926 // Create the FlatBuffer ReshapeLayer
927 auto flatBufferReshapeLayer = serializer::CreateReshapeLayer(m_flatBufferBuilder, flatBufferReshapeBaseLayer,
928 flatBufferReshapeDesc);
930 // Add the AnyLayer to the FlatBufferLayers
931 CreateAnyLayer(flatBufferReshapeLayer.o, serializer::Layer::Layer_ReshapeLayer);
934 void SerializerVisitor::VisitResizeBilinearLayer(const armnn::IConnectableLayer* layer,
935 const armnn::ResizeBilinearDescriptor& resizeDescriptor,
940 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_ResizeBilinear);
942 auto flatBufferDescriptor =
943 CreateResizeBilinearDescriptor(m_flatBufferBuilder,
944 resizeDescriptor.m_TargetWidth,
945 resizeDescriptor.m_TargetHeight,
946 GetFlatBufferDataLayout(resizeDescriptor.m_DataLayout),
947 resizeDescriptor.m_AlignCorners,
948 resizeDescriptor.m_HalfPixelCenters);
950 auto flatBufferLayer = serializer::CreateResizeBilinearLayer(m_flatBufferBuilder,
952 flatBufferDescriptor);
954 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ResizeBilinearLayer);
957 void SerializerVisitor::VisitResizeLayer(const armnn::IConnectableLayer* layer,
958 const armnn::ResizeDescriptor& resizeDescriptor,
963 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Resize);
965 auto flatBufferDescriptor =
966 CreateResizeDescriptor(m_flatBufferBuilder,
967 resizeDescriptor.m_TargetHeight,
968 resizeDescriptor.m_TargetWidth,
969 GetFlatBufferResizeMethod(resizeDescriptor.m_Method),
970 GetFlatBufferDataLayout(resizeDescriptor.m_DataLayout),
971 resizeDescriptor.m_AlignCorners,
972 resizeDescriptor.m_HalfPixelCenters);
974 auto flatBufferLayer = serializer::CreateResizeLayer(m_flatBufferBuilder,
976 flatBufferDescriptor);
978 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ResizeLayer);
981 void SerializerVisitor::VisitRsqrtLayer(const armnn::IConnectableLayer* layer, const char* name)
985 auto fbRsqrtBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Rsqrt);
986 auto fbRsqrtLayer = serializer::CreateRsqrtLayer(m_flatBufferBuilder, fbRsqrtBaseLayer);
988 CreateAnyLayer(fbRsqrtLayer.o, serializer::Layer::Layer_RsqrtLayer);
991 void SerializerVisitor::VisitSliceLayer(const armnn::IConnectableLayer* layer,
992 const armnn::SliceDescriptor& sliceDescriptor,
997 auto fbSliceBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Slice);
998 auto fbSliceDescriptor = CreateSliceDescriptor(m_flatBufferBuilder,
999 m_flatBufferBuilder.CreateVector(sliceDescriptor.m_Begin),
1000 m_flatBufferBuilder.CreateVector(sliceDescriptor.m_Size));
1002 auto fbSliceLayer = serializer::CreateSliceLayer(m_flatBufferBuilder, fbSliceBaseLayer, fbSliceDescriptor);
1004 CreateAnyLayer(fbSliceLayer.o, serializer::Layer::Layer_SliceLayer);
1007 // Build FlatBuffer for Softmax Layer
1008 void SerializerVisitor::VisitSoftmaxLayer(const armnn::IConnectableLayer* layer,
1009 const armnn::SoftmaxDescriptor& softmaxDescriptor,
1014 // Create FlatBuffer BaseLayer
1015 auto flatBufferSoftmaxBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Softmax);
1017 // Create the FlatBuffer SoftmaxDescriptor
1018 auto flatBufferSoftmaxDesc =
1019 serializer::CreateSoftmaxDescriptor(m_flatBufferBuilder, softmaxDescriptor.m_Beta);
1021 // Create the FlatBuffer SoftmaxLayer
1022 auto flatBufferSoftmaxLayer =
1023 serializer::CreateSoftmaxLayer(m_flatBufferBuilder,
1024 flatBufferSoftmaxBaseLayer,
1025 flatBufferSoftmaxDesc);
1027 CreateAnyLayer(flatBufferSoftmaxLayer.o, serializer::Layer::Layer_SoftmaxLayer);
1030 void SerializerVisitor::VisitPooling2dLayer(const armnn::IConnectableLayer* layer,
1031 const armnn::Pooling2dDescriptor& pooling2dDescriptor,
1036 auto fbPooling2dBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Pooling2d);
1037 auto fbPooling2dDescriptor = serializer::CreatePooling2dDescriptor(
1038 m_flatBufferBuilder,
1039 GetFlatBufferPoolingAlgorithm(pooling2dDescriptor.m_PoolType),
1040 pooling2dDescriptor.m_PadLeft,
1041 pooling2dDescriptor.m_PadRight,
1042 pooling2dDescriptor.m_PadTop,
1043 pooling2dDescriptor.m_PadBottom,
1044 pooling2dDescriptor.m_PoolWidth,
1045 pooling2dDescriptor.m_PoolHeight,
1046 pooling2dDescriptor.m_StrideX,
1047 pooling2dDescriptor.m_StrideY,
1048 GetFlatBufferOutputShapeRounding(pooling2dDescriptor.m_OutputShapeRounding),
1049 GetFlatBufferPaddingMethod(pooling2dDescriptor.m_PaddingMethod),
1050 GetFlatBufferDataLayout(pooling2dDescriptor.m_DataLayout));
1052 auto fbPooling2dLayer = serializer::CreatePooling2dLayer(m_flatBufferBuilder,
1053 fbPooling2dBaseLayer,
1054 fbPooling2dDescriptor);
1056 CreateAnyLayer(fbPooling2dLayer.o, serializer::Layer::Layer_Pooling2dLayer);
1059 void SerializerVisitor::VisitPreluLayer(const armnn::IConnectableLayer* layer,
1064 // Create FlatBuffer BaseLayer
1065 auto flatBufferPreluBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Prelu);
1067 // Create the FlatBuffer AdditionLayer
1068 auto flatBufferPreluLayer = serializer::CreatePreluLayer(m_flatBufferBuilder, flatBufferPreluBaseLayer);
1070 // Add the AnyLayer to the FlatBufferLayers
1071 CreateAnyLayer(flatBufferPreluLayer.o, serializer::Layer::Layer_PreluLayer);
1074 void SerializerVisitor::VisitQuantizeLayer(const armnn::IConnectableLayer *layer, const char *name)
1078 auto fbQuantizeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Quantize);
1079 auto fbQuantizeLayer = serializer::CreateQuantizeLayer(m_flatBufferBuilder,
1080 fbQuantizeBaseLayer);
1081 CreateAnyLayer(fbQuantizeLayer.o, serializer::Layer::Layer_QuantizeLayer);
1084 // Build FlatBuffer for FullyConnected Layer
1085 void SerializerVisitor::VisitFullyConnectedLayer(const armnn::IConnectableLayer* layer,
1086 const armnn::FullyConnectedDescriptor& fullyConnectedDescriptor,
1087 const armnn::ConstTensor& weights,
1088 const armnn::Optional<armnn::ConstTensor>& biases,
1093 // Create FlatBuffer BaseLayer
1094 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_FullyConnected);
1096 // Create FlatBuffer FullyConnectedDescriptor
1097 auto flatBufferDescriptor =
1098 serializer::CreateFullyConnectedDescriptor(m_flatBufferBuilder,
1099 fullyConnectedDescriptor.m_BiasEnabled,
1100 fullyConnectedDescriptor.m_TransposeWeightMatrix);
1102 // Create FlatBuffer weights data
1103 auto flatBufferWeights = CreateConstTensorInfo(weights);
1105 // Create FlatBuffer bias data
1106 flatbuffers::Offset<serializer::ConstTensor> flatBufferBiases;
1107 if (fullyConnectedDescriptor.m_BiasEnabled)
1109 flatBufferBiases = CreateConstTensorInfo(biases.value());
1112 // Create FlatBuffer FullyConnectedLayer
1113 auto flatBufferLayer = serializer::CreateFullyConnectedLayer(m_flatBufferBuilder,
1114 flatBufferBaseLayer,
1115 flatBufferDescriptor,
1119 // Add created FullyConnectedLayer to the FlatBufferLayers
1120 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_FullyConnectedLayer);
1123 // Build FlatBuffer for SpaceToBatchNd Layer
1124 void SerializerVisitor::VisitSpaceToBatchNdLayer(const armnn::IConnectableLayer* layer,
1125 const armnn::SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
1130 // Create FlatBuffer BaseLayer
1131 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_SpaceToBatchNd);
1133 std::vector<unsigned int> padList;
1134 padList.reserve(spaceToBatchNdDescriptor.m_PadList.size()*2);
1135 for (auto& pad : spaceToBatchNdDescriptor.m_PadList)
1137 padList.push_back(pad.first);
1138 padList.push_back(pad.second);
1141 auto flatBufferDescriptor =
1142 CreateSpaceToBatchNdDescriptor(m_flatBufferBuilder,
1143 m_flatBufferBuilder.CreateVector(spaceToBatchNdDescriptor.m_BlockShape),
1144 m_flatBufferBuilder.CreateVector(padList),
1145 GetFlatBufferDataLayout(spaceToBatchNdDescriptor.m_DataLayout));
1147 auto flatBufferLayer = serializer::CreateSpaceToBatchNdLayer(m_flatBufferBuilder,
1148 flatBufferBaseLayer,
1149 flatBufferDescriptor);
1151 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_SpaceToBatchNdLayer);
1154 // Build FlatBuffer for SpaceToDepthLayer
1155 void SerializerVisitor::VisitSpaceToDepthLayer(const armnn::IConnectableLayer* layer,
1156 const armnn::SpaceToDepthDescriptor& spaceToDepthDescriptor,
1161 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_SpaceToDepth);
1162 auto flatBufferDescriptor =
1163 CreateSpaceToDepthDescriptor(m_flatBufferBuilder,
1164 spaceToDepthDescriptor.m_BlockSize,
1165 GetFlatBufferDataLayout(spaceToDepthDescriptor.m_DataLayout));
1167 auto flatBufferLayer = serializer::CreateSpaceToDepthLayer(m_flatBufferBuilder,
1168 flatBufferBaseLayer,
1169 flatBufferDescriptor);
1171 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_SpaceToDepthLayer);
1174 // Build FlatBuffer for Splitter Layer
1175 void SerializerVisitor::VisitSplitterLayer(const armnn::IConnectableLayer* layer,
1176 const armnn::ViewsDescriptor& viewsDescriptor,
1181 // Create FlatBuffer ViewOrigins
1182 std::vector<flatbuffers::Offset<UintVector>> flatBufferViewOrigins;
1183 flatBufferViewOrigins.reserve(viewsDescriptor.GetNumViews());
1185 for(unsigned int vIdx = 0; vIdx < viewsDescriptor.GetNumViews(); ++vIdx)
1187 std::vector<uint32_t> viewOrigin;
1188 viewOrigin.reserve(viewsDescriptor.GetNumDimensions());
1191 for(unsigned int dIdx = 0; dIdx < viewsDescriptor.GetNumDimensions(); ++dIdx)
1193 viewOrigin.push_back(viewsDescriptor.GetViewOrigin(vIdx)[dIdx]);
1196 flatBufferViewOrigins.push_back(CreateUintVector(m_flatBufferBuilder,
1197 m_flatBufferBuilder.CreateVector(viewOrigin)));
1200 // Create FlatBuffer OriginsDescriptor
1201 auto flatBufferOriginDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder,
1202 viewsDescriptor.GetOrigins().GetConcatAxis(),
1203 viewsDescriptor.GetOrigins().GetNumViews(),
1204 viewsDescriptor.GetOrigins().GetNumDimensions(),
1205 m_flatBufferBuilder.CreateVector(flatBufferViewOrigins));
1207 // Create FlatBuffer ViewOrigins
1208 std::vector<flatbuffers::Offset<UintVector>> flatBufferViewSizes;
1209 flatBufferViewSizes.reserve(viewsDescriptor.GetNumViews());
1211 for(unsigned int vIdx = 0; vIdx < viewsDescriptor.GetNumViews(); ++vIdx)
1213 std::vector<uint32_t> viewSize;
1214 viewSize.reserve(viewsDescriptor.GetNumDimensions());
1217 for(unsigned int dIdx = 0; dIdx < viewsDescriptor.GetNumDimensions(); ++dIdx)
1219 viewSize.push_back(viewsDescriptor.GetViewSizes(vIdx)[dIdx]);
1222 flatBufferViewSizes.push_back(CreateUintVector(m_flatBufferBuilder,
1223 m_flatBufferBuilder.CreateVector(viewSize)));
1226 // Create FlatBuffer ViewsDescriptor
1227 auto flatBufferViewsDescriptor = CreateViewsDescriptor(m_flatBufferBuilder,
1228 flatBufferOriginDescriptor,
1229 m_flatBufferBuilder.CreateVector(flatBufferViewSizes));
1231 // Create FlatBuffer BaseLayer
1232 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Splitter);
1234 auto flatBufferSplitterLayer = serializer::CreateSplitterLayer(m_flatBufferBuilder,
1235 flatBufferBaseLayer,
1236 flatBufferViewsDescriptor);
1238 CreateAnyLayer(flatBufferSplitterLayer.o, serializer::Layer::Layer_SplitterLayer);
1241 void SerializerVisitor::VisitNormalizationLayer(const armnn::IConnectableLayer* layer,
1242 const armnn::NormalizationDescriptor& descriptor,
1247 auto fbNormalizationBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Normalization);
1249 auto fbNormalizationDescriptor = serializer::CreateNormalizationDescriptor(
1250 m_flatBufferBuilder,
1251 GetFlatBufferNormalizationAlgorithmChannel(descriptor.m_NormChannelType),
1252 GetFlatBufferNormalizationAlgorithmMethod(descriptor.m_NormMethodType),
1253 descriptor.m_NormSize,
1257 GetFlatBufferDataLayout(descriptor.m_DataLayout));
1259 auto flatBufferLayer = serializer::CreateNormalizationLayer(m_flatBufferBuilder,
1260 fbNormalizationBaseLayer,
1261 fbNormalizationDescriptor);
1263 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_NormalizationLayer);
1266 void SerializerVisitor::VisitStackLayer(const armnn::IConnectableLayer* layer,
1267 const armnn::StackDescriptor& stackDescriptor,
1272 auto stackBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Stack);
1274 std::vector<unsigned int> inputShape;
1275 for (unsigned int i =0; i < stackDescriptor.m_InputShape.GetNumDimensions(); i++)
1277 inputShape.push_back(stackDescriptor.m_InputShape[i]);
1280 auto flatBufferStackDescriptor = CreateStackDescriptor(m_flatBufferBuilder,
1281 stackDescriptor.m_Axis,
1282 stackDescriptor.m_NumInputs,
1283 m_flatBufferBuilder.CreateVector(inputShape));
1285 auto stackLayer = serializer::CreateStackLayer(m_flatBufferBuilder, stackBaseLayer, flatBufferStackDescriptor);
1286 CreateAnyLayer(stackLayer.o, serializer::Layer::Layer_StackLayer);
1289 void SerializerVisitor::VisitStandInLayer(const armnn::IConnectableLayer *layer,
1290 const armnn::StandInDescriptor& standInDescriptor,
1295 auto fbDescriptor = serializer::CreateStandInDescriptor(m_flatBufferBuilder,
1296 standInDescriptor.m_NumInputs,
1297 standInDescriptor.m_NumOutputs);
1299 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_StandIn);
1300 auto fbLayer = serializer::CreateStandInLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
1302 CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_StandInLayer);
1305 void SerializerVisitor::VisitStridedSliceLayer(const armnn::IConnectableLayer* layer,
1306 const armnn::StridedSliceDescriptor& stridedSliceDescriptor,
1311 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_StridedSlice);
1313 auto flatBufferDescriptor =
1314 CreateStridedSliceDescriptor(m_flatBufferBuilder,
1315 m_flatBufferBuilder.CreateVector(stridedSliceDescriptor.m_Begin),
1316 m_flatBufferBuilder.CreateVector(stridedSliceDescriptor.m_End),
1317 m_flatBufferBuilder.CreateVector(stridedSliceDescriptor.m_Stride),
1318 stridedSliceDescriptor.m_BeginMask,
1319 stridedSliceDescriptor.m_EndMask,
1320 stridedSliceDescriptor.m_ShrinkAxisMask,
1321 stridedSliceDescriptor.m_EllipsisMask,
1322 stridedSliceDescriptor.m_NewAxisMask,
1323 GetFlatBufferDataLayout(stridedSliceDescriptor.m_DataLayout));
1325 auto flatBufferLayer = serializer::CreateStridedSliceLayer(m_flatBufferBuilder,
1326 flatBufferBaseLayer,
1327 flatBufferDescriptor);
1329 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_StridedSliceLayer);
1332 void SerializerVisitor::VisitSubtractionLayer(const armnn::IConnectableLayer* layer, const char* name)
1336 auto fbSubtractionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Subtraction);
1337 auto fbSubtractionLayer = serializer::CreateSubtractionLayer(m_flatBufferBuilder, fbSubtractionBaseLayer);
1339 CreateAnyLayer(fbSubtractionLayer.o, serializer::Layer::Layer_SubtractionLayer);
1342 void SerializerVisitor::VisitSwitchLayer(const armnn::IConnectableLayer* layer, const char* name)
1346 auto fbSwitchBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Switch);
1347 auto fbSwitchLayer = serializer::CreateSwitchLayer(m_flatBufferBuilder, fbSwitchBaseLayer);
1349 CreateAnyLayer(fbSwitchLayer.o, serializer::Layer::Layer_SwitchLayer);
1352 void SerializerVisitor::VisitTransposeConvolution2dLayer(
1353 const armnn::IConnectableLayer* layer,
1354 const armnn::TransposeConvolution2dDescriptor& descriptor,
1355 const armnn::ConstTensor& weights,
1356 const armnn::Optional<armnn::ConstTensor>& biases,
1361 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Convolution2d);
1362 auto fbDescriptor = CreateTransposeConvolution2dDescriptor(m_flatBufferBuilder,
1363 descriptor.m_PadLeft,
1364 descriptor.m_PadRight,
1365 descriptor.m_PadTop,
1366 descriptor.m_PadBottom,
1367 descriptor.m_StrideX,
1368 descriptor.m_StrideY,
1369 descriptor.m_BiasEnabled,
1370 GetFlatBufferDataLayout(descriptor.m_DataLayout));
1373 auto fbWeightsConstTensorInfo = CreateConstTensorInfo(weights);
1374 flatbuffers::Offset<serializer::ConstTensor> fbBiasesConstTensorInfo;
1375 if (biases.has_value())
1377 fbBiasesConstTensorInfo = CreateConstTensorInfo(biases.value());
1380 auto fbLayer = CreateTransposeConvolution2dLayer(m_flatBufferBuilder,
1383 fbWeightsConstTensorInfo,
1384 fbBiasesConstTensorInfo);
1386 CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_TransposeConvolution2dLayer);
1389 void SerializerVisitor::VisitTransposeLayer(const armnn::IConnectableLayer* layer,
1390 const armnn::TransposeDescriptor& descriptor,
1395 // Create FlatBuffer BaseLayer
1396 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Transpose);
1398 std::vector<unsigned int> dimMappings;
1399 for (unsigned int i=0; i<descriptor.m_DimMappings.GetSize(); ++i)
1401 dimMappings.push_back(descriptor.m_DimMappings[i]);
1404 auto flatBufferDesc = serializer::CreateTransposeDescriptor(m_flatBufferBuilder,
1405 m_flatBufferBuilder.CreateVector(dimMappings));
1407 // Create the FlatBuffer TransposeLayer
1408 auto flatBufferLayer = serializer::CreateTransposeLayer(m_flatBufferBuilder,
1409 flatBufferBaseLayer,
1412 // Add the AnyLayer to the FlatBufferLayers
1413 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_TransposeLayer);
1416 void SerializerVisitor::VisitQLstmLayer(const armnn::IConnectableLayer* layer,
1417 const armnn::QLstmDescriptor& descriptor,
1418 const armnn::LstmInputParams& params,
1423 auto fbQLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QLstm);
1425 auto fbQLstmDescriptor = serializer::CreateQLstmDescriptor(
1426 m_flatBufferBuilder,
1427 descriptor.m_CifgEnabled,
1428 descriptor.m_PeepholeEnabled,
1429 descriptor.m_ProjectionEnabled,
1430 descriptor.m_LayerNormEnabled,
1431 descriptor.m_CellClip,
1432 descriptor.m_ProjectionClip,
1433 descriptor.m_InputIntermediateScale,
1434 descriptor.m_ForgetIntermediateScale,
1435 descriptor.m_CellIntermediateScale,
1436 descriptor.m_OutputIntermediateScale,
1437 descriptor.m_HiddenStateZeroPoint,
1438 descriptor.m_HiddenStateScale
1442 auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights);
1443 auto inputToCellWeights = CreateConstTensorInfo(*params.m_InputToCellWeights);
1444 auto inputToOutputWeights = CreateConstTensorInfo(*params.m_InputToOutputWeights);
1445 auto recurrentToForgetWeights = CreateConstTensorInfo(*params.m_RecurrentToForgetWeights);
1446 auto recurrentToCellWeights = CreateConstTensorInfo(*params.m_RecurrentToCellWeights);
1447 auto recurrentToOutputWeights = CreateConstTensorInfo(*params.m_RecurrentToOutputWeights);
1448 auto forgetGateBias = CreateConstTensorInfo(*params.m_ForgetGateBias);
1449 auto cellBias = CreateConstTensorInfo(*params.m_CellBias);
1450 auto outputGateBias = CreateConstTensorInfo(*params.m_OutputGateBias);
1453 flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights;
1454 flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights;
1455 flatbuffers::Offset<serializer::ConstTensor> inputGateBias;
1457 if (!descriptor.m_CifgEnabled)
1459 inputToInputWeights = CreateConstTensorInfo(*params.m_InputToInputWeights);
1460 recurrentToInputWeights = CreateConstTensorInfo(*params.m_RecurrentToInputWeights);
1461 inputGateBias = CreateConstTensorInfo(*params.m_InputGateBias);
1465 flatbuffers::Offset<serializer::ConstTensor> projectionWeights;
1466 flatbuffers::Offset<serializer::ConstTensor> projectionBias;
1468 if (descriptor.m_ProjectionEnabled)
1470 projectionWeights = CreateConstTensorInfo(*params.m_ProjectionWeights);
1471 projectionBias = CreateConstTensorInfo(*params.m_ProjectionBias);
1475 flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights;
1476 flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights;
1477 flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights;
1479 if (descriptor.m_PeepholeEnabled)
1481 if (!descriptor.m_CifgEnabled)
1483 cellToInputWeights = CreateConstTensorInfo(*params.m_CellToInputWeights);
1486 cellToForgetWeights = CreateConstTensorInfo(*params.m_CellToForgetWeights);
1487 cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights);
1491 flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights;
1492 flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights;
1493 flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights;
1494 flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights;
1496 if (descriptor.m_LayerNormEnabled)
1498 if (!descriptor.m_CifgEnabled)
1500 inputLayerNormWeights = CreateConstTensorInfo((*params.m_InputLayerNormWeights));
1503 forgetLayerNormWeights = CreateConstTensorInfo(*params.m_ForgetLayerNormWeights);
1504 cellLayerNormWeights = CreateConstTensorInfo(*params.m_CellLayerNormWeights);
1505 outputLayerNormWeights = CreateConstTensorInfo(*params.m_OutputLayerNormWeights);
1508 auto fbQLstmParams = serializer::CreateQLstmInputParams(
1509 m_flatBufferBuilder,
1510 inputToForgetWeights,
1512 inputToOutputWeights,
1513 recurrentToForgetWeights,
1514 recurrentToCellWeights,
1515 recurrentToOutputWeights,
1519 inputToInputWeights,
1520 recurrentToInputWeights,
1525 cellToForgetWeights,
1526 cellToOutputWeights,
1527 inputLayerNormWeights,
1528 forgetLayerNormWeights,
1529 cellLayerNormWeights,
1530 outputLayerNormWeights);
1532 auto fbQLstmLayer = serializer::CreateQLstmLayer(
1533 m_flatBufferBuilder,
1538 CreateAnyLayer(fbQLstmLayer.o, serializer::Layer::Layer_QLstmLayer);
1541 void SerializerVisitor::VisitQuantizedLstmLayer(const armnn::IConnectableLayer* layer,
1542 const armnn::QuantizedLstmInputParams& params,
1547 auto fbQuantizedLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QuantizedLstm);
1549 // Get input parameters
1550 auto inputToInputWeights = CreateConstTensorInfo(params.GetInputToInputWeights());
1551 auto inputToForgetWeights = CreateConstTensorInfo(params.GetInputToForgetWeights());
1552 auto inputToCellWeights = CreateConstTensorInfo(params.GetInputToCellWeights());
1553 auto inputToOutputWeights = CreateConstTensorInfo(params.GetInputToOutputWeights());
1555 auto recurrentToInputWeights = CreateConstTensorInfo(params.GetRecurrentToInputWeights());
1556 auto recurrentToForgetWeights = CreateConstTensorInfo(params.GetRecurrentToForgetWeights());
1557 auto recurrentToCellWeights = CreateConstTensorInfo(params.GetRecurrentToCellWeights());
1558 auto recurrentToOutputWeights = CreateConstTensorInfo(params.GetRecurrentToOutputWeights());
1560 auto inputGateBias = CreateConstTensorInfo(params.GetInputGateBias());
1561 auto forgetGateBias = CreateConstTensorInfo(params.GetForgetGateBias());
1562 auto cellBias = CreateConstTensorInfo(params.GetCellBias());
1563 auto outputGateBias = CreateConstTensorInfo(params.GetOutputGateBias());
1565 auto fbQuantizedLstmParams = serializer::CreateQuantizedLstmInputParams(
1566 m_flatBufferBuilder,
1567 inputToInputWeights,
1568 inputToForgetWeights,
1570 inputToOutputWeights,
1571 recurrentToInputWeights,
1572 recurrentToForgetWeights,
1573 recurrentToCellWeights,
1574 recurrentToOutputWeights,
1580 auto fbQuantizedLstmLayer = serializer::CreateQuantizedLstmLayer(
1581 m_flatBufferBuilder,
1582 fbQuantizedLstmBaseLayer,
1583 fbQuantizedLstmParams);
1585 CreateAnyLayer(fbQuantizedLstmLayer.o, serializer::Layer::Layer_QuantizedLstmLayer);
1588 fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer,
1589 const serializer::LayerType layerType)
1592 uint32_t fbIndex = GetSerializedId(layer->GetGuid());
1594 std::vector<fb::Offset<serializer::InputSlot>> inputSlots = CreateInputSlots(layer);
1595 std::vector<fb::Offset<serializer::OutputSlot>> outputSlots = CreateOutputSlots(layer);
1597 return serializer::CreateLayerBase(m_flatBufferBuilder,
1599 m_flatBufferBuilder.CreateString(layer->GetName()),
1601 m_flatBufferBuilder.CreateVector(inputSlots),
1602 m_flatBufferBuilder.CreateVector(outputSlots));
1605 void SerializerVisitor::CreateAnyLayer(const flatbuffers::Offset<void>& layer, const serializer::Layer serializerLayer)
1608 auto anyLayer = armnnSerializer::CreateAnyLayer(m_flatBufferBuilder, serializerLayer, layer);
1609 m_serializedLayers.push_back(anyLayer);
1612 template <typename T>
1613 flatbuffers::Offset<flatbuffers::Vector<T>> SerializerVisitor::CreateDataVector(const void* memory, unsigned int size)
1615 const T* buffer = reinterpret_cast<const T*>(memory);
1616 std::vector<T> vector(buffer, buffer + (size / sizeof(T)));
1617 auto fbVector = m_flatBufferBuilder.CreateVector(vector);
1621 flatbuffers::Offset<TensorInfo> SerializerVisitor::CreateTensorInfo(const armnn::TensorInfo& tensorInfo)
1623 // Get the dimensions
1624 std::vector<unsigned int> shape;
1625 for(unsigned int dim = 0; dim < tensorInfo.GetShape().GetNumDimensions(); ++dim)
1627 shape.push_back(tensorInfo.GetShape()[dim]);
1630 if (tensorInfo.HasPerAxisQuantization())
1632 // Create FlatBuffer TensorInfo
1633 auto flatBufferTensorInfo =
1634 serializer::CreateTensorInfo(m_flatBufferBuilder,
1635 m_flatBufferBuilder.CreateVector(shape),
1636 GetFlatBufferDataType(tensorInfo.GetDataType()),
1637 tensorInfo.GetQuantizationScales()[0],
1638 tensorInfo.GetQuantizationOffset(),
1639 m_flatBufferBuilder.CreateVector(tensorInfo.GetQuantizationScales()),
1640 tensorInfo.GetQuantizationDim().value(),
1641 static_cast<unsigned int>
1642 (tensorInfo.GetShape().GetDimensionality()));
1643 return flatBufferTensorInfo;
1646 // Create FlatBuffer TensorInfo
1647 auto flatBufferTensorInfo = serializer::CreateTensorInfo(m_flatBufferBuilder,
1648 m_flatBufferBuilder.CreateVector(shape),
1649 GetFlatBufferDataType(tensorInfo.GetDataType()),
1650 tensorInfo.GetQuantizationScale(),
1651 tensorInfo.GetQuantizationOffset(),
1654 static_cast<unsigned int>
1655 (tensorInfo.GetShape().GetDimensionality()));
1656 return flatBufferTensorInfo;
1659 flatbuffers::Offset<serializer::ConstTensor>
1660 SerializerVisitor::CreateConstTensorInfo(const armnn::ConstTensor& constTensor)
1662 armnn::TensorInfo tensorInfo = constTensor.GetInfo();
1664 flatbuffers::Offset<void> fbPayload;
1666 switch (tensorInfo.GetDataType())
1668 case armnn::DataType::Float32:
1669 case armnn::DataType::Signed32:
1671 auto fbVector = CreateDataVector<int32_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1672 flatbuffers::Offset<serializer::IntData> flatBuffersData = serializer::CreateIntData(
1673 m_flatBufferBuilder,
1675 fbPayload = flatBuffersData.o;
1678 case armnn::DataType::Float16:
1679 case armnn::DataType::BFloat16:
1680 case armnn::DataType::QSymmS16:
1682 auto fbVector = CreateDataVector<int16_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1683 flatbuffers::Offset<serializer::ShortData> flatBuffersData = serializer::CreateShortData(
1684 m_flatBufferBuilder,
1686 fbPayload = flatBuffersData.o;
1689 case armnn::DataType::QSymmS8:
1690 case armnn::DataType::QAsymmS8:
1691 case armnn::DataType::QAsymmU8:
1692 case armnn::DataType::Boolean:
1695 auto fbVector = CreateDataVector<int8_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1696 flatbuffers::Offset<serializer::ByteData> flatBuffersData = serializer::CreateByteData(
1697 m_flatBufferBuilder,
1699 fbPayload = flatBuffersData.o;
1702 flatbuffers::Offset<serializer::ConstTensor> flatBufferConstTensor = serializer::CreateConstTensor(
1703 m_flatBufferBuilder,
1704 CreateTensorInfo(tensorInfo),
1705 GetFlatBufferConstTensorData(tensorInfo.GetDataType()),
1707 return flatBufferConstTensor;
1710 flatbuffers::Offset<armnnSerializer::FeatureCompatibilityVersions> SerializerVisitor::GetVersionTable()
1712 flatbuffers::Offset<armnnSerializer::FeatureCompatibilityVersions> versionsTable =
1713 serializer::CreateFeatureCompatibilityVersions(
1714 m_flatBufferBuilder,
1715 1 // Binding ids scheme version
1717 return versionsTable;
1720 std::vector<fb::Offset<serializer::InputSlot>>
1721 SerializerVisitor::CreateInputSlots(const armnn::IConnectableLayer* layer)
1723 std::vector<fb::Offset<serializer::InputSlot>> inputSlots;
1725 // Get the InputSlots
1726 for (unsigned int slotIndex = 0; slotIndex<layer->GetNumInputSlots(); ++slotIndex)
1728 const IInputSlot& inputSlot = layer->GetInputSlot(slotIndex);
1730 // Get the Connection for the InputSlot
1731 const IOutputSlot* connection = inputSlot.GetConnection();
1733 // Create FlatBuffer Connection
1734 serializer::Connection conn(GetSerializedId(inputSlot.GetConnection()->GetOwningLayerGuid()),
1735 connection->CalculateIndexOnOwner());
1736 // Create FlatBuffer InputSlot
1737 inputSlots.push_back(serializer::CreateInputSlot(m_flatBufferBuilder, slotIndex, &conn));
1742 std::vector<fb::Offset<serializer::OutputSlot>>
1743 SerializerVisitor::CreateOutputSlots(const armnn::IConnectableLayer* layer)
1745 std::vector<fb::Offset<serializer::OutputSlot>> outputSlots;
1747 // Get the OutputSlots
1748 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumOutputSlots(); ++slotIndex)
1750 const IOutputSlot& outputSlot = layer->GetOutputSlot(slotIndex);
1751 const armnn::TensorInfo& tensorInfo = outputSlot.GetTensorInfo();
1753 // Create FlatBuffer Outputslot
1754 outputSlots.push_back(serializer::CreateOutputSlot(m_flatBufferBuilder,
1756 CreateTensorInfo(tensorInfo)));
1761 void ISerializer::SerializerImpl::Serialize(const INetwork& inNetwork)
1763 // Iterate through to network
1764 inNetwork.Accept(m_SerializerVisitor);
1765 flatbuffers::FlatBufferBuilder& fbBuilder = m_SerializerVisitor.GetFlatBufferBuilder();
1767 // Create FlatBuffer SerializedGraph
1768 auto serializedGraph = serializer::CreateSerializedGraph(
1770 fbBuilder.CreateVector(m_SerializerVisitor.GetSerializedLayers()),
1771 fbBuilder.CreateVector(m_SerializerVisitor.GetInputIds()),
1772 fbBuilder.CreateVector(m_SerializerVisitor.GetOutputIds()),
1773 m_SerializerVisitor.GetVersionTable());
1775 // Serialize the graph
1776 fbBuilder.Finish(serializedGraph);
1779 bool ISerializer::SerializerImpl::SaveSerializedToStream(std::ostream& stream)
1781 flatbuffers::FlatBufferBuilder& fbBuilder = m_SerializerVisitor.GetFlatBufferBuilder();
1783 auto bytesToWrite = armnn::numeric_cast<std::streamsize>(fbBuilder.GetSize());
1784 stream.write(reinterpret_cast<const char*>(fbBuilder.GetBufferPointer()), bytesToWrite);
1785 return !stream.bad();
1789 } // namespace armnnSerializer