IVGCVSW-2467 Remove GetDataType<T> function
[platform/upstream/armnn.git] / src / backends / backendsCommon / test / SplitterTestImpl.hpp
index e88356c..004060f 100644 (file)
@@ -16,7 +16,7 @@
 
 #include <test/TensorHelpers.hpp>
 
-template<typename T>
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
 std::vector<LayerTestResult<T,3>> SplitterTestCommon(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
@@ -46,15 +46,15 @@ std::vector<LayerTestResult<T,3>> SplitterTestCommon(
 
 
     // Define the tensor descriptors.
-    armnn::TensorInfo inputTensorInfo({ inputChannels, inputHeight, inputWidth }, armnn::GetDataType<T>());
+    armnn::TensorInfo inputTensorInfo({ inputChannels, inputHeight, inputWidth }, ArmnnType);
 
     // Outputs of the original split.
-    armnn::TensorInfo outputTensorInfo1({ outputChannels1, outputHeight1, outputWidth1 }, armnn::GetDataType<T>());
-    armnn::TensorInfo outputTensorInfo2({ outputChannels2, outputHeight2, outputWidth2 }, armnn::GetDataType<T>());
+    armnn::TensorInfo outputTensorInfo1({ outputChannels1, outputHeight1, outputWidth1 }, ArmnnType);
+    armnn::TensorInfo outputTensorInfo2({ outputChannels2, outputHeight2, outputWidth2 }, ArmnnType);
 
     // Outputs of the subsequent subtensor split.
-    armnn::TensorInfo outputTensorInfo3({ outputChannels1, outputHeight1, outputWidth1 }, armnn::GetDataType<T>());
-    armnn::TensorInfo outputTensorInfo4({ outputChannels1, outputHeight1, outputWidth1 }, armnn::GetDataType<T>());
+    armnn::TensorInfo outputTensorInfo3({ outputChannels1, outputHeight1, outputWidth1 }, ArmnnType);
+    armnn::TensorInfo outputTensorInfo4({ outputChannels1, outputHeight1, outputWidth1 }, ArmnnType);
 
     // Set quantization parameters if the requested type is a quantized type.
     // The quantization doesn't really matter as the splitter operator doesn't dequantize/quantize.
@@ -245,13 +245,13 @@ std::vector<LayerTestResult<T,3>> SplitterTestCommon(
 }
 
 
-template <typename T>
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
 LayerTestResult<T, 3> CopyViaSplitterTestImpl(
     armnn::IWorkloadFactory& workloadFactory,
     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
     float qScale, int32_t qOffset)
 {
-    const armnn::TensorInfo tensorInfo({ 3, 6, 5 }, armnn::GetDataType<T>());
+    const armnn::TensorInfo tensorInfo({ 3, 6, 5 }, ArmnnType);
     auto input = MakeTensor<T, 3>(tensorInfo, QuantizedVector<T>(qScale, qOffset,
                                                                  {
                                                                      1.0f, 2.0f, 3.0f, 4.0f, 5.0f,