2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
7 #include <LayersFwd.hpp>
9 #include <armnn/Types.hpp>
10 #include <armnn/LayerSupport.hpp>
11 #include <armnn/ILayerSupport.hpp>
12 #include <armnn/BackendRegistry.hpp>
13 #include <armnn/utility/PolymorphicDowncast.hpp>
15 #include <backendsCommon/WorkloadFactory.hpp>
16 #include <armnn/backends/IBackendInternal.hpp>
17 #include <backendsCommon/CpuTensorHandle.hpp>
18 #include <backendsCommon/WorkloadFactory.hpp>
20 #include <backendsCommon/test/WorkloadTestUtils.hpp>
22 #include <boost/iterator/transform_iterator.hpp>
33 const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
40 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
43 } // anonymous namespace
45 bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
46 const IConnectableLayer& connectableLayer,
47 Optional<DataType> dataType,
48 std::string& outReasonIfUnsupported)
50 Optional<std::string&> reason = outReasonIfUnsupported;
52 const Layer& layer = *(PolymorphicDowncast<const Layer*>(&connectableLayer));
54 auto const& backendRegistry = BackendRegistryInstance();
55 if (!backendRegistry.IsBackendRegistered(backendId))
58 ss << connectableLayer.GetName() << " is not supported on " << backendId
59 << " because this backend is not registered.";
61 outReasonIfUnsupported = ss.str();
65 auto backendFactory = backendRegistry.GetFactory(backendId);
66 auto backendObject = backendFactory();
67 auto layerSupportObject = backendObject->GetLayerSupport();
69 switch(layer.GetType())
71 case LayerType::Activation:
73 auto cLayer = PolymorphicDowncast<const ActivationLayer*>(&layer);
74 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
75 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
76 result = layerSupportObject->IsActivationSupported(
77 OverrideDataType(input, dataType),
78 OverrideDataType(output, dataType),
79 cLayer->GetParameters(),
83 case LayerType::Addition:
85 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
86 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
87 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
88 result = layerSupportObject->IsAdditionSupported(
89 OverrideDataType(input0, dataType),
90 OverrideDataType(input1, dataType),
91 OverrideDataType(output, dataType),
95 case LayerType::ArgMinMax:
97 auto cLayer = PolymorphicDowncast<const ArgMinMaxLayer*>(&layer);
98 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
100 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
101 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
102 result = layerSupportObject->IsArgMinMaxSupported(
103 OverrideDataType(input, dataType),
104 OverrideDataType(output, DataType::Signed32),
109 case LayerType::BatchNormalization:
111 auto cLayer = PolymorphicDowncast<const BatchNormalizationLayer*>(&layer);
112 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
113 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
114 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
115 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
116 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
117 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
118 result = layerSupportObject->IsBatchNormalizationSupported(
119 OverrideDataType(input, dataType),
120 OverrideDataType(output, dataType),
121 OverrideDataType(mean, dataType),
122 OverrideDataType(var, dataType),
123 OverrideDataType(beta, dataType),
124 OverrideDataType(gamma, dataType),
125 cLayer->GetParameters(),
129 case LayerType::BatchToSpaceNd:
131 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
132 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
133 auto cLayer = PolymorphicDowncast<const BatchToSpaceNdLayer*>(&layer);
135 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
136 OverrideDataType(output, dataType),
137 cLayer->GetParameters(),
141 case LayerType::Comparison:
143 auto cLayer = PolymorphicDowncast<const ComparisonLayer*>(&layer);
145 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
146 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
147 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
149 result = layerSupportObject->IsComparisonSupported(OverrideDataType(input0, dataType),
150 OverrideDataType(input1, dataType),
151 OverrideDataType(output, DataType::Boolean),
152 cLayer->GetParameters(),
156 case LayerType::Constant:
158 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
159 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
162 case LayerType::ConvertBf16ToFp32:
164 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
165 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
166 result = layerSupportObject->IsConvertBf16ToFp32Supported(input, output, reason);
169 case LayerType::ConvertFp16ToFp32:
171 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
172 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
173 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
176 case LayerType::ConvertFp32ToBf16:
178 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
179 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
180 result = layerSupportObject->IsConvertFp32ToBf16Supported(input, output, reason);
183 case LayerType::ConvertFp32ToFp16:
185 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
186 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
187 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
190 case LayerType::Convolution2d:
192 auto cLayer = PolymorphicDowncast<const Convolution2dLayer*>(&layer);
194 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
196 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
197 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
199 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
201 // Construct optional biases object based on the value of m_BiasEnabled
202 Optional<TensorInfo> biases;
203 if (descriptor.m_BiasEnabled)
206 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
209 result = layerSupportObject->IsConvolution2dSupported(
213 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
218 case LayerType::Debug:
220 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
221 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
223 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
224 OverrideDataType(output, dataType),
228 case LayerType::DepthToSpace:
230 auto cLayer = PolymorphicDowncast<const DepthToSpaceLayer*>(&layer);
232 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
233 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
235 result = layerSupportObject->IsDepthToSpaceSupported(OverrideDataType(input, dataType),
236 OverrideDataType(output, dataType),
237 cLayer->GetParameters(),
241 case LayerType::DepthwiseConvolution2d:
243 auto cLayer = PolymorphicDowncast<const DepthwiseConvolution2dLayer*>(&layer);
244 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
246 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
247 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
249 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
251 // Construct optional biases object based on the value of m_BiasEnabled
252 Optional<TensorInfo> biases;
253 if (descriptor.m_BiasEnabled)
256 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
259 result = layerSupportObject->IsDepthwiseConvolutionSupported(
263 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
268 case LayerType::Dequantize:
270 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
271 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
273 result = layerSupportObject->IsDequantizeSupported(input,
274 OverrideDataType(output, dataType),
278 case LayerType::DetectionPostProcess:
280 auto cLayer = PolymorphicDowncast<const DetectionPostProcessLayer*>(&layer);
281 const TensorInfo& boxEncodings = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
282 const TensorInfo& scores = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
283 const TensorInfo& anchors = cLayer->m_Anchors->GetTensorInfo();
285 const TensorInfo& detectionBoxes = layer.GetOutputSlot(0).GetTensorInfo();
286 const TensorInfo& detectionClasses = layer.GetOutputSlot(1).GetTensorInfo();
287 const TensorInfo& detectionScores = layer.GetOutputSlot(2).GetTensorInfo();
288 const TensorInfo& numDetections = layer.GetOutputSlot(3).GetTensorInfo();
290 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
291 result = layerSupportObject->IsDetectionPostProcessSupported(boxEncodings,
302 case LayerType::ElementwiseUnary:
304 auto cLayer = PolymorphicDowncast<const ElementwiseUnaryLayer*>(&layer);
306 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
307 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
309 result = layerSupportObject->IsElementwiseUnarySupported(OverrideDataType(input, dataType),
310 OverrideDataType(output, dataType),
311 cLayer->GetParameters(),
315 case LayerType::FakeQuantization:
317 auto cLayer = PolymorphicDowncast<const FakeQuantizationLayer*>(&layer);
318 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
319 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
320 cLayer->GetParameters(),
324 case LayerType::Floor:
326 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
327 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
328 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
329 OverrideDataType(output, dataType),
333 case LayerType::FullyConnected:
335 auto cLayer = PolymorphicDowncast<const FullyConnectedLayer*>(&layer);
336 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
337 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
338 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
341 const TensorInfo * biasInfoPtr = nullptr;
342 static const TensorInfo dummyBFloat16Bias(TensorShape({1,1,1,1}), DataType::BFloat16);
343 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
344 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
345 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
347 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
348 if (descriptor.m_BiasEnabled)
350 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
351 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
352 biasInfoPtr = &biasInfo;
356 // If biases are not enabled pass a dummy tensorinfo for the validation
357 switch(input.GetDataType())
359 case DataType::BFloat16:
361 biasInfoPtr = &dummyBFloat16Bias;
364 case DataType::Float16:
366 biasInfoPtr = &dummyFloat16Bias;
369 case DataType::Float32:
371 biasInfoPtr = &dummyFloat32Bias;
374 case DataType::QAsymmU8:
375 case DataType::QAsymmS8:
376 case DataType::QSymmS8:
377 case DataType::QSymmS16:
379 biasInfoPtr = &dummyQA8Bias;
384 ARMNN_ASSERT_MSG(false, "Unexpected bias type");
389 result = layerSupportObject->IsFullyConnectedSupported(
390 OverrideDataType(input, dataType),
391 OverrideDataType(output, dataType),
392 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
398 case LayerType::Gather:
400 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
401 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
402 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
403 result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType),
405 OverrideDataType(output, dataType),
409 case LayerType::Input:
411 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
412 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
415 case LayerType::InstanceNormalization:
417 auto cLayer = PolymorphicDowncast<const InstanceNormalizationLayer*>(&layer);
418 const InstanceNormalizationDescriptor& descriptor = cLayer->GetParameters();
420 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
421 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
423 result = layerSupportObject->IsInstanceNormalizationSupported(
424 OverrideDataType(input, dataType),
425 OverrideDataType(output, dataType),
430 case LayerType::L2Normalization:
432 auto cLayer = PolymorphicDowncast<const L2NormalizationLayer*>(&layer);
433 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
435 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
436 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
438 result = layerSupportObject->IsL2NormalizationSupported(
439 OverrideDataType(input, dataType),
440 OverrideDataType(output, dataType),
445 case LayerType::LogSoftmax:
447 auto cLayer = PolymorphicDowncast<const LogSoftmaxLayer*>(&layer);
449 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
450 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
452 result = layerSupportObject->IsLogSoftmaxSupported(OverrideDataType(input, dataType),
453 OverrideDataType(output, dataType),
454 cLayer->GetParameters(),
458 case LayerType::Lstm:
460 auto cLayer = PolymorphicDowncast<const LstmLayer*>(&layer);
461 const LstmDescriptor& descriptor = cLayer->GetParameters();
464 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
466 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
468 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
471 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
472 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
473 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
474 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
477 const TensorInfo& inputToForgetWeights
478 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
479 const TensorInfo& inputToCellWeights
480 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
481 const TensorInfo& inputToOutputWeights
482 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
483 const TensorInfo& recurrentToForgetWeights
484 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
485 const TensorInfo& recurrentToCellWeights
486 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
487 const TensorInfo& recurrentToOutputWeights
488 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
489 const TensorInfo& forgetGateBias
490 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
491 const TensorInfo& cellBias
492 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
493 const TensorInfo& outputGateBias
494 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
496 LstmInputParamsInfo paramsInfo;
498 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
499 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
500 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
501 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
502 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
503 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
504 paramsInfo.m_ForgetGateBias = &forgetGateBias;
505 paramsInfo.m_CellBias = &cellBias;
506 paramsInfo.m_OutputGateBias = &outputGateBias;
509 // Optional parameters
510 TensorInfo optInputToInputWeights;
511 TensorInfo optRecurrentToInputWeights;
512 TensorInfo optCellToInputWeights;
513 TensorInfo optInputGateBias;
514 TensorInfo optProjectionWeights;
515 TensorInfo optProjectionBias;
516 TensorInfo optCellToForgetWeights;
517 TensorInfo optCellToOutputWeights;
518 TensorInfo optInputLayerNormWeights;
519 TensorInfo optForgetLayerNormWeights;
520 TensorInfo optCellLayerNormWeights;
521 TensorInfo optOutputLayerNormWeights;
523 if(!descriptor.m_CifgEnabled)
525 optInputToInputWeights =
526 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
527 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
529 optRecurrentToInputWeights =
530 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
531 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
533 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
534 paramsInfo.m_InputGateBias = &optInputGateBias;
537 if(descriptor.m_ProjectionEnabled)
539 optProjectionWeights =
540 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
541 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
542 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
545 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
546 paramsInfo.m_ProjectionBias = &optProjectionBias;
550 if(descriptor.m_PeepholeEnabled)
552 if(!descriptor.m_CifgEnabled)
554 optCellToInputWeights =
555 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo(),
557 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
559 optCellToForgetWeights =
560 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
561 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
562 optCellToOutputWeights =
563 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
564 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
567 if(descriptor.m_LayerNormEnabled)
569 if (!descriptor.m_CifgEnabled)
571 optInputLayerNormWeights = OverrideDataType(
572 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
573 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
576 optForgetLayerNormWeights = OverrideDataType(
577 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
578 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
580 optCellLayerNormWeights = OverrideDataType(
581 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
582 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
584 optOutputLayerNormWeights = OverrideDataType(
585 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
586 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
589 result = layerSupportObject->IsLstmSupported(
602 case LayerType::Maximum:
604 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
605 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
606 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
608 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
609 OverrideDataType(input1, dataType),
610 OverrideDataType(output, dataType),
614 case LayerType::MemCopy:
616 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
617 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
619 result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
620 OverrideDataType(output, dataType),
624 case LayerType::MemImport:
626 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
627 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
629 result = layerSupportObject->IsMemImportSupported(OverrideDataType(input, dataType),
630 OverrideDataType(output, dataType),
634 case LayerType::Merge:
636 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
637 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
638 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
640 result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType),
641 OverrideDataType(input1, dataType),
642 OverrideDataType(output, dataType),
646 case LayerType::Concat:
648 auto cLayer = PolymorphicDowncast<const ConcatLayer*>(&layer);
650 // Get vector of all inputs.
651 auto getTensorInfo = [&dataType](const InputSlot& slot)
653 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
655 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
656 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
657 std::vector<TensorInfo> inputs(beginI, endI);
659 auto getTensorInfoPtr = [](const TensorInfo& info)
663 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
664 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
665 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
667 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
669 result = layerSupportObject->IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
674 case LayerType::Multiplication:
676 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
677 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
678 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
679 result = layerSupportObject->IsMultiplicationSupported(
680 OverrideDataType(input0, dataType),
681 OverrideDataType(input1, dataType),
682 OverrideDataType(output, dataType),
686 case LayerType::Normalization:
688 auto cLayer = PolymorphicDowncast<const NormalizationLayer*>(&layer);
689 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
690 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
691 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
692 OverrideDataType(output, dataType),
693 cLayer->GetParameters(),
697 case LayerType::Output:
699 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
700 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
703 case LayerType::Permute:
705 auto cLayer = PolymorphicDowncast<const PermuteLayer*>(&layer);
706 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
707 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
708 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
709 OverrideDataType(output, dataType),
710 cLayer->GetParameters(),
716 auto cLayer = PolymorphicDowncast<const PadLayer*>(&layer);
717 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
718 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
719 result = layerSupportObject->IsPadSupported(
720 OverrideDataType(input, dataType),
721 OverrideDataType(output, dataType),
722 cLayer->GetParameters(),
726 case LayerType::Pooling2d:
728 auto cLayer = PolymorphicDowncast<const Pooling2dLayer*>(&layer);
729 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
730 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
731 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
732 OverrideDataType(output, dataType),
733 cLayer->GetParameters(),
737 case LayerType::PreCompiled:
739 auto cLayer = PolymorphicDowncast<const PreCompiledLayer*>(&layer);
740 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
741 result = layerSupportObject->IsPreCompiledSupported(OverrideDataType(input, dataType),
742 cLayer->GetParameters(),
746 case LayerType::Quantize:
748 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
749 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
750 result = layerSupportObject->IsQuantizeSupported(input, output, reason);
753 case LayerType::QLstm:
755 auto cLayer = PolymorphicDowncast<const QLstmLayer*>(&layer);
756 const QLstmDescriptor& descriptor = cLayer->GetParameters();
759 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
760 const TensorInfo& previousOutputIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
761 const TensorInfo& previousCellStateIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
764 const TensorInfo& outputStateOut = layer.GetOutputSlot(0).GetTensorInfo();
765 const TensorInfo& cellStateOut = layer.GetOutputSlot(1).GetTensorInfo();
766 const TensorInfo& output = layer.GetOutputSlot(2).GetTensorInfo();
769 LstmInputParamsInfo paramsInfo;
772 paramsInfo.m_InputToForgetWeights = &cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo();
773 paramsInfo.m_InputToCellWeights = &cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo();
774 paramsInfo.m_InputToOutputWeights = &cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo();
776 paramsInfo.m_RecurrentToForgetWeights =
777 &cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo();
778 paramsInfo.m_RecurrentToCellWeights =
779 &cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo();
780 paramsInfo.m_RecurrentToOutputWeights =
781 &cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo();
783 paramsInfo.m_ForgetGateBias = &cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo();
784 paramsInfo.m_CellBias = &cLayer->m_BasicParameters.m_CellBias->GetTensorInfo();
785 paramsInfo.m_OutputGateBias = &cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo();
787 if(!descriptor.m_CifgEnabled)
789 paramsInfo.m_InputToInputWeights = &cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo();
790 paramsInfo.m_RecurrentToInputWeights =
791 &cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo();
792 paramsInfo.m_InputGateBias = &cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo();
795 if(descriptor.m_ProjectionEnabled)
797 paramsInfo.m_ProjectionWeights = &cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo();
799 // Projection bias is optional even if projection is enabled
800 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
802 paramsInfo.m_ProjectionBias = &cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo();
806 if(descriptor.m_PeepholeEnabled)
808 if (!descriptor.m_CifgEnabled)
810 paramsInfo.m_CellToInputWeights =
811 &cLayer->m_PeepholeParameters.m_CellToInputWeights->GetTensorInfo();
814 paramsInfo.m_CellToForgetWeights =
815 &cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo();
816 paramsInfo.m_CellToOutputWeights = &cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo();
819 if(descriptor.m_LayerNormEnabled)
821 if (!descriptor.m_CifgEnabled)
823 paramsInfo.m_InputLayerNormWeights =
824 &cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo();
827 paramsInfo.m_ForgetLayerNormWeights =
828 &cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo();
829 paramsInfo.m_CellLayerNormWeights =
830 &cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo();
831 paramsInfo.m_OutputLayerNormWeights =
832 &cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo();
835 result = layerSupportObject->IsQLstmSupported(input,
846 case LayerType::QuantizedLstm:
848 auto cLayer = PolymorphicDowncast<const QuantizedLstmLayer*>(&layer);
851 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
852 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
853 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
856 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
857 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
859 // QuantizedLstm parameters
860 QuantizedLstmInputParamsInfo paramsInfo;
862 paramsInfo.m_InputToInputWeights =
863 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
864 paramsInfo.m_InputToForgetWeights =
865 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
866 paramsInfo.m_InputToCellWeights =
867 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
868 paramsInfo.m_InputToOutputWeights =
869 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
871 paramsInfo.m_RecurrentToInputWeights =
872 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
873 paramsInfo.m_RecurrentToForgetWeights =
874 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
875 paramsInfo.m_RecurrentToCellWeights =
876 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
877 paramsInfo.m_RecurrentToOutputWeights =
878 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
880 paramsInfo.m_InputGateBias =
881 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
882 paramsInfo.m_ForgetGateBias =
883 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
884 paramsInfo.m_CellBias =
885 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
886 paramsInfo.m_OutputGateBias =
887 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
889 result = layerSupportObject->IsQuantizedLstmSupported(input,
898 case LayerType::Division:
900 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
901 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
902 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
903 result = layerSupportObject->IsDivisionSupported(
904 OverrideDataType(input0, dataType),
905 OverrideDataType(input1, dataType),
906 OverrideDataType(output, dataType),
910 case LayerType::Reshape:
912 auto cLayer = PolymorphicDowncast<const ReshapeLayer*>(&layer);
913 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
914 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
915 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
916 OverrideDataType(output, dataType),
917 cLayer->GetParameters(),
921 case LayerType::Resize:
923 auto cLayer = PolymorphicDowncast<const ResizeLayer*>(&layer);
924 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
925 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
926 result = layerSupportObject->IsResizeSupported(OverrideDataType(input, dataType),
927 OverrideDataType(output, dataType),
928 cLayer->GetParameters(),
932 case LayerType::Slice:
934 auto cLayer = PolymorphicDowncast<const SliceLayer*>(&layer);
936 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
937 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
939 result = layerSupportObject->IsSliceSupported(OverrideDataType(input, dataType),
940 OverrideDataType(output, dataType),
941 cLayer->GetParameters(),
945 case LayerType::Softmax:
947 auto cLayer = PolymorphicDowncast<const SoftmaxLayer*>(&layer);
948 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
949 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
950 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
951 OverrideDataType(output, dataType),
952 cLayer->GetParameters(),
956 case LayerType::SpaceToBatchNd:
958 auto cLayer = PolymorphicDowncast<const SpaceToBatchNdLayer*>(&layer);
959 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
960 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
961 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
962 OverrideDataType(output, dataType),
963 cLayer->GetParameters(),
967 case LayerType::SpaceToDepth:
969 auto cLayer = PolymorphicDowncast<const SpaceToDepthLayer*>(&layer);
971 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
972 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
974 result = layerSupportObject->IsSpaceToDepthSupported(OverrideDataType(input, dataType),
975 OverrideDataType(output, dataType),
976 cLayer->GetParameters(),
980 case LayerType::Splitter:
982 auto cLayer = PolymorphicDowncast<const SplitterLayer*>(&layer);
983 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
985 // Get vector of all outputs.
986 auto getTensorInfo = [&dataType](const OutputSlot& slot)
988 return OverrideDataType(slot.GetTensorInfo(), dataType);
990 auto beginI = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfo);
991 auto endI = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfo);
992 std::vector<TensorInfo> outputs(beginI, endI);
994 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
996 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
998 cLayer->GetParameters(),
1002 case LayerType::Stack:
1004 auto cLayer = PolymorphicDowncast<const StackLayer*>(&layer);
1006 // Get vector of all inputs.
1007 auto getTensorInfo = [&dataType](const InputSlot& slot)
1009 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1011 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
1012 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
1013 std::vector<TensorInfo> inputs(beginI, endI);
1015 auto getTensorInfoPtr = [](const TensorInfo& info)
1019 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
1020 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
1021 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
1023 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1025 result = layerSupportObject->IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
1029 case LayerType::StandIn:
1031 auto cLayer = PolymorphicDowncast<const StandInLayer*>(&layer);
1033 // Get vector of all inputs.
1034 auto getTensorInfoIn = [&dataType](const InputSlot& slot)
1036 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
1038 auto getTensorInfoOut = [&dataType](const OutputSlot& slot)
1040 return OverrideDataType(slot.GetTensorInfo(), dataType);
1042 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfoIn);
1043 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfoIn);
1044 std::vector<TensorInfo> inputs(beginI, endI);
1046 auto beginO = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfoOut);
1047 auto endO = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfoOut);
1048 std::vector<TensorInfo> outputs(beginO, endO);
1051 auto getTensorInfoPtr = [](const TensorInfo& info)
1055 auto beginPtrI = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
1056 auto endPtrI = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
1057 std::vector<const TensorInfo*> inputPtrs(beginPtrI, endPtrI);
1059 auto beginPtrO = boost::make_transform_iterator(outputs.begin(), getTensorInfoPtr);
1060 auto endPtrO = boost::make_transform_iterator(outputs.end(), getTensorInfoPtr);
1061 std::vector<const TensorInfo*> outputPtrs(beginPtrO, endPtrO);
1064 result = layerSupportObject->IsStandInSupported(inputPtrs,
1066 cLayer->GetParameters(),
1070 case LayerType::StridedSlice:
1072 auto cLayer = PolymorphicDowncast<const StridedSliceLayer*>(&layer);
1073 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1074 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1075 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
1076 OverrideDataType(output, dataType),
1077 cLayer->GetParameters(),
1081 case LayerType::Subtraction:
1083 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1084 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1085 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1086 result = layerSupportObject->IsSubtractionSupported(
1087 OverrideDataType(input0, dataType),
1088 OverrideDataType(input1, dataType),
1089 OverrideDataType(output, dataType),
1093 case LayerType::Switch:
1095 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1096 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1097 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
1098 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
1099 result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType),
1100 OverrideDataType(input1, dataType),
1101 OverrideDataType(output0, dataType),
1102 OverrideDataType(output1, dataType),
1106 case LayerType::Mean:
1108 auto cLayer = PolymorphicDowncast<const MeanLayer*>(&layer);
1109 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1110 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1111 result = layerSupportObject->IsMeanSupported(
1112 OverrideDataType(input, dataType),
1113 OverrideDataType(output, dataType),
1114 cLayer->GetParameters(),
1118 case LayerType::Minimum:
1120 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1121 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1122 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1123 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
1124 OverrideDataType(input1, dataType),
1125 OverrideDataType(output, dataType),
1129 case LayerType::Prelu:
1131 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1132 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
1133 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1134 result = layerSupportObject->IsPreluSupported(OverrideDataType(input, dataType),
1135 OverrideDataType(alpha, dataType),
1136 OverrideDataType(output, dataType),
1140 case LayerType::Transpose:
1142 auto cLayer = PolymorphicDowncast<const TransposeLayer*>(&layer);
1143 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
1144 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
1145 result = layerSupportObject->IsTransposeSupported(OverrideDataType(input, dataType),
1146 OverrideDataType(output, dataType),
1147 cLayer->GetParameters(),
1151 case LayerType::TransposeConvolution2d:
1153 auto cLayer = PolymorphicDowncast<const TransposeConvolution2dLayer*>(&layer);
1155 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
1157 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
1159 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
1161 Optional<TensorInfo> biases;
1162 if (descriptor.m_BiasEnabled)
1164 ARMNN_ASSERT(cLayer->m_Bias.get() != nullptr);
1165 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
1166 GetBiasTypeFromWeightsType(dataType));
1169 ARMNN_ASSERT(cLayer->m_Weight.get() != nullptr);
1170 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
1172 result = layerSupportObject->IsTransposeConvolution2dSupported(input,
1183 ARMNN_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
1184 reason.value() = "Unrecognised layer type";
1192 bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
1193 Optional<DataType> dataType,
1194 std::string& outReasonIfUnsupported)
1196 auto layer = PolymorphicDowncast<const Layer*>(&connectableLayer);
1197 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
1200 // Default Implementations
1201 std::unique_ptr<IWorkload> IWorkloadFactory::CreateAbs(const AbsQueueDescriptor& /*descriptor*/,
1202 const WorkloadInfo& /*info*/) const
1204 return std::unique_ptr<IWorkload>();
1207 std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& /*descriptor*/,
1208 const WorkloadInfo& /*info*/) const
1210 return std::unique_ptr<IWorkload>();
1213 std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& /*descriptor*/,
1214 const WorkloadInfo& /*info*/) const
1216 return std::unique_ptr<IWorkload>();
1219 std::unique_ptr<IWorkload> IWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& /*descriptor*/,
1220 const WorkloadInfo& /*info*/) const
1222 return std::unique_ptr<IWorkload>();
1225 std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
1226 const BatchNormalizationQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
1228 return std::unique_ptr<IWorkload>();
1231 std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& /*desc*/,
1232 const WorkloadInfo& /*Info*/) const
1234 return std::unique_ptr<IWorkload>();
1237 std::unique_ptr<IWorkload> IWorkloadFactory::CreateComparison(const ComparisonQueueDescriptor& /*descriptor*/,
1238 const WorkloadInfo& /*info*/) const
1240 return std::unique_ptr<IWorkload>();
1243 std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& /*descriptor*/,
1244 const WorkloadInfo& /*info*/) const
1246 return std::unique_ptr<IWorkload>();
1249 std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& /*descriptor*/,
1250 const WorkloadInfo& /*info*/) const
1252 return std::unique_ptr<IWorkload>();
1255 std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertBf16ToFp32(const ConvertBf16ToFp32QueueDescriptor& /*desc*/,
1256 const WorkloadInfo& /*info*/) const
1258 return std::unique_ptr<IWorkload>();
1261 std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& /*desc*/,
1262 const WorkloadInfo& /*info*/) const
1264 return std::unique_ptr<IWorkload>();
1267 std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToBf16(const ConvertFp32ToBf16QueueDescriptor& /*desc*/,
1268 const WorkloadInfo& /*info*/) const
1270 return std::unique_ptr<IWorkload>();
1273 std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& /*desc*/,
1274 const WorkloadInfo& /*info*/) const
1276 return std::unique_ptr<IWorkload>();
1279 std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& /*descriptor*/,
1280 const WorkloadInfo& /*info*/) const
1282 return std::unique_ptr<IWorkload>();
1285 std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& /*descriptor*/,
1286 const WorkloadInfo& /*info*/) const
1288 return std::unique_ptr<IWorkload>();
1291 std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthToSpace(const DepthToSpaceQueueDescriptor& /*descriptor*/,
1292 const WorkloadInfo& /*info*/) const
1294 return std::unique_ptr<IWorkload>();
1297 std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
1298 const DepthwiseConvolution2dQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
1300 return std::unique_ptr<IWorkload>();
1303 std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
1304 const DequantizeQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
1306 return std::unique_ptr<IWorkload>();
1309 std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
1310 const DetectionPostProcessQueueDescriptor& /*descriptor*/, const WorkloadInfo& /*info*/) const
1312 return std::unique_ptr<IWorkload>();
1315 std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& /*descriptor*/,
1316 const WorkloadInfo& /*info*/) const
1318 return std::unique_ptr<IWorkload>();
1321 std::unique_ptr<IWorkload> IWorkloadFactory::CreateElementwiseUnary(const ElementwiseUnaryQueueDescriptor& /*desc*/,
1322 const WorkloadInfo& /*info*/) const
1324 return std::unique_ptr<IWorkload>();
1327 std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& /*descriptor*/,
1328 const WorkloadInfo& /*Info*/) const
1330 return std::unique_ptr<IWorkload>();
1333 std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& /*desc*/,
1334 const WorkloadInfo& /*info*/) const
1336 return std::unique_ptr<IWorkload>();
1339 std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& /*descriptor*/,
1340 const WorkloadInfo& /*info*/) const
1342 return std::unique_ptr<IWorkload>();
1345 std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& /*descriptor*/,
1346 const WorkloadInfo& /*info*/) const
1348 return std::unique_ptr<IWorkload>();
1351 std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& /*descriptor*/,
1352 const WorkloadInfo& /*info*/) const
1354 return std::unique_ptr<IWorkload>();
1357 std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& /*descriptor*/,
1358 const WorkloadInfo& /*info*/) const
1360 return std::unique_ptr<IWorkload>();
1363 std::unique_ptr<IWorkload> IWorkloadFactory::CreateInstanceNormalization(
1364 const InstanceNormalizationQueueDescriptor& /*descriptor*/,
1365 const WorkloadInfo& /*info*/) const
1367 return std::unique_ptr<IWorkload>();
1370 std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& /*desc*/,
1371 const WorkloadInfo& /*info*/) const
1373 return std::unique_ptr<IWorkload>();
1376 std::unique_ptr<IWorkload> IWorkloadFactory::CreateLogSoftmax(const LogSoftmaxQueueDescriptor& /*descriptor*/,
1377 const WorkloadInfo& /*info*/) const
1379 return std::unique_ptr<IWorkload>();
1382 std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& /*descriptor*/,
1383 const WorkloadInfo& /*info*/) const
1385 return std::unique_ptr<IWorkload>();
1388 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& /*descriptor*/,
1389 const WorkloadInfo& /*info*/) const
1391 return std::unique_ptr<IWorkload>();
1394 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& /*descriptor*/,
1395 const WorkloadInfo& /*Info*/) const
1397 return std::unique_ptr<IWorkload>();
1400 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& /*descriptor*/,
1401 const WorkloadInfo& /*info*/) const
1403 return std::unique_ptr<IWorkload>();
1406 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& /*descriptor*/,
1407 const WorkloadInfo& /*info*/) const
1409 return std::unique_ptr<IWorkload>();
1412 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& /*descriptor*/,
1413 const WorkloadInfo& /*info*/) const
1415 return std::unique_ptr<IWorkload>();
1418 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& /*descriptor*/,
1419 const WorkloadInfo& /*info*/) const
1421 return std::unique_ptr<IWorkload>();
1424 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& /*descriptor*/,
1425 const WorkloadInfo& /*info*/) const
1427 return std::unique_ptr<IWorkload>();
1430 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& /*descriptor*/,
1431 const WorkloadInfo& /*info*/) const
1433 return std::unique_ptr<IWorkload>();
1436 std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& /*descriptor*/,
1437 const WorkloadInfo& /*info*/) const
1439 return std::unique_ptr<IWorkload>();
1442 std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& /*descriptor*/,
1443 const WorkloadInfo& /*info*/) const
1445 return std::unique_ptr<IWorkload>();
1448 std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& /*descriptor*/,
1449 const WorkloadInfo& /*Info*/) const
1451 return std::unique_ptr<IWorkload>();
1454 std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& /*descriptor*/,
1455 const WorkloadInfo& /*info*/) const
1457 return std::unique_ptr<IWorkload>();
1460 std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& /*descriptor*/,
1461 const WorkloadInfo& /*info*/) const
1463 return std::unique_ptr<IWorkload>();
1466 std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& /*descriptor*/,
1467 const WorkloadInfo& /*info*/) const
1469 return std::unique_ptr<IWorkload>();
1472 std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &/*descriptor*/,
1473 const WorkloadInfo &/*info*/) const
1475 return std::unique_ptr<IWorkload>();
1478 std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& /*descriptor*/,
1479 const WorkloadInfo& /*Info*/) const
1481 return std::unique_ptr<IWorkload>();
1484 std::unique_ptr<IWorkload> IWorkloadFactory::CreateQLstm(const QLstmQueueDescriptor& /*descriptor*/,
1485 const WorkloadInfo& /*info*/) const
1487 return std::unique_ptr<IWorkload>();
1490 std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& /*descriptor*/,
1491 const WorkloadInfo& /*info*/) const
1493 return std::unique_ptr<IWorkload>();
1496 std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& /*descriptor*/,
1497 const WorkloadInfo& /*info*/) const
1499 return std::unique_ptr<IWorkload>();
1502 std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& /*descriptor*/,
1503 const WorkloadInfo& /*info*/) const
1505 return std::unique_ptr<IWorkload>();
1508 std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& /*descriptor*/,
1509 const WorkloadInfo& /*info*/) const
1511 return std::unique_ptr<IWorkload>();
1514 std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& /*descriptor*/,
1515 const WorkloadInfo& /*info*/) const
1517 return std::unique_ptr<IWorkload>();
1520 std::unique_ptr<IWorkload> IWorkloadFactory::CreateSlice(const SliceQueueDescriptor& /*descriptor*/,
1521 const WorkloadInfo& /*info*/) const
1523 return std::unique_ptr<IWorkload>();
1526 std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& /*descriptor*/,
1527 const WorkloadInfo& /*info*/) const
1529 return std::unique_ptr<IWorkload>();
1532 std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& /*descriptor*/,
1533 const WorkloadInfo& /*info*/) const
1535 return std::unique_ptr<IWorkload>();
1538 std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& /*descriptor*/,
1539 const WorkloadInfo& /*info*/) const
1541 return std::unique_ptr<IWorkload>();
1544 std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& /*descriptor*/,
1545 const WorkloadInfo& /*info*/) const
1547 return std::unique_ptr<IWorkload>();
1550 std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& /*descriptor*/,
1551 const WorkloadInfo& /*info*/) const
1553 return std::unique_ptr<IWorkload>();
1556 std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& /*descriptor*/,
1557 const WorkloadInfo& /*info*/) const
1559 return std::unique_ptr<IWorkload>();
1562 std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& /*descriptor*/,
1563 const WorkloadInfo& /*info*/) const
1565 return std::unique_ptr<IWorkload>();
1568 std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& /*descriptor*/,
1569 const WorkloadInfo& /*info*/) const
1571 return std::unique_ptr<IWorkload>();
1574 std::unique_ptr<IWorkload> IWorkloadFactory::CreateTranspose(const TransposeQueueDescriptor& /*descriptor*/,
1575 const WorkloadInfo& /*info*/) const
1577 return std::unique_ptr<IWorkload>();
1580 std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
1581 const TransposeConvolution2dQueueDescriptor& /*descriptor*/,
1582 const WorkloadInfo& /*info*/) const
1584 return std::unique_ptr<IWorkload>();
1587 } // namepsace armnn