IVGCVSW-2018 Support NHWC in the current ref implementation
authorMatteo Martincigh <matteo.martincigh@arm.com>
Tue, 16 Oct 2018 15:17:34 +0000 (16:17 +0100)
committerMatthew Bentham <matthew.bentham@arm.com>
Mon, 22 Oct 2018 15:57:54 +0000 (16:57 +0100)
 * Added NHWC support to TensorBufferArrayView class

Change-Id: I41e1d0acd226a471ec834e380389631d9236cb00

src/backends/reference/workloads/TensorBufferArrayView.hpp

index aba44e4..b149073 100644 (file)
@@ -20,6 +20,7 @@ public:
         , m_Data(data)
         , m_DataLayout(dataLayout)
     {
+        BOOST_ASSERT(m_Shape.GetNumDimensions() == 4);
     }
 
     DataType& Get(unsigned int b, unsigned int c, unsigned int h, unsigned int w) const
@@ -32,10 +33,26 @@ public:
         BOOST_ASSERT( w < m_Shape[m_DataLayout.GetWidthIndex()] ||
             ( m_Shape[m_DataLayout.GetWidthIndex()]    == 0 && w == 0) );
 
-        return m_Data[b * m_Shape[1] * m_Shape[2] * m_Shape[3]
-                    + c * m_Shape[m_DataLayout.GetHeightIndex()] * m_Shape[m_DataLayout.GetWidthIndex()]
-                    + h * m_Shape[m_DataLayout.GetWidthIndex()]
-                    + w];
+        // Offset the given indices appropriately depending on the data layout.
+        switch (m_DataLayout.GetDataLayout())
+        {
+        case DataLayout::NHWC:
+            b *= m_Shape[1] * m_Shape[2] * m_Shape[3]; // b *= height_index * width_index * channel_index;
+            h *= m_Shape[m_DataLayout.GetWidthIndex()] * m_Shape[m_DataLayout.GetChannelsIndex()];
+            w *= m_Shape[m_DataLayout.GetChannelsIndex()];
+            // c stays unchanged
+            break;
+        case DataLayout::NCHW:
+        default:
+            b *= m_Shape[1] * m_Shape[2] * m_Shape[3]; // b *= height_index * width_index * channel_index;
+            c *= m_Shape[m_DataLayout.GetHeightIndex()] * m_Shape[m_DataLayout.GetWidthIndex()];
+            h *= m_Shape[m_DataLayout.GetWidthIndex()];
+            // w stays unchanged
+            break;
+        }
+
+        // Get the value using the correct offset.
+        return m_Data[b + c + h + w];
     }
 
 private: