2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
5 #include <backends/WorkloadFactory.hpp>
6 #include <backends/LayerSupportRegistry.hpp>
8 #include <armnn/Types.hpp>
9 #include <armnn/LayerSupport.hpp>
11 #include <LayersFwd.hpp>
12 #include "CpuTensorHandle.hpp"
14 #include <boost/cast.hpp>
16 #include <boost/iterator/transform_iterator.hpp>
24 const TensorInfo OverrideDataType(const TensorInfo& info, Optional<DataType> type)
31 return TensorInfo(info.GetShape(), type.value(), info.GetQuantizationScale(), info.GetQuantizationOffset());
34 Optional<DataType> GetBiasTypeFromWeightsType(Optional<DataType> weightsType)
41 switch(weightsType.value())
43 case DataType::Float16:
44 case DataType::Float32:
46 case DataType::QuantisedAsymm8:
47 return DataType::Signed32;
49 BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
51 return EmptyOptional();
54 } // anonymous namespace
56 bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId,
57 const IConnectableLayer& connectableLayer,
58 Optional<DataType> dataType,
59 std::string& outReasonIfUnsupported)
61 Optional<std::string&> reason = outReasonIfUnsupported;
63 const Layer& layer = *(boost::polymorphic_downcast<const Layer*>(&connectableLayer));
65 auto const& layerSupportRegistry = LayerSupportRegistryInstance();
66 auto layerSupportFactory = layerSupportRegistry.GetFactory(backendId);
67 auto layerSupportObject = layerSupportFactory(EmptyInitializer());
69 switch(layer.GetType())
71 case LayerType::Activation:
73 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
74 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
75 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
76 result = layerSupportObject->IsActivationSupported(
77 OverrideDataType(input, dataType),
78 OverrideDataType(output, dataType),
79 cLayer->GetParameters(),
83 case LayerType::Addition:
85 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
86 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
87 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
88 result = layerSupportObject->IsAdditionSupported(
89 OverrideDataType(input0, dataType),
90 OverrideDataType(input1, dataType),
91 OverrideDataType(output, dataType),
95 case LayerType::BatchNormalization:
97 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
98 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
99 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
100 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
101 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
102 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
103 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
104 result = layerSupportObject->IsBatchNormalizationSupported(
105 OverrideDataType(input, dataType),
106 OverrideDataType(output, dataType),
107 OverrideDataType(mean, dataType),
108 OverrideDataType(var, dataType),
109 OverrideDataType(beta, dataType),
110 OverrideDataType(gamma, dataType),
111 cLayer->GetParameters(),
115 case LayerType::Constant:
117 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
118 result = layerSupportObject->IsConstantSupported(OverrideDataType(output, dataType), reason);
121 case LayerType::ConvertFp16ToFp32:
123 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
124 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
125 result = layerSupportObject->IsConvertFp16ToFp32Supported(input, output, reason);
128 case LayerType::ConvertFp32ToFp16:
130 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
131 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
132 result = layerSupportObject->IsConvertFp32ToFp16Supported(input, output, reason);
135 case LayerType::Convolution2d:
137 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
139 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
141 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
142 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
144 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
146 // Construct optional biases object based on the value of m_BiasEnabled
147 Optional<TensorInfo> biases;
148 if (descriptor.m_BiasEnabled)
151 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
154 result = layerSupportObject->IsConvolution2dSupported(
158 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
163 case LayerType::MemCopy:
165 // MemCopy supported for CpuRef, CpuAcc and GpuAcc backends,
166 // (also treat Undefined as CpuRef to avoid breaking lots of Unit tests).
167 result = backendId == Compute::CpuRef || backendId == Compute::Undefined
168 || backendId == Compute::CpuAcc || backendId == Compute::GpuAcc;
169 reason.value() = "Unsupported backend type";
172 case LayerType::DepthwiseConvolution2d:
174 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
175 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
177 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
178 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
180 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
182 // Construct optional biases object based on the value of m_BiasEnabled
183 Optional<TensorInfo> biases;
184 if (descriptor.m_BiasEnabled)
187 OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
190 result = layerSupportObject->IsDepthwiseConvolutionSupported(
194 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
199 case LayerType::FakeQuantization:
201 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
202 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
203 result = layerSupportObject->IsFakeQuantizationSupported(OverrideDataType(input, dataType),
204 cLayer->GetParameters(),
208 case LayerType::Floor:
210 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
211 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
212 result = layerSupportObject->IsFloorSupported(OverrideDataType(input, dataType),
213 OverrideDataType(output, dataType),
217 case LayerType::FullyConnected:
219 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
220 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
221 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
222 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
225 const TensorInfo * biasInfoPtr = nullptr;
226 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
227 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
228 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
230 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
231 if (descriptor.m_BiasEnabled)
233 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
234 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
235 biasInfoPtr = &biasInfo;
239 // If biases are not enabled pass a dummy tensorinfo for the validation
240 switch(input.GetDataType())
242 case DataType::Float16:
244 biasInfoPtr = &dummyFloat16Bias;
247 case DataType::Float32:
249 biasInfoPtr = &dummyFloat32Bias;
252 case DataType::QuantisedAsymm8:
254 biasInfoPtr = &dummyQA8Bias;
259 BOOST_ASSERT_MSG(false, "Unexpected bias type");
264 result = layerSupportObject->IsFullyConnectedSupported(
265 OverrideDataType(input, dataType),
266 OverrideDataType(output, dataType),
267 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
273 case LayerType::Input:
275 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
276 result = layerSupportObject->IsInputSupported(OverrideDataType(input, dataType), reason);
279 case LayerType::L2Normalization:
281 auto cLayer = boost::polymorphic_downcast<const L2NormalizationLayer*>(&layer);
282 const L2NormalizationDescriptor& descriptor = cLayer->GetParameters();
284 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
285 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
287 result = layerSupportObject->IsL2NormalizationSupported(
288 OverrideDataType(input, dataType),
289 OverrideDataType(output, dataType),
294 case LayerType::Lstm:
296 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
297 const LstmDescriptor& descriptor = cLayer->GetParameters();
300 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
302 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
304 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
307 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
308 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
309 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
310 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
313 const TensorInfo& inputToForgetWeights
314 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
315 const TensorInfo& inputToCellWeights
316 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
317 const TensorInfo& inputToOutputWeights
318 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
319 const TensorInfo& recurrentToForgetWeights
320 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
321 const TensorInfo& recurrentToCellWeights
322 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
323 const TensorInfo& recurrentToOutputWeights
324 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
325 const TensorInfo& forgetGateBias
326 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
327 const TensorInfo& cellBias
328 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
329 const TensorInfo& outputGateBias
330 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
332 // Optional parameters
333 const TensorInfo* inputToInputWeights = nullptr;
334 const TensorInfo* recurrentToInputWeights = nullptr;
335 const TensorInfo* cellToInputWeights = nullptr;
336 const TensorInfo* inputGateBias = nullptr;
337 const TensorInfo* projectionWeights = nullptr;
338 const TensorInfo* projectionBias = nullptr;
339 const TensorInfo* cellToForgetWeights = nullptr;
340 const TensorInfo* cellToOutputWeights = nullptr;
342 TensorInfo optInputToInputWeights;
343 TensorInfo optRecurrentToInputWeights;
344 TensorInfo optCellToInputWeights;
345 TensorInfo optInputGateBias;
346 TensorInfo optProjectionWeights;
347 TensorInfo optProjectionBias;
348 TensorInfo optCellToForgetWeights;
349 TensorInfo optCellToOutputWeights;
351 if(!descriptor.m_CifgEnabled)
353 optInputToInputWeights =
354 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
355 inputToInputWeights = &optInputToInputWeights;
357 optRecurrentToInputWeights =
358 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
359 recurrentToInputWeights = &optRecurrentToInputWeights;
360 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
362 optCellToInputWeights =
363 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
364 cellToInputWeights = &optCellToInputWeights;
367 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
368 inputGateBias = &optInputGateBias;
371 if(descriptor.m_ProjectionEnabled)
373 optProjectionWeights =
374 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
375 projectionWeights = &optProjectionWeights;
376 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
379 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
380 projectionBias = &optProjectionBias;
384 if(descriptor.m_PeepholeEnabled)
386 optCellToForgetWeights =
387 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
388 cellToForgetWeights = &optCellToForgetWeights;
389 optCellToOutputWeights =
390 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
391 cellToOutputWeights = &optCellToOutputWeights;
394 result = layerSupportObject->IsLstmSupported(
403 inputToForgetWeights,
405 inputToOutputWeights,
406 recurrentToForgetWeights,
407 recurrentToCellWeights,
408 recurrentToOutputWeights,
413 recurrentToInputWeights,
423 case LayerType::Merger:
425 auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
427 // Get vector of all inputs.
428 auto getTensorInfo = [&dataType](const InputSlot& slot)
430 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
432 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
433 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
434 std::vector<TensorInfo> inputs(beginI, endI);
436 auto getTensorInfoPtr = [](const TensorInfo& info)
440 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
441 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
442 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
444 result = layerSupportObject->IsMergerSupported(inputPtrs, cLayer->GetParameters(), reason);
447 case LayerType::Multiplication:
449 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
450 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
451 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
452 result = layerSupportObject->IsMultiplicationSupported(
453 OverrideDataType(input0, dataType),
454 OverrideDataType(input1, dataType),
455 OverrideDataType(output, dataType),
459 case LayerType::Normalization:
461 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
462 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
463 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
464 result = layerSupportObject->IsNormalizationSupported(OverrideDataType(input, dataType),
465 OverrideDataType(output, dataType),
466 cLayer->GetParameters(),
470 case LayerType::Output:
472 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
473 result = layerSupportObject->IsOutputSupported(OverrideDataType(output, dataType), reason);
476 case LayerType::Permute:
478 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
479 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
480 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
481 result = layerSupportObject->IsPermuteSupported(OverrideDataType(input, dataType),
482 OverrideDataType(output, dataType),
483 cLayer->GetParameters(),
489 auto cLayer = boost::polymorphic_downcast<const PadLayer*>(&layer);
490 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
491 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
492 result = layerSupportObject->IsPadSupported(
493 OverrideDataType(input, dataType),
494 OverrideDataType(output, dataType),
495 cLayer->GetParameters(),
499 case LayerType::Pooling2d:
501 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
502 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
503 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
504 result = layerSupportObject->IsPooling2dSupported(OverrideDataType(input, dataType),
505 OverrideDataType(output, dataType),
506 cLayer->GetParameters(),
510 case LayerType::Division:
512 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
513 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
514 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
515 result = layerSupportObject->IsDivisionSupported(
516 OverrideDataType(input0, dataType),
517 OverrideDataType(input1, dataType),
518 OverrideDataType(output, dataType),
522 case LayerType::Reshape:
524 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
525 result = layerSupportObject->IsReshapeSupported(OverrideDataType(input, dataType), reason);
528 case LayerType::ResizeBilinear:
530 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
531 result = layerSupportObject->IsResizeBilinearSupported(OverrideDataType(input, dataType), reason);
534 case LayerType::Softmax:
536 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
537 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
538 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
539 result = layerSupportObject->IsSoftmaxSupported(OverrideDataType(input, dataType),
540 OverrideDataType(output, dataType),
541 cLayer->GetParameters(),
545 case LayerType::SpaceToBatchNd:
547 auto cLayer = boost::polymorphic_downcast<const SpaceToBatchNdLayer*>(&layer);
548 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
549 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
550 result = layerSupportObject->IsSpaceToBatchNdSupported(OverrideDataType(input, dataType),
551 OverrideDataType(output, dataType),
552 cLayer->GetParameters(),
556 case LayerType::Splitter:
558 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
559 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
560 result = layerSupportObject->IsSplitterSupported(OverrideDataType(input, dataType),
561 cLayer->GetParameters(),
565 case LayerType::Subtraction:
567 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
568 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
569 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
570 result = layerSupportObject->IsSubtractionSupported(
571 OverrideDataType(input0, dataType),
572 OverrideDataType(input1, dataType),
573 OverrideDataType(output, dataType),
577 case LayerType::Mean:
579 auto cLayer = boost::polymorphic_downcast<const MeanLayer*>(&layer);
580 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
581 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
582 result = layerSupportObject->IsMeanSupported(
583 OverrideDataType(input, dataType),
584 OverrideDataType(output, dataType),
585 cLayer->GetParameters(),
591 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
592 reason.value() = "Unrecognised layer type";
600 bool IWorkloadFactory::IsLayerSupported(const IConnectableLayer& connectableLayer,
601 Optional<DataType> dataType,
602 std::string& outReasonIfUnsupported)
604 auto layer = boost::polymorphic_downcast<const Layer*>(&connectableLayer);
605 return IsLayerSupported(layer->GetBackendId(), connectableLayer, dataType, outReasonIfUnsupported);