Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / ClLayerSupport.cpp
index 8905adf..72594ac 100644 (file)
@@ -7,7 +7,6 @@
 
 #include "ClLayerSupport.hpp"
 #include "InternalTypes.hpp"
-
 #include <armnn/Descriptors.hpp>
 #include <armnn/Types.hpp>
 #include <armnn/Tensor.hpp>
 
 #ifdef ARMCOMPUTECL_ENABLED
 #include "ClWorkloads/ClAdditionFloat32Workload.hpp"
+#include "ClWorkloads/ClActivationFloat32Workload.hpp"
+#include "ClWorkloads/ClBatchNormalizationFloat32Workload.hpp"
+
+#include "ClWorkloads/ClConvertFp16ToFp32Workload.hpp"
+#include "ClWorkloads/ClConvertFp32ToFp16Workload.hpp"
 #include "ClWorkloads/ClConvolution2dBaseWorkload.hpp"
+#include "ClWorkloads/ClDepthwiseConvolutionBaseWorkload.hpp"
+#include "ClWorkloads/ClL2NormalizationFloat32Workload.hpp"
+#include "ClWorkloads/ClMultiplicationFloat32Workload.hpp"
+#include "ClWorkloads/ClFullyConnectedFloat32Workload.hpp"
 #include "ClWorkloads/ClPooling2dBaseWorkload.hpp"
 #include "ClWorkloads/ClPermuteWorkload.hpp"
 #include "ClWorkloads/ClNormalizationFloat32Workload.hpp"
+#include "ClWorkloads/ClSoftmaxBaseWorkload.hpp"
+#include "ClWorkloads/ClLstmFloat32Workload.hpp"
 #endif
 
 using namespace boost;
@@ -31,7 +41,7 @@ namespace
 template<unsigned int FilterSize>
 bool IsMatchingSize2d(const TensorInfo& weightInfo)
 {
-    // Width & Height must match
+    // Width & Height must match.
     return (weightInfo.GetShape()[3] == FilterSize) && (weightInfo.GetShape()[2] == FilterSize);
 }
 
@@ -88,58 +98,10 @@ inline bool IsWorkloadSupported(FuncType&& func, std::string* reasonIfUnsupporte
 
 } //namespace
 
-bool IsClActivationUint8Supported(std::string* reasonIfUnsupported, const ActivationDescriptor& parameters)
-{
-    if (parameters.m_Function != ActivationFunction::BoundedReLu)
-    {
-        if (reasonIfUnsupported)
-        {
-            *reasonIfUnsupported = "Unsupported activation function, only BoundedReLu is supported";
-        }
-
-        return false;
-    }
-
-    return true;
-}
-
-bool IsClDepthwiseConvolution2dDescParamsSupported(std::string* reasonIfUnsupported,
-                                                   const DepthwiseConvolution2dDescriptor& parameters,
-                                                   const TensorInfo& weights)
-{
-    if (weights.GetNumDimensions() != 4)
-    {
-        if (reasonIfUnsupported)
-        {
-            *reasonIfUnsupported = "Depthwise convolution Weight tensor needs to be 4d";
-        }
-        return false;
-    }
-    // weights.GetShape()[0] = channel multiplier
-    if (weights.GetShape()[0] != 1)
-    {
-        if (reasonIfUnsupported)
-        {
-            *reasonIfUnsupported = "Channel multiplier only supports the value 1 in the CL backend";
-        }
-        return false;
-    }
-    else if ((weights.GetDataType() == armnn::DataType::QuantisedAsymm8) && !IsMatchingSize2d<3>(weights))
-    {
-        if (reasonIfUnsupported)
-        {
-            *reasonIfUnsupported = "CL backend only supports 3x3 filtering for Depthwise Convolution on 8-bit";
-        }
-        return false;
-    }
-
-    return true;
-}
-
-template<typename Float32Func, typename Uint8Func, typename ... Params>
+template<typename FloatFunc, typename Uint8Func, typename ... Params>
 bool IsSupportedForDataTypeCl(std::string* reasonIfUnsupported,
                               DataType dataType,
-                              Float32Func floatFuncPtr,
+                              FloatFunc floatFuncPtr,
                               Uint8Func uint8FuncPtr,
                               Params&&... params)
 {
@@ -147,19 +109,21 @@ bool IsSupportedForDataTypeCl(std::string* reasonIfUnsupported,
         IsSupportedForDataTypeGeneric(reasonIfUnsupported,
                                       dataType,
                                       floatFuncPtr,
+                                      floatFuncPtr,
                                       uint8FuncPtr,
                                       std::forward<Params>(params)...);
 }
 
 bool IsActivationSupportedCl(const TensorInfo& input,
+                             const TensorInfo& output,
                              const ActivationDescriptor& descriptor,
                              std::string* reasonIfUnsupported)
 {
-    return IsSupportedForDataTypeCl(reasonIfUnsupported,
-                                    input.GetDataType(),
-                                    &TrueFunc<const ActivationDescriptor&>,
-                                    &IsClActivationUint8Supported,
-                                    descriptor);
+    FORWARD_WORKLOAD_VALIDATE_FUNC(ClActivationWorkloadValidate,
+                                   reasonIfUnsupported,
+                                   input,
+                                   output,
+                                   descriptor);
 }
 
 bool IsAdditionSupportedCl(const TensorInfo& input0,
@@ -167,21 +131,30 @@ bool IsAdditionSupportedCl(const TensorInfo& input0,
                            const TensorInfo& output,
                            std::string* reasonIfUnsupported)
 {
-    return FORWARD_CL_LAYER_SUPPORT_FUNC(ClAdditionFloat32Workload::IsSupported(input0,
+    return FORWARD_CL_LAYER_SUPPORT_FUNC(ClAdditionValidate(input0,
         input1,
         output,
         reasonIfUnsupported));
 }
 
 bool IsBatchNormalizationSupportedCl(const TensorInfo& input,
+                                     const TensorInfo& output,
+                                     const TensorInfo& mean,
+                                     const TensorInfo& var,
+                                     const TensorInfo& beta,
+                                     const TensorInfo& gamma,
                                      const BatchNormalizationDescriptor& descriptor,
                                      std::string* reasonIfUnsupported)
 {
-    return IsSupportedForDataTypeCl(reasonIfUnsupported,
-                                    input.GetDataType(),
-                                    &TrueFunc<const BatchNormalizationDescriptor&>,
-                                    &FalseFuncU8<const BatchNormalizationDescriptor&>,
-                                    descriptor);
+    FORWARD_WORKLOAD_VALIDATE_FUNC(ClBatchNormalizationValidate,
+                                   reasonIfUnsupported,
+                                   input,
+                                   output,
+                                   mean,
+                                   var,
+                                   beta,
+                                   gamma,
+                                   descriptor);
 }
 
 bool IsConstantSupportedCl(const TensorInfo& output,
@@ -206,20 +179,20 @@ bool IsClDirectConvolution2dSupported(const TensorInfo& weightInfo, const Convol
     bool strideIsOneOrTwo        = strideXIsOneOrTwo && strideYIsOneOrTwo;
     bool strideIsOneOrTwoOrThree = ( strideXIsOneOrTwo || strideXIsThree ) && ( strideYIsOneOrTwo || strideYIsThree );
 
-    // 1x1 convolution with strides of 1,2,3
+    // 1x1 convolution with strides of 1,2,3.
     isSupported |= IsMatchingSize2d<1>(weightInfo) && ( strideIsOneOrTwoOrThree );
 
-    // 3x3 convolution with strides of 1,2
+    // 3x3 convolution with strides of 1,2.
     isSupported |= IsMatchingSize2d<3>(weightInfo) && ( strideIsOneOrTwo );
 
     // 5x5 convolution with strides of 1,2
     isSupported |= IsMatchingSize2d<5>(weightInfo) && ( strideIsOneOrTwo );
 
-    //fall back to normal convolution for the asymmetric padding case.
+    //Fall back to normal convolution for the asymmetric padding case.
     if (desc.m_PadLeft != desc.m_PadRight ||
         desc.m_PadTop != desc.m_PadBottom)
     {
-        //direct convolution does not support asymmetric padding yet.
+        //Direct convolution does not support asymmetric padding yet.
         isSupported = false;
     }
 
@@ -250,27 +223,40 @@ bool IsConvolution2dSupportedCl(const TensorInfo& input,
 }
 
 bool IsDepthwiseConvolutionSupportedCl(const TensorInfo& input,
+                                       const TensorInfo& output,
                                        const DepthwiseConvolution2dDescriptor& descriptor,
                                        const TensorInfo& weights,
+                                       const TensorInfo& biases,
                                        std::string* reasonIfUnsupported)
 {
-    return IsSupportedForDataTypeCl(reasonIfUnsupported,
-                                    input.GetDataType(),
-                                    &IsClDepthwiseConvolution2dDescParamsSupported,
-                                    &IsClDepthwiseConvolution2dDescParamsSupported,
-                                    descriptor,
-                                    weights);
+    FORWARD_WORKLOAD_VALIDATE_FUNC(ClDepthwiseConvolutionWorkloadValidate,
+                                   reasonIfUnsupported,
+                                   input,
+                                   output,
+                                   descriptor,
+                                   weights,
+                                   biases);
 }
 
 bool IsFullyConnectedSupportedCl(const TensorInfo& input,
+                                 const TensorInfo& output,
+                                 const TensorInfo& weights,
+                                 const TensorInfo& biases,
                                  const FullyConnectedDescriptor& descriptor,
                                  std::string* reasonIfUnsupported)
 {
-    ignore_unused(descriptor);
-    return IsSupportedForDataTypeCl(reasonIfUnsupported,
-                                    input.GetDataType(),
-                                    &TrueFunc<>,
-                                    &FalseFuncU8<>);
+    // At the moment U8 is unsupported
+    if (input.GetDataType() == DataType::QuantisedAsymm8)
+    {
+        return false;
+    }
+    FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
+                                   reasonIfUnsupported,
+                                   input,
+                                   output,
+                                   weights,
+                                   biases,
+                                   descriptor);
 }
 
 bool IsInputSupportedCl(const TensorInfo& input,
@@ -283,12 +269,10 @@ bool IsInputSupportedCl(const TensorInfo& input,
 }
 
 bool IsL2NormalizationSupportedCl(const TensorInfo& input,
+                                  const TensorInfo& output,
                                   std::string* reasonIfUnsupported)
 {
-    return IsSupportedForDataTypeCl(reasonIfUnsupported,
-                                    input.GetDataType(),
-                                    &TrueFunc<>,
-                                    &FalseFuncU8<>);
+    FORWARD_WORKLOAD_VALIDATE_FUNC(ClL2NormalizationWorkloadValidate, reasonIfUnsupported, input, output);
 }
 
 bool IsMergerSupportedCl(const std::vector<const TensorInfo*> inputs,
@@ -304,13 +288,14 @@ bool IsMergerSupportedCl(const std::vector<const TensorInfo*> inputs,
 
 bool IsMultiplicationSupportedCl(const TensorInfo& input0,
                                  const TensorInfo& input1,
+                                 const TensorInfo& output,
                                  std::string* reasonIfUnsupported)
 {
-    ignore_unused(input1);
-    return IsSupportedForDataTypeCl(reasonIfUnsupported,
-                                    input0.GetDataType(),
-                                    &TrueFunc<>,
-                                    &FalseFuncU8<>);
+    FORWARD_WORKLOAD_VALIDATE_FUNC(ClMultiplicationWorkloadValidate,
+                                   reasonIfUnsupported,
+                                   input0,
+                                   input1,
+                                   output);
 }
 
 bool IsNormalizationSupportedCl(const TensorInfo& input,
@@ -358,14 +343,12 @@ bool IsResizeBilinearSupportedCl(const TensorInfo& input,
 }
 
 bool IsSoftmaxSupportedCl(const TensorInfo& input,
+                          const TensorInfo& output,
                           const SoftmaxDescriptor& descriptor,
                           std::string* reasonIfUnsupported)
 {
     ignore_unused(descriptor);
-    return IsSupportedForDataTypeCl(reasonIfUnsupported,
-                                    input.GetDataType(),
-                                    &TrueFunc<>,
-                                    &TrueFunc<>);
+    FORWARD_WORKLOAD_VALIDATE_FUNC(ClSoftmaxWorkloadValidate, reasonIfUnsupported, input, output);
 }
 
 bool IsSplitterSupportedCl(const TensorInfo& input,
@@ -400,10 +383,59 @@ bool IsFloorSupportedCl(const TensorInfo& input,
                         std::string* reasonIfUnsupported)
 {
     ignore_unused(output);
-    return IsSupportedForDataTypeCl(reasonIfUnsupported,
-                                    input.GetDataType(),
-                                    &TrueFunc<>,
-                                    &FalseFuncU8<>);
+    return IsClBackendSupported(reasonIfUnsupported) &&
+           IsSupportedForDataTypeGeneric(reasonIfUnsupported,
+                                         input.GetDataType(),
+                                         &FalseFuncF16<>,
+                                         &TrueFunc<>,
+                                         &FalseFuncU8<>);
+}
+
+bool IsLstmSupportedCl(const TensorInfo& input, const TensorInfo& outputStateIn,
+                       const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
+                       const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
+                       const TensorInfo& output, const LstmDescriptor& descriptor,
+                       const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights,
+                       const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights,
+                       const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights,
+                       const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
+                       const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights,
+                       const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights,
+                       const TensorInfo* inputGateBias, const TensorInfo* projectionWeights,
+                       const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights,
+                       const TensorInfo* cellToOutputWeights, std::string* reasonIfUnsupported)
+{
+    FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloat32WorkloadValidate, reasonIfUnsupported,
+                                   input, outputStateIn, cellStateIn, scratchBuffer, outputStateOut, cellStateOut,
+                                   output, descriptor, inputToForgetWeights, inputToCellWeights,
+                                   inputToOutputWeights, recurrentToForgetWeights,
+                                   recurrentToCellWeights, recurrentToOutputWeights,
+                                   forgetGateBias, cellBias, outputGateBias,
+                                   inputToInputWeights, recurrentToInputWeights,
+                                   cellToInputWeights, inputGateBias, projectionWeights,
+                                   projectionBias, cellToForgetWeights, cellToOutputWeights);
+}
+
+bool IsConvertFp16ToFp32SupportedCl(const TensorInfo& input,
+                                    const TensorInfo& output,
+                                    std::string* reasonIfUnsupported)
+{
+    FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp16ToFp32WorkloadValidate,
+                                   reasonIfUnsupported,
+                                   input,
+                                   output,
+                                   reasonIfUnsupported);
+}
+
+bool IsConvertFp32ToFp16SupportedCl(const TensorInfo& input,
+                                    const TensorInfo& output,
+                                    std::string* reasonIfUnsupported)
+{
+    FORWARD_WORKLOAD_VALIDATE_FUNC(ClConvertFp32ToFp16WorkloadValidate,
+                                   reasonIfUnsupported,
+                                   input,
+                                   output,
+                                   reasonIfUnsupported);
 }
 
 }