IVGCVSW-2467 Remove GetDataType<T> function
[platform/upstream/armnn.git] / src / backends / backendsCommon / test / ActivationTestImpl.hpp
index 46c700c..ca61302 100644 (file)
@@ -19,7 +19,7 @@
 
 #include <algorithm>
 
-template<typename T>
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
 LayerTestResult<T, 4> BoundedReLuTestCommon(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
@@ -41,11 +41,9 @@ LayerTestResult<T, 4> BoundedReLuTestCommon(
     unsigned int outputChannels = inputChannels;
     unsigned int outputBatchSize = inputBatchSize;
 
-    armnn::TensorInfo inputTensorInfo({ inputBatchSize, inputChannels, inputHeight, inputWidth },
-        armnn::GetDataType<T>());
+    armnn::TensorInfo inputTensorInfo({ inputBatchSize, inputChannels, inputHeight, inputWidth }, ArmnnType);
 
-    armnn::TensorInfo outputTensorInfo({ outputBatchSize, outputChannels, outputHeight, outputWidth },
-        armnn::GetDataType<T>());
+    armnn::TensorInfo outputTensorInfo({ outputBatchSize, outputChannels, outputHeight, outputWidth }, ArmnnType);
 
     if(armnn::IsQuantizedType<T>())
     {
@@ -115,7 +113,7 @@ LayerTestResult<float, 4> BoundedReLuUpperAndLowerBoundTest(
      0.999f,       1.0f,    0.89f,      1.0f,
     };
 
-    return BoundedReLuTestCommon(
+    return BoundedReLuTestCommon<armnn::DataType::Float32>(
         workloadFactory, memoryManager, 1.0f, -1.0f, 1.0f, 0, 1.0f, 0, input, output,
         inputWidth, inputHeight, inputChannels, inputBatchSize);
 }
@@ -146,7 +144,7 @@ LayerTestResult<float, 4> BoundedReLuUpperBoundOnlyTest(
      0.999f,       1.2f,    0.89f,       6.0f,
     };
 
-    return BoundedReLuTestCommon(
+    return BoundedReLuTestCommon<armnn::DataType::Float32>(
         workloadFactory, memoryManager, 6.0f, 0.0f, 1.0f, 0, 1.0f, 0, input, output,
         inputWidth, inputHeight, inputChannels, inputBatchSize);
 }
@@ -176,10 +174,10 @@ LayerTestResult<uint8_t, 4> BoundedReLuUint8UpperBoundOnlyTest(
     float outputScale    = 6.0f / 255.0f;
     int32_t outputOffset = 0;
 
-    return BoundedReLuTestCommon(workloadFactory, memoryManager, 6.0f, 0.0f,
-                                 inputScale, inputOffset, outputScale, outputOffset,
-                                 input, output,
-                                 inputWidth, inputHeight, inputChannels, inputBatchSize);
+    return BoundedReLuTestCommon<armnn::DataType::QuantisedAsymm8>(
+        workloadFactory, memoryManager, 6.0f, 0.0f,
+        inputScale, inputOffset, outputScale, outputOffset,
+        input, output, inputWidth, inputHeight, inputChannels, inputBatchSize);
 }
 
 LayerTestResult<uint8_t, 4> BoundedReLuUint8UpperAndLowerBoundTest(
@@ -205,10 +203,10 @@ LayerTestResult<uint8_t, 4> BoundedReLuUint8UpperAndLowerBoundTest(
     int32_t inputOffset = 112;
     float inputScale    = 0.0125f;
 
-    return BoundedReLuTestCommon(workloadFactory, memoryManager, 1.0f, -1.0f,
-                                 inputScale, inputOffset, inputScale, inputOffset, // Input/output scale & offset same.
-                                 input, output,
-                                 inputWidth, inputHeight, inputChannels, inputBatchSize);
+    return BoundedReLuTestCommon<armnn::DataType::QuantisedAsymm8>(
+        workloadFactory, memoryManager, 1.0f, -1.0f,
+        inputScale, inputOffset, inputScale, inputOffset, // Input/output scale & offset same.
+        input, output, inputWidth, inputHeight, inputChannels, inputBatchSize);
 }
 
 namespace
@@ -303,7 +301,7 @@ LayerTestResult<float, 4> CompareBoundedReLuTest(
     return result;
 }
 
-template<typename T>
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
 LayerTestResult<T,4> ConstantLinearActivationTestCommon(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
@@ -320,8 +318,8 @@ LayerTestResult<T,4> ConstantLinearActivationTestCommon(
 
     unsigned int shape[]  = {batchSize, inputChannels, inputHeight, inputWidth};
 
-    inputTensorInfo = armnn::TensorInfo(4, shape, armnn::GetDataType<T>());
-    outputTensorInfo = armnn::TensorInfo(4, shape, armnn::GetDataType<T>());
+    inputTensorInfo = armnn::TensorInfo(4, shape, ArmnnType);
+    outputTensorInfo = armnn::TensorInfo(4, shape, ArmnnType);
 
     // Set quantization parameters if the requested type is a quantized type.
     if(armnn::IsQuantizedType<T>())
@@ -368,17 +366,18 @@ LayerTestResult<float, 4> ConstantLinearActivationTest(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
 {
-    return ConstantLinearActivationTestCommon<float>(workloadFactory, memoryManager);
+    return ConstantLinearActivationTestCommon<armnn::DataType::Float32>(workloadFactory, memoryManager);
 }
 
 LayerTestResult<uint8_t, 4> ConstantLinearActivationUint8Test(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
 {
-    return ConstantLinearActivationTestCommon<uint8_t>(workloadFactory, memoryManager, 4.0f, 3);
+    return ConstantLinearActivationTestCommon<armnn::DataType::QuantisedAsymm8>(
+        workloadFactory, memoryManager, 4.0f, 3);
 }
 
-template<typename T>
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
 LayerTestResult<T, 4> SimpleActivationTest(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
@@ -400,10 +399,8 @@ LayerTestResult<T, 4> SimpleActivationTest(
     constexpr static unsigned int outputChannels = inputChannels;
     constexpr static unsigned int outputBatchSize = inputBatchSize;
 
-    armnn::TensorInfo inputTensorInfo({ inputBatchSize, inputChannels, inputHeight, inputWidth },
-                                      armnn::GetDataType<T>());
-    armnn::TensorInfo outputTensorInfo({ outputBatchSize, outputChannels, outputHeight, outputWidth },
-                                       armnn::GetDataType<T>());
+    armnn::TensorInfo inputTensorInfo({ inputBatchSize, inputChannels, inputHeight, inputWidth }, ArmnnType);
+    armnn::TensorInfo outputTensorInfo({ outputBatchSize, outputChannels, outputHeight, outputWidth }, ArmnnType);
 
     // Set quantization parameters if the requested type is a quantized type.
     if(armnn::IsQuantizedType<T>())
@@ -448,7 +445,7 @@ LayerTestResult<T, 4> SimpleActivationTest(
     return result;
 }
 
-template<typename T>
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
 LayerTestResult<T, 4> SimpleSigmoidTestCommon(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
@@ -470,32 +467,32 @@ LayerTestResult<T, 4> SimpleSigmoidTestCommon(
     std::vector<float> outputExpectedData(inputData.size());
     std::transform(inputData.begin(), inputData.end(), outputExpectedData.begin(), f);
 
-    return SimpleActivationTest<T>(workloadFactory,
-                                   memoryManager,
-                                   armnn::ActivationFunction::Sigmoid,
-                                   0.f,
-                                   0.f,
-                                   qScale,
-                                   qOffset,
-                                   inputData,
-                                   outputExpectedData);
+    return SimpleActivationTest<ArmnnType>(workloadFactory,
+                                           memoryManager,
+                                           armnn::ActivationFunction::Sigmoid,
+                                           0.f,
+                                           0.f,
+                                           qScale,
+                                           qOffset,
+                                           inputData,
+                                           outputExpectedData);
 }
 
 LayerTestResult<float, 4> SimpleSigmoidTest(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
 {
-    return SimpleSigmoidTestCommon<float>(workloadFactory, memoryManager, 0.0f, 0);
+    return SimpleSigmoidTestCommon<armnn::DataType::Float32>(workloadFactory, memoryManager, 0.0f, 0);
 }
 
 LayerTestResult<uint8_t, 4> SimpleSigmoidUint8Test(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager)
 {
-    return SimpleSigmoidTestCommon<uint8_t>(workloadFactory, memoryManager, 0.1f, 50);
+    return SimpleSigmoidTestCommon<armnn::DataType::QuantisedAsymm8>(workloadFactory, memoryManager, 0.1f, 50);
 }
 
-template<typename T>
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
 LayerTestResult<T,4> CompareActivationTestImpl(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
@@ -517,8 +514,8 @@ LayerTestResult<T,4> CompareActivationTestImpl(
 
     unsigned int shape[] = {batchSize, channels, height, width};
 
-    inputTensorInfo = armnn::TensorInfo(4, shape, armnn::GetDataType<T>());
-    outputTensorInfo = armnn::TensorInfo(4, shape, armnn::GetDataType<T>());
+    inputTensorInfo = armnn::TensorInfo(4, shape, ArmnnType);
+    outputTensorInfo = armnn::TensorInfo(4, shape, ArmnnType);
 
     // Set quantization parameters if the requested type is a quantized type.
     if(armnn::IsQuantizedType<T>())
@@ -596,7 +593,7 @@ LayerTestResult<float,4> CompareActivationTest(
     armnn::ActivationFunction f,
     unsigned int batchSize)
 {
-    return CompareActivationTestImpl<float>(
+    return CompareActivationTestImpl<armnn::DataType::Float32>(
         workloadFactory, memoryManager, refWorkloadFactory, f, batchSize);
 }
 
@@ -606,6 +603,6 @@ LayerTestResult<uint8_t,4> CompareActivationUint8Test(
     armnn::IWorkloadFactory& refWorkloadFactory,
     armnn::ActivationFunction f)
 {
-    return CompareActivationTestImpl<uint8_t>(
+    return CompareActivationTestImpl<armnn::DataType::QuantisedAsymm8>(
         workloadFactory, memoryManager, refWorkloadFactory, f, 5, 0.1f, 50);
 }