Release 18.08
[platform/upstream/armnn.git] / src / armnn / backends / ClWorkloads / ClFullyConnectedFloat32Workload.cpp
index 5dfab9c..5014dd2 100644 (file)
@@ -7,47 +7,89 @@
 #include "backends/ClTensorHandle.hpp"
 #include "backends/CpuTensorHandle.hpp"
 #include "backends/ArmComputeTensorUtils.hpp"
+#include "backends/ArmComputeUtils.hpp"
+#include "backends/ClLayerSupport.hpp"
 
 namespace armnn
 {
 using namespace armcomputetensorutils;
 
+arm_compute::Status ClFullyConnectedWorkloadValidate(const TensorInfo& input,
+                                                     const TensorInfo& output,
+                                                     const TensorInfo& weights,
+                                                     const TensorInfo& biases,
+                                                     const FullyConnectedDescriptor& descriptor)
+{
+    const arm_compute::TensorInfo aclInput = BuildArmComputeTensorInfo(input);
+    const arm_compute::TensorInfo aclOutput = BuildArmComputeTensorInfo(output);
+    const arm_compute::TensorInfo aclWeights = BuildArmComputeTensorInfo(weights);
+
+    arm_compute::TensorInfo aclBiases;
+    arm_compute::TensorInfo *optionalAclBiases = nullptr;
+    if (descriptor.m_BiasEnabled)
+    {
+        aclBiases  = BuildArmComputeTensorInfo(biases);
+        optionalAclBiases = &aclBiases;
+    }
+
+    const arm_compute::FullyConnectedLayerInfo fullyConnectedLayerInfo =
+        ConvertFullyConnectedDescriptorToAclFullyConnectedLayerInfo(descriptor);
+
+    return arm_compute::CLFullyConnectedLayer::validate(&aclInput,
+                                                        &aclWeights,
+                                                        optionalAclBiases,
+                                                        &aclOutput,
+                                                        fullyConnectedLayerInfo);
+}
+
 ClFullyConnectedFloat32Workload::ClFullyConnectedFloat32Workload(const FullyConnectedQueueDescriptor& descriptor,
     const WorkloadInfo& info, std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager)
-    : Float32Workload<FullyConnectedQueueDescriptor>(descriptor, info)
-    , m_FullyConnected(memoryManager)
+    : FloatWorkload<FullyConnectedQueueDescriptor>(descriptor, info)
+    , m_FullyConnectedLayer(memoryManager)
 {
+    m_WeightsTensor = std::make_unique<arm_compute::CLTensor>();
+    BuildArmComputeTensor(*m_WeightsTensor, m_Data.m_Weight->GetTensorInfo());
 
-    BuildArmComputeTensor(m_WeightsTensor, m_Data.m_Weight->GetTensorInfo());
-
-    arm_compute::CLTensor* optionalBiasTensor = nullptr;
     if (m_Data.m_Parameters.m_BiasEnabled)
     {
-        BuildArmComputeTensor(m_BiasesTensor, m_Data.m_Bias->GetTensorInfo());
-        optionalBiasTensor = &m_BiasesTensor;
+        m_BiasesTensor = std::make_unique<arm_compute::CLTensor>();
+        BuildArmComputeTensor(*m_BiasesTensor, m_Data.m_Bias->GetTensorInfo());
     }
 
     m_Data.ValidateInputsOutputs("ClFullyConnectedFloat32Workload", 1, 1);
 
     arm_compute::ICLTensor& input  = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
     arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
+
     // Construct
-    m_FullyConnected.configure(
-        &input, &m_WeightsTensor, optionalBiasTensor, &output, m_Data.m_Parameters.m_TransposeWeightMatrix);
+    arm_compute::FullyConnectedLayerInfo fc_info;
+    fc_info.transpose_weights = m_Data.m_Parameters.m_TransposeWeightMatrix;
+    m_FullyConnectedLayer.configure(&input, m_WeightsTensor.get(), m_BiasesTensor.get(), &output, fc_info);
 
     // Allocate
-    InitialiseArmComputeClTensorData(m_WeightsTensor, m_Data.m_Weight->GetConstTensor<float>());
+    InitializeArmComputeClTensorDataForFloatTypes(*m_WeightsTensor, m_Data.m_Weight);
 
-    if (optionalBiasTensor)
+    if (m_BiasesTensor)
     {
-        InitialiseArmComputeClTensorData(*optionalBiasTensor, m_Data.m_Bias->GetConstTensor<float>());
+        InitializeArmComputeClTensorDataForFloatTypes(*m_BiasesTensor, m_Data.m_Bias);
     }
+
+    // Force Compute Library to perform the necessary copying and reshaping, after which
+    // delete all the input tensors that will no longer be needed
+    m_FullyConnectedLayer.prepare();
+    FreeUnusedTensors();
 }
 
 void ClFullyConnectedFloat32Workload::Execute() const
 {
-    ARMNN_SCOPED_PROFILING_EVENT(Compute::GpuAcc, "ClFullyConnectedFloat32Workload_Execute");
-    m_FullyConnected.run();
+    ARMNN_SCOPED_PROFILING_EVENT_CL("ClFullyConnectedFloat32Workload_Execute");
+    m_FullyConnectedLayer.run();
+}
+
+void ClFullyConnectedFloat32Workload::FreeUnusedTensors()
+{
+    FreeTensorIfUnused(m_WeightsTensor);
+    FreeTensorIfUnused(m_BiasesTensor);
 }
 
 } //namespace armnn