2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include "Serializer.hpp"
8 #include "SerializerUtils.hpp"
10 #include <armnn/ArmNN.hpp>
14 #include <boost/numeric/conversion/cast.hpp>
16 #include <flatbuffers/util.h>
18 using namespace armnn;
19 namespace fb = flatbuffers;
20 namespace serializer = armnnSerializer;
22 namespace armnnSerializer
25 serializer::ActivationFunction GetFlatBufferActivationFunction(armnn::ActivationFunction function)
29 case armnn::ActivationFunction::Sigmoid:
30 return serializer::ActivationFunction::ActivationFunction_Sigmoid;
31 case armnn::ActivationFunction::TanH:
32 return serializer::ActivationFunction::ActivationFunction_TanH;
33 case armnn::ActivationFunction::Linear:
34 return serializer::ActivationFunction::ActivationFunction_Linear;
35 case armnn::ActivationFunction::ReLu:
36 return serializer::ActivationFunction::ActivationFunction_ReLu;
37 case armnn::ActivationFunction::BoundedReLu:
38 return serializer::ActivationFunction::ActivationFunction_BoundedReLu;
39 case armnn::ActivationFunction::LeakyReLu:
40 return serializer::ActivationFunction::ActivationFunction_LeakyReLu;
41 case armnn::ActivationFunction::Abs:
42 return serializer::ActivationFunction::ActivationFunction_Abs;
43 case armnn::ActivationFunction::Sqrt:
44 return serializer::ActivationFunction::ActivationFunction_Sqrt;
45 case armnn::ActivationFunction::Square:
46 return serializer::ActivationFunction::ActivationFunction_Square;
48 return serializer::ActivationFunction::ActivationFunction_Sigmoid;
52 uint32_t SerializerVisitor::GetSerializedId(unsigned int guid)
54 std::pair<unsigned int, uint32_t> guidPair(guid, m_layerId);
56 if (m_guidMap.empty())
58 m_guidMap.insert(guidPair);
60 else if (m_guidMap.find(guid) == m_guidMap.end())
62 guidPair.second = ++m_layerId;
63 m_guidMap.insert(guidPair);
66 return m_guidMap[guid];
69 // Build FlatBuffer for Input Layer
70 void SerializerVisitor::VisitInputLayer(const armnn::IConnectableLayer* layer, LayerBindingId id, const char* name)
72 // Create FlatBuffer BaseLayer
73 auto flatBufferInputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Input);
75 // Create FlatBuffer BindableBaseLayer
76 auto flatBufferInputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
77 flatBufferInputBaseLayer,
79 // Push layer Guid to outputIds.
80 m_inputIds.push_back(GetSerializedId(layer->GetGuid()));
82 // Create the FlatBuffer InputLayer
83 auto flatBufferInputLayer = serializer::CreateInputLayer(m_flatBufferBuilder, flatBufferInputBindableBaseLayer);
85 // Add the AnyLayer to the FlatBufferLayers
86 CreateAnyLayer(flatBufferInputLayer.o, serializer::Layer::Layer_InputLayer);
89 // Build FlatBuffer for Output Layer
90 void SerializerVisitor::VisitOutputLayer(const armnn::IConnectableLayer* layer, LayerBindingId id, const char* name)
92 // Create FlatBuffer BaseLayer
93 auto flatBufferOutputBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Output);
95 // Create FlatBuffer BindableBaseLayer
96 auto flatBufferOutputBindableBaseLayer = serializer::CreateBindableLayerBase(m_flatBufferBuilder,
97 flatBufferOutputBaseLayer,
99 // Push layer Guid to outputIds.
100 m_outputIds.push_back(GetSerializedId(layer->GetGuid()));
102 // Create the FlatBuffer OutputLayer
103 auto flatBufferOutputLayer = serializer::CreateOutputLayer(m_flatBufferBuilder, flatBufferOutputBindableBaseLayer);
104 // Add the AnyLayer to the FlatBufferLayers
105 CreateAnyLayer(flatBufferOutputLayer.o, serializer::Layer::Layer_OutputLayer);
108 void SerializerVisitor::VisitAbsLayer(const armnn::IConnectableLayer* layer, const char* name)
110 throw UnimplementedException("SerializerVisitor::VisitAbsLayer is not implemented");
113 // Build FlatBuffer for Activation Layer
114 void SerializerVisitor::VisitActivationLayer(const armnn::IConnectableLayer* layer,
115 const armnn::ActivationDescriptor& descriptor,
118 // Create FlatBuffer BaseLayer
119 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Activation);
121 // Create the FlatBuffer ActivationDescriptor
122 auto flatBufferDescriptor = CreateActivationDescriptor(m_flatBufferBuilder,
123 GetFlatBufferActivationFunction(descriptor.m_Function),
127 // Create the FlatBuffer ActivationLayer
128 auto flatBufferAdditionLayer = CreateActivationLayer(m_flatBufferBuilder,
130 flatBufferDescriptor);
132 // Add the AnyLayer to the FlatBufferLayers
133 CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_ActivationLayer);
136 // Build FlatBuffer for Addition Layer
137 void SerializerVisitor::VisitAdditionLayer(const armnn::IConnectableLayer* layer, const char* name)
139 // Create FlatBuffer BaseLayer
140 auto flatBufferAdditionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Addition);
142 // Create the FlatBuffer AdditionLayer
143 auto flatBufferAdditionLayer = serializer::CreateAdditionLayer(m_flatBufferBuilder, flatBufferAdditionBaseLayer);
145 // Add the AnyLayer to the FlatBufferLayers
146 CreateAnyLayer(flatBufferAdditionLayer.o, serializer::Layer::Layer_AdditionLayer);
149 // Build FlatBuffer for ArgMinMax Layer
150 void SerializerVisitor::VisitArgMinMaxLayer(const armnn::IConnectableLayer *layer,
151 const armnn::ArgMinMaxDescriptor& descriptor,
154 // This will be implemented in IVGCVSW-3724
155 throw UnimplementedException("SerializerVisitor::VisitArgMinMaxLayer is not implemented");
158 // Build FlatBuffer for BatchToSpaceNd Layer
159 void SerializerVisitor::VisitBatchToSpaceNdLayer(const armnn::IConnectableLayer* layer,
160 const armnn::BatchToSpaceNdDescriptor& descriptor,
163 // Create FlatBuffer BaseLayer
164 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_BatchToSpaceNd);
166 std::vector<unsigned int> crops;
167 crops.reserve(descriptor.m_Crops.size() * 2);
168 for (auto& crop : descriptor.m_Crops)
170 crops.push_back(crop.first);
171 crops.push_back(crop.second);
174 auto flatBufferDescriptor =
175 CreateBatchToSpaceNdDescriptor(m_flatBufferBuilder,
176 m_flatBufferBuilder.CreateVector(descriptor.m_BlockShape),
177 m_flatBufferBuilder.CreateVector(crops),
178 GetFlatBufferDataLayout(descriptor.m_DataLayout));
180 auto flatBufferLayer = serializer::CreateBatchToSpaceNdLayer(m_flatBufferBuilder,
182 flatBufferDescriptor);
184 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_BatchToSpaceNdLayer);
187 void SerializerVisitor::VisitBatchNormalizationLayer(const armnn::IConnectableLayer* layer,
188 const armnn::BatchNormalizationDescriptor& batchNormDescriptor,
189 const armnn::ConstTensor& mean,
190 const armnn::ConstTensor& variance,
191 const armnn::ConstTensor& beta,
192 const armnn::ConstTensor& gamma,
195 auto fbBatchNormalizationBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_BatchNormalization);
196 auto fbBatchNormalizationDescriptor = serializer::CreateBatchNormalizationDescriptor(
198 batchNormDescriptor.m_Eps,
199 GetFlatBufferDataLayout(batchNormDescriptor.m_DataLayout));
201 auto fbMeanConstTensorInfo = CreateConstTensorInfo(mean);
202 auto fbVarianceConstTensorInfo = CreateConstTensorInfo(variance);
203 auto fbBetaConstTensorInfo = CreateConstTensorInfo(beta);
204 auto fbGammaConstTensorInfo = CreateConstTensorInfo(gamma);
205 auto fbBatchNormalizationLayer = serializer::CreateBatchNormalizationLayer(m_flatBufferBuilder,
206 fbBatchNormalizationBaseLayer,
207 fbBatchNormalizationDescriptor,
208 fbMeanConstTensorInfo,
209 fbVarianceConstTensorInfo,
210 fbBetaConstTensorInfo,
211 fbGammaConstTensorInfo);
213 CreateAnyLayer(fbBatchNormalizationLayer.o, serializer::Layer::Layer_BatchNormalizationLayer);
216 // Build FlatBuffer for Constant Layer
217 void SerializerVisitor::VisitConstantLayer(const armnn::IConnectableLayer* layer,
218 const armnn::ConstTensor& input,
221 // Create FlatBuffer BaseLayer
222 auto flatBufferConstantBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Constant);
224 auto flatBufferConstTensorInfo = CreateConstTensorInfo(input);
226 // Create the FlatBuffer ConstantLayer
227 auto flatBufferLayer = CreateConstantLayer(m_flatBufferBuilder,
228 flatBufferConstantBaseLayer,
229 flatBufferConstTensorInfo);
231 // Add the AnyLayer to the FlatBufferLayers
232 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ConstantLayer);
235 // Build FlatBuffer for Convolution2dLayer
236 void SerializerVisitor::VisitConvolution2dLayer(const armnn::IConnectableLayer* layer,
237 const armnn::Convolution2dDescriptor& descriptor,
238 const armnn::ConstTensor& weights,
239 const armnn::Optional<armnn::ConstTensor>& biases,
242 // Create FlatBuffer BaseLayer
243 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Convolution2d);
245 auto flatBufferDescriptor = CreateConvolution2dDescriptor(m_flatBufferBuilder,
246 descriptor.m_PadLeft,
247 descriptor.m_PadRight,
249 descriptor.m_PadBottom,
250 descriptor.m_StrideX,
251 descriptor.m_StrideY,
252 descriptor.m_DilationX,
253 descriptor.m_DilationY,
254 descriptor.m_BiasEnabled,
255 GetFlatBufferDataLayout(descriptor.m_DataLayout));
256 auto flatBufferWeightsConstTensorInfo = CreateConstTensorInfo(weights);
257 flatbuffers::Offset<serializer::ConstTensor> flatBufferBiasesConstTensorInfo;
259 if (biases.has_value())
261 flatBufferBiasesConstTensorInfo = CreateConstTensorInfo(biases.value());
264 // Create the FlatBuffer Convolution2dLayer
265 auto flatBufferLayer = CreateConvolution2dLayer(m_flatBufferBuilder,
267 flatBufferDescriptor,
268 flatBufferWeightsConstTensorInfo,
269 flatBufferBiasesConstTensorInfo);
271 // Add the AnyLayer to the FlatBufferLayers
272 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_Convolution2dLayer);
275 void SerializerVisitor::VisitDepthwiseConvolution2dLayer(const armnn::IConnectableLayer* layer,
276 const armnn::DepthwiseConvolution2dDescriptor& descriptor,
277 const armnn::ConstTensor& weights,
278 const armnn::Optional<armnn::ConstTensor>& biases,
281 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_DepthwiseConvolution2d);
282 auto fbDescriptor = CreateDepthwiseConvolution2dDescriptor(m_flatBufferBuilder,
283 descriptor.m_PadLeft,
284 descriptor.m_PadRight,
286 descriptor.m_PadBottom,
287 descriptor.m_StrideX,
288 descriptor.m_StrideY,
289 descriptor.m_DilationX,
290 descriptor.m_DilationY,
291 descriptor.m_BiasEnabled,
292 GetFlatBufferDataLayout(descriptor.m_DataLayout));
294 flatbuffers::Offset<serializer::ConstTensor> fbWeightsConstTensorInfo = CreateConstTensorInfo(weights);
295 flatbuffers::Offset<serializer::ConstTensor> fbBiasesConstTensorInfo;
296 if (biases.has_value())
298 fbBiasesConstTensorInfo = CreateConstTensorInfo(biases.value());
301 auto flatBufferLayer = CreateDepthwiseConvolution2dLayer(m_flatBufferBuilder,
304 fbWeightsConstTensorInfo,
305 fbBiasesConstTensorInfo);
307 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_DepthwiseConvolution2dLayer);
310 void SerializerVisitor::VisitDequantizeLayer(const armnn::IConnectableLayer* layer,
313 auto fbDequantizeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Dequantize);
314 auto fbDequantizeLayer = serializer::CreateDequantizeLayer(m_flatBufferBuilder, fbDequantizeBaseLayer);
316 CreateAnyLayer(fbDequantizeLayer.o, serializer::Layer::Layer_DequantizeLayer);
319 void SerializerVisitor::VisitDetectionPostProcessLayer(const armnn::IConnectableLayer* layer,
320 const armnn::DetectionPostProcessDescriptor& descriptor,
321 const armnn::ConstTensor& anchors,
324 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_DetectionPostProcess);
325 auto fbDescriptor = CreateDetectionPostProcessDescriptor(m_flatBufferBuilder,
326 descriptor.m_MaxDetections,
327 descriptor.m_MaxClassesPerDetection,
328 descriptor.m_DetectionsPerClass,
329 descriptor.m_NmsScoreThreshold,
330 descriptor.m_NmsIouThreshold,
331 descriptor.m_NumClasses,
332 descriptor.m_UseRegularNms,
336 descriptor.m_ScaleH);
338 flatbuffers::Offset<serializer::ConstTensor> fbAnchorsConstTensorInfo = CreateConstTensorInfo(anchors);
340 auto flatBufferLayer = CreateDetectionPostProcessLayer(m_flatBufferBuilder,
343 fbAnchorsConstTensorInfo);
345 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_DetectionPostProcessLayer);
348 void SerializerVisitor::VisitDivisionLayer(const armnn::IConnectableLayer* layer, const char* name)
350 auto fbDivisionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Division);
351 auto fbDivisionLayer = serializer::CreateDivisionLayer(m_flatBufferBuilder, fbDivisionBaseLayer);
353 CreateAnyLayer(fbDivisionLayer.o, serializer::Layer::Layer_DivisionLayer);
356 void SerializerVisitor::VisitEqualLayer(const armnn::IConnectableLayer* layer, const char* name)
358 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Equal);
359 auto fbEqualLayer = serializer::CreateEqualLayer(m_flatBufferBuilder, fbBaseLayer);
361 CreateAnyLayer(fbEqualLayer.o, serializer::Layer::Layer_EqualLayer);
364 void SerializerVisitor::VisitFloorLayer(const armnn::IConnectableLayer *layer, const char *name)
366 auto flatBufferFloorBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Floor);
367 auto flatBufferFloorLayer = serializer::CreateFloorLayer(m_flatBufferBuilder, flatBufferFloorBaseLayer);
369 CreateAnyLayer(flatBufferFloorLayer.o, serializer::Layer::Layer_FloorLayer);
372 void SerializerVisitor::VisitGatherLayer(const armnn::IConnectableLayer* layer, const char* name)
374 auto fbGatherBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Gather);
375 auto flatBufferLayer = serializer::CreateGatherLayer(m_flatBufferBuilder, fbGatherBaseLayer);
377 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_GatherLayer);
380 void SerializerVisitor::VisitGreaterLayer(const armnn::IConnectableLayer* layer, const char* name)
382 auto fbGreaterBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Greater);
383 auto fbGreaterLayer = serializer::CreateGreaterLayer(m_flatBufferBuilder, fbGreaterBaseLayer);
385 CreateAnyLayer(fbGreaterLayer.o, serializer::Layer::Layer_GreaterLayer);
388 void SerializerVisitor::VisitL2NormalizationLayer(const armnn::IConnectableLayer* layer,
389 const armnn::L2NormalizationDescriptor& l2NormalizationDescriptor,
392 // Create FlatBuffer BaseLayer
393 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_L2Normalization);
395 // Create the FlatBuffer L2Normalization Descriptor
396 auto fbDescriptor = serializer::CreateL2NormalizationDescriptor(
398 GetFlatBufferDataLayout(l2NormalizationDescriptor.m_DataLayout),
399 l2NormalizationDescriptor.m_Eps);
401 // Create FlatBuffer layer
402 auto fbLayer = serializer::CreateL2NormalizationLayer(m_flatBufferBuilder, fbBaseLayer, fbDescriptor);
404 CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_L2NormalizationLayer);
407 void SerializerVisitor::VisitLstmLayer(const armnn::IConnectableLayer* layer, const armnn::LstmDescriptor& descriptor,
408 const armnn::LstmInputParams& params, const char* name)
410 auto fbLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Lstm);
412 auto fbLstmDescriptor = serializer::CreateLstmDescriptor(
414 descriptor.m_ActivationFunc,
415 descriptor.m_ClippingThresCell,
416 descriptor.m_ClippingThresProj,
417 descriptor.m_CifgEnabled,
418 descriptor.m_PeepholeEnabled,
419 descriptor.m_ProjectionEnabled,
420 descriptor.m_LayerNormEnabled);
422 // Get mandatory input parameters
423 auto inputToForgetWeights = CreateConstTensorInfo(*params.m_InputToForgetWeights);
424 auto inputToCellWeights = CreateConstTensorInfo(*params.m_InputToCellWeights);
425 auto inputToOutputWeights = CreateConstTensorInfo(*params.m_InputToOutputWeights);
426 auto recurrentToForgetWeights = CreateConstTensorInfo(*params.m_RecurrentToForgetWeights);
427 auto recurrentToCellWeights = CreateConstTensorInfo(*params.m_RecurrentToCellWeights);
428 auto recurrentToOutputWeights = CreateConstTensorInfo(*params.m_RecurrentToOutputWeights);
429 auto forgetGateBias = CreateConstTensorInfo(*params.m_ForgetGateBias);
430 auto cellBias = CreateConstTensorInfo(*params.m_CellBias);
431 auto outputGateBias = CreateConstTensorInfo(*params.m_OutputGateBias);
433 //Define optional parameters, these will be set depending on configuration in Lstm descriptor
434 flatbuffers::Offset<serializer::ConstTensor> inputToInputWeights;
435 flatbuffers::Offset<serializer::ConstTensor> recurrentToInputWeights;
436 flatbuffers::Offset<serializer::ConstTensor> cellToInputWeights;
437 flatbuffers::Offset<serializer::ConstTensor> inputGateBias;
438 flatbuffers::Offset<serializer::ConstTensor> projectionWeights;
439 flatbuffers::Offset<serializer::ConstTensor> projectionBias;
440 flatbuffers::Offset<serializer::ConstTensor> cellToForgetWeights;
441 flatbuffers::Offset<serializer::ConstTensor> cellToOutputWeights;
442 flatbuffers::Offset<serializer::ConstTensor> inputLayerNormWeights;
443 flatbuffers::Offset<serializer::ConstTensor> forgetLayerNormWeights;
444 flatbuffers::Offset<serializer::ConstTensor> cellLayerNormWeights;
445 flatbuffers::Offset<serializer::ConstTensor> outputLayerNormWeights;
447 if (!descriptor.m_CifgEnabled)
449 inputToInputWeights = CreateConstTensorInfo(*params.m_InputToInputWeights);
450 recurrentToInputWeights = CreateConstTensorInfo(*params.m_RecurrentToInputWeights);
451 cellToInputWeights = CreateConstTensorInfo(*params.m_CellToInputWeights);
452 inputGateBias = CreateConstTensorInfo(*params.m_InputGateBias);
455 if (descriptor.m_ProjectionEnabled)
457 projectionWeights = CreateConstTensorInfo(*params.m_ProjectionWeights);
458 projectionBias = CreateConstTensorInfo(*params.m_ProjectionBias);
461 if (descriptor.m_PeepholeEnabled)
463 cellToForgetWeights = CreateConstTensorInfo(*params.m_CellToForgetWeights);
464 cellToOutputWeights = CreateConstTensorInfo(*params.m_CellToOutputWeights);
467 if (descriptor.m_LayerNormEnabled)
469 if (!descriptor.m_CifgEnabled)
471 inputLayerNormWeights = CreateConstTensorInfo((*params.m_InputLayerNormWeights));
473 forgetLayerNormWeights = CreateConstTensorInfo(*params.m_ForgetLayerNormWeights);
474 cellLayerNormWeights = CreateConstTensorInfo(*params.m_CellLayerNormWeights);
475 outputLayerNormWeights = CreateConstTensorInfo(*params.m_OutputLayerNormWeights);
478 auto fbLstmParams = serializer::CreateLstmInputParams(
480 inputToForgetWeights,
482 inputToOutputWeights,
483 recurrentToForgetWeights,
484 recurrentToCellWeights,
485 recurrentToOutputWeights,
490 recurrentToInputWeights,
497 inputLayerNormWeights,
498 forgetLayerNormWeights,
499 cellLayerNormWeights,
500 outputLayerNormWeights);
502 auto fbLstmLayer = serializer::CreateLstmLayer(
508 CreateAnyLayer(fbLstmLayer.o, serializer::Layer::Layer_LstmLayer);
511 void SerializerVisitor::VisitMaximumLayer(const armnn::IConnectableLayer* layer, const char* name)
513 auto fbMaximumBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Maximum);
514 auto fbMaximumLayer = serializer::CreateMaximumLayer(m_flatBufferBuilder, fbMaximumBaseLayer);
516 CreateAnyLayer(fbMaximumLayer.o, serializer::Layer::Layer_MaximumLayer);
519 void SerializerVisitor::VisitMeanLayer(const armnn::IConnectableLayer* layer,
520 const armnn::MeanDescriptor& descriptor,
523 auto fbMeanBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Mean);
524 auto fbMeanDescriptor = serializer::CreateMeanDescriptor(m_flatBufferBuilder,
525 m_flatBufferBuilder.CreateVector(descriptor.m_Axis),
526 descriptor.m_KeepDims);
528 auto fbMeanLayer = serializer::CreateMeanLayer(m_flatBufferBuilder,
532 CreateAnyLayer(fbMeanLayer.o, serializer::Layer::Layer_MeanLayer);
535 void SerializerVisitor::VisitMinimumLayer(const armnn::IConnectableLayer* layer, const char* name)
537 auto fbMinimumBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Minimum);
538 auto fbMinimumLayer = serializer::CreateMinimumLayer(m_flatBufferBuilder, fbMinimumBaseLayer);
540 CreateAnyLayer(fbMinimumLayer.o, serializer::Layer::Layer_MinimumLayer);
543 void SerializerVisitor::VisitMergeLayer(const armnn::IConnectableLayer* layer, const char* name)
545 auto fbMergeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Merge);
546 auto fbMergeLayer = serializer::CreateMergeLayer(m_flatBufferBuilder, fbMergeBaseLayer);
548 CreateAnyLayer(fbMergeLayer.o, serializer::Layer::Layer_MergeLayer);
551 void SerializerVisitor::VisitMergerLayer(const armnn::IConnectableLayer* layer,
552 const armnn::MergerDescriptor& mergerDescriptor,
555 VisitConcatLayer(layer, mergerDescriptor, name);
558 void SerializerVisitor::VisitConcatLayer(const armnn::IConnectableLayer* layer,
559 const armnn::ConcatDescriptor& concatDescriptor,
562 auto flatBufferConcatBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Concat);
564 std::vector<flatbuffers::Offset<UintVector>> views;
565 for (unsigned int v = 0; v < concatDescriptor.GetNumViews(); ++v)
567 const uint32_t* origin = concatDescriptor.GetViewOrigin(v);
568 std::vector<uint32_t> origins;
569 for (unsigned int d = 0; d < concatDescriptor.GetNumDimensions(); ++d)
571 origins.push_back(origin[d]);
573 auto view = m_flatBufferBuilder.CreateVector(origins);
574 auto uintVector = CreateUintVector(m_flatBufferBuilder, view);
575 views.push_back(uintVector);
578 auto flatBufferConcatDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder,
579 concatDescriptor.GetConcatAxis(),
580 concatDescriptor.GetNumViews(),
581 concatDescriptor.GetNumDimensions(),
582 m_flatBufferBuilder.CreateVector(views));
584 auto flatBufferLayer = CreateConcatLayer(m_flatBufferBuilder,
585 flatBufferConcatBaseLayer,
586 flatBufferConcatDescriptor);
588 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ConcatLayer);
591 void SerializerVisitor::VisitMultiplicationLayer(const armnn::IConnectableLayer* layer, const char* name)
593 auto fbMultiplicationBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Multiplication);
594 auto fbMultiplicationLayer = serializer::CreateMultiplicationLayer(m_flatBufferBuilder,
595 fbMultiplicationBaseLayer);
597 CreateAnyLayer(fbMultiplicationLayer.o, serializer::Layer::Layer_MultiplicationLayer);
600 void SerializerVisitor::VisitPadLayer(const armnn::IConnectableLayer* layer,
601 const armnn::PadDescriptor& padDescriptor,
604 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Pad);
606 std::vector<unsigned int> padList;
607 for (auto& p: padDescriptor.m_PadList)
609 padList.push_back(p.first);
610 padList.push_back(p.second);
613 auto flatBufferPadDesc = serializer::CreatePadDescriptor(m_flatBufferBuilder,
614 m_flatBufferBuilder.CreateVector(padList),
615 padDescriptor.m_PadValue);
617 auto flatBufferPadLayer = serializer::CreatePadLayer(m_flatBufferBuilder,
621 CreateAnyLayer(flatBufferPadLayer.o, serializer::Layer::Layer_PadLayer);
624 void SerializerVisitor::VisitPermuteLayer(const armnn::IConnectableLayer* layer,
625 const armnn::PermuteDescriptor& permuteDescriptor,
628 // Create FlatBuffer BaseLayer
629 auto flatBufferPermuteBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Permute);
631 std::vector<unsigned int> dimMappings;
632 for (unsigned int i=0; i<permuteDescriptor.m_DimMappings.GetSize(); ++i)
634 dimMappings.push_back(permuteDescriptor.m_DimMappings[i]);
637 auto flatBufferPermuteDesc = serializer::CreatePermuteDescriptor(m_flatBufferBuilder,
638 m_flatBufferBuilder.CreateVector(dimMappings));
640 // Create the FlatBuffer PermuteLayer
641 auto flatBufferPermuteLayer = serializer::CreatePermuteLayer(m_flatBufferBuilder,
642 flatBufferPermuteBaseLayer,
643 flatBufferPermuteDesc);
645 // Add the AnyLayer to the FlatBufferLayers
646 CreateAnyLayer(flatBufferPermuteLayer.o, serializer::Layer::Layer_PermuteLayer);
649 // Build FlatBuffer for Reshape Layer
650 void SerializerVisitor::VisitReshapeLayer(const armnn::IConnectableLayer* layer,
651 const armnn::ReshapeDescriptor& reshapeDescriptor,
654 // Create FlatBuffer BaseLayer
655 auto flatBufferReshapeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Reshape);
657 std::vector<unsigned int> targetShape;
658 for (unsigned int i =0; i < reshapeDescriptor.m_TargetShape.GetNumDimensions(); i++)
660 targetShape.push_back(reshapeDescriptor.m_TargetShape[i]);
663 auto flatBufferReshapeDesc = serializer::CreateReshapeDescriptor(m_flatBufferBuilder,
664 m_flatBufferBuilder.CreateVector(targetShape));
666 // Create the FlatBuffer ReshapeLayer
667 auto flatBufferReshapeLayer = serializer::CreateReshapeLayer(m_flatBufferBuilder, flatBufferReshapeBaseLayer,
668 flatBufferReshapeDesc);
670 // Add the AnyLayer to the FlatBufferLayers
671 CreateAnyLayer(flatBufferReshapeLayer.o, serializer::Layer::Layer_ReshapeLayer);
674 void SerializerVisitor::VisitResizeBilinearLayer(const armnn::IConnectableLayer* layer,
675 const armnn::ResizeBilinearDescriptor& resizeDescriptor,
678 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_ResizeBilinear);
680 auto flatBufferDescriptor =
681 CreateResizeBilinearDescriptor(m_flatBufferBuilder,
682 resizeDescriptor.m_TargetWidth,
683 resizeDescriptor.m_TargetHeight,
684 GetFlatBufferDataLayout(resizeDescriptor.m_DataLayout));
686 auto flatBufferLayer = serializer::CreateResizeBilinearLayer(m_flatBufferBuilder,
688 flatBufferDescriptor);
690 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ResizeBilinearLayer);
693 void SerializerVisitor::VisitResizeLayer(const armnn::IConnectableLayer* layer,
694 const armnn::ResizeDescriptor& resizeDescriptor,
697 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Resize);
699 auto flatBufferDescriptor =
700 CreateResizeDescriptor(m_flatBufferBuilder,
701 resizeDescriptor.m_TargetHeight,
702 resizeDescriptor.m_TargetWidth,
703 GetFlatBufferResizeMethod(resizeDescriptor.m_Method),
704 GetFlatBufferDataLayout(resizeDescriptor.m_DataLayout));
706 auto flatBufferLayer = serializer::CreateResizeLayer(m_flatBufferBuilder,
708 flatBufferDescriptor);
710 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_ResizeLayer);
713 void SerializerVisitor::VisitRsqrtLayer(const armnn::IConnectableLayer* layer, const char* name)
715 auto fbRsqrtBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Rsqrt);
716 auto fbRsqrtLayer = serializer::CreateRsqrtLayer(m_flatBufferBuilder, fbRsqrtBaseLayer);
718 CreateAnyLayer(fbRsqrtLayer.o, serializer::Layer::Layer_RsqrtLayer);
721 // Build FlatBuffer for Softmax Layer
722 void SerializerVisitor::VisitSoftmaxLayer(const armnn::IConnectableLayer* layer,
723 const armnn::SoftmaxDescriptor& softmaxDescriptor,
726 // Create FlatBuffer BaseLayer
727 auto flatBufferSoftmaxBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Softmax);
729 // Create the FlatBuffer SoftmaxDescriptor
730 auto flatBufferSoftmaxDesc =
731 serializer::CreateSoftmaxDescriptor(m_flatBufferBuilder, softmaxDescriptor.m_Beta);
733 // Create the FlatBuffer SoftmaxLayer
734 auto flatBufferSoftmaxLayer =
735 serializer::CreateSoftmaxLayer(m_flatBufferBuilder,
736 flatBufferSoftmaxBaseLayer,
737 flatBufferSoftmaxDesc);
739 CreateAnyLayer(flatBufferSoftmaxLayer.o, serializer::Layer::Layer_SoftmaxLayer);
742 void SerializerVisitor::VisitPooling2dLayer(const armnn::IConnectableLayer* layer,
743 const armnn::Pooling2dDescriptor& pooling2dDescriptor,
746 auto fbPooling2dBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Pooling2d);
747 auto fbPooling2dDescriptor = serializer::CreatePooling2dDescriptor(
749 GetFlatBufferPoolingAlgorithm(pooling2dDescriptor.m_PoolType),
750 pooling2dDescriptor.m_PadLeft,
751 pooling2dDescriptor.m_PadRight,
752 pooling2dDescriptor.m_PadTop,
753 pooling2dDescriptor.m_PadBottom,
754 pooling2dDescriptor.m_PoolWidth,
755 pooling2dDescriptor.m_PoolHeight,
756 pooling2dDescriptor.m_StrideX,
757 pooling2dDescriptor.m_StrideY,
758 GetFlatBufferOutputShapeRounding(pooling2dDescriptor.m_OutputShapeRounding),
759 GetFlatBufferPaddingMethod(pooling2dDescriptor.m_PaddingMethod),
760 GetFlatBufferDataLayout(pooling2dDescriptor.m_DataLayout));
762 auto fbPooling2dLayer = serializer::CreatePooling2dLayer(m_flatBufferBuilder,
763 fbPooling2dBaseLayer,
764 fbPooling2dDescriptor);
766 CreateAnyLayer(fbPooling2dLayer.o, serializer::Layer::Layer_Pooling2dLayer);
769 void SerializerVisitor::VisitPreluLayer(const armnn::IConnectableLayer* layer,
772 // Create FlatBuffer BaseLayer
773 auto flatBufferPreluBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Prelu);
775 // Create the FlatBuffer AdditionLayer
776 auto flatBufferPreluLayer = serializer::CreatePreluLayer(m_flatBufferBuilder, flatBufferPreluBaseLayer);
778 // Add the AnyLayer to the FlatBufferLayers
779 CreateAnyLayer(flatBufferPreluLayer.o, serializer::Layer::Layer_PreluLayer);
782 void SerializerVisitor::VisitQuantizeLayer(const armnn::IConnectableLayer *layer, const char *name)
784 auto fbQuantizeBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Quantize);
785 auto fbQuantizeLayer = serializer::CreateQuantizeLayer(m_flatBufferBuilder,
786 fbQuantizeBaseLayer);
787 CreateAnyLayer(fbQuantizeLayer.o, serializer::Layer::Layer_QuantizeLayer);
790 // Build FlatBuffer for FullyConnected Layer
791 void SerializerVisitor::VisitFullyConnectedLayer(const armnn::IConnectableLayer* layer,
792 const armnn::FullyConnectedDescriptor& fullyConnectedDescriptor,
793 const armnn::ConstTensor& weights,
794 const armnn::Optional<armnn::ConstTensor>& biases,
797 // Create FlatBuffer BaseLayer
798 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_FullyConnected);
800 // Create FlatBuffer FullyConnectedDescriptor
801 auto flatBufferDescriptor =
802 serializer::CreateFullyConnectedDescriptor(m_flatBufferBuilder,
803 fullyConnectedDescriptor.m_BiasEnabled,
804 fullyConnectedDescriptor.m_TransposeWeightMatrix);
806 // Create FlatBuffer weights data
807 auto flatBufferWeights = CreateConstTensorInfo(weights);
809 // Create FlatBuffer bias data
810 flatbuffers::Offset<serializer::ConstTensor> flatBufferBiases;
811 if (fullyConnectedDescriptor.m_BiasEnabled)
813 flatBufferBiases = CreateConstTensorInfo(biases.value());
816 // Create FlatBuffer FullyConnectedLayer
817 auto flatBufferLayer = serializer::CreateFullyConnectedLayer(m_flatBufferBuilder,
819 flatBufferDescriptor,
823 // Add created FullyConnectedLayer to the FlatBufferLayers
824 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_FullyConnectedLayer);
827 // Build FlatBuffer for SpaceToBatchNd Layer
828 void SerializerVisitor::VisitSpaceToBatchNdLayer(const armnn::IConnectableLayer* layer,
829 const armnn::SpaceToBatchNdDescriptor& spaceToBatchNdDescriptor,
832 // Create FlatBuffer BaseLayer
833 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_SpaceToBatchNd);
835 std::vector<unsigned int> padList;
836 padList.reserve(spaceToBatchNdDescriptor.m_PadList.size()*2);
837 for (auto& pad : spaceToBatchNdDescriptor.m_PadList)
839 padList.push_back(pad.first);
840 padList.push_back(pad.second);
843 auto flatBufferDescriptor =
844 CreateSpaceToBatchNdDescriptor(m_flatBufferBuilder,
845 m_flatBufferBuilder.CreateVector(spaceToBatchNdDescriptor.m_BlockShape),
846 m_flatBufferBuilder.CreateVector(padList),
847 GetFlatBufferDataLayout(spaceToBatchNdDescriptor.m_DataLayout));
849 auto flatBufferLayer = serializer::CreateSpaceToBatchNdLayer(m_flatBufferBuilder,
851 flatBufferDescriptor);
853 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_SpaceToBatchNdLayer);
856 // Build FlatBuffer for SpaceToDepthLayer
857 void SerializerVisitor::VisitSpaceToDepthLayer(const armnn::IConnectableLayer* layer,
858 const armnn::SpaceToDepthDescriptor& spaceToDepthDescriptor,
861 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_SpaceToDepth);
862 auto flatBufferDescriptor =
863 CreateSpaceToDepthDescriptor(m_flatBufferBuilder,
864 spaceToDepthDescriptor.m_BlockSize,
865 GetFlatBufferDataLayout(spaceToDepthDescriptor.m_DataLayout));
867 auto flatBufferLayer = serializer::CreateSpaceToDepthLayer(m_flatBufferBuilder,
869 flatBufferDescriptor);
871 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_SpaceToDepthLayer);
874 // Build FlatBuffer for Splitter Layer
875 void SerializerVisitor::VisitSplitterLayer(const armnn::IConnectableLayer* layer,
876 const armnn::ViewsDescriptor& viewsDescriptor,
879 // Create FlatBuffer ViewOrigins
880 std::vector<flatbuffers::Offset<UintVector>> flatBufferViewOrigins;
881 flatBufferViewOrigins.reserve(viewsDescriptor.GetNumViews());
883 for(unsigned int vIdx = 0; vIdx < viewsDescriptor.GetNumViews(); ++vIdx)
885 std::vector<uint32_t> viewOrigin;
886 viewOrigin.reserve(viewsDescriptor.GetNumDimensions());
889 for(unsigned int dIdx = 0; dIdx < viewsDescriptor.GetNumDimensions(); ++dIdx)
891 viewOrigin.push_back(viewsDescriptor.GetViewOrigin(vIdx)[dIdx]);
894 flatBufferViewOrigins.push_back(CreateUintVector(m_flatBufferBuilder,
895 m_flatBufferBuilder.CreateVector(viewOrigin)));
898 // Create FlatBuffer OriginsDescriptor
899 auto flatBufferOriginDescriptor = CreateOriginsDescriptor(m_flatBufferBuilder,
900 viewsDescriptor.GetOrigins().GetConcatAxis(),
901 viewsDescriptor.GetOrigins().GetNumViews(),
902 viewsDescriptor.GetOrigins().GetNumDimensions(),
903 m_flatBufferBuilder.CreateVector(flatBufferViewOrigins));
905 // Create FlatBuffer ViewOrigins
906 std::vector<flatbuffers::Offset<UintVector>> flatBufferViewSizes;
907 flatBufferViewSizes.reserve(viewsDescriptor.GetNumViews());
909 for(unsigned int vIdx = 0; vIdx < viewsDescriptor.GetNumViews(); ++vIdx)
911 std::vector<uint32_t> viewSize;
912 viewSize.reserve(viewsDescriptor.GetNumDimensions());
915 for(unsigned int dIdx = 0; dIdx < viewsDescriptor.GetNumDimensions(); ++dIdx)
917 viewSize.push_back(viewsDescriptor.GetViewSizes(vIdx)[dIdx]);
920 flatBufferViewSizes.push_back(CreateUintVector(m_flatBufferBuilder,
921 m_flatBufferBuilder.CreateVector(viewSize)));
924 // Create FlatBuffer ViewsDescriptor
925 auto flatBufferViewsDescriptor = CreateViewsDescriptor(m_flatBufferBuilder,
926 flatBufferOriginDescriptor,
927 m_flatBufferBuilder.CreateVector(flatBufferViewSizes));
929 // Create FlatBuffer BaseLayer
930 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Splitter);
932 auto flatBufferSplitterLayer = serializer::CreateSplitterLayer(m_flatBufferBuilder,
934 flatBufferViewsDescriptor);
936 CreateAnyLayer(flatBufferSplitterLayer.o, serializer::Layer::Layer_SplitterLayer);
939 void SerializerVisitor::VisitNormalizationLayer(const armnn::IConnectableLayer* layer,
940 const armnn::NormalizationDescriptor& descriptor,
943 auto fbNormalizationBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Normalization);
945 auto fbNormalizationDescriptor = serializer::CreateNormalizationDescriptor(
947 GetFlatBufferNormalizationAlgorithmChannel(descriptor.m_NormChannelType),
948 GetFlatBufferNormalizationAlgorithmMethod(descriptor.m_NormMethodType),
949 descriptor.m_NormSize,
953 GetFlatBufferDataLayout(descriptor.m_DataLayout));
955 auto flatBufferLayer = serializer::CreateNormalizationLayer(m_flatBufferBuilder,
956 fbNormalizationBaseLayer,
957 fbNormalizationDescriptor);
959 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_NormalizationLayer);
962 void SerializerVisitor::VisitStackLayer(const armnn::IConnectableLayer* layer,
963 const armnn::StackDescriptor& stackDescriptor,
966 auto stackBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Stack);
968 std::vector<unsigned int> inputShape;
969 for (unsigned int i =0; i < stackDescriptor.m_InputShape.GetNumDimensions(); i++)
971 inputShape.push_back(stackDescriptor.m_InputShape[i]);
974 auto flatBufferStackDescriptor = CreateStackDescriptor(m_flatBufferBuilder,
975 stackDescriptor.m_Axis,
976 stackDescriptor.m_NumInputs,
977 m_flatBufferBuilder.CreateVector(inputShape));
979 auto stackLayer = serializer::CreateStackLayer(m_flatBufferBuilder, stackBaseLayer, flatBufferStackDescriptor);
980 CreateAnyLayer(stackLayer.o, serializer::Layer::Layer_StackLayer);
983 void SerializerVisitor::VisitStridedSliceLayer(const armnn::IConnectableLayer* layer,
984 const armnn::StridedSliceDescriptor& stridedSliceDescriptor,
987 auto flatBufferBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_StridedSlice);
989 auto flatBufferDescriptor =
990 CreateStridedSliceDescriptor(m_flatBufferBuilder,
991 m_flatBufferBuilder.CreateVector(stridedSliceDescriptor.m_Begin),
992 m_flatBufferBuilder.CreateVector(stridedSliceDescriptor.m_End),
993 m_flatBufferBuilder.CreateVector(stridedSliceDescriptor.m_Stride),
994 stridedSliceDescriptor.m_BeginMask,
995 stridedSliceDescriptor.m_EndMask,
996 stridedSliceDescriptor.m_ShrinkAxisMask,
997 stridedSliceDescriptor.m_EllipsisMask,
998 stridedSliceDescriptor.m_NewAxisMask,
999 GetFlatBufferDataLayout(stridedSliceDescriptor.m_DataLayout));
1001 auto flatBufferLayer = serializer::CreateStridedSliceLayer(m_flatBufferBuilder,
1002 flatBufferBaseLayer,
1003 flatBufferDescriptor);
1005 CreateAnyLayer(flatBufferLayer.o, serializer::Layer::Layer_StridedSliceLayer);
1008 void SerializerVisitor::VisitSubtractionLayer(const armnn::IConnectableLayer* layer, const char* name)
1010 auto fbSubtractionBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Subtraction);
1011 auto fbSubtractionLayer = serializer::CreateSubtractionLayer(m_flatBufferBuilder, fbSubtractionBaseLayer);
1013 CreateAnyLayer(fbSubtractionLayer.o, serializer::Layer::Layer_SubtractionLayer);
1016 void SerializerVisitor::VisitSwitchLayer(const armnn::IConnectableLayer* layer, const char* name)
1018 auto fbSwitchBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Switch);
1019 auto fbSwitchLayer = serializer::CreateSwitchLayer(m_flatBufferBuilder, fbSwitchBaseLayer);
1021 CreateAnyLayer(fbSwitchLayer.o, serializer::Layer::Layer_SwitchLayer);
1024 void SerializerVisitor::VisitTransposeConvolution2dLayer(
1025 const armnn::IConnectableLayer* layer,
1026 const armnn::TransposeConvolution2dDescriptor& descriptor,
1027 const armnn::ConstTensor& weights,
1028 const armnn::Optional<armnn::ConstTensor>& biases,
1031 auto fbBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_Convolution2d);
1032 auto fbDescriptor = CreateTransposeConvolution2dDescriptor(m_flatBufferBuilder,
1033 descriptor.m_PadLeft,
1034 descriptor.m_PadRight,
1035 descriptor.m_PadTop,
1036 descriptor.m_PadBottom,
1037 descriptor.m_StrideX,
1038 descriptor.m_StrideY,
1039 descriptor.m_BiasEnabled,
1040 GetFlatBufferDataLayout(descriptor.m_DataLayout));
1043 auto fbWeightsConstTensorInfo = CreateConstTensorInfo(weights);
1044 flatbuffers::Offset<serializer::ConstTensor> fbBiasesConstTensorInfo;
1045 if (biases.has_value())
1047 fbBiasesConstTensorInfo = CreateConstTensorInfo(biases.value());
1050 auto fbLayer = CreateTransposeConvolution2dLayer(m_flatBufferBuilder,
1053 fbWeightsConstTensorInfo,
1054 fbBiasesConstTensorInfo);
1056 CreateAnyLayer(fbLayer.o, serializer::Layer::Layer_TransposeConvolution2dLayer);
1059 void SerializerVisitor::VisitQuantizedLstmLayer(const armnn::IConnectableLayer* layer,
1060 const armnn::QuantizedLstmInputParams& params,
1063 auto fbQuantizedLstmBaseLayer = CreateLayerBase(layer, serializer::LayerType::LayerType_QuantizedLstm);
1065 // Get input parameters
1066 auto inputToInputWeights = CreateConstTensorInfo(params.GetInputToInputWeights());
1067 auto inputToForgetWeights = CreateConstTensorInfo(params.GetInputToForgetWeights());
1068 auto inputToCellWeights = CreateConstTensorInfo(params.GetInputToCellWeights());
1069 auto inputToOutputWeights = CreateConstTensorInfo(params.GetInputToOutputWeights());
1071 auto recurrentToInputWeights = CreateConstTensorInfo(params.GetRecurrentToInputWeights());
1072 auto recurrentToForgetWeights = CreateConstTensorInfo(params.GetRecurrentToForgetWeights());
1073 auto recurrentToCellWeights = CreateConstTensorInfo(params.GetRecurrentToCellWeights());
1074 auto recurrentToOutputWeights = CreateConstTensorInfo(params.GetRecurrentToOutputWeights());
1076 auto inputGateBias = CreateConstTensorInfo(params.GetInputGateBias());
1077 auto forgetGateBias = CreateConstTensorInfo(params.GetForgetGateBias());
1078 auto cellBias = CreateConstTensorInfo(params.GetCellBias());
1079 auto outputGateBias = CreateConstTensorInfo(params.GetOutputGateBias());
1081 auto fbQuantizedLstmParams = serializer::CreateQuantizedLstmInputParams(
1082 m_flatBufferBuilder,
1083 inputToInputWeights,
1084 inputToForgetWeights,
1086 inputToOutputWeights,
1087 recurrentToInputWeights,
1088 recurrentToForgetWeights,
1089 recurrentToCellWeights,
1090 recurrentToOutputWeights,
1096 auto fbQuantizedLstmLayer = serializer::CreateQuantizedLstmLayer(
1097 m_flatBufferBuilder,
1098 fbQuantizedLstmBaseLayer,
1099 fbQuantizedLstmParams);
1101 CreateAnyLayer(fbQuantizedLstmLayer.o, serializer::Layer::Layer_QuantizedLstmLayer);
1104 fb::Offset<serializer::LayerBase> SerializerVisitor::CreateLayerBase(const IConnectableLayer* layer,
1105 const serializer::LayerType layerType)
1107 uint32_t fbIndex = GetSerializedId(layer->GetGuid());
1109 std::vector<fb::Offset<serializer::InputSlot>> inputSlots = CreateInputSlots(layer);
1110 std::vector<fb::Offset<serializer::OutputSlot>> outputSlots = CreateOutputSlots(layer);
1112 return serializer::CreateLayerBase(m_flatBufferBuilder,
1114 m_flatBufferBuilder.CreateString(layer->GetName()),
1116 m_flatBufferBuilder.CreateVector(inputSlots),
1117 m_flatBufferBuilder.CreateVector(outputSlots));
1120 void SerializerVisitor::CreateAnyLayer(const flatbuffers::Offset<void>& layer, const serializer::Layer serializerLayer)
1122 auto anyLayer = armnnSerializer::CreateAnyLayer(m_flatBufferBuilder, serializerLayer, layer);
1123 m_serializedLayers.push_back(anyLayer);
1126 template <typename T>
1127 flatbuffers::Offset<flatbuffers::Vector<T>> SerializerVisitor::CreateDataVector(const void* memory, unsigned int size)
1129 const T* buffer = reinterpret_cast<const T*>(memory);
1130 std::vector<T> vector(buffer, buffer + (size / sizeof(T)));
1131 auto fbVector = m_flatBufferBuilder.CreateVector(vector);
1135 flatbuffers::Offset<serializer::ConstTensor>
1136 SerializerVisitor::CreateConstTensorInfo(const armnn::ConstTensor& constTensor)
1138 armnn::TensorInfo tensorInfo = constTensor.GetInfo();
1140 // Get the dimensions
1141 std::vector<unsigned int> shape;
1143 for(unsigned int dim = 0; dim < tensorInfo.GetShape().GetNumDimensions(); ++dim)
1145 shape.push_back(tensorInfo.GetShape()[dim]);
1148 // Create FlatBuffer TensorInfo
1149 auto flatBufferTensorInfo = serializer::CreateTensorInfo(m_flatBufferBuilder,
1150 m_flatBufferBuilder.CreateVector(shape),
1151 GetFlatBufferDataType(tensorInfo.GetDataType()),
1152 tensorInfo.GetQuantizationScale(),
1153 tensorInfo.GetQuantizationOffset());
1154 flatbuffers::Offset<void> fbPayload;
1156 switch (tensorInfo.GetDataType())
1158 case armnn::DataType::Float32:
1159 case armnn::DataType::Signed32:
1161 auto fbVector = CreateDataVector<int32_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1162 flatbuffers::Offset<serializer::IntData> flatBuffersData = serializer::CreateIntData(
1163 m_flatBufferBuilder,
1165 fbPayload = flatBuffersData.o;
1168 case armnn::DataType::Float16:
1170 auto fbVector = CreateDataVector<int16_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1171 flatbuffers::Offset<serializer::ShortData> flatBuffersData = serializer::CreateShortData(
1172 m_flatBufferBuilder,
1174 fbPayload = flatBuffersData.o;
1177 case armnn::DataType::QuantisedSymm16:
1179 auto fbVector = CreateDataVector<int16_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1180 flatbuffers::Offset<serializer::ShortData> flatBuffersData = serializer::CreateShortData(
1181 m_flatBufferBuilder,
1183 fbPayload = flatBuffersData.o;
1186 case armnn::DataType::QuantisedAsymm8:
1187 case armnn::DataType::Boolean:
1190 auto fbVector = CreateDataVector<int8_t>(constTensor.GetMemoryArea(), constTensor.GetNumBytes());
1191 flatbuffers::Offset<serializer::ByteData> flatBuffersData = serializer::CreateByteData(
1192 m_flatBufferBuilder,
1194 fbPayload = flatBuffersData.o;
1197 flatbuffers::Offset<serializer::ConstTensor> flatBufferConstTensor = serializer::CreateConstTensor(
1198 m_flatBufferBuilder,
1199 flatBufferTensorInfo,
1200 GetFlatBufferConstTensorData(tensorInfo.GetDataType()),
1202 return flatBufferConstTensor;
1205 std::vector<fb::Offset<serializer::InputSlot>>
1206 SerializerVisitor::CreateInputSlots(const armnn::IConnectableLayer* layer)
1208 std::vector<fb::Offset<serializer::InputSlot>> inputSlots;
1210 // Get the InputSlots
1211 for (unsigned int slotIndex = 0; slotIndex<layer->GetNumInputSlots(); ++slotIndex)
1213 const IInputSlot& inputSlot = layer->GetInputSlot(slotIndex);
1215 // Get the Connection for the InputSlot
1216 const IOutputSlot* connection = inputSlot.GetConnection();
1218 // Create FlatBuffer Connection
1219 serializer::Connection conn(GetSerializedId(inputSlot.GetConnection()->GetOwningLayerGuid()),
1220 connection->CalculateIndexOnOwner());
1221 // Create FlatBuffer InputSlot
1222 inputSlots.push_back(serializer::CreateInputSlot(m_flatBufferBuilder, slotIndex, &conn));
1227 std::vector<fb::Offset<serializer::OutputSlot>>
1228 SerializerVisitor::CreateOutputSlots(const armnn::IConnectableLayer* layer)
1230 std::vector<fb::Offset<serializer::OutputSlot>> outputSlots;
1232 // Get the OutputSlots
1233 for (unsigned int slotIndex = 0; slotIndex < layer->GetNumOutputSlots(); ++slotIndex)
1235 const IOutputSlot& outputSlot = layer->GetOutputSlot(slotIndex);
1236 const armnn::TensorInfo& tensorInfo = outputSlot.GetTensorInfo();
1238 // Get the dimensions
1239 std::vector<unsigned int> shape;
1240 for(unsigned int dim = 0; dim < tensorInfo.GetShape().GetNumDimensions(); ++dim)
1242 shape.push_back(tensorInfo.GetShape()[dim]);
1245 // Create FlatBuffer TensorInfo
1246 auto flatBufferTensorInfo = serializer::CreateTensorInfo(m_flatBufferBuilder,
1247 m_flatBufferBuilder.CreateVector(shape),
1248 GetFlatBufferDataType(tensorInfo.GetDataType()),
1249 tensorInfo.GetQuantizationScale(),
1250 tensorInfo.GetQuantizationOffset());
1252 // Create FlatBuffer Outputslot
1253 outputSlots.push_back(serializer::CreateOutputSlot(m_flatBufferBuilder,
1255 flatBufferTensorInfo));
1261 ISerializer* ISerializer::CreateRaw()
1263 return new Serializer();
1266 ISerializerPtr ISerializer::Create()
1268 return ISerializerPtr(CreateRaw(), &ISerializer::Destroy);
1271 void ISerializer::Destroy(ISerializer* serializer)
1276 void Serializer::Serialize(const INetwork& inNetwork)
1278 // Iterate through to network
1279 inNetwork.Accept(m_SerializerVisitor);
1280 flatbuffers::FlatBufferBuilder& fbBuilder = m_SerializerVisitor.GetFlatBufferBuilder();
1282 // Create FlatBuffer SerializedGraph
1283 auto serializedGraph = serializer::CreateSerializedGraph(
1285 fbBuilder.CreateVector(m_SerializerVisitor.GetSerializedLayers()),
1286 fbBuilder.CreateVector(m_SerializerVisitor.GetInputIds()),
1287 fbBuilder.CreateVector(m_SerializerVisitor.GetOutputIds()));
1289 // Serialize the graph
1290 fbBuilder.Finish(serializedGraph);
1293 bool Serializer::SaveSerializedToStream(std::ostream& stream)
1295 flatbuffers::FlatBufferBuilder& fbBuilder = m_SerializerVisitor.GetFlatBufferBuilder();
1297 auto bytesToWrite = boost::numeric_cast<std::streamsize>(fbBuilder.GetSize());
1298 stream.write(reinterpret_cast<const char*>(fbBuilder.GetBufferPointer()), bytesToWrite);
1299 return !stream.bad();
1302 } // namespace armnnSerializer