2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include "CpuTensorHandle.hpp"
7 #include "WorkloadFactory.hpp"
11 #include <LayersFwd.hpp>
13 #include <armnn/Types.hpp>
14 #include <armnn/LayerSupport.hpp>
15 #include <armnn/ILayerSupport.hpp>
17 #include <backendsCommon/BackendRegistry.hpp>
18 #include <backendsCommon/WorkloadFactory.hpp>
19 #include <backendsCommon/IBackendInternal.hpp>
20 #include <backendsCommon/test/WorkloadTestUtils.hpp>
22 #include <boost/cast.hpp>
23 #include <boost/iterator/transform_iterator.hpp>
34 const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
41 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
44 } // anonymous namespace
46 bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
47 const IConnectableLayer& connectableLayer,
48 Optional<DataType> dataType,
49 std::string& outReasonIfUnsupported)
51 Optional<std::string&> reason = outReasonIfUnsupported;
53 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
55 auto const& backendRegistry = BackendRegistryInstance();
56 if (!backendRegistry.IsBackendRegistered(backendId))
59 ss << connectableLayer.GetName() << " is not supported on " << backendId
60 << " because this backend is not registered.";
62 outReasonIfUnsupported = ss.str();
66 auto backendFactory = backendRegistry.GetFactory(backendId);
67 auto backendObject = backendFactory();
68 auto layerSupportObject = backendObject->GetLayerSupport();
70 switch(layer.GetType())
74 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
75 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
76 result = layerSupportObject->IsAbsSupported(OverrideDataType(input, dataType),
77 OverrideDataType(output, dataType),
81 case LayerType::Activation:
83 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
84 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
85 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
86 result = layerSupportObject->IsActivationSupported(
87 OverrideDataType(input, dataType),
88 OverrideDataType(output, dataType),
89 cLayer->GetParameters(),
93 case LayerType::Addition:
95 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
96 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
97 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
98 result = layerSupportObject->IsAdditionSupported(
99 OverrideDataType(input0, dataType),
100 OverrideDataType(input1, dataType),
101 OverrideDataType(output, dataType),
105 case LayerType::ArgMinMax:
107 auto cLayer = boost::polymorphic_downcast<const ArgMinMaxLayer*>(&layer);
108 const ArgMinMaxDescriptor& descriptor = cLayer->GetParameters();
110 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
111 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
112 result = layerSupportObject->IsArgMinMaxSupported(
113 OverrideDataType(input, dataType),
114 OverrideDataType(output, dataType),
119 case LayerType::BatchNormalization:
121 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
122 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
123 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
124 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
125 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
126 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
127 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
128 result = layerSupportObject->IsBatchNormalizationSupported(
129 OverrideDataType(input, dataType),
130 OverrideDataType(output, dataType),
131 OverrideDataType(mean, dataType),
132 OverrideDataType(var, dataType),
133 OverrideDataType(beta, dataType),
134 OverrideDataType(gamma, dataType),
135 cLayer->GetParameters(),
139 case LayerType::BatchToSpaceNd:
141 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
142 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
143 auto cLayer = boost::polymorphic_downcast<const BatchToSpaceNdLayer*>(&layer);
145 result = layerSupportObject->IsBatchToSpaceNdSupported(OverrideDataType(input, dataType),
146 OverrideDataType(output, dataType),
147 cLayer->GetParameters(),
151 case LayerType::Constant:
153 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
154 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
157 case LayerType::ConvertFp16ToFp32:
159 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
160 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
161 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
164 case LayerType::ConvertFp32ToFp16:
166 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
167 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
168 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
171 case LayerType::Convolution2d:
173 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
175 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
177 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
178 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
180 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
182 // Construct optional biases object based on the value of m_BiasEnabled
183 Optional<TensorInfo> biases;
184 if (descriptor.m_BiasEnabled)
187 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
190 result = layerSupportObject->IsConvolution2dSupported(
194 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
199 case LayerType::Debug:
201 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
202 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
204 result = layerSupportObject->IsDebugSupported(OverrideDataType(input, dataType),
205 OverrideDataType(output, dataType),
209 case LayerType::DepthwiseConvolution2d:
211 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
212 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
214 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
215 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
217 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
219 // Construct optional biases object based on the value of m_BiasEnabled
220 Optional<TensorInfo> biases;
221 if (descriptor.m_BiasEnabled)
224 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
227 result = layerSupportObject->IsDepthwiseConvolutionSupported(
231 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
236 case LayerType::Dequantize:
238 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
239 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
241 result = layerSupportObject->IsDequantizeSupported(OverrideDataType(input, dataType),
242 OverrideDataType(output, DataType::Float32),
246 case LayerType::DetectionPostProcess:
248 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
249 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
250 auto cLayer = boost::polymorphic_downcast<const DetectionPostProcessLayer*>(&layer);
251 const DetectionPostProcessDescriptor& descriptor = cLayer->GetParameters();
252 result = layerSupportObject->IsDetectionPostProcessSupported(input0,
258 case LayerType::Equal:
260 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
261 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
262 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
263 result = layerSupportObject->IsEqualSupported(OverrideDataType(input0, dataType),
264 OverrideDataType(input1, dataType),
265 OverrideDataType(output, dataType),
269 case LayerType::FakeQuantization:
271 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
272 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
273 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
274 cLayer->GetParameters(),
278 case LayerType::Floor:
280 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
281 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
282 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
283 OverrideDataType(output, dataType),
287 case LayerType::FullyConnected:
289 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
290 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
291 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
292 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
295 const TensorInfo * biasInfoPtr = nullptr;
296 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
297 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
298 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
300 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
301 if (descriptor.m_BiasEnabled)
303 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
304 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
305 biasInfoPtr = &biasInfo;
309 // If biases are not enabled pass a dummy tensorinfo for the validation
310 switch(input.GetDataType())
312 case DataType::Float16:
314 biasInfoPtr = &dummyFloat16Bias;
317 case DataType::Float32:
319 biasInfoPtr = &dummyFloat32Bias;
322 case DataType::QuantisedAsymm8:
323 case DataType::QuantisedSymm16:
325 biasInfoPtr = &dummyQA8Bias;
330 BOOST_ASSERT_MSG(false, "Unexpected bias type");
335 result = layerSupportObject->IsFullyConnectedSupported(
336 OverrideDataType(input, dataType),
337 OverrideDataType(output, dataType),
338 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
344 case LayerType::Gather:
346 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
347 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
348 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
349 result = layerSupportObject->IsGatherSupported(OverrideDataType(input0, dataType),
351 OverrideDataType(output, dataType),
355 case LayerType::Input:
357 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
358 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
361 case LayerType::L2Normalization:
363 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
364 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
366 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
367 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
369 result = layerSupportObject->IsL2NormalizationSupported(
370 OverrideDataType(input, dataType),
371 OverrideDataType(output, dataType),
376 case LayerType::Lstm:
378 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
379 const LstmDescriptor& descriptor = cLayer->GetParameters();
382 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
384 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
386 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
389 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
390 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
391 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
392 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
395 const TensorInfo& inputToForgetWeights
396 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
397 const TensorInfo& inputToCellWeights
398 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
399 const TensorInfo& inputToOutputWeights
400 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
401 const TensorInfo& recurrentToForgetWeights
402 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
403 const TensorInfo& recurrentToCellWeights
404 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
405 const TensorInfo& recurrentToOutputWeights
406 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
407 const TensorInfo& forgetGateBias
408 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
409 const TensorInfo& cellBias
410 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
411 const TensorInfo& outputGateBias
412 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
414 LstmInputParamsInfo paramsInfo;
416 paramsInfo.m_InputToForgetWeights = &inputToForgetWeights;
417 paramsInfo.m_InputToCellWeights = &inputToCellWeights;
418 paramsInfo.m_InputToOutputWeights = &inputToOutputWeights;
419 paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights;
420 paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights;
421 paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights;
422 paramsInfo.m_ForgetGateBias = &forgetGateBias;
423 paramsInfo.m_CellBias = &cellBias;
424 paramsInfo.m_OutputGateBias = &outputGateBias;
427 // Optional parameters
428 TensorInfo optInputToInputWeights;
429 TensorInfo optRecurrentToInputWeights;
430 TensorInfo optCellToInputWeights;
431 TensorInfo optInputGateBias;
432 TensorInfo optProjectionWeights;
433 TensorInfo optProjectionBias;
434 TensorInfo optCellToForgetWeights;
435 TensorInfo optCellToOutputWeights;
436 TensorInfo optInputLayerNormWeights;
437 TensorInfo optForgetLayerNormWeights;
438 TensorInfo optCellLayerNormWeights;
439 TensorInfo optOutputLayerNormWeights;
441 if(!descriptor.m_CifgEnabled)
443 optInputToInputWeights =
444 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
445 paramsInfo.m_InputToInputWeights = &optInputToInputWeights;
447 optRecurrentToInputWeights =
448 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
449 paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights;
450 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
452 optCellToInputWeights =
453 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
454 paramsInfo.m_CellToInputWeights = &optCellToInputWeights;
457 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
458 paramsInfo.m_InputGateBias = &optInputGateBias;
461 if(descriptor.m_ProjectionEnabled)
463 optProjectionWeights =
464 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
465 paramsInfo.m_ProjectionWeights = &optProjectionWeights;
466 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
469 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
470 paramsInfo.m_ProjectionBias = &optProjectionBias;
474 if(descriptor.m_PeepholeEnabled)
476 optCellToForgetWeights =
477 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
478 paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights;
479 optCellToOutputWeights =
480 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
481 paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights;
484 if(descriptor.m_LayerNormEnabled)
486 if (!descriptor.m_CifgEnabled)
488 optInputLayerNormWeights = OverrideDataType(
489 cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType);
490 paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights;
493 optForgetLayerNormWeights = OverrideDataType(
494 cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType);
495 paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights;
497 optCellLayerNormWeights = OverrideDataType(
498 cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType);
499 paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights;
501 optOutputLayerNormWeights = OverrideDataType(
502 cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType);
503 paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights;
506 result = layerSupportObject->IsLstmSupported(
519 case LayerType::Maximum:
521 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
522 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
523 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
525 result = layerSupportObject->IsMaximumSupported(OverrideDataType(input0, dataType),
526 OverrideDataType(input1, dataType),
527 OverrideDataType(output, dataType),
531 case LayerType::MemCopy:
533 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
534 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
536 result = layerSupportObject->IsMemCopySupported(OverrideDataType(input, dataType),
537 OverrideDataType(output, dataType),
541 case LayerType::MemImport:
543 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
544 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
546 result = layerSupportObject->IsMemImportSupported(OverrideDataType(input, dataType),
547 OverrideDataType(output, dataType),
551 case LayerType::Merge:
553 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
554 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
555 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
557 result = layerSupportObject->IsMergeSupported(OverrideDataType(input0, dataType),
558 OverrideDataType(input1, dataType),
559 OverrideDataType(output, dataType),
563 case LayerType::Concat:
565 auto cLayer = boost::polymorphic_downcast<const ConcatLayer*>(&layer);
567 // Get vector of all inputs.
568 auto getTensorInfo = [&dataType](const InputSlot& slot)
570 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
572 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
573 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
574 std::vector<TensorInfo> inputs(beginI, endI);
576 auto getTensorInfoPtr = [](const TensorInfo& info)
580 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
581 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
582 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
584 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
586 result = layerSupportObject->IsConcatSupported(inputPtrs, output, cLayer->GetParameters(), reason);
591 case LayerType::Multiplication:
593 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
594 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
595 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
596 result = layerSupportObject->IsMultiplicationSupported(
597 OverrideDataType(input0, dataType),
598 OverrideDataType(input1, dataType),
599 OverrideDataType(output, dataType),
603 case LayerType::Normalization:
605 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
606 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
607 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
608 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
609 OverrideDataType(output, dataType),
610 cLayer->GetParameters(),
614 case LayerType::Output:
616 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
617 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
620 case LayerType::Permute:
622 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
623 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
624 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
625 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
626 OverrideDataType(output, dataType),
627 cLayer->GetParameters(),
633 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
634 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
635 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
636 result = layerSupportObject->IsPadSupported(
637 OverrideDataType(input, dataType),
638 OverrideDataType(output, dataType),
639 cLayer->GetParameters(),
643 case LayerType::Pooling2d:
645 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
646 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
647 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
648 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
649 OverrideDataType(output, dataType),
650 cLayer->GetParameters(),
654 case LayerType::PreCompiled:
656 auto cLayer = boost::polymorphic_downcast<const PreCompiledLayer*>(&layer);
657 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
658 result = layerSupportObject->IsPreCompiledSupported(OverrideDataType(input, dataType),
659 cLayer->GetParameters(),
663 case LayerType::Quantize:
665 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
666 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
667 result = layerSupportObject->IsQuantizeSupported(input, output, reason);
670 case LayerType::QuantizedLstm:
672 auto cLayer = boost::polymorphic_downcast<const QuantizedLstmLayer*>(&layer);
675 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
676 const TensorInfo& previousCellStateIn = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
677 const TensorInfo& previousOutputIn = layer.GetInputSlot(2).GetConnection()->GetTensorInfo();
680 const TensorInfo& cellStateOut = layer.GetOutputSlot(0).GetTensorInfo();
681 const TensorInfo& output = layer.GetOutputSlot(1).GetTensorInfo();
683 // QuantizedLstm parameters
684 QuantizedLstmInputParamsInfo paramsInfo;
686 paramsInfo.m_InputToInputWeights =
687 &cLayer->m_QuantizedLstmParameters.m_InputToInputWeights->GetTensorInfo();
688 paramsInfo.m_InputToForgetWeights =
689 &cLayer->m_QuantizedLstmParameters.m_InputToForgetWeights->GetTensorInfo();
690 paramsInfo.m_InputToCellWeights =
691 &cLayer->m_QuantizedLstmParameters.m_InputToCellWeights->GetTensorInfo();
692 paramsInfo.m_InputToOutputWeights =
693 &cLayer->m_QuantizedLstmParameters.m_InputToOutputWeights->GetTensorInfo();
695 paramsInfo.m_RecurrentToInputWeights =
696 &cLayer->m_QuantizedLstmParameters.m_RecurrentToInputWeights->GetTensorInfo();
697 paramsInfo.m_RecurrentToForgetWeights =
698 &cLayer->m_QuantizedLstmParameters.m_RecurrentToForgetWeights->GetTensorInfo();
699 paramsInfo.m_RecurrentToCellWeights =
700 &cLayer->m_QuantizedLstmParameters.m_RecurrentToCellWeights->GetTensorInfo();
701 paramsInfo.m_RecurrentToOutputWeights =
702 &cLayer->m_QuantizedLstmParameters.m_RecurrentToOutputWeights->GetTensorInfo();
704 paramsInfo.m_InputGateBias =
705 &cLayer->m_QuantizedLstmParameters.m_InputGateBias->GetTensorInfo();
706 paramsInfo.m_ForgetGateBias =
707 &cLayer->m_QuantizedLstmParameters.m_ForgetGateBias->GetTensorInfo();
708 paramsInfo.m_CellBias =
709 &cLayer->m_QuantizedLstmParameters.m_CellBias->GetTensorInfo();
710 paramsInfo.m_OutputGateBias =
711 &cLayer->m_QuantizedLstmParameters.m_OutputGateBias->GetTensorInfo();;
713 result = layerSupportObject->IsQuantizedLstmSupported(input,
722 case LayerType::Division:
724 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
725 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
726 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
727 result = layerSupportObject->IsDivisionSupported(
728 OverrideDataType(input0, dataType),
729 OverrideDataType(input1, dataType),
730 OverrideDataType(output, dataType),
734 case LayerType::Reshape:
736 auto cLayer = boost::polymorphic_downcast<const ReshapeLayer*>(&layer);
737 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
738 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType),
739 cLayer->GetParameters(),
743 case LayerType::Resize:
745 auto cLayer = boost::polymorphic_downcast<const ResizeLayer*>(&layer);
746 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
747 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
748 result = layerSupportObject->IsResizeSupported(OverrideDataType(input, dataType),
749 OverrideDataType(output, dataType),
750 cLayer->GetParameters(),
754 case LayerType::Rsqrt:
756 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
757 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
758 result = layerSupportObject->IsRsqrtSupported(OverrideDataType(input, dataType),
759 OverrideDataType(output, dataType),
763 case LayerType::Softmax:
765 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
766 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
767 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
768 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
769 OverrideDataType(output, dataType),
770 cLayer->GetParameters(),
774 case LayerType::SpaceToBatchNd:
776 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
777 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
778 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
779 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
780 OverrideDataType(output, dataType),
781 cLayer->GetParameters(),
785 case LayerType::SpaceToDepth:
787 auto cLayer = boost::polymorphic_downcast<const SpaceToDepthLayer*>(&layer);
789 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
790 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
792 result = layerSupportObject->IsSpaceToDepthSupported(OverrideDataType(input, dataType),
793 OverrideDataType(output, dataType),
794 cLayer->GetParameters(),
798 case LayerType::Splitter:
800 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
801 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
803 // Get vector of all outputs.
804 auto getTensorInfo = [&dataType](const OutputSlot& slot)
806 return OverrideDataType(slot.GetTensorInfo(), dataType);
808 auto beginI = boost::make_transform_iterator(layer.GetOutputSlots().begin(), getTensorInfo);
809 auto endI = boost::make_transform_iterator(layer.GetOutputSlots().end(), getTensorInfo);
810 std::vector<TensorInfo> outputs(beginI, endI);
812 const std::vector<std::reference_wrapper<TensorInfo>> outputPtrs(outputs.begin(), outputs.end());
814 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
816 cLayer->GetParameters(),
820 case LayerType::Stack:
822 auto cLayer = boost::polymorphic_downcast<const StackLayer*>(&layer);
824 // Get vector of all inputs.
825 auto getTensorInfo = [&dataType](const InputSlot& slot)
827 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
829 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
830 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
831 std::vector<TensorInfo> inputs(beginI, endI);
833 auto getTensorInfoPtr = [](const TensorInfo& info)
837 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
838 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
839 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
841 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
843 result = layerSupportObject->IsStackSupported(inputPtrs, output, cLayer->GetParameters(), reason);
847 case LayerType::StridedSlice:
849 auto cLayer = boost::polymorphic_downcast<const StridedSliceLayer*>(&layer);
850 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
851 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
852 result = layerSupportObject->IsStridedSliceSupported(OverrideDataType(input, dataType),
853 OverrideDataType(output, dataType),
854 cLayer->GetParameters(),
858 case LayerType::Subtraction:
860 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
861 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
862 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
863 result = layerSupportObject->IsSubtractionSupported(
864 OverrideDataType(input0, dataType),
865 OverrideDataType(input1, dataType),
866 OverrideDataType(output, dataType),
870 case LayerType::Switch:
872 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
873 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
874 const TensorInfo& output0 = layer.GetOutputSlot(0).GetTensorInfo();
875 const TensorInfo& output1 = layer.GetOutputSlot(1).GetTensorInfo();
876 result = layerSupportObject->IsSwitchSupported(OverrideDataType(input0, dataType),
877 OverrideDataType(input1, dataType),
878 OverrideDataType(output0, dataType),
879 OverrideDataType(output1, dataType),
883 case LayerType::Mean:
885 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
886 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
887 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
888 result = layerSupportObject->IsMeanSupported(
889 OverrideDataType(input, dataType),
890 OverrideDataType(output, dataType),
891 cLayer->GetParameters(),
895 case LayerType::Minimum:
897 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
898 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
899 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
900 result = layerSupportObject->IsMinimumSupported(OverrideDataType(input0, dataType),
901 OverrideDataType(input1, dataType),
902 OverrideDataType(output, dataType),
906 case LayerType::Greater:
908 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
909 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
910 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
911 result = layerSupportObject->IsGreaterSupported(OverrideDataType(input0, dataType),
912 OverrideDataType(input1, dataType),
913 OverrideDataType(output, DataType::Boolean),
917 case LayerType::Prelu:
919 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
920 const TensorInfo& alpha = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
921 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
922 result = layerSupportObject->IsPreluSupported(OverrideDataType(input, dataType),
923 OverrideDataType(alpha, dataType),
924 OverrideDataType(output, dataType),
928 case LayerType::TransposeConvolution2d:
930 auto cLayer = boost::polymorphic_downcast<const TransposeConvolution2dLayer*>(&layer);
932 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
934 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
936 const TransposeConvolution2dDescriptor& descriptor = cLayer->GetParameters();
938 Optional<TensorInfo> biases;
939 if (descriptor.m_BiasEnabled)
941 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
942 biases = OverrideDataType(cLayer->m_Bias->GetTensorInfo(),
943 GetBiasTypeFromWeightsType(dataType));
946 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
947 const TensorInfo weights = OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType);
949 result = layerSupportObject->IsTransposeConvolution2dSupported(input,
960 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
961 reason.value() = "Unrecognised layer type";
969 bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
970 Optional<DataType> dataType,
971 std::string& outReasonIfUnsupported)
973 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
974 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);
977 // Default Implementations
978 std::unique_ptr<IWorkload> IWorkloadFactory::CreateAbs(const AbsQueueDescriptor& descriptor,
979 const WorkloadInfo& info) const
981 return std::unique_ptr<IWorkload>();
984 std::unique_ptr<IWorkload> IWorkloadFactory::CreateActivation(const ActivationQueueDescriptor& descriptor,
985 const WorkloadInfo& info) const
987 return std::unique_ptr<IWorkload>();
990 std::unique_ptr<IWorkload> IWorkloadFactory::CreateAddition(const AdditionQueueDescriptor& descriptor,
991 const WorkloadInfo& info) const
993 return std::unique_ptr<IWorkload>();
996 std::unique_ptr<IWorkload> IWorkloadFactory::CreateArgMinMax(const ArgMinMaxQueueDescriptor& descriptor,
997 const WorkloadInfo& info) const
999 return std::unique_ptr<IWorkload>();
1002 std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchNormalization(
1003 const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) const
1005 return std::unique_ptr<IWorkload>();
1008 std::unique_ptr<IWorkload> IWorkloadFactory::CreateBatchToSpaceNd(const BatchToSpaceNdQueueDescriptor& descriptor,
1009 const WorkloadInfo& Info) const
1011 return std::unique_ptr<IWorkload>();
1014 std::unique_ptr<IWorkload> IWorkloadFactory::CreateConcat(const ConcatQueueDescriptor& descriptor,
1015 const WorkloadInfo& info) const
1017 return std::unique_ptr<IWorkload>();
1020 std::unique_ptr<IWorkload> IWorkloadFactory::CreateConstant(const ConstantQueueDescriptor& descriptor,
1021 const WorkloadInfo& info) const
1023 return std::unique_ptr<IWorkload>();
1026 std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp16ToFp32(const ConvertFp16ToFp32QueueDescriptor& descriptor,
1027 const WorkloadInfo& info) const
1029 return std::unique_ptr<IWorkload>();
1032 std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvertFp32ToFp16(const ConvertFp32ToFp16QueueDescriptor& descriptor,
1033 const WorkloadInfo& info) const
1035 return std::unique_ptr<IWorkload>();
1038 std::unique_ptr<IWorkload> IWorkloadFactory::CreateConvolution2d(const Convolution2dQueueDescriptor& descriptor,
1039 const WorkloadInfo& info) const
1041 return std::unique_ptr<IWorkload>();
1044 std::unique_ptr<IWorkload> IWorkloadFactory::CreateDebug(const DebugQueueDescriptor& descriptor,
1045 const WorkloadInfo& info) const
1047 return std::unique_ptr<IWorkload>();
1050 std::unique_ptr<IWorkload> IWorkloadFactory::CreateDepthwiseConvolution2d(
1051 const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) const
1053 return std::unique_ptr<IWorkload>();
1056 std::unique_ptr<IWorkload> IWorkloadFactory::CreateDequantize(
1057 const DequantizeQueueDescriptor& descriptor, const WorkloadInfo& info) const
1059 return std::unique_ptr<IWorkload>();
1062 std::unique_ptr<IWorkload> IWorkloadFactory::CreateDetectionPostProcess(
1063 const DetectionPostProcessQueueDescriptor& descriptor, const WorkloadInfo& info) const
1065 return std::unique_ptr<IWorkload>();
1068 std::unique_ptr<IWorkload> IWorkloadFactory::CreateDivision(const DivisionQueueDescriptor& descriptor,
1069 const WorkloadInfo& info) const
1071 return std::unique_ptr<IWorkload>();
1074 std::unique_ptr<IWorkload> IWorkloadFactory::CreateEqual(const EqualQueueDescriptor& descriptor,
1075 const WorkloadInfo& Info) const
1077 return std::unique_ptr<IWorkload>();
1080 std::unique_ptr<IWorkload> IWorkloadFactory::CreateFakeQuantization(const FakeQuantizationQueueDescriptor& descriptor,
1081 const WorkloadInfo& info) const
1083 return std::unique_ptr<IWorkload>();
1086 std::unique_ptr<IWorkload> IWorkloadFactory::CreateFloor(const FloorQueueDescriptor& descriptor,
1087 const WorkloadInfo& info) const
1089 return std::unique_ptr<IWorkload>();
1092 std::unique_ptr<IWorkload> IWorkloadFactory::CreateFullyConnected(const FullyConnectedQueueDescriptor& descriptor,
1093 const WorkloadInfo& info) const
1095 return std::unique_ptr<IWorkload>();
1098 std::unique_ptr<IWorkload> IWorkloadFactory::CreateGather(const GatherQueueDescriptor& descriptor,
1099 const WorkloadInfo& info) const
1101 return std::unique_ptr<IWorkload>();
1104 std::unique_ptr<IWorkload> IWorkloadFactory::CreateGreater(const GreaterQueueDescriptor& descriptor,
1105 const WorkloadInfo& info) const
1107 return std::unique_ptr<IWorkload>();
1110 std::unique_ptr<IWorkload> IWorkloadFactory::CreateL2Normalization(const L2NormalizationQueueDescriptor& descriptor,
1111 const WorkloadInfo& info) const
1113 return std::unique_ptr<IWorkload>();
1116 std::unique_ptr<IWorkload> IWorkloadFactory::CreateLstm(const LstmQueueDescriptor& descriptor,
1117 const WorkloadInfo& info) const
1119 return std::unique_ptr<IWorkload>();
1122 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMaximum(const MaximumQueueDescriptor& descriptor,
1123 const WorkloadInfo& info) const
1125 return std::unique_ptr<IWorkload>();
1128 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMean(const MeanQueueDescriptor& descriptor,
1129 const WorkloadInfo& Info) const
1131 return std::unique_ptr<IWorkload>();
1134 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemCopy(const MemCopyQueueDescriptor& descriptor,
1135 const WorkloadInfo& info) const
1137 return std::unique_ptr<IWorkload>();
1140 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMemImport(const MemImportQueueDescriptor& descriptor,
1141 const WorkloadInfo& info) const
1143 return std::unique_ptr<IWorkload>();
1146 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerge(const MergeQueueDescriptor& descriptor,
1147 const WorkloadInfo& info) const
1149 return std::unique_ptr<IWorkload>();
1152 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMerger(const MergerQueueDescriptor& descriptor,
1153 const WorkloadInfo& info) const
1155 return std::unique_ptr<IWorkload>();
1158 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMinimum(const MinimumQueueDescriptor& descriptor,
1159 const WorkloadInfo& info) const
1161 return std::unique_ptr<IWorkload>();
1164 std::unique_ptr<IWorkload> IWorkloadFactory::CreateMultiplication(const MultiplicationQueueDescriptor& descriptor,
1165 const WorkloadInfo& info) const
1167 return std::unique_ptr<IWorkload>();
1170 std::unique_ptr<IWorkload> IWorkloadFactory::CreateNormalization(const NormalizationQueueDescriptor& descriptor,
1171 const WorkloadInfo& info) const
1173 return std::unique_ptr<IWorkload>();
1176 std::unique_ptr<IWorkload> IWorkloadFactory::CreateOutput(const OutputQueueDescriptor& descriptor,
1177 const WorkloadInfo& info) const
1179 return std::unique_ptr<IWorkload>();
1182 std::unique_ptr<IWorkload> IWorkloadFactory::CreatePad(const PadQueueDescriptor& descriptor,
1183 const WorkloadInfo& Info) const
1185 return std::unique_ptr<IWorkload>();
1188 std::unique_ptr<IWorkload> IWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
1189 const WorkloadInfo& info) const
1191 return std::unique_ptr<IWorkload>();
1194 std::unique_ptr<IWorkload> IWorkloadFactory::CreatePooling2d(const Pooling2dQueueDescriptor& descriptor,
1195 const WorkloadInfo& info) const
1197 return std::unique_ptr<IWorkload>();
1200 std::unique_ptr<IWorkload> IWorkloadFactory::CreatePreCompiled(const PreCompiledQueueDescriptor& descriptor,
1201 const WorkloadInfo& info) const
1203 return std::unique_ptr<IWorkload>();
1206 std::unique_ptr<IWorkload> IWorkloadFactory::CreatePrelu(const PreluQueueDescriptor &descriptor,
1207 const WorkloadInfo &info) const
1209 return std::unique_ptr<IWorkload>();
1212 std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantize(const QuantizeQueueDescriptor& descriptor,
1213 const WorkloadInfo& Info) const
1215 return std::unique_ptr<IWorkload>();
1218 std::unique_ptr<IWorkload> IWorkloadFactory::CreateQuantizedLstm(const QuantizedLstmQueueDescriptor& descriptor,
1219 const WorkloadInfo& info) const
1221 return std::unique_ptr<IWorkload>();
1224 std::unique_ptr<IWorkload> IWorkloadFactory::CreateReshape(const ReshapeQueueDescriptor& descriptor,
1225 const WorkloadInfo& info) const
1227 return std::unique_ptr<IWorkload>();
1230 std::unique_ptr<IWorkload> IWorkloadFactory::CreateResizeBilinear(const ResizeBilinearQueueDescriptor& descriptor,
1231 const WorkloadInfo& info) const
1233 return std::unique_ptr<IWorkload>();
1236 std::unique_ptr<IWorkload> IWorkloadFactory::CreateResize(const ResizeQueueDescriptor& descriptor,
1237 const WorkloadInfo& info) const
1239 return std::unique_ptr<IWorkload>();
1242 std::unique_ptr<IWorkload> IWorkloadFactory::CreateRsqrt(const RsqrtQueueDescriptor& descriptor,
1243 const WorkloadInfo& info) const
1245 return std::unique_ptr<IWorkload>();
1248 std::unique_ptr<IWorkload> IWorkloadFactory::CreateSoftmax(const SoftmaxQueueDescriptor& descriptor,
1249 const WorkloadInfo& info) const
1251 return std::unique_ptr<IWorkload>();
1254 std::unique_ptr<IWorkload> IWorkloadFactory::CreateSplitter(const SplitterQueueDescriptor& descriptor,
1255 const WorkloadInfo& info) const
1257 return std::unique_ptr<IWorkload>();
1260 std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToBatchNd(const SpaceToBatchNdQueueDescriptor& descriptor,
1261 const WorkloadInfo& info) const
1263 return std::unique_ptr<IWorkload>();
1266 std::unique_ptr<IWorkload> IWorkloadFactory::CreateSpaceToDepth(const SpaceToDepthQueueDescriptor& descriptor,
1267 const WorkloadInfo& info) const
1269 return std::unique_ptr<IWorkload>();
1272 std::unique_ptr<IWorkload> IWorkloadFactory::CreateStack(const StackQueueDescriptor& descriptor,
1273 const WorkloadInfo& info) const
1275 return std::unique_ptr<IWorkload>();
1278 std::unique_ptr<IWorkload> IWorkloadFactory::CreateStridedSlice(const StridedSliceQueueDescriptor& descriptor,
1279 const WorkloadInfo& Info) const
1281 return std::unique_ptr<IWorkload>();
1284 std::unique_ptr<IWorkload> IWorkloadFactory::CreateSubtraction(const SubtractionQueueDescriptor& descriptor,
1285 const WorkloadInfo& info) const
1287 return std::unique_ptr<IWorkload>();
1290 std::unique_ptr<IWorkload> IWorkloadFactory::CreateSwitch(const SwitchQueueDescriptor& descriptor,
1291 const WorkloadInfo& info) const
1293 return std::unique_ptr<IWorkload>();
1296 std::unique_ptr<IWorkload> IWorkloadFactory::CreateTransposeConvolution2d(
1297 const TransposeConvolution2dQueueDescriptor& descriptor,
1298 const WorkloadInfo& info) const
1300 return std::unique_ptr<IWorkload>();
1303 } // namepsace armnn