//
#pragma once
+#include "TypeUtils.hpp"
+
#include <armnn/INetwork.hpp>
#include <backendsCommon/test/CommonTestUtils.hpp>
return net;
}
-template<typename T>
+template<armnn::DataType ArmnnType>
void MergerDim0EndToEnd(const std::vector<BackendId>& backends)
{
using namespace armnn;
+ using T = ResolveType<ArmnnType>;
unsigned int concatAxis = 0;
const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
const TensorShape& outputShape = { 4, 3, 2, 2 };
// Builds up the structure of the network
- INetworkPtr net = CreateMergerNetwork<GetDataType<T>()>(inputShapes, outputShape, concatAxis);
+ INetworkPtr net = CreateMergerNetwork<ArmnnType>(inputShapes, outputShape, concatAxis);
BOOST_TEST_CHECKPOINT("create a network");
EndToEndLayerTestImpl<T>(move(net), inputTensorData, expectedOutputData, backends);
}
-template<typename T>
+template<armnn::DataType ArmnnType>
void MergerDim1EndToEnd(const std::vector<BackendId>& backends)
{
using namespace armnn;
+ using T = ResolveType<ArmnnType>;
unsigned int concatAxis = 1;
const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
const TensorShape& outputShape = { 2, 6, 2, 2 };
// Builds up the structure of the network
- INetworkPtr net = CreateMergerNetwork<GetDataType<T>()>(inputShapes, outputShape, concatAxis);
+ INetworkPtr net = CreateMergerNetwork<ArmnnType>(inputShapes, outputShape, concatAxis);
BOOST_TEST_CHECKPOINT("create a network");
EndToEndLayerTestImpl<T>(move(net), inputTensorData, expectedOutputData, backends);
}
-template<typename T>
+template<armnn::DataType ArmnnType>
void MergerDim2EndToEnd(const std::vector<BackendId>& backends)
{
using namespace armnn;
+ using T = ResolveType<ArmnnType>;
unsigned int concatAxis = 2;
const std::vector<TensorShape> inputShapes{{ 2, 3, 2, 2 }, { 2, 3, 2, 2 }};
const TensorShape& outputShape = { 2, 3, 4, 2 };
// Builds up the structure of the network
- INetworkPtr net = CreateMergerNetwork<GetDataType<T>()>(inputShapes, outputShape, concatAxis);
+ INetworkPtr net = CreateMergerNetwork<ArmnnType>(inputShapes, outputShape, concatAxis);
BOOST_TEST_CHECKPOINT("create a network");
EndToEndLayerTestImpl<T>(move(net), inputTensorData, expectedOutputData, backends);
}
-template<typename T>
+template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
void MergerDim3EndToEnd(const std::vector<BackendId>& backends)
{
using namespace armnn;
const TensorShape& outputShape = { 2, 3, 2, 4 };
// Builds up the structure of the network
- INetworkPtr net = CreateMergerNetwork<GetDataType<T>()>(inputShapes, outputShape, concatAxis);
+ INetworkPtr net = CreateMergerNetwork<ArmnnType>(inputShapes, outputShape, concatAxis);
BOOST_TEST_CHECKPOINT("create a network");