MLCE-166: Add support for extracting indices in NEPoolingLayer 2x2 NHWC
authormorgolock <pablo.tello@arm.com>
Fri, 3 Apr 2020 15:57:46 +0000 (16:57 +0100)
committerPablo Marquez <pablo.tello@arm.com>
Tue, 5 May 2020 09:36:00 +0000 (09:36 +0000)
     * Added support for pooling indices in NHWC Poolsize 2x2

Change-Id: Ib2a3468e794f58bbf2c03aba9f6b184b9d76b183
Signed-off-by: morgolock <pablo.tello@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2997
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Manuel Bottini <manuel.bottini@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>

arm_compute/core/NEON/kernels/NEPoolingLayerKernel.h
src/core/NEON/kernels/NEPoolingLayerKernel.cpp
tests/validation/NEON/PoolingLayer.cpp
tests/validation/fixtures/PoolingLayerFixture.h
tests/validation/reference/PoolingLayer.cpp
tests/validation/reference/PoolingLayer.h

index 6519ac72fedd7cb14a751ff1063b0b9780380d85..b0574b7cf67198c81efc0fd487d94d8fd4ca8645 100644 (file)
@@ -92,6 +92,12 @@ private:
      * @param[in] window       Output region on which to execute the kernel.
      */
     void pooling2_f32_nchw_maxpool_indices(const Window &window_input, const Window &window);
+    /** Function to perform 2x2 pooling and compute the pooling indices. The indices can be used for max unpool.
+     *
+     * @param[in] window_input Input region on which to execute the kernel.
+     * @param[in] window       Output region on which to execute the kernel.
+     */
+    void pooling2_f32_nhwc_maxpool_indices(const Window &window_input, const Window &window);
     /** Function to perform MxN pooling for 32-bit floating point values.
      *
      * @param[in] window_input    Input region on which to execute the kernel.
index fdbba815b4ee1dfa208a7d85ce0e866f8752887d..6d61f51f31e3111f638ede9d6c063d0052de23fa 100644 (file)
@@ -156,7 +156,7 @@ Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, c
         if(indices)
         {
             ARM_COMPUTE_RETURN_ERROR_ON_MSG((pool_size != Size2D(2, 2)), "Pooling indices only supported for pool size 2x2");
-            ARM_COMPUTE_RETURN_ERROR_ON_MSG(input->data_layout() == DataLayout::NHWC, "Pool indices only supported in NCHW");
+
             ARM_COMPUTE_RETURN_ERROR_ON((indices->dimension(get_data_layout_dimension_index(indices->data_layout(), DataLayoutDimension::WIDTH)) != pooled_w)
                                         || (indices->dimension(get_data_layout_dimension_index(indices->data_layout(), DataLayoutDimension::HEIGHT)) != pooled_h));
         }
@@ -183,7 +183,9 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITen
     if(indices)
     {
         // Indices auto inizialitation if not yet initialized
-        auto_init_if_empty(*indices, (input->clone()->set_tensor_shape(compute_pool_shape(*input, pool_info))).set_data_type(DataType::U32) /* we store the offset to the element */);
+        auto_init_if_empty(*indices, (input->clone()->set_tensor_shape(compute_pool_shape(*input,
+                                                                                          pool_info)))
+                           .set_data_type(DataType::U32) /* we store the offset to the element */);
     }
     const auto          data_layout                  = pool_info.data_layout == DataLayout::UNKNOWN ? input->data_layout() : pool_info.data_layout;
     unsigned int        num_elems_read_per_iteration = 0;
@@ -1750,24 +1752,126 @@ void NEPoolingLayerKernel::pooling7_f32_nchw(const Window &window_input, const W
 }
 
 void NEPoolingLayerKernel::poolingMxN_f32_nhwc(const Window &window_input, const Window &window, PoolingType pooling_type, bool exclude_padding)
+{
+    if(_pool_info.pool_size == Size2D(2, 2) && pooling_type == PoolingType::MAX && _indices)
+    {
+        pooling2_f32_nhwc_maxpool_indices(window_input, window);
+    }
+    else
+    {
+        Iterator input(_input, window_input);
+        Iterator output(_output, window);
+
+        const int pool_size_x     = _pool_info.is_global_pooling ? _input->info()->tensor_shape().y() : _pool_info.pool_size.width;
+        const int pool_size_y     = _pool_info.is_global_pooling ? _input->info()->tensor_shape().z() : _pool_info.pool_size.height;
+        const int pool_pad_right  = _pool_info.pad_stride_info.pad_right();
+        const int pool_pad_top    = _pool_info.pad_stride_info.pad_top();
+        const int pool_pad_left   = _pool_info.pad_stride_info.pad_left();
+        const int pool_pad_bottom = _pool_info.pad_stride_info.pad_bottom();
+        int       pool_stride_x   = 0;
+        int       pool_stride_y   = 0;
+        std::tie(pool_stride_x, pool_stride_y) = _pool_info.pad_stride_info.stride();
+        const int upper_bound_w = _input->info()->dimension(1) + (exclude_padding ? 0 : pool_pad_right);
+        const int upper_bound_h = _input->info()->dimension(2) + (exclude_padding ? 0 : pool_pad_bottom);
+
+        float32x4_t vres;
+
+        execute_window_loop(window, [&](const Coordinates & id)
+        {
+            const int idx_width    = id.y() * pool_stride_x;
+            const int idx_height   = id.z() * pool_stride_y;
+            const int pool_limit_y = pool_pad_top - idx_height;
+            const int pool_limit_x = pool_pad_left - idx_width;
+
+            const int pool_start_y = std::max(0, window_input.z().start() + pool_limit_y);
+            const int pool_end_y   = std::min(pool_size_y, window_input.z().end() + pool_limit_y);
+            const int pool_start_x = std::max(0, window_input.y().start() + pool_limit_x);
+            const int pool_end_x   = std::min(pool_size_x, window_input.y().end() + pool_limit_x);
+
+            if(pooling_type != PoolingType::MAX)
+            {
+                // Calculate scale
+                const float scale = calculate_avg_scale(exclude_padding, DataLayout::NHWC, id, pool_size_x, pool_size_y, upper_bound_w, upper_bound_h, pool_pad_left, pool_pad_top, pool_stride_x,
+                                                        pool_stride_y);
+                const float32x4_t scale_v = vdupq_n_f32(scale);
+
+                // Perform pooling
+                vres = vdupq_n_f32(0.0f);
+
+                for(int y = pool_start_y; y < pool_end_y; ++y)
+                {
+                    for(int x = pool_start_x; x < pool_end_x; ++x)
+                    {
+                        const float32x4_t data = vld1q_f32(reinterpret_cast<const float *>(input.ptr() + (x - pool_pad_left) * static_cast<int>(_input->info()->strides_in_bytes().y()) + (y - pool_pad_top) * static_cast<int>
+                                                                                           (_input->info()->strides_in_bytes().z())));
+
+                        // Get power of 2 in case of l2 pooling and accumulate
+                        if(pooling_type == PoolingType::L2)
+                        {
+                            vres = vmlaq_f32(vres, data, data);
+                        }
+                        else
+                        {
+                            vres = vaddq_f32(vres, data);
+                        }
+                    }
+                }
+                // Divide by scale
+                vres = vmulq_f32(vres, scale_v);
+            }
+            else
+            {
+                vres = vdupq_n_f32(std::numeric_limits<float>::lowest());
+                for(int y = pool_start_y; y < pool_end_y; ++y)
+                {
+                    for(int x = pool_start_x; x < pool_end_x; ++x)
+                    {
+                        const float32x4_t data = vld1q_f32(reinterpret_cast<const float *>(input.ptr() + (x - pool_pad_left) * static_cast<int>(_input->info()->strides_in_bytes().y()) + (y - pool_pad_top) * static_cast<int>
+                                                                                           (_input->info()->strides_in_bytes().z())));
+                        vres                   = vmaxq_f32(vres, data);
+                    }
+                }
+            }
+
+            // Calculate square-root in case of l2 pooling
+            if(pooling_type == PoolingType::L2)
+            {
+                float32x4_t l2_res = { static_cast<float>(sqrt(vgetq_lane_f32(vres, 0))),
+                                       static_cast<float>(sqrt(vgetq_lane_f32(vres, 1))),
+                                       static_cast<float>(sqrt(vgetq_lane_f32(vres, 2))),
+                                       static_cast<float>(sqrt(vgetq_lane_f32(vres, 3)))
+                                     };
+                vres = l2_res;
+            }
+
+            // Store result
+            vst1q_f32(reinterpret_cast<float *>(output.ptr()), vres);
+        },
+        input, output);
+    }
+}
+
+void NEPoolingLayerKernel::pooling2_f32_nhwc_maxpool_indices(const Window &window_input, const Window &window)
 {
     Iterator input(_input, window_input);
     Iterator output(_output, window);
+    Iterator indices(_indices, window);
 
-    const int pool_size_x     = _pool_info.is_global_pooling ? _input->info()->tensor_shape().y() : _pool_info.pool_size.width;
-    const int pool_size_y     = _pool_info.is_global_pooling ? _input->info()->tensor_shape().z() : _pool_info.pool_size.height;
-    const int pool_pad_right  = _pool_info.pad_stride_info.pad_right();
-    const int pool_pad_top    = _pool_info.pad_stride_info.pad_top();
-    const int pool_pad_left   = _pool_info.pad_stride_info.pad_left();
-    const int pool_pad_bottom = _pool_info.pad_stride_info.pad_bottom();
-    int       pool_stride_x   = 0;
-    int       pool_stride_y   = 0;
+    const int pool_pad_top  = _pool_info.pad_stride_info.pad_top();
+    const int pool_pad_left = _pool_info.pad_stride_info.pad_left();
+
+    int pool_stride_x = 0;
+    int pool_stride_y = 0;
     std::tie(pool_stride_x, pool_stride_y) = _pool_info.pad_stride_info.stride();
-    const int upper_bound_w = _input->info()->dimension(1) + (exclude_padding ? 0 : pool_pad_right);
-    const int upper_bound_h = _input->info()->dimension(2) + (exclude_padding ? 0 : pool_pad_bottom);
 
     float32x4_t vres;
 
+    const int pad_right   = _input->info()->padding().right;
+    const int pad_top     = _input->info()->padding().top;
+    const int in_stride_y = static_cast<int>(_input->info()->strides_in_bytes().y());
+    const int in_stride_z = static_cast<int>(_input->info()->strides_in_bytes().z());
+    const int in_stride_w = static_cast<int>(_input->info()->strides_in_bytes()[3]);
+
     execute_window_loop(window, [&](const Coordinates & id)
     {
         const int idx_width    = id.y() * pool_stride_x;
@@ -1776,70 +1880,53 @@ void NEPoolingLayerKernel::poolingMxN_f32_nhwc(const Window &window_input, const
         const int pool_limit_x = pool_pad_left - idx_width;
 
         const int pool_start_y = std::max(0, window_input.z().start() + pool_limit_y);
-        const int pool_end_y   = std::min(pool_size_y, window_input.z().end() + pool_limit_y);
         const int pool_start_x = std::max(0, window_input.y().start() + pool_limit_x);
-        const int pool_end_x   = std::min(pool_size_x, window_input.y().end() + pool_limit_x);
+        const int in_x0_offset = (pool_start_x - pool_pad_left) * static_cast<int>(_input->info()->strides_in_bytes().y()) + (pool_start_y - pool_pad_top) * static_cast<int>
+                                 (_input->info()->strides_in_bytes().z());
+        const int in_x1_offset = (pool_start_x + 1 - pool_pad_left) * static_cast<int>(_input->info()->strides_in_bytes().y()) + (pool_start_y - pool_pad_top) * static_cast<int>
+                                 (_input->info()->strides_in_bytes().z());
+
+        const int in_x2_offset = (pool_start_x - pool_pad_left) * static_cast<int>(_input->info()->strides_in_bytes().y()) + (pool_start_y + 1 - pool_pad_top) * static_cast<int>
+                                 (_input->info()->strides_in_bytes().z());
+
+        const int in_x3_offset = (pool_start_x + 1 - pool_pad_left) * static_cast<int>(_input->info()->strides_in_bytes().y()) + (pool_start_y + 1 - pool_pad_top) * static_cast<int>
+                                 (_input->info()->strides_in_bytes().z());
+
+        const auto in_x0_ptr = reinterpret_cast<const float *>(input.ptr() + in_x0_offset);
+        const auto in_x1_ptr = reinterpret_cast<const float *>(input.ptr() + in_x1_offset);
+        const auto in_x2_ptr = reinterpret_cast<const float *>(input.ptr() + in_x2_offset);
+        const auto in_x3_ptr = reinterpret_cast<const float *>(input.ptr() + in_x3_offset);
+        const auto v_x0      = vld1q_f32(in_x0_ptr);
+        const auto v_x1      = vld1q_f32(in_x1_ptr);
+        const auto v_x2      = vld1q_f32(in_x2_ptr);
+        const auto v_x3      = vld1q_f32(in_x3_ptr);
+        vres                 = vmaxq_f32(vmaxq_f32(v_x2, v_x3), vmaxq_f32(v_x0, v_x1));
+        // Store result
+        vst1q_f32(reinterpret_cast<float *>(output.ptr()), vres);
 
-        if(pooling_type != PoolingType::MAX)
-        {
-            // Calculate scale
-            const float scale = calculate_avg_scale(exclude_padding, DataLayout::NHWC, id, pool_size_x, pool_size_y, upper_bound_w, upper_bound_h, pool_pad_left, pool_pad_top, pool_stride_x,
-                                                    pool_stride_y);
-            const float32x4_t scale_v = vdupq_n_f32(scale);
+        const uint32_t offset_base = input.offset()
+                                     - sizeof(float) * pad_right * id.y() * pool_stride_x                                     /* subtract padding elems per row */
+                                     - pad_top * sizeof(float)                                                                /* top padding */
+                                     - sizeof(float) * pad_right * _input->info()->tensor_shape()[1] * id.z() * pool_stride_y /* for each Z plane there are width*pad_right padding elems */
+                                     - in_stride_w * id[3] + _input->info()->tensor_shape()[0] * sizeof(float) * id[3];
 
-            // Perform pooling
-            vres = vdupq_n_f32(0.0f);
+        const uint32_t offset_x0 = (uint32_t)offset_base / sizeof(float);
+        const uint32_t offset_x1 = (uint32_t)offset_x0 + in_stride_y / sizeof(float) - pad_right;
+        const uint32_t offset_x2 = (uint32_t)offset_x0 + in_stride_z / sizeof(float) - pad_right * _input->info()->tensor_shape()[1];
+        const uint32_t offset_x3 = (uint32_t)offset_x2 + in_stride_y / sizeof(float) - pad_right;
 
-            for(int y = pool_start_y; y < pool_end_y; ++y)
-            {
-                for(int x = pool_start_x; x < pool_end_x; ++x)
-                {
-                    const float32x4_t data = vld1q_f32(reinterpret_cast<const float *>(input.ptr() + (x - pool_pad_left) * static_cast<int>(_input->info()->strides_in_bytes().y()) + (y - pool_pad_top) * static_cast<int>
-                                                                                       (_input->info()->strides_in_bytes().z())));
+        const uint32x4_t voffset_x0   = { offset_x0, offset_x0 + 1, offset_x0 + 2, offset_x0 + 3 };
+        const uint32x4_t voffset_x1   = { offset_x1, offset_x1 + 1, offset_x1 + 2, offset_x1 + 3 };
+        const uint32x4_t voffset_x2   = { offset_x2, offset_x2 + 1, offset_x2 + 2, offset_x2 + 3 };
+        const uint32x4_t voffset_x3   = { offset_x3, offset_x3 + 1, offset_x3 + 2, offset_x3 + 3 };
+        const uint32x4_t tmp_indices0 = vbslq_u32(vcgtq_f32(v_x0, v_x1), voffset_x0, voffset_x1);
+        const uint32x4_t tmp_indices1 = vbslq_u32(vcgtq_f32(v_x2, v_x3), voffset_x2, voffset_x3);
+        const uint32x4_t tmp_indices2 = vbslq_u32(vcgtq_f32(vmaxq_f32(v_x0, v_x1), vmaxq_f32(v_x2, v_x3)), tmp_indices0, tmp_indices1);
 
-                    // Get power of 2 in case of l2 pooling and accumulate
-                    if(pooling_type == PoolingType::L2)
-                    {
-                        vres = vmlaq_f32(vres, data, data);
-                    }
-                    else
-                    {
-                        vres = vaddq_f32(vres, data);
-                    }
-                }
-            }
-            // Divide by scale
-            vres = vmulq_f32(vres, scale_v);
-        }
-        else
-        {
-            vres = vdupq_n_f32(std::numeric_limits<float>::lowest());
-            for(int y = pool_start_y; y < pool_end_y; ++y)
-            {
-                for(int x = pool_start_x; x < pool_end_x; ++x)
-                {
-                    const float32x4_t data = vld1q_f32(reinterpret_cast<const float *>(input.ptr() + (x - pool_pad_left) * static_cast<int>(_input->info()->strides_in_bytes().y()) + (y - pool_pad_top) * static_cast<int>
-                                                                                       (_input->info()->strides_in_bytes().z())));
-                    vres                   = vmaxq_f32(vres, data);
-                }
-            }
-        }
+        vst1q_u32(reinterpret_cast<uint32_t *>(indices.ptr()), tmp_indices2);
 
-        // Calculate square-root in case of l2 pooling
-        if(pooling_type == PoolingType::L2)
-        {
-            float32x4_t l2_res = { static_cast<float>(sqrt(vgetq_lane_f32(vres, 0))),
-                                   static_cast<float>(sqrt(vgetq_lane_f32(vres, 1))),
-                                   static_cast<float>(sqrt(vgetq_lane_f32(vres, 2))),
-                                   static_cast<float>(sqrt(vgetq_lane_f32(vres, 3)))
-                                 };
-            vres = l2_res;
-        }
-
-        // Store result
-        vst1q_f32(reinterpret_cast<float *>(output.ptr()), vres);
     },
-    input, output);
+    input, output, indices);
 }
 
 template <typename T>
index a5876dcd0afddc3f4fbcb842c6265a2b9bd3e414..4b073d5352c9c4c3ec1269b946ed1f4f0dae3488 100644 (file)
@@ -35,7 +35,6 @@
 #include "tests/framework/datasets/Datasets.h"
 #include "tests/validation/Validation.h"
 #include "tests/validation/fixtures/PoolingLayerFixture.h"
-
 namespace arm_compute
 {
 namespace test
@@ -129,7 +128,7 @@ TEST_SUITE(FP32)
 FIXTURE_DATA_TEST_CASE(RunIndices, NEPoolingLayerIndicesFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(datasets::SmallShapes(), combine(PoolingLayerIndicesDatasetFPSmall,
                                                                                                                    framework::dataset::make("DataType",
                                                                                                                            DataType::F32))),
-                                                                                                                   framework::dataset::make("DataLayout", { DataLayout::NCHW })
+                                                                                                                   framework::dataset::make("DataLayout", { DataLayout::NCHW, DataLayout::NHWC })
 
                                                                                                                   ))
 {
index 7f2d7ac2258c59688d0ca861cae7e44ff36cd206..eb40cea0c24a8f2c8086e7ff39f12f1c4799a6f4 100644 (file)
@@ -35,7 +35,6 @@
 #include "tests/framework/Fixture.h"
 #include "tests/validation/reference/PoolingLayer.h"
 #include <random>
-
 namespace arm_compute
 {
 namespace test
@@ -59,7 +58,7 @@ public:
 
         _pool_info = pool_info;
         _target    = compute_target(shape, pool_info, data_type, data_layout, input_qinfo, output_qinfo, indices);
-        _reference = compute_reference(shape, pool_info, data_type, input_qinfo, output_qinfo, indices);
+        _reference = compute_reference(shape, pool_info, data_type, data_layout, input_qinfo, output_qinfo, indices);
     }
 
 protected:
@@ -92,7 +91,7 @@ protected:
         TensorType        src       = create_tensor<TensorType>(shape, data_type, 1, input_qinfo, data_layout);
         const TensorShape dst_shape = misc::shape_calculator::compute_pool_shape(*(src.info()), info);
         TensorType        dst       = create_tensor<TensorType>(dst_shape, data_type, 1, output_qinfo, data_layout);
-        _target_indices             = create_tensor<TensorType>(dst_shape, DataType::U32, 1);
+        _target_indices             = create_tensor<TensorType>(dst_shape, DataType::U32, 1, output_qinfo, data_layout);
 
         // Create and configure function
         FunctionType pool_layer;
@@ -120,15 +119,14 @@ protected:
         return dst;
     }
 
-    SimpleTensor<T> compute_reference(const TensorShape &shape, PoolingLayerInfo info, DataType data_type,
+    SimpleTensor<T> compute_reference(TensorShape shape, PoolingLayerInfo info, DataType data_type, DataLayout data_layout,
                                       QuantizationInfo input_qinfo, QuantizationInfo output_qinfo, bool indices)
     {
         // Create reference
-        SimpleTensor<T> src{ shape, data_type, 1, input_qinfo };
+        SimpleTensor<T> src(shape, data_type, 1, input_qinfo);
         // Fill reference
         fill(src);
-
-        return reference::pooling_layer<T>(src, info, output_qinfo, indices ? &_ref_indices : nullptr);
+        return reference::pooling_layer<T>(src, info, output_qinfo, indices ? &_ref_indices : nullptr, data_layout);
     }
 
     TensorType             _target{};
index 1a1aebd1b42aa5ceb67efbd79d9dbd9d6be7789a..778e28d7c143dd47ce95ea4ffcf12be8e7d760f1 100644 (file)
@@ -38,7 +38,7 @@ namespace reference
 using namespace arm_compute::misc::shape_calculator;
 
 template <typename T, typename ACC_T, typename std::enable_if<is_floating_point<T>::value, int>::type>
-SimpleTensor<T> pooling_layer_internal(const SimpleTensor<T> &src, const PoolingLayerInfo &info, SimpleTensor<uint32_t> *indices)
+SimpleTensor<T> pooling_layer_internal(const SimpleTensor<T> &src, const PoolingLayerInfo &info, SimpleTensor<uint32_t> *indices, DataLayout data_layout)
 {
     ARM_COMPUTE_ERROR_ON(info.is_global_pooling && (src.shape().x() != src.shape().y()));
     // Create reference
@@ -62,8 +62,10 @@ SimpleTensor<T> pooling_layer_internal(const SimpleTensor<T> &src, const Pooling
     const auto h_src      = static_cast<int>(src.shape()[1]);
     const int  upper_dims = src.shape().total_size() / (w_src * h_src);
 
-    const auto w_dst = static_cast<int>(dst.shape()[0]);
-    const auto h_dst = static_cast<int>(dst.shape()[1]);
+    const auto  w_dst = static_cast<int>(dst.shape()[0]);
+    const auto  h_dst = static_cast<int>(dst.shape()[1]);
+    TensorShape shape_nhwc(src.shape());
+    permute(shape_nhwc, PermutationVector(2U, 0U, 1U));
 
     if(type == PoolingType::MAX)
     {
@@ -89,8 +91,15 @@ SimpleTensor<T> pooling_layer_internal(const SimpleTensor<T> &src, const Pooling
                             const auto val = static_cast<ACC_T>(src[r * h_src * w_src + y * w_src + x]);
                             if(val > max_val)
                             {
-                                max_val   = val;
-                                max_index = coord2index(src.shape(), Coordinates(x, y, r));
+                                max_val = val;
+                                if(data_layout == DataLayout::NCHW)
+                                {
+                                    max_index = coord2index(src.shape(), Coordinates(x, y, r));
+                                }
+                                else
+                                {
+                                    max_index = coord2index(shape_nhwc, Coordinates(r, x, y));
+                                }
                             }
                         }
                     }
@@ -159,48 +168,52 @@ SimpleTensor<T> pooling_layer_internal(const SimpleTensor<T> &src, const Pooling
     return dst;
 }
 
-template SimpleTensor<float> pooling_layer_internal<float>(const SimpleTensor<float> &src, const PoolingLayerInfo &info, SimpleTensor<uint32_t> *indices);
-template SimpleTensor<half> pooling_layer_internal<half>(const SimpleTensor<half> &src, const PoolingLayerInfo &info, SimpleTensor<uint32_t> *indices);
-template SimpleTensor<half> pooling_layer_internal<half, float>(const SimpleTensor<half> &src, const PoolingLayerInfo &info, SimpleTensor<uint32_t> *indices);
+template SimpleTensor<float> pooling_layer_internal<float>(const SimpleTensor<float> &src, const PoolingLayerInfo &info, SimpleTensor<uint32_t> *indices, DataLayout data_layout);
+
+template SimpleTensor<half> pooling_layer_internal<half>(const SimpleTensor<half> &src, const PoolingLayerInfo &info, SimpleTensor<uint32_t> *indices, DataLayout data_layout);
+
+template SimpleTensor<half> pooling_layer_internal<half, float>(const SimpleTensor<half> &src, const PoolingLayerInfo &info, SimpleTensor<uint32_t> *indices, DataLayout data_layout);
 
 template <typename T>
-SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo, SimpleTensor<uint32_t> *indices)
+SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo, SimpleTensor<uint32_t> *indices, DataLayout data_layout)
 {
     ARM_COMPUTE_UNUSED(output_qinfo);
-    return pooling_layer_internal<T, T>(src, info, indices);
+    return pooling_layer_internal<T, T>(src, info, indices, data_layout);
 }
 
 template <>
-SimpleTensor<uint8_t> pooling_layer<uint8_t>(const SimpleTensor<uint8_t> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo, SimpleTensor<uint32_t> *indices)
+SimpleTensor<uint8_t> pooling_layer<uint8_t>(const SimpleTensor<uint8_t> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo, SimpleTensor<uint32_t> *indices,
+                                             DataLayout data_layout)
 {
     SimpleTensor<float>   src_tmp = convert_from_asymmetric(src);
-    SimpleTensor<float>   dst_tmp = pooling_layer_internal<float>(src_tmp, info, indices);
+    SimpleTensor<float>   dst_tmp = pooling_layer_internal<float>(src_tmp, info, indices, data_layout);
     SimpleTensor<uint8_t> dst     = convert_to_asymmetric<uint8_t>(dst_tmp, output_qinfo);
     return dst;
 }
 
 template <>
-SimpleTensor<int8_t> pooling_layer<int8_t>(const SimpleTensor<int8_t> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo, SimpleTensor<uint32_t> *indices)
+SimpleTensor<int8_t> pooling_layer<int8_t>(const SimpleTensor<int8_t> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo, SimpleTensor<uint32_t> *indices, DataLayout data_layout)
 {
     SimpleTensor<float>  src_tmp = convert_from_asymmetric(src);
-    SimpleTensor<float>  dst_tmp = pooling_layer_internal<float>(src_tmp, info, indices);
+    SimpleTensor<float>  dst_tmp = pooling_layer_internal<float>(src_tmp, info, indices, data_layout);
     SimpleTensor<int8_t> dst     = convert_to_asymmetric<int8_t>(dst_tmp, output_qinfo);
     return dst;
 }
 
 template <>
-SimpleTensor<half> pooling_layer(const SimpleTensor<half> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo, SimpleTensor<uint32_t> *indices)
+SimpleTensor<half> pooling_layer(const SimpleTensor<half> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo, SimpleTensor<uint32_t> *indices, DataLayout data_layout)
 {
     ARM_COMPUTE_UNUSED(output_qinfo);
     if(src.data_type() == DataType::F16 && info.fp_mixed_precision)
     {
-        return pooling_layer_internal<half, float>(src, info, indices);
+        return pooling_layer_internal<half, float>(src, info, indices, data_layout);
     }
 
     return pooling_layer_internal<half>(src, info, indices);
 }
 
-template SimpleTensor<float> pooling_layer(const SimpleTensor<float> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo, SimpleTensor<uint32_t> *indices);
+template SimpleTensor<float> pooling_layer(const SimpleTensor<float> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo, SimpleTensor<uint32_t> *indices, DataLayout data_layout);
+
 } // namespace reference
 } // namespace validation
 } // namespace test
index 3ca7f28d5aaa3cadd9bce909279397e8e87b8b0c..346f1c0c9f8ffb2b5990a7648a9c3958a4617987 100644 (file)
@@ -36,9 +36,12 @@ namespace validation
 namespace reference
 {
 template <typename T, typename ACC_T = T, typename std::enable_if<is_floating_point<T>::value, int>::type = 0>
-SimpleTensor<T> pooling_layer_internal(const SimpleTensor<T> &src, const PoolingLayerInfo &info, SimpleTensor<uint32_t> *indices);
+SimpleTensor<T> pooling_layer_internal(const SimpleTensor<T> &src, const PoolingLayerInfo &info, SimpleTensor<uint32_t> *indices, DataLayout data_layout = DataLayout::NCHW);
+
 template <typename T>
-SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo, SimpleTensor<uint32_t> *indices);
+SimpleTensor<T> pooling_layer(const SimpleTensor<T> &src, const PoolingLayerInfo &info, const QuantizationInfo &output_qinfo, SimpleTensor<uint32_t> *indices,
+                              DataLayout data_layout = DataLayout::NCHW);
+
 } // namespace reference
 } // namespace validation
 } // namespace test