IVGCVSW-1931: Add data layout param for ResizeBilinear
authorJames Conroy <james.conroy@arm.com>
Mon, 1 Oct 2018 08:15:19 +0000 (09:15 +0100)
committerMatthew Bentham <matthew.bentham@arm.com>
Wed, 10 Oct 2018 15:16:58 +0000 (16:16 +0100)
* Added data layout parameter to ResizeBilinear
  descriptor, in order to support NHWC.

Change-Id: Ifdbc4529127b7329a056d0a68e2e42b175aeea4a

include/armnn/Descriptors.hpp
src/backends/WorkloadData.cpp
src/backends/cl/workloads/ClResizeBilinearFloatWorkload.cpp

index 30c8144..2de031e 100644 (file)
@@ -92,7 +92,7 @@ struct ViewsDescriptor
     friend void swap(ViewsDescriptor& first, ViewsDescriptor& second);
 private:
     OriginsDescriptor m_Origins;
-    uint32_t** m_ViewSizes;
+    uint32_t**        m_ViewSizes;
 };
 
 /// Convenience template to create an OriginsDescriptor to use when creating a Merger layer for performing concatenation
@@ -308,10 +308,12 @@ struct ResizeBilinearDescriptor
     ResizeBilinearDescriptor()
     : m_TargetWidth(0)
     , m_TargetHeight(0)
+    , m_DataLayout(DataLayout::NCHW)
     {}
 
-    uint32_t m_TargetWidth;
-    uint32_t m_TargetHeight;
+    uint32_t   m_TargetWidth;
+    uint32_t   m_TargetHeight;
+    DataLayout m_DataLayout;
 };
 
 struct ReshapeDescriptor
index c5c607d..8b28b47 100644 (file)
@@ -664,8 +664,11 @@ void ResizeBilinearQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
     }
 
     {
-        const unsigned int inputChannelCount = workloadInfo.m_InputTensorInfos[0].GetShape()[1];
-        const unsigned int outputChannelCount = workloadInfo.m_OutputTensorInfos[0].GetShape()[1];
+        // DataLayout is NCHW by default (channelsIndex = 1)
+        const unsigned int channelsIndex = this->m_Parameters.m_DataLayout == armnn::DataLayout::NHWC ? 3 : 1;
+
+        const unsigned int inputChannelCount = workloadInfo.m_InputTensorInfos[0].GetShape()[channelsIndex];
+        const unsigned int outputChannelCount = workloadInfo.m_OutputTensorInfos[0].GetShape()[channelsIndex];
         if (inputChannelCount != outputChannelCount)
         {
             throw InvalidArgumentException(
index 499466e..1a33035 100644 (file)
@@ -8,14 +8,17 @@
 #include <backends/CpuTensorHandle.hpp>
 #include <backends/cl/ClLayerSupport.hpp>
 #include <backends/aclCommon/ArmComputeUtils.hpp>
+#include <backends/aclCommon/ArmComputeTensorUtils.hpp>
 
 #include "ClWorkloadUtils.hpp"
 
+using namespace armnn::armcomputetensorutils;
+
 namespace armnn
 {
 
 ClResizeBilinearFloatWorkload::ClResizeBilinearFloatWorkload(const ResizeBilinearQueueDescriptor& descriptor,
-                                                               const WorkloadInfo& info)
+                                                             const WorkloadInfo& info)
     : FloatWorkload<ResizeBilinearQueueDescriptor>(descriptor, info)
 {
     m_Data.ValidateInputsOutputs("ClResizeBilinearFloatWorkload", 1, 1);
@@ -23,6 +26,9 @@ ClResizeBilinearFloatWorkload::ClResizeBilinearFloatWorkload(const ResizeBilinea
     arm_compute::ICLTensor& input  = static_cast<IClTensorHandle*>(m_Data.m_Inputs[0])->GetTensor();
     arm_compute::ICLTensor& output = static_cast<IClTensorHandle*>(m_Data.m_Outputs[0])->GetTensor();
 
+    (&input)->info()->set_data_layout(ConvertDataLayout(m_Data.m_Parameters.m_DataLayout));
+    (&output)->info()->set_data_layout(ConvertDataLayout(m_Data.m_Parameters.m_DataLayout));
+
     m_ResizeBilinearLayer.configure(&input, &output, arm_compute::InterpolationPolicy::BILINEAR,
                                     arm_compute::BorderMode::REPLICATE, arm_compute::PixelValue(0.f),
                                     arm_compute::SamplingPolicy::TOP_LEFT);