Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / ClLayerSupport.cpp
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // See LICENSE file in the project root for full license information.
4 //
5
6 #include "LayerSupportCommon.hpp"
7
8 #include "ClLayerSupport.hpp"
9 #include "InternalTypes.hpp"
10 #include <armnn/Descriptors.hpp>
11 #include <armnn/Types.hpp>
12 #include <armnn/Tensor.hpp>
13
14 #include <boost/core/ignore_unused.hpp>
15
16 #ifdef ARMCOMPUTECL_ENABLED
17 #include "ClWorkloads/ClAdditionFloat32Workload.hpp"
18 #include "ClWorkloads/ClActivationFloat32Workload.hpp"
19 #include "ClWorkloads/ClBatchNormalizationFloat32Workload.hpp"
20
21 #include "ClWorkloads/ClConvertFp16ToFp32Workload.hpp"
22 #include "ClWorkloads/ClConvertFp32ToFp16Workload.hpp"
23 #include "ClWorkloads/ClConvolution2dBaseWorkload.hpp"
24 #include "ClWorkloads/ClDepthwiseConvolutionBaseWorkload.hpp"
25 #include "ClWorkloads/ClL2NormalizationFloat32Workload.hpp"
26 #include "ClWorkloads/ClMultiplicationFloat32Workload.hpp"
27 #include "ClWorkloads/ClFullyConnectedFloat32Workload.hpp"
28 #include "ClWorkloads/ClPooling2dBaseWorkload.hpp"
29 #include "ClWorkloads/ClPermuteWorkload.hpp"
30 #include "ClWorkloads/ClNormalizationFloat32Workload.hpp"
31 #include "ClWorkloads/ClSoftmaxBaseWorkload.hpp"
32 #include "ClWorkloads/ClLstmFloat32Workload.hpp"
33 #endif
34
35 using namespace boost;
36
37 namespace armnn
38 {
39 namespace
40 {
41 template<unsigned int FilterSize>
42 bool IsMatchingSize2d(const TensorInfo& weightInfo)
43 {
44     // Width & Height must match.
45     return (weightInfo.GetShape()[3] == FilterSize) && (weightInfo.GetShape()[2] == FilterSize);
46 }
47
48 template<uint32_t ValidStride>
49 bool IsMatchingStride(uint32_t actualStride)
50 {
51     return ValidStride == actualStride;
52 }
53
54 template<uint32_t FirstStride, uint32_t SecondStride, uint32_t... ValidStrides>
55 bool IsMatchingStride(uint32_t actualStride)
56 {
57     return IsMatchingStride<FirstStride>(actualStride) || IsMatchingStride<SecondStride, ValidStrides...>(actualStride);
58 };
59
60 bool IsClBackendSupported(std::string* reasonIfUnsupported)
61 {
62 #if ARMCOMPUTECL_ENABLED
63     return true;
64 #else
65     if (reasonIfUnsupported != nullptr)
66     {
67         *reasonIfUnsupported = "The armnn library has been built without CL support";
68     }
69     return false;
70 #endif
71 }
72
73 #if ARMCOMPUTECL_ENABLED
74 #define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) (expr)
75 #else
76 #define FORWARD_CL_LAYER_SUPPORT_FUNC(expr) IsClBackendSupported(reasonIfUnsupported)
77 #endif
78
79 #if ARMCOMPUTECL_ENABLED
80 template<class FuncType, class... Args>
81 inline bool IsWorkloadSupported(FuncType&& func, std::string* reasonIfUnsupported, Args&&... args)
82 {
83     arm_compute::Status aclStatus = func(std::forward<Args>(args)...);
84     const bool supported = (aclStatus.error_code() == arm_compute::ErrorCode::OK);
85     if (!supported && reasonIfUnsupported)
86     {
87         *reasonIfUnsupported = aclStatus.error_description();
88     }
89     return supported;
90 }
91
92 #define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
93     return IsWorkloadSupported(func, reasonIfUnsupported, __VA_ARGS__);
94 #else
95 #define FORWARD_WORKLOAD_VALIDATE_FUNC(func, reasonIfUnsupported, ...) \
96     return IsClBackendSupported(reasonIfUnsupported);
97 #endif
98
99 } //namespace
100
101 template<typename FloatFunc, typename Uint8Func, typename ... Params>
102 bool IsSupportedForDataTypeCl(std::string* reasonIfUnsupported,
103                               DataType dataType,
104                               FloatFunc floatFuncPtr,
105                               Uint8Func uint8FuncPtr,
106                               Params&&... params)
107 {
108     return IsClBackendSupported(reasonIfUnsupported) &&
109         IsSupportedForDataTypeGeneric(reasonIfUnsupported,
110                                       dataType,
111                                       floatFuncPtr,
112                                       floatFuncPtr,
113                                       uint8FuncPtr,
114                                       std::forward<Params>(params)...);
115 }
116
117 bool IsActivationSupportedCl(const TensorInfo& input,
118                              const TensorInfo& output,
119                              const ActivationDescriptor& descriptor,
120                              std::string* reasonIfUnsupported)
121 {
122     FORWARD_WORKLOAD_VALIDATE_FUNC(ClActivationWorkloadValidate,
123                                    reasonIfUnsupported,
124                                    input,
125                                    output,
126                                    descriptor);
127 }
128
129 bool IsAdditionSupportedCl(const TensorInfo& input0,
130                            const TensorInfo& input1,
131                            const TensorInfo& output,
132                            std::string* reasonIfUnsupported)
133 {
134     return FORWARD_CL_LAYER_SUPPORT_FUNC(ClAdditionValidate(input0,
135         input1,
136         output,
137         reasonIfUnsupported));
138 }
139
140 bool IsBatchNormalizationSupportedCl(const TensorInfo& input,
141                                      const TensorInfo& output,
142                                      const TensorInfo& mean,
143                                      const TensorInfo& var,
144                                      const TensorInfo& beta,
145                                      const TensorInfo& gamma,
146                                      const BatchNormalizationDescriptor& descriptor,
147                                      std::string* reasonIfUnsupported)
148 {
149     FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchNormalizationValidate,
150                                    reasonIfUnsupported,
151                                    input,
152                                    output,
153                                    mean,
154                                    var,
155                                    beta,
156                                    gamma,
157                                    descriptor);
158 }
159
160 bool IsConstantSupportedCl(const TensorInfo& output,
161                            std::string* reasonIfUnsupported)
162 {
163     return IsSupportedForDataTypeCl(reasonIfUnsupported,
164                                     output.GetDataType(),
165                                     &TrueFunc<>,
166                                     &FalseFuncU8<>);
167 }
168
169 bool IsClDirectConvolution2dSupported(const TensorInfo& weightInfo, const Convolution2dDescriptor& desc)
170 {
171     bool isSupported = false;
172
173     bool strideXIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideX);
174     bool strideXIsThree    = IsMatchingStride<3>(desc.m_StrideX);
175
176     bool strideYIsOneOrTwo = IsMatchingStride<1, 2>(desc.m_StrideY);
177     bool strideYIsThree    = IsMatchingStride<3>(desc.m_StrideY);
178
179     bool strideIsOneOrTwo        = strideXIsOneOrTwo && strideYIsOneOrTwo;
180     bool strideIsOneOrTwoOrThree = ( strideXIsOneOrTwo || strideXIsThree ) && ( strideYIsOneOrTwo || strideYIsThree );
181
182     // 1x1 convolution with strides of 1,2,3.
183     isSupported |= IsMatchingSize2d<1>(weightInfo) && ( strideIsOneOrTwoOrThree );
184
185     // 3x3 convolution with strides of 1,2.
186     isSupported |= IsMatchingSize2d<3>(weightInfo) && ( strideIsOneOrTwo );
187
188     // 5x5 convolution with strides of 1,2
189     isSupported |= IsMatchingSize2d<5>(weightInfo) && ( strideIsOneOrTwo );
190
191     //Fall back to normal convolution for the asymmetric padding case.
192     if (desc.m_PadLeft != desc.m_PadRight ||
193         desc.m_PadTop != desc.m_PadBottom)
194     {
195         //Direct convolution does not support asymmetric padding yet.
196         isSupported = false;
197     }
198
199     return isSupported;
200 }
201
202 bool IsDirectConvolution2dParamsSupportedCl(std::string* reasonIfUnsupported,
203                                             const Convolution2dDescriptor& parameters,
204                                             const TensorInfo& weightInfo)
205 {
206     return IsClDirectConvolution2dSupported(weightInfo, parameters);
207 }
208
209 bool IsConvolution2dSupportedCl(const TensorInfo& input,
210                                 const TensorInfo& output,
211                                 const Convolution2dDescriptor& descriptor,
212                                 const TensorInfo& weights,
213                                 const TensorInfo& biases,
214                                 std::string* reasonIfUnsupported)
215 {
216     FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvolution2dWorkloadValidate,
217                                    reasonIfUnsupported,
218                                    input,
219                                    output,
220                                    descriptor,
221                                    weights,
222                                    biases);
223 }
224
225 bool IsDepthwiseConvolutionSupportedCl(const TensorInfo& input,
226                                        const TensorInfo& output,
227                                        const DepthwiseConvolution2dDescriptor& descriptor,
228                                        const TensorInfo& weights,
229                                        const TensorInfo& biases,
230                                        std::string* reasonIfUnsupported)
231 {
232     FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
233                                    reasonIfUnsupported,
234                                    input,
235                                    output,
236                                    descriptor,
237                                    weights,
238                                    biases);
239 }
240
241 bool IsFullyConnectedSupportedCl(const TensorInfo& input,
242                                  const TensorInfo& output,
243                                  const TensorInfo& weights,
244                                  const TensorInfo& biases,
245                                  const FullyConnectedDescriptor& descriptor,
246                                  std::string* reasonIfUnsupported)
247 {
248     // At the moment U8 is unsupported
249     if (input.GetDataType() == DataType::QuantisedAsymm8)
250     {
251         return false;
252     }
253     FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
254                                    reasonIfUnsupported,
255                                    input,
256                                    output,
257                                    weights,
258                                    biases,
259                                    descriptor);
260 }
261
262 bool IsInputSupportedCl(const TensorInfo& input,
263     std::string* reasonIfUnsupported)
264 {
265     return IsSupportedForDataTypeCl(reasonIfUnsupported,
266                                     input.GetDataType(),
267                                     &TrueFunc<>,
268                                     &TrueFunc<>);
269 }
270
271 bool IsL2NormalizationSupportedCl(const TensorInfo& input,
272                                   const TensorInfo& output,
273                                   std::string* reasonIfUnsupported)
274 {
275     FORWARD_WORKLOAD_VALIDATE_FUNC(ClL2NormalizationWorkloadValidate, reasonIfUnsupported, input, output);
276 }
277
278 bool IsMergerSupportedCl(const std::vector<const TensorInfo*> inputs,
279                          const OriginsDescriptor& descriptor,
280                          std::string* reasonIfUnsupported)
281 {
282     ignore_unused(descriptor);
283     return IsSupportedForDataTypeCl(reasonIfUnsupported,
284                                     inputs[0]->GetDataType(),
285                                     &TrueFunc<>,
286                                     &FalseFuncU8<>);
287 }
288
289 bool IsMultiplicationSupportedCl(const TensorInfo& input0,
290                                  const TensorInfo& input1,
291                                  const TensorInfo& output,
292                                  std::string* reasonIfUnsupported)
293 {
294     FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate,
295                                    reasonIfUnsupported,
296                                    input0,
297                                    input1,
298                                    output);
299 }
300
301 bool IsNormalizationSupportedCl(const TensorInfo& input,
302                                 const TensorInfo& output,
303                                 const NormalizationDescriptor& descriptor,
304                                 std::string* reasonIfUnsupported)
305 {
306     FORWARD_WORKLOAD_VALIDATE_FUNC(ClNormalizationWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
307 }
308
309 bool IsOutputSupportedCl(const TensorInfo& output,
310                          std::string* reasonIfUnsupported)
311 {
312     return IsSupportedForDataTypeCl(reasonIfUnsupported,
313                                     output.GetDataType(),
314                                     &TrueFunc<>,
315                                     &TrueFunc<>);
316 }
317
318 bool IsPermuteSupportedCl(const TensorInfo& input,
319                           const TensorInfo& output,
320                           const PermuteDescriptor& descriptor,
321                           std::string* reasonIfUnsupported)
322 {
323     ignore_unused(input);
324     ignore_unused(output);
325     FORWARD_WORKLOAD_VALIDATE_FUNC(ClPermuteWorkloadValidate, reasonIfUnsupported, descriptor);
326 }
327
328 bool IsPooling2dSupportedCl(const TensorInfo& input,
329                             const TensorInfo& output,
330                             const Pooling2dDescriptor& descriptor,
331                             std::string* reasonIfUnsupported)
332 {
333     FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
334 }
335
336 bool IsResizeBilinearSupportedCl(const TensorInfo& input,
337                                  std::string* reasonIfUnsupported)
338 {
339     return IsSupportedForDataTypeCl(reasonIfUnsupported,
340                                     input.GetDataType(),
341                                     &TrueFunc<>,
342                                     &FalseFuncU8<>);
343 }
344
345 bool IsSoftmaxSupportedCl(const TensorInfo& input,
346                           const TensorInfo& output,
347                           const SoftmaxDescriptor& descriptor,
348                           std::string* reasonIfUnsupported)
349 {
350     ignore_unused(descriptor);
351     FORWARD_WORKLOAD_VALIDATE_FUNC(ClSoftmaxWorkloadValidate, reasonIfUnsupported, input, output);
352 }
353
354 bool IsSplitterSupportedCl(const TensorInfo& input,
355                            const ViewsDescriptor& descriptor,
356                            std::string* reasonIfUnsupported)
357 {
358     ignore_unused(descriptor);
359     return IsSupportedForDataTypeCl(reasonIfUnsupported,
360                                     input.GetDataType(),
361                                     &TrueFunc<>,
362                                     &TrueFunc<>);
363 }
364
365 bool IsFakeQuantizationSupportedCl(const TensorInfo& input,
366                                    const FakeQuantizationDescriptor& descriptor,
367                                    std::string* reasonIfUnsupported)
368 {
369     ignore_unused(input);
370     ignore_unused(descriptor);
371     return false;
372 }
373
374 bool IsReshapeSupportedCl(const TensorInfo& input,
375                           std::string* reasonIfUnsupported)
376 {
377     ignore_unused(input);
378     return true;
379 }
380
381 bool IsFloorSupportedCl(const TensorInfo& input,
382                         const TensorInfo& output,
383                         std::string* reasonIfUnsupported)
384 {
385     ignore_unused(output);
386     return IsClBackendSupported(reasonIfUnsupported) &&
387            IsSupportedForDataTypeGeneric(reasonIfUnsupported,
388                                          input.GetDataType(),
389                                          &FalseFuncF16<>,
390                                          &TrueFunc<>,
391                                          &FalseFuncU8<>);
392 }
393
394 bool IsLstmSupportedCl(const TensorInfo& input, const TensorInfo& outputStateIn,
395                        const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
396                        const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
397                        const TensorInfo& output, const LstmDescriptor& descriptor,
398                        const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights,
399                        const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights,
400                        const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights,
401                        const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
402                        const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights,
403                        const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights,
404                        const TensorInfo* inputGateBias, const TensorInfo* projectionWeights,
405                        const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights,
406                        const TensorInfo* cellToOutputWeights, std::string* reasonIfUnsupported)
407 {
408     FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloat32WorkloadValidate, reasonIfUnsupported,
409                                    input, outputStateIn, cellStateIn, scratchBuffer, outputStateOut, cellStateOut,
410                                    output, descriptor, inputToForgetWeights, inputToCellWeights,
411                                    inputToOutputWeights, recurrentToForgetWeights,
412                                    recurrentToCellWeights, recurrentToOutputWeights,
413                                    forgetGateBias, cellBias, outputGateBias,
414                                    inputToInputWeights, recurrentToInputWeights,
415                                    cellToInputWeights, inputGateBias, projectionWeights,
416                                    projectionBias, cellToForgetWeights, cellToOutputWeights);
417 }
418
419 bool IsConvertFp16ToFp32SupportedCl(const TensorInfo& input,
420                                     const TensorInfo& output,
421                                     std::string* reasonIfUnsupported)
422 {
423     FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp16ToFp32WorkloadValidate,
424                                    reasonIfUnsupported,
425                                    input,
426                                    output,
427                                    reasonIfUnsupported);
428 }
429
430 bool IsConvertFp32ToFp16SupportedCl(const TensorInfo& input,
431                                     const TensorInfo& output,
432                                     std::string* reasonIfUnsupported)
433 {
434     FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp32ToFp16WorkloadValidate,
435                                    reasonIfUnsupported,
436                                    input,
437                                    output,
438                                    reasonIfUnsupported);
439 }
440
441 }