IVGCVSW-3296 Add CL backend support for ResizeNearestNeighbour
[platform/upstream/armnn.git] / src / backends / cl / ClLayerSupport.cpp
index 12c2efe..7eb1dcf 100644 (file)
 #include "workloads/ClPadWorkload.hpp"
 #include "workloads/ClPermuteWorkload.hpp"
 #include "workloads/ClPooling2dWorkload.hpp"
+#include "workloads/ClPreluWorkload.hpp"
+#include "workloads/ClResizeWorkload.hpp"
 #include "workloads/ClQuantizeWorkload.hpp"
 #include "workloads/ClSoftmaxBaseWorkload.hpp"
 #include "workloads/ClSpaceToBatchNdWorkload.hpp"
+#include "workloads/ClSpaceToDepthWorkload.hpp"
 #include "workloads/ClSplitterWorkload.hpp"
 #include "workloads/ClStridedSliceWorkload.hpp"
 #include "workloads/ClSubtractionWorkload.hpp"
@@ -403,23 +406,7 @@ bool ClLayerSupport::IsLstmSupported(const TensorInfo& input,
                                      const TensorInfo& cellStateOut,
                                      const TensorInfo& output,
                                      const LstmDescriptor& descriptor,
-                                     const TensorInfo& inputToForgetWeights,
-                                     const TensorInfo& inputToCellWeights,
-                                     const TensorInfo& inputToOutputWeights,
-                                     const TensorInfo& recurrentToForgetWeights,
-                                     const TensorInfo& recurrentToCellWeights,
-                                     const TensorInfo& recurrentToOutputWeights,
-                                     const TensorInfo& forgetGateBias,
-                                     const TensorInfo& cellBias,
-                                     const TensorInfo& outputGateBias,
-                                     const TensorInfo* inputToInputWeights,
-                                     const TensorInfo* recurrentToInputWeights,
-                                     const TensorInfo* cellToInputWeights,
-                                     const TensorInfo* inputGateBias,
-                                     const TensorInfo* projectionWeights,
-                                     const TensorInfo* projectionBias,
-                                     const TensorInfo* cellToForgetWeights,
-                                     const TensorInfo* cellToOutputWeights,
+                                     const LstmInputParamsInfo& paramsInfo,
                                      Optional<std::string&> reasonIfUnsupported) const
 {
     FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate,
@@ -432,23 +419,7 @@ bool ClLayerSupport::IsLstmSupported(const TensorInfo& input,
                                    cellStateOut,
                                    output,
                                    descriptor,
-                                   inputToForgetWeights,
-                                   inputToCellWeights,
-                                   inputToOutputWeights,
-                                   recurrentToForgetWeights,
-                                   recurrentToCellWeights,
-                                   recurrentToOutputWeights,
-                                   forgetGateBias,
-                                   cellBias,
-                                   outputGateBias,
-                                   inputToInputWeights,
-                                   recurrentToInputWeights,
-                                   cellToInputWeights,
-                                   inputGateBias,
-                                   projectionWeights,
-                                   projectionBias,
-                                   cellToForgetWeights,
-                                   cellToOutputWeights);
+                                   paramsInfo);
 }
 
 bool ClLayerSupport::IsMaximumSupported(const TensorInfo& input0,
@@ -567,6 +538,14 @@ bool ClLayerSupport::IsPooling2dSupported(const TensorInfo& input,
     FORWARD_WORKLOAD_VALIDATE_FUNC(ClPooling2dWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
 }
 
+bool ClLayerSupport::IsPreluSupported(const armnn::TensorInfo &input,
+                                      const armnn::TensorInfo &alpha,
+                                      const armnn::TensorInfo &output,
+                                      armnn::Optional<std::string &> reasonIfUnsupported) const
+{
+    FORWARD_WORKLOAD_VALIDATE_FUNC(ClPreluWorkloadValidate, reasonIfUnsupported, input, alpha, output);
+}
+
 bool ClLayerSupport::IsQuantizeSupported(const TensorInfo& input,
                                          const TensorInfo& output,
                                          Optional<std::string&> reasonIfUnsupported) const
@@ -587,15 +566,27 @@ bool ClLayerSupport::IsReshapeSupported(const TensorInfo& input,
     return true;
 }
 
+bool ClLayerSupport::IsResizeSupported(const TensorInfo& input,
+                                       const TensorInfo& output,
+                                       const ResizeDescriptor& descriptor,
+                                       Optional<std::string&> reasonIfUnsupported) const
+{
+    FORWARD_WORKLOAD_VALIDATE_FUNC(ClResizeWorkloadValidate, reasonIfUnsupported, input, output, descriptor);
+}
+
 bool ClLayerSupport::IsResizeBilinearSupported(const TensorInfo& input,
                                                const TensorInfo& output,
                                                Optional<std::string&> reasonIfUnsupported) const
 {
-    ignore_unused(output);
-    return IsSupportedForDataTypeCl(reasonIfUnsupported,
-                                    input.GetDataType(),
-                                    &TrueFunc<>,
-                                    &FalseFuncU8<>);
+    ResizeDescriptor descriptor;
+    descriptor.m_Method     = ResizeMethod::Bilinear;
+    descriptor.m_DataLayout = DataLayout::NCHW;
+
+    const TensorShape& outputShape = output.GetShape();
+    descriptor.m_TargetHeight = outputShape[2];
+    descriptor.m_TargetWidth  = outputShape[3];
+
+    return IsResizeSupported(input, output, descriptor, reasonIfUnsupported);
 }
 
 bool ClLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
@@ -619,6 +610,18 @@ bool ClLayerSupport::IsSpaceToBatchNdSupported(const TensorInfo& input,
                                    descriptor);
 }
 
+bool ClLayerSupport::IsSpaceToDepthSupported(const TensorInfo& input,
+                                             const TensorInfo& output,
+                                             const SpaceToDepthDescriptor& descriptor,
+                                             Optional<std::string&> reasonIfUnsupported) const
+{
+    FORWARD_WORKLOAD_VALIDATE_FUNC(ClSpaceToDepthWorkloadValidate,
+                                   reasonIfUnsupported,
+                                   input,
+                                   output,
+                                   descriptor);
+}
+
 bool ClLayerSupport::IsSplitterSupported(const TensorInfo& input,
                                          const ViewsDescriptor& descriptor,
                                          Optional<std::string&> reasonIfUnsupported) const