IVGCVSW-2018 Support NHWC in the current ref implementation
authorMatteo Martincigh <matteo.martincigh@arm.com>
Tue, 16 Oct 2018 15:23:33 +0000 (16:23 +0100)
committerMatthew Bentham <matthew.bentham@arm.com>
Mon, 22 Oct 2018 15:57:54 +0000 (16:57 +0100)
 * Enabled the now supported ref layer tests
 * Re-enabled the failing test now that the bug has been fixed in
   ACL 1903a9976ae24f40cb2203364211ed62fcfbb985
 * Added CreateWorkload test for ref L2Normalization NHWC
 * Refactoring the ref L2Normalization for clarity

!armnn:153723

Change-Id: Id0067e49072b3e057ffe3ae3b70d928be6091c0f

src/backends/cl/test/ClLayerTests.cpp
src/backends/reference/test/RefCreateWorkloadTests.cpp
src/backends/reference/test/RefLayerTests.cpp
src/backends/reference/workloads/RefL2NormalizationFloat32Workload.cpp
src/backends/reference/workloads/RefL2NormalizationFloat32Workload.hpp

index 62ce2cb..3b1603c 100755 (executable)
@@ -188,8 +188,7 @@ ARMNN_AUTO_TEST_CASE(L2Normalization2d, L2Normalization2dTest)
 ARMNN_AUTO_TEST_CASE(L2Normalization3d, L2Normalization3dTest)
 ARMNN_AUTO_TEST_CASE(L2Normalization4d, L2Normalization4dTest)
 
-// NOTE: The following test hits a bug in ACL that makes it fail, keep it disabled until a patch is available in ACL
-//ARMNN_AUTO_TEST_CASE(L2Normalization1dNhwc, L2Normalization1dNhwcTest)
+ARMNN_AUTO_TEST_CASE(L2Normalization1dNhwc, L2Normalization1dNhwcTest)
 ARMNN_AUTO_TEST_CASE(L2Normalization2dNhwc, L2Normalization2dNhwcTest)
 ARMNN_AUTO_TEST_CASE(L2Normalization3dNhwc, L2Normalization3dNhwcTest)
 ARMNN_AUTO_TEST_CASE(L2Normalization4dNhwc, L2Normalization4dNhwcTest)
index a8901d2..dc0348d 100644 (file)
@@ -449,18 +449,42 @@ BOOST_AUTO_TEST_CASE(CreateResizeBilinearFloat32Nhwc)
     RefCreateResizeBilinearTest<RefResizeBilinearFloat32Workload, armnn::DataType::Float32>(DataLayout::NHWC);
 }
 
-BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32)
+template <typename L2NormalizationWorkloadType, armnn::DataType DataType>
+static void RefCreateL2NormalizationTest(DataLayout dataLayout)
 {
     Graph graph;
     RefWorkloadFactory factory;
-    auto workload = CreateL2NormalizationWorkloadTest<RefL2NormalizationFloat32Workload, armnn::DataType::Float32>
-            (factory, graph);
+    auto workload =
+            CreateL2NormalizationWorkloadTest<L2NormalizationWorkloadType, DataType>(factory, graph, dataLayout);
+
+    TensorShape inputShape;
+    TensorShape outputShape;
+
+    switch (dataLayout)
+    {
+        case DataLayout::NHWC:
+            inputShape  = { 5, 50, 67, 20 };
+            outputShape = { 5, 50, 67, 20 };
+            break;
+        case DataLayout::NCHW:
+        default:
+            inputShape  = { 5, 20, 50, 67 };
+            outputShape = { 5, 20, 50, 67 };
+            break;
+    }
 
     // Checks that outputs and inputs are as we expect them (see definition of CreateL2NormalizationWorkloadTest).
-    CheckInputOutput(
-        std::move(workload),
-        TensorInfo({ 5, 20, 50, 67 }, armnn::DataType::Float32),
-        TensorInfo({ 5, 20, 50, 67 }, armnn::DataType::Float32));
+    CheckInputOutput(std::move(workload), TensorInfo(inputShape, DataType), TensorInfo(outputShape, DataType));
+}
+
+BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32)
+{
+    RefCreateL2NormalizationTest<RefL2NormalizationFloat32Workload, armnn::DataType::Float32>(DataLayout::NCHW);
+}
+
+BOOST_AUTO_TEST_CASE(CreateL2NormalizationFloat32Nhwc)
+{
+    RefCreateL2NormalizationTest<RefL2NormalizationFloat32Workload, armnn::DataType::Float32>(DataLayout::NHWC);
 }
 
 template <typename ReshapeWorkloadType, armnn::DataType DataType>
index 797051e..2815e34 100644 (file)
@@ -211,11 +211,10 @@ ARMNN_AUTO_TEST_CASE(Pad2d, Pad2dTest)
 ARMNN_AUTO_TEST_CASE(Pad3d, Pad3dTest)
 ARMNN_AUTO_TEST_CASE(Pad4d, Pad4dTest)
 
-// NOTE: These tests are disabled until NHWC is supported by the reference L2Normalization implementation.
-//ARMNN_AUTO_TEST_CASE(L2Normalization1dNhwc, L2Normalization1dNhwcTest);
-//ARMNN_AUTO_TEST_CASE(L2Normalization2dNhwc, L2Normalization2dNhwcTest);
-//ARMNN_AUTO_TEST_CASE(L2Normalization3dNhwc, L2Normalization3dNhwcTest);
-//ARMNN_AUTO_TEST_CASE(L2Normalization4dNhwc, L2Normalization4dNhwcTest);
+ARMNN_AUTO_TEST_CASE(L2Normalization1dNhwc, L2Normalization1dNhwcTest)
+ARMNN_AUTO_TEST_CASE(L2Normalization2dNhwc, L2Normalization2dNhwcTest)
+ARMNN_AUTO_TEST_CASE(L2Normalization3dNhwc, L2Normalization3dNhwcTest)
+ARMNN_AUTO_TEST_CASE(L2Normalization4dNhwc, L2Normalization4dNhwcTest)
 
 // Constant
 ARMNN_AUTO_TEST_CASE(Constant, ConstantTest)
index 973c87b..d21cfa9 100644 (file)
@@ -22,26 +22,32 @@ void RefL2NormalizationFloat32Workload::Execute() const
     const TensorInfo& inputInfo = GetTensorInfo(m_Data.m_Inputs[0]);
     const TensorInfo& outputInfo = GetTensorInfo(m_Data.m_Outputs[0]);
 
-    TensorBufferArrayView<const float> input(inputInfo.GetShape(), GetInputTensorDataFloat(0, m_Data));
-    TensorBufferArrayView<float> output(outputInfo.GetShape(), GetOutputTensorDataFloat(0, m_Data));
+    TensorBufferArrayView<const float> input(inputInfo.GetShape(),
+                                             GetInputTensorDataFloat(0, m_Data),
+                                             m_Data.m_Parameters.m_DataLayout);
+    TensorBufferArrayView<float> output(outputInfo.GetShape(),
+                                        GetOutputTensorDataFloat(0, m_Data),
+                                        m_Data.m_Parameters.m_DataLayout);
 
-    const unsigned int batchSize = inputInfo.GetShape()[0];
-    const unsigned int depth = inputInfo.GetShape()[1];
-    const unsigned int rows = inputInfo.GetShape()[2];
-    const unsigned int cols = inputInfo.GetShape()[3];
+    DataLayoutIndexed dataLayout(m_Data.m_Parameters.m_DataLayout);
 
-    for (unsigned int n = 0; n < batchSize; ++n)
+    const unsigned int batches  = inputInfo.GetShape()[0];
+    const unsigned int channels = inputInfo.GetShape()[dataLayout.GetChannelsIndex()];
+    const unsigned int height   = inputInfo.GetShape()[dataLayout.GetHeightIndex()];
+    const unsigned int width    = inputInfo.GetShape()[dataLayout.GetWidthIndex()];
+
+    for (unsigned int n = 0; n < batches; ++n)
     {
-        for (unsigned int d = 0; d < depth; ++d)
+        for (unsigned int c = 0; c < channels; ++c)
         {
-            for (unsigned int h = 0; h < rows; ++h)
+            for (unsigned int h = 0; h < height; ++h)
             {
-                for (unsigned int w = 0; w < cols; ++w)
+                for (unsigned int w = 0; w < width; ++w)
                 {
                     float reduction = 0.0;
-                    for (unsigned int c = 0; c < depth; ++c)
+                    for (unsigned int d = 0; d < channels; ++d)
                     {
-                        const float value = input.Get(n, c, h, w);
+                        const float value = input.Get(n, d, h, w);
                         reduction += value * value;
                     }
 
@@ -51,7 +57,7 @@ void RefL2NormalizationFloat32Workload::Execute() const
                     //   backend.
                     // - The reference semantics for this operator do not include this parameter.
                     const float scale = 1.0f / sqrtf(reduction);
-                    output.Get(n, d, h, w) = input.Get(n, d, h, w) * scale;
+                    output.Get(n, c, h, w) = input.Get(n, c, h, w) * scale;
                 }
             }
         }
index 67055a9..b2e3795 100644 (file)
@@ -15,7 +15,8 @@ class RefL2NormalizationFloat32Workload : public Float32Workload<L2Normalization
 {
 public:
     using Float32Workload<L2NormalizationQueueDescriptor>::Float32Workload;
-    virtual void Execute() const override;
+
+    void Execute() const override;
 };
 
 } //namespace armnn