COMPMID-3499: Fix integer overflow for large GEMM on NEON
authorSang-Hoon Park <sang-hoon.park@arm.com>
Thu, 21 May 2020 19:34:19 +0000 (20:34 +0100)
committerManuel Bottini <manuel.bottini@arm.com>
Tue, 26 May 2020 09:07:31 +0000 (10:07 +0100)
Change-Id: Id9eef3abc8a902b52ba61772f716f2ba2b97f7d4
Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/3245
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com>
Reviewed-by: Manuel Bottini <manuel.bottini@arm.com>
Reviewed-by: Michalis Spyrou <michalis.spyrou@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>

src/core/NEON/kernels/arm_gemm/transforms/a64_block16_interleave4_8bit.hpp

index 6a15fc42e4ab2c330dd2b0713ceb72778665d26f..6b742c8776a17d8afe49e3f63ff3b20a2802fe21 100644 (file)
@@ -38,7 +38,7 @@ void TransformImpl<4, 16, false, 1, 1, false>::Transform(T *out, const T *in, in
 
     uint8_t zerobuff[16] = { 0 };
 
-    for (int y=y0; y<ymax; y+=4) {
+    for (uint64_t y = y0 ; y < static_cast<uint64_t>(ymax) ; y+=4) {
         const uint8_t *inptr0 = inptr + y * ldin + k0;
         const uint8_t *inptr1 = inptr0 + ldin;
         const uint8_t *inptr2 = inptr1 + ldin;
@@ -52,7 +52,7 @@ void TransformImpl<4, 16, false, 1, 1, false>::Transform(T *out, const T *in, in
         int x=(kmax-k0);
         for (;x>15;x-=16) {
             /* Cope with ragged cases by copying from a buffer of zeroes instead */
-            if ((y + 3) >= ymax) {
+            if ((y + 3) >= static_cast<uint64_t>(ymax)) {
                 switch ((y + 3) - ymax) {
                     /* Everything falls through in here */
                     case 2:
@@ -90,7 +90,7 @@ void TransformImpl<4, 16, false, 1, 1, false>::Transform(T *out, const T *in, in
 
         if (x>0) {
             /* Need to duplicate this here, in case we didn't run the main loop. */
-            if ((y + 3) >= ymax) {
+            if ((y + 3) >= static_cast<uint64_t>(ymax)) {
                 switch ((y + 3) - ymax) {
                     /* Everything falls through in here */
                     case 2: