IVGCVSW-3142 Refactor DataLayoutIndexed and TensorBufferArrayView
authorMatteo Martincigh <matteo.martincigh@arm.com>
Wed, 5 Jun 2019 08:02:41 +0000 (09:02 +0100)
committerMatteo Martincigh <matteo.martincigh@arm.com>
Wed, 5 Jun 2019 08:10:50 +0000 (09:10 +0100)
for convenience

 * Added GetIndex method to DataLayoutIndexed
 * Refactored TensorBufferArrayView::Get to use the new method

Change-Id: Iae08b2761bddeda9e935b25e6bc4985f2d386cd3
Signed-off-by: Matteo Martincigh <matteo.martincigh@arm.com>
src/armnnUtils/DataLayoutIndexed.cpp
src/armnnUtils/DataLayoutIndexed.hpp
src/backends/reference/workloads/TensorBufferArrayView.hpp

index db27de4..b02f07e 100644 (file)
@@ -5,6 +5,8 @@
 
 #include "DataLayoutIndexed.hpp"
 
+#include <boost/assert.hpp>
+
 using namespace armnn;
 
 namespace armnnUtils
@@ -31,13 +33,45 @@ DataLayoutIndexed::DataLayoutIndexed(armnn::DataLayout dataLayout)
     }
 }
 
-// Definition in include/armnn/Types.hpp
+unsigned int DataLayoutIndexed::GetIndex(const TensorShape& shape,
+                                         unsigned int batchIndex, unsigned int channelIndex,
+                                         unsigned int heightIndex, unsigned int widthIndex) const
+{
+    BOOST_ASSERT( batchIndex < shape[0] || ( shape[0] == 0 && batchIndex == 0 ) );
+    BOOST_ASSERT( channelIndex < shape[m_ChannelsIndex] ||
+                ( shape[m_ChannelsIndex] == 0 && channelIndex == 0) );
+    BOOST_ASSERT( heightIndex < shape[m_HeightIndex] ||
+                ( shape[m_HeightIndex] == 0 && heightIndex == 0) );
+    BOOST_ASSERT( widthIndex < shape[m_WidthIndex] ||
+                ( shape[m_WidthIndex] == 0 && widthIndex == 0) );
+
+    // Offset the given indices appropriately depending on the data layout
+    switch (m_DataLayout)
+    {
+    case DataLayout::NHWC:
+        batchIndex  *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
+        heightIndex *= shape[m_WidthIndex] * shape[m_ChannelsIndex];
+        widthIndex  *= shape[m_ChannelsIndex];
+        // channelIndex stays unchanged
+        break;
+    case DataLayout::NCHW:
+    default:
+        batchIndex   *= shape[1] * shape[2] * shape[3]; // batchIndex *= heightIndex * widthIndex * channelIndex
+        channelIndex *= shape[m_HeightIndex] * shape[m_WidthIndex];
+        heightIndex  *= shape[m_WidthIndex];
+        // widthIndex stays unchanged
+        break;
+    }
+
+    // Get the value using the correct offset
+    return batchIndex + channelIndex + heightIndex + widthIndex;
+}
+
 bool operator==(const DataLayout& dataLayout, const DataLayoutIndexed& indexed)
 {
     return dataLayout == indexed.GetDataLayout();
 }
 
-// Definition in include/armnn/Types.hpp
 bool operator==(const DataLayoutIndexed& indexed, const DataLayout& dataLayout)
 {
     return indexed.GetDataLayout() == dataLayout;
index 1cf2a09..5bb8e0d 100644 (file)
@@ -2,8 +2,11 @@
 // Copyright © 2017 Arm Ltd. All rights reserved.
 // SPDX-License-Identifier: MIT
 //
+
 #pragma once
+
 #include <armnn/Types.hpp>
+#include <armnn/Tensor.hpp>
 
 namespace armnnUtils
 {
@@ -18,6 +21,9 @@ public:
     unsigned int      GetChannelsIndex() const { return m_ChannelsIndex; }
     unsigned int      GetHeightIndex()   const { return m_HeightIndex; }
     unsigned int      GetWidthIndex()    const { return m_WidthIndex; }
+    unsigned int      GetIndex(const armnn::TensorShape& shape,
+                               unsigned int batchIndex, unsigned int channelIndex,
+                               unsigned int heightIndex, unsigned int widthIndex) const;
 
 private:
     armnn::DataLayout m_DataLayout;
index aecec67..c064072 100644 (file)
@@ -30,34 +30,7 @@ public:
 
     DataType& Get(unsigned int b, unsigned int c, unsigned int h, unsigned int w) const
     {
-        BOOST_ASSERT( b < m_Shape[0] || ( m_Shape[0]   == 0 && b == 0 ) );
-        BOOST_ASSERT( c < m_Shape[m_DataLayout.GetChannelsIndex()] ||
-            ( m_Shape[m_DataLayout.GetChannelsIndex()] == 0 && c == 0) );
-        BOOST_ASSERT( h < m_Shape[m_DataLayout.GetHeightIndex()] ||
-            ( m_Shape[m_DataLayout.GetHeightIndex()]   == 0 && h == 0) );
-        BOOST_ASSERT( w < m_Shape[m_DataLayout.GetWidthIndex()] ||
-            ( m_Shape[m_DataLayout.GetWidthIndex()]    == 0 && w == 0) );
-
-        // Offset the given indices appropriately depending on the data layout.
-        switch (m_DataLayout.GetDataLayout())
-        {
-        case DataLayout::NHWC:
-            b *= m_Shape[1] * m_Shape[2] * m_Shape[3]; // b *= height_index * width_index * channel_index;
-            h *= m_Shape[m_DataLayout.GetWidthIndex()] * m_Shape[m_DataLayout.GetChannelsIndex()];
-            w *= m_Shape[m_DataLayout.GetChannelsIndex()];
-            // c stays unchanged
-            break;
-        case DataLayout::NCHW:
-        default:
-            b *= m_Shape[1] * m_Shape[2] * m_Shape[3]; // b *= height_index * width_index * channel_index;
-            c *= m_Shape[m_DataLayout.GetHeightIndex()] * m_Shape[m_DataLayout.GetWidthIndex()];
-            h *= m_Shape[m_DataLayout.GetWidthIndex()];
-            // w stays unchanged
-            break;
-        }
-
-        // Get the value using the correct offset.
-        return m_Data[b + c + h + w];
+        return m_Data[m_DataLayout.GetIndex(m_Shape, b, c, h, w)];
     }
 
 private: