[ hgemm ] Add 4x4 kernel-using f16-f32 hgemm_noTrans
authorskykongkong8 <ss.kong@samsung.com>
Fri, 12 Apr 2024 03:48:02 +0000 (12:48 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 22 May 2024 23:13:42 +0000 (08:13 +0900)
- Now Hgemm supports 4x4 f16-f32 partial accumulation strategy

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
nntrainer/tensor/hgemm/hgemm.cpp
nntrainer/tensor/hgemm/hgemm.h

index faf6d21cbee89d0df4bd94835651af2b9a91099f..ffe51056606926fb6b5da1d6789c28381934eec4 100644 (file)
@@ -412,6 +412,73 @@ void hgemm_noTrans_1x8(unsigned int M, unsigned int N, unsigned int K,
   free(sb);
 }
 
+void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha, float beta) {
+  __fp16 *sa = alignedMalloc(M * K);
+  __fp16 *sb = alignedMalloc(K * N);
+
+  unsigned int ms, mms, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+      }
+      packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
+
+      for (mms = ms; mms < ms + m_min; mms += m2_min) {
+        m2_min = (ms + m_min) - mms;
+        if (m2_min >= 3 * GEMM_UNROLLING_4) {
+          m2_min = 3 * GEMM_UNROLLING_4;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
+          m2_min = 2 * GEMM_UNROLLING_4;
+        } else if (m2_min > GEMM_UNROLLING_4) {
+          m2_min = GEMM_UNROLLING_4;
+        }
+
+        packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
+                   sa + k_min * (mms - ms));
+
+        HGEMM_KERNEL_4x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
+                         C + mms * ldc, ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+        }
+
+        packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+        HGEMM_KERNEL_4x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sa);
+  free(sb);
+}
+
+
 void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
                        const __fp16 *A, unsigned int lda, const __fp16 *B,
                        unsigned int ldb, __fp16 *C, unsigned int ldc,
index b05d89cb01534a556a120176ed96057c82954d6c..7c8194edf28376e170a3bb9503b5a409a4d55c63 100644 (file)
@@ -181,6 +181,26 @@ void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
                        unsigned int ldb, __fp16 *C, unsigned int ldc,
                        float alpha = 1.F, float beta = 0.F);
 
+/**
+ * @brief hgemm noTrans computation with 4x4 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix C
+ * @param B input matrix B
+ * @param ldb length of the col of matrix C
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha = 1.F, float beta = 0.F);
+
 /**
  * @brief hgemm noTrans computation with 8x8 kernel : C = A*B,
  *