Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / RefLayerSupport.hpp
index 9db1c14..5e543ac 100644 (file)
@@ -7,11 +7,14 @@
 #include <armnn/DescriptorsFwd.hpp>
 #include <armnn/Types.hpp>
 #include <armnn/Tensor.hpp>
+#include <layers/LstmLayer.hpp>
+#include <boost/optional.hpp>
 
 namespace armnn
 {
 
 bool IsActivationSupportedRef(const TensorInfo& input,
+                              const TensorInfo& output,
                               const ActivationDescriptor& descriptor,
                               std::string* reasonIfUnsupported = nullptr);
 
@@ -21,6 +24,11 @@ bool IsAdditionSupportedRef(const TensorInfo& input0,
                             std::string* reasonIfUnsupported = nullptr);
 
 bool IsBatchNormalizationSupportedRef(const TensorInfo& input,
+                                      const TensorInfo& output,
+                                      const TensorInfo& mean,
+                                      const TensorInfo& var,
+                                      const TensorInfo& beta,
+                                      const TensorInfo& gamma,
                                       const BatchNormalizationDescriptor& descriptor,
                                       std::string* reasonIfUnsupported = nullptr);
 
@@ -35,11 +43,16 @@ bool IsConvolution2dSupportedRef(const TensorInfo& input,
                                  std::string* reasonIfUnsupported = nullptr);
 
 bool IsDepthwiseConvolutionSupportedRef(const TensorInfo& input,
+                                        const TensorInfo& output,
                                         const DepthwiseConvolution2dDescriptor& descriptor,
                                         const TensorInfo& weights,
+                                        const TensorInfo& biases,
                                         std::string* reasonIfUnsupported = nullptr);
 
 bool IsFullyConnectedSupportedRef(const TensorInfo& input,
+                                  const TensorInfo& output,
+                                  const TensorInfo& weights,
+                                  const TensorInfo& biases,
                                   const FullyConnectedDescriptor& descriptor,
                                   std::string* reasonIfUnsupported = nullptr);
 
@@ -47,14 +60,30 @@ bool IsInputSupportedRef(const TensorInfo& input,
                          std::string* reasonIfUnsupported = nullptr);
 
 bool IsL2NormalizationSupportedRef(const TensorInfo& input,
+                                   const TensorInfo& output,
                                    std::string* reasonIfUnsupported = nullptr);
 
+bool IsLstmSupportedRef(const TensorInfo& input, const TensorInfo& outputStateIn,
+                        const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer,
+                        const TensorInfo& outputStateOut, const TensorInfo& cellStateOut,
+                        const TensorInfo& output, const LstmDescriptor& descriptor,
+                        const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights,
+                        const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights,
+                        const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights,
+                        const TensorInfo& forgetGateBias, const TensorInfo& cellBias,
+                        const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights,
+                        const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights,
+                        const TensorInfo* inputGateBias, const TensorInfo* projectionWeights,
+                        const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights,
+                        const TensorInfo* cellToOutputWeights, std::string* reasonIfUnsupported = nullptr);
+
 bool IsMergerSupportedRef(const std::vector<const TensorInfo*> inputs,
                           const OriginsDescriptor& descriptor,
                           std::string* reasonIfUnsupported = nullptr);
 
 bool IsMultiplicationSupportedRef(const TensorInfo& input0,
                                   const TensorInfo& input1,
+                                  const TensorInfo& output,
                                   std::string* reasonIfUnsupported = nullptr);
 
 bool IsNormalizationSupportedRef(const TensorInfo& input,
@@ -79,6 +108,7 @@ bool IsResizeBilinearSupportedRef(const TensorInfo& input,
                                   std::string* reasonIfUnsupported = nullptr);
 
 bool IsSoftmaxSupportedRef(const TensorInfo& input,
+                           const TensorInfo& output,
                            const SoftmaxDescriptor& descriptor,
                            std::string* reasonIfUnsupported = nullptr);
 
@@ -97,4 +127,12 @@ bool IsFloorSupportedRef(const TensorInfo& input,
                          const TensorInfo& output,
                          std::string* reasonIfUnsupported = nullptr);
 
+bool IsConvertFp16ToFp32SupportedRef(const TensorInfo& input,
+                                     const TensorInfo& output,
+                                     std::string* reasonIfUnsupported = nullptr);
+
+bool IsConvertFp32ToFp16SupportedRef(const TensorInfo& input,
+                                     const TensorInfo& output,
+                                     std::string* reasonIfUnsupported = nullptr);
+
 }