IVGCVSW-2467 Remove GetDataType<T> function
[platform/upstream/armnn.git] / src / backends / backendsCommon / test / SoftmaxTestImpl.hpp
index 97199e3..25ceda1 100644 (file)
@@ -19,7 +19,7 @@
 
 #include <algorithm>
 
-template<typename T>
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
 LayerTestResult<T, 2> SimpleSoftmaxTestImpl(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
@@ -32,13 +32,13 @@ LayerTestResult<T, 2> SimpleSoftmaxTestImpl(
 
     unsigned int inputShape[] = { 2, 4 };
 
-    inputTensorInfo = armnn::TensorInfo(2, inputShape, armnn::GetDataType<T>());
+    inputTensorInfo = armnn::TensorInfo(2, inputShape, ArmnnType);
     float qScale = 1.f / 256.f;
     int qOffset = 0;
     inputTensorInfo.SetQuantizationScale(qScale);
     inputTensorInfo.SetQuantizationOffset(qOffset);
 
-    outputTensorInfo = armnn::TensorInfo(2, inputShape, armnn::GetDataType<T>());
+    outputTensorInfo = armnn::TensorInfo(2, inputShape, ArmnnType);
     outputTensorInfo.SetQuantizationScale(qScale);
     outputTensorInfo.SetQuantizationOffset(qOffset);
 
@@ -87,7 +87,7 @@ LayerTestResult<T, 2> SimpleSoftmaxTestImpl(
     return ret;
 }
 
-template<typename T>
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
 LayerTestResult<T, 2> CompareSoftmaxTestImpl(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
@@ -103,8 +103,8 @@ LayerTestResult<T, 2> CompareSoftmaxTestImpl(
 
     unsigned int inputShape[] = { batchSize, channels };
 
-    inputTensorInfo = armnn::TensorInfo(2, inputShape, armnn::GetDataType<T>());
-    outputTensorInfo = armnn::TensorInfo(2, inputShape, armnn::GetDataType<T>());
+    inputTensorInfo = armnn::TensorInfo(2, inputShape, ArmnnType);
+    outputTensorInfo = armnn::TensorInfo(2, inputShape, ArmnnType);
     float qScale = 1.f / 256.f;
     int qOffset = 0;
     inputTensorInfo.SetQuantizationScale(qScale);