IVGCVSW-3224 Add Uint8 support for Rsqrt
authornikraj01 <nikhil.raj@arm.com>
Fri, 14 Jun 2019 08:40:34 +0000 (09:40 +0100)
committernikraj01 <nikhil.raj@arm.com>
Fri, 14 Jun 2019 08:40:34 +0000 (09:40 +0100)
Change-Id: I45598fc9b6d408b19d8d050e64c12b1d48535fa3
Signed-off-by: nikraj01 <nikhil.raj@arm.com>
src/backends/backendsCommon/WorkloadData.cpp
src/backends/backendsCommon/test/LayerTests.hpp
src/backends/reference/RefLayerSupport.cpp
src/backends/reference/RefWorkloadFactory.cpp
src/backends/reference/test/RefCreateWorkloadTests.cpp
src/backends/reference/test/RefLayerTests.cpp

index 1e14b65..20e1252 100644 (file)
@@ -1468,6 +1468,21 @@ void RsqrtQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
                               "RsqrtQueueDescriptor",
                               "input",
                               "output");
+
+    std::vector<DataType> supportedTypes =
+    {
+            DataType::Float16,
+            DataType::Float32,
+            DataType::QuantisedAsymm8
+    };
+
+    ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+                      supportedTypes,
+                      "RsqrtQueueDescriptor");
+
+    ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+                      {workloadInfo.m_InputTensorInfos[0].GetDataType()},
+                      "RsqrtQueueDescriptor");
 }
 
 void GatherQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
index 8bbd0d4..8a5a611 100644 (file)
@@ -887,8 +887,8 @@ LayerTestResult<T, 2> Rsqrt2dTestCommon(
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         const armnn::TensorInfo inputTensorInfo,
         const armnn::TensorInfo outputTensorInfo,
-        std::vector<T> inputValues,
-        std::vector<T> expectedOutputValues);
+        const std::vector<float>& inputValues,
+        const std::vector<float>& expectedOutputValues);
 
 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
 LayerTestResult<T, 2> Rsqrt2dTest(
@@ -1941,19 +1941,21 @@ std::vector<T> ConvertToDataType(const std::vector<float>& input,
     return output;
 }
 
-template<typename T>
+template<armnn::DataType ArmnnType, typename T>
 LayerTestResult<T, 2> Rsqrt2dTestCommon(
         armnn::IWorkloadFactory& workloadFactory,
         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
         const armnn::TensorInfo inputTensorInfo,
         const armnn::TensorInfo outputTensorInfo,
-        std::vector<T> inputValues,
-        std::vector<T> expectedOutputValues)
+        const std::vector<float>& inputValues,
+        const std::vector<float>& expectedOutputValues)
 {
-    auto inputTensor = MakeTensor<T, 2>(inputTensorInfo, std::vector<T>(inputValues));
+    auto inputTensor = MakeTensor<T, 2>(inputTensorInfo, ConvertToDataType<ArmnnType>(inputValues,inputTensorInfo));
 
     LayerTestResult<T, 2> result(outputTensorInfo);
-    result.outputExpected = MakeTensor<T, 2>(outputTensorInfo, std::vector<T>(expectedOutputValues));
+
+    result.outputExpected = MakeTensor<T, 2>(outputTensorInfo,
+                                             ConvertToDataType<ArmnnType>(expectedOutputValues,outputTensorInfo));
 
     std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
     std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
@@ -1988,22 +1990,27 @@ LayerTestResult<T, 2> Rsqrt2dTest(
     const armnn::TensorShape inputShape{ 2, 2 };
     const armnn::TensorShape outputShape{ 2, 2 };
 
-    const armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType);
-    const armnn::TensorInfo outputTensorInfo(outputShape, ArmnnType);
+    armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType);
+    inputTensorInfo.SetQuantizationScale(0.1f);
+    inputTensorInfo.SetQuantizationOffset(0);
+
+    armnn::TensorInfo outputTensorInfo(outputShape, ArmnnType);
+    outputTensorInfo.SetQuantizationScale(0.1f);
+    outputTensorInfo.SetQuantizationOffset(0);
 
-    std::vector<T> inputValues
-            {
-                    1.f, 4.f,
-                    16.f, 25.f
-            };
+    std::vector<float> inputValues
+    {
+        1.f, 4.f,
+        16.f, 25.f
+    };
 
-    std::vector<T> expectedOutputValues
-            {
-                    1.f, 0.5f,
-                    0.25f, 0.2f
-            };
+    std::vector<float> expectedOutputValues
+    {
+        1.f, 0.5f,
+        0.25f, 0.2f
+    };
 
-    return Rsqrt2dTestCommon<T>(workloadFactory, memoryManager,
+    return Rsqrt2dTestCommon<ArmnnType>(workloadFactory, memoryManager,
                                 inputTensorInfo, outputTensorInfo,
                                 inputValues, expectedOutputValues);
 }
@@ -2016,25 +2023,31 @@ LayerTestResult<T, 3> Rsqrt3dTest(
     const armnn::TensorShape inputShape{ 3, 1, 2 };
     const armnn::TensorShape outputShape{ 3, 1, 2 };
 
-    const armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType);
-    const armnn::TensorInfo outputTensorInfo(outputShape, ArmnnType);
+    armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType);
+    inputTensorInfo.SetQuantizationScale(0.1f);
+    inputTensorInfo.SetQuantizationOffset(0);
 
-    std::vector<T> inputValues
-            {
-                    1.f, 4.f, 16.f,
-                    25.f, 64.f, 100.f
-            };
+    armnn::TensorInfo outputTensorInfo(outputShape, ArmnnType);
+    outputTensorInfo.SetQuantizationScale(0.1f);
+    outputTensorInfo.SetQuantizationOffset(0);
 
-    std::vector<T> expectedOutputValues
-            {
-                    1.f, 0.5f, 0.25f,
-                    0.2f, 0.125f, 0.1f
-            };
+    std::vector<float> inputValues
+    {
+        1.f, 4.f, 16.f,
+        25.f, 64.f, 100.f
+    };
 
-    auto inputTensor = MakeTensor<T, 3>(inputTensorInfo, std::vector<T>(inputValues));
+    std::vector<float> expectedOutputValues
+    {
+        1.f, 0.5f, 0.25f,
+        0.2f, 0.125f, 0.1f
+    };
+
+    auto inputTensor = MakeTensor<T, 3>(inputTensorInfo, ConvertToDataType<ArmnnType>(inputValues,inputTensorInfo));
 
     LayerTestResult<T, 3> result(outputTensorInfo);
-    result.outputExpected = MakeTensor<T, 3>(outputTensorInfo, std::vector<T>(expectedOutputValues));
+    result.outputExpected = MakeTensor<T, 3>(outputTensorInfo,
+                                             ConvertToDataType<ArmnnType>(expectedOutputValues,outputTensorInfo));
 
     std::unique_ptr<armnn::ITensorHandle> inputHandle = workloadFactory.CreateTensorHandle(inputTensorInfo);
     std::unique_ptr<armnn::ITensorHandle> outputHandle = workloadFactory.CreateTensorHandle(outputTensorInfo);
@@ -2069,20 +2082,23 @@ LayerTestResult<T, 2> RsqrtZeroTest(
     const armnn::TensorShape inputShape{ 1, 2 };
     const armnn::TensorShape outputShape{ 1, 2 };
 
-    const armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType);
-    const armnn::TensorInfo outputTensorInfo(outputShape, ArmnnType);
+    armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType);
+    inputTensorInfo.SetQuantizationScale(0.1f);
+
+    armnn::TensorInfo outputTensorInfo(outputShape, ArmnnType);
+    outputTensorInfo.SetQuantizationScale(0.1f);
 
-    std::vector<T> inputValues
-            {
-                    0.f, -0.f
-            };
+    std::vector<float> inputValues
+    {
+        0.f, -0.f
+    };
 
-    std::vector<T> expectedOutputValues
-            {
-                    INFINITY, -INFINITY
-            };
+    std::vector<float> expectedOutputValues
+    {
+        INFINITY, -INFINITY
+    };
 
-    return Rsqrt2dTestCommon<T>(workloadFactory, memoryManager,
+    return Rsqrt2dTestCommon<ArmnnType>(workloadFactory, memoryManager,
                                 inputTensorInfo, outputTensorInfo,
                                 inputValues, expectedOutputValues);
 }
@@ -2095,20 +2111,25 @@ LayerTestResult<T, 2> RsqrtNegativeTest(
     const armnn::TensorShape inputShape{ 1, 2 };
     const armnn::TensorShape outputShape{ 1, 2 };
 
-    const armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType);
-    const armnn::TensorInfo outputTensorInfo(outputShape, ArmnnType);
+    armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType);
+    inputTensorInfo.SetQuantizationScale(0.1f);
+    inputTensorInfo.SetQuantizationOffset(0);
 
-    std::vector<T> inputValues
-            {
-                    -25.f, -16.f
-            };
+    armnn::TensorInfo outputTensorInfo(outputShape, ArmnnType);
+    outputTensorInfo.SetQuantizationScale(0.1f);
+    outputTensorInfo.SetQuantizationOffset(0);
 
-    std::vector<T> expectedOutputValues
-            {
-                    -NAN, -NAN
-            };
+    std::vector<float> inputValues
+    {
+        -25.f, -16.f
+    };
+
+    std::vector<float> expectedOutputValues
+    {
+        -NAN, -NAN
+    };
 
-    return Rsqrt2dTestCommon<T>(workloadFactory, memoryManager,
+    return Rsqrt2dTestCommon<ArmnnType>(workloadFactory, memoryManager,
                                 inputTensorInfo, outputTensorInfo,
                                 inputValues, expectedOutputValues);
 }
index f2ab9ed..b508dfd 100644 (file)
@@ -1116,11 +1116,26 @@ bool RefLayerSupport::IsRsqrtSupported(const TensorInfo& input,
                                        const TensorInfo& output,
                                        Optional<std::string&> reasonIfUnsupported) const
 {
-    ignore_unused(output);
-    return IsSupportedForDataTypeRef(reasonIfUnsupported,
-                                     input.GetDataType(),
-                                     &TrueFunc<>,
-                                     &FalseFuncU8<>);
+    bool supported = true;
+    std::array<DataType,2> supportedTypes =
+    {
+            DataType::Float32,
+            DataType::QuantisedAsymm8
+    };
+
+    supported &= CheckSupportRule(TypeAnyOf(input, supportedTypes), reasonIfUnsupported,
+                                  "Reference rsqrt: input type not supported");
+
+    supported &= CheckSupportRule(TypeAnyOf(output, supportedTypes), reasonIfUnsupported,
+                                  "Reference rsqrt: output type not supported");
+
+    supported &= CheckSupportRule(TypesAreEqual(input, output), reasonIfUnsupported,
+                                  "Reference rsqrt: input and output types not matching");
+
+    supported &= CheckSupportRule(ShapesAreSameTotalSize(input, output), reasonIfUnsupported,
+                                  "Reference Rsqrt: input and output shapes have different number of total elements");
+
+    return supported;
 }
 
 bool RefLayerSupport::IsSoftmaxSupported(const TensorInfo& input,
index 1ef88a0..cb26f26 100644 (file)
@@ -402,10 +402,6 @@ std::unique_ptr<IWorkload> RefWorkloadFactory::CreateRsqrt(const RsqrtQueueDescr
     {
         return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
     }
-    else if(IsUint8(info))
-    {
-        return MakeWorkload<NullWorkload, NullWorkload>(descriptor, info);
-    }
     return std::make_unique<RefRsqrtWorkload>(descriptor, info);
 }
 
index 5139888..dbcf201 100644 (file)
@@ -677,6 +677,11 @@ BOOST_AUTO_TEST_CASE(CreateRsqrtFloat32)
     RefCreateRsqrtTest<RefRsqrtWorkload, armnn::DataType::Float32>();
 }
 
+BOOST_AUTO_TEST_CASE(CreateRsqrtUint8)
+{
+    RefCreateRsqrtTest<RefRsqrtWorkload, armnn::DataType::QuantisedAsymm8>();
+}
+
 template <typename L2NormalizationWorkloadType, armnn::DataType DataType>
 static void RefCreateL2NormalizationTest(DataLayout dataLayout)
 {
index fd01550..8ebb725 100644 (file)
@@ -552,6 +552,8 @@ ARMNN_AUTO_TEST_CASE(Rsqrt2d, Rsqrt2dTest<armnn::DataType::Float32>)
 ARMNN_AUTO_TEST_CASE(Rsqrt3d, Rsqrt3dTest<armnn::DataType::Float32>)
 ARMNN_AUTO_TEST_CASE(RsqrtZero, RsqrtZeroTest<armnn::DataType::Float32>)
 ARMNN_AUTO_TEST_CASE(RsqrtNegative, RsqrtNegativeTest<armnn::DataType::Float32>)
+ARMNN_AUTO_TEST_CASE(Rsqrt2dQuantisedAsymm8, Rsqrt2dTest<armnn::DataType::QuantisedAsymm8>)
+ARMNN_AUTO_TEST_CASE(Rsqrt3dQuantisedAsymm8, Rsqrt3dTest<armnn::DataType::QuantisedAsymm8>)
 
 // Permute
 ARMNN_AUTO_TEST_CASE(SimplePermuteFloat32, SimplePermuteFloat32Test)