[ hgemm ] Support scaling factor beta in kernel-based hgemm
authorskykongkong8 <ss.kong@samsung.com>
Tue, 14 May 2024 07:29:53 +0000 (16:29 +0900)
committerMyungJoo Ham <myungjoo.ham@samsung.com>
Mon, 10 Jun 2024 09:43:30 +0000 (18:43 +0900)
- This commit allows hgemm to get beta condition as well.
- Note that beta for here is as follow:
C = alpha * A * B + beta * C
- In addition add zero-init code for beta = 0.F case. According to recent model profiling result, even for initialization, minimizing instruction is quite helpful more overall model latency reduction.

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
nntrainer/tensor/blas_neon.cpp
nntrainer/tensor/hgemm/hgemm.cpp

index 32c8eb9349f9b4a54a70df80f4b7d90872129369..c8249fb4b8b9103f28038819294127da31bce179 100644 (file)
@@ -1595,22 +1595,33 @@ void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, uint32_t N,
   // performing beta*C
   unsigned int idx = 0;
   unsigned int size = M * N;
-  for (; idx < (size - idx) && (size - idx) >= 8; idx += 8) {
-    float16x8_t c = vmulq_n_f16(vld1q_f16(&C[idx]), static_cast<__fp16>(beta));
+  if (beta != 0.F) {
+    for (; idx < (size - idx) && (size - idx) >= 8; idx += 8) {
+      float16x8_t c =
+        vmulq_n_f16(vld1q_f16(&C[idx]), static_cast<__fp16>(beta));
 
-    vst1q_f32(&C32[idx], vcvt_f32_f16(vget_low_f16(c)));
-    vst1q_f32(&C32[idx + 4], vcvt_f32_f16(vget_high_f16(c)));
-  }
-  // remaining 4
-  for (; idx < (size - idx) && (size - idx) >= 4; idx += 4) {
-    float16x4_t c = vmul_n_f16(vld1_f16(&C[idx]), static_cast<__fp16>(beta));
+      vst1q_f32(&C32[idx], vcvt_f32_f16(vget_low_f16(c)));
+      vst1q_f32(&C32[idx + 4], vcvt_f32_f16(vget_high_f16(c)));
+    }
+    // remaining 4
+    for (; idx < (size - idx) && (size - idx) >= 4; idx += 4) {
+      float16x4_t c = vmul_n_f16(vld1_f16(&C[idx]), static_cast<__fp16>(beta));
 
-    vst1q_f32(&C32[idx], vcvt_f32_f16(c));
-  }
+      vst1q_f32(&C32[idx], vcvt_f32_f16(c));
+    }
 
-  // remaining values if dimensions not a multiple of 8
-  for (; idx < size; idx++) {
-    C32[idx] = C[idx] * beta;
+    // remaining values if dimensions not a multiple of 8
+    for (; idx < size; idx++) {
+      C32[idx] = C[idx] * beta;
+    }
+  } else {
+    float32x4_t zeros = vmovq_n_f32(0.F);
+    for (; idx < (size - idx) && (size - idx) >= 4; idx += 4) {
+      vst1q_f32(&C32[idx], zeros);
+    }
+    for (; idx < size; idx++) {
+      C32[idx] = 0.F;
+    }
   }
 
   if (!TransA && TransB) {
index be61cd5c91a46f03e1bee995761e221064ee1b2e..dd1c173a99e45e1fed4f5a175e50ae4ed3216137 100644 (file)
@@ -31,7 +31,7 @@
 
 void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
                    unsigned int N, unsigned int K, float alpha, float beta) {
-  if (alpha == 1.F && beta == 0.F && N > 4) {
+  if (alpha == 1.F) {
     // used bitwise operator instead of modulo for performance
     // e.g (M % 8) is same as (M & 0x7) which will extract last 3 bits of M
     if ((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0) {
@@ -53,7 +53,7 @@ void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
 
 void hgemm_noTrans(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
                    unsigned int N, unsigned int K, float alpha, float beta) {
-  if (alpha == 1.F && beta == 0.F) {
+  if (alpha == 1.F) {
     // used bitwise operator instead of modulo for performance
     // e.g (M % 8) is same as (M & 0x7) which will extract last 3 bits of M
     if ((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0) {