2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include "Serializer.hpp"
8 #include "SerializerUtils.hpp"
10 #include <armnn/ArmNN.hpp>
14 #include <boost/numeric/conversion/cast.hpp>
16 #include <flatbuffers/util.h>
18 using namespace armnn;
19 namespace fb = flatbuffers;
20 namespace serializer = armnnSerializer;
22 namespace armnnSerializer
25 serializer::ActivationFunction GetFlatBufferActivationFunction(armnn::ActivationFunction function)
29 case armnn::ActivationFunction::Sigmoid:
30 return serializer::ActivationFunction::ActivationFunction_Sigmoid;
31 case armnn::ActivationFunction::TanH:
32 return serializer::ActivationFunction::ActivationFunction_TanH;
33 case armnn::ActivationFunction::Linear:
34 return serializer::ActivationFunction::ActivationFunction_Linear;
35 case armnn::ActivationFunction::ReLu:
36 return serializer::ActivationFunction::ActivationFunction_ReLu;
37 case armnn::ActivationFunction::BoundedReLu:
38 return serializer::ActivationFunction::ActivationFunction_BoundedReLu;
39 case armnn::ActivationFunction::LeakyReLu:
40 return serializer::ActivationFunction::ActivationFunction_LeakyReLu;
41 case armnn::ActivationFunction::Abs:
42 return serializer::ActivationFunction::ActivationFunction_Abs;
43 case armnn::ActivationFunction::Sqrt:
44 return serializer::ActivationFunction::ActivationFunction_Sqrt;
45 case armnn::ActivationFunction::Square:
46 return serializer::ActivationFunction::ActivationFunction_Square;
48 return serializer::ActivationFunction::ActivationFunction_Sigmoid;
52 uint32_t SerializerVisitor::GetSerializedId(unsigned int guid)
54 std::pair<unsigned int, uint32_t> guidPair(guid, m_layerId);
56 if (m_guidMap.empty())
58 m_guidMap.insert(guidPair);
60 else if (m_guidMap.find(guid) == m_guidMap.end())
62 guidPair.second = ++m_layerId;
63 m_guidMap.insert(guidPair);
66 return m_guidMap[guid];
69 // Build FlatBuffer for Input Layer
70 void SerializerVisitor::VisitInputLayer(const armnn::IConnectableLayer* layer, LayerBindingId id, const char* name)
72 // Create FlatBuffer BaseLayer
73 auto flatBufferInputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Input);
75 // Create FlatBuffer BindableBaseLayer
76 auto flatBufferInputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
77 flatBufferInputBaseLayer,
79 // Push layer Guid to outputIds.
80 m_inputIds.push_back(GetSerializedId(layer->GetGuid()));
82 // Create the FlatBuffer InputLayer
83 auto flatBufferInputLayer = serializer::CreateInputLayer(m_flatBufferBuilder, flatBufferInputBindableBaseLayer);
85 // Add the AnyLayer to the FlatBufferLayers
86 CreateAnyLayer(flatBufferInputLayer.o, serializer::Layer::Layer_InputLayer);
89 // Build FlatBuffer for Output Layer
90 void SerializerVisitor::VisitOutputLayer(const armnn::IConnectableLayer* layer, LayerBindingId id, const char* name)
92 // Create FlatBuffer BaseLayer
93 auto flatBufferOutputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Output);
95 // Create FlatBuffer BindableBaseLayer
96 auto flatBufferOutputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
97 flatBufferOutputBaseLayer,
99 // Push layer Guid to outputIds.
100 m_outputIds.push_back(GetSerializedId(layer->GetGuid()));
102 // Create the FlatBuffer OutputLayer
103 auto flatBufferOutputLayer = serializer::CreateOutputLayer(m_flatBufferBuilder, flatBufferOutputBindableBaseLayer);
104 // Add the AnyLayer to the FlatBufferLayers
105 CreateAnyLayer(flatBufferOutputLayer.o, serializer::Layer::Layer_OutputLayer);
108 void SerializerVisitor::VisitAbsLayer(const armnn::IConnectableLayer* layer, const char* name)
110 throw UnimplementedException("SerializerVisitor::VisitAbsLayer is not implemented");
113 // Build FlatBuffer for Activation Layer
114 void SerializerVisitor::VisitActivationLayer(const armnn::IConnectableLayer* layer,
115 const armnn::ActivationDescriptor& descriptor,
118 // Create FlatBuffer BaseLayer
119 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Activation);
121 // Create the FlatBuffer ActivationDescriptor
122 auto flatBufferDescriptor = CreateActivationDescriptor(m_flatBufferBuilder,
123 GetFlatBufferActivationFunction(descriptor.m_Function),
127 // Create the FlatBuffer ActivationLayer
128 auto flatBufferAdditionLayer = CreateActivationLayer(m_flatBufferBuilder,
130 flatBufferDescriptor);
132 // Add the AnyLayer to the FlatBufferLayers
133 CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_ActivationLayer);
136 // Build FlatBuffer for Addition Layer
137 void SerializerVisitor::VisitAdditionLayer(const armnn::IConnectableLayer* layer, const char* name)
139 // Create FlatBuffer BaseLayer
140 auto flatBufferAdditionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Addition);
142 // Create the FlatBuffer AdditionLayer
143 auto flatBufferAdditionLayer = serializer::CreateAdditionLayer(m_flatBufferBuilder, flatBufferAdditionBaseLayer);
145 // Add the AnyLayer to the FlatBufferLayers
146 CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_AdditionLayer);
149 // Build FlatBuffer for BatchToSpaceNd Layer
150 void SerializerVisitor::VisitBatchToSpaceNdLayer(const armnn::IConnectableLayer* layer,
151 const armnn::BatchToSpaceNdDescriptor& descriptor,
154 // Create FlatBuffer BaseLayer
155 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_BatchToSpaceNd);
157 std::vector<unsigned int> crops;
158 crops.reserve(descriptor.m_Crops.size() * 2);
159 for (auto& crop : descriptor.m_Crops)
161 crops.push_back(crop.first);
162 crops.push_back(crop.second);
165 auto flatBufferDescriptor =
166 CreateBatchToSpaceNdDescriptor(m_flatBufferBuilder,
167 m_flatBufferBuilder.CreateVector(descriptor.m_BlockShape),
168 m_flatBufferBuilder.CreateVector(crops),
169 GetFlatBufferDataLayout(descriptor.m_DataLayout));
171 auto flatBufferLayer = serializer::CreateBatchToSpaceNdLayer(m_flatBufferBuilder,
173 flatBufferDescriptor);
175 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_BatchToSpaceNdLayer);
178 void SerializerVisitor::VisitBatchNormalizationLayer(const armnn::IConnectableLayer* layer,
179 const armnn::BatchNormalizationDescriptor& batchNormDescriptor,
180 const armnn::ConstTensor& mean,
181 const armnn::ConstTensor& variance,
182 const armnn::ConstTensor& beta,
183 const armnn::ConstTensor& gamma,
186 auto fbBatchNormalizationBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_BatchNormalization);
187 auto fbBatchNormalizationDescriptor = serializer::CreateBatchNormalizationDescriptor(
189 batchNormDescriptor.m_Eps,
190 GetFlatBufferDataLayout(batchNormDescriptor.m_DataLayout));
192 auto fbMeanConstTensorInfo = CreateConstTensorInfo(mean);
193 auto fbVarianceConstTensorInfo = CreateConstTensorInfo(variance);
194 auto fbBetaConstTensorInfo = CreateConstTensorInfo(beta);
195 auto fbGammaConstTensorInfo = CreateConstTensorInfo(gamma);
196 auto fbBatchNormalizationLayer = serializer::CreateBatchNormalizationLayer(m_flatBufferBuilder,
197 fbBatchNormalizationBaseLayer,
198 fbBatchNormalizationDescriptor,
199 fbMeanConstTensorInfo,
200 fbVarianceConstTensorInfo,
201 fbBetaConstTensorInfo,
202 fbGammaConstTensorInfo);
204 CreateAnyLayer(fbBatchNormalizationLayer.o, serializer::Layer::Layer_BatchNormalizationLayer);
207 // Build FlatBuffer for Constant Layer
208 void SerializerVisitor::VisitConstantLayer(const armnn::IConnectableLayer* layer,
209 const armnn::ConstTensor& input,
212 // Create FlatBuffer BaseLayer
213 auto flatBufferConstantBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Constant);
215 auto flatBufferConstTensorInfo = CreateConstTensorInfo(input);
217 // Create the FlatBuffer ConstantLayer
218 auto flatBufferLayer = CreateConstantLayer(m_flatBufferBuilder,
219 flatBufferConstantBaseLayer,
220 flatBufferConstTensorInfo);
222 // Add the AnyLayer to the FlatBufferLayers
223 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ConstantLayer);
226 // Build FlatBuffer for Convolution2dLayer
227 void SerializerVisitor::VisitConvolution2dLayer(const armnn::IConnectableLayer* layer,
228 const armnn::Convolution2dDescriptor& descriptor,
229 const armnn::ConstTensor& weights,
230 const armnn::Optional<armnn::ConstTensor>& biases,
233 // Create FlatBuffer BaseLayer
234 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Convolution2d);
236 auto flatBufferDescriptor = CreateConvolution2dDescriptor(m_flatBufferBuilder,
237 descriptor.m_PadLeft,
238 descriptor.m_PadRight,
240 descriptor.m_PadBottom,
241 descriptor.m_StrideX,
242 descriptor.m_StrideY,
243 descriptor.m_DilationX,
244 descriptor.m_DilationY,
245 descriptor.m_BiasEnabled,
246 GetFlatBufferDataLayout(descriptor.m_DataLayout));
247 auto flatBufferWeightsConstTensorInfo = CreateConstTensorInfo(weights);
248 flatbuffers::Offset<serializer::ConstTensor> flatBufferBiasesConstTensorInfo;
250 if (biases.has_value())
252 flatBufferBiasesConstTensorInfo = CreateConstTensorInfo(biases.value());
255 // Create the FlatBuffer Convolution2dLayer
256 auto flatBufferLayer = CreateConvolution2dLayer(m_flatBufferBuilder,
258 flatBufferDescriptor,
259 flatBufferWeightsConstTensorInfo,
260 flatBufferBiasesConstTensorInfo);
262 // Add the AnyLayer to the FlatBufferLayers
263 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_Convolution2dLayer);
266 void SerializerVisitor::VisitDepthwiseConvolution2dLayer(const armnn::IConnectableLayer* layer,
267 const armnn::DepthwiseConvolution2dDescriptor& descriptor,
268 const armnn::ConstTensor& weights,
269 const armnn::Optional<armnn::ConstTensor>& biases,
272 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_DepthwiseConvolution2d);
273 auto fbDescriptor = CreateDepthwiseConvolution2dDescriptor(m_flatBufferBuilder,
274 descriptor.m_PadLeft,
275 descriptor.m_PadRight,
277 descriptor.m_PadBottom,
278 descriptor.m_StrideX,
279 descriptor.m_StrideY,
280 descriptor.m_DilationX,
281 descriptor.m_DilationY,
282 descriptor.m_BiasEnabled,
283 GetFlatBufferDataLayout(descriptor.m_DataLayout));
285 flatbuffers::Offset<serializer::ConstTensor> fbWeightsConstTensorInfo = CreateConstTensorInfo(weights);
286 flatbuffers::Offset<serializer::ConstTensor> fbBiasesConstTensorInfo;
287 if (biases.has_value())
289 fbBiasesConstTensorInfo = CreateConstTensorInfo(biases.value());
292 auto flatBufferLayer = CreateDepthwiseConvolution2dLayer(m_flatBufferBuilder,
295 fbWeightsConstTensorInfo,
296 fbBiasesConstTensorInfo);
298 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_DepthwiseConvolution2dLayer);
301 void SerializerVisitor::VisitDequantizeLayer(const armnn::IConnectableLayer* layer,
304 auto fbDequantizeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Dequantize);
305 auto fbDequantizeLayer = serializer::CreateDequantizeLayer(m_flatBufferBuilder, fbDequantizeBaseLayer);
307 CreateAnyLayer(fbDequantizeLayer.o, serializer::Layer::Layer_DequantizeLayer);
310 void SerializerVisitor::VisitDetectionPostProcessLayer(const armnn::IConnectableLayer* layer,
311 const armnn::DetectionPostProcessDescriptor& descriptor,
312 const armnn::ConstTensor& anchors,
315 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_DetectionPostProcess);
316 auto fbDescriptor = CreateDetectionPostProcessDescriptor(m_flatBufferBuilder,
317 descriptor.m_MaxDetections,
318 descriptor.m_MaxClassesPerDetection,
319 descriptor.m_DetectionsPerClass,
320 descriptor.m_NmsScoreThreshold,
321 descriptor.m_NmsIouThreshold,
322 descriptor.m_NumClasses,
323 descriptor.m_UseRegularNms,
327 descriptor.m_ScaleH);
329 flatbuffers::Offset<serializer::ConstTensor> fbAnchorsConstTensorInfo = CreateConstTensorInfo(anchors);
331 auto flatBufferLayer = CreateDetectionPostProcessLayer(m_flatBufferBuilder,
334 fbAnchorsConstTensorInfo);
336 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_DetectionPostProcessLayer);
339 void SerializerVisitor::VisitDivisionLayer(const armnn::IConnectableLayer* layer, const char* name)
341 auto fbDivisionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Division);
342 auto fbDivisionLayer = serializer::CreateDivisionLayer(m_flatBufferBuilder, fbDivisionBaseLayer);
344 CreateAnyLayer(fbDivisionLayer.o, serializer::Layer::Layer_DivisionLayer);
347 void SerializerVisitor::VisitEqualLayer(const armnn::IConnectableLayer* layer, const char* name)
349 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Equal);
350 auto fbEqualLayer = serializer::CreateEqualLayer(m_flatBufferBuilder, fbBaseLayer);
352 CreateAnyLayer(fbEqualLayer.o, serializer::Layer::Layer_EqualLayer);
355 void SerializerVisitor::VisitFloorLayer(const armnn::IConnectableLayer *layer, const char *name)
357 auto flatBufferFloorBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Floor);
358 auto flatBufferFloorLayer = serializer::CreateFloorLayer(m_flatBufferBuilder, flatBufferFloorBaseLayer);
360 CreateAnyLayer(flatBufferFloorLayer.o, serializer::Layer::Layer_FloorLayer);
363 void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer, const char* name)
365 auto fbGatherBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Gather);
366 auto flatBufferLayer = serializer::CreateGatherLayer(m_flatBufferBuilder, fbGatherBaseLayer);
368 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_GatherLayer);
371 void SerializerVisitor::VisitGreaterLayer(const armnn::IConnectableLayer* layer, const char* name)
373 auto fbGreaterBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Greater);
374 auto fbGreaterLayer = serializer::CreateGreaterLayer(m_flatBufferBuilder, fbGreaterBaseLayer);
376 CreateAnyLayer(fbGreaterLayer.o, serializer::Layer::Layer_GreaterLayer);
379 void SerializerVisitor::VisitL2NormalizationLayer(const armnn::IConnectableLayer* layer,
380 const armnn::L2NormalizationDescriptor& l2NormalizationDescriptor,
383 // Create FlatBuffer BaseLayer
384 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_L2Normalization);
386 // Create the FlatBuffer L2Normalization Descriptor
387 auto fbDescriptor = serializer::CreateL2NormalizationDescriptor(
389 GetFlatBufferDataLayout(l2NormalizationDescriptor.m_DataLayout),
390 l2NormalizationDescriptor.m_Eps);
392 // Create FlatBuffer layer
393 auto fbLayer = serializer::CreateL2NormalizationLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
395 CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_L2NormalizationLayer);
398 void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, const armnn::LstmDescriptor& descriptor,
399 const armnn::LstmInputParams& params, const char* name)
401 auto fbLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Lstm);
403 auto fbLstmDescriptor = serializer::CreateLstmDescriptor(
405 descriptor.m_ActivationFunc,
406 descriptor.m_ClippingThresCell,
407 descriptor.m_ClippingThresProj,
408 descriptor.m_CifgEnabled,
409 descriptor.m_PeepholeEnabled,
410 descriptor.m_ProjectionEnabled,
411 descriptor.m_LayerNormEnabled);
413 // Get mandatory input parameters
414 auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights);
415 auto inputToCellWeights = CreateConstTensorInfo(*params.m_InputToCellWeights);
416 auto inputToOutputWeights = CreateConstTensorInfo(*params.m_InputToOutputWeights);
417 auto recurrentToForgetWeights = CreateConstTensorInfo(*params.m_RecurrentToForgetWeights);
418 auto recurrentToCellWeights = CreateConstTensorInfo(*params.m_RecurrentToCellWeights);
419 auto recurrentToOutputWeights = CreateConstTensorInfo(*params.m_RecurrentToOutputWeights);
420 auto forgetGateBias = CreateConstTensorInfo(*params.m_ForgetGateBias);
421 auto cellBias = CreateConstTensorInfo(*params.m_CellBias);
422 auto outputGateBias = CreateConstTensorInfo(*params.m_OutputGateBias);
424 //Define optional parameters, these will be set depending on configuration in Lstm descriptor
425 flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights;
426 flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights;
427 flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights;
428 flatbuffers::Offset<serializer::ConstTensor> inputGateBias;
429 flatbuffers::Offset<serializer::ConstTensor> projectionWeights;
430 flatbuffers::Offset<serializer::ConstTensor> projectionBias;
431 flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights;
432 flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights;
433 flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights;
434 flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights;
435 flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights;
436 flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights;
438 if (!descriptor.m_CifgEnabled)
440 inputToInputWeights = CreateConstTensorInfo(*params.m_InputToInputWeights);
441 recurrentToInputWeights = CreateConstTensorInfo(*params.m_RecurrentToInputWeights);
442 cellToInputWeights = CreateConstTensorInfo(*params.m_CellToInputWeights);
443 inputGateBias = CreateConstTensorInfo(*params.m_InputGateBias);
446 if (descriptor.m_ProjectionEnabled)
448 projectionWeights = CreateConstTensorInfo(*params.m_ProjectionWeights);
449 projectionBias = CreateConstTensorInfo(*params.m_ProjectionBias);
452 if (descriptor.m_PeepholeEnabled)
454 cellToForgetWeights = CreateConstTensorInfo(*params.m_CellToForgetWeights);
455 cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights);
458 if (descriptor.m_LayerNormEnabled)
460 if (!descriptor.m_CifgEnabled)
462 inputLayerNormWeights = CreateConstTensorInfo((*params.m_InputLayerNormWeights));
464 forgetLayerNormWeights = CreateConstTensorInfo(*params.m_ForgetLayerNormWeights);
465 cellLayerNormWeights = CreateConstTensorInfo(*params.m_CellLayerNormWeights);
466 outputLayerNormWeights = CreateConstTensorInfo(*params.m_OutputLayerNormWeights);
469 auto fbLstmParams = serializer::CreateLstmInputParams(
471 inputToForgetWeights,
473 inputToOutputWeights,
474 recurrentToForgetWeights,
475 recurrentToCellWeights,
476 recurrentToOutputWeights,
481 recurrentToInputWeights,
488 inputLayerNormWeights,
489 forgetLayerNormWeights,
490 cellLayerNormWeights,
491 outputLayerNormWeights);
493 auto fbLstmLayer = serializer::CreateLstmLayer(
499 CreateAnyLayer(fbLstmLayer.o, serializer::Layer::Layer_LstmLayer);
502 void SerializerVisitor::VisitMaximumLayer(const armnn::IConnectableLayer* layer, const char* name)
504 auto fbMaximumBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Maximum);
505 auto fbMaximumLayer = serializer::CreateMaximumLayer(m_flatBufferBuilder, fbMaximumBaseLayer);
507 CreateAnyLayer(fbMaximumLayer.o, serializer::Layer::Layer_MaximumLayer);
510 void SerializerVisitor::VisitMeanLayer(const armnn::IConnectableLayer* layer,
511 const armnn::MeanDescriptor& descriptor,
514 auto fbMeanBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Mean);
515 auto fbMeanDescriptor = serializer::CreateMeanDescriptor(m_flatBufferBuilder,
516 m_flatBufferBuilder.CreateVector(descriptor.m_Axis),
517 descriptor.m_KeepDims);
519 auto fbMeanLayer = serializer::CreateMeanLayer(m_flatBufferBuilder,
523 CreateAnyLayer(fbMeanLayer.o, serializer::Layer::Layer_MeanLayer);
526 void SerializerVisitor::VisitMinimumLayer(const armnn::IConnectableLayer* layer, const char* name)
528 auto fbMinimumBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Minimum);
529 auto fbMinimumLayer = serializer::CreateMinimumLayer(m_flatBufferBuilder, fbMinimumBaseLayer);
531 CreateAnyLayer(fbMinimumLayer.o, serializer::Layer::Layer_MinimumLayer);
534 void SerializerVisitor::VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name)
536 auto fbMergeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Merge);
537 auto fbMergeLayer = serializer::CreateMergeLayer(m_flatBufferBuilder, fbMergeBaseLayer);
539 CreateAnyLayer(fbMergeLayer.o, serializer::Layer::Layer_MergeLayer);
542 void SerializerVisitor::VisitMergerLayer(const armnn::IConnectableLayer* layer,
543 const armnn::MergerDescriptor& mergerDescriptor,
546 VisitConcatLayer(layer, mergerDescriptor, name);
549 void SerializerVisitor::VisitConcatLayer(const armnn::IConnectableLayer* layer,
550 const armnn::ConcatDescriptor& concatDescriptor,
553 auto flatBufferConcatBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Concat);
555 std::vector<flatbuffers::Offset<UintVector>> views;
556 for (unsigned int v = 0; v < concatDescriptor.GetNumViews(); ++v)
558 const uint32_t* origin = concatDescriptor.GetViewOrigin(v);
559 std::vector<uint32_t> origins;
560 for (unsigned int d = 0; d < concatDescriptor.GetNumDimensions(); ++d)
562 origins.push_back(origin[d]);
564 auto view = m_flatBufferBuilder.CreateVector(origins);
565 auto uintVector = CreateUintVector(m_flatBufferBuilder, view);
566 views.push_back(uintVector);
569 auto flatBufferConcatDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder,
570 concatDescriptor.GetConcatAxis(),
571 concatDescriptor.GetNumViews(),
572 concatDescriptor.GetNumDimensions(),
573 m_flatBufferBuilder.CreateVector(views));
575 auto flatBufferLayer = CreateConcatLayer(m_flatBufferBuilder,
576 flatBufferConcatBaseLayer,
577 flatBufferConcatDescriptor);
579 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ConcatLayer);
582 void SerializerVisitor::VisitMultiplicationLayer(const armnn::IConnectableLayer* layer, const char* name)
584 auto fbMultiplicationBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Multiplication);
585 auto fbMultiplicationLayer = serializer::CreateMultiplicationLayer(m_flatBufferBuilder,
586 fbMultiplicationBaseLayer);
588 CreateAnyLayer(fbMultiplicationLayer.o, serializer::Layer::Layer_MultiplicationLayer);
591 void SerializerVisitor::VisitPadLayer(const armnn::IConnectableLayer* layer,
592 const armnn::PadDescriptor& padDescriptor,
595 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Pad);
597 std::vector<unsigned int> padList;
598 for (auto& p: padDescriptor.m_PadList)
600 padList.push_back(p.first);
601 padList.push_back(p.second);
604 auto flatBufferPadDesc = serializer::CreatePadDescriptor(m_flatBufferBuilder,
605 m_flatBufferBuilder.CreateVector(padList),
606 padDescriptor.m_PadValue);
608 auto flatBufferPadLayer = serializer::CreatePadLayer(m_flatBufferBuilder,
612 CreateAnyLayer(flatBufferPadLayer.o, serializer::Layer::Layer_PadLayer);
615 void SerializerVisitor::VisitPermuteLayer(const armnn::IConnectableLayer* layer,
616 const armnn::PermuteDescriptor& permuteDescriptor,
619 // Create FlatBuffer BaseLayer
620 auto flatBufferPermuteBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Permute);
622 std::vector<unsigned int> dimMappings;
623 for (unsigned int i=0; i<permuteDescriptor.m_DimMappings.GetSize(); ++i)
625 dimMappings.push_back(permuteDescriptor.m_DimMappings[i]);
628 auto flatBufferPermuteDesc = serializer::CreatePermuteDescriptor(m_flatBufferBuilder,
629 m_flatBufferBuilder.CreateVector(dimMappings));
631 // Create the FlatBuffer PermuteLayer
632 auto flatBufferPermuteLayer = serializer::CreatePermuteLayer(m_flatBufferBuilder,
633 flatBufferPermuteBaseLayer,
634 flatBufferPermuteDesc);
636 // Add the AnyLayer to the FlatBufferLayers
637 CreateAnyLayer(flatBufferPermuteLayer.o, serializer::Layer::Layer_PermuteLayer);
640 // Build FlatBuffer for Reshape Layer
641 void SerializerVisitor::VisitReshapeLayer(const armnn::IConnectableLayer* layer,
642 const armnn::ReshapeDescriptor& reshapeDescriptor,
645 // Create FlatBuffer BaseLayer
646 auto flatBufferReshapeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Reshape);
648 std::vector<unsigned int> targetShape;
649 for (unsigned int i =0; i < reshapeDescriptor.m_TargetShape.GetNumDimensions(); i++)
651 targetShape.push_back(reshapeDescriptor.m_TargetShape[i]);
654 auto flatBufferReshapeDesc = serializer::CreateReshapeDescriptor(m_flatBufferBuilder,
655 m_flatBufferBuilder.CreateVector(targetShape));
657 // Create the FlatBuffer ReshapeLayer
658 auto flatBufferReshapeLayer = serializer::CreateReshapeLayer(m_flatBufferBuilder, flatBufferReshapeBaseLayer,
659 flatBufferReshapeDesc);
661 // Add the AnyLayer to the FlatBufferLayers
662 CreateAnyLayer(flatBufferReshapeLayer.o, serializer::Layer::Layer_ReshapeLayer);
665 void SerializerVisitor::VisitResizeBilinearLayer(const armnn::IConnectableLayer* layer,
666 const armnn::ResizeBilinearDescriptor& resizeDescriptor,
669 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_ResizeBilinear);
671 auto flatBufferDescriptor =
672 CreateResizeBilinearDescriptor(m_flatBufferBuilder,
673 resizeDescriptor.m_TargetWidth,
674 resizeDescriptor.m_TargetHeight,
675 GetFlatBufferDataLayout(resizeDescriptor.m_DataLayout));
677 auto flatBufferLayer = serializer::CreateResizeBilinearLayer(m_flatBufferBuilder,
679 flatBufferDescriptor);
681 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ResizeBilinearLayer);
684 void SerializerVisitor::VisitResizeLayer(const armnn::IConnectableLayer* layer,
685 const armnn::ResizeDescriptor& resizeDescriptor,
688 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Resize);
690 auto flatBufferDescriptor =
691 CreateResizeDescriptor(m_flatBufferBuilder,
692 resizeDescriptor.m_TargetHeight,
693 resizeDescriptor.m_TargetWidth,
694 GetFlatBufferResizeMethod(resizeDescriptor.m_Method),
695 GetFlatBufferDataLayout(resizeDescriptor.m_DataLayout));
697 auto flatBufferLayer = serializer::CreateResizeLayer(m_flatBufferBuilder,
699 flatBufferDescriptor);
701 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ResizeLayer);
704 void SerializerVisitor::VisitRsqrtLayer(const armnn::IConnectableLayer* layer, const char* name)
706 auto fbRsqrtBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Rsqrt);
707 auto fbRsqrtLayer = serializer::CreateRsqrtLayer(m_flatBufferBuilder, fbRsqrtBaseLayer);
709 CreateAnyLayer(fbRsqrtLayer.o, serializer::Layer::Layer_RsqrtLayer);
712 // Build FlatBuffer for Softmax Layer
713 void SerializerVisitor::VisitSoftmaxLayer(const armnn::IConnectableLayer* layer,
714 const armnn::SoftmaxDescriptor& softmaxDescriptor,
717 // Create FlatBuffer BaseLayer
718 auto flatBufferSoftmaxBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Softmax);
720 // Create the FlatBuffer SoftmaxDescriptor
721 auto flatBufferSoftmaxDesc =
722 serializer::CreateSoftmaxDescriptor(m_flatBufferBuilder, softmaxDescriptor.m_Beta);
724 // Create the FlatBuffer SoftmaxLayer
725 auto flatBufferSoftmaxLayer =
726 serializer::CreateSoftmaxLayer(m_flatBufferBuilder,
727 flatBufferSoftmaxBaseLayer,
728 flatBufferSoftmaxDesc);
730 CreateAnyLayer(flatBufferSoftmaxLayer.o, serializer::Layer::Layer_SoftmaxLayer);
733 void SerializerVisitor::VisitPooling2dLayer(const armnn::IConnectableLayer* layer,
734 const armnn::Pooling2dDescriptor& pooling2dDescriptor,
737 auto fbPooling2dBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Pooling2d);
738 auto fbPooling2dDescriptor = serializer::CreatePooling2dDescriptor(
740 GetFlatBufferPoolingAlgorithm(pooling2dDescriptor.m_PoolType),
741 pooling2dDescriptor.m_PadLeft,
742 pooling2dDescriptor.m_PadRight,
743 pooling2dDescriptor.m_PadTop,
744 pooling2dDescriptor.m_PadBottom,
745 pooling2dDescriptor.m_PoolWidth,
746 pooling2dDescriptor.m_PoolHeight,
747 pooling2dDescriptor.m_StrideX,
748 pooling2dDescriptor.m_StrideY,
749 GetFlatBufferOutputShapeRounding(pooling2dDescriptor.m_OutputShapeRounding),
750 GetFlatBufferPaddingMethod(pooling2dDescriptor.m_PaddingMethod),
751 GetFlatBufferDataLayout(pooling2dDescriptor.m_DataLayout));
753 auto fbPooling2dLayer = serializer::CreatePooling2dLayer(m_flatBufferBuilder,
754 fbPooling2dBaseLayer,
755 fbPooling2dDescriptor);
757 CreateAnyLayer(fbPooling2dLayer.o, serializer::Layer::Layer_Pooling2dLayer);
760 void SerializerVisitor::VisitPreluLayer(const armnn::IConnectableLayer* layer,
763 // Create FlatBuffer BaseLayer
764 auto flatBufferPreluBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Prelu);
766 // Create the FlatBuffer AdditionLayer
767 auto flatBufferPreluLayer = serializer::CreatePreluLayer(m_flatBufferBuilder, flatBufferPreluBaseLayer);
769 // Add the AnyLayer to the FlatBufferLayers
770 CreateAnyLayer(flatBufferPreluLayer.o, serializer::Layer::Layer_PreluLayer);
773 void SerializerVisitor::VisitQuantizeLayer(const armnn::IConnectableLayer *layer, const char *name)
775 auto fbQuantizeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Quantize);
776 auto fbQuantizeLayer = serializer::CreateQuantizeLayer(m_flatBufferBuilder,
777 fbQuantizeBaseLayer);
778 CreateAnyLayer(fbQuantizeLayer.o, serializer::Layer::Layer_QuantizeLayer);
781 // Build FlatBuffer for FullyConnected Layer
782 void SerializerVisitor::VisitFullyConnectedLayer(const armnn::IConnectableLayer* layer,
783 const armnn::FullyConnectedDescriptor& fullyConnectedDescriptor,
784 const armnn::ConstTensor& weights,
785 const armnn::Optional<armnn::ConstTensor>& biases,
788 // Create FlatBuffer BaseLayer
789 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_FullyConnected);
791 // Create FlatBuffer FullyConnectedDescriptor
792 auto flatBufferDescriptor =
793 serializer::CreateFullyConnectedDescriptor(m_flatBufferBuilder,
794 fullyConnectedDescriptor.m_BiasEnabled,
795 fullyConnectedDescriptor.m_TransposeWeightMatrix);
797 // Create FlatBuffer weights data
798 auto flatBufferWeights = CreateConstTensorInfo(weights);
800 // Create FlatBuffer bias data
801 flatbuffers::Offset<serializer::ConstTensor> flatBufferBiases;
802 if (fullyConnectedDescriptor.m_BiasEnabled)
804 flatBufferBiases = CreateConstTensorInfo(biases.value());
807 // Create FlatBuffer FullyConnectedLayer
808 auto flatBufferLayer = serializer::CreateFullyConnectedLayer(m_flatBufferBuilder,
810 flatBufferDescriptor,
814 // Add created FullyConnectedLayer to the FlatBufferLayers
815 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_FullyConnectedLayer);
818 // Build FlatBuffer for SpaceToBatchNd Layer
819 void SerializerVisitor::VisitSpaceToBatchNdLayer(const armnn::IConnectableLayer* layer,
820 const armnn::SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
823 // Create FlatBuffer BaseLayer
824 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_SpaceToBatchNd);
826 std::vector<unsigned int> padList;
827 padList.reserve(spaceToBatchNdDescriptor.m_PadList.size()*2);
828 for (auto& pad : spaceToBatchNdDescriptor.m_PadList)
830 padList.push_back(pad.first);
831 padList.push_back(pad.second);
834 auto flatBufferDescriptor =
835 CreateSpaceToBatchNdDescriptor(m_flatBufferBuilder,
836 m_flatBufferBuilder.CreateVector(spaceToBatchNdDescriptor.m_BlockShape),
837 m_flatBufferBuilder.CreateVector(padList),
838 GetFlatBufferDataLayout(spaceToBatchNdDescriptor.m_DataLayout));
840 auto flatBufferLayer = serializer::CreateSpaceToBatchNdLayer(m_flatBufferBuilder,
842 flatBufferDescriptor);
844 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_SpaceToBatchNdLayer);
847 // Build FlatBuffer for SpaceToDepthLayer
848 void SerializerVisitor::VisitSpaceToDepthLayer(const armnn::IConnectableLayer* layer,
849 const armnn::SpaceToDepthDescriptor& spaceToDepthDescriptor,
852 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_SpaceToDepth);
853 auto flatBufferDescriptor =
854 CreateSpaceToDepthDescriptor(m_flatBufferBuilder,
855 spaceToDepthDescriptor.m_BlockSize,
856 GetFlatBufferDataLayout(spaceToDepthDescriptor.m_DataLayout));
858 auto flatBufferLayer = serializer::CreateSpaceToDepthLayer(m_flatBufferBuilder,
860 flatBufferDescriptor);
862 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_SpaceToDepthLayer);
865 // Build FlatBuffer for Splitter Layer
866 void SerializerVisitor::VisitSplitterLayer(const armnn::IConnectableLayer* layer,
867 const armnn::ViewsDescriptor& viewsDescriptor,
870 // Create FlatBuffer ViewOrigins
871 std::vector<flatbuffers::Offset<UintVector>> flatBufferViewOrigins;
872 flatBufferViewOrigins.reserve(viewsDescriptor.GetNumViews());
874 for(unsigned int vIdx = 0; vIdx < viewsDescriptor.GetNumViews(); ++vIdx)
876 std::vector<uint32_t> viewOrigin;
877 viewOrigin.reserve(viewsDescriptor.GetNumDimensions());
880 for(unsigned int dIdx = 0; dIdx < viewsDescriptor.GetNumDimensions(); ++dIdx)
882 viewOrigin.push_back(viewsDescriptor.GetViewOrigin(vIdx)[dIdx]);
885 flatBufferViewOrigins.push_back(CreateUintVector(m_flatBufferBuilder,
886 m_flatBufferBuilder.CreateVector(viewOrigin)));
889 // Create FlatBuffer OriginsDescriptor
890 auto flatBufferOriginDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder,
891 viewsDescriptor.GetOrigins().GetConcatAxis(),
892 viewsDescriptor.GetOrigins().GetNumViews(),
893 viewsDescriptor.GetOrigins().GetNumDimensions(),
894 m_flatBufferBuilder.CreateVector(flatBufferViewOrigins));
896 // Create FlatBuffer ViewOrigins
897 std::vector<flatbuffers::Offset<UintVector>> flatBufferViewSizes;
898 flatBufferViewSizes.reserve(viewsDescriptor.GetNumViews());
900 for(unsigned int vIdx = 0; vIdx < viewsDescriptor.GetNumViews(); ++vIdx)
902 std::vector<uint32_t> viewSize;
903 viewSize.reserve(viewsDescriptor.GetNumDimensions());
906 for(unsigned int dIdx = 0; dIdx < viewsDescriptor.GetNumDimensions(); ++dIdx)
908 viewSize.push_back(viewsDescriptor.GetViewSizes(vIdx)[dIdx]);
911 flatBufferViewSizes.push_back(CreateUintVector(m_flatBufferBuilder,
912 m_flatBufferBuilder.CreateVector(viewSize)));
915 // Create FlatBuffer ViewsDescriptor
916 auto flatBufferViewsDescriptor = CreateViewsDescriptor(m_flatBufferBuilder,
917 flatBufferOriginDescriptor,
918 m_flatBufferBuilder.CreateVector(flatBufferViewSizes));
920 // Create FlatBuffer BaseLayer
921 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Splitter);
923 auto flatBufferSplitterLayer = serializer::CreateSplitterLayer(m_flatBufferBuilder,
925 flatBufferViewsDescriptor);
927 CreateAnyLayer(flatBufferSplitterLayer.o, serializer::Layer::Layer_SplitterLayer);
930 void SerializerVisitor::VisitNormalizationLayer(const armnn::IConnectableLayer* layer,
931 const armnn::NormalizationDescriptor& descriptor,
934 auto fbNormalizationBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Normalization);
936 auto fbNormalizationDescriptor = serializer::CreateNormalizationDescriptor(
938 GetFlatBufferNormalizationAlgorithmChannel(descriptor.m_NormChannelType),
939 GetFlatBufferNormalizationAlgorithmMethod(descriptor.m_NormMethodType),
940 descriptor.m_NormSize,
944 GetFlatBufferDataLayout(descriptor.m_DataLayout));
946 auto flatBufferLayer = serializer::CreateNormalizationLayer(m_flatBufferBuilder,
947 fbNormalizationBaseLayer,
948 fbNormalizationDescriptor);
950 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_NormalizationLayer);
953 void SerializerVisitor::VisitStackLayer(const armnn::IConnectableLayer* layer,
954 const armnn::StackDescriptor& stackDescriptor,
957 auto stackBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Stack);
959 std::vector<unsigned int> inputShape;
960 for (unsigned int i =0; i < stackDescriptor.m_InputShape.GetNumDimensions(); i++)
962 inputShape.push_back(stackDescriptor.m_InputShape[i]);
965 auto flatBufferStackDescriptor = CreateStackDescriptor(m_flatBufferBuilder,
966 stackDescriptor.m_Axis,
967 stackDescriptor.m_NumInputs,
968 m_flatBufferBuilder.CreateVector(inputShape));
970 auto stackLayer = serializer::CreateStackLayer(m_flatBufferBuilder, stackBaseLayer, flatBufferStackDescriptor);
971 CreateAnyLayer(stackLayer.o, serializer::Layer::Layer_StackLayer);
974 void SerializerVisitor::VisitStridedSliceLayer(const armnn::IConnectableLayer* layer,
975 const armnn::StridedSliceDescriptor& stridedSliceDescriptor,
978 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_StridedSlice);
980 auto flatBufferDescriptor =
981 CreateStridedSliceDescriptor(m_flatBufferBuilder,
982 m_flatBufferBuilder.CreateVector(stridedSliceDescriptor.m_Begin),
983 m_flatBufferBuilder.CreateVector(stridedSliceDescriptor.m_End),
984 m_flatBufferBuilder.CreateVector(stridedSliceDescriptor.m_Stride),
985 stridedSliceDescriptor.m_BeginMask,
986 stridedSliceDescriptor.m_EndMask,
987 stridedSliceDescriptor.m_ShrinkAxisMask,
988 stridedSliceDescriptor.m_EllipsisMask,
989 stridedSliceDescriptor.m_NewAxisMask,
990 GetFlatBufferDataLayout(stridedSliceDescriptor.m_DataLayout));
992 auto flatBufferLayer = serializer::CreateStridedSliceLayer(m_flatBufferBuilder,
994 flatBufferDescriptor);
996 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_StridedSliceLayer);
999 void SerializerVisitor::VisitSubtractionLayer(const armnn::IConnectableLayer* layer, const char* name)
1001 auto fbSubtractionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Subtraction);
1002 auto fbSubtractionLayer = serializer::CreateSubtractionLayer(m_flatBufferBuilder, fbSubtractionBaseLayer);
1004 CreateAnyLayer(fbSubtractionLayer.o, serializer::Layer::Layer_SubtractionLayer);
1007 void SerializerVisitor::VisitSwitchLayer(const armnn::IConnectableLayer* layer, const char* name)
1009 auto fbSwitchBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Switch);
1010 auto fbSwitchLayer = serializer::CreateSwitchLayer(m_flatBufferBuilder, fbSwitchBaseLayer);
1012 CreateAnyLayer(fbSwitchLayer.o, serializer::Layer::Layer_SwitchLayer);
1015 void SerializerVisitor::VisitTransposeConvolution2dLayer(
1016 const armnn::IConnectableLayer* layer,
1017 const armnn::TransposeConvolution2dDescriptor& descriptor,
1018 const armnn::ConstTensor& weights,
1019 const armnn::Optional<armnn::ConstTensor>& biases,
1022 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Convolution2d);
1023 auto fbDescriptor = CreateTransposeConvolution2dDescriptor(m_flatBufferBuilder,
1024 descriptor.m_PadLeft,
1025 descriptor.m_PadRight,
1026 descriptor.m_PadTop,
1027 descriptor.m_PadBottom,
1028 descriptor.m_StrideX,
1029 descriptor.m_StrideY,
1030 descriptor.m_BiasEnabled,
1031 GetFlatBufferDataLayout(descriptor.m_DataLayout));
1034 auto fbWeightsConstTensorInfo = CreateConstTensorInfo(weights);
1035 flatbuffers::Offset<serializer::ConstTensor> fbBiasesConstTensorInfo;
1036 if (biases.has_value())
1038 fbBiasesConstTensorInfo = CreateConstTensorInfo(biases.value());
1041 auto fbLayer = CreateTransposeConvolution2dLayer(m_flatBufferBuilder,
1044 fbWeightsConstTensorInfo,
1045 fbBiasesConstTensorInfo);
1047 CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_TransposeConvolution2dLayer);
1050 void SerializerVisitor::VisitQuantizedLstmLayer(const armnn::IConnectableLayer* layer,
1051 const armnn::QuantizedLstmInputParams& params,
1054 auto fbQuantizedLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QuantizedLstm);
1056 // Get input parameters
1057 auto inputToInputWeights = CreateConstTensorInfo(params.GetInputToInputWeights());
1058 auto inputToForgetWeights = CreateConstTensorInfo(params.GetInputToForgetWeights());
1059 auto inputToCellWeights = CreateConstTensorInfo(params.GetInputToCellWeights());
1060 auto inputToOutputWeights = CreateConstTensorInfo(params.GetInputToOutputWeights());
1062 auto recurrentToInputWeights = CreateConstTensorInfo(params.GetRecurrentToInputWeights());
1063 auto recurrentToForgetWeights = CreateConstTensorInfo(params.GetRecurrentToForgetWeights());
1064 auto recurrentToCellWeights = CreateConstTensorInfo(params.GetRecurrentToCellWeights());
1065 auto recurrentToOutputWeights = CreateConstTensorInfo(params.GetRecurrentToOutputWeights());
1067 auto inputGateBias = CreateConstTensorInfo(params.GetInputGateBias());
1068 auto forgetGateBias = CreateConstTensorInfo(params.GetForgetGateBias());
1069 auto cellBias = CreateConstTensorInfo(params.GetCellBias());
1070 auto outputGateBias = CreateConstTensorInfo(params.GetOutputGateBias());
1072 auto fbQuantizedLstmParams = serializer::CreateQuantizedLstmInputParams(
1073 m_flatBufferBuilder,
1074 inputToInputWeights,
1075 inputToForgetWeights,
1077 inputToOutputWeights,
1078 recurrentToInputWeights,
1079 recurrentToForgetWeights,
1080 recurrentToCellWeights,
1081 recurrentToOutputWeights,
1087 auto fbQuantizedLstmLayer = serializer::CreateQuantizedLstmLayer(
1088 m_flatBufferBuilder,
1089 fbQuantizedLstmBaseLayer,
1090 fbQuantizedLstmParams);
1092 CreateAnyLayer(fbQuantizedLstmLayer.o, serializer::Layer::Layer_QuantizedLstmLayer);
1095 fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer,
1096 const serializer::LayerType layerType)
1098 uint32_t fbIndex = GetSerializedId(layer->GetGuid());
1100 std::vector<fb::Offset<serializer::InputSlot>> inputSlots = CreateInputSlots(layer);
1101 std::vector<fb::Offset<serializer::OutputSlot>> outputSlots = CreateOutputSlots(layer);
1103 return serializer::CreateLayerBase(m_flatBufferBuilder,
1105 m_flatBufferBuilder.CreateString(layer->GetName()),
1107 m_flatBufferBuilder.CreateVector(inputSlots),
1108 m_flatBufferBuilder.CreateVector(outputSlots));
1111 void SerializerVisitor::CreateAnyLayer(const flatbuffers::Offset<void>& layer, const serializer::Layer serializerLayer)
1113 auto anyLayer = armnnSerializer::CreateAnyLayer(m_flatBufferBuilder, serializerLayer, layer);
1114 m_serializedLayers.push_back(anyLayer);
1117 template <typename T>
1118 flatbuffers::Offset<flatbuffers::Vector<T>> SerializerVisitor::CreateDataVector(const void* memory, unsigned int size)
1120 const T* buffer = reinterpret_cast<const T*>(memory);
1121 std::vector<T> vector(buffer, buffer + (size / sizeof(T)));
1122 auto fbVector = m_flatBufferBuilder.CreateVector(vector);
1126 flatbuffers::Offset<serializer::ConstTensor>
1127 SerializerVisitor::CreateConstTensorInfo(const armnn::ConstTensor& constTensor)
1129 armnn::TensorInfo tensorInfo = constTensor.GetInfo();
1131 // Get the dimensions
1132 std::vector<unsigned int> shape;
1134 for(unsigned int dim = 0; dim < tensorInfo.GetShape().GetNumDimensions(); ++dim)
1136 shape.push_back(tensorInfo.GetShape()[dim]);
1139 // Create FlatBuffer TensorInfo
1140 auto flatBufferTensorInfo = serializer::CreateTensorInfo(m_flatBufferBuilder,
1141 m_flatBufferBuilder.CreateVector(shape),
1142 GetFlatBufferDataType(tensorInfo.GetDataType()),
1143 tensorInfo.GetQuantizationScale(),
1144 tensorInfo.GetQuantizationOffset());
1145 flatbuffers::Offset<void> fbPayload;
1147 switch (tensorInfo.GetDataType())
1149 case armnn::DataType::Float32:
1150 case armnn::DataType::Signed32:
1152 auto fbVector = CreateDataVector<int32_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1153 flatbuffers::Offset<serializer::IntData> flatBuffersData = serializer::CreateIntData(
1154 m_flatBufferBuilder,
1156 fbPayload = flatBuffersData.o;
1159 case armnn::DataType::Float16:
1161 auto fbVector = CreateDataVector<int16_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1162 flatbuffers::Offset<serializer::ShortData> flatBuffersData = serializer::CreateShortData(
1163 m_flatBufferBuilder,
1165 fbPayload = flatBuffersData.o;
1168 case armnn::DataType::QuantisedSymm16:
1170 auto fbVector = CreateDataVector<int16_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1171 flatbuffers::Offset<serializer::ShortData> flatBuffersData = serializer::CreateShortData(
1172 m_flatBufferBuilder,
1174 fbPayload = flatBuffersData.o;
1177 case armnn::DataType::QuantisedAsymm8:
1178 case armnn::DataType::Boolean:
1181 auto fbVector = CreateDataVector<int8_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1182 flatbuffers::Offset<serializer::ByteData> flatBuffersData = serializer::CreateByteData(
1183 m_flatBufferBuilder,
1185 fbPayload = flatBuffersData.o;
1188 flatbuffers::Offset<serializer::ConstTensor> flatBufferConstTensor = serializer::CreateConstTensor(
1189 m_flatBufferBuilder,
1190 flatBufferTensorInfo,
1191 GetFlatBufferConstTensorData(tensorInfo.GetDataType()),
1193 return flatBufferConstTensor;
1196 std::vector<fb::Offset<serializer::InputSlot>>
1197 SerializerVisitor::CreateInputSlots(const armnn::IConnectableLayer* layer)
1199 std::vector<fb::Offset<serializer::InputSlot>> inputSlots;
1201 // Get the InputSlots
1202 for (unsigned int slotIndex = 0; slotIndex<layer->GetNumInputSlots(); ++slotIndex)
1204 const IInputSlot& inputSlot = layer->GetInputSlot(slotIndex);
1206 // Get the Connection for the InputSlot
1207 const IOutputSlot* connection = inputSlot.GetConnection();
1209 // Create FlatBuffer Connection
1210 serializer::Connection conn(GetSerializedId(inputSlot.GetConnection()->GetOwningLayerGuid()),
1211 connection->CalculateIndexOnOwner());
1212 // Create FlatBuffer InputSlot
1213 inputSlots.push_back(serializer::CreateInputSlot(m_flatBufferBuilder, slotIndex, &conn));
1218 std::vector<fb::Offset<serializer::OutputSlot>>
1219 SerializerVisitor::CreateOutputSlots(const armnn::IConnectableLayer* layer)
1221 std::vector<fb::Offset<serializer::OutputSlot>> outputSlots;
1223 // Get the OutputSlots
1224 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumOutputSlots(); ++slotIndex)
1226 const IOutputSlot& outputSlot = layer->GetOutputSlot(slotIndex);
1227 const armnn::TensorInfo& tensorInfo = outputSlot.GetTensorInfo();
1229 // Get the dimensions
1230 std::vector<unsigned int> shape;
1231 for(unsigned int dim = 0; dim < tensorInfo.GetShape().GetNumDimensions(); ++dim)
1233 shape.push_back(tensorInfo.GetShape()[dim]);
1236 // Create FlatBuffer TensorInfo
1237 auto flatBufferTensorInfo = serializer::CreateTensorInfo(m_flatBufferBuilder,
1238 m_flatBufferBuilder.CreateVector(shape),
1239 GetFlatBufferDataType(tensorInfo.GetDataType()),
1240 tensorInfo.GetQuantizationScale(),
1241 tensorInfo.GetQuantizationOffset());
1243 // Create FlatBuffer Outputslot
1244 outputSlots.push_back(serializer::CreateOutputSlot(m_flatBufferBuilder,
1246 flatBufferTensorInfo));
1252 ISerializer* ISerializer::CreateRaw()
1254 return new Serializer();
1257 ISerializerPtr ISerializer::Create()
1259 return ISerializerPtr(CreateRaw(), &ISerializer::Destroy);
1262 void ISerializer::Destroy(ISerializer* serializer)
1267 void Serializer::Serialize(const INetwork& inNetwork)
1269 // Iterate through to network
1270 inNetwork.Accept(m_SerializerVisitor);
1271 flatbuffers::FlatBufferBuilder& fbBuilder = m_SerializerVisitor.GetFlatBufferBuilder();
1273 // Create FlatBuffer SerializedGraph
1274 auto serializedGraph = serializer::CreateSerializedGraph(
1276 fbBuilder.CreateVector(m_SerializerVisitor.GetSerializedLayers()),
1277 fbBuilder.CreateVector(m_SerializerVisitor.GetInputIds()),
1278 fbBuilder.CreateVector(m_SerializerVisitor.GetOutputIds()));
1280 // Serialize the graph
1281 fbBuilder.Finish(serializedGraph);
1284 bool Serializer::SaveSerializedToStream(std::ostream& stream)
1286 flatbuffers::FlatBufferBuilder& fbBuilder = m_SerializerVisitor.GetFlatBufferBuilder();
1288 auto bytesToWrite = boost::numeric_cast<std::streamsize>(fbBuilder.GetSize());
1289 stream.write(reinterpret_cast<const char*>(fbBuilder.GetBufferPointer()), bytesToWrite);
1290 return !stream.bad();
1293 } // namespace armnnSerializer