IVGCVSW-1865 - Support NHWC for Convolution2D (CpuRef)
[platform/upstream/armnn.git] / src / backends / reference / workloads / ConvImpl.hpp
index 4c9ab2a..60a3622 100644 (file)
@@ -63,21 +63,26 @@ static void ConvImpl(ConvData data,
         throw InvalidArgumentException("Bias is enabled but the bias data is invalid");
     }
 
-    const TensorInfo& inputInfo0 = GetTensorInfo(data.m_Inputs[0]);
+    const TensorInfo& inputInfo0  = GetTensorInfo(data.m_Inputs[0]);
     const TensorInfo& outputInfo0 = GetTensorInfo(data.m_Outputs[0]);
 
+    const DataLayoutIndexed dataLayoutIndexed(data.m_Parameters.m_DataLayout);
+    const unsigned int channelsIndex = dataLayoutIndexed.GetChannelsIndex();
+    const unsigned int heightIndex   = dataLayoutIndexed.GetHeightIndex();
+    const unsigned int widthIndex    = dataLayoutIndexed.GetWidthIndex();
+
     unsigned int depthMult      = depthwise ? filterInfo.GetShape()[0] : 1;
-    unsigned int channelsInput  = filterInfo.GetShape()[1];
+    unsigned int channelsInput  = filterInfo.GetShape()[channelsIndex];
     unsigned int channelsOutput = depthwise ? channelsInput * depthMult : filterInfo.GetShape()[0];
 
     unsigned int batchSize    = outputInfo0.GetShape()[0];
-    unsigned int heightOutput = outputInfo0.GetShape()[2];
-    unsigned int widthOutput  = outputInfo0.GetShape()[3];
-    unsigned int heightInput  = inputInfo0.GetShape()[2];
-    unsigned int widthInput   = inputInfo0.GetShape()[3];
+    unsigned int heightOutput = outputInfo0.GetShape()[heightIndex];
+    unsigned int widthOutput  = outputInfo0.GetShape()[widthIndex];
+    unsigned int heightInput  = inputInfo0.GetShape()[heightIndex];
+    unsigned int widthInput   = inputInfo0.GetShape()[widthIndex];
 
-    unsigned int heightFilter = filterInfo.GetShape()[2];
-    unsigned int widthFilter  = filterInfo.GetShape()[3];
+    unsigned int heightFilter = filterInfo.GetShape()[heightIndex];
+    unsigned int widthFilter  = filterInfo.GetShape()[widthIndex];
 
     unsigned int paddingTop = data.m_Parameters.m_PadTop;
     unsigned int paddingLeft = data.m_Parameters.m_PadLeft;