[ hgemm ] Use zero padding in Non-8-divisible GEMM case
authorskykongkong8 <ss.kong@samsung.com>
Thu, 20 Jun 2024 10:56:08 +0000 (19:56 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 28 Jun 2024 04:48:47 +0000 (13:48 +0900)
- For temporary solution apply zero padding in non-8-K divisible case.

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
nntrainer/tensor/hgemm/hgemm.cpp
nntrainer/tensor/hgemm/hgemm.h
nntrainer/tensor/hgemm/hgemm_common.h

index b8827d0bd6a76d8364a18e17f327ae4b7a78fb76..c8a31f216c4eb9127bc705d743cee52a7d00c2b0 100644 (file)
@@ -36,26 +36,37 @@ void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
                    unsigned int N, unsigned int K, float alpha, float beta) {
   const float eps = std::numeric_limits<float>::epsilon();
   if (std::abs(alpha - 1.F) < eps) {
-    // used bitwise operator instead of modulo for performance
-    // e.g (M % 8) is same as (M & 0x7) which will extract last 3 bits of M
-    if ((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0) {
-      hgemm_noTrans_8x16(M, N, K, A, K, B, N, C32, N, alpha, beta);
-    } else if ((M & 0x7) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) {
-      hgemm_noTrans_8x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
-    } else if ((M & 0x3) == 0 && (N & 0x7) == 0 && (K & 0x3) == 0) {
-      hgemm_noTrans_4x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
-    } else if ((K & 0x7) == 0 && (N & 0x7) == 0) {
-      hgemm_noTrans_1x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
-    } else if ((K & 0x7) == 0 && (N & 0x3) == 0) {
-      hgemm_noTrans_1x4(M, N, K, A, K, B, N, C32, N, alpha, beta);
+    if ((K & 0x7) != 0) {
+      hgemm_noTrans_padding_wrt_K(A, B, C32, M, N, K, alpha, beta);
     } else {
-      hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
+      hgemm_noTrans_strict(A, B, C32, M, N, K, alpha, beta);
     }
-  } else
+  } else {
+    hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
+  }
+}
+
+void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, float *C32,
+                          unsigned int M, unsigned int N, unsigned int K,
+                          float alpha, float beta) {
+  // used bitwise operator instead of modulo for performance
+  // e.g (M % 8) is same as (M & 0x7) which will extract last 3 bits of M
+  if ((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0) {
+    hgemm_noTrans_8x16(M, N, K, A, K, B, N, C32, N, alpha, beta);
+  } else if ((M & 0x7) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) {
+    hgemm_noTrans_8x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
+  } else if ((M & 0x3) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) {
+    hgemm_noTrans_4x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
+  } else if ((K & 0x7) == 0 && (N & 0x7) == 0) {
+    hgemm_noTrans_1x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
+  } else if ((K & 0x7) == 0 && (N & 0x3) == 0) {
+    hgemm_noTrans_1x4(M, N, K, A, K, B, N, C32, N, alpha, beta);
+  } else {
     hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
+  }
 }
 
-void hgemm_noTrans(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
+void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
                    unsigned int N, unsigned int K, float alpha, float beta) {
   if (alpha == 1.F) {
     // used bitwise operator instead of modulo for performance
@@ -76,6 +87,62 @@ void hgemm_noTrans(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
   }
 }
 
+void hgemm_noTrans_padding_wrt_K(const __fp16 *A, const __fp16 *B, float *C,
+                                 unsigned int M, unsigned int N, unsigned int K,
+                                 float alpha, float beta) {
+  const unsigned int K8_high = ((K - 1) / 8 + 1) * 8;
+  const unsigned int K8_low = (K >> 3) << 3;
+
+  const unsigned int lda = K;
+  const unsigned int ldb = N;
+
+  __fp16 *A8 = new __fp16[M * K8_high];
+  __fp16 *B8 = new __fp16[K8_high * N];
+
+  float16x8_t ZEROS = vmovq_n_f16(0.F);
+
+  // Make zero-padded A matrix
+  for (unsigned int m = 0; m < M; ++m) {
+    unsigned int k = 0;
+    for (; k < K8_low; k += 8) {
+      vst1q_f16(&A8[m * K8_high + k], vld1q_f16(&A[m * K + k]));
+    }
+    for (; k < K; ++k) {
+      A8[m * K8_high + k] = A[m * K + k];
+    }
+    for (; k < K8_high; ++k) {
+      A8[m * K8_high + k] = 0.F;
+    }
+  }
+
+  // Make zero-padded B matrix
+  unsigned int k = 0;
+  unsigned int N8 = (N >> 3) << 3;
+  for (; k < K; ++k) {
+    unsigned int n = 0;
+    for (; n < N8; n += 8) {
+      vst1q_f16(&B8[k * N + n], vld1q_f16(&B[k * N + n]));
+    }
+    for (; n < N; ++n) {
+      B8[k * N + n] = B[k * N + n];
+    }
+  }
+  for (; k < K8_high; ++k) {
+    unsigned int n = 0;
+    for (; n < N8; n += 8) {
+      vst1q_f16(&B8[k * N + n], ZEROS);
+    }
+    for (; n < N; ++n) {
+      B8[k * N + n] = 0.F;
+    }
+  }
+
+  hgemm_noTrans_strict(A8, B8, C, M, N, K8_high, alpha, beta);
+
+  free(A8);
+  free(B8);
+}
+
 void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
                        const __fp16 *A, unsigned int lda, const __fp16 *B,
                        unsigned int ldb, __fp16 *C, unsigned int ldc,
@@ -762,6 +829,7 @@ void hgemm_noTrans_8x16(unsigned int M, unsigned int N, unsigned int K,
                         const __fp16 *A, unsigned int lda, const __fp16 *B,
                         unsigned int ldb, __fp16 *C, unsigned int ldc,
                         float alpha, float beta) {
+// M, N, K is full M, N, K here
 
   __fp16 *sA = alignedMalloc(M * K);
   __fp16 *sB = alignedMalloc(K * N);
index 8a071eadf3c173e99f5cbd9eb9862f07bef0aa17..87bdb0cc4da76f5d2b1a80d5e8529e8c4064560c 100644 (file)
@@ -38,10 +38,55 @@ void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
  * @param[in] alpha float number
  * @param[in] beta float number
  */
-void hgemm_noTrans(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
+void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
                    unsigned int N, unsigned int K, float alpha = 1.F,
                    float beta = 0.F);
 
+/**
+ * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, float *C,
+                          unsigned int M, unsigned int N, unsigned int K,
+                          float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_padding_wrt_K(const __fp16 *A, const __fp16 *B, float *C,
+                                 unsigned int M, unsigned int N, unsigned int K,
+                                 float alpha = 1.F, float beta = 0.F);
+
+                                 /**
+ * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_padding_wrt_K4(const __fp16 *A, const __fp16 *B, float *C,
+                                 unsigned int M, unsigned int N, unsigned int K,
+                                 float alpha = 1.F, float beta = 0.F);
+
 /**
  * @brief     hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
  * @param M length of the row of matrix A
index 942648e5008b4cb1bbff6613cd0632ee52fd997c..d5c965c477fb18056501d8172659a49af8930a76 100644 (file)
@@ -17,7 +17,7 @@
 #define B(i, j) b[(i) * ldb + (j)]
 #define C(i, j) c[(i) * ldc + (j)]
 
-#define N_BLOCKING (384)
+#define N_BLOCKING (768)
 #define K_BLOCKING (256)
 #define M_BLOCKING (4096)
 #define GEMM_UNROLLING_16 (16)