[hgemm] hgemm noTrans with 1x4 kernel
authorDebadri Samaddar <s.debadri@samsung.com>
Tue, 23 Apr 2024 06:30:16 +0000 (12:00 +0530)
committerJijoong Moon <jijoong.moon@samsung.com>
Thu, 25 Apr 2024 23:07:39 +0000 (08:07 +0900)
Added hgemm_kernel_1x4
Added hgemm_noTrans_1x4 calls
Added unittest dot_gemm_50_768_516

Signed-off-by: Debadri Samaddar <s.debadri@samsung.com>
nntrainer/tensor/hgemm/hgemm.cpp
nntrainer/tensor/hgemm/hgemm.h
nntrainer/tensor/hgemm/hgemm_kernel_1x4.h [new file with mode: 0644]
test/unittest/unittest_nntrainer_tensor_neon_fp16.cpp

index d57f9a8dbc22b1924875c069e1d721638cddfd4c..a41a5ba6dc12367a521cdc92e61bd15061ff9590 100644 (file)
@@ -13,6 +13,7 @@
  */
 
 #include <hgemm.h>
+#include <hgemm_kernel_1x4.h>
 #include <hgemm_kernel_1x8.h>
 #include <hgemm_kernel_4x4.h>
 #include <hgemm_kernel_4x8.h>
@@ -21,6 +22,7 @@
 #include <hgemm_kernel_pack.h>
 #include <hgemm_util.h>
 
+#define HGEMM_KERNEL_1x4 hgemm_kernel_1x4
 #define HGEMM_KERNEL_4x4 hgemm_kernel_4x4
 #define HGEMM_KERNEL_1x8 hgemm_kernel_1x8
 #define HGEMM_KERNEL_4x8 hgemm_kernel_4x8
@@ -38,6 +40,8 @@ void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
       hgemm_noTrans_4x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
     } else if (N % 8 == 0) {
       hgemm_noTrans_1x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
+    } else if (N % 4 == 0) {
+      hgemm_noTrans_1x4(M, N, K, A, K, B, N, C32, N, alpha, beta);
     } else {
       hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
     }
@@ -58,10 +62,144 @@ void hgemm_noTrans(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
       hgemm_noTrans_1x8(M, N, K, A, K, B, N, C, N, alpha, beta);
     } else if (M % 4 == 0 && N % 4 == 0 && K % 4 == 0) {
       hgemm_noTrans_4x4(M, N, K, A, K, B, N, C, N, alpha, beta);
+    } else if (N % 4 == 0) {
+      hgemm_noTrans_1x4(M, N, K, A, K, B, N, C, N, alpha, beta);
     }
   }
 }
 
+void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, __fp16 *C, unsigned int ldc,
+                       float alpha, float beta) {
+  __fp16 *sa = alignedMalloc(M * K);
+  __fp16 *sb = alignedMalloc(K * N);
+
+  unsigned int ms, mms, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+      }
+      packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
+
+      for (mms = ms; mms < ms + m_min; mms += m2_min) {
+        m2_min = (ms + m_min) - mms;
+        if (m2_min >= 3 * GEMM_UNROLLING_1) {
+          m2_min = 3 * GEMM_UNROLLING_1;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_1) {
+          m2_min = 2 * GEMM_UNROLLING_1;
+        } else if (m2_min > GEMM_UNROLLING_1) {
+          m2_min = GEMM_UNROLLING_1;
+        }
+
+        packing_A1(m2_min, k_min, A + mms * lda + ks, lda,
+                   sa + k_min * (mms - ms));
+
+        HGEMM_KERNEL_1x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
+                         C + mms * ldc, ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+        }
+
+        packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+        HGEMM_KERNEL_1x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sa);
+  free(sb);
+}
+
+void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha, float beta) {
+  __fp16 *sa = alignedMalloc(M * K);
+  __fp16 *sb = alignedMalloc(K * N);
+
+  unsigned int ms, mms, ns, ks;
+  unsigned int m_min, m2_min, n_min, k_min;
+  for (ms = 0; ms < M; ms += M_BLOCKING) {
+    m_min = M - ms;
+    if (m_min > M_BLOCKING) {
+      m_min = M_BLOCKING;
+    }
+
+    for (ks = 0; ks < K; ks += k_min) {
+      k_min = K - ks;
+      if (k_min >= (K_BLOCKING << 1)) {
+        k_min = K_BLOCKING;
+      } else if (k_min > K_BLOCKING) {
+        k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+      }
+
+      n_min = N;
+      if (N >= N_BLOCKING * 2) {
+        n_min = N_BLOCKING;
+      } else if (N > N_BLOCKING) {
+        n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+      }
+      packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
+
+      for (mms = ms; mms < ms + m_min; mms += m2_min) {
+        m2_min = (ms + m_min) - mms;
+        if (m2_min >= 3 * GEMM_UNROLLING_1) {
+          m2_min = 3 * GEMM_UNROLLING_1;
+        } else if (m2_min >= 2 * GEMM_UNROLLING_1) {
+          m2_min = 2 * GEMM_UNROLLING_1;
+        } else if (m2_min > GEMM_UNROLLING_1) {
+          m2_min = GEMM_UNROLLING_1;
+        }
+
+        packing_A1(m2_min, k_min, A + mms * lda + ks, lda,
+                   sa + k_min * (mms - ms));
+
+        HGEMM_KERNEL_1x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
+                         C + mms * ldc, ldc);
+      }
+
+      for (ns = n_min; ns < N; ns += n_min) {
+        n_min = N - ns;
+        if (n_min >= N_BLOCKING * 2) {
+          n_min = N_BLOCKING;
+        } else if (n_min > N_BLOCKING) {
+          n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+        }
+
+        packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+        HGEMM_KERNEL_1x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+      }
+    }
+  }
+
+  free(sa);
+  free(sb);
+}
+
 void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
                        const __fp16 *A, unsigned int lda, const __fp16 *B,
                        unsigned int ldb, __fp16 *C, unsigned int ldc,
index 2d9cc2e31af060a4aa010e6e423c289aa7e359ee..b05d89cb01534a556a120176ed96057c82954d6c 100644 (file)
@@ -61,6 +61,46 @@ void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K,
                             unsigned int ldb, float *C, unsigned int ldc,
                             float alpha = 1.F, float beta = 0.F);
 
+/**
+ * @brief hgemm noTrans computation with 1x4 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix C
+ * @param B input matrix B
+ * @param ldb length of the col of matrix C
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, __fp16 *C, unsigned int ldc,
+                       float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 1x4 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix C
+ * @param B input matrix B
+ * @param ldb length of the col of matrix C
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
+                       const __fp16 *A, unsigned int lda, const __fp16 *B,
+                       unsigned int ldb, float *C, unsigned int ldc,
+                       float alpha = 1.F, float beta = 0.F);
+
 /**
  * @brief hgemm noTrans computation with 4x4 kernel : C = A*B,
  *
diff --git a/nntrainer/tensor/hgemm/hgemm_kernel_1x4.h b/nntrainer/tensor/hgemm/hgemm_kernel_1x4.h
new file mode 100644 (file)
index 0000000..c189f63
--- /dev/null
@@ -0,0 +1,144 @@
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Debadri Samaddar <s.debadri@samsung.com>
+ *
+ * @file   hgemm_kernel_1x4.h
+ * @date   23 April 2024
+ * @see    https://github.com/nnstreamer/nntrainer
+ * @author Debadri Samaddar <s.debadri@samsung.com>
+ * @bug    No known bugs except for NYI items
+ * @brief  This is half-precision GEMM 1x4 kernel
+ *
+ */
+
+#include <hgemm_common.h>
+#include <stdlib.h>
+
+/**
+ * @brief hgemm 1x4 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading dimension of matrix C
+ */
+void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, __fp16 *sc, unsigned int ldc) {
+  assert(M > 0 && N > 0 && K > 0);
+  assert(N % 4 == 0);
+
+  __fp16 *a = sa, *b = sb, *c = sc;
+  unsigned int i, j, l;
+  for (i = 0; i < M; i++) {
+    for (j = 0; j < N; j += VL_FP16_HALF) {
+      __builtin_prefetch(b, 0, 3);
+      __builtin_prefetch(a, 0, 3);
+
+      for (l = 0; l < K; l += VL_FP16_HALF) {
+        float16x4_t v24 = {0.F};
+        float16x4_t v0 = vld1_f16(b);
+        float16_t v16 = *a;
+
+        v24 = vfma_n_f16(v24, v0, v16);
+
+        float16x4_t v1 = vld1_f16(b + 4);
+        float16_t v17 = *(a + 1);
+
+        v24 = vfma_n_f16(v24, v1, v17);
+
+        float16x4_t v2 = vld1_f16(b + 8);
+        float16_t v18 = *(a + 2);
+
+        v24 = vfma_n_f16(v24, v2, v18);
+
+        float16x4_t v3 = vld1_f16(b + 12);
+        float16_t v19 = *(a + 3);
+
+        v24 = vfma_n_f16(v24, v3, v19);
+
+        __builtin_prefetch(b + 16, 0, 3);
+        __builtin_prefetch(a + 4, 0, 3);
+
+        b += 16;
+        a += 4;
+
+        v24 = vadd_f16(vld1_f16(c), v24);
+
+        vst1_f16(c, v24);
+      }
+      c += 4;
+      a -= K;
+    }
+    sc += ldc;
+    c = sc;
+    a += K;
+    b = sb;
+  }
+}
+
+/**
+ * @brief hgemm 1x4 kernel sc = sa * sb
+ *
+ * @param m length of the row of matrix A
+ * @param n length of the col of matrix B
+ * @param k length of the col of matrix A
+ * @param sa sub-matrix of input matrix A
+ * @param sb sub-matrix of input matrix B
+ * @param sc sub-matrix of output matrix C
+ * @param ldc leading dimension of matrix C
+ */
+void hgemm_kernel_1x4(unsigned int M, unsigned int N, unsigned int K,
+                      __fp16 *sa, __fp16 *sb, float *sc, unsigned int ldc) {
+  assert(M > 0 && N > 0 && K > 0);
+  assert(N % 4 == 0);
+
+  __fp16 *a = sa, *b = sb;
+  float *c = sc;
+  unsigned int i, j, l;
+  for (i = 0; i < M; i++) {
+    for (j = 0; j < N; j += VL_FP16_HALF) {
+      __builtin_prefetch(b, 0, 3);
+      __builtin_prefetch(a, 0, 3);
+
+      for (l = 0; l < K; l += VL_FP16_HALF) {
+        float16x4_t v24 = {0.F};
+        float16x4_t v0 = vld1_f16(b);
+        float16_t v16 = *a;
+
+        v24 = vfma_n_f16(v24, v0, v16);
+
+        float16x4_t v1 = vld1_f16(b + 4);
+        float16_t v17 = *(a + 1);
+
+        v24 = vfma_n_f16(v24, v1, v17);
+
+        float16x4_t v2 = vld1_f16(b + 8);
+        float16_t v18 = *(a + 2);
+
+        v24 = vfma_n_f16(v24, v2, v18);
+
+        float16x4_t v3 = vld1_f16(b + 12);
+        float16_t v19 = *(a + 3);
+
+        v24 = vfma_n_f16(v24, v3, v19);
+
+        __builtin_prefetch(b + 16, 0, 3);
+        __builtin_prefetch(a + 4, 0, 3);
+
+        b += 16;
+        a += 4;
+
+        vst1q_f32(c, vaddq_f32(vld1q_f32(c), vcvt_f32_f16(v24)));
+      }
+      c += 4;
+      a -= K;
+    }
+    sc += ldc;
+    c = sc;
+    a += K;
+    b = sb;
+  }
+}
index 6454feea2a7a3c793647760443c4e6741d1df85e..53d01858ffb8856d24cb1463791a2a96a1420af9 100644 (file)
@@ -658,6 +658,67 @@ TEST(nntrainer_Tensor, dot_gemm_50_768_20000) {
   EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1);
 }
 
+TEST(nntrainer_Tensor, dot_gemm_50_768_516) {
+  /// @note GEMM : A X B = C
+  int batch = 1;
+  int channel = 1;
+  int height = 50;
+  int width = 768;
+
+  int height_b = 768;
+  int width_b = 516;
+
+  bool transA = false;
+  bool transB = false;
+
+  nntrainer::TensorDim::TensorType t_type_nchw_fp16 = {
+    nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16};
+
+  nntrainer::TensorDim::TensorType t_type_nchw_fp32 = {
+    nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32};
+
+  nntrainer::Tensor A(batch, channel, height, width, t_type_nchw_fp16);
+  nntrainer::Tensor B(batch, channel, height_b, width_b, t_type_nchw_fp16);
+
+  nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32);
+  nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32);
+
+  const float alpha = 1e-1;
+  const int MOD = 10;
+
+  GEN_TEST_INPUT(A, ((i * (batch * height * channel) + j * (batch * height) +
+                      k * (width) + l + 1) %
+                     MOD) *
+                      alpha);
+  GEN_TEST_INPUT_B(B, ((i * (batch * height_b * channel) +
+                        j * (batch * height_b) + k * (width_b) + l + 1) %
+                       MOD) *
+                        alpha);
+  GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) +
+                           j * (batch * height) + k * (width) + l + 1) %
+                          MOD) *
+                           alpha);
+  GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) +
+                             j * (batch * height_b) + k * (width_b) + l + 1) %
+                            MOD) *
+                             alpha);
+
+  nntrainer::Tensor C = A.dot(B, transA, transB);
+
+  nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB);
+
+  float mseErrorNeon =
+    mse<__fp16>(C.getData<__fp16>(), C_fp32.getData<float>(), C.size());
+
+  double cosSimNeon = cosine_similarity<__fp16>(
+    C.getData<__fp16>(), C_fp32.getData<float>(), C.size());
+
+  const float epsilon = 1e-3 * width;
+
+  EXPECT_IN_RANGE(mseErrorNeon, 0, epsilon);
+  EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1);
+}
+
 TEST(nntrainer_Tensor, dot_gemv_768_96000) {
   /// @note GEMV : A X B = C
   int batch = 1;