[blas/neon] NEON fp16 implementation of SGEMM
authorDebadri Samaddar <s.debadri@samsung.com>
Wed, 30 Aug 2023 13:57:47 +0000 (19:27 +0530)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 7 Sep 2023 00:03:35 +0000 (09:03 +0900)
SGEMM fp16 implmentation for Android(ARM) using NEON.

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Debadri Samaddar <s.debadri@samsung.com>
nntrainer/tensor/blas_interface.cpp
nntrainer/tensor/blas_neon.cpp
nntrainer/tensor/blas_neon.h

index ced2070de8d4762b23316e06e3715a46a1a72e5b..fffb4095b3f4438cfe21812166d8d118de7d6d87 100644 (file)
       Y[i * incY] = Y[i * incY] + static_cast<_FP16>(alpha) * X[i * incX]; \
   } while (0);
 
+#define sgemm_loop_fp16()                                                 \
+  do {                                                                    \
+    for (unsigned int m = 0; m < M; ++m) {                                \
+      for (unsigned int n = 0; n < N; ++n) {                              \
+        _FP16 c = 0;                                                      \
+        _FP16 c_old = C[m * ldc + n];                                     \
+        for (unsigned int k = 0; k < K; ++k) {                            \
+          _FP16 a, b;                                                     \
+          a = ((TransA == CblasTrans) ? A[k * lda + m] : A[m * lda + k]); \
+          b = ((TransB == CblasTrans) ? B[n * ldb + k] : B[k * ldb + n]); \
+          c += a * b;                                                     \
+        }                                                                 \
+        C[m * ldc + n] = static_cast<_FP16>(alpha) * c;                   \
+        if (beta != 0.0)                                                  \
+          C[m * ldc + n] += static_cast<_FP16>(beta) * c_old;             \
+      }                                                                   \
+    }                                                                     \
+  } while (0);
+
 namespace nntrainer {
 
 #ifdef ENABLE_FP16
@@ -193,21 +212,17 @@ static void sgemm_FP16(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA,
                        const unsigned int ldb, const float beta, _FP16 *C,
                        const unsigned int ldc) {
 
-  for (unsigned int m = 0; m < M; ++m) {
-    for (unsigned int n = 0; n < N; ++n) {
-      _FP16 c = 0;
-      _FP16 c_old = C[m * ldc + n];
-      for (unsigned int k = 0; k < K; ++k) {
-        _FP16 a, b;
-        a = ((TransA == CblasTrans) ? A[k * lda + m] : A[m * lda + k]);
-        b = ((TransB == CblasTrans) ? B[n * ldb + k] : B[k * ldb + n]);
-        c += a * b;
-      }
-      C[m * ldc + n] = static_cast<_FP16>(alpha) * c;
-      if (beta != 0.0)
-        C[m * ldc + n] += static_cast<_FP16>(beta) * c_old;
-    }
+#ifdef USE__FP16
+  if ((M % 8 == 0) && (N % 8 == 0) && (K % 8 == 0)) {
+    nntrainer::neon::sgemm_neon_fp16(A, B, C, M, N, K, alpha, beta,
+                                     TransA == CblasTrans,
+                                     TransB == CblasTrans);
+  } else {
+    sgemm_loop_fp16();
   }
+#else
+  sgemm_loop_fp16();
+#endif
 }
 
 static unsigned int isamax_FP16(const unsigned int N, const _FP16 *X,
index 3d9cf2ee50632f5f7de3f53b2628ec1f2d33e01d..3249ab21a7bf4a1e8f32e2f6157867a6ff350fd4 100644 (file)
@@ -741,6 +741,90 @@ unsigned int isamax_neon_fp16(const unsigned int N, const __fp16 *X) {
 
   return retIdx;
 }
+
+void sgemm_neon_fp16(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
+                     uint32_t N, uint32_t K, float alpha, float beta,
+                     bool TransA, bool TransB) {
+
+  float16x8_t v_alpha = vmovq_n_f16(alpha);
+  float16x8_t v_beta = vmovq_n_f16(beta);
+
+  // performing beta*C
+  for (unsigned int idx = 0; idx < (M * N); idx += 8) {
+    float16x8_t c = vld1q_f16(&C[idx]);
+    c = vmulq_f16(v_beta, c);
+    vst1q_f16(&C[idx], c);
+  }
+
+  __fp16 r[4];
+
+  if (!TransA && TransB) {
+    for (unsigned int m = 0; m < M; m++) {
+      for (unsigned int n = 0; n < N; n++) {
+        float16x8_t sum = vmovq_n_f16(0);
+
+        for (unsigned int k = 0; k < K; k += 8) {
+          float16x8_t a = vld1q_f16(&A[m * K + k]);
+          float16x8_t b = vld1q_f16(&B[n * K + k]);
+          sum = vfmaq_f16(sum, a, b);
+        }
+        sum = vmulq_f16(v_alpha, sum);
+
+        float16x4_t sum_high = vget_high_f16(sum);
+        float16x4_t sum_low = vget_low_f16(sum);
+
+        sum_low = vadd_f16(sum_high, sum_low);
+        vst1_f16(r, sum_low);
+
+        C[m * N + n] += r[0] + r[1] + r[2] + r[3];
+      }
+    }
+  } else if (TransA && !TransB) {
+    for (unsigned int k = 0; k < K; k++) {
+      for (unsigned int m = 0; m < M; m++) {
+        __fp16 a = alpha * A[k * M + m];
+
+        for (unsigned int n = 0; n < N; n += 8) {
+          float16x8_t b = vld1q_f16(&B[k * N + n]);
+
+          // load previously calculated C
+          float16x8_t c = vld1q_f16(&C[m * N + n]);
+          c = vfmaq_n_f16(c, b, a);
+          vst1q_f16(&C[m * N + n], c);
+        }
+      }
+    }
+  } else if (!TransA && !TransB) {
+    for (unsigned int k = 0; k < K; k++) {
+      for (unsigned int m = 0; m < M; m++) {
+        __fp16 a = alpha * A[m * K + k];
+
+        for (unsigned int n = 0; n < N; n += 8) {
+          float16x8_t b = vld1q_f16(&B[k * N + n]);
+
+          // load previously calculated C
+          float16x8_t c = vld1q_f16(&C[m * N + n]);
+          c = vfmaq_n_f16(c, b, a);
+          vst1q_f16(&C[m * N + n], c);
+        }
+      }
+    }
+  } else { // TransA && TransB
+    for (unsigned int m = 0; m < M; m++) {
+      for (unsigned int n = 0; n < N; n++) {
+        __fp16 sum = 0;
+        for (int k = 0; k < K; k++) {
+          __fp16 a = A[k * M + m];
+          __fp16 b = B[n * K + k];
+          sum += a * b;
+        }
+
+        sum = alpha * sum;
+        C[m * N + n] += sum;
+      }
+    }
+  }
+}
 #endif
 
 } // namespace nntrainer::neon
index 260fda001a2c3dfe453a8bd04f02d7689c70445a..5ce576b89ef3710d69e7c1e01b65225bb201bca4 100644 (file)
@@ -124,6 +124,22 @@ void scopy_neon_fp16(const unsigned int N, const __fp16 *X, __fp16 *Y);
  * @param[in] X __fp16 * for Vector X
  */
 unsigned int isamax_neon_fp16(const unsigned int N, const __fp16 *X);
+
+/**
+ * @brief     sgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * where op(X) is one of X or X**T
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void sgemm_neon_fp16(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
+                     uint32_t N, uint32_t K, float alpha, float beta,
+                     bool TransA, bool TransB);
 #endif
 
 } // namespace nntrainer::neon