[ hgemm ] Implement 4x4 f16-f32 kernel
authorskykongkong8 <ss.kong@samsung.com>
Fri, 12 Apr 2024 03:46:57 +0000 (12:46 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 22 May 2024 23:13:42 +0000 (08:13 +0900)
- Implement 4x4 GEMM kernel that works f16-f32 partial accumulation

**Self evaluation:**
1. Build test:     [X]Passed [ ]Failed [ ]Skipped
2. Run test:     [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: skykongkong8 <ss.kong@samsung.com>
nntrainer/tensor/hgemm/hgemm_kernel_4x4.h

index 6166b9407dca3dc5653c87dbd06c81d19403d8db..5efa58487926a0135768ea52a4573d1f978b64e8 100644 (file)
@@ -101,3 +101,88 @@ void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
     b = sb;
   }
 }
+
+/**
+ * @brief hgemm 4x4 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading dimension of matrix C
+ */
+void hgemm_kernel_4x4(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
+  assert(M > 0 && N > 0 && K > 0);
+  assert(M % 4 == 0 && N % 4 == 0 && K % 4 == 0);
+
+  __fp16 *a = sa, *b = sb;
+  float *c = sc;
+  unsigned int i, j, l;
+  for (i = 0; i < M; i += VL_FP16_HALF) {
+    for (j = 0; j < N; j += VL_FP16_HALF) {
+      __builtin_prefetch(b, 0, 3);
+      __builtin_prefetch(a, 0, 3);
+
+      float16x4_t v24 = {0};
+      float16x4_t v25 = {0};
+      float16x4_t v26 = {0};
+      float16x4_t v27 = {0};
+
+      for (l = 0; l < K; l += VL_FP16_HALF) {
+        float16x4_t v0 = vld1_f16(b);
+        float16x4_t v16 = vld1_f16(a);
+
+        v24 = vfma_lane_f16(v24, v0, v16, 0);
+        v25 = vfma_lane_f16(v25, v0, v16, 1);
+        v26 = vfma_lane_f16(v26, v0, v16, 2);
+        v27 = vfma_lane_f16(v27, v0, v16, 3);
+
+        float16x4_t v1 = vld1_f16(b + 4);
+        float16x4_t v17 = vld1_f16(a + 4);
+
+        v24 = vfma_lane_f16(v24, v1, v17, 0);
+        v25 = vfma_lane_f16(v25, v1, v17, 1);
+        v26 = vfma_lane_f16(v26, v1, v17, 2);
+        v27 = vfma_lane_f16(v27, v1, v17, 3);
+
+        float16x4_t v2 = vld1_f16(b + 8);
+        float16x4_t v18 = vld1_f16(a + 8);
+
+        v24 = vfma_lane_f16(v24, v2, v18, 0);
+        v25 = vfma_lane_f16(v25, v2, v18, 1);
+        v26 = vfma_lane_f16(v26, v2, v18, 2);
+        v27 = vfma_lane_f16(v27, v2, v18, 3);
+
+        float16x4_t v3 = vld1_f16(b + 12);
+        float16x4_t v19 = vld1_f16(a + 12);
+
+        v24 = vfma_lane_f16(v24, v3, v19, 0);
+        v25 = vfma_lane_f16(v25, v3, v19, 1);
+        v26 = vfma_lane_f16(v26, v3, v19, 2);
+        v27 = vfma_lane_f16(v27, v3, v19, 3);
+
+        __builtin_prefetch(b + 16, 0, 3);
+        __builtin_prefetch(a + 16, 0, 3);
+
+        b += 16;
+        a += 16;
+
+        vst1_f32(c, vadd_f32(vld1_f32(c), vcvt_f32_f16(v24)));
+        vst1_f32(c + ldc, vadd_f32(vld1_f32(c + ldc), vcvt_f32_f16(v25)));
+        vst1_f32(c + 2 * ldc, vadd_f32(vld1_f32(c + 2 * ldc), vcvt_f32_f16(v26)));
+        vst1_f32(c + 3 * ldc,  vadd_f32(vld1_f32(c + 3 * ldc), vcvt_f32_f16(v27)));
+      }
+
+      c += 4;
+      a -= 4 * K;
+    }
+    sc += ldc * 4;
+    c = sc;
+    a += 4 * K;
+    b = sb;
+  }
+}
+