IVGCVSW-1865 - Support NHWC for Convolution2D (CpuRef)
[platform/upstream/armnn.git] / src / backends / reference / test / RefCreateWorkloadTests.cpp
index dc0348d..236267c 100644 (file)
@@ -177,17 +177,32 @@ BOOST_AUTO_TEST_CASE(CreateConvertFp32ToFp16Float16Workload)
         std::move(workload), TensorInfo({1, 3, 2, 3}, DataType::Float32), TensorInfo({1, 3, 2, 3}, DataType::Float16));
 }
 
-BOOST_AUTO_TEST_CASE(CreateConvolution2dWorkload)
+static void RefCreateConvolution2dWorkloadTest(DataLayout dataLayout = DataLayout::NCHW)
 {
-    Graph                graph;
+    Graph graph;
     RefWorkloadFactory factory;
-    auto                 workload = CreateConvolution2dWorkloadTest<RefConvolution2dFloat32Workload,
-                         DataType::Float32>(factory, graph);
+    auto workload = CreateConvolution2dWorkloadTest<RefConvolution2dFloat32Workload, DataType::Float32>
+                    (factory, graph, dataLayout);
+
+    std::initializer_list<unsigned int> inputShape  = (dataLayout == DataLayout::NCHW) ?
+        std::initializer_list<unsigned int>({2, 3, 8, 16}) : std::initializer_list<unsigned int>({2, 8, 16, 3});
+    std::initializer_list<unsigned int> outputShape = (dataLayout == DataLayout::NCHW) ?
+        std::initializer_list<unsigned int>({2, 2, 2, 10}) : std::initializer_list<unsigned int>({2, 2, 10, 2});
 
     // Checks that outputs and inputs are as we expect them (see definition of CreateConvolution2dWorkloadTest).
     CheckInputOutput(std::move(workload),
-                     TensorInfo({2, 3, 8, 16}, DataType::Float32),
-                     TensorInfo({2, 2, 2, 10}, DataType::Float32));
+                     TensorInfo(inputShape, DataType::Float32),
+                     TensorInfo(outputShape, DataType::Float32));
+}
+
+BOOST_AUTO_TEST_CASE(CreateConvolution2dFloatNchwWorkload)
+{
+    RefCreateConvolution2dWorkloadTest(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateConvolution2dFloatNhwcWorkload)
+{
+    RefCreateConvolution2dWorkloadTest(DataLayout::NHWC);
 }
 
 template <typename FullyConnectedWorkloadType, armnn::DataType DataType>