IVGCVSW-2467 Remove GetDataType<T> function
[platform/upstream/armnn.git] / src / backends / backendsCommon / test / MergerTestImpl.hpp
index e0b8233..2bdfe28 100644 (file)
@@ -4,6 +4,8 @@
 //
 #pragma once
 
+#include "TypeUtils.hpp"
+
 #include <armnn/INetwork.hpp>
 
 #include <backendsCommon/test/CommonTestUtils.hpp>
@@ -47,17 +49,18 @@ INetworkPtr CreateMergerNetwork(const std::vector<TensorShape>& inputShapes,
     return net;
 }
 
-template<typename T>
+template<armnn::DataType ArmnnType>
 void MergerDim0EndToEnd(const std::vector<BackendId>& backends)
 {
     using namespace armnn;
+    using T = ResolveType<ArmnnType>;
 
     unsigned int concatAxis = 0;
     const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
     const TensorShape& outputShape = { 4, 3, 2, 2 };
 
     // Builds up the structure of the network
-    INetworkPtr net = CreateMergerNetwork<GetDataType<T>()>(inputShapes, outputShape, concatAxis);
+    INetworkPtr net = CreateMergerNetwork<ArmnnType>(inputShapes, outputShape, concatAxis);
 
     BOOST_TEST_CHECKPOINT("create a network");
 
@@ -110,17 +113,18 @@ void MergerDim0EndToEnd(const std::vector<BackendId>& backends)
     EndToEndLayerTestImpl<T>(move(net), inputTensorData, expectedOutputData, backends);
 }
 
-template<typename T>
+template<armnn::DataType ArmnnType>
 void MergerDim1EndToEnd(const std::vector<BackendId>& backends)
 {
     using namespace armnn;
+    using T = ResolveType<ArmnnType>;
 
     unsigned int concatAxis = 1;
     const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
     const TensorShape& outputShape = { 2, 6, 2, 2 };
 
     // Builds up the structure of the network
-    INetworkPtr net = CreateMergerNetwork<GetDataType<T>()>(inputShapes, outputShape, concatAxis);
+    INetworkPtr net = CreateMergerNetwork<ArmnnType>(inputShapes, outputShape, concatAxis);
 
     BOOST_TEST_CHECKPOINT("create a network");
 
@@ -173,17 +177,18 @@ void MergerDim1EndToEnd(const std::vector<BackendId>& backends)
     EndToEndLayerTestImpl<T>(move(net), inputTensorData, expectedOutputData, backends);
 }
 
-template<typename T>
+template<armnn::DataType ArmnnType>
 void MergerDim2EndToEnd(const std::vector<BackendId>& backends)
 {
     using namespace armnn;
+    using T = ResolveType<ArmnnType>;
 
     unsigned int concatAxis = 2;
     const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
     const TensorShape& outputShape = { 2, 3, 4, 2 };
 
     // Builds up the structure of the network
-    INetworkPtr net = CreateMergerNetwork<GetDataType<T>()>(inputShapes, outputShape, concatAxis);
+    INetworkPtr net = CreateMergerNetwork<ArmnnType>(inputShapes, outputShape, concatAxis);
 
     BOOST_TEST_CHECKPOINT("create a network");
 
@@ -236,7 +241,7 @@ void MergerDim2EndToEnd(const std::vector<BackendId>& backends)
     EndToEndLayerTestImpl<T>(move(net), inputTensorData, expectedOutputData, backends);
 }
 
-template<typename T>
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
 void MergerDim3EndToEnd(const std::vector<BackendId>& backends)
 {
     using namespace armnn;
@@ -246,7 +251,7 @@ void MergerDim3EndToEnd(const std::vector<BackendId>& backends)
     const TensorShape& outputShape = { 2, 3, 2, 4 };
 
     // Builds up the structure of the network
-    INetworkPtr net = CreateMergerNetwork<GetDataType<T>()>(inputShapes, outputShape, concatAxis);
+    INetworkPtr net = CreateMergerNetwork<ArmnnType>(inputShapes, outputShape, concatAxis);
 
     BOOST_TEST_CHECKPOINT("create a network");