[ matrix_transpose/bugfix ] Prevent reading/saving data from/to unallocated memory
authorskykongkong8 <ss.kong@samsung.com>
Tue, 6 Aug 2024 04:35:37 +0000 (13:35 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 9 Aug 2024 04:47:39 +0000 (13:47 +0900)
- Previous transpose kernel occasionally load/save unallocated memory, and then masked it.
- Now, it does not read them at the first place, but load with for-loop
- This would deteriorate speed of fp16 matrix transpose, but won't be dominant in total model latency

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
nntrainer/tensor/matrix_transpose_neon/matrix_transpose_kernels_neon.h

index fd8290eca39399be6bad2cea58fb0f898461be75..bfb4141b1cfd91115b6f67f671dfc3e1f7fcf2a6 100644 (file)
@@ -73,14 +73,20 @@ static void transpose_kernel_mxn_neon_128(unsigned int N, const __fp16 *src,
                                           unsigned int ld_src, __fp16 *dst,
                                           unsigned int ld_dst) {
 
-  uint16x4_t bitmask_v8 =
-    vld1_u16(reinterpret_cast<const uint16_t *>(masks[N]));
   float16x4_t input[4];
   float16x4_t ZEROS = vmov_n_f16(0.F);
 
   unsigned i;
   for (i = 0; i < M; ++i) {
-    input[i] = vbsl_f16(bitmask_v8, vld1_f16(&src[i * ld_src]), ZEROS);
+    if (N == 4) {
+      input[i] = vld1_f16(&src[i * ld_src]);
+    } else {
+      float16x4_t tmp = ZEROS;
+      for (unsigned int n = 0; n < N; ++n) {
+        tmp[n] = src[i * ld_src + n];
+      }
+      input[i] = tmp;
+    }
   }
   for (; i < 4; ++i) {
     input[i] = vmov_n_f16(0.F);
@@ -95,7 +101,6 @@ static void transpose_kernel_mxn_neon_128(unsigned int N, const __fp16 *src,
     temp[i] = vmov_n_f16(0.F);
   }
 
-  bitmask_v8 = vld1_u16(reinterpret_cast<const uint16_t *>(masks[M]));
   for (i = 0; i < N; ++i) {
     if (i % 2 == 0) {
       input[i] =
@@ -106,10 +111,16 @@ static void transpose_kernel_mxn_neon_128(unsigned int N, const __fp16 *src,
         vcombine_f32(vget_high_f32(vcvt_f32_f16(temp[i / 2])),
                      vget_high_f32(vcvt_f32_f16(temp[2 + i / 2]))));
     }
-    vst1_f16(&dst[i * ld_dst],
-             vbsl_f16(bitmask_v8, input[i], vld1_f16(&dst[i * ld_dst])));
+    if (M == 4) {
+      vst1_f16(&dst[i * ld_dst], input[i]);
+    } else {
+      for (unsigned int m = 0; m < M; ++m) {
+        dst[i * ld_dst + m] = input[i][m];
+      }
+    }
   }
 }
+
 /**
  * @brief 8x8 sized kernel for matrix transpose in NEON
  *
@@ -182,6 +193,7 @@ static inline void transpose_kernel_8x8_neon(const __fp16 *src,
   vst1q_f16(&dst[6 * ld_dst], g);
   vst1q_f16(&dst[7 * ld_dst], h);
 }
+
 /**
  * @brief general case mxn sized matrix transpose kernel with 256 bit SIMD
  * register
@@ -198,12 +210,18 @@ static void transpose_kernel_mxn_neon_256(unsigned int N, const __fp16 *src,
                                           unsigned int ld_src, __fp16 *dst,
                                           unsigned int ld_dst) {
   float16x8_t ZEROS = vmovq_n_f16(0.F);
-  uint16x8_t bitmask_v8 =
-    vld1q_u16(reinterpret_cast<const uint16_t *>(neon_16bit_masks[N]));
   float16x8_t input[8];
   unsigned i;
   for (i = 0; i < M; ++i) {
-    input[i] = vbslq_f16(bitmask_v8, vld1q_f16(&src[i * ld_src]), ZEROS);
+    if (N == 8) {
+      input[i] = vld1q_f16(&src[i * ld_src]);
+    } else {
+      float16x8_t tmp = ZEROS;
+      for (unsigned int n = 0; n < N; ++n) {
+        tmp[n] = src[i * ld_src + n];
+      }
+      input[i] = tmp;
+    }
   }
   for (; i < 8; ++i) {
     input[i] = ZEROS;
@@ -235,8 +253,6 @@ static void transpose_kernel_mxn_neon_256(unsigned int N, const __fp16 *src,
       vbslq_f16(shuffle_mask, vextq_f16(temp[4 * i + 1], temp[4 * i + 1], 2),
                 temp[4 * i + 3]);
   }
-  bitmask_v8 =
-    vld1q_u16(reinterpret_cast<const uint16_t *>(neon_16bit_masks[M]));
   for (i = 0; i < N; ++i) {
     if (i < 4) {
       temp[i] =
@@ -245,7 +261,12 @@ static void transpose_kernel_mxn_neon_256(unsigned int N, const __fp16 *src,
       temp[i] =
         vcombine_f16(vget_high_f16(input[i - 4]), vget_high_f16(input[i]));
     }
-    vst1q_f16(&dst[i * ld_dst],
-              vbslq_f16(bitmask_v8, temp[i], vld1q_f16(&dst[i * ld_dst])));
+    if (M == 8) {
+      vst1q_f16(&dst[i * ld_dst], temp[i]);
+    } else {
+      for (unsigned int m = 0; m < M; ++m) {
+        dst[i * ld_dst + m] = temp[i][m];
+      }
+    }
   }
 }