IVGCVSW-949 Add 8-bit fully connected support
authorMatthew Bentham <matthew.bentham@arm.com>
Mon, 17 Sep 2018 10:17:41 +0000 (11:17 +0100)
committerMatthew Bentham <matthew.bentham@arm.com>
Wed, 10 Oct 2018 15:16:56 +0000 (16:16 +0100)
Change-Id: I0953bb8dbc4b76001f207e37c8c2742a6ebd888b

Android.mk
CMakeLists.txt
src/armnn/backends/ClLayerSupport.cpp
src/armnn/backends/ClWorkloadFactory.cpp
src/armnn/backends/ClWorkloads.hpp
src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.cpp [moved from src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.cpp with 74% similarity]
src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.hpp [moved from src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.hpp with 68% similarity]
src/armnn/backends/test/ArmComputeCl.cpp
src/armnn/backends/test/CreateWorkloadCl.cpp

index c070b28..ad02db9 100644 (file)
@@ -61,7 +61,7 @@ LOCAL_SRC_FILES := \
         src/armnn/backends/ClWorkloads/ClDepthwiseConvolutionUint8Workload.cpp \
         src/armnn/backends/ClWorkloads/ClDivisionFloatWorkload.cpp \
         src/armnn/backends/ClWorkloads/ClFloorFloatWorkload.cpp \
-        src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.cpp \
+        src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.cpp \
         src/armnn/backends/ClWorkloads/ClL2NormalizationFloatWorkload.cpp \
         src/armnn/backends/ClWorkloads/ClLstmFloatWorkload.cpp \
         src/armnn/backends/ClWorkloads/ClMergerFloatWorkload.cpp \
index 4290461..a5dde68 100644 (file)
@@ -506,8 +506,8 @@ if(ARMCOMPUTECL)
         src/armnn/backends/ClWorkloads/ClDepthwiseConvolutionUint8Workload.hpp
         src/armnn/backends/ClWorkloads/ClFloorFloatWorkload.cpp
         src/armnn/backends/ClWorkloads/ClFloorFloatWorkload.hpp
-        src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.cpp
-        src/armnn/backends/ClWorkloads/ClFullyConnectedFloatWorkload.hpp
+        src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.cpp
+        src/armnn/backends/ClWorkloads/ClFullyConnectedWorkload.hpp
         src/armnn/backends/ClWorkloads/ClL2NormalizationFloatWorkload.cpp
         src/armnn/backends/ClWorkloads/ClL2NormalizationFloatWorkload.hpp
         src/armnn/backends/ClWorkloads/ClLstmFloatWorkload.cpp
index 4664c2e..30a1330 100644 (file)
@@ -24,7 +24,7 @@
 #include "ClWorkloads/ClDivisionFloatWorkload.hpp"
 #include "ClWorkloads/ClL2NormalizationFloatWorkload.hpp"
 #include "ClWorkloads/ClMultiplicationFloatWorkload.hpp"
-#include "ClWorkloads/ClFullyConnectedFloatWorkload.hpp"
+#include "ClWorkloads/ClFullyConnectedWorkload.hpp"
 #include "ClWorkloads/ClPooling2dBaseWorkload.hpp"
 #include "ClWorkloads/ClPermuteWorkload.hpp"
 #include "ClWorkloads/ClNormalizationFloatWorkload.hpp"
@@ -269,11 +269,6 @@ bool IsFullyConnectedSupportedCl(const TensorInfo& input,
                                  const FullyConnectedDescriptor& descriptor,
                                  std::string* reasonIfUnsupported)
 {
-    // At the moment U8 is unsupported
-    if (input.GetDataType() == DataType::QuantisedAsymm8)
-    {
-        return false;
-    }
     FORWARD_WORKLOAD_VALIDATE_FUNC(ClFullyConnectedWorkloadValidate,
                                    reasonIfUnsupported,
                                    input,
index c35f044..591fb85 100644 (file)
@@ -116,8 +116,8 @@ std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateMerger(const MergerQu
 std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreateFullyConnected(
     const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info) const
 {
-    return MakeWorkload<ClFullyConnectedFloatWorkload, NullWorkload>(descriptor, info,
-                                                                       m_MemoryManager.GetIntraLayerManager());
+    return MakeWorkload<ClFullyConnectedWorkload, ClFullyConnectedWorkload>(descriptor, info,
+                                                                            m_MemoryManager.GetIntraLayerManager());
 }
 
 std::unique_ptr<armnn::IWorkload> ClWorkloadFactory::CreatePermute(const PermuteQueueDescriptor& descriptor,
index 3472bca..2bbda8a 100644 (file)
@@ -18,7 +18,7 @@
 #include "backends/ClWorkloads/ClDepthwiseConvolutionUint8Workload.hpp"
 #include "backends/ClWorkloads/ClDivisionFloatWorkload.hpp"
 #include "backends/ClWorkloads/ClFloorFloatWorkload.hpp"
-#include "backends/ClWorkloads/ClFullyConnectedFloatWorkload.hpp"
+#include "backends/ClWorkloads/ClFullyConnectedWorkload.hpp"
 #include "backends/ClWorkloads/ClL2NormalizationFloatWorkload.hpp"
 #include "backends/ClWorkloads/ClLstmFloatWorkload.hpp"
 #include "backends/ClWorkloads/ClMergerFloatWorkload.hpp"
@@ -3,7 +3,7 @@
 // SPDX-License-Identifier: MIT
 //
 
-#include "ClFullyConnectedFloatWorkload.hpp"
+#include "ClFullyConnectedWorkload.hpp"
 #include "backends/ClTensorHandle.hpp"
 #include "backends/CpuTensorHandle.hpp"
 #include "backends/ArmComputeTensorUtils.hpp"
@@ -42,9 +42,9 @@ arm_compute::Status ClFullyConnectedWorkloadValidate(const TensorInfo& input,
                                                         fullyConnectedLayerInfo);
 }
 
-ClFullyConnectedFloatWorkload::ClFullyConnectedFloatWorkload(const FullyConnectedQueueDescriptor& descriptor,
+ClFullyConnectedWorkload::ClFullyConnectedWorkload(const FullyConnectedQueueDescriptor& descriptor,
     const WorkloadInfo& info, std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager)
-    : FloatWorkload<FullyConnectedQueueDescriptor>(descriptor, info)
+    : BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info)
     , m_FullyConnectedLayer(memoryManager)
 {
     m_WeightsTensor = std::make_unique<arm_compute::CLTensor>();
@@ -56,7 +56,7 @@ ClFullyConnectedFloatWorkload::ClFullyConnectedFloatWorkload(const FullyConnecte
         BuildArmComputeTensor(*m_BiasesTensor, m_Data.m_Bias->GetTensorInfo());
     }
 
-    m_Data.ValidateInputsOutputs("ClFullyConnectedFloatWorkload", 1, 1);
+    m_Data.ValidateInputsOutputs("ClFullyConnectedWorkload", 1, 1);
 
     arm_compute::ICLTensor& input  = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
     arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
@@ -67,11 +67,25 @@ ClFullyConnectedFloatWorkload::ClFullyConnectedFloatWorkload(const FullyConnecte
     m_FullyConnectedLayer.configure(&input, m_WeightsTensor.get(), m_BiasesTensor.get(), &output, fc_info);
 
     // Allocate
-    InitializeArmComputeClTensorDataForFloatTypes(*m_WeightsTensor, m_Data.m_Weight);
+    if (m_Data.m_Weight->GetTensorInfo().GetDataType() == DataType::QuantisedAsymm8)
+    {
+        InitialiseArmComputeClTensorData(*m_WeightsTensor, m_Data.m_Weight->GetConstTensor<uint8_t>());
+    }
+    else
+    {
+        InitializeArmComputeClTensorDataForFloatTypes(*m_WeightsTensor, m_Data.m_Weight);
+    }
 
     if (m_BiasesTensor)
     {
-        InitializeArmComputeClTensorDataForFloatTypes(*m_BiasesTensor, m_Data.m_Bias);
+        if (m_Data.m_Bias->GetTensorInfo().GetDataType() == DataType::Signed32)
+        {
+            InitialiseArmComputeClTensorData(*m_BiasesTensor, m_Data.m_Bias->GetConstTensor<int32_t>());
+        }
+        else
+        {
+            InitializeArmComputeClTensorDataForFloatTypes(*m_BiasesTensor, m_Data.m_Bias);
+        }
     }
 
     // Force Compute Library to perform the necessary copying and reshaping, after which
@@ -80,13 +94,13 @@ ClFullyConnectedFloatWorkload::ClFullyConnectedFloatWorkload(const FullyConnecte
     FreeUnusedTensors();
 }
 
-void ClFullyConnectedFloatWorkload::Execute() const
+void ClFullyConnectedWorkload::Execute() const
 {
-    ARMNN_SCOPED_PROFILING_EVENT_CL("ClFullyConnectedFloatWorkload_Execute");
+    ARMNN_SCOPED_PROFILING_EVENT_CL("ClFullyConnectedWorkload_Execute");
     m_FullyConnectedLayer.run();
 }
 
-void ClFullyConnectedFloatWorkload::FreeUnusedTensors()
+void ClFullyConnectedWorkload::FreeUnusedTensors()
 {
     FreeTensorIfUnused(m_WeightsTensor);
     FreeTensorIfUnused(m_BiasesTensor);
@@ -20,14 +20,14 @@ arm_compute::Status ClFullyConnectedWorkloadValidate(const TensorInfo& input,
                                                      const TensorInfo& biases,
                                                      const FullyConnectedDescriptor& descriptor);
 
-class ClFullyConnectedFloatWorkload : public armnn::FloatWorkload<armnn::FullyConnectedQueueDescriptor>
+class ClFullyConnectedWorkload : public armnn::BaseWorkload<armnn::FullyConnectedQueueDescriptor>
 {
 public:
-    ClFullyConnectedFloatWorkload(const armnn::FullyConnectedQueueDescriptor& descriptor,
-                                  const armnn::WorkloadInfo& info,
-                                  std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager);
+    ClFullyConnectedWorkload(const armnn::FullyConnectedQueueDescriptor& descriptor,
+                             const armnn::WorkloadInfo& info,
+                             std::shared_ptr<arm_compute::MemoryManagerOnDemand>& memoryManager);
 
-    using armnn::FloatWorkload<armnn::FullyConnectedQueueDescriptor>::m_Data;
+    using armnn::BaseWorkload<armnn::FullyConnectedQueueDescriptor>::m_Data;
     void Execute() const override;
 
 private:
index 2c1d8b6..d8a70f0 100644 (file)
@@ -42,6 +42,8 @@ ARMNN_AUTO_TEST_CASE(ReLu6Uint8, BoundedReLuUint8UpperBoundOnlyTest)
 ARMNN_AUTO_TEST_CASE(SimpleFullyConnected, FullyConnectedFloat32Test, false, false)
 ARMNN_AUTO_TEST_CASE(SimpleFullyConnectedWithBias, FullyConnectedFloat32Test, true, false)
 ARMNN_AUTO_TEST_CASE(SimpleFullyConnectedWithTranspose, FullyConnectedFloat32Test, false, true)
+ARMNN_AUTO_TEST_CASE(FullyConnectedUint8, FullyConnectedUint8Test, false)
+ARMNN_AUTO_TEST_CASE(FullyConnectedBiasedUint8, FullyConnectedUint8Test, true)
 
 ARMNN_AUTO_TEST_CASE(FullyConnectedLarge, FullyConnectedLargeTest, false)
 ARMNN_AUTO_TEST_CASE(FullyConnectedLargeTransposed, FullyConnectedLargeTest, true)
index a273582..bce265c 100644 (file)
@@ -268,12 +268,12 @@ static void ClCreateFullyConnectedWorkloadTest()
 
 BOOST_AUTO_TEST_CASE(CreateFullyConnectedFloatWorkloadTest)
 {
-    ClCreateFullyConnectedWorkloadTest<ClFullyConnectedFloatWorkload, armnn::DataType::Float32>();
+    ClCreateFullyConnectedWorkloadTest<ClFullyConnectedWorkload, armnn::DataType::Float32>();
 }
 
 BOOST_AUTO_TEST_CASE(CreateFullyConnectedFloat16WorkloadTest)
 {
-    ClCreateFullyConnectedWorkloadTest<ClFullyConnectedFloatWorkload, armnn::DataType::Float16>();
+    ClCreateFullyConnectedWorkloadTest<ClFullyConnectedWorkload, armnn::DataType::Float16>();
 }
 
 template <typename NormalizationWorkloadType, typename armnn::DataType DataType>