[neon] Optmize sgemv
authorskykongkong8 <ss.kong@samsung.com>
Fri, 6 Oct 2023 07:17:36 +0000 (16:17 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 10 Oct 2023 05:41:23 +0000 (14:41 +0900)
- Instead of declaring explicit register variable, declaring the function in inline code can save the number of register variable in use.
- This way, we can load more of variables to accelerate sgemv computation

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
nntrainer/tensor/blas_neon.cpp

index d08f1bdb1cb8e5cf03523c805c4e419d7597b9a2..8bb4532f6ca69cc2252045e20f67ea6ccf5a4923 100644 (file)
@@ -281,6 +281,7 @@ void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows,
   float Y32[rows];
 
   unsigned int idx = 0;
+
   for (; rows - idx >= 8; idx += 8) {
     float32x4_t y0_3 = vcvt_f32_f16(vld1_f16(&Y[idx]));
     float32x4_t y4_7 = vcvt_f32_f16(vld1_f16(&Y[idx + 4]));
@@ -296,125 +297,166 @@ void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows,
 
     vst1q_f32(&Y32[idx], y0_3_32);
   }
-  while (idx < rows) {
-    Y32[idx] = Y[idx] * beta;
-    ++idx;
+  for (; idx < rows; ++idx) {
+    Y32[idx] = beta * Y[idx];
   }
 
   idx = 0;
+  for (; cols - idx >= 64; idx += 64) {
+    float16x8_t x0_7 = vld1q_f16(&X[idx]);
+    float16x8_t x8_15 = vld1q_f16(&X[idx + 8]);
+    float16x8_t x16_23 = vld1q_f16(&X[idx + 16]);
+    float16x8_t x24_31 = vld1q_f16(&X[idx + 24]);
+
+    float16x8_t x32_39 = vld1q_f16(&X[idx + 32]);
+    float16x8_t x40_47 = vld1q_f16(&X[idx + 40]);
+    float16x8_t x48_55 = vld1q_f16(&X[idx + 48]);
+    float16x8_t x56_63 = vld1q_f16(&X[idx + 56]);
+
+    if (alpha != 1.0) {
+      x0_7 = vmulq_n_f16(x0_7, alpha);
+      x8_15 = vmulq_n_f16(x8_15, alpha);
+      x16_23 = vmulq_n_f16(x16_23, alpha);
+      x24_31 = vmulq_n_f16(x24_31, alpha);
+      x32_39 = vmulq_n_f16(x32_39, alpha);
+      x40_47 = vmulq_n_f16(x40_47, alpha);
+      x48_55 = vmulq_n_f16(x48_55, alpha);
+      x56_63 = vmulq_n_f16(x56_63, alpha);
+    }
+
+    const __fp16 *__restrict w;
+
+    float yVal_low;
+    float yVal_high;
+
+    for (unsigned int j = 0; j < rows; ++j) {
+      w = &A[j * cols + idx];
+      float16x8_t wvec0_7 = vld1q_f16(&w[0]);
+      float16x8_t wvec8_15 = vld1q_f16(&w[8]);
+      float16x8_t wvec16_23 = vld1q_f16(&w[16]);
+      float16x8_t wvec24_31 = vld1q_f16(&w[24]);
+
+      float16x8_t wvec32_39 = vld1q_f16(&w[32]);
+      float16x8_t wvec40_47 = vld1q_f16(&w[40]);
+      float16x8_t wvec48_55 = vld1q_f16(&w[48]);
+      float16x8_t wvec56_63 = vld1q_f16(&w[56]);
+
+      float16x8_t y = vmulq_f16(wvec0_7, x0_7);
+      y = vfmaq_f16(y, wvec8_15, x8_15);
+      y = vfmaq_f16(y, wvec16_23, x16_23);
+      y = vfmaq_f16(y, wvec24_31, x24_31);
+
+      y = vfmaq_f16(y, wvec32_39, x32_39);
+      y = vfmaq_f16(y, wvec40_47, x40_47);
+      y = vfmaq_f16(y, wvec48_55, x48_55);
+      y = vfmaq_f16(y, wvec56_63, x56_63);
+
+      yVal_low = vaddvq_f32(vcvt_f32_f16(vget_low_f16(y)));
+      yVal_high = vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
+
+      Y32[j] += yVal_low + yVal_high;
+    }
+  }
   for (; cols - idx >= 32; idx += 32) {
-    float32x4_t x0_3_f32 = vcvt_f32_f16(vld1_f16(&X[idx]));
-    float32x4_t x4_7_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 4]));
-    float32x4_t x8_11_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 8]));
-    float32x4_t x12_15_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 12]));
-    float32x4_t x16_19_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 16]));
-    float32x4_t x20_23_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 20]));
-    float32x4_t x24_27_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 24]));
-    float32x4_t x28_31_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 28]));
+    float16x8_t x0_7 = vld1q_f16(&X[idx]);
+    float16x8_t x8_15 = vld1q_f16(&X[idx + 8]);
+    float16x8_t x16_23 = vld1q_f16(&X[idx + 16]);
+    float16x8_t x24_31 = vld1q_f16(&X[idx + 24]);
 
     if (alpha != 1.0) {
-      x0_3_f32 = vmulq_n_f32(x0_3_f32, alpha);
-      x4_7_f32 = vmulq_n_f32(x4_7_f32, alpha);
-      x8_11_f32 = vmulq_n_f32(x8_11_f32, alpha);
-      x12_15_f32 = vmulq_n_f32(x12_15_f32, alpha);
-      x16_19_f32 = vmulq_n_f32(x16_19_f32, alpha);
-      x20_23_f32 = vmulq_n_f32(x20_23_f32, alpha);
-      x24_27_f32 = vmulq_n_f32(x24_27_f32, alpha);
-      x28_31_f32 = vmulq_n_f32(x28_31_f32, alpha);
+      x0_7 = vmulq_n_f16(x0_7, alpha);
+      x8_15 = vmulq_n_f16(x8_15, alpha);
+      x16_23 = vmulq_n_f16(x16_23, alpha);
+      x24_31 = vmulq_n_f16(x24_31, alpha);
     }
 
     const __fp16 *__restrict w;
 
+    float yVal_low;
+    float yVal_high;
+
     for (unsigned int j = 0; j < rows; ++j) {
       w = &A[j * cols + idx];
-      float32x4_t wvec0_3_f32 = vcvt_f32_f16(vld1_f16(&w[0]));
-      float32x4_t wvec4_7_f32 = vcvt_f32_f16(vld1_f16(&w[4]));
-      float32x4_t wvec8_11_f32 = vcvt_f32_f16(vld1_f16(&w[8]));
-      float32x4_t wvec12_15_f32 = vcvt_f32_f16(vld1_f16(&w[12]));
-      float32x4_t wvec16_19_f32 = vcvt_f32_f16(vld1_f16(&w[16]));
-      float32x4_t wvec20_23_f32 = vcvt_f32_f16(vld1_f16(&w[20]));
-      float32x4_t wvec24_27_f32 = vcvt_f32_f16(vld1_f16(&w[24]));
-      float32x4_t wvec28_31_f32 = vcvt_f32_f16(vld1_f16(&w[28]));
+      float16x8_t wvec0_7 = vld1q_f16(&w[0]);
+      float16x8_t wvec8_15 = vld1q_f16(&w[8]);
+      float16x8_t wvec16_23 = vld1q_f16(&w[16]);
+      float16x8_t wvec24_31 = vld1q_f16(&w[24]);
 
-      float32x4_t y0 = vmulq_f32(wvec0_3_f32, x0_3_f32);
-      y0 = vfmaq_f32(y0, wvec4_7_f32, x4_7_f32);
-      y0 = vfmaq_f32(y0, wvec8_11_f32, x8_11_f32);
-      y0 = vfmaq_f32(y0, wvec12_15_f32, x12_15_f32);
-      y0 = vfmaq_f32(y0, wvec16_19_f32, x16_19_f32);
-      y0 = vfmaq_f32(y0, wvec20_23_f32, x20_23_f32);
-      y0 = vfmaq_f32(y0, wvec24_27_f32, x24_27_f32);
-      y0 = vfmaq_f32(y0, wvec28_31_f32, x28_31_f32);
+      float16x8_t y = vmulq_f16(wvec0_7, x0_7);
+      y = vfmaq_f16(y, wvec8_15, x8_15);
+      y = vfmaq_f16(y, wvec16_23, x16_23);
+      y = vfmaq_f16(y, wvec24_31, x24_31);
 
-      Y32[j] += vaddvq_f32(y0);
+      yVal_low = vaddvq_f32(vcvt_f32_f16(vget_low_f16(y)));
+      yVal_high = vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
+
+      Y32[j] += yVal_low + yVal_high;
     }
   }
   for (; cols - idx >= 16; idx += 16) {
-    float32x4_t x0_3_f32 = vcvt_f32_f16(vld1_f16(&X[idx]));
-    float32x4_t x4_7_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 4]));
-    float32x4_t x8_11_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 8]));
-    float32x4_t x12_15_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 12]));
+    float16x8_t x0_7 = vld1q_f16(&X[idx]);
+    float16x8_t x8_15 = vld1q_f16(&X[idx + 8]);
 
     if (alpha != 1.0) {
-      x0_3_f32 = vmulq_n_f32(x0_3_f32, alpha);
-      x4_7_f32 = vmulq_n_f32(x4_7_f32, alpha);
-      x8_11_f32 = vmulq_n_f32(x8_11_f32, alpha);
-      x12_15_f32 = vmulq_n_f32(x12_15_f32, alpha);
+      x0_7 = vmulq_n_f16(x0_7, alpha);
+      x8_15 = vmulq_n_f16(x8_15, alpha);
     }
 
     const __fp16 *__restrict w;
-
+    float yVal_low;
+    float yVal_high;
     for (unsigned int j = 0; j < rows; ++j) {
       w = &A[j * cols + idx];
-      float32x4_t wvec0_3_f32 = vcvt_f32_f16(vld1_f16(&w[0]));
-      float32x4_t wvec4_7_f32 = vcvt_f32_f16(vld1_f16(&w[4]));
-      float32x4_t wvec8_11_f32 = vcvt_f32_f16(vld1_f16(&w[8]));
-      float32x4_t wvec12_15_f32 = vcvt_f32_f16(vld1_f16(&w[12]));
+      float16x8_t wvec0_7 = vld1q_f16(&w[0]);
+      float16x8_t wvec8_15 = vld1q_f16(&w[8]);
+
+      float16x8_t y = vmulq_f16(wvec0_7, x0_7);
+      y = vfmaq_f16(y, wvec8_15, x8_15);
 
-      float32x4_t y0 = vmulq_f32(wvec0_3_f32, x0_3_f32);
-      y0 = vfmaq_f32(y0, wvec4_7_f32, x4_7_f32);
-      y0 = vfmaq_f32(y0, wvec8_11_f32, x8_11_f32);
-      y0 = vfmaq_f32(y0, wvec12_15_f32, x12_15_f32);
+      yVal_low = vaddvq_f32(vcvt_f32_f16(vget_low_f16(y)));
+      yVal_high = vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
 
-      Y32[j] += vaddvq_f32(y0);
+      Y32[j] += yVal_low + yVal_high;
     }
   }
   for (; cols - idx >= 8; idx += 8) {
-    float32x4_t x0_3_f32 = vcvt_f32_f16(vld1_f16(&X[idx]));
-    float32x4_t x4_7_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 4]));
+    float16x8_t x0_7 = vld1q_f16(&X[idx]);
 
     if (alpha != 1.0) {
-      x0_3_f32 = vmulq_n_f32(x0_3_f32, alpha);
-      x4_7_f32 = vmulq_n_f32(x4_7_f32, alpha);
+      x0_7 = vmulq_n_f16(x0_7, alpha);
     }
 
     const __fp16 *__restrict w;
 
+    float yVal_low;
+    float yVal_high;
+
     for (unsigned int j = 0; j < rows; ++j) {
       w = &A[j * cols + idx];
-      float32x4_t wvec0_3_f32 = vcvt_f32_f16(vld1_f16(&w[0]));
-      float32x4_t wvec4_7_f32 = vcvt_f32_f16(vld1_f16(&w[4]));
+      float16x8_t wvec0_7 = vld1q_f16(&w[0]);
+      float16x8_t y = vmulq_f16(wvec0_7, x0_7);
 
-      float32x4_t y0 = vmulq_f32(wvec0_3_f32, x0_3_f32);
-      y0 = vfmaq_f32(y0, wvec4_7_f32, x4_7_f32);
+      yVal_low = vaddvq_f32(vcvt_f32_f16(vget_low_f16(y)));
+      yVal_high = vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
 
-      Y32[j] += vaddvq_f32(y0);
+      Y32[j] += yVal_low + yVal_high;
     }
   }
   for (; cols - idx >= 4; idx += 4) {
-    float32x4_t x0_3_f32 = vcvt_f32_f16(vld1_f16(&X[idx]));
+    float16x4_t x0_3 = vld1_f16(&X[idx]);
 
     if (alpha != 1.0) {
-      x0_3_f32 = vmulq_n_f32(x0_3_f32, alpha);
+      x0_3 = vmul_n_f16(x0_3, alpha);
     }
 
     const __fp16 *__restrict w;
 
     for (unsigned int j = 0; j < rows; ++j) {
       w = &A[j * cols + idx];
-      float32x4_t wvec0_3_f32 = vcvt_f32_f16(vld1_f16(&w[0]));
-      float32x4_t y0 = vmulq_f32(wvec0_3_f32, x0_3_f32);
+      float16x4_t wvec0_3 = (vld1_f16(&w[0]));
+      float16x4_t y0 = vmul_f16(wvec0_3, x0_3);
 
-      Y32[j] += vaddvq_f32(y0);
+      Y32[j] += vaddvq_f32(vcvt_f32_f16(y0));
     }
   }
 
@@ -425,14 +467,14 @@ void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows,
       x0_3[j] = 0;
     }
 
-    float32x4_t x0_3_f32 = vcvt_f32_f16(x0_3);
-
     if (alpha != 1.0) {
-      x0_3_f32 = vmulq_n_f32(x0_3_f32, alpha);
+      x0_3 = vmul_n_f16(x0_3, alpha);
     }
 
     const __fp16 *__restrict w;
 
+    __fp16 yVal;
+
     for (unsigned int j = 0; j < rows; ++j) {
       w = &A[j * cols + idx];
       float16x4_t wvec0_3 = vld1_f16(&w[0]);
@@ -441,11 +483,11 @@ void sgemv_neon_fp16(const __fp16 *A, const __fp16 *X, __fp16 *Y, uint32_t rows,
         wvec0_3[k] = 0;
       }
 
-      float32x4_t wvec0_3_f32 = vcvt_f32_f16(wvec0_3);
+      float16x4_t y0 = vmul_f16(wvec0_3, x0_3);
 
-      float32x4_t y0 = vmulq_f32(wvec0_3_f32, x0_3_f32);
-
-      Y32[j] += vaddvq_f32(y0);
+      for (int k = 0; k < cols - idx; ++k) {
+        Y32[j] += y0[k];
+      }
     }
   }