[neon] Apply inline function style in sgemv_noTrans
authorskykongkong8 <ss.kong@samsung.com>
Thu, 12 Oct 2023 04:33:46 +0000 (13:33 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Mon, 16 Oct 2023 05:05:27 +0000 (14:05 +0900)
- By applying inline function style in sgemv, we use more 128 bit variables in a single iteration.
- Since noTrans sgemv is optimized in column-direction, this optimization is valid and proben by unittest result.

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
nntrainer/tensor/blas_neon.cpp

index 4afe64225d9a3b17a6eb8a00292474b32cbf5389..148dede8898a136f06e745d0c2601f4ba70282a6 100644 (file)
@@ -277,7 +277,6 @@ void sgemv_transpose_neon(const float *A, const float *X, float *Y,
 void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows,
                      uint32_t cols, float alpha, float beta) {
   const __fp16 *__restrict x;
-  const float32x4_t v_beta_32 = vmovq_n_f32(beta);
   float Y32[rows];
 
   unsigned int idx = 0;
@@ -285,15 +284,15 @@ void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows,
   for (; rows - idx >= 8; idx += 8) {
     float32x4_t y0_3 = vcvt_f32_f16(vld1_f16(&Y[idx]));
     float32x4_t y4_7 = vcvt_f32_f16(vld1_f16(&Y[idx + 4]));
-    y0_3 = vmulq_f32(y0_3, v_beta_32);
-    y4_7 = vmulq_f32(y4_7, v_beta_32);
+    y0_3 = vmulq_n_f32(y0_3, beta);
+    y4_7 = vmulq_n_f32(y4_7, beta);
 
     vst1q_f32(&Y32[idx], y0_3);
     vst1q_f32(&Y32[idx + 4], y4_7);
   }
   for (; rows - idx >= 4; idx += 4) {
     float32x4_t y0_3_32 = vcvt_f32_f16(vld1_f16(&Y[idx]));
-    y0_3_32 = vmulq_f32(y0_3_32, v_beta_32);
+    y0_3_32 = vmulq_n_f32(y0_3_32, beta);
 
     vst1q_f32(&Y32[idx], y0_3_32);
   }
@@ -302,7 +301,7 @@ void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows,
   }
 
   idx = 0;
-  for (; cols - idx >= 64; idx += 64) {
+  for (; cols - idx >= 120; idx += 120) {
     float16x8_t x0_7 = vld1q_f16(&X[idx]);
     float16x8_t x8_15 = vld1q_f16(&X[idx + 8]);
     float16x8_t x16_23 = vld1q_f16(&X[idx + 16]);
@@ -313,6 +312,15 @@ void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows,
     float16x8_t x48_55 = vld1q_f16(&X[idx + 48]);
     float16x8_t x56_63 = vld1q_f16(&X[idx + 56]);
 
+    float16x8_t x64_71 = vld1q_f16(&X[idx + 64]);
+    float16x8_t x72_79 = vld1q_f16(&X[idx + 72]);
+    float16x8_t x80_87 = vld1q_f16(&X[idx + 80]);
+
+    float16x8_t x88_95 = vld1q_f16(&X[idx + 88]);
+    float16x8_t x96_103 = vld1q_f16(&X[idx + 96]);
+    float16x8_t x104_111 = vld1q_f16(&X[idx + 104]);
+    float16x8_t x112_120 = vld1q_f16(&X[idx + 112]);
+
     if (alpha != 1.0) {
       x0_7 = vmulq_n_f16(x0_7, alpha);
       x8_15 = vmulq_n_f16(x8_15, alpha);
@@ -322,12 +330,66 @@ void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows,
       x40_47 = vmulq_n_f16(x40_47, alpha);
       x48_55 = vmulq_n_f16(x48_55, alpha);
       x56_63 = vmulq_n_f16(x56_63, alpha);
+
+      x64_71 = vmulq_n_f16(x64_71, alpha);
+      x72_79 = vmulq_n_f16(x72_79, alpha);
+      x80_87 = vmulq_n_f16(x80_87, alpha);
+      x88_95 = vmulq_n_f16(x88_95, alpha);
+      x96_103 = vmulq_n_f16(x96_103, alpha);
+      x104_111 = vmulq_n_f16(x104_111, alpha);
+      x112_120 = vmulq_n_f16(x112_120, alpha);
     }
 
     const __fp16 *__restrict w;
 
-    float yVal_low;
-    float yVal_high;
+    for (unsigned int j = 0; j < rows; ++j) {
+      w = &A[j * cols + idx];
+      float16x8_t y = vmulq_f16(vld1q_f16(&w[0]), x0_7);
+      y = vfmaq_f16(y, vld1q_f16(&w[8]), x8_15);
+      y = vfmaq_f16(y, vld1q_f16(&w[16]), x16_23);
+      y = vfmaq_f16(y, vld1q_f16(&w[24]), x24_31);
+
+      y = vfmaq_f16(y, vld1q_f16(&w[32]), x32_39);
+      y = vfmaq_f16(y, vld1q_f16(&w[40]), x40_47);
+      y = vfmaq_f16(y, vld1q_f16(&w[48]), x48_55);
+      y = vfmaq_f16(y, vld1q_f16(&w[56]), x56_63);
+
+      y = vfmaq_f16(y, vld1q_f16(&w[64]), x64_71);
+      y = vfmaq_f16(y, vld1q_f16(&w[72]), x72_79);
+      y = vfmaq_f16(y, vld1q_f16(&w[80]), x80_87);
+
+      y = vfmaq_f16(y, vld1q_f16(&w[88]), x88_95);
+      y = vfmaq_f16(y, vld1q_f16(&w[96]), x96_103);
+      y = vfmaq_f16(y, vld1q_f16(&w[104]), x104_111);
+      y = vfmaq_f16(y, vld1q_f16(&w[112]), x112_120);
+
+      Y32[j] += vaddvq_f32(vcvt_f32_f16(vget_low_f16(y))) +
+                vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
+    }
+  }
+  for (; cols - idx >= 64; idx += 64) {
+    float16x8_t x0_7 = vld1q_f16(&X[idx]);
+    float16x8_t x8_15 = vld1q_f16(&X[idx + 8]);
+    float16x8_t x16_23 = vld1q_f16(&X[idx + 16]);
+    float16x8_t x24_31 = vld1q_f16(&X[idx + 24]);
+
+    float16x8_t x32_39 = vld1q_f16(&X[idx + 32]);
+    float16x8_t x40_47 = vld1q_f16(&X[idx + 40]);
+    float16x8_t x48_55 = vld1q_f16(&X[idx + 48]);
+    float16x8_t x56_63 = vld1q_f16(&X[idx + 56]);
+
+    if (alpha != 1.0) {
+      x0_7 = vmulq_n_f16(x0_7, alpha);
+      x8_15 = vmulq_n_f16(x8_15, alpha);
+      x16_23 = vmulq_n_f16(x16_23, alpha);
+      x24_31 = vmulq_n_f16(x24_31, alpha);
+      x32_39 = vmulq_n_f16(x32_39, alpha);
+      x40_47 = vmulq_n_f16(x40_47, alpha);
+      x48_55 = vmulq_n_f16(x48_55, alpha);
+      x56_63 = vmulq_n_f16(x56_63, alpha);
+    }
+
+    const __fp16 *__restrict w;
 
     for (unsigned int j = 0; j < rows; ++j) {
       w = &A[j * cols + idx];
@@ -351,10 +413,8 @@ void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows,
       y = vfmaq_f16(y, wvec48_55, x48_55);
       y = vfmaq_f16(y, wvec56_63, x56_63);
 
-      yVal_low = vaddvq_f32(vcvt_f32_f16(vget_low_f16(y)));
-      yVal_high = vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
-
-      Y32[j] += yVal_low + yVal_high;
+      Y32[j] += vaddvq_f32(vcvt_f32_f16(vget_low_f16(y))) +
+                vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
     }
   }
   for (; cols - idx >= 32; idx += 32) {
@@ -372,9 +432,6 @@ void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows,
 
     const __fp16 *__restrict w;
 
-    float yVal_low;
-    float yVal_high;
-
     for (unsigned int j = 0; j < rows; ++j) {
       w = &A[j * cols + idx];
       float16x8_t wvec0_7 = vld1q_f16(&w[0]);
@@ -387,10 +444,8 @@ void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows,
       y = vfmaq_f16(y, wvec16_23, x16_23);
       y = vfmaq_f16(y, wvec24_31, x24_31);
 
-      yVal_low = vaddvq_f32(vcvt_f32_f16(vget_low_f16(y)));
-      yVal_high = vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
-
-      Y32[j] += yVal_low + yVal_high;
+      Y32[j] += vaddvq_f32(vcvt_f32_f16(vget_low_f16(y))) +
+                vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
     }
   }
   for (; cols - idx >= 16; idx += 16) {
@@ -403,8 +458,6 @@ void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows,
     }
 
     const __fp16 *__restrict w;
-    float yVal_low;
-    float yVal_high;
     for (unsigned int j = 0; j < rows; ++j) {
       w = &A[j * cols + idx];
       float16x8_t wvec0_7 = vld1q_f16(&w[0]);
@@ -413,10 +466,8 @@ void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows,
       float16x8_t y = vmulq_f16(wvec0_7, x0_7);
       y = vfmaq_f16(y, wvec8_15, x8_15);
 
-      yVal_low = vaddvq_f32(vcvt_f32_f16(vget_low_f16(y)));
-      yVal_high = vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
-
-      Y32[j] += yVal_low + yVal_high;
+      Y32[j] += vaddvq_f32(vcvt_f32_f16(vget_low_f16(y))) +
+                vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
     }
   }
   for (; cols - idx >= 8; idx += 8) {
@@ -428,18 +479,13 @@ void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows,
 
     const __fp16 *__restrict w;
 
-    float yVal_low;
-    float yVal_high;
-
     for (unsigned int j = 0; j < rows; ++j) {
       w = &A[j * cols + idx];
       float16x8_t wvec0_7 = vld1q_f16(&w[0]);
       float16x8_t y = vmulq_f16(wvec0_7, x0_7);
 
-      yVal_low = vaddvq_f32(vcvt_f32_f16(vget_low_f16(y)));
-      yVal_high = vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
-
-      Y32[j] += yVal_low + yVal_high;
+      Y32[j] += vaddvq_f32(vcvt_f32_f16(vget_low_f16(y))) +
+                vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
     }
   }
   for (; cols - idx >= 4; idx += 4) {
@@ -564,16 +610,15 @@ void sgemv_transpose_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y,
           vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 10) * cols + idx]), x11);
         w2vec0_7_f16 =
           vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 11) * cols + idx]), x12);
-          w2vec0_7_f16 =
+        w2vec0_7_f16 =
           vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 12) * cols + idx]), x13);
-          w2vec0_7_f16 =
+        w2vec0_7_f16 =
           vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 13) * cols + idx]), x14);
-          w2vec0_7_f16 =
+        w2vec0_7_f16 =
           vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 14) * cols + idx]), x15);
-          w2vec0_7_f16 =
+        w2vec0_7_f16 =
           vfmaq_n_f16(w2vec0_7_f16, vld1q_f16(&A[(i + 15) * cols + idx]), x16);
 
-
         float32x4_t y0_3 = vaddq_f32(vld1q_f32(&Y32[idx]),
                                      vcvt_f32_f16(vget_low_f16(wvec0_7_f16)));
         y0_3 = vaddq_f32(y0_3, vcvt_f32_f16(vget_low_f16(w2vec0_7_f16)));
@@ -630,8 +675,7 @@ void sgemv_transpose_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y,
         vst1q_f32(&Y32[idx], y0_3);
       }
     }
-  } else 
-  if (rows % 8 == 0) {
+  } else if (rows % 8 == 0) {
     for (unsigned int i = 0; i < rows; i += 8) {
       __fp16 x = alpha * (X[i]);
       __fp16 x2 = alpha * (X[i + 1]);