From: Jan Eilers Date: Wed, 3 Jul 2019 17:20:40 +0000 (+0100) Subject: IVGCVSW-3397 Join lstm parameter infos in a struct for isLstmSupported X-Git-Tag: submit/tizen/20200316.035456~463 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=d01a83c8de77c44a938a618918d17385da3baa88;p=platform%2Fupstream%2Farmnn.git IVGCVSW-3397 Join lstm parameter infos in a struct for isLstmSupported !android-nn-driver:1461 Change-Id: I9d8fe7adf13832ed0cbcfe98b2353c2f37011d22 Signed-off-by: Jan Eilers --- diff --git a/include/armnn/ILayerSupport.hpp b/include/armnn/ILayerSupport.hpp index 58722fe..53dd29d 100644 --- a/include/armnn/ILayerSupport.hpp +++ b/include/armnn/ILayerSupport.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -153,28 +154,8 @@ public: const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional reasonIfUnsupported = EmptyOptional(), - const TensorInfo* inputLayerNormWeights = nullptr, - const TensorInfo* forgetLayerNormWeights = nullptr, - const TensorInfo* cellLayerNormWeights = nullptr, - const TensorInfo* outputLayerNormWeights = nullptr) const = 0; + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported = EmptyOptional()) const = 0; virtual bool IsMaximumSupported(const TensorInfo& input0, const TensorInfo& input1, diff --git a/include/armnn/LayerSupport.hpp b/include/armnn/LayerSupport.hpp index 35336ed..65f9d08 100644 --- a/include/armnn/LayerSupport.hpp +++ b/include/armnn/LayerSupport.hpp @@ -9,6 +9,7 @@ #include #include #include +#include "LstmParams.hpp" namespace armnn { @@ -178,15 +179,7 @@ bool IsLstmSupported(const BackendId& backend, const TensorInfo& input, const Te const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer, const TensorInfo& outputStateOut, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, const TensorInfo& cellBias, - const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, char* reasonIfUnsupported = nullptr, + const LstmInputParamsInfo& paramsInfo, char* reasonIfUnsupported = nullptr, size_t reasonIfUnsupportedMaxLength = 1024); /// Deprecated in favor of IBackend and ILayerSupport interfaces diff --git a/include/armnn/LstmParams.hpp b/include/armnn/LstmParams.hpp index a7c57c7..0c8e66d 100644 --- a/include/armnn/LstmParams.hpp +++ b/include/armnn/LstmParams.hpp @@ -5,6 +5,7 @@ #pragma once #include "TensorFwd.hpp" +#include "Exceptions.hpp" namespace armnn { @@ -59,5 +60,149 @@ struct LstmInputParams const ConstTensor* m_OutputLayerNormWeights; }; +struct LstmInputParamsInfo +{ + LstmInputParamsInfo() + : m_InputToInputWeights(nullptr) + , m_InputToForgetWeights(nullptr) + , m_InputToCellWeights(nullptr) + , m_InputToOutputWeights(nullptr) + , m_RecurrentToInputWeights(nullptr) + , m_RecurrentToForgetWeights(nullptr) + , m_RecurrentToCellWeights(nullptr) + , m_RecurrentToOutputWeights(nullptr) + , m_CellToInputWeights(nullptr) + , m_CellToForgetWeights(nullptr) + , m_CellToOutputWeights(nullptr) + , m_InputGateBias(nullptr) + , m_ForgetGateBias(nullptr) + , m_CellBias(nullptr) + , m_OutputGateBias(nullptr) + , m_ProjectionWeights(nullptr) + , m_ProjectionBias(nullptr) + , m_InputLayerNormWeights(nullptr) + , m_ForgetLayerNormWeights(nullptr) + , m_CellLayerNormWeights(nullptr) + , m_OutputLayerNormWeights(nullptr) + { + } + const TensorInfo* m_InputToInputWeights; + const TensorInfo* m_InputToForgetWeights; + const TensorInfo* m_InputToCellWeights; + const TensorInfo* m_InputToOutputWeights; + const TensorInfo* m_RecurrentToInputWeights; + const TensorInfo* m_RecurrentToForgetWeights; + const TensorInfo* m_RecurrentToCellWeights; + const TensorInfo* m_RecurrentToOutputWeights; + const TensorInfo* m_CellToInputWeights; + const TensorInfo* m_CellToForgetWeights; + const TensorInfo* m_CellToOutputWeights; + const TensorInfo* m_InputGateBias; + const TensorInfo* m_ForgetGateBias; + const TensorInfo* m_CellBias; + const TensorInfo* m_OutputGateBias; + const TensorInfo* m_ProjectionWeights; + const TensorInfo* m_ProjectionBias; + const TensorInfo* m_InputLayerNormWeights; + const TensorInfo* m_ForgetLayerNormWeights; + const TensorInfo* m_CellLayerNormWeights; + const TensorInfo* m_OutputLayerNormWeights; + + const TensorInfo& deref(const TensorInfo* tensorInfo) const + { + if (tensorInfo != nullptr) + { + const TensorInfo &temp = *tensorInfo; + return temp; + } + throw InvalidArgumentException("Can't dereference a null pointer"); + } + + const TensorInfo& get_InputToInputWeights() const + { + return deref(m_InputToInputWeights); + } + const TensorInfo& get_InputToForgetWeights() const + { + return deref(m_InputToForgetWeights); + } + const TensorInfo& get_InputToCellWeights() const + { + return deref(m_InputToCellWeights); + } + const TensorInfo& get_InputToOutputWeights() const + { + return deref(m_InputToOutputWeights); + } + const TensorInfo& get_RecurrentToInputWeights() const + { + return deref(m_RecurrentToInputWeights); + } + const TensorInfo& get_RecurrentToForgetWeights() const + { + return deref(m_RecurrentToForgetWeights); + } + const TensorInfo& get_RecurrentToCellWeights() const + { + return deref(m_RecurrentToCellWeights); + } + const TensorInfo& get_RecurrentToOutputWeights() const + { + return deref(m_RecurrentToOutputWeights); + } + const TensorInfo& get_CellToInputWeights() const + { + return deref(m_CellToInputWeights); + } + const TensorInfo& get_CellToForgetWeights() const + { + return deref(m_CellToForgetWeights); + } + const TensorInfo& get_CellToOutputWeights() const + { + return deref(m_CellToOutputWeights); + } + const TensorInfo& get_InputGateBias() const + { + return deref(m_InputGateBias); + } + const TensorInfo& get_ForgetGateBias() const + { + return deref(m_ForgetGateBias); + } + const TensorInfo& get_CellBias() const + { + return deref(m_CellBias); + } + const TensorInfo& get_OutputGateBias() const + { + return deref(m_OutputGateBias); + } + const TensorInfo& get_ProjectionWeights() const + { + return deref(m_ProjectionWeights); + } + const TensorInfo& get_ProjectionBias() const + { + return deref(m_ProjectionBias); + } + const TensorInfo& get_InputLayerNormWeights() const + { + return deref(m_InputLayerNormWeights); + } + const TensorInfo& get_ForgetLayerNormWeights() const + { + return deref(m_ForgetLayerNormWeights); + } + const TensorInfo& get_CellLayerNormWeights() const + { + return deref(m_CellLayerNormWeights); + } + const TensorInfo& get_OutputLayerNormWeights() const + { + return deref(m_OutputLayerNormWeights); + } +}; + } // namespace armnn diff --git a/src/armnn/LayerSupport.cpp b/src/armnn/LayerSupport.cpp index b2ca85c..a2908aa 100644 --- a/src/armnn/LayerSupport.cpp +++ b/src/armnn/LayerSupport.cpp @@ -333,27 +333,13 @@ bool IsLstmSupported(const BackendId& backend, const TensorInfo& input, const Te const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer, const TensorInfo& outputStateOut, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, const TensorInfo& cellBias, - const TensorInfo& outputGateBias, const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, char* reasonIfUnsupported, + const LstmInputParamsInfo& paramsInfo, char* reasonIfUnsupported, size_t reasonIfUnsupportedMaxLength) { FORWARD_LAYER_SUPPORT_FUNC(backend, IsLstmSupported, input, outputStateIn, cellStateIn, scratchBuffer, outputStateOut, cellStateOut, - output, descriptor, inputToForgetWeights, inputToCellWeights, - inputToOutputWeights, recurrentToForgetWeights, - recurrentToCellWeights, recurrentToOutputWeights, - forgetGateBias, cellBias, outputGateBias, - inputToInputWeights, recurrentToInputWeights, - cellToInputWeights, inputGateBias, projectionWeights, - projectionBias, cellToForgetWeights, cellToOutputWeights); + output, descriptor, paramsInfo); } bool IsMaximumSupported(const BackendId& backend, diff --git a/src/backends/backendsCommon/LayerSupportBase.cpp b/src/backends/backendsCommon/LayerSupportBase.cpp index 4488e25..ea22fac 100644 --- a/src/backends/backendsCommon/LayerSupportBase.cpp +++ b/src/backends/backendsCommon/LayerSupportBase.cpp @@ -226,28 +226,8 @@ bool LayerSupportBase::IsLstmSupported(const TensorInfo& input, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional reasonIfUnsupported, - const TensorInfo* inputLayerNormWeights, - const TensorInfo* forgetLayerNormWeights, - const TensorInfo* cellLayerNormWeights, - const TensorInfo* outputLayerNormWeights) const + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported) const { return DefaultLayerSupport(__func__, __FILE__, __LINE__, reasonIfUnsupported); } diff --git a/src/backends/backendsCommon/LayerSupportBase.hpp b/src/backends/backendsCommon/LayerSupportBase.hpp index 03a928a..36b8e77 100644 --- a/src/backends/backendsCommon/LayerSupportBase.hpp +++ b/src/backends/backendsCommon/LayerSupportBase.hpp @@ -140,28 +140,8 @@ public: const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional reasonIfUnsupported = EmptyOptional(), - const TensorInfo* inputLayerNormWeights = nullptr, - const TensorInfo* forgetLayerNormWeights = nullptr, - const TensorInfo* cellLayerNormWeights = nullptr, - const TensorInfo* outputLayerNormWeights = nullptr) const override; + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported = EmptyOptional()) const override; bool IsMaximumSupported(const TensorInfo& input0, const TensorInfo& input1, diff --git a/src/backends/backendsCommon/WorkloadFactory.cpp b/src/backends/backendsCommon/WorkloadFactory.cpp index 3502c38..1c23e17 100644 --- a/src/backends/backendsCommon/WorkloadFactory.cpp +++ b/src/backends/backendsCommon/WorkloadFactory.cpp @@ -388,20 +388,20 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, const TensorInfo& outputGateBias = OverrideDataType(cLayer->m_BasicParameters.m_OutputGateBias->GetTensorInfo(), dataType); - // Optional parameters - const TensorInfo* inputToInputWeights = nullptr; - const TensorInfo* recurrentToInputWeights = nullptr; - const TensorInfo* cellToInputWeights = nullptr; - const TensorInfo* inputGateBias = nullptr; - const TensorInfo* projectionWeights = nullptr; - const TensorInfo* projectionBias = nullptr; - const TensorInfo* cellToForgetWeights = nullptr; - const TensorInfo* cellToOutputWeights = nullptr; - const TensorInfo* inputLayerNormWeights = nullptr; - const TensorInfo* forgetLayerNormWeights = nullptr; - const TensorInfo* cellLayerNormWeights = nullptr; - const TensorInfo* outputLayerNormWeights = nullptr; + LstmInputParamsInfo paramsInfo; + + paramsInfo.m_InputToForgetWeights = &inputToForgetWeights; + paramsInfo.m_InputToCellWeights = &inputToCellWeights; + paramsInfo.m_InputToOutputWeights = &inputToOutputWeights; + paramsInfo.m_RecurrentToForgetWeights = &recurrentToForgetWeights; + paramsInfo.m_RecurrentToCellWeights = &recurrentToCellWeights; + paramsInfo.m_RecurrentToOutputWeights = &recurrentToOutputWeights; + paramsInfo.m_ForgetGateBias = &forgetGateBias; + paramsInfo.m_CellBias = &cellBias; + paramsInfo.m_OutputGateBias = &outputGateBias; + + // Optional parameters TensorInfo optInputToInputWeights; TensorInfo optRecurrentToInputWeights; TensorInfo optCellToInputWeights; @@ -419,32 +419,32 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, { optInputToInputWeights = OverrideDataType(cLayer->m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), dataType); - inputToInputWeights = &optInputToInputWeights; + paramsInfo.m_InputToInputWeights = &optInputToInputWeights; optRecurrentToInputWeights = OverrideDataType(cLayer->m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), dataType); - recurrentToInputWeights = &optRecurrentToInputWeights; + paramsInfo.m_RecurrentToInputWeights = &optRecurrentToInputWeights; if (cLayer->m_CifgParameters.m_CellToInputWeights != nullptr) { optCellToInputWeights = OverrideDataType(cLayer->m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), dataType); - cellToInputWeights = &optCellToInputWeights; + paramsInfo.m_CellToInputWeights = &optCellToInputWeights; } optInputGateBias = OverrideDataType(cLayer->m_CifgParameters.m_InputGateBias->GetTensorInfo(), dataType); - inputGateBias = &optInputGateBias; + paramsInfo.m_InputGateBias = &optInputGateBias; } if(descriptor.m_ProjectionEnabled) { optProjectionWeights = OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), dataType); - projectionWeights = &optProjectionWeights; + paramsInfo.m_ProjectionWeights = &optProjectionWeights; if (cLayer->m_ProjectionParameters.m_ProjectionBias != nullptr) { optProjectionBias = OverrideDataType(cLayer->m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), dataType); - projectionBias = &optProjectionBias; + paramsInfo.m_ProjectionBias = &optProjectionBias; } } @@ -452,29 +452,29 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, { optCellToForgetWeights = OverrideDataType(cLayer->m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), dataType); - cellToForgetWeights = &optCellToForgetWeights; + paramsInfo.m_CellToForgetWeights = &optCellToForgetWeights; optCellToOutputWeights = OverrideDataType(cLayer->m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), dataType); - cellToOutputWeights = &optCellToOutputWeights; + paramsInfo.m_CellToOutputWeights = &optCellToOutputWeights; } if(descriptor.m_LayerNormEnabled) { optInputLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_InputLayerNormWeights->GetTensorInfo(), dataType); - inputLayerNormWeights = &optInputLayerNormWeights; + paramsInfo.m_InputLayerNormWeights = &optInputLayerNormWeights; optForgetLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_ForgetLayerNormWeights->GetTensorInfo(), dataType); - forgetLayerNormWeights = &optForgetLayerNormWeights; + paramsInfo.m_ForgetLayerNormWeights = &optForgetLayerNormWeights; optCellLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_CellLayerNormWeights->GetTensorInfo(), dataType); - cellLayerNormWeights = &optCellLayerNormWeights; + paramsInfo.m_CellLayerNormWeights = &optCellLayerNormWeights; optOutputLayerNormWeights = OverrideDataType( cLayer->m_LayerNormParameters.m_OutputLayerNormWeights->GetTensorInfo(), dataType); - outputLayerNormWeights = &optOutputLayerNormWeights; + paramsInfo.m_OutputLayerNormWeights = &optOutputLayerNormWeights; } result = layerSupportObject->IsLstmSupported( @@ -486,28 +486,8 @@ bool IWorkloadFactory::IsLayerSupported(const BackendId& backendId, cellStateOut, output, descriptor, - inputToForgetWeights, - inputToCellWeights, - inputToOutputWeights, - recurrentToForgetWeights, - recurrentToCellWeights, - recurrentToOutputWeights, - forgetGateBias, - cellBias, - outputGateBias, - inputToInputWeights, - recurrentToInputWeights, - cellToInputWeights, - inputGateBias, - projectionWeights, - projectionBias, - cellToForgetWeights, - cellToOutputWeights, - reason, - inputLayerNormWeights, - forgetLayerNormWeights, - cellLayerNormWeights, - outputLayerNormWeights); + paramsInfo, + reason); break; } case LayerType::Maximum: diff --git a/src/backends/cl/ClLayerSupport.cpp b/src/backends/cl/ClLayerSupport.cpp index 497a643..6d9b197 100644 --- a/src/backends/cl/ClLayerSupport.cpp +++ b/src/backends/cl/ClLayerSupport.cpp @@ -405,28 +405,8 @@ bool ClLayerSupport::IsLstmSupported(const TensorInfo& input, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional reasonIfUnsupported, - const TensorInfo* inputLayerNormWeights, - const TensorInfo* forgetLayerNormWeights, - const TensorInfo* cellLayerNormWeights, - const TensorInfo* outputLayerNormWeights) const + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported) const { FORWARD_WORKLOAD_VALIDATE_FUNC(ClLstmFloatWorkloadValidate, reasonIfUnsupported, @@ -438,23 +418,7 @@ bool ClLayerSupport::IsLstmSupported(const TensorInfo& input, cellStateOut, output, descriptor, - inputToForgetWeights, - inputToCellWeights, - inputToOutputWeights, - recurrentToForgetWeights, - recurrentToCellWeights, - recurrentToOutputWeights, - forgetGateBias, - cellBias, - outputGateBias, - inputToInputWeights, - recurrentToInputWeights, - cellToInputWeights, - inputGateBias, - projectionWeights, - projectionBias, - cellToForgetWeights, - cellToOutputWeights); + paramsInfo); } bool ClLayerSupport::IsMaximumSupported(const TensorInfo& input0, diff --git a/src/backends/cl/ClLayerSupport.hpp b/src/backends/cl/ClLayerSupport.hpp index 4a55997..63a4daf 100644 --- a/src/backends/cl/ClLayerSupport.hpp +++ b/src/backends/cl/ClLayerSupport.hpp @@ -114,28 +114,8 @@ public: const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional reasonIfUnsupported = EmptyOptional(), - const TensorInfo* inputLayerNormWeights = nullptr, - const TensorInfo* forgetLayerNormWeights = nullptr, - const TensorInfo* cellLayerNormWeights = nullptr, - const TensorInfo* outputLayerNormWeights = nullptr) const override; + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported = EmptyOptional()) const override; bool IsMaximumSupported(const TensorInfo& input0, const TensorInfo& input1, diff --git a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp index f4d8974..3dbbbc3 100644 --- a/src/backends/cl/workloads/ClLstmFloatWorkload.cpp +++ b/src/backends/cl/workloads/ClLstmFloatWorkload.cpp @@ -224,22 +224,7 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer, const TensorInfo& outputStateOut, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights) + const LstmInputParamsInfo& paramsInfo) { arm_compute::LSTMParams lstm_params_info; @@ -253,18 +238,21 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T const arm_compute::TensorInfo aclOutputInfo = BuildArmComputeTensorInfo(output); // Basic parameters - const arm_compute::TensorInfo aclInputToForgetWeightsInfo = BuildArmComputeTensorInfo(inputToForgetWeights); - const arm_compute::TensorInfo aclInputToCellWeightsInfo = BuildArmComputeTensorInfo(inputToCellWeights); - const arm_compute::TensorInfo aclInputToOutputWeightsInfo = BuildArmComputeTensorInfo(inputToOutputWeights); + const arm_compute::TensorInfo aclInputToForgetWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputToForgetWeights()); + const arm_compute::TensorInfo aclInputToCellWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputToCellWeights()); + const arm_compute::TensorInfo aclInputToOutputWeightsInfo + = BuildArmComputeTensorInfo(paramsInfo.get_InputToOutputWeights()); const arm_compute::TensorInfo aclRecurrentToForgetWeightsInfo - = BuildArmComputeTensorInfo(recurrentToForgetWeights); + = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToForgetWeights()); const arm_compute::TensorInfo aclRecurrentToCellWeightsInfo - = BuildArmComputeTensorInfo(recurrentToCellWeights); + = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToCellWeights()); const arm_compute::TensorInfo aclRecurrentToOutputWeightsInfo - = BuildArmComputeTensorInfo(recurrentToOutputWeights); - const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(forgetGateBias); - const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(cellBias); - const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(outputGateBias); + = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToOutputWeights()); + const arm_compute::TensorInfo aclForgetGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_ForgetGateBias()); + const arm_compute::TensorInfo aclCellBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellBias()); + const arm_compute::TensorInfo aclOutputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_OutputGateBias()); arm_compute::TensorInfo aclInputToInputWeightsInfo; arm_compute::TensorInfo aclRecurrentToInputWeightsInfo; @@ -277,43 +265,37 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T if (!descriptor.m_CifgEnabled) { - armnn::TensorInfo inputToInputWInfo = *inputToInputWeights; - aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(inputToInputWInfo); - armnn::TensorInfo recurrentToInputWInfo = *recurrentToInputWeights; - aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(recurrentToInputWInfo); + aclInputToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputToInputWeights()); + aclRecurrentToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_RecurrentToInputWeights()); - if (cellToInputWeights != nullptr) + if (paramsInfo.m_CellToInputWeights != nullptr) { - armnn::TensorInfo cellToInputWInfo = *cellToInputWeights; - aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(cellToInputWInfo); + aclCellToInputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToInputWeights()); } - armnn::TensorInfo inputGateBiasInfo = *inputGateBias; - aclInputGateBiasInfo = BuildArmComputeTensorInfo(inputGateBiasInfo); + aclInputGateBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias()); lstm_params_info.set_cifg_params(&aclInputToInputWeightsInfo, &aclRecurrentToInputWeightsInfo, - cellToInputWeights != nullptr ? &aclCellToInputWeightsInfo: nullptr, + paramsInfo.m_CellToInputWeights != nullptr ? + &aclCellToInputWeightsInfo: nullptr, &aclInputGateBiasInfo); } if (descriptor.m_ProjectionEnabled) { - const armnn::TensorInfo& projectionWInfo = *projectionWeights; - aclProjectionWeightsInfo = BuildArmComputeTensorInfo(projectionWInfo); + aclProjectionWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_ProjectionWeights()); - if (projectionBias != nullptr) + if (paramsInfo.m_ProjectionBias != nullptr) { - const armnn::TensorInfo& projectionBiasInfo = *projectionBias; - aclProjectionBiasInfo = BuildArmComputeTensorInfo(projectionBiasInfo); + aclProjectionBiasInfo = BuildArmComputeTensorInfo(paramsInfo.get_InputGateBias()); } lstm_params_info.set_projection_params(&aclProjectionWeightsInfo, - projectionBias != nullptr ? &aclProjectionBiasInfo: nullptr); + paramsInfo.m_ProjectionBias != nullptr ? + &aclProjectionBiasInfo: nullptr); } if (descriptor.m_PeepholeEnabled) { - const armnn::TensorInfo& cellToForgetWInfo = *cellToForgetWeights; - aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(cellToForgetWInfo); - const armnn::TensorInfo& cellToOutputWInfo = *cellToOutputWeights; - aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(cellToOutputWInfo); + aclCellToForgetWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToForgetWeights()); + aclCellToOutputWeightsInfo = BuildArmComputeTensorInfo(paramsInfo.get_CellToOutputWeights()); lstm_params_info.set_peephole_params(&aclCellToForgetWeightsInfo, &aclCellToOutputWeightsInfo); } diff --git a/src/backends/cl/workloads/ClLstmFloatWorkload.hpp b/src/backends/cl/workloads/ClLstmFloatWorkload.hpp index 6a0c41f..9a3211a 100644 --- a/src/backends/cl/workloads/ClLstmFloatWorkload.hpp +++ b/src/backends/cl/workloads/ClLstmFloatWorkload.hpp @@ -49,20 +49,5 @@ arm_compute::Status ClLstmFloatWorkloadValidate(const TensorInfo& input, const T const TensorInfo& cellStateIn, const TensorInfo& scratchBuffer, const TensorInfo& outputStateOut, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor &descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights); + const LstmInputParamsInfo& paramsInfo); } //namespace armnn diff --git a/src/backends/reference/RefLayerSupport.cpp b/src/backends/reference/RefLayerSupport.cpp index ac7f310..59c14c4 100644 --- a/src/backends/reference/RefLayerSupport.cpp +++ b/src/backends/reference/RefLayerSupport.cpp @@ -924,51 +924,11 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional reasonIfUnsupported, - const TensorInfo* inputLayerNormWeights, - const TensorInfo* forgetLayerNormWeights, - const TensorInfo* cellLayerNormWeights, - const TensorInfo* outputLayerNormWeights) const + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported) const { ignore_unused(descriptor); - ignore_unused(inputToForgetWeights); - ignore_unused(inputToCellWeights); - ignore_unused(inputToOutputWeights); - ignore_unused(recurrentToForgetWeights); - ignore_unused(recurrentToCellWeights); - ignore_unused(recurrentToOutputWeights); - ignore_unused(forgetGateBias); - ignore_unused(cellBias); - ignore_unused(outputGateBias); - ignore_unused(inputToInputWeights); - ignore_unused(recurrentToInputWeights); - ignore_unused(cellToInputWeights); - ignore_unused(inputGateBias); - ignore_unused(projectionWeights); - ignore_unused(projectionBias); - ignore_unused(cellToForgetWeights); - ignore_unused(cellToOutputWeights); - ignore_unused(inputLayerNormWeights); - ignore_unused(forgetLayerNormWeights); - ignore_unused(cellLayerNormWeights); - ignore_unused(outputLayerNormWeights); + ignore_unused(paramsInfo); bool supported = true; @@ -977,26 +937,91 @@ bool RefLayerSupport::IsLstmSupported(const TensorInfo& input, DataType::QuantisedSymm16 }; + // check inputs and outputs supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported, "Reference Lstm: input is not a supported type."); - supported &= CheckSupportRule(TypesAreEqual(input, outputStateIn), reasonIfUnsupported, "Reference Lstm: input and outputStateIn types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, cellStateIn), reasonIfUnsupported, "Reference Lstm: input and cellStateIn types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, scratchBuffer), reasonIfUnsupported, "Reference Lstm: input and scratchBuffer types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, outputStateOut), reasonIfUnsupported, "Reference Lstm: input and outputStateOut types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, cellStateOut), reasonIfUnsupported, "Reference Lstm: input and cellStateOut types are mismatched"); - supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported, "Reference Lstm: input and output types are mismatched"); + // check layer parameters + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToForgetWeights()), reasonIfUnsupported, + "Reference Lstm: input and InputToForgetWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToCellWeights()), reasonIfUnsupported, + "Reference Lstm: input and InputToCellWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToOutputWeights()), reasonIfUnsupported, + "Reference Lstm: input and InputToOutputWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToForgetWeights()), reasonIfUnsupported, + "Reference Lstm: input and RecurrentToForgetWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToCellWeights()), reasonIfUnsupported, + "Reference Lstm: input and RecurrentToCellWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToOutputWeights()), reasonIfUnsupported, + "Reference Lstm: input and RecurrentToOutputWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetGateBias()), reasonIfUnsupported, + "Reference Lstm: input and ForgetGateBias types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellBias()), reasonIfUnsupported, + "Reference Lstm: input and CellBias types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputGateBias()), reasonIfUnsupported, + "Reference Lstm: input and OutputGateBias types are mismatched"); + if (!descriptor.m_CifgEnabled) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputToInputWeights()), reasonIfUnsupported, + "Reference Lstm: input and InputToInputWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_RecurrentToInputWeights()), + reasonIfUnsupported, + "Reference Lstm: input and RecurrentToInputWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputGateBias()), reasonIfUnsupported, + "Reference Lstm: input and InputGateBias types are mismatched"); + if (descriptor.m_PeepholeEnabled) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToInputWeights()), + reasonIfUnsupported, + "Reference Lstm: input and CellToInputWeights types are mismatched"); + } + } + if (descriptor.m_PeepholeEnabled) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToForgetWeights()), reasonIfUnsupported, + "Reference Lstm: input and CellToForgetWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellToOutputWeights()), reasonIfUnsupported, + "Reference Lstm: input and CellToOutputWeights types are mismatched"); + } + if (descriptor.m_ProjectionEnabled) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionWeights()), reasonIfUnsupported, + "Reference Lstm: input and mProjectionWeights types are mismatched"); + if (paramsInfo.m_ProjectionBias != nullptr) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ProjectionBias()), reasonIfUnsupported, + "Reference Lstm: input and ProjectionBias types are mismatched"); + } + } + if (descriptor.m_LayerNormEnabled) + { + if (!descriptor.m_CifgEnabled) + { + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_InputLayerNormWeights()), + reasonIfUnsupported, + "Reference Lstm: input and InputLayerNormWeights types are mismatched"); + } + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_ForgetLayerNormWeights()), + reasonIfUnsupported, + "Reference Lstm: input and ForgetLayerNormWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_CellLayerNormWeights()), + reasonIfUnsupported, + "Reference Lstm: input and CellLayerNormWeights types are mismatched"); + supported &= CheckSupportRule(TypesAreEqual(input, paramsInfo.get_OutputLayerNormWeights()), + reasonIfUnsupported, + "Reference Lstm: input and OutputLayerNormWeights types are mismatched"); + } return supported; } diff --git a/src/backends/reference/RefLayerSupport.hpp b/src/backends/reference/RefLayerSupport.hpp index ead4d1c..c0bf188 100644 --- a/src/backends/reference/RefLayerSupport.hpp +++ b/src/backends/reference/RefLayerSupport.hpp @@ -138,28 +138,8 @@ public: const TensorInfo& cellStateOut, const TensorInfo& output, const LstmDescriptor& descriptor, - const TensorInfo& inputToForgetWeights, - const TensorInfo& inputToCellWeights, - const TensorInfo& inputToOutputWeights, - const TensorInfo& recurrentToForgetWeights, - const TensorInfo& recurrentToCellWeights, - const TensorInfo& recurrentToOutputWeights, - const TensorInfo& forgetGateBias, - const TensorInfo& cellBias, - const TensorInfo& outputGateBias, - const TensorInfo* inputToInputWeights, - const TensorInfo* recurrentToInputWeights, - const TensorInfo* cellToInputWeights, - const TensorInfo* inputGateBias, - const TensorInfo* projectionWeights, - const TensorInfo* projectionBias, - const TensorInfo* cellToForgetWeights, - const TensorInfo* cellToOutputWeights, - Optional reasonIfUnsupported = EmptyOptional(), - const TensorInfo* inputLayerNormWeights = nullptr, - const TensorInfo* forgetLayerNormWeights = nullptr, - const TensorInfo* cellLayerNormWeights = nullptr, - const TensorInfo* outputLayerNormWeights = nullptr) const override; + const LstmInputParamsInfo& paramsInfo, + Optional reasonIfUnsupported = EmptyOptional()) const override; bool IsMaximumSupported(const TensorInfo& input0, const TensorInfo& input1,