[ hgemm ] Implement packing-blocking-kernel sequence for hgemm transB
authorskykongkong8 <ss.kong@samsung.com>
Wed, 10 Jul 2024 03:27:39 +0000 (12:27 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Tue, 30 Jul 2024 22:45:30 +0000 (07:45 +0900)
- Previously, hgemm transB computation was relying on transposing the entire matrix and using non-transpose sequence.
- For optimal performance, matrix packing-blocking-kernel sequence for transB case is explicitly implemented.
- Note that current implementation only supports for 8x16 gemm kernel.

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
nntrainer/tensor/blas_neon.cpp
nntrainer/tensor/hgemm/hgemm.cpp
nntrainer/tensor/hgemm/hgemm.h

index 3d5a0937b721543c5e480a53f5b6f85721f99ffd..8b931c729547b41e0f39852aac2d87a8aa21835d 100644 (file)
@@ -1628,16 +1628,8 @@ void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, uint32_t N,
     }
   }
 
-  if (!TransA && TransB) {
-    hgemm_transB(A, B, C32, M, N, K, alpha, beta);
-  } else if (TransA && !TransB) {
-    hgemm_transA(A, B, C32, M, N, K, alpha, beta);
-  } else if (!TransA && !TransB) {
-    hgemm_noTrans(A, B, C32, M, N, K, alpha, beta);
-  } else { // TransA && TransB
-    hgemm_transAB(A, B, C32, M, N, K, alpha, beta);
-  }
-
+  hgemm_ensure_divisibility(A, B, C32, M, N, K, alpha, beta, TransA, TransB);
+  
   copy_fp32_to_fp16(M * N, C32, C);
   free(C32);
 }
index e71136f3d10d8f1d72d75e6a2d68736631229cc1..fc99d701cbe8b7e829969af2de71be45eb4e255d 100644 (file)
@@ -21,6 +21,7 @@
 #include <hgemm_kernel_8x16.h>
 #include <hgemm_kernel_8x8.h>
 #include <hgemm_kernel_pack.h>
+#include <hgemm_padding.h>
 #include <hgemm_util.h>
 #include <limits>
 #include <matrix_transpose_neon.h>
@@ -36,11 +37,7 @@ void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
                    unsigned int N, unsigned int K, float alpha, float beta) {
   const float eps = std::numeric_limits<float>::epsilon();
   if (std::abs(alpha - 1.F) < eps) {
-    if ((K & 0x7) != 0) {
-      hgemm_noTrans_padding_wrt_K(A, B, C32, M, N, K, alpha, beta);
-    } else {
-      hgemm_noTrans_strict(A, B, C32, M, N, K, alpha, beta);
-    }
+    hgemm_noTrans_strict(A, B, C32, M, N, K, alpha, beta);
   } else {
     hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
   }
@@ -88,60 +85,68 @@ void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, __fp16 *C,
   }
 }
 
-void hgemm_noTrans_padding_wrt_K(const __fp16 *A, const __fp16 *B, float *C,
-                                 unsigned int M, unsigned int N, unsigned int K,
-                                 float alpha, float beta) {
-  const unsigned int K8_high = ((K - 1) / 8 + 1) * 8;
-  const unsigned int K8_low = (K >> 3) << 3;
-
-  const unsigned int lda = K;
-  const unsigned int ldb = N;
+void hgemm_ensure_divisibility(const __fp16 *A, const __fp16 *B, float *C32,
+                               unsigned int M, unsigned int N, unsigned int K,
+                               float alpha, float beta, bool TransA,
+                               bool TransB) {
+  /// @note Padding standard : 8x16 is the only KERNEL that outperforms single
+  /// precision GEMM 'so far'. Padding will forcibly make every GEMM cases to
+  /// use it. Note that padding is not the optimal way here, but just an option
+  /// that is easier to implement. Fine-grained packing should be supported on
+  /// the future for optimal performance.
 
-  __fp16 *A8 = alignedMalloc(M * K8_high);
-  __fp16 *B8 = alignedMalloc(K8_high * N);
+  __fp16 *A_ = (__fp16 *)A, *B_ = (__fp16 *)B;
+  unsigned int M_ = M, N_ = N, K_ = K;
+  bool pad_A = false, pad_B = false;
 
-  float16x8_t ZEROS = vmovq_n_f16(0.F);
+  // Case 2 : smaller than 8, 16 | padding would be redundant?
+  if (M < 8 && K < 16 && N < 16)
+    return hgemm_classify(A_, B_, C32, M_, N_, K_, alpha, beta, TransA, TransB);
 
-  // Make zero-padded A matrix
-  for (unsigned int m = 0; m < M; ++m) {
-    unsigned int k = 0;
-    for (; k < K8_low; k += 8) {
-      vst1q_f16(&A8[m * K8_high + k], vld1q_f16(&A[m * K + k]));
-    }
-    for (; k < K; ++k) {
-      A8[m * K8_high + k] = A[m * K + k];
-    }
-    for (; k < K8_high; ++k) {
-      A8[m * K8_high + k] = 0.F;
-    }
-  }
+  __fp16 *Ap;
+  __fp16 *Bp;
 
-  // Make zero-padded B matrix
-  unsigned int k = 0;
-  unsigned int N8 = (N >> 3) << 3;
-  for (; k < K; ++k) {
-    unsigned int n = 0;
-    for (; n < N8; n += 8) {
-      vst1q_f16(&B8[k * N + n], vld1q_f16(&B[k * N + n]));
-    }
-    for (; n < N; ++n) {
-      B8[k * N + n] = B[k * N + n];
-    }
+  const unsigned int M8_high = ((M - 1) / 8 + 1) * 8;
+  const unsigned int K8_high = ((K - 1) / 8 + 1) * 8;
+  const unsigned int N16_high = ((N - 1) / 16 + 1) * 16;
+
+  if ((M8_high != M) || (K8_high != K)) {
+    pad_A = true;
+    Ap = alignedMalloc(M8_high * K8_high);
+    hgemm_padding_A(A, Ap, M, K, M8_high, K8_high, TransA);
+    A_ = Ap;
+    M_ = M8_high;
+    K_ = K8_high;
   }
-  for (; k < K8_high; ++k) {
-    unsigned int n = 0;
-    for (; n < N8; n += 8) {
-      vst1q_f16(&B8[k * N + n], ZEROS);
-    }
-    for (; n < N; ++n) {
-      B8[k * N + n] = 0.F;
-    }
+  if ((K8_high != K) || (N16_high != N)) {
+    pad_B = true;
+    Bp = alignedMalloc(K8_high * N16_high);
+    hgemm_padding_B(B, Bp, K, N, K8_high, N16_high, TransB);
+    B_ = Bp;
+    K_ = K8_high;
+    N_ = N16_high;
   }
 
-  hgemm_noTrans_strict(A8, B8, C, M, N, K8_high, alpha, beta);
+  hgemm_classify(A_, B_, C32, M_, N_, K_, alpha, beta, TransA, TransB);
+
+  if (pad_A)
+    free(Ap);
+  if (pad_B)
+    free(Bp);
+}
 
-  free(A8);
-  free(B8);
+void hgemm_classify(const __fp16 *A, const __fp16 *B, float *C32,
+                    unsigned int M, unsigned int N, unsigned int K, float alpha,
+                    float beta, bool TransA, bool TransB) {
+  if (!TransA && !TransB) {
+    hgemm_noTrans(A, B, C32, M, N, K, alpha, beta);
+  } else if (TransA && !TransB) {
+    hgemm_transA(A, B, C32, M, N, K, alpha, beta);
+  } else if (!TransA && TransB) {
+    hgemm_transB(A, B, C32, M, N, K, alpha, beta);
+  } else { // TransA && TransB
+    hgemm_transAB(A, B, C32, M, N, K, alpha, beta);
+  }
 }
 
 void hgemm_K1_noTrans(unsigned int M, unsigned int N, unsigned int K,
@@ -898,8 +903,6 @@ void hgemm_noTrans_8x16(unsigned int M, unsigned int N, unsigned int K,
                         const __fp16 *A, unsigned int lda, const __fp16 *B,
                         unsigned int ldb, __fp16 *C, unsigned int ldc,
                         float alpha, float beta) {
-  // M, N, K is full M, N, K here
-
   __fp16 *sA = alignedMalloc(M * K);
   __fp16 *sB = alignedMalloc(K * N);
 
@@ -972,7 +975,6 @@ void hgemm_noTrans_8x16(unsigned int M, unsigned int N, unsigned int K,
                         const __fp16 *A, unsigned int lda, const __fp16 *B,
                         unsigned int ldb, float *C, unsigned int ldc,
                         float alpha, float beta) {
-
   __fp16 *sA = alignedMalloc(M * K);
   __fp16 *sB = alignedMalloc(K * N);
 
@@ -1027,10 +1029,13 @@ void hgemm_noTrans_8x16(unsigned int M, unsigned int N, unsigned int K,
         if (n_min >= N_BLOCKING * 2) {
           n_min = N_BLOCKING;
         } else if (n_min > N_BLOCKING) {
-          n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+          n_min =
+            (n_min / 2 + GEMM_UNROLLING_16 - 1) & ~(GEMM_UNROLLING_16 - 1);
+
         }
 
         packing_B16(k_min, n_min, B + ns + ldb * ks, ldb, sB);
+
         HGEMM_KERNEL_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
       }
     }
@@ -1324,8 +1329,88 @@ void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K,
   }
 }
 
+void hgemm_transB_8x16(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha, float beta) {
+
+  __fp16 *sA = alignedMalloc(M * K);
+  __fp16 *sB = alignedMalloc(K * N);
+
+  unsigned int ms, ms2, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  unsigned int stride_l1 = 1;
+
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = ((n_min / 2 + GEMM_UNROLLING_16 - 1) / GEMM_UNROLLING_16) *
+                GEMM_UNROLLING_16;
+      } else {
+        stride_l1 = 0;
+      }
+      packing_transB16(k_min, n_min, B + (ks), ldb, sB);
+
+      for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) {
+        m2_min = (ms + m_min) - ms2;
+        if (m2_min >= 3 * GEMM_UNROLLING_8) {
+          m2_min = 3 * GEMM_UNROLLING_8;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
+          m2_min = 2 * GEMM_UNROLLING_8;
+        } else if (m2_min > GEMM_UNROLLING_8) {
+          m2_min = GEMM_UNROLLING_8;
+        }
+        packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda,
+                   sA + k_min * (ms2 - ms) * stride_l1);
+        HGEMM_KERNEL_8x16(m2_min, n_min, k_min,
+                          sA + k_min * (ms2 - ms) * stride_l1, sB,
+                          C + ms2 * ldc, ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+        }
+        packing_transB16(k_min, n_min, B + ks + ldb * ns, ldb, sB);
+        HGEMM_KERNEL_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sA);
+  free(sB);
+}
+
 void hgemm_transB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
                   unsigned int N, unsigned int K, float alpha, float beta) {
+  if (((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0)) {
+    return hgemm_transB_8x16(M, N, K, A, K, B, K, C, N, alpha, beta);
+  } else {
+    return hgemm_transB_fallback(A, B, C, M, N, K, alpha, beta);
+  }
+}
+
+void hgemm_transB_fallback(const __fp16 *A, const __fp16 *B, float *C,
+                           unsigned int M, unsigned int N, unsigned int K,
+                           float alpha, float beta) {
   __fp16 *B_T = alignedMalloc(K * N);
 
   transpose_neon<__fp16>(N, K, B, K, B_T, N);
index 029076cd5b54d5659a8e3e836bf5bf998be742e3..c380eb8304dbf5c5bd1cc6bcd873d775055ab545 100644 (file)
@@ -38,9 +38,9 @@ void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
  * @param[in] alpha float number
  * @param[in] beta float number
  */
-void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
-                   unsigned int N, unsigned int K, float alpha = 1.F,
-                   float beta = 0.F);
+void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, __fp16 *C,
+                          unsigned int M, unsigned int N, unsigned int K,
+                          float alpha = 1.F, float beta = 0.F);
 
 /**
  * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
@@ -68,11 +68,12 @@ void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, float *C,
  * @param[in] alpha float number
  * @param[in] beta float number
  */
-void hgemm_noTrans_padding_wrt_K(const __fp16 *A, const __fp16 *B, float *C,
-                                 unsigned int M, unsigned int N, unsigned int K,
-                                 float alpha = 1.F, float beta = 0.F);
+void hgemm_ensure_divisibility(const __fp16 *A, const __fp16 *B, float *C32,
+                               unsigned int M, unsigned int N, unsigned int K,
+                               float alpha = 1.F, float beta = 0.F,
+                               bool TransA = false, bool TransB = false);
 
-                                 /**
+/**
  * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
  * @param[in] A __fp16 * for Matrix A
  * @param[in] B __fp16 * for Matrix B
@@ -83,9 +84,9 @@ void hgemm_noTrans_padding_wrt_K(const __fp16 *A, const __fp16 *B, float *C,
  * @param[in] alpha float number
  * @param[in] beta float number
  */
-void hgemm_noTrans_padding_wrt_K4(const __fp16 *A, const __fp16 *B, float *C,
-                                 unsigned int M, unsigned int N, unsigned int K,
-                                 float alpha = 1.F, float beta = 0.F);
+void hgemm_classify(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
+                    unsigned int N, unsigned int K, float alpha = 1.F, float beta = 0.F,
+                    bool TransA = false, bool TransB = false);
 
 /**
  * @brief     hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
@@ -447,6 +448,27 @@ void hgemm_transA(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
  */
 void hgemm_transB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
                   unsigned int N, unsigned int K, float alpha, float beta);
+
+void hgemm_transB_fallback(const __fp16 *A, const __fp16 *B, float *C,
+                           unsigned int M, unsigned int N, unsigned int K,
+                           float alpha, float beta);
+
+/**
+ * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * where op(X) is one of X or X**T
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_transB_8x16(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha = 1.F, float beta = 0.F);
 /**
  * @brief     hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
  * where op(X) is one of X or X**T