COMPMID-3236: Extend CLGEMMLowpReduction kernels to multiply by a scalar value
authorMichele Di Giorgio <michele.digiorgio@arm.com>
Fri, 3 Apr 2020 11:40:10 +0000 (12:40 +0100)
committerMichele Di Giorgio <michele.digiorgio@arm.com>
Wed, 8 Apr 2020 08:43:39 +0000 (08:43 +0000)
Change-Id: Iebd6afac65d10a42d60c2c9df9e1895fadb205ae
Signed-off-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/2981
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>

arm_compute/core/CL/kernels/CLGEMMLowpReductionKernel.h
src/core/CL/cl_kernels/gemmlowp.cl
src/core/CL/kernels/CLGEMMLowpReductionKernel.cpp
src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp

index 4e52a8029edb1068d6ae08fb5752d69d9726fc4d..71681cf628f72da4e89c12c5138952bfbf579baf 100644 (file)
@@ -29,6 +29,7 @@
 namespace arm_compute
 {
 class ICLTensor;
+struct GEMMLowpReductionKernelInfo;
 
 /** Common interface for all OpenCL reduction kernels */
 class ICLGEMMLowpReductionKernel : public ICLKernel
@@ -49,8 +50,13 @@ public:
      *
      * @param[in]  input  Input tensor. Data type supported: S8
      * @param[out] output Output row-vector of sums of all the entries in each row/col of input tensor. Data type supported: S32
+     * @param[in]  info   Kernel metadata:
+     *                    - k            Number of matrix columns/rows depending on the type of reduction.
+     *                    - is_reshaped  True if the matrix has been reshaped.
+     *                    - scalar       Scalar value to multiply each reduced column/row by.
+     *                    - mul_byscalar True if each reduced column/row must be multiplied by a scalar value.
      */
-    virtual void configure(const ICLTensor *input, ICLTensor *output) = 0;
+    virtual void configure(const ICLTensor *input, ICLTensor *output, const GEMMLowpReductionKernelInfo &info) = 0;
 
 protected:
     const ICLTensor *_input;
@@ -69,16 +75,26 @@ public:
      *
      * @param[in]  mtx_a          Input tensor. Data type supported: QASYMM8/QASYMM8_SIGNED
      * @param[out] vector_sum_row Output row-vector of sums of all the entries in each row of mtx_a. Data type supported: S32
+     * @param[in]  info           Kernel metadata:
+     *                            - k            Number of matrix columns/rows depending on the type of reduction.
+     *                            - is_reshaped  True if the matrix has been reshaped.
+     *                            - scalar       Scalar value to multiply each reduced column/row by.
+     *                            - mul_byscalar True if each reduced column/row must be multiplied by a scalar value.
      */
-    void configure(const ICLTensor *mtx_a, ICLTensor *vector_sum_row) override;
+    void configure(const ICLTensor *mtx_a, ICLTensor *vector_sum_row, const GEMMLowpReductionKernelInfo &info) override;
     /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMLowpMatrixAReductionKernel
      *
      * @param[in] mtx_a          Input tensor. Data type supported: QASYMM8/QASYMM8_SIGNED
      * @param[in] vector_sum_row Output row-vector of sums of all the entries in each row of mtx_a. Data type supported: S32
+     * @param[in] info           Kernel metadata:
+     *                           - k            Number of matrix columns/rows depending on the type of reduction.
+     *                           - is_reshaped  True if the matrix has been reshaped.
+     *                           - scalar       Scalar value to multiply each reduced column/row by.
+     *                           - mul_byscalar True if each reduced column/row must be multiplied by a scalar value.
      *
      * @return a status
      */
-    static Status validate(const ITensorInfo *mtx_a, const ITensorInfo *vector_sum_row);
+    static Status validate(const ITensorInfo *mtx_a, const ITensorInfo *vector_sum_row, const GEMMLowpReductionKernelInfo &info);
 
     // Inherited methods overridden:
     void run(const Window &window, cl::CommandQueue &queue) override;
@@ -96,16 +112,26 @@ public:
      *
      * @param[in]  mtx_b          Input tensor. Data type supported: Data type supported: QASYMM8/QASYMM8_SIGNED
      * @param[out] vector_sum_col Output row-vector of sums of all the entries in each column of mtx_b. Data type supported: S32
+     * @param[in]  info           Kernel metadata:
+     *                            - k            Number of matrix columns/rows depending on the type of reduction.
+     *                            - is_reshaped  True if the matrix has been reshaped.
+     *                            - scalar       Scalar value to multiply each reduced column/row by.
+     *                            - mul_byscalar True if each reduced column/row must be multiplied by a scalar value.
      */
-    void configure(const ICLTensor *mtx_b, ICLTensor *vector_sum_col) override;
+    void configure(const ICLTensor *mtx_b, ICLTensor *vector_sum_col, const GEMMLowpReductionKernelInfo &info) override;
     /** Static function to check if given info will lead to a valid configuration of @ref CLGEMMLowpMatrixBReductionKernel
      *
      * @param[in] mtx_b          Input tensor. Data type supported: Data type supported: QASYMM8/QASYMM8_SIGNED
      * @param[in] vector_sum_col Output row-vector of sums of all the entries in each column of mtx_b. Data type supported: S32
+     * @param[in] info           Kernel metadata:
+     *                           - k            Number of matrix columns/rows depending on the type of reduction.
+     *                           - is_reshaped  True if the matrix has been reshaped.
+     *                           - scalar       Scalar value to multiply each reduced column/row by.
+     *                           - mul_byscalar True if each reduced column/row must be multiplied by a scalar value.
      *
      * @return a status
      */
-    static Status validate(const ITensorInfo *mtx_b, const ITensorInfo *vector_sum_col);
+    static Status validate(const ITensorInfo *mtx_b, const ITensorInfo *vector_sum_col, const GEMMLowpReductionKernelInfo &info);
 
     // Inherited methods overridden:
     void run(const Window &window, cl::CommandQueue &queue) override;
index 71de1d4b2788bddaa4fb3d1bc972f9e34b8d9967..b707ec81752974ac6633ec63cec3a6b06aad4c96 100644 (file)
@@ -1287,6 +1287,7 @@ __kernel void gemmlowp_mm_native(IMAGE_DECLARATION(lhs),
 
 #if defined(COLS_A)
 /** OpenCL kernel used to compute the row-vectors of sums of all the entries in each row of Matrix A.
+ * It is also possible to multiply each reduced row by a scalar value, if SCALAR is passed at compile time.
  *
  * @note This stage is needed to handle the offset of matrix product
  *       https://github.com/google/gemmlowp/blob/master/doc/low-precision.md
@@ -1294,8 +1295,9 @@ __kernel void gemmlowp_mm_native(IMAGE_DECLARATION(lhs),
  * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
  * @note The input data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=uchar)
  * @note The data type for the accumulation must be passed at compile time using -DACC_DATA_TYPE (i.e. -DACC_DATA_TYPE=uint)
+ * @note In case of scaling the scalar value must be passed at compile time using -DSCALAR (e.g. -DSCALAR=3)
  *
- * @param[in]  src_ptr                           Pointer to the source tensor. Supported data type: QASYMM8
+ * @param[in]  src_ptr                           Pointer to the source tensor. Supported data type: QASYMM8/QASYMM8_SIGNED
  * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
  * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
@@ -1342,11 +1344,15 @@ __kernel void gemmlowp_matrix_a_reduction(TENSOR3D_DECLARATION(src),
 
     sum_row += sum_row_32.s0 + sum_row_32.s1 + sum_row_32.s2 + sum_row_32.s3;
 
+#if defined(SCALAR)
+    sum_row *= (int)SCALAR;
+#endif // defined(SCALAR)
     *((__global int *)dst.ptr) = (int)sum_row;
 }
 
 #if defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
-/** OpenCL kernel used to compute the row-vectors of sums of all the entries in each row of Matrix A using the arm dot product instruction
+/** OpenCL kernel used to compute the row-vectors of sums of all the entries in each row of Matrix A using the arm dot product instruction.
+ * It is also possible to multiply each reduced row by a scalar value, if SCALAR is passed at compile time.
  *
  * @note This stage is needed to handle the offset of matrix product
  *       https://github.com/google/gemmlowp/blob/master/doc/low-precision.md
@@ -1354,8 +1360,9 @@ __kernel void gemmlowp_matrix_a_reduction(TENSOR3D_DECLARATION(src),
  * @attention The number of matrix A columns needs to be passed at compile time using -DCOLS_A
  * @note The input data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=uchar)
  * @note The data type for the accumulation must be passed at compile time using -DACC_DATA_TYPE (i.e. -DACC_DATA_TYPE=uint)
+ * @note In case of scaling the scalar value must be passed at compile time using -DSCALAR (e.g. -DSCALAR=3)
  *
- * @param[in]  src_ptr                           Pointer to the source tensor. Supported data type: QASYMM8
+ * @param[in]  src_ptr                           Pointer to the source tensor. Supported data type: QASYMM8/QASYMM8_SIGNED
  * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
  * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
@@ -1408,6 +1415,9 @@ __kernel void gemmlowp_matrix_a_reduction_dot8(TENSOR3D_DECLARATION(src),
         sum_row += (ACC_DATA_TYPE)matrix_a[i];
     }
 
+#if defined(SCALAR)
+    sum_row *= (int)SCALAR;
+#endif // defined(SCALAR)
     *((__global int *)dst.ptr) = (int)sum_row;
 }
 #endif // defined(ARM_COMPUTE_OPENCL_DOT8_ENABLED) && defined(cl_arm_integer_dot_product_int8)
@@ -1415,6 +1425,7 @@ __kernel void gemmlowp_matrix_a_reduction_dot8(TENSOR3D_DECLARATION(src),
 
 #if defined(COLS_B) && defined(ROWS_B)
 /** OpenCL kernel used to compute the row-vectors of sums of all the entries in each column of Matrix B.
+ * It is also possible to multiply each reduced column by a scalar value, if SCALAR is passed at compile time.
  *
  * @note This stage is needed to handle the offset of matrix product
  *       https://github.com/google/gemmlowp/blob/master/doc/low-precision.md
@@ -1422,8 +1433,9 @@ __kernel void gemmlowp_matrix_a_reduction_dot8(TENSOR3D_DECLARATION(src),
  * @attention The number of matrix B columns and rows needs to be passed at compile time using -DCOLS_B and -DROWS_B
  * @note The input data type must be passed at compile time using -DDATA_TYPE (i.e. -DDATA_TYPE=uchar)
  * @note The data type for the accumulation must be passed at compile time using -DACC_DATA_TYPE (i.e. -DACC_DATA_TYPE=uint)
+ * @note In case of scaling the scalar value must be passed at compile time using -DSCALAR (i.e. -DSCALAR=3)
  *
- * @param[in]  src_ptr                           Pointer to the source tensor. Supported data type: QASYMM8
+ * @param[in]  src_ptr                           Pointer to the source tensor. Supported data type: QASYMM8/QASYMM8_SIGNED
  * @param[in]  src_stride_x                      Stride of the source tensor in X dimension (in bytes)
  * @param[in]  src_step_x                        src_stride_x * number of elements along X processed per workitem(in bytes)
  * @param[in]  src_stride_y                      Stride of the source tensor in Y dimension (in bytes)
@@ -1480,7 +1492,11 @@ __kernel void gemmlowp_matrix_b_reduction(TENSOR3D_DECLARATION(src),
         matrix_b += src_stride_y;
     }
 
-    vstore16(convert_int16(sum_col_32), 0, (__global int *)dst.ptr);
+#if defined(SCALAR)
+    sum_col_32 *= (VEC_DATA_TYPE(ACC_DATA_TYPE, 16))SCALAR;
+#endif // defined(SCALAR)
+    VSTORE(16)
+    (sum_col_32, 0, (__global int *)dst.ptr);
 }
 #endif // defined(COLS_B) && defined(ROWS_B)
 
index 8f3d53a412136dffb349ba338e11e6876d0f599b..832e6281f4b01293ecd924d9c95449684b78a544 100644 (file)
 #include "arm_compute/core/AccessWindowStatic.h"
 #include "arm_compute/core/CL/CLHelpers.h"
 #include "arm_compute/core/CL/ICLTensor.h"
-#include "arm_compute/core/Error.h"
-#include "arm_compute/core/Helpers.h"
-#include "arm_compute/core/TensorInfo.h"
-#include "arm_compute/core/Types.h"
-#include "arm_compute/core/Utils.h"
-#include "arm_compute/core/Validate.h"
-#include "arm_compute/core/Window.h"
+#include "arm_compute/core/KernelDescriptors.h"
 #include "support/StringSupport.h"
 
-#include <cstddef>
-#include <cstdint>
-
 namespace arm_compute
 {
-class Coordinates;
-
 namespace
 {
 Status validate_arguments_matrix_a_reduction(const ITensorInfo *input, const ITensorInfo *output)
 {
+    ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
 
+    if(output->total_size() > 0)
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->dimension(0) != input->dimension(1), "Output vector must have length equal to the number of rows of the input matrix");
+    }
     return Status{};
 }
-std::pair<Status, Window> validate_and_configure_window_matrix_a_reduction(ITensorInfo *input, ITensorInfo *output)
-{
-    constexpr unsigned int num_elems_processed_per_iteration = 1;
-
-    Window win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration));
-
-    AccessWindowStatic     input_access(input, 0, 0, input->dimension(0), input->dimension(1));
-    AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
-
-    bool window_changed = update_window_and_padding(win, input_access, output_access);
-
-    output_access.set_valid_region(win, ValidRegion(Coordinates(0, 0), output->tensor_shape()));
-
-    Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
-    return std::make_pair(err, win);
-}
 
 Status validate_arguments_matrix_b_reduction(const ITensorInfo *input, const ITensorInfo *output)
 {
+    ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
-    ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
 
+    if(output->total_size() > 0)
+    {
+        ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::S32);
+        ARM_COMPUTE_RETURN_ERROR_ON_MSG(output->dimension(0) != input->dimension(0), "Output vector must have length equal to the number of columns of the input matrix");
+    }
     return Status{};
 }
 
@@ -100,7 +83,7 @@ ICLGEMMLowpReductionKernel::ICLGEMMLowpReductionKernel()
 {
 }
 
-void CLGEMMLowpMatrixAReductionKernel::configure(const ICLTensor *mtx_a, ICLTensor *vector_sum_row)
+void CLGEMMLowpMatrixAReductionKernel::configure(const ICLTensor *mtx_a, ICLTensor *vector_sum_row, const GEMMLowpReductionKernelInfo &info)
 {
     // Perform validate step
     ARM_COMPUTE_ERROR_ON_NULLPTR(mtx_a, vector_sum_row);
@@ -114,6 +97,7 @@ void CLGEMMLowpMatrixAReductionKernel::configure(const ICLTensor *mtx_a, ICLTens
     build_opts.add_option("-DCOLS_A=" + support::cpp11::to_string(mtx_a->info()->dimension(0)));
     build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(mtx_a->info()->data_type()));
     build_opts.add_option("-DACC_DATA_TYPE=" + get_cl_dot8_acc_type_from_data_type(mtx_a->info()->data_type()));
+    build_opts.add_option_if(info.mul_by_scalar, "-DSCALAR=" + support::cpp11::to_string(info.scalar));
 
     const bool is_dot8_supported = dot8_supported(CLKernelLibrary::get().get_device());
 
@@ -123,9 +107,9 @@ void CLGEMMLowpMatrixAReductionKernel::configure(const ICLTensor *mtx_a, ICLTens
     _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(kernel_name, build_opts.options()));
 
     // Configure kernel window
-    auto win_config = validate_and_configure_window_matrix_a_reduction(_input->info(), _output->info());
-    ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
-    ICLKernel::configure_internal(win_config.second);
+    // This kernel does not need padding
+    Window win = calculate_max_window(*vector_sum_row->info(), Steps());
+    ICLKernel::configure_internal(win);
 
     _config_id = kernel_name;
     _config_id += "_";
@@ -136,10 +120,10 @@ void CLGEMMLowpMatrixAReductionKernel::configure(const ICLTensor *mtx_a, ICLTens
     _config_id += support::cpp11::to_string(_input->info()->dimension(2));
 }
 
-Status CLGEMMLowpMatrixAReductionKernel::validate(const ITensorInfo *mtx_a, const ITensorInfo *vector_sum_row)
+Status CLGEMMLowpMatrixAReductionKernel::validate(const ITensorInfo *mtx_a, const ITensorInfo *vector_sum_row, const GEMMLowpReductionKernelInfo &info)
 {
+    ARM_COMPUTE_UNUSED(info);
     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_matrix_a_reduction(mtx_a, vector_sum_row));
-    ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_matrix_a_reduction(mtx_a->clone().get(), vector_sum_row->clone().get()).first);
 
     return Status{};
 }
@@ -168,7 +152,7 @@ void CLGEMMLowpMatrixAReductionKernel::run(const Window &window, cl::CommandQueu
     while(collapsed.slide_window_slice_2D(slice_out));
 }
 
-void CLGEMMLowpMatrixBReductionKernel::configure(const ICLTensor *mtx_b, ICLTensor *vector_sum_col)
+void CLGEMMLowpMatrixBReductionKernel::configure(const ICLTensor *mtx_b, ICLTensor *vector_sum_col, const GEMMLowpReductionKernelInfo &info)
 {
     ARM_COMPUTE_ERROR_ON_NULLPTR(mtx_b, vector_sum_col);
     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_matrix_b_reduction(mtx_b->info(), vector_sum_col->info()));
@@ -182,6 +166,7 @@ void CLGEMMLowpMatrixBReductionKernel::configure(const ICLTensor *mtx_b, ICLTens
     build_opts.add_option("-DROWS_B=" + support::cpp11::to_string(mtx_b->info()->dimension(1)));
     build_opts.add_option("-DDATA_TYPE=" + get_cl_type_from_data_type(mtx_b->info()->data_type()));
     build_opts.add_option("-DACC_DATA_TYPE=" + get_cl_dot8_acc_type_from_data_type(mtx_b->info()->data_type()));
+    build_opts.add_option_if(info.mul_by_scalar, "-DSCALAR=" + support::cpp11::to_string(info.scalar));
 
     // Create kernel
     _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel("gemmlowp_matrix_b_reduction", build_opts.options()));
@@ -192,8 +177,9 @@ void CLGEMMLowpMatrixBReductionKernel::configure(const ICLTensor *mtx_b, ICLTens
     ICLKernel::configure_internal(win_config.second);
 }
 
-Status CLGEMMLowpMatrixBReductionKernel::validate(const ITensorInfo *mtx_b, const ITensorInfo *vector_sum_col)
+Status CLGEMMLowpMatrixBReductionKernel::validate(const ITensorInfo *mtx_b, const ITensorInfo *vector_sum_col, const GEMMLowpReductionKernelInfo &info)
 {
+    ARM_COMPUTE_UNUSED(info);
     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments_matrix_b_reduction(mtx_b, vector_sum_col));
     ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window_matrix_b_reduction(mtx_b->clone().get(), vector_sum_col->clone().get()).first);
 
index 9346e9357c62206f620cf19aea538b1d242a09da..90e5698fd8854309c29ed90f40e829b049e922e1 100644 (file)
@@ -28,6 +28,7 @@
 #include "arm_compute/core/CL/gemm/reshaped_only_rhs/CLGEMMReshapedOnlyRHSKernelConfiguration.h"
 #include "arm_compute/core/Error.h"
 #include "arm_compute/core/Helpers.h"
+#include "arm_compute/core/KernelDescriptors.h"
 #include "arm_compute/core/TensorInfo.h"
 #include "arm_compute/core/Types.h"
 #include "arm_compute/core/Validate.h"
@@ -149,6 +150,9 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor
         _mtx_b_reshape_kernel.configure(_convert_to_qasymm8 ? &_qasymm8_weights : b, &_tmp_b, rhs_info);
     }
 
+    // Using default reduction info
+    const GEMMLowpReductionKernelInfo reduction_info {};
+
     // Initialize matrix B reduction kernel only if _a_offset is not equal to 0
     if(_a_offset != 0)
     {
@@ -160,7 +164,7 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor
         }
 
         // Configure Matrix B reduction kernel
-        _mtx_b_reduction_kernel.configure(_convert_to_qasymm8 ? &_qasymm8_weights : b, &_vector_sum_col);
+        _mtx_b_reduction_kernel.configure(_convert_to_qasymm8 ? &_qasymm8_weights : b, &_vector_sum_col, reduction_info);
     }
 
     // Initialize Matrix A reduction kernel only if _b_offset is not equal to 0
@@ -171,7 +175,7 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor
         _memory_group.manage(&_vector_sum_row);
 
         // Configure matrix A reduction kernel
-        _mtx_a_reduction_kernel.configure(a, &_vector_sum_row);
+        _mtx_a_reduction_kernel.configure(a, &_vector_sum_row, reduction_info);
     }
 
     GEMMKernelInfo gemm_kernel_info;
@@ -356,13 +360,14 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
     TensorInfo info_vector_sum_col{};
     TensorInfo info_vector_sum_row{};
 
+    const GEMMLowpReductionKernelInfo reduction_info;
     // Validate matrix B reduction kernel only if _a_offset is not equal to 0
     if(a_offset != 0)
     {
         info_vector_sum_col = TensorInfo(compute_reductionA_shape(weights_info), 1, DataType::S32);
 
         // Configure Matrix B reduction kernel
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixBReductionKernel::validate(&weights_info, &info_vector_sum_col));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixBReductionKernel::validate(&weights_info, &info_vector_sum_col, reduction_info));
     }
 
     // Validate Matrix A reduction kernel only if _b_offset is not equal to 0
@@ -371,7 +376,7 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
         info_vector_sum_row = TensorInfo(compute_reductionB_shape(*a), 1, DataType::S32);
 
         // Configure matrix A reduction kernel
-        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(a, &info_vector_sum_row));
+        ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpMatrixAReductionKernel::validate(a, &info_vector_sum_row, reduction_info));
     }
 
     GEMMKernelInfo gemm_kernel_info;