From d161ba0bc83fa14f7aea4c629ca3e6ea04a2dc34 Mon Sep 17 00:00:00 2001 From: jimfly01 Date: Mon, 28 Jan 2019 12:51:53 +0000 Subject: [PATCH] IVGCVSW-2569 Add implementation of ConstTensor Accept functions * Create the required ConstTensors and pass them to the appropriate visit method. Back fill of dummies added during IVGCVSW-2547 * Moved the VisitDetectionPostProcessLayer function declaration in ILayerVistor to its correct location after the VisitDepthwiseConvolution2dLayer functions. Change-Id: I0bd2f8c3603cbdb933b1216ead96dd8273eb5013 --- include/armnn/ILayerVisitor.hpp | 18 ++-- src/armnn/layers/BatchNormalizationLayer.cpp | 7 +- src/armnn/layers/ConstantLayer.cpp | 4 +- src/armnn/layers/Convolution2dLayer.cpp | 7 +- src/armnn/layers/DepthwiseConvolution2dLayer.cpp | 7 +- src/armnn/layers/FullyConnectedLayer.cpp | 7 +- src/armnn/layers/LstmLayer.cpp | 111 ++++++++++++++++++++++- 7 files changed, 137 insertions(+), 24 deletions(-) diff --git a/include/armnn/ILayerVisitor.hpp b/include/armnn/ILayerVisitor.hpp index c56be81..f30396a 100644 --- a/include/armnn/ILayerVisitor.hpp +++ b/include/armnn/ILayerVisitor.hpp @@ -62,15 +62,6 @@ public: const ConstTensor& weights, const char* name = nullptr) = 0; - /// Function that a Detection PostProcess layer should call back to when its - /// Accept(ILayerVisitor&) function is invoked. - /// @param layer - pointer to the layer which is calling back to this visit function. - /// @param descriptor - Description of the Detection PostProcess layer. - /// @param name - Optional name for the layer. - virtual void VisitDetectionPostProcessLayer(const IConnectableLayer* layer, - const DetectionPostProcessDescriptor& descriptor, - const char* name = nullptr) = 0; - /// Function that a 2D depthwise convolution layer with biases should call back to when its /// Accept(ILayerVisitor&) function is invoked. /// @param layer - pointer to the layer which is calling back to this visit function. @@ -84,6 +75,15 @@ public: const ConstTensor& biases, const char* name = nullptr) = 0; + /// Function that a Detection PostProcess layer should call back to when its + /// Accept(ILayerVisitor&) function is invoked. + /// @param layer - pointer to the layer which is calling back to this visit function. + /// @param descriptor - Description of the Detection PostProcess layer. + /// @param name - Optional name for the layer. + virtual void VisitDetectionPostProcessLayer(const IConnectableLayer* layer, + const DetectionPostProcessDescriptor& descriptor, + const char* name = nullptr) = 0; + /// Function that a fully connected layer without biases should call back to when its Accept(ILayerVisitor&) /// function is invoked. /// @param layer - pointer to the layer which is calling back to this visit function. diff --git a/src/armnn/layers/BatchNormalizationLayer.cpp b/src/armnn/layers/BatchNormalizationLayer.cpp index dfba2d7..8513205 100644 --- a/src/armnn/layers/BatchNormalizationLayer.cpp +++ b/src/armnn/layers/BatchNormalizationLayer.cpp @@ -71,8 +71,11 @@ Layer::ConstantTensors BatchNormalizationLayer::GetConstantTensorsByRef() void BatchNormalizationLayer::Accept(ILayerVisitor& visitor) const { - ConstTensor dummy; - visitor.VisitBatchNormalizationLayer(this, GetParameters(), dummy, dummy, dummy, dummy); + ConstTensor meanTensor(m_Mean->GetTensorInfo(), m_Mean->GetTensor()) ; + ConstTensor varianceTensor(m_Variance->GetTensorInfo(), m_Variance->GetTensor()) ; + ConstTensor betaTensor(m_Beta->GetTensorInfo(), m_Beta->GetTensor()) ; + ConstTensor gammaTensor(m_Gamma->GetTensorInfo(), m_Gamma->GetTensor()) ; + visitor.VisitBatchNormalizationLayer(this, GetParameters(), meanTensor, varianceTensor, betaTensor, gammaTensor); } } // namespace armnn diff --git a/src/armnn/layers/ConstantLayer.cpp b/src/armnn/layers/ConstantLayer.cpp index 8b94fdb..919fd61 100644 --- a/src/armnn/layers/ConstantLayer.cpp +++ b/src/armnn/layers/ConstantLayer.cpp @@ -53,8 +53,8 @@ void ConstantLayer::ValidateTensorShapesFromInputs() void ConstantLayer::Accept(ILayerVisitor& visitor) const { - ConstTensor dummy; - visitor.VisitConstantLayer(this, dummy, GetName()); + ConstTensor layerOutputTensor(m_LayerOutput->GetTensorInfo(), m_LayerOutput->GetTensor()) ; + visitor.VisitConstantLayer(this, layerOutputTensor, GetName()); } } // namespace armnn diff --git a/src/armnn/layers/Convolution2dLayer.cpp b/src/armnn/layers/Convolution2dLayer.cpp index 86eb71d..cb90f81 100644 --- a/src/armnn/layers/Convolution2dLayer.cpp +++ b/src/armnn/layers/Convolution2dLayer.cpp @@ -112,14 +112,15 @@ Layer::ConstantTensors Convolution2dLayer::GetConstantTensorsByRef() void Convolution2dLayer::Accept(ILayerVisitor& visitor) const { - ConstTensor dummy; + ConstTensor weightsTensor(m_Weight->GetTensorInfo(), m_Weight->GetTensor()) ; + ConstTensor biasTensor(m_Bias->GetTensorInfo(), m_Bias->GetConstTensor()); if (GetParameters().m_BiasEnabled) { - visitor.VisitConvolution2dLayer(this, GetParameters(), dummy, dummy, GetName()); + visitor.VisitConvolution2dLayer(this, GetParameters(), weightsTensor, biasTensor, GetName()); } else { - visitor.VisitConvolution2dLayer(this, GetParameters(), dummy, GetName()); + visitor.VisitConvolution2dLayer(this, GetParameters(), weightsTensor, GetName()); } } diff --git a/src/armnn/layers/DepthwiseConvolution2dLayer.cpp b/src/armnn/layers/DepthwiseConvolution2dLayer.cpp index b2d9814..dca13f2 100644 --- a/src/armnn/layers/DepthwiseConvolution2dLayer.cpp +++ b/src/armnn/layers/DepthwiseConvolution2dLayer.cpp @@ -122,14 +122,15 @@ Layer::ConstantTensors DepthwiseConvolution2dLayer::GetConstantTensorsByRef() void DepthwiseConvolution2dLayer::Accept(ILayerVisitor& visitor) const { - ConstTensor dummy; + ConstTensor weightsTensor(m_Weight->GetTensorInfo(), m_Weight->GetTensor()) ; + ConstTensor biasTensor(m_Bias->GetTensorInfo(), m_Bias->GetConstTensor()); if (GetParameters().m_BiasEnabled) { - visitor.VisitDepthwiseConvolution2dLayer(this, GetParameters(), dummy, dummy, GetName()); + visitor.VisitDepthwiseConvolution2dLayer(this, GetParameters(), weightsTensor, biasTensor, GetName()); } else { - visitor.VisitDepthwiseConvolution2dLayer(this, GetParameters(), dummy, GetName()); + visitor.VisitDepthwiseConvolution2dLayer(this, GetParameters(), weightsTensor, GetName()); } } diff --git a/src/armnn/layers/FullyConnectedLayer.cpp b/src/armnn/layers/FullyConnectedLayer.cpp index 977c276..783482e 100644 --- a/src/armnn/layers/FullyConnectedLayer.cpp +++ b/src/armnn/layers/FullyConnectedLayer.cpp @@ -88,14 +88,15 @@ Layer::ConstantTensors FullyConnectedLayer::GetConstantTensorsByRef() void FullyConnectedLayer::Accept(ILayerVisitor& visitor) const { - ConstTensor dummy; + ConstTensor weightsTensor(m_Weight->GetTensorInfo(), m_Weight->GetTensor()) ; + ConstTensor biasTensor(m_Bias->GetTensorInfo(), m_Bias->GetConstTensor()); if (GetParameters().m_BiasEnabled) { - visitor.VisitFullyConnectedLayer(this, GetParameters(), dummy, dummy, GetName()); + visitor.VisitFullyConnectedLayer(this, GetParameters(), weightsTensor, biasTensor, GetName()); } else { - visitor.VisitFullyConnectedLayer(this, GetParameters(), dummy, GetName()); + visitor.VisitFullyConnectedLayer(this, GetParameters(), weightsTensor, GetName()); } } diff --git a/src/armnn/layers/LstmLayer.cpp b/src/armnn/layers/LstmLayer.cpp index 942038a..06140c9 100644 --- a/src/armnn/layers/LstmLayer.cpp +++ b/src/armnn/layers/LstmLayer.cpp @@ -251,8 +251,115 @@ Layer::ConstantTensors LstmLayer::GetConstantTensorsByRef() void LstmLayer::Accept(ILayerVisitor& visitor) const { - LstmInputParams dummy; - visitor.VisitLstmLayer(this, GetParameters(), dummy, GetName()); + LstmInputParams inputParams; + if (m_CifgParameters.m_InputToInputWeights != nullptr) + { + ConstTensor inputToInputWeightsTensor(m_CifgParameters.m_InputToInputWeights->GetTensorInfo(), + m_CifgParameters.m_InputToInputWeights->GetConstTensor()); + inputParams.m_InputToInputWeights = &inputToInputWeightsTensor; + } + if (m_BasicParameters.m_InputToForgetWeights != nullptr) + { + ConstTensor inputToForgetWeightsTensor(m_BasicParameters.m_InputToForgetWeights->GetTensorInfo(), + m_BasicParameters.m_InputToForgetWeights->GetConstTensor()); + inputParams.m_InputToForgetWeights = &inputToForgetWeightsTensor; + } + if (m_BasicParameters.m_InputToCellWeights != nullptr) + { + ConstTensor inputToCellWeightsTensor(m_BasicParameters.m_InputToCellWeights->GetTensorInfo(), + m_BasicParameters.m_InputToCellWeights->GetConstTensor()); + inputParams.m_InputToCellWeights = &inputToCellWeightsTensor; + } + if (m_BasicParameters.m_InputToOutputWeights != nullptr) + { + ConstTensor inputToOutputWeightsTensor(m_BasicParameters.m_InputToOutputWeights->GetTensorInfo(), + m_BasicParameters.m_InputToOutputWeights->GetConstTensor()); + inputParams.m_InputToOutputWeights = &inputToOutputWeightsTensor; + } + if (m_CifgParameters.m_RecurrentToInputWeights != nullptr) + { + ConstTensor recurrentToInputWeightsTensor( + m_CifgParameters.m_RecurrentToInputWeights->GetTensorInfo(), + m_CifgParameters.m_RecurrentToInputWeights->GetConstTensor()); + inputParams.m_RecurrentToInputWeights = &recurrentToInputWeightsTensor; + } + if (m_BasicParameters.m_RecurrentToForgetWeights != nullptr) + { + ConstTensor recurrentToForgetWeightsTensor( + m_BasicParameters.m_RecurrentToForgetWeights->GetTensorInfo(), + m_BasicParameters.m_RecurrentToForgetWeights->GetConstTensor()); + inputParams.m_RecurrentToForgetWeights = &recurrentToForgetWeightsTensor; + } + if (m_BasicParameters.m_RecurrentToCellWeights != nullptr) + { + ConstTensor recurrentToCellWeightsTensor( + m_BasicParameters.m_RecurrentToCellWeights->GetTensorInfo(), + m_BasicParameters.m_RecurrentToCellWeights->GetConstTensor()); + inputParams.m_RecurrentToCellWeights = &recurrentToCellWeightsTensor; + } + if (m_BasicParameters.m_RecurrentToOutputWeights != nullptr) + { + ConstTensor recurrentToOutputWeightsTensor( + m_BasicParameters.m_RecurrentToOutputWeights->GetTensorInfo(), + m_BasicParameters.m_RecurrentToOutputWeights->GetConstTensor()); + inputParams.m_RecurrentToOutputWeights = &recurrentToOutputWeightsTensor; + } + if (m_CifgParameters.m_CellToInputWeights != nullptr) + { + ConstTensor cellToInputWeightsTensor(m_CifgParameters.m_CellToInputWeights->GetTensorInfo(), + m_CifgParameters.m_CellToInputWeights->GetConstTensor()); + inputParams.m_CellToInputWeights = &cellToInputWeightsTensor; + } + if (m_PeepholeParameters.m_CellToForgetWeights != nullptr) + { + ConstTensor cellToForgetWeightsTensor(m_PeepholeParameters.m_CellToForgetWeights->GetTensorInfo(), + m_PeepholeParameters.m_CellToForgetWeights->GetConstTensor()); + inputParams.m_CellToForgetWeights = &cellToForgetWeightsTensor; + } + if (m_PeepholeParameters.m_CellToOutputWeights != nullptr) + { + ConstTensor cellToOutputWeightsTensor(m_PeepholeParameters.m_CellToOutputWeights->GetTensorInfo(), + m_PeepholeParameters.m_CellToOutputWeights->GetConstTensor()); + inputParams.m_CellToOutputWeights = &cellToOutputWeightsTensor; + } + if (m_CifgParameters.m_InputGateBias != nullptr) + { + ConstTensor inputGateBiasTensor(m_CifgParameters.m_InputGateBias->GetTensorInfo(), + m_CifgParameters.m_InputGateBias->GetConstTensor()); + inputParams.m_InputGateBias = &inputGateBiasTensor; + } + if (m_BasicParameters.m_ForgetGateBias != nullptr) + { + ConstTensor forgetGateBiasTensor(m_BasicParameters.m_ForgetGateBias->GetTensorInfo(), + m_BasicParameters.m_ForgetGateBias->GetConstTensor()); + inputParams.m_ForgetGateBias = &forgetGateBiasTensor; + } + if (m_BasicParameters.m_CellBias != nullptr) + { + ConstTensor cellBiasTensor(m_BasicParameters.m_CellBias->GetTensorInfo(), + m_BasicParameters.m_CellBias->GetConstTensor()); + inputParams.m_CellBias = &cellBiasTensor; + } + if (m_BasicParameters.m_OutputGateBias != nullptr) + { + ConstTensor outputGateBias(m_BasicParameters.m_OutputGateBias->GetTensorInfo(), + m_BasicParameters.m_OutputGateBias->GetConstTensor()); + inputParams.m_OutputGateBias = &outputGateBias; + } + if (m_ProjectionParameters.m_ProjectionWeights != nullptr) + { + ConstTensor projectionWeightsTensor(m_ProjectionParameters.m_ProjectionWeights->GetTensorInfo(), + m_ProjectionParameters.m_ProjectionWeights->GetConstTensor()); + inputParams.m_ProjectionWeights = &projectionWeightsTensor; + } + if (m_ProjectionParameters.m_ProjectionBias != nullptr) + { + ConstTensor projectionBiasTensor(m_ProjectionParameters.m_ProjectionBias->GetTensorInfo(), + m_ProjectionParameters.m_ProjectionBias->GetConstTensor()); + inputParams.m_ProjectionBias = &projectionBiasTensor; + } + + visitor.VisitLstmLayer(this, GetParameters(), inputParams, GetName()); } } // namespace armnn -- 2.7.4