IVGCVSW-1865 - Support NHWC for Convolution2D (CpuRef)
authorNikhil Raj <nikhil.raj@arm.com>
Thu, 18 Oct 2018 09:11:04 +0000 (10:11 +0100)
committerMatthew Bentham <matthew.bentham@arm.com>
Mon, 22 Oct 2018 15:57:54 +0000 (16:57 +0100)
* Updated the ConvImpl.hpp to use DataLayoutIndex
* Enabled unit test for CpuRef
* Update CreateWorkload Tests for ref with NHWC

Change-Id: Id309b7ef677489d63dcb5e09bd48ab9624b5ebfb

src/backends/reference/test/RefCreateWorkloadTests.cpp
src/backends/reference/test/RefLayerTests.cpp
src/backends/reference/workloads/ConvImpl.hpp

index dc0348d..236267c 100644 (file)
@@ -177,17 +177,32 @@ BOOST_AUTO_TEST_CASE(CreateConvertFp32ToFp16Float16Workload)
         std::move(workload), TensorInfo({1, 3, 2, 3}, DataType::Float32), TensorInfo({1, 3, 2, 3}, DataType::Float16));
 }
 
-BOOST_AUTO_TEST_CASE(CreateConvolution2dWorkload)
+static void RefCreateConvolution2dWorkloadTest(DataLayout dataLayout = DataLayout::NCHW)
 {
-    Graph                graph;
+    Graph graph;
     RefWorkloadFactory factory;
-    auto                 workload = CreateConvolution2dWorkloadTest<RefConvolution2dFloat32Workload,
-                         DataType::Float32>(factory, graph);
+    auto workload = CreateConvolution2dWorkloadTest<RefConvolution2dFloat32Workload, DataType::Float32>
+                    (factory, graph, dataLayout);
+
+    std::initializer_list<unsigned int> inputShape  = (dataLayout == DataLayout::NCHW) ?
+        std::initializer_list<unsigned int>({2, 3, 8, 16}) : std::initializer_list<unsigned int>({2, 8, 16, 3});
+    std::initializer_list<unsigned int> outputShape = (dataLayout == DataLayout::NCHW) ?
+        std::initializer_list<unsigned int>({2, 2, 2, 10}) : std::initializer_list<unsigned int>({2, 2, 10, 2});
 
     // Checks that outputs and inputs are as we expect them (see definition of CreateConvolution2dWorkloadTest).
     CheckInputOutput(std::move(workload),
-                     TensorInfo({2, 3, 8, 16}, DataType::Float32),
-                     TensorInfo({2, 2, 2, 10}, DataType::Float32));
+                     TensorInfo(inputShape, DataType::Float32),
+                     TensorInfo(outputShape, DataType::Float32));
+}
+
+BOOST_AUTO_TEST_CASE(CreateConvolution2dFloatNchwWorkload)
+{
+    RefCreateConvolution2dWorkloadTest(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateConvolution2dFloatNhwcWorkload)
+{
+    RefCreateConvolution2dWorkloadTest(DataLayout::NHWC);
 }
 
 template <typename FullyConnectedWorkloadType, armnn::DataType DataType>
index 2137161..259739b 100644 (file)
@@ -36,6 +36,8 @@ ARMNN_AUTO_TEST_CASE(SimpleConvolution2dAsymmetricPaddingLargerThanHalfKernelSiz
     Convolution2dAsymmetricPaddingLargerThanHalfKernelSizeTest)
 ARMNN_AUTO_TEST_CASE(SimpleConvolution2dAsymmetricPadding, Convolution2dAsymmetricPaddingTest)
 
+ARMNN_AUTO_TEST_CASE(SimpleConvolution2dSquareNhwc, SimpleConvolution2d3x3NhwcTest, false)
+
 // Depthwise Convolution
 ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2d, DepthwiseConvolution2dTest, true)
 ARMNN_AUTO_TEST_CASE(DepthwiseConvolution2dUint8, DepthwiseConvolution2dUint8Test, true)
index 4c9ab2a..60a3622 100644 (file)
@@ -63,21 +63,26 @@ static void ConvImpl(ConvData data,
         throw InvalidArgumentException("Bias is enabled but the bias data is invalid");
     }
 
-    const TensorInfo& inputInfo0 = GetTensorInfo(data.m_Inputs[0]);
+    const TensorInfo& inputInfo0  = GetTensorInfo(data.m_Inputs[0]);
     const TensorInfo& outputInfo0 = GetTensorInfo(data.m_Outputs[0]);
 
+    const DataLayoutIndexed dataLayoutIndexed(data.m_Parameters.m_DataLayout);
+    const unsigned int channelsIndex = dataLayoutIndexed.GetChannelsIndex();
+    const unsigned int heightIndex   = dataLayoutIndexed.GetHeightIndex();
+    const unsigned int widthIndex    = dataLayoutIndexed.GetWidthIndex();
+
     unsigned int depthMult      = depthwise ? filterInfo.GetShape()[0] : 1;
-    unsigned int channelsInput  = filterInfo.GetShape()[1];
+    unsigned int channelsInput  = filterInfo.GetShape()[channelsIndex];
     unsigned int channelsOutput = depthwise ? channelsInput * depthMult : filterInfo.GetShape()[0];
 
     unsigned int batchSize    = outputInfo0.GetShape()[0];
-    unsigned int heightOutput = outputInfo0.GetShape()[2];
-    unsigned int widthOutput  = outputInfo0.GetShape()[3];
-    unsigned int heightInput  = inputInfo0.GetShape()[2];
-    unsigned int widthInput   = inputInfo0.GetShape()[3];
+    unsigned int heightOutput = outputInfo0.GetShape()[heightIndex];
+    unsigned int widthOutput  = outputInfo0.GetShape()[widthIndex];
+    unsigned int heightInput  = inputInfo0.GetShape()[heightIndex];
+    unsigned int widthInput   = inputInfo0.GetShape()[widthIndex];
 
-    unsigned int heightFilter = filterInfo.GetShape()[2];
-    unsigned int widthFilter  = filterInfo.GetShape()[3];
+    unsigned int heightFilter = filterInfo.GetShape()[heightIndex];
+    unsigned int widthFilter  = filterInfo.GetShape()[widthIndex];
 
     unsigned int paddingTop = data.m_Parameters.m_PadTop;
     unsigned int paddingLeft = data.m_Parameters.m_PadLeft;