2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
6 #include "LayerSupportCommon.hpp"
8 #include "ClLayerSupport.hpp"
9 #include "InternalTypes.hpp"
10 #include <armnn/Descriptors.hpp>
11 #include <armnn/Types.hpp>
12 #include <armnn/Tensor.hpp>
14 #include <boost/core/ignore_unused.hpp>
16 #ifdef ARMCOMPUTECL_ENABLED
17 #include "workloads/ClAdditionWorkload.hpp"
18 #include "workloads/ClActivationFloatWorkload.hpp"
19 #include "workloads/ClBatchNormalizationFloatWorkload.hpp"
20 #include "workloads/ClConvertFp16ToFp32Workload.hpp"
21 #include "workloads/ClConvertFp32ToFp16Workload.hpp"
22 #include "workloads/ClConvolution2dBaseWorkload.hpp"
23 #include "workloads/ClDepthwiseConvolutionBaseWorkload.hpp"
24 #include "workloads/ClDivisionFloatWorkload.hpp"
25 #include "workloads/ClL2NormalizationFloatWorkload.hpp"
26 #include "workloads/ClMultiplicationFloatWorkload.hpp"
27 #include "workloads/ClFullyConnectedWorkload.hpp"
28 #include "workloads/ClPadWorkload.hpp"
29 #include "workloads/ClPooling2dBaseWorkload.hpp"
30 #include "workloads/ClPermuteWorkload.hpp"
31 #include "workloads/ClNormalizationFloatWorkload.hpp"
32 #include "workloads/ClSoftmaxBaseWorkload.hpp"
33 #include "workloads/ClSubtractionWorkload.hpp"
34 #include "workloads/ClLstmFloatWorkload.hpp"
37 using namespace boost;
43 template<unsigned int FilterSize>
44 bool IsMatchingSize2d(const TensorInfo& weightInfo)
46 // Width & Height must match.
47 return (weightInfo.GetShape()[3] == FilterSize) && (weightInfo.GetShape()[2] == FilterSize);
50 template<uint32_t ValidStride>
51 bool IsMatchingStride(uint32_t actualStride)
53 return ValidStride == actualStride;
56 template<uint32_t FirstStride, uint32_t SecondStride, uint32_t... ValidStrides>
57 bool IsMatchingStride(uint32_t actualStride)
59 return IsMatchingStride<FirstStride>(actualStride) || IsMatchingStride<SecondStride, ValidStrides...>(actualStride);
62 bool IsClBackendSupported(std::string* reasonIfUnsupported)
64 #if ARMCOMPUTECL_ENABLED
67 if (reasonIfUnsupported != nullptr)
69 *reasonIfUnsupported = "The armnn library has been built without CL support";
75 #if ARMCOMPUTECL_ENABLED
76 #define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) (expr)
78 #define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) IsClBackendSupported(reasonIfUnsupported)
81 #if ARMCOMPUTECL_ENABLED
82 template<class FuncType, class... Args>
83 inline bool IsWorkloadSupported(FuncType&& func, std::string* reasonIfUnsupported, Args&&... args)
85 arm_compute::Status aclStatus = func(std::forward<Args>(args)...);
86 const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
87 if (!supported && reasonIfUnsupported)
89 *reasonIfUnsupported = aclStatus.error_description();
94 #define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
95 return IsWorkloadSupported(func, reasonIfUnsupported, __VA_ARGS__);
97 #define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
98 return IsClBackendSupported(reasonIfUnsupported);
103 template<typename FloatFunc, typename Uint8Func, typename ... Params>
104 bool IsSupportedForDataTypeCl(std::string* reasonIfUnsupported,
106 FloatFunc floatFuncPtr,
107 Uint8Func uint8FuncPtr,
110 return IsClBackendSupported(reasonIfUnsupported) &&
111 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
116 std::forward<Params>(params)...);
119 bool IsActivationSupportedCl(const TensorInfo& input,
120 const TensorInfo& output,
121 const ActivationDescriptor& descriptor,
122 std::string* reasonIfUnsupported)
124 FORWARD_WORKLOAD_VALIDATE_FUNC(ClActivationWorkloadValidate,
131 bool IsAdditionSupportedCl(const TensorInfo& input0,
132 const TensorInfo& input1,
133 const TensorInfo& output,
134 std::string* reasonIfUnsupported)
136 return FORWARD_CL_LAYER_SUPPORT_FUNC(ClAdditionValidate(input0,
139 reasonIfUnsupported));
142 bool IsBatchNormalizationSupportedCl(const TensorInfo& input,
143 const TensorInfo& output,
144 const TensorInfo& mean,
145 const TensorInfo& var,
146 const TensorInfo& beta,
147 const TensorInfo& gamma,
148 const BatchNormalizationDescriptor& descriptor,
149 std::string* reasonIfUnsupported)
151 FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchNormalizationValidate,
162 bool IsConstantSupportedCl(const TensorInfo& output,
163 std::string* reasonIfUnsupported)
165 return IsSupportedForDataTypeCl(reasonIfUnsupported,
166 output.GetDataType(),
171 bool IsClDirectConvolution2dSupported(const TensorInfo& weightInfo, const Convolution2dDescriptor& desc)
173 bool isSupported = false;
175 bool strideXIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideX);
176 bool strideXIsThree = IsMatchingStride<3>(desc.m_StrideX);
178 bool strideYIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideY);
179 bool strideYIsThree = IsMatchingStride<3>(desc.m_StrideY);
181 bool strideIsOneOrTwo = strideXIsOneOrTwo && strideYIsOneOrTwo;
182 bool strideIsOneOrTwoOrThree = ( strideXIsOneOrTwo || strideXIsThree ) && ( strideYIsOneOrTwo || strideYIsThree );
184 // 1x1 convolution with strides of 1,2,3.
185 isSupported |= IsMatchingSize2d<1>(weightInfo) && ( strideIsOneOrTwoOrThree );
187 // 3x3 convolution with strides of 1,2.
188 isSupported |= IsMatchingSize2d<3>(weightInfo) && ( strideIsOneOrTwo );
190 // 5x5 convolution with strides of 1,2
191 isSupported |= IsMatchingSize2d<5>(weightInfo) && ( strideIsOneOrTwo );
193 //Fall back to normal convolution for the asymmetric padding case.
194 if (desc.m_PadLeft != desc.m_PadRight ||
195 desc.m_PadTop != desc.m_PadBottom)
197 //Direct convolution does not support asymmetric padding yet.
204 bool IsDirectConvolution2dParamsSupportedCl(std::string* reasonIfUnsupported,
205 const Convolution2dDescriptor& parameters,
206 const TensorInfo& weightInfo)
208 return IsClDirectConvolution2dSupported(weightInfo, parameters);
211 bool IsConvolution2dSupportedCl(const TensorInfo& input,
212 const TensorInfo& output,
213 const Convolution2dDescriptor& descriptor,
214 const TensorInfo& weights,
215 const boost::optional<TensorInfo>& biases,
216 std::string* reasonIfUnsupported)
218 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution2dWorkloadValidate,
227 bool IsDepthwiseConvolutionSupportedCl(const TensorInfo& input,
228 const TensorInfo& output,
229 const DepthwiseConvolution2dDescriptor& descriptor,
230 const TensorInfo& weights,
231 const boost::optional<TensorInfo>& biases,
232 std::string* reasonIfUnsupported)
234 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
243 bool IsDivisionSupportedCl(const TensorInfo& input0,
244 const TensorInfo& input1,
245 const TensorInfo& output,
246 std::string* reasonIfUnsupported)
248 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDivisionWorkloadValidate,
255 bool IsSubtractionSupportedCl(const TensorInfo& input0,
256 const TensorInfo& input1,
257 const TensorInfo& output,
258 std::string* reasonIfUnsupported)
260 return FORWARD_CL_LAYER_SUPPORT_FUNC(ClSubtractionValidate(input0,
263 reasonIfUnsupported));
266 bool IsFullyConnectedSupportedCl(const TensorInfo& input,
267 const TensorInfo& output,
268 const TensorInfo& weights,
269 const TensorInfo& biases,
270 const FullyConnectedDescriptor& descriptor,
271 std::string* reasonIfUnsupported)
273 FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
282 bool IsInputSupportedCl(const TensorInfo& input,
283 std::string* reasonIfUnsupported)
285 return IsSupportedForDataTypeCl(reasonIfUnsupported,
291 bool IsL2NormalizationSupportedCl(const TensorInfo& input,
292 const TensorInfo& output,
293 const L2NormalizationDescriptor& descriptor,
294 std::string* reasonIfUnsupported)
296 FORWARD_WORKLOAD_VALIDATE_FUNC(ClL2NormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
299 bool IsMergerSupportedCl(const std::vector<const TensorInfo*> inputs,
300 const OriginsDescriptor& descriptor,
301 std::string* reasonIfUnsupported)
303 ignore_unused(descriptor);
304 return IsSupportedForDataTypeCl(reasonIfUnsupported,
305 inputs[0]->GetDataType(),
310 bool IsMultiplicationSupportedCl(const TensorInfo& input0,
311 const TensorInfo& input1,
312 const TensorInfo& output,
313 std::string* reasonIfUnsupported)
315 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate,
322 bool IsNormalizationSupportedCl(const TensorInfo& input,
323 const TensorInfo& output,
324 const NormalizationDescriptor& descriptor,
325 std::string* reasonIfUnsupported)
327 FORWARD_WORKLOAD_VALIDATE_FUNC(ClNormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
330 bool IsOutputSupportedCl(const TensorInfo& output,
331 std::string* reasonIfUnsupported)
333 return IsSupportedForDataTypeCl(reasonIfUnsupported,
334 output.GetDataType(),
339 bool IsPadSupportedCl(const TensorInfo& input,
340 const TensorInfo& output,
341 const PadDescriptor& descriptor,
342 std::string* reasonIfUnsupported)
344 return FORWARD_CL_LAYER_SUPPORT_FUNC(ClPadValidate(input, output, descriptor, reasonIfUnsupported));
347 bool IsPermuteSupportedCl(const TensorInfo& input,
348 const TensorInfo& output,
349 const PermuteDescriptor& descriptor,
350 std::string* reasonIfUnsupported)
352 ignore_unused(input);
353 ignore_unused(output);
354 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, descriptor);
357 bool IsPooling2dSupportedCl(const TensorInfo& input,
358 const TensorInfo& output,
359 const Pooling2dDescriptor& descriptor,
360 std::string* reasonIfUnsupported)
362 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
365 bool IsResizeBilinearSupportedCl(const TensorInfo& input,
366 std::string* reasonIfUnsupported)
368 return IsSupportedForDataTypeCl(reasonIfUnsupported,
374 bool IsSoftmaxSupportedCl(const TensorInfo& input,
375 const TensorInfo& output,
376 const SoftmaxDescriptor& descriptor,
377 std::string* reasonIfUnsupported)
379 ignore_unused(descriptor);
380 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSoftmaxWorkloadValidate, reasonIfUnsupported, input, output);
383 bool IsSplitterSupportedCl(const TensorInfo& input,
384 const ViewsDescriptor& descriptor,
385 std::string* reasonIfUnsupported)
387 ignore_unused(descriptor);
388 return IsSupportedForDataTypeCl(reasonIfUnsupported,
394 bool IsFakeQuantizationSupportedCl(const TensorInfo& input,
395 const FakeQuantizationDescriptor& descriptor,
396 std::string* reasonIfUnsupported)
398 ignore_unused(input);
399 ignore_unused(descriptor);
403 bool IsReshapeSupportedCl(const TensorInfo& input,
404 std::string* reasonIfUnsupported)
406 ignore_unused(input);
410 bool IsFloorSupportedCl(const TensorInfo& input,
411 const TensorInfo& output,
412 std::string* reasonIfUnsupported)
414 ignore_unused(output);
415 return IsClBackendSupported(reasonIfUnsupported) &&
416 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
423 bool IsLstmSupportedCl(const TensorInfo& input, const TensorInfo& outputStateIn,
424 const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
425 const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
426 const TensorInfo& output, const LstmDescriptor& descriptor,
427 const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights,
428 const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights,
429 const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights,
430 const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
431 const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights,
432 const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights,
433 const TensorInfo* inputGateBias, const TensorInfo* projectionWeights,
434 const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights,
435 const TensorInfo* cellToOutputWeights, std::string* reasonIfUnsupported)
437 FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate, reasonIfUnsupported,
438 input, outputStateIn, cellStateIn, scratchBuffer, outputStateOut, cellStateOut,
439 output, descriptor, inputToForgetWeights, inputToCellWeights,
440 inputToOutputWeights, recurrentToForgetWeights,
441 recurrentToCellWeights, recurrentToOutputWeights,
442 forgetGateBias, cellBias, outputGateBias,
443 inputToInputWeights, recurrentToInputWeights,
444 cellToInputWeights, inputGateBias, projectionWeights,
445 projectionBias, cellToForgetWeights, cellToOutputWeights);
448 bool IsConvertFp16ToFp32SupportedCl(const TensorInfo& input,
449 const TensorInfo& output,
450 std::string* reasonIfUnsupported)
452 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp16ToFp32WorkloadValidate,
456 reasonIfUnsupported);
459 bool IsConvertFp32ToFp16SupportedCl(const TensorInfo& input,
460 const TensorInfo& output,
461 std::string* reasonIfUnsupported)
463 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp32ToFp16WorkloadValidate,
467 reasonIfUnsupported);
470 bool IsMeanSupportedCl(const TensorInfo& input,
471 const TensorInfo& output,
472 const MeanDescriptor& descriptor,
473 std::string* reasonIfUnsupported)