[ hgemm ] Consider K=1 changes
authorskykongkong8 <ss.kong@samsung.com>
Thu, 27 Jun 2024 07:25:54 +0000 (16:25 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Fri, 28 Jun 2024 08:25:40 +0000 (17:25 +0900)
- Current implementation is rooted on general cases, thus optimize only w.r.t. K accumulation.
- However, when it comes to M,1 x 1,N computation, all optimizations like packing, transposing is no use.
- Implementing a explicit kernel function for such case resolved the latency issue.

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
nntrainer/tensor/blas_neon.cpp
nntrainer/tensor/hgemm/hgemm.cpp
nntrainer/tensor/hgemm/hgemm.h

index f442fd8cb6adef9764da853477cebded2a68cbe2..576bc0e3e7bb663f6a65dce94c398bd6cea70e1d 100644 (file)
@@ -1588,7 +1588,11 @@ unsigned int isamax(const unsigned int N, const __fp16 *X) {
 
 void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, uint32_t N,
            uint32_t K, float alpha, float beta, bool TransA, bool TransB) {
-
+  if (K == 1) {
+    unsigned int lda = (TransA) ? M : K;
+    unsigned int ldb = (TransB) ? K : N;
+    return hgemm_K1(M, N, K, A, lda, B, ldb, C, N, alpha, beta);
+  }
   // dynamic creation to avoid reaching stack limit(causes segmentation fault)
   float *C32 = (float *)malloc(M * N * sizeof(float));
 
index 2b48a630572755037e17eb7544854ae6f5e28516..3593488e123e209ad87f1f43fda0d611027e70ea 100644 (file)
@@ -144,6 +144,22 @@ void hgemm_noTrans_padding_wrt_K(const __fp16 *A, const __fp16 *B, float *C,
   free(B8);
 }
 
+void hgemm_K1(unsigned int M, unsigned int N, unsigned int K, const __fp16 *A,
+              unsigned int lda, const __fp16 *B, unsigned int ldb, __fp16 *C,
+              unsigned int ldc, float alpha, float beta) {
+  float16x8_t a_vec;
+  unsigned int N8 = (N >> 3) << 3;
+  for (unsigned int m = 0; m < M; ++m) {
+    a_vec = vmovq_n_f16(A[m]);
+    for (unsigned int n = 0; n < N8; n += 8) {
+      vst1q_f16(&C[m * ldc + n], vmulq_f16(a_vec, vld1q_f16(&B[n])));
+    }
+    for (unsigned int n = N8; n < N; ++n) {
+      C[m * ldc + n] = A[m] * B[n];
+    }
+  }
+}
+
 void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
                        const __fp16 *A, unsigned int lda, const __fp16 *B,
                        unsigned int ldb, __fp16 *C, unsigned int ldc,
index 87bdb0cc4da76f5d2b1a80d5e8529e8c4064560c..1f37f01f2f501a85db245c1baa8440cee1ba8d60 100644 (file)
@@ -106,6 +106,25 @@ void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K,
                             unsigned int ldb, float *C, unsigned int ldc,
                             float alpha = 1.F, float beta = 0.F);
 
+/**
+ * @brief     hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_K1(unsigned int M, unsigned int N, unsigned int K,
+                      const __fp16 *A, unsigned int lda, const __fp16 *B,
+                      unsigned int ldb, __fp16 *C, unsigned int ldc,
+                      float alpha = 1.F, float beta = 0.F);
+
 /**
  * @brief hgemm noTrans computation with 1x4 kernel : C = A*B,
  *