IVGCVSW-1900 : CL backend folder structure
[platform/upstream/armnn.git] / src / backends / cl / ClLayerSupport.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "LayerSupportCommon.hpp"
7
8 #include "ClLayerSupport.hpp"
9 #include "InternalTypes.hpp"
10 #include <armnn/Descriptors.hpp>
11 #include <armnn/Types.hpp>
12 #include <armnn/Tensor.hpp>
13
14 #include <boost/core/ignore_unused.hpp>
15
16 #ifdef ARMCOMPUTECL_ENABLED
17 #include "workloads/ClAdditionWorkload.hpp"
18 #include "workloads/ClActivationFloatWorkload.hpp"
19 #include "workloads/ClBatchNormalizationFloatWorkload.hpp"
20 #include "workloads/ClConvertFp16ToFp32Workload.hpp"
21 #include "workloads/ClConvertFp32ToFp16Workload.hpp"
22 #include "workloads/ClConvolution2dBaseWorkload.hpp"
23 #include "workloads/ClDepthwiseConvolutionBaseWorkload.hpp"
24 #include "workloads/ClDivisionFloatWorkload.hpp"
25 #include "workloads/ClL2NormalizationFloatWorkload.hpp"
26 #include "workloads/ClMultiplicationFloatWorkload.hpp"
27 #include "workloads/ClFullyConnectedWorkload.hpp"
28 #include "workloads/ClPadWorkload.hpp"
29 #include "workloads/ClPooling2dBaseWorkload.hpp"
30 #include "workloads/ClPermuteWorkload.hpp"
31 #include "workloads/ClNormalizationFloatWorkload.hpp"
32 #include "workloads/ClSoftmaxBaseWorkload.hpp"
33 #include "workloads/ClSubtractionWorkload.hpp"
34 #include "workloads/ClLstmFloatWorkload.hpp"
35 #endif
36
37 using namespace boost;
38
39 namespace armnn
40 {
41 namespace
42 {
43 template<unsigned int FilterSize>
44 bool IsMatchingSize2d(const TensorInfo& weightInfo)
45 {
46     // Width & Height must match.
47     return (weightInfo.GetShape()[3] == FilterSize) && (weightInfo.GetShape()[2] == FilterSize);
48 }
49
50 template<uint32_t ValidStride>
51 bool IsMatchingStride(uint32_t actualStride)
52 {
53     return ValidStride == actualStride;
54 }
55
56 template<uint32_t FirstStride, uint32_t SecondStride, uint32_t... ValidStrides>
57 bool IsMatchingStride(uint32_t actualStride)
58 {
59     return IsMatchingStride<FirstStride>(actualStride) || IsMatchingStride<SecondStride, ValidStrides...>(actualStride);
60 };
61
62 bool IsClBackendSupported(std::string* reasonIfUnsupported)
63 {
64 #if ARMCOMPUTECL_ENABLED
65     return true;
66 #else
67     if (reasonIfUnsupported != nullptr)
68     {
69         *reasonIfUnsupported = "The armnn library has been built without CL support";
70     }
71     return false;
72 #endif
73 }
74
75 #if ARMCOMPUTECL_ENABLED
76 #define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) (expr)
77 #else
78 #define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) IsClBackendSupported(reasonIfUnsupported)
79 #endif
80
81 #if ARMCOMPUTECL_ENABLED
82 template<class FuncType, class... Args>
83 inline bool IsWorkloadSupported(FuncType&& func, std::string* reasonIfUnsupported, Args&&... args)
84 {
85     arm_compute::Status aclStatus = func(std::forward<Args>(args)...);
86     const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
87     if (!supported && reasonIfUnsupported)
88     {
89         *reasonIfUnsupported = aclStatus.error_description();
90     }
91     return supported;
92 }
93
94 #define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
95     return IsWorkloadSupported(func, reasonIfUnsupported, __VA_ARGS__);
96 #else
97 #define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
98     return IsClBackendSupported(reasonIfUnsupported);
99 #endif
100
101 } //namespace
102
103 template<typename FloatFunc, typename Uint8Func, typename ... Params>
104 bool IsSupportedForDataTypeCl(std::string* reasonIfUnsupported,
105                               DataType dataType,
106                               FloatFunc floatFuncPtr,
107                               Uint8Func uint8FuncPtr,
108                               Params&&... params)
109 {
110     return IsClBackendSupported(reasonIfUnsupported) &&
111         IsSupportedForDataTypeGeneric(reasonIfUnsupported,
112                                       dataType,
113                                       floatFuncPtr,
114                                       floatFuncPtr,
115                                       uint8FuncPtr,
116                                       std::forward<Params>(params)...);
117 }
118
119 bool IsActivationSupportedCl(const TensorInfo& input,
120                              const TensorInfo& output,
121                              const ActivationDescriptor& descriptor,
122                              std::string* reasonIfUnsupported)
123 {
124     FORWARD_WORKLOAD_VALIDATE_FUNC(ClActivationWorkloadValidate,
125                                    reasonIfUnsupported,
126                                    input,
127                                    output,
128                                    descriptor);
129 }
130
131 bool IsAdditionSupportedCl(const TensorInfo& input0,
132                            const TensorInfo& input1,
133                            const TensorInfo& output,
134                            std::string* reasonIfUnsupported)
135 {
136     return FORWARD_CL_LAYER_SUPPORT_FUNC(ClAdditionValidate(input0,
137         input1,
138         output,
139         reasonIfUnsupported));
140 }
141
142 bool IsBatchNormalizationSupportedCl(const TensorInfo& input,
143                                      const TensorInfo& output,
144                                      const TensorInfo& mean,
145                                      const TensorInfo& var,
146                                      const TensorInfo& beta,
147                                      const TensorInfo& gamma,
148                                      const BatchNormalizationDescriptor& descriptor,
149                                      std::string* reasonIfUnsupported)
150 {
151     FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchNormalizationValidate,
152                                    reasonIfUnsupported,
153                                    input,
154                                    output,
155                                    mean,
156                                    var,
157                                    beta,
158                                    gamma,
159                                    descriptor);
160 }
161
162 bool IsConstantSupportedCl(const TensorInfo& output,
163                            std::string* reasonIfUnsupported)
164 {
165     return IsSupportedForDataTypeCl(reasonIfUnsupported,
166                                     output.GetDataType(),
167                                     &TrueFunc<>,
168                                     &FalseFuncU8<>);
169 }
170
171 bool IsClDirectConvolution2dSupported(const TensorInfo& weightInfo, const Convolution2dDescriptor& desc)
172 {
173     bool isSupported = false;
174
175     bool strideXIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideX);
176     bool strideXIsThree    = IsMatchingStride<3>(desc.m_StrideX);
177
178     bool strideYIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideY);
179     bool strideYIsThree    = IsMatchingStride<3>(desc.m_StrideY);
180
181     bool strideIsOneOrTwo        = strideXIsOneOrTwo && strideYIsOneOrTwo;
182     bool strideIsOneOrTwoOrThree = ( strideXIsOneOrTwo || strideXIsThree ) && ( strideYIsOneOrTwo || strideYIsThree );
183
184     // 1x1 convolution with strides of 1,2,3.
185     isSupported |= IsMatchingSize2d<1>(weightInfo) && ( strideIsOneOrTwoOrThree );
186
187     // 3x3 convolution with strides of 1,2.
188     isSupported |= IsMatchingSize2d<3>(weightInfo) && ( strideIsOneOrTwo );
189
190     // 5x5 convolution with strides of 1,2
191     isSupported |= IsMatchingSize2d<5>(weightInfo) && ( strideIsOneOrTwo );
192
193     //Fall back to normal convolution for the asymmetric padding case.
194     if (desc.m_PadLeft != desc.m_PadRight ||
195         desc.m_PadTop != desc.m_PadBottom)
196     {
197         //Direct convolution does not support asymmetric padding yet.
198         isSupported = false;
199     }
200
201     return isSupported;
202 }
203
204 bool IsDirectConvolution2dParamsSupportedCl(std::string* reasonIfUnsupported,
205                                             const Convolution2dDescriptor& parameters,
206                                             const TensorInfo& weightInfo)
207 {
208     return IsClDirectConvolution2dSupported(weightInfo, parameters);
209 }
210
211 bool IsConvolution2dSupportedCl(const TensorInfo& input,
212                                 const TensorInfo& output,
213                                 const Convolution2dDescriptor& descriptor,
214                                 const TensorInfo& weights,
215                                 const boost::optional<TensorInfo>& biases,
216                                 std::string* reasonIfUnsupported)
217 {
218     FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution2dWorkloadValidate,
219                                    reasonIfUnsupported,
220                                    input,
221                                    output,
222                                    descriptor,
223                                    weights,
224                                    biases);
225 }
226
227 bool IsDepthwiseConvolutionSupportedCl(const TensorInfo& input,
228                                        const TensorInfo& output,
229                                        const DepthwiseConvolution2dDescriptor& descriptor,
230                                        const TensorInfo& weights,
231                                        const boost::optional<TensorInfo>& biases,
232                                        std::string* reasonIfUnsupported)
233 {
234     FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
235                                    reasonIfUnsupported,
236                                    input,
237                                    output,
238                                    descriptor,
239                                    weights,
240                                    biases);
241 }
242
243 bool IsDivisionSupportedCl(const TensorInfo& input0,
244                            const TensorInfo& input1,
245                            const TensorInfo& output,
246                            std::string* reasonIfUnsupported)
247 {
248     FORWARD_WORKLOAD_VALIDATE_FUNC(ClDivisionWorkloadValidate,
249                                    reasonIfUnsupported,
250                                    input0,
251                                    input1,
252                                    output);
253 }
254
255 bool IsSubtractionSupportedCl(const TensorInfo& input0,
256                               const TensorInfo& input1,
257                               const TensorInfo& output,
258                               std::string* reasonIfUnsupported)
259 {
260     return FORWARD_CL_LAYER_SUPPORT_FUNC(ClSubtractionValidate(input0,
261                                          input1,
262                                          output,
263                                          reasonIfUnsupported));
264 }
265
266 bool IsFullyConnectedSupportedCl(const TensorInfo& input,
267                                  const TensorInfo& output,
268                                  const TensorInfo& weights,
269                                  const TensorInfo& biases,
270                                  const FullyConnectedDescriptor& descriptor,
271                                  std::string* reasonIfUnsupported)
272 {
273     FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
274                                    reasonIfUnsupported,
275                                    input,
276                                    output,
277                                    weights,
278                                    biases,
279                                    descriptor);
280 }
281
282 bool IsInputSupportedCl(const TensorInfo& input,
283     std::string* reasonIfUnsupported)
284 {
285     return IsSupportedForDataTypeCl(reasonIfUnsupported,
286                                     input.GetDataType(),
287                                     &TrueFunc<>,
288                                     &TrueFunc<>);
289 }
290
291 bool IsL2NormalizationSupportedCl(const TensorInfo& input,
292                                   const TensorInfo& output,
293                                   const L2NormalizationDescriptor& descriptor,
294                                   std::string* reasonIfUnsupported)
295 {
296     FORWARD_WORKLOAD_VALIDATE_FUNC(ClL2NormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
297 }
298
299 bool IsMergerSupportedCl(const std::vector<const TensorInfo*> inputs,
300                          const OriginsDescriptor& descriptor,
301                          std::string* reasonIfUnsupported)
302 {
303     ignore_unused(descriptor);
304     return IsSupportedForDataTypeCl(reasonIfUnsupported,
305                                     inputs[0]->GetDataType(),
306                                     &TrueFunc<>,
307                                     &FalseFuncU8<>);
308 }
309
310 bool IsMultiplicationSupportedCl(const TensorInfo& input0,
311                                  const TensorInfo& input1,
312                                  const TensorInfo& output,
313                                  std::string* reasonIfUnsupported)
314 {
315     FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate,
316                                    reasonIfUnsupported,
317                                    input0,
318                                    input1,
319                                    output);
320 }
321
322 bool IsNormalizationSupportedCl(const TensorInfo& input,
323                                 const TensorInfo& output,
324                                 const NormalizationDescriptor& descriptor,
325                                 std::string* reasonIfUnsupported)
326 {
327     FORWARD_WORKLOAD_VALIDATE_FUNC(ClNormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
328 }
329
330 bool IsOutputSupportedCl(const TensorInfo& output,
331                          std::string* reasonIfUnsupported)
332 {
333     return IsSupportedForDataTypeCl(reasonIfUnsupported,
334                                     output.GetDataType(),
335                                     &TrueFunc<>,
336                                     &TrueFunc<>);
337 }
338
339 bool IsPadSupportedCl(const TensorInfo& input,
340                       const TensorInfo& output,
341                       const PadDescriptor& descriptor,
342                       std::string* reasonIfUnsupported)
343 {
344     return FORWARD_CL_LAYER_SUPPORT_FUNC(ClPadValidate(input, output, descriptor, reasonIfUnsupported));
345 }
346
347 bool IsPermuteSupportedCl(const TensorInfo& input,
348                           const TensorInfo& output,
349                           const PermuteDescriptor& descriptor,
350                           std::string* reasonIfUnsupported)
351 {
352     ignore_unused(input);
353     ignore_unused(output);
354     FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, descriptor);
355 }
356
357 bool IsPooling2dSupportedCl(const TensorInfo& input,
358                             const TensorInfo& output,
359                             const Pooling2dDescriptor& descriptor,
360                             std::string* reasonIfUnsupported)
361 {
362     FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
363 }
364
365 bool IsResizeBilinearSupportedCl(const TensorInfo& input,
366                                  std::string* reasonIfUnsupported)
367 {
368     return IsSupportedForDataTypeCl(reasonIfUnsupported,
369                                     input.GetDataType(),
370                                     &TrueFunc<>,
371                                     &FalseFuncU8<>);
372 }
373
374 bool IsSoftmaxSupportedCl(const TensorInfo& input,
375                           const TensorInfo& output,
376                           const SoftmaxDescriptor& descriptor,
377                           std::string* reasonIfUnsupported)
378 {
379     ignore_unused(descriptor);
380     FORWARD_WORKLOAD_VALIDATE_FUNC(ClSoftmaxWorkloadValidate, reasonIfUnsupported, input, output);
381 }
382
383 bool IsSplitterSupportedCl(const TensorInfo& input,
384                            const ViewsDescriptor& descriptor,
385                            std::string* reasonIfUnsupported)
386 {
387     ignore_unused(descriptor);
388     return IsSupportedForDataTypeCl(reasonIfUnsupported,
389                                     input.GetDataType(),
390                                     &TrueFunc<>,
391                                     &TrueFunc<>);
392 }
393
394 bool IsFakeQuantizationSupportedCl(const TensorInfo& input,
395                                    const FakeQuantizationDescriptor& descriptor,
396                                    std::string* reasonIfUnsupported)
397 {
398     ignore_unused(input);
399     ignore_unused(descriptor);
400     return false;
401 }
402
403 bool IsReshapeSupportedCl(const TensorInfo& input,
404                           std::string* reasonIfUnsupported)
405 {
406     ignore_unused(input);
407     return true;
408 }
409
410 bool IsFloorSupportedCl(const TensorInfo& input,
411                         const TensorInfo& output,
412                         std::string* reasonIfUnsupported)
413 {
414     ignore_unused(output);
415     return IsClBackendSupported(reasonIfUnsupported) &&
416            IsSupportedForDataTypeGeneric(reasonIfUnsupported,
417                                          input.GetDataType(),
418                                          &FalseFuncF16<>,
419                                          &TrueFunc<>,
420                                          &FalseFuncU8<>);
421 }
422
423 bool IsLstmSupportedCl(const TensorInfo& input, const TensorInfo& outputStateIn,
424                        const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
425                        const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
426                        const TensorInfo& output, const LstmDescriptor& descriptor,
427                        const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights,
428                        const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights,
429                        const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights,
430                        const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
431                        const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights,
432                        const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights,
433                        const TensorInfo* inputGateBias, const TensorInfo* projectionWeights,
434                        const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights,
435                        const TensorInfo* cellToOutputWeights, std::string* reasonIfUnsupported)
436 {
437     FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate, reasonIfUnsupported,
438                                    input, outputStateIn, cellStateIn, scratchBuffer, outputStateOut, cellStateOut,
439                                    output, descriptor, inputToForgetWeights, inputToCellWeights,
440                                    inputToOutputWeights, recurrentToForgetWeights,
441                                    recurrentToCellWeights, recurrentToOutputWeights,
442                                    forgetGateBias, cellBias, outputGateBias,
443                                    inputToInputWeights, recurrentToInputWeights,
444                                    cellToInputWeights, inputGateBias, projectionWeights,
445                                    projectionBias, cellToForgetWeights, cellToOutputWeights);
446 }
447
448 bool IsConvertFp16ToFp32SupportedCl(const TensorInfo& input,
449                                     const TensorInfo& output,
450                                     std::string* reasonIfUnsupported)
451 {
452     FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp16ToFp32WorkloadValidate,
453                                    reasonIfUnsupported,
454                                    input,
455                                    output,
456                                    reasonIfUnsupported);
457 }
458
459 bool IsConvertFp32ToFp16SupportedCl(const TensorInfo& input,
460                                     const TensorInfo& output,
461                                     std::string* reasonIfUnsupported)
462 {
463     FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp32ToFp16WorkloadValidate,
464                                    reasonIfUnsupported,
465                                    input,
466                                    output,
467                                    reasonIfUnsupported);
468 }
469
470 bool IsMeanSupportedCl(const TensorInfo& input,
471                        const TensorInfo& output,
472                        const MeanDescriptor& descriptor,
473                        std::string* reasonIfUnsupported)
474 {
475     return false;
476 }
477
478 }