2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
5 #include "WorkloadFactory.hpp"
6 #include "RefWorkloadFactory.hpp"
7 #include "NeonWorkloadFactory.hpp"
8 #include "ClWorkloadFactory.hpp"
10 #include "armnn/Types.hpp"
11 #include "armnn/LayerSupport.hpp"
13 #include "LayersFwd.hpp"
14 #include "CpuTensorHandle.hpp"
16 #include <boost/cast.hpp>
18 #include <boost/iterator/transform_iterator.hpp>
25 const TensorInfo OverrideDataType(const TensorInfo& info, boost::optional<DataType> type)
27 if (type == boost::none)
32 return TensorInfo(info.GetShape(), type.get(), info.GetQuantizationScale(), info.GetQuantizationOffset());
35 boost::optional<DataType> GetBiasTypeFromWeightsType(boost::optional<DataType> weightsType)
37 if (weightsType == boost::none)
42 switch(weightsType.get())
44 case DataType::Float16:
45 case DataType::Float32:
47 case DataType::QuantisedAsymm8:
48 return DataType::Signed32;
50 BOOST_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
56 bool IWorkloadFactory::IsLayerSupported(Compute compute, const Layer& layer, boost::optional<DataType> dataType,
57 std::string& outReasonIfUnsupported)
59 constexpr size_t reasonCapacity = 1024;
60 char reason[reasonCapacity];
62 switch(layer.GetType())
64 case LayerType::Activation:
66 auto cLayer = boost::polymorphic_downcast<const ActivationLayer*>(&layer);
67 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
68 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
69 result = IsActivationSupported(compute,
70 OverrideDataType(input, dataType),
71 OverrideDataType(output, dataType),
72 cLayer->GetParameters(),
77 case LayerType::Addition:
79 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
80 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
81 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
82 result = IsAdditionSupported(compute,
83 OverrideDataType(input0, dataType),
84 OverrideDataType(input1, dataType),
85 OverrideDataType(output, dataType),
90 case LayerType::BatchNormalization:
92 auto cLayer = boost::polymorphic_downcast<const BatchNormalizationLayer*>(&layer);
93 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
94 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
95 const TensorInfo& mean = cLayer->m_Mean->GetTensorInfo();
96 const TensorInfo& var = cLayer->m_Variance->GetTensorInfo();
97 const TensorInfo& beta = cLayer->m_Beta->GetTensorInfo();
98 const TensorInfo& gamma = cLayer->m_Gamma->GetTensorInfo();
99 result = IsBatchNormalizationSupported(compute,
100 OverrideDataType(input, dataType),
101 OverrideDataType(output, dataType),
102 OverrideDataType(mean, dataType),
103 OverrideDataType(var, dataType),
104 OverrideDataType(beta, dataType),
105 OverrideDataType(gamma, dataType),
106 cLayer->GetParameters(),
107 reason, reasonCapacity);
110 case LayerType::Constant:
112 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
113 result = IsConstantSupported(compute, OverrideDataType(output, dataType), reason, reasonCapacity);
116 case LayerType::ConvertFp16ToFp32:
118 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
119 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
120 result = IsConvertFp16ToFp32Supported(compute, input, output, reason, reasonCapacity);
123 case LayerType::ConvertFp32ToFp16:
125 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
126 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
127 result = IsConvertFp32ToFp16Supported(compute, input, output, reason, reasonCapacity);
130 case LayerType::Convolution2d:
132 auto cLayer = boost::polymorphic_downcast<const Convolution2dLayer*>(&layer);
133 const TensorInfo input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(), dataType);
134 const TensorInfo output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
135 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
138 const TensorInfo * biasInfoPtr = nullptr;
139 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
140 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
141 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
143 const Convolution2dDescriptor& descriptor = cLayer->GetParameters();
145 if (descriptor.m_BiasEnabled)
147 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
148 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
149 biasInfoPtr = &biasInfo;
153 // If biases are not enabled pass a dummy tensorinfo for the validation.
154 switch(input.GetDataType())
156 case DataType::Float16:
158 biasInfoPtr = &dummyFloat16Bias;
161 case DataType::Float32:
163 biasInfoPtr = &dummyFloat32Bias;
166 case DataType::QuantisedAsymm8:
168 biasInfoPtr = &dummyQA8Bias;
173 BOOST_ASSERT_MSG(false, "Unexpected input type");
178 result = IsConvolution2dSupported(compute,
182 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
188 case LayerType::MemCopy:
190 // MemCopy supported for CpuRef, CpuAcc and GpuAcc backends,
191 // (also treat Undefined as CpuRef to avoid breaking lots of Unit tests).
192 result = compute == Compute::CpuRef || compute == Compute::Undefined
193 || compute == Compute::CpuAcc || compute == Compute::GpuAcc;
194 strcpy(reason, "Unsupported backend type");
197 case LayerType::DepthwiseConvolution2d:
199 auto cLayer = boost::polymorphic_downcast<const DepthwiseConvolution2dLayer*>(&layer);
200 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
202 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
203 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
206 const TensorInfo * biasInfoPtr = nullptr;
207 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
208 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
209 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
211 const DepthwiseConvolution2dDescriptor& descriptor = cLayer->GetParameters();
212 if (descriptor.m_BiasEnabled)
214 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
215 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
216 biasInfoPtr = &biasInfo;
220 // If biases are not enabled pass a dummy tensorinfo for the validation
221 switch(input.GetDataType())
223 case DataType::Float16:
225 biasInfoPtr = &dummyFloat16Bias;
228 case DataType::Float32:
230 biasInfoPtr = &dummyFloat32Bias;
233 case DataType::QuantisedAsymm8:
235 biasInfoPtr = &dummyQA8Bias;
240 BOOST_ASSERT_MSG(false, "Unexpected bias type");
246 result = IsDepthwiseConvolutionSupported(compute,
250 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
256 case LayerType::FakeQuantization:
258 auto cLayer = boost::polymorphic_downcast<const FakeQuantizationLayer*>(&layer);
259 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
260 result = IsFakeQuantizationSupported(compute, OverrideDataType(input, dataType), cLayer->GetParameters(),
261 reason, reasonCapacity);
264 case LayerType::Floor:
266 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
267 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
268 result = IsFloorSupported(compute, OverrideDataType(input, dataType), OverrideDataType(output, dataType),
269 reason, reasonCapacity);
272 case LayerType::FullyConnected:
274 auto cLayer = boost::polymorphic_downcast<const FullyConnectedLayer*>(&layer);
275 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
276 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
277 BOOST_ASSERT(cLayer->m_Weight.get() != nullptr);
280 const TensorInfo * biasInfoPtr = nullptr;
281 static const TensorInfo dummyFloat16Bias(TensorShape({1,1,1,1}), DataType::Float16);
282 static const TensorInfo dummyFloat32Bias(TensorShape({1,1,1,1}), DataType::Float32);
283 static const TensorInfo dummyQA8Bias(TensorShape({1,1,1,1}), DataType::Signed32);
285 const FullyConnectedDescriptor& descriptor = cLayer->GetParameters();
286 if (descriptor.m_BiasEnabled)
288 BOOST_ASSERT(cLayer->m_Bias.get() != nullptr);
289 biasInfo = OverrideDataType(cLayer->m_Bias->GetTensorInfo(), GetBiasTypeFromWeightsType(dataType));
290 biasInfoPtr = &biasInfo;
294 // If biases are not enabled pass a dummy tensorinfo for the validation
295 switch(input.GetDataType())
297 case DataType::Float16:
299 biasInfoPtr = &dummyFloat16Bias;
302 case DataType::Float32:
304 biasInfoPtr = &dummyFloat32Bias;
307 case DataType::QuantisedAsymm8:
309 biasInfoPtr = &dummyQA8Bias;
314 BOOST_ASSERT_MSG(false, "Unexpected bias type");
319 result = IsFullyConnectedSupported(compute,
320 OverrideDataType(input, dataType),
321 OverrideDataType(output, dataType),
322 OverrideDataType(cLayer->m_Weight->GetTensorInfo(), dataType),
329 case LayerType::Input:
331 const TensorInfo& input = layer.GetOutputSlot(0).GetTensorInfo();
332 result = IsInputSupported(compute, OverrideDataType(input, dataType), reason, reasonCapacity);
335 case LayerType::L2Normalization:
337 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
338 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
339 result = IsL2NormalizationSupported(compute, OverrideDataType(input, dataType),
340 OverrideDataType(output, dataType), reason, reasonCapacity);
343 case LayerType::Lstm:
345 auto cLayer = boost::polymorphic_downcast<const LstmLayer*>(&layer);
346 const LstmDescriptor& descriptor = cLayer->GetParameters();
349 const TensorInfo& input = OverrideDataType(layer.GetInputSlot(0).GetConnection()->GetTensorInfo(),
351 const TensorInfo& outputStateIn = OverrideDataType(layer.GetInputSlot(1).GetConnection()->GetTensorInfo(),
353 const TensorInfo& cellStateIn = OverrideDataType(layer.GetInputSlot(2).GetConnection()->GetTensorInfo(),
356 const TensorInfo& scratchBuffer = OverrideDataType(layer.GetOutputSlot(0).GetTensorInfo(), dataType);
357 const TensorInfo& outputStateOut = OverrideDataType(layer.GetOutputSlot(1).GetTensorInfo(), dataType);
358 const TensorInfo& cellStateOut = OverrideDataType(layer.GetOutputSlot(2).GetTensorInfo(), dataType);
359 const TensorInfo& output = OverrideDataType(layer.GetOutputSlot(3).GetTensorInfo(), dataType);
362 const TensorInfo& inputToForgetWeights
363 = OverrideDataType(cLayer->m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), dataType);
364 const TensorInfo& inputToCellWeights
365 = OverrideDataType(cLayer->m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), dataType);
366 const TensorInfo& inputToOutputWeights
367 = OverrideDataType(cLayer->m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), dataType);
368 const TensorInfo& recurrentToForgetWeights
369 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), dataType);
370 const TensorInfo& recurrentToCellWeights
371 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), dataType);
372 const TensorInfo& recurrentToOutputWeights
373 = OverrideDataType(cLayer->m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), dataType);
374 const TensorInfo& forgetGateBias
375 = OverrideDataType(cLayer->m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), dataType);
376 const TensorInfo& cellBias
377 = OverrideDataType(cLayer->m_BasicParameters.m_CellBias->GetTensorInfo(), dataType);
378 const TensorInfo& outputGateBias
379 = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType);
381 // Optional parameters
382 const TensorInfo* inputToInputWeights = nullptr;
383 const TensorInfo* recurrentToInputWeights = nullptr;
384 const TensorInfo* cellToInputWeights = nullptr;
385 const TensorInfo* inputGateBias = nullptr;
386 const TensorInfo* projectionWeights = nullptr;
387 const TensorInfo* projectionBias = nullptr;
388 const TensorInfo* cellToForgetWeights = nullptr;
389 const TensorInfo* cellToOutputWeights = nullptr;
391 TensorInfo optInputToInputWeights;
392 TensorInfo optRecurrentToInputWeights;
393 TensorInfo optCellToInputWeights;
394 TensorInfo optInputGateBias;
395 TensorInfo optProjectionWeights;
396 TensorInfo optProjectionBias;
397 TensorInfo optCellToForgetWeights;
398 TensorInfo optCellToOutputWeights;
400 if(!descriptor.m_CifgEnabled)
402 optInputToInputWeights =
403 OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType);
404 inputToInputWeights = &optInputToInputWeights;
406 optRecurrentToInputWeights =
407 OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType);
408 recurrentToInputWeights = &optRecurrentToInputWeights;
409 if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr)
411 optCellToInputWeights =
412 OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType);
413 cellToInputWeights = &optCellToInputWeights;
416 OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType);
417 inputGateBias = &optInputGateBias;
420 if(descriptor.m_ProjectionEnabled)
422 optProjectionWeights =
423 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType);
424 projectionWeights = &optProjectionWeights;
425 if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr)
428 OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType);
429 projectionBias = &optProjectionBias;
433 if(descriptor.m_PeepholeEnabled)
435 optCellToForgetWeights =
436 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType);
437 cellToForgetWeights = &optCellToForgetWeights;
438 optCellToOutputWeights =
439 OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType);
440 cellToOutputWeights = &optCellToOutputWeights;
443 result = IsLstmSupported(compute,
452 inputToForgetWeights,
454 inputToOutputWeights,
455 recurrentToForgetWeights,
456 recurrentToCellWeights,
457 recurrentToOutputWeights,
462 recurrentToInputWeights,
473 case LayerType::Merger:
475 auto cLayer = boost::polymorphic_downcast<const MergerLayer*>(&layer);
477 // Get vector of all inputs.
478 auto getTensorInfo = [&dataType](const InputSlot& slot)
480 return OverrideDataType(slot.GetConnectedOutputSlot()->GetTensorInfo(), dataType);
482 auto beginI = boost::make_transform_iterator(layer.GetInputSlots().begin(), getTensorInfo);
483 auto endI = boost::make_transform_iterator(layer.GetInputSlots().end(), getTensorInfo);
484 std::vector<TensorInfo> inputs(beginI, endI);
486 auto getTensorInfoPtr = [](const TensorInfo& info)
490 auto beginPtr = boost::make_transform_iterator(inputs.begin(), getTensorInfoPtr);
491 auto endPtr = boost::make_transform_iterator(inputs.end(), getTensorInfoPtr);
492 std::vector<const TensorInfo*> inputPtrs(beginPtr, endPtr);
494 result = IsMergerSupported(compute, inputPtrs, cLayer->GetParameters(), reason, reasonCapacity);
497 case LayerType::Multiplication:
499 const TensorInfo& input0 = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
500 const TensorInfo& input1 = layer.GetInputSlot(1).GetConnection()->GetTensorInfo();
501 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
502 result = IsMultiplicationSupported(compute,
503 OverrideDataType(input0, dataType),
504 OverrideDataType(input1, dataType),
505 OverrideDataType(output, dataType),
510 case LayerType::Normalization:
512 auto cLayer = boost::polymorphic_downcast<const NormalizationLayer*>(&layer);
513 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
514 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
515 result = IsNormalizationSupported(compute, OverrideDataType(input, dataType),
516 OverrideDataType(output, dataType), cLayer->GetParameters(), reason,
520 case LayerType::Output:
522 const TensorInfo& output = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
523 result = IsOutputSupported(compute, OverrideDataType(output, dataType), reason, reasonCapacity);
526 case LayerType::Permute:
528 auto cLayer = boost::polymorphic_downcast<const PermuteLayer*>(&layer);
529 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
530 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
531 result = IsPermuteSupported(compute, OverrideDataType(input, dataType), OverrideDataType(output, dataType),
532 cLayer->GetParameters(), reason, reasonCapacity);
535 case LayerType::Pooling2d:
537 auto cLayer = boost::polymorphic_downcast<const Pooling2dLayer*>(&layer);
538 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
539 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
540 result = IsPooling2dSupported(compute, OverrideDataType(input, dataType),
541 OverrideDataType(output, dataType), cLayer->GetParameters(), reason,
545 case LayerType::Reshape:
547 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
548 result = IsReshapeSupported(compute, OverrideDataType(input, dataType), reason, reasonCapacity);
551 case LayerType::ResizeBilinear:
553 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
554 result = IsResizeBilinearSupported(compute, OverrideDataType(input, dataType), reason, reasonCapacity);
557 case LayerType::Softmax:
559 auto cLayer = boost::polymorphic_downcast<const SoftmaxLayer*>(&layer);
560 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
561 const TensorInfo& output = layer.GetOutputSlot(0).GetTensorInfo();
562 result = IsSoftmaxSupported(compute, OverrideDataType(input, dataType), OverrideDataType(output, dataType),
563 cLayer->GetParameters(), reason, reasonCapacity);
566 case LayerType::Splitter:
568 auto cLayer = boost::polymorphic_downcast<const SplitterLayer*>(&layer);
569 const TensorInfo& input = layer.GetInputSlot(0).GetConnection()->GetTensorInfo();
570 result = IsSplitterSupported(compute, OverrideDataType(input, dataType), cLayer->GetParameters(), reason,
576 BOOST_ASSERT_MSG(false, "WorkloadFactory did not recognise type of layer.");
577 strcpy(reason, "Unrecognised layer type");
582 outReasonIfUnsupported = reason;
586 bool IWorkloadFactory::IsLayerSupported(const Layer& layer, boost::optional<DataType> dataType,
587 std::string& outReasonIfUnsupported)
589 return IsLayerSupported(layer.GetComputeDevice(), layer, dataType, outReasonIfUnsupported);