[blas/neon] SGEMM Neon execution for any M value
authorDebadri Samaddar <s.debadri@samsung.com>
Fri, 8 Sep 2023 11:13:34 +0000 (16:43 +0530)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 8 Sep 2023 12:55:19 +0000 (21:55 +0900)
Used padded calculations for SGEMM using NEON for any value of M.
Where M is the number of rows in output matrix.

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Debadri Samaddar <s.debadri@samsung.com>
nntrainer/tensor/blas_interface.cpp
nntrainer/tensor/blas_neon.cpp

index fffb409..03c8043 100644 (file)
@@ -213,7 +213,7 @@ static void sgemm_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
                        const unsigned int ldc) {
 
 #ifdef USE__FP16
-  if ((M % 8 == 0) && (N % 8 == 0) && (K % 8 == 0)) {
+  if ((N % 8 == 0) && (K % 8 == 0)) {
     nntrainer::neon::sgemm_neon_fp16(A, B, C, M, N, K, alpha, beta,
                                      TransA == CblasTrans,
                                      TransB == CblasTrans);
index 4110546..de13426 100644 (file)
@@ -750,12 +750,19 @@ void sgemm_neon_fp16(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
   float16x8_t v_beta = vmovq_n_f16(beta);
 
   // performing beta*C
-  for (unsigned int idx = 0; idx < (M * N); idx += 8) {
+  unsigned int idx = 0;
+  unsigned int size = M * N;
+  for (; idx < (size - idx) >= 8; idx += 8) {
     float16x8_t c = vld1q_f16(&C[idx]);
     c = vmulq_f16(v_beta, c);
     vst1q_f16(&C[idx], c);
   }
 
+  // remaining values if dimensions not a multiple of 8
+  for (; idx < size; idx++) {
+    C[idx] *= beta;
+  }
+
   __fp16 r[4];
 
   if (!TransA && TransB) {
@@ -815,7 +822,8 @@ void sgemm_neon_fp16(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
       for (unsigned int k = 0; k < K; k++) {
 
         __fp16 b = alpha * B[n * K + k];
-        for (unsigned int m = 0; m < M; m += 8) {
+        unsigned int m = 0;
+        for (; (M - m) >= 8; m += 8) {
           float16x8_t a = vld1q_f16(&A[k * M + m]);
           a = vmulq_n_f16(a, b);
           vst1q_f16(vals, a);
@@ -824,6 +832,21 @@ void sgemm_neon_fp16(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
           for (unsigned int idx = m; idx < m + 8; idx++)
             C[idx * N + n] += vals[idx - m];
         }
+
+        // remaining when M is not a multiple of 8
+        if (m < M) {
+          for (idx = m; idx < M; idx++) {
+            vals[idx - m] = A[k * M + idx];
+          }
+
+          float16x8_t a = vld1q_f16(vals);
+          a = vmulq_n_f16(a, b);
+          vst1q_f16(vals, a);
+
+          // calculations for all remaining M values
+          for (idx = m; idx < M; idx++)
+            C[idx * N + n] += vals[idx - m];
+        }
       }
     }
   }