2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
6 #include "LayerSupportCommon.hpp"
8 #include "ClLayerSupport.hpp"
9 #include "InternalTypes.hpp"
10 #include <armnn/Descriptors.hpp>
11 #include <armnn/Types.hpp>
12 #include <armnn/Tensor.hpp>
14 #include <boost/core/ignore_unused.hpp>
16 #ifdef ARMCOMPUTECL_ENABLED
17 #include "ClWorkloads/ClAdditionFloat32Workload.hpp"
18 #include "ClWorkloads/ClActivationFloat32Workload.hpp"
19 #include "ClWorkloads/ClBatchNormalizationFloat32Workload.hpp"
21 #include "ClWorkloads/ClConvertFp16ToFp32Workload.hpp"
22 #include "ClWorkloads/ClConvertFp32ToFp16Workload.hpp"
23 #include "ClWorkloads/ClConvolution2dBaseWorkload.hpp"
24 #include "ClWorkloads/ClDepthwiseConvolutionBaseWorkload.hpp"
25 #include "ClWorkloads/ClL2NormalizationFloat32Workload.hpp"
26 #include "ClWorkloads/ClMultiplicationFloat32Workload.hpp"
27 #include "ClWorkloads/ClFullyConnectedFloat32Workload.hpp"
28 #include "ClWorkloads/ClPooling2dBaseWorkload.hpp"
29 #include "ClWorkloads/ClPermuteWorkload.hpp"
30 #include "ClWorkloads/ClNormalizationFloat32Workload.hpp"
31 #include "ClWorkloads/ClSoftmaxBaseWorkload.hpp"
32 #include "ClWorkloads/ClLstmFloat32Workload.hpp"
35 using namespace boost;
41 template<unsigned int FilterSize>
42 bool IsMatchingSize2d(const TensorInfo& weightInfo)
44 // Width & Height must match.
45 return (weightInfo.GetShape()[3] == FilterSize) && (weightInfo.GetShape()[2] == FilterSize);
48 template<uint32_t ValidStride>
49 bool IsMatchingStride(uint32_t actualStride)
51 return ValidStride == actualStride;
54 template<uint32_t FirstStride, uint32_t SecondStride, uint32_t... ValidStrides>
55 bool IsMatchingStride(uint32_t actualStride)
57 return IsMatchingStride<FirstStride>(actualStride) || IsMatchingStride<SecondStride, ValidStrides...>(actualStride);
60 bool IsClBackendSupported(std::string* reasonIfUnsupported)
62 #if ARMCOMPUTECL_ENABLED
65 if (reasonIfUnsupported != nullptr)
67 *reasonIfUnsupported = "The armnn library has been built without CL support";
73 #if ARMCOMPUTECL_ENABLED
74 #define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) (expr)
76 #define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) IsClBackendSupported(reasonIfUnsupported)
79 #if ARMCOMPUTECL_ENABLED
80 template<class FuncType, class... Args>
81 inline bool IsWorkloadSupported(FuncType&& func, std::string* reasonIfUnsupported, Args&&... args)
83 arm_compute::Status aclStatus = func(std::forward<Args>(args)...);
84 const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
85 if (!supported && reasonIfUnsupported)
87 *reasonIfUnsupported = aclStatus.error_description();
92 #define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
93 return IsWorkloadSupported(func, reasonIfUnsupported, __VA_ARGS__);
95 #define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
96 return IsClBackendSupported(reasonIfUnsupported);
101 template<typename FloatFunc, typename Uint8Func, typename ... Params>
102 bool IsSupportedForDataTypeCl(std::string* reasonIfUnsupported,
104 FloatFunc floatFuncPtr,
105 Uint8Func uint8FuncPtr,
108 return IsClBackendSupported(reasonIfUnsupported) &&
109 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
114 std::forward<Params>(params)...);
117 bool IsActivationSupportedCl(const TensorInfo& input,
118 const TensorInfo& output,
119 const ActivationDescriptor& descriptor,
120 std::string* reasonIfUnsupported)
122 FORWARD_WORKLOAD_VALIDATE_FUNC(ClActivationWorkloadValidate,
129 bool IsAdditionSupportedCl(const TensorInfo& input0,
130 const TensorInfo& input1,
131 const TensorInfo& output,
132 std::string* reasonIfUnsupported)
134 return FORWARD_CL_LAYER_SUPPORT_FUNC(ClAdditionValidate(input0,
137 reasonIfUnsupported));
140 bool IsBatchNormalizationSupportedCl(const TensorInfo& input,
141 const TensorInfo& output,
142 const TensorInfo& mean,
143 const TensorInfo& var,
144 const TensorInfo& beta,
145 const TensorInfo& gamma,
146 const BatchNormalizationDescriptor& descriptor,
147 std::string* reasonIfUnsupported)
149 FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchNormalizationValidate,
160 bool IsConstantSupportedCl(const TensorInfo& output,
161 std::string* reasonIfUnsupported)
163 return IsSupportedForDataTypeCl(reasonIfUnsupported,
164 output.GetDataType(),
169 bool IsClDirectConvolution2dSupported(const TensorInfo& weightInfo, const Convolution2dDescriptor& desc)
171 bool isSupported = false;
173 bool strideXIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideX);
174 bool strideXIsThree = IsMatchingStride<3>(desc.m_StrideX);
176 bool strideYIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideY);
177 bool strideYIsThree = IsMatchingStride<3>(desc.m_StrideY);
179 bool strideIsOneOrTwo = strideXIsOneOrTwo && strideYIsOneOrTwo;
180 bool strideIsOneOrTwoOrThree = ( strideXIsOneOrTwo || strideXIsThree ) && ( strideYIsOneOrTwo || strideYIsThree );
182 // 1x1 convolution with strides of 1,2,3.
183 isSupported |= IsMatchingSize2d<1>(weightInfo) && ( strideIsOneOrTwoOrThree );
185 // 3x3 convolution with strides of 1,2.
186 isSupported |= IsMatchingSize2d<3>(weightInfo) && ( strideIsOneOrTwo );
188 // 5x5 convolution with strides of 1,2
189 isSupported |= IsMatchingSize2d<5>(weightInfo) && ( strideIsOneOrTwo );
191 //Fall back to normal convolution for the asymmetric padding case.
192 if (desc.m_PadLeft != desc.m_PadRight ||
193 desc.m_PadTop != desc.m_PadBottom)
195 //Direct convolution does not support asymmetric padding yet.
202 bool IsDirectConvolution2dParamsSupportedCl(std::string* reasonIfUnsupported,
203 const Convolution2dDescriptor& parameters,
204 const TensorInfo& weightInfo)
206 return IsClDirectConvolution2dSupported(weightInfo, parameters);
209 bool IsConvolution2dSupportedCl(const TensorInfo& input,
210 const TensorInfo& output,
211 const Convolution2dDescriptor& descriptor,
212 const TensorInfo& weights,
213 const TensorInfo& biases,
214 std::string* reasonIfUnsupported)
216 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution2dWorkloadValidate,
225 bool IsDepthwiseConvolutionSupportedCl(const TensorInfo& input,
226 const TensorInfo& output,
227 const DepthwiseConvolution2dDescriptor& descriptor,
228 const TensorInfo& weights,
229 const TensorInfo& biases,
230 std::string* reasonIfUnsupported)
232 FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
241 bool IsFullyConnectedSupportedCl(const TensorInfo& input,
242 const TensorInfo& output,
243 const TensorInfo& weights,
244 const TensorInfo& biases,
245 const FullyConnectedDescriptor& descriptor,
246 std::string* reasonIfUnsupported)
248 // At the moment U8 is unsupported
249 if (input.GetDataType() == DataType::QuantisedAsymm8)
253 FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
262 bool IsInputSupportedCl(const TensorInfo& input,
263 std::string* reasonIfUnsupported)
265 return IsSupportedForDataTypeCl(reasonIfUnsupported,
271 bool IsL2NormalizationSupportedCl(const TensorInfo& input,
272 const TensorInfo& output,
273 std::string* reasonIfUnsupported)
275 FORWARD_WORKLOAD_VALIDATE_FUNC(ClL2NormalizationWorkloadValidate, reasonIfUnsupported, input, output);
278 bool IsMergerSupportedCl(const std::vector<const TensorInfo*> inputs,
279 const OriginsDescriptor& descriptor,
280 std::string* reasonIfUnsupported)
282 ignore_unused(descriptor);
283 return IsSupportedForDataTypeCl(reasonIfUnsupported,
284 inputs[0]->GetDataType(),
289 bool IsMultiplicationSupportedCl(const TensorInfo& input0,
290 const TensorInfo& input1,
291 const TensorInfo& output,
292 std::string* reasonIfUnsupported)
294 FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate,
301 bool IsNormalizationSupportedCl(const TensorInfo& input,
302 const TensorInfo& output,
303 const NormalizationDescriptor& descriptor,
304 std::string* reasonIfUnsupported)
306 FORWARD_WORKLOAD_VALIDATE_FUNC(ClNormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
309 bool IsOutputSupportedCl(const TensorInfo& output,
310 std::string* reasonIfUnsupported)
312 return IsSupportedForDataTypeCl(reasonIfUnsupported,
313 output.GetDataType(),
318 bool IsPermuteSupportedCl(const TensorInfo& input,
319 const TensorInfo& output,
320 const PermuteDescriptor& descriptor,
321 std::string* reasonIfUnsupported)
323 ignore_unused(input);
324 ignore_unused(output);
325 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, descriptor);
328 bool IsPooling2dSupportedCl(const TensorInfo& input,
329 const TensorInfo& output,
330 const Pooling2dDescriptor& descriptor,
331 std::string* reasonIfUnsupported)
333 FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
336 bool IsResizeBilinearSupportedCl(const TensorInfo& input,
337 std::string* reasonIfUnsupported)
339 return IsSupportedForDataTypeCl(reasonIfUnsupported,
345 bool IsSoftmaxSupportedCl(const TensorInfo& input,
346 const TensorInfo& output,
347 const SoftmaxDescriptor& descriptor,
348 std::string* reasonIfUnsupported)
350 ignore_unused(descriptor);
351 FORWARD_WORKLOAD_VALIDATE_FUNC(ClSoftmaxWorkloadValidate, reasonIfUnsupported, input, output);
354 bool IsSplitterSupportedCl(const TensorInfo& input,
355 const ViewsDescriptor& descriptor,
356 std::string* reasonIfUnsupported)
358 ignore_unused(descriptor);
359 return IsSupportedForDataTypeCl(reasonIfUnsupported,
365 bool IsFakeQuantizationSupportedCl(const TensorInfo& input,
366 const FakeQuantizationDescriptor& descriptor,
367 std::string* reasonIfUnsupported)
369 ignore_unused(input);
370 ignore_unused(descriptor);
374 bool IsReshapeSupportedCl(const TensorInfo& input,
375 std::string* reasonIfUnsupported)
377 ignore_unused(input);
381 bool IsFloorSupportedCl(const TensorInfo& input,
382 const TensorInfo& output,
383 std::string* reasonIfUnsupported)
385 ignore_unused(output);
386 return IsClBackendSupported(reasonIfUnsupported) &&
387 IsSupportedForDataTypeGeneric(reasonIfUnsupported,
394 bool IsLstmSupportedCl(const TensorInfo& input, const TensorInfo& outputStateIn,
395 const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
396 const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
397 const TensorInfo& output, const LstmDescriptor& descriptor,
398 const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights,
399 const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights,
400 const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights,
401 const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
402 const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights,
403 const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights,
404 const TensorInfo* inputGateBias, const TensorInfo* projectionWeights,
405 const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights,
406 const TensorInfo* cellToOutputWeights, std::string* reasonIfUnsupported)
408 FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloat32WorkloadValidate, reasonIfUnsupported,
409 input, outputStateIn, cellStateIn, scratchBuffer, outputStateOut, cellStateOut,
410 output, descriptor, inputToForgetWeights, inputToCellWeights,
411 inputToOutputWeights, recurrentToForgetWeights,
412 recurrentToCellWeights, recurrentToOutputWeights,
413 forgetGateBias, cellBias, outputGateBias,
414 inputToInputWeights, recurrentToInputWeights,
415 cellToInputWeights, inputGateBias, projectionWeights,
416 projectionBias, cellToForgetWeights, cellToOutputWeights);
419 bool IsConvertFp16ToFp32SupportedCl(const TensorInfo& input,
420 const TensorInfo& output,
421 std::string* reasonIfUnsupported)
423 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp16ToFp32WorkloadValidate,
427 reasonIfUnsupported);
430 bool IsConvertFp32ToFp16SupportedCl(const TensorInfo& input,
431 const TensorInfo& output,
432 std::string* reasonIfUnsupported)
434 FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp32ToFp16WorkloadValidate,
438 reasonIfUnsupported);