IVGCVSW-2467 Remove GetDataType<T> function
[platform/upstream/armnn.git] / src / backends / backendsCommon / test / ArithmeticTestImpl.hpp
index f70bf48..1d6cf1d 100644 (file)
@@ -4,6 +4,8 @@
 //
 #pragma once
 
+#include "TypeUtils.hpp"
+
 #include <armnn/INetwork.hpp>
 
 #include <backendsCommon/test/CommonTestUtils.hpp>
@@ -49,7 +51,7 @@ INetworkPtr CreateArithmeticNetwork(const std::vector<TensorShape>& inputShapes,
     return net;
 }
 
-template<typename T>
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
 void ArithmeticSimpleEndToEnd(const std::vector<BackendId>& backends,
                               const LayerType type,
                               const std::vector<T> expectedOutput)
@@ -60,7 +62,7 @@ void ArithmeticSimpleEndToEnd(const std::vector<BackendId>& backends,
     const TensorShape& outputShape = { 2, 2, 2, 2 };
 
     // Builds up the structure of the network
-    INetworkPtr net = CreateArithmeticNetwork<GetDataType<T>()>(inputShapes, outputShape, type);
+    INetworkPtr net = CreateArithmeticNetwork<ArmnnType>(inputShapes, outputShape, type);
 
     BOOST_TEST_CHECKPOINT("create a network");
 
@@ -76,7 +78,7 @@ void ArithmeticSimpleEndToEnd(const std::vector<BackendId>& backends,
     EndToEndLayerTestImpl<T>(move(net), inputTensorData, expectedOutputData, backends);
 }
 
-template<typename T>
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
 void ArithmeticBroadcastEndToEnd(const std::vector<BackendId>& backends,
                                  const LayerType type,
                                  const std::vector<T> expectedOutput)
@@ -87,7 +89,7 @@ void ArithmeticBroadcastEndToEnd(const std::vector<BackendId>& backends,
     const TensorShape& outputShape = { 1, 2, 2, 3 };
 
     // Builds up the structure of the network
-    INetworkPtr net = CreateArithmeticNetwork<GetDataType<T>()>(inputShapes, outputShape, type);
+    INetworkPtr net = CreateArithmeticNetwork<ArmnnType>(inputShapes, outputShape, type);
 
     BOOST_TEST_CHECKPOINT("create a network");