- Missing optimizations for K=1 GEMM case was recently detected.
- Add such TC accordingly.
**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped
Signed-off-by: skykongkong8 <ss.kong@samsung.com>
const unsigned int ldc) {
#if (defined USE__FP16 && USE_NEON)
- nntrainer::neon::hgemm(A, B, C, M, N, K, alpha, beta, TransA == CblasTrans,
+ nntrainer::neon::custom_hgemm(A, B, C, M, N, K, alpha, beta, TransA == CblasTrans,
TransB == CblasTrans);
#else
float *A_ = new float[M * K];
return retIdx;
}
-void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, uint32_t N,
- uint32_t K, float alpha, float beta, bool TransA, bool TransB) {
- if (K == 1) {
- return hgemm_K1(A, B, C, M, N, K, alpha, beta, TransA, TransB);
- }
- // dynamic creation to avoid reaching stack limit(causes segmentation fault)
- float *C32 = (float *)malloc(M * N * sizeof(float));
-
- // performing beta*C
- unsigned int idx = 0;
- unsigned int size = M * N;
- unsigned int size8 = (size >> 3) << 3;
- unsigned int size4 = (size >> 2) << 2;
- if (std::fpclassify(beta) != FP_ZERO) {
- for (; idx < size8; idx += 8) {
- float16x8_t c =
- vmulq_n_f16(vld1q_f16(&C[idx]), static_cast<__fp16>(beta));
-
- vst1q_f32(&C32[idx], vcvt_f32_f16(vget_low_f16(c)));
- vst1q_f32(&C32[idx + 4], vcvt_f32_f16(vget_high_f16(c)));
- }
- // remaining 4
- for (; idx < size4; idx += 4) {
- float16x4_t c = vmul_n_f16(vld1_f16(&C[idx]), static_cast<__fp16>(beta));
-
- vst1q_f32(&C32[idx], vcvt_f32_f16(c));
- }
-
- // remaining values if dimensions not a multiple of 8
- for (; idx < size; idx++) {
- C32[idx] = C[idx] * beta;
- }
- } else {
- float32x4_t zeros = vmovq_n_f32(0.F);
- for (; idx < size4; idx += 4) {
- vst1q_f32(&C32[idx], zeros);
- }
- for (; idx < size; idx++) {
- C32[idx] = 0.F;
- }
- }
-
- hgemm_ensure_divisibility(A, B, C32, M, N, K, alpha, beta, TransA, TransB);
-
- copy_fp32_to_fp16(M * N, C32, C);
- free(C32);
+void custom_hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
+ uint32_t N, uint32_t K, float alpha, float beta, bool TransA,
+ bool TransB) {
+ hgemm(A, B, C, M, N, K, alpha, beta, TransA, TransB);
}
-
void ele_mul(const unsigned int N, const __fp16 *X, const __fp16 *Y, __fp16 *Z,
float alpha, float beta) {
unsigned int i = 0;
* @param[in] alpha float number
* @param[in] beta float number
*/
-void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, uint32_t N,
+void custom_hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M, uint32_t N,
uint32_t K, float alpha, float beta, bool TransA, bool TransB);
*
*/
-#include <cmath>
#include <hgemm.h>
-#include <hgemm_kernel_1x4.h>
-#include <hgemm_kernel_1x8.h>
-#include <hgemm_kernel_4x4.h>
-#include <hgemm_kernel_4x8.h>
-#include <hgemm_kernel_8x16.h>
-#include <hgemm_kernel_8x8.h>
-#include <hgemm_kernel_pack.h>
+#include <hgemm_noTrans.h>
#include <hgemm_padding.h>
+#include <hgemm_transA.h>
+#include <hgemm_transAB.h>
+#include <hgemm_transB.h>
#include <hgemm_util.h>
-#include <limits>
-#include <matrix_transpose_neon.h>
+#include <hgemm_common.h>
+#include <cmath>
-#define HGEMM_KERNEL_1x4 hgemm_kernel_1x4
-#define HGEMM_KERNEL_4x4 hgemm_kernel_4x4
-#define HGEMM_KERNEL_1x8 hgemm_kernel_1x8
-#define HGEMM_KERNEL_4x8 hgemm_kernel_4x8
-#define HGEMM_KERNEL_8x8 hgemm_kernel_8x8
-#define HGEMM_KERNEL_8x16 hgemm_kernel_8x16
-void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
- unsigned int N, unsigned int K, float alpha, float beta) {
- const float eps = std::numeric_limits<float>::epsilon();
- if (std::abs(alpha - 1.F) < eps) {
- hgemm_noTrans_strict(A, B, C32, M, N, K, alpha, beta);
- } else {
- hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
+void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N,
+ unsigned int K, float alpha, float beta, bool TransA, bool TransB) {
+ if (K == 1) {
+ return hgemm_K1(A, B, C, M, N, K, alpha, beta, TransA, TransB);
}
-}
+ // dynamic creation to avoid reaching stack limit(causes segmentation fault)
+ float *C32 = (float *)malloc(M * N * sizeof(float));
+
+ // performing beta*C
+ unsigned int idx = 0;
+ unsigned int size = M * N;
+ unsigned int size8 = (size >> 3) << 3;
+ unsigned int size4 = (size >> 2) << 2;
+
+ if (std::fpclassify(beta) != FP_ZERO) {
+ for (; idx < size8; idx += 8) {
+ float16x8_t c =
+ vmulq_n_f16(vld1q_f16(&C[idx]), static_cast<__fp16>(beta));
+
+ vst1q_f32(&C32[idx], vcvt_f32_f16(vget_low_f16(c)));
+ vst1q_f32(&C32[idx + 4], vcvt_f32_f16(vget_high_f16(c)));
+ }
+ // remaining 4
+ for (; idx < size4; idx += 4) {
+ float16x4_t c = vmul_n_f16(vld1_f16(&C[idx]), static_cast<__fp16>(beta));
+
+ vst1q_f32(&C32[idx], vcvt_f32_f16(c));
+ }
-void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, float *C32,
- unsigned int M, unsigned int N, unsigned int K,
- float alpha, float beta) {
- // used bitwise operator instead of modulo for performance
- // e.g (M % 8) is same as (M & 0x7) which will extract last 3 bits of M
- if ((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0) {
- hgemm_noTrans_8x16(M, N, K, A, K, B, N, C32, N, alpha, beta);
- } else if ((M & 0x7) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) {
- hgemm_noTrans_8x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
- } else if ((M & 0x3) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) {
- hgemm_noTrans_4x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
- } else if ((K & 0x7) == 0 && (N & 0x7) == 0) {
- hgemm_noTrans_1x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
- } else if ((K & 0x7) == 0 && (N & 0x3) == 0) {
- hgemm_noTrans_1x4(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ // remaining values if dimensions not a multiple of 8
+ for (; idx < size; idx++) {
+ C32[idx] = C[idx] * beta;
+ }
} else {
- hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ float32x4_t zeros = vmovq_n_f32(0.F);
+ for (; idx < size4; idx += 4) {
+ vst1q_f32(&C32[idx], zeros);
+ }
+ for (; idx < size; idx++) {
+ C32[idx] = 0.F;
+ }
}
-}
-void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, __fp16 *C,
- unsigned int M, unsigned int N, unsigned int K,
- float alpha, float beta) {
- if (alpha == 1.F) {
- // used bitwise operator instead of modulo for performance
- // e.g (M % 8) is same as (M & 0x7) which will extract last 3 bits of M
- if ((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0) {
- hgemm_noTrans_8x16(M, N, K, A, K, B, N, C, N, alpha, beta);
- } else if ((M & 0x7) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) {
- hgemm_noTrans_8x8(M, N, K, A, K, B, N, C, N, alpha, beta);
- } else if ((M & 0x3) == 0 && (N & 0x7) == 0 && (K & 0x3) == 0) {
- hgemm_noTrans_4x8(M, N, K, A, K, B, N, C, N, alpha, beta);
- } else if ((M & 0x3) == 0 && (N & 0x3) == 0 && (K & 0x3) == 0) {
- hgemm_noTrans_4x4(M, N, K, A, K, B, N, C, N, alpha, beta);
- } else if ((N & 0x7) == 0 && (K & 0x7) == 0) {
- hgemm_noTrans_1x8(M, N, K, A, K, B, N, C, N, alpha, beta);
- } else if ((N & 0x3) == 0 && (K & 0x7) == 0) {
- hgemm_noTrans_1x4(M, N, K, A, K, B, N, C, N, alpha, beta);
- }
+ hgemm_ensure_divisibility(A, B, C32, M, N, K, alpha, beta, TransA, TransB);
+
+ unsigned int L = M*N;
+ unsigned int L8 = (L >> 3) <<3;
+
+ for (unsigned int idx = 0; idx < L8; idx += 8) {
+ float32x4_t x1 = vld1q_f32(&C32[idx]);
+ float32x4_t x2 = vld1q_f32(&C32[idx + 4]);
+
+ float16x8_t y1 = vcombine_f16(vcvt_f16_f32(x1), vcvt_f16_f32(x2));
+
+ vst1q_f16(&C[idx], y1);
+ }
+ for (unsigned int idx = L8; idx < L; ++idx) {
+ C[idx] = static_cast<__fp16>(C32[idx]);
}
+
+ free(C32);
}
void hgemm_ensure_divisibility(const __fp16 *A, const __fp16 *B, float *C32,
}
}
-void hgemm_K1_noTrans(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, __fp16 *C, unsigned int ldc,
- float alpha, float beta) {
- const float eps = std::numeric_limits<float>::epsilon();
- float16x8_t a_vec;
- unsigned int N8 = (N >> 3) << 3;
- for (unsigned int m = 0; m < M; ++m) {
- a_vec = vmovq_n_f16(alpha * A[m]);
- if (std::fpclassify(beta) != FP_ZERO) {
- for (unsigned int n = 0; n < N8; n += 8) {
- vst1q_f16(&C[m * ldc + n],
- vaddq_f16(vmulq_f16(a_vec, vld1q_f16(&B[n])),
- vmulq_n_f16(vld1q_f16(&C[m * ldc + n]), beta)));
- }
- } else {
- for (unsigned int n = 0; n < N8; n += 8) {
- vst1q_f16(&C[m * ldc + n], vmulq_f16(a_vec, vld1q_f16(&B[n])));
- }
- }
- for (unsigned int n = N8; n < N; ++n) {
- C[m * ldc + n] = alpha * A[m] * B[n] + beta * C[m * ldc + n];
- }
- }
-}
-
-void hgemm_K1_transA(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, __fp16 *C, unsigned int ldc, float alpha,
- float beta) {
- __fp16 *A_T = alignedMalloc(M * K);
-
- transpose_neon<__fp16>(K, M, A, M, A_T, K);
-
- hgemm_K1_noTrans(M, N, K, A_T, lda, B, ldb, C, ldc, alpha, beta);
-
- free(A_T);
-}
-
-void hgemm_K1_transB(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, __fp16 *C, unsigned int ldc, float alpha,
- float beta) {
- __fp16 *B_T = alignedMalloc(K * N);
-
- transpose_neon<__fp16>(N, K, B, K, B_T, N);
-
- hgemm_K1_noTrans(M, N, K, A, lda, B_T, ldb, C, ldc, alpha, beta);
-
- free(B_T);
-}
-
-void hgemm_K1_transAB(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, __fp16 *C, unsigned int ldc,
- float alpha, float beta) {
- __fp16 *A_T = alignedMalloc(M * K);
- __fp16 *B_T = alignedMalloc(K * N);
-
- transpose_neon<__fp16>(K, M, A, M, A_T, K);
- transpose_neon<__fp16>(N, K, B, K, B_T, N);
-
- hgemm_K1_noTrans(M, N, K, A_T, lda, B_T, ldb, C, ldc, alpha, beta);
-
- free(A_T);
- free(B_T);
-}
-
-void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, __fp16 *C, unsigned int ldc,
- float alpha, float beta) {
- __fp16 *sa = alignedMalloc(M * K);
- __fp16 *sb = alignedMalloc(K * N);
-
- unsigned int ms, mms, ns, ks;
- unsigned int m_min, m2_min, n_min, k_min;
- for (ms = 0; ms < M; ms += M_BLOCKING) {
- m_min = M - ms;
- if (m_min > M_BLOCKING) {
- m_min = M_BLOCKING;
- }
-
- for (ks = 0; ks < K; ks += k_min) {
- k_min = K - ks;
- if (k_min >= (K_BLOCKING << 1)) {
- k_min = K_BLOCKING;
- } else if (k_min > K_BLOCKING) {
- k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
- }
-
- n_min = N;
- if (N >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (N > N_BLOCKING) {
- n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
- }
- packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
-
- for (mms = ms; mms < ms + m_min; mms += m2_min) {
- m2_min = (ms + m_min) - mms;
- if (m2_min >= 3 * GEMM_UNROLLING_1) {
- m2_min = 3 * GEMM_UNROLLING_1;
- } else if (m2_min >= 2 * GEMM_UNROLLING_1) {
- m2_min = 2 * GEMM_UNROLLING_1;
- } else if (m2_min > GEMM_UNROLLING_1) {
- m2_min = GEMM_UNROLLING_1;
- }
-
- packing_A1(m2_min, k_min, A + mms * lda + ks, lda,
- sa + k_min * (mms - ms));
-
- HGEMM_KERNEL_1x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
- C + mms * ldc, ldc);
- }
-
- for (ns = n_min; ns < N; ns += n_min) {
- n_min = N - ns;
- if (n_min >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (n_min > N_BLOCKING) {
- n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
- }
-
- packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
- HGEMM_KERNEL_1x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
- }
- }
- }
-
- free(sa);
- free(sb);
-}
-
-void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, float *C, unsigned int ldc,
- float alpha, float beta) {
- __fp16 *sa = alignedMalloc(M * K);
- __fp16 *sb = alignedMalloc(K * N);
-
- unsigned int ms, mms, ns, ks;
- unsigned int m_min, m2_min, n_min, k_min;
- for (ms = 0; ms < M; ms += M_BLOCKING) {
- m_min = M - ms;
- if (m_min > M_BLOCKING) {
- m_min = M_BLOCKING;
- }
-
- for (ks = 0; ks < K; ks += k_min) {
- k_min = K - ks;
- if (k_min >= (K_BLOCKING << 1)) {
- k_min = K_BLOCKING;
- } else if (k_min > K_BLOCKING) {
- k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
- }
-
- n_min = N;
- if (N >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (N > N_BLOCKING) {
- n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
- }
- packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
-
- for (mms = ms; mms < ms + m_min; mms += m2_min) {
- m2_min = (ms + m_min) - mms;
- if (m2_min >= 3 * GEMM_UNROLLING_1) {
- m2_min = 3 * GEMM_UNROLLING_1;
- } else if (m2_min >= 2 * GEMM_UNROLLING_1) {
- m2_min = 2 * GEMM_UNROLLING_1;
- } else if (m2_min > GEMM_UNROLLING_1) {
- m2_min = GEMM_UNROLLING_1;
- }
-
- packing_A1(m2_min, k_min, A + mms * lda + ks, lda,
- sa + k_min * (mms - ms));
-
- HGEMM_KERNEL_1x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
- C + mms * ldc, ldc);
- }
-
- for (ns = n_min; ns < N; ns += n_min) {
- n_min = N - ns;
- if (n_min >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (n_min > N_BLOCKING) {
- n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
- }
-
- packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
- HGEMM_KERNEL_1x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
- }
- }
- }
-
- free(sa);
- free(sb);
-}
-
-void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, __fp16 *C, unsigned int ldc,
- float alpha, float beta) {
- __fp16 *sa = alignedMalloc(M * K);
- __fp16 *sb = alignedMalloc(K * N);
-
- unsigned int ms, mms, ns, ks;
- unsigned int m_min, m2_min, n_min, k_min;
- for (ms = 0; ms < M; ms += M_BLOCKING) {
- m_min = M - ms;
- if (m_min > M_BLOCKING) {
- m_min = M_BLOCKING;
- }
-
- for (ks = 0; ks < K; ks += k_min) {
- k_min = K - ks;
- if (k_min >= (K_BLOCKING << 1)) {
- k_min = K_BLOCKING;
- } else if (k_min > K_BLOCKING) {
- k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
- }
-
- n_min = N;
- if (N >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (N > N_BLOCKING) {
- n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
- }
- packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
-
- for (mms = ms; mms < ms + m_min; mms += m2_min) {
- m2_min = (ms + m_min) - mms;
- if (m2_min >= 3 * GEMM_UNROLLING_4) {
- m2_min = 3 * GEMM_UNROLLING_4;
- } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
- m2_min = 2 * GEMM_UNROLLING_4;
- } else if (m2_min > GEMM_UNROLLING_4) {
- m2_min = GEMM_UNROLLING_4;
- }
-
- packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
- sa + k_min * (mms - ms));
-
- HGEMM_KERNEL_4x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
- C + mms * ldc, ldc);
- }
-
- for (ns = n_min; ns < N; ns += n_min) {
- n_min = N - ns;
- if (n_min >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (n_min > N_BLOCKING) {
- n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
- }
-
- packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
- HGEMM_KERNEL_4x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
- }
- }
- }
-
- free(sa);
- free(sb);
-}
-
-void hgemm_noTrans_1x8(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, __fp16 *C, unsigned int ldc,
- float alpha, float beta) {
- __fp16 *sa = alignedMalloc(M * K);
- __fp16 *sb = alignedMalloc(K * N);
-
- unsigned int ms, mms, ns, ks;
- unsigned int m_min, m2_min, n_min, k_min;
- unsigned int l1stride = 1;
- for (ms = 0; ms < M; ms += M_BLOCKING) {
- m_min = M - ms;
- if (m_min > M_BLOCKING) {
- m_min = M_BLOCKING;
- }
-
- for (ks = 0; ks < K; ks += k_min) {
- k_min = K - ks;
- if (k_min >= (K_BLOCKING << 1)) {
- k_min = K_BLOCKING;
- } else if (k_min > K_BLOCKING) {
- k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
- }
-
- n_min = N;
- if (N >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (N > N_BLOCKING) {
- n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) *
- GEMM_UNROLLING_8;
- } else {
- l1stride = 0;
- }
- packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
-
- for (mms = ms; mms < ms + m_min; mms += m2_min) {
- m2_min = (ms + m_min) - mms;
- if (m2_min >= 3 * GEMM_UNROLLING_1) {
- m2_min = 3 * GEMM_UNROLLING_1;
- } else if (m2_min >= 2 * GEMM_UNROLLING_1) {
- m2_min = 2 * GEMM_UNROLLING_1;
- } else if (m2_min > GEMM_UNROLLING_1) {
- m2_min = GEMM_UNROLLING_1;
- }
-
- packing_A1(m2_min, k_min, A + mms * lda + ks, lda,
- sa + k_min * (mms - ms) * l1stride);
-
- HGEMM_KERNEL_1x8(m2_min, n_min, k_min,
- sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc,
- ldc);
- }
-
- for (ns = n_min; ns < N; ns += n_min) {
- n_min = N - ns;
- if (n_min >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (n_min > N_BLOCKING) {
- n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
- }
-
- packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
- HGEMM_KERNEL_1x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
- }
- }
- }
-
- free(sa);
- free(sb);
-}
-
-void hgemm_noTrans_1x8(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, float *C, unsigned int ldc,
- float alpha, float beta) {
- __fp16 *sa = alignedMalloc(M * K);
- __fp16 *sb = alignedMalloc(K * N);
-
- unsigned int ms, mms, ns, ks;
- unsigned int m_min, m2_min, n_min, k_min;
- unsigned int l1stride = 1;
- for (ms = 0; ms < M; ms += M_BLOCKING) {
- m_min = M - ms;
- if (m_min > M_BLOCKING) {
- m_min = M_BLOCKING;
- }
-
- for (ks = 0; ks < K; ks += k_min) {
- k_min = K - ks;
- if (k_min >= (K_BLOCKING << 1)) {
- k_min = K_BLOCKING;
- } else if (k_min > K_BLOCKING) {
- k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
- }
-
- n_min = N;
- if (N >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (N > N_BLOCKING) {
- n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) *
- GEMM_UNROLLING_8;
- } else {
- l1stride = 0;
- }
- packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
-
- for (mms = ms; mms < ms + m_min; mms += m2_min) {
- m2_min = (ms + m_min) - mms;
- if (m2_min >= 3 * GEMM_UNROLLING_1) {
- m2_min = 3 * GEMM_UNROLLING_1;
- } else if (m2_min >= 2 * GEMM_UNROLLING_1) {
- m2_min = 2 * GEMM_UNROLLING_1;
- } else if (m2_min > GEMM_UNROLLING_1) {
- m2_min = GEMM_UNROLLING_1;
- }
-
- packing_A1(m2_min, k_min, A + mms * lda + ks, lda,
- sa + k_min * (mms - ms) * l1stride);
-
- HGEMM_KERNEL_1x8(m2_min, n_min, k_min,
- sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc,
- ldc);
- }
-
- for (ns = n_min; ns < N; ns += n_min) {
- n_min = N - ns;
- if (n_min >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (n_min > N_BLOCKING) {
- n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
- }
-
- packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
- HGEMM_KERNEL_1x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
- }
- }
- }
-
- free(sa);
- free(sb);
-}
-
-void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, float *C, unsigned int ldc,
- float alpha, float beta) {
- __fp16 *sa = alignedMalloc(M * K);
- __fp16 *sb = alignedMalloc(K * N);
-
- unsigned int ms, mms, ns, ks;
- unsigned int m_min, m2_min, n_min, k_min;
- for (ms = 0; ms < M; ms += M_BLOCKING) {
- m_min = M - ms;
- if (m_min > M_BLOCKING) {
- m_min = M_BLOCKING;
- }
-
- for (ks = 0; ks < K; ks += k_min) {
- k_min = K - ks;
- if (k_min >= (K_BLOCKING << 1)) {
- k_min = K_BLOCKING;
- } else if (k_min > K_BLOCKING) {
- k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
- }
-
- n_min = N;
- if (N >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (N > N_BLOCKING) {
- n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
- }
- packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
-
- for (mms = ms; mms < ms + m_min; mms += m2_min) {
- m2_min = (ms + m_min) - mms;
- if (m2_min >= 3 * GEMM_UNROLLING_4) {
- m2_min = 3 * GEMM_UNROLLING_4;
- } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
- m2_min = 2 * GEMM_UNROLLING_4;
- } else if (m2_min > GEMM_UNROLLING_4) {
- m2_min = GEMM_UNROLLING_4;
- }
-
- packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
- sa + k_min * (mms - ms));
-
- HGEMM_KERNEL_4x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
- C + mms * ldc, ldc);
- }
-
- for (ns = n_min; ns < N; ns += n_min) {
- n_min = N - ns;
- if (n_min >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (n_min > N_BLOCKING) {
- n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
- }
-
- packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
- HGEMM_KERNEL_4x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
- }
- }
- }
-
- free(sa);
- free(sb);
-}
-
-void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, __fp16 *C, unsigned int ldc,
- float alpha, float beta) {
- __fp16 *sa = alignedMalloc(M * K);
- __fp16 *sb = alignedMalloc(K * N);
-
- unsigned int ms, mms, ns, ks;
- unsigned int m_min, m2_min, n_min, k_min;
- unsigned int l1stride = 1;
- for (ms = 0; ms < M; ms += M_BLOCKING) {
- m_min = M - ms;
- if (m_min > M_BLOCKING) {
- m_min = M_BLOCKING;
- }
-
- for (ks = 0; ks < K; ks += k_min) {
- k_min = K - ks;
- if (k_min >= (K_BLOCKING << 1)) {
- k_min = K_BLOCKING;
- } else if (k_min > K_BLOCKING) {
- k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
- }
-
- n_min = N;
- if (N >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (N > N_BLOCKING) {
- n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) *
- GEMM_UNROLLING_8;
- } else {
- l1stride = 0;
- }
- packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
-
- for (mms = ms; mms < ms + m_min; mms += m2_min) {
- m2_min = (ms + m_min) - mms;
- if (m2_min >= 3 * GEMM_UNROLLING_4) {
- m2_min = 3 * GEMM_UNROLLING_4;
- } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
- m2_min = 2 * GEMM_UNROLLING_4;
- } else if (m2_min > GEMM_UNROLLING_4) {
- m2_min = GEMM_UNROLLING_4;
- }
-
- packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
- sa + k_min * (mms - ms) * l1stride);
-
- HGEMM_KERNEL_4x8(m2_min, n_min, k_min,
- sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc,
- ldc);
- }
-
- for (ns = n_min; ns < N; ns += n_min) {
- n_min = N - ns;
- if (n_min >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (n_min > N_BLOCKING) {
- n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
- }
-
- packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
- HGEMM_KERNEL_4x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
- }
- }
- }
-
- free(sa);
- free(sb);
-}
-
-void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, float *C, unsigned int ldc,
- float alpha, float beta) {
- __fp16 *sa = alignedMalloc(M * K);
- __fp16 *sb = alignedMalloc(K * N);
-
- unsigned int ms, mms, ns, ks;
- unsigned int m_min, m2_min, n_min, k_min;
- unsigned int l1stride = 1;
- for (ms = 0; ms < M; ms += M_BLOCKING) {
- m_min = M - ms;
- if (m_min > M_BLOCKING) {
- m_min = M_BLOCKING;
- }
-
- for (ks = 0; ks < K; ks += k_min) {
- k_min = K - ks;
- if (k_min >= (K_BLOCKING << 1)) {
- k_min = K_BLOCKING;
- } else if (k_min > K_BLOCKING) {
- k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
- }
-
- n_min = N;
- if (N >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (N > N_BLOCKING) {
- n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) *
- GEMM_UNROLLING_8;
- } else {
- l1stride = 0;
- }
- packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
-
- for (mms = ms; mms < ms + m_min; mms += m2_min) {
- m2_min = (ms + m_min) - mms;
- if (m2_min >= 3 * GEMM_UNROLLING_4) {
- m2_min = 3 * GEMM_UNROLLING_4;
- } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
- m2_min = 2 * GEMM_UNROLLING_4;
- } else if (m2_min > GEMM_UNROLLING_4) {
- m2_min = GEMM_UNROLLING_4;
- }
-
- packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
- sa + k_min * (mms - ms) * l1stride);
-
- HGEMM_KERNEL_4x8(m2_min, n_min, k_min,
- sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc,
- ldc);
- }
-
- for (ns = n_min; ns < N; ns += n_min) {
- n_min = N - ns;
- if (n_min >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (n_min > N_BLOCKING) {
- n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
- }
-
- packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
- HGEMM_KERNEL_4x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
- }
- }
- }
-
- free(sa);
- free(sb);
-}
-
-void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, __fp16 *C, unsigned int ldc,
- float alpha, float beta) {
-
- __fp16 *sa = alignedMalloc(M * K);
- __fp16 *sb = alignedMalloc(K * N);
-
- unsigned int ms, mms, ns, ks;
- unsigned int m_min, m2_min, n_min, k_min;
- for (ms = 0; ms < M; ms += M_BLOCKING) {
- m_min = M - ms;
- if (m_min > M_BLOCKING) {
- m_min = M_BLOCKING;
- }
-
- for (ks = 0; ks < K; ks += k_min) {
- k_min = K - ks;
- if (k_min >= (K_BLOCKING << 1)) {
- k_min = K_BLOCKING;
- } else if (k_min > K_BLOCKING) {
- k_min = (k_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
- }
-
- n_min = N;
- if (N >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (N > N_BLOCKING) {
- n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
- }
- packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
-
- for (mms = ms; mms < ms + m_min; mms += m2_min) {
- m2_min = (ms + m_min) - mms;
- if (m2_min >= 3 * GEMM_UNROLLING_8) {
- m2_min = 3 * GEMM_UNROLLING_8;
- } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
- m2_min = 2 * GEMM_UNROLLING_8;
- } else if (m2_min > GEMM_UNROLLING_8) {
- m2_min = GEMM_UNROLLING_8;
- }
-
- packing_A8(m2_min, k_min, A + mms * lda + ks, lda,
- sa + k_min * (mms - ms));
-
- HGEMM_KERNEL_8x8(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
- C + mms * ldc, ldc);
- }
-
- for (ns = n_min; ns < N; ns += n_min) {
- n_min = N - ns;
- if (n_min >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (n_min > N_BLOCKING) {
- n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
- }
-
- packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
- HGEMM_KERNEL_8x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
- }
- }
- }
-
- free(sa);
- free(sb);
-}
-
-void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, float *C, unsigned int ldc,
- float alpha, float beta) {
-
- __fp16 *sA = alignedMalloc(M * K);
- __fp16 *sB = alignedMalloc(K * N);
-
- unsigned int ms, ms2, ns, ks;
- unsigned int m_min, m2_min, n_min, k_min;
- for (ms = 0; ms < M; ms += M_BLOCKING) {
- m_min = M - ms;
- if (m_min > M_BLOCKING) {
- m_min = M_BLOCKING;
- }
-
- for (ks = 0; ks < K; ks += k_min) {
- k_min = K - ks;
- if (k_min >= (K_BLOCKING << 1)) {
- k_min = K_BLOCKING;
- } else if (k_min > K_BLOCKING) {
- k_min = (k_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
- }
-
- n_min = N;
- if (N >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (N > N_BLOCKING) {
- n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
- }
- packing_B8(k_min, n_min, B + ks * ldb, ldb, sB);
-
- for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) {
- m2_min = (ms + m_min) - ms2;
- if (m2_min >= 3 * GEMM_UNROLLING_8) {
- m2_min = 3 * GEMM_UNROLLING_8;
- } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
- m2_min = 2 * GEMM_UNROLLING_8;
- } else if (m2_min > GEMM_UNROLLING_8) {
- m2_min = GEMM_UNROLLING_8;
- }
-
- packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda,
- sA + k_min * (ms2 - ms));
-
- HGEMM_KERNEL_8x8(m2_min, n_min, k_min, sA + k_min * (ms2 - ms), sB,
- C + ms2 * ldc, ldc);
- }
-
- for (ns = n_min; ns < N; ns += n_min) {
- n_min = N - ns;
- if (n_min >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (n_min > N_BLOCKING) {
- n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
- }
-
- packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sB);
- HGEMM_KERNEL_8x8(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
- }
- }
- }
-
- free(sA);
- free(sB);
-}
-
-void hgemm_noTrans_8x16(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, __fp16 *C, unsigned int ldc,
- float alpha, float beta) {
- __fp16 *sA = alignedMalloc(M * K);
- __fp16 *sB = alignedMalloc(K * N);
-
- unsigned int ms, ms2, ns, ks;
- unsigned int m_min, m2_min, n_min, k_min;
- unsigned int stride_l1 = 1;
-
- for (ms = 0; ms < M; ms += M_BLOCKING) {
- m_min = M - ms;
- if (m_min > M_BLOCKING) {
- m_min = M_BLOCKING;
- }
-
- for (ks = 0; ks < K; ks += k_min) {
- k_min = K - ks;
- if (k_min >= (K_BLOCKING << 1)) {
- k_min = K_BLOCKING;
- } else if (k_min > K_BLOCKING) {
- k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
- }
-
- n_min = N;
- if (N >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (N > N_BLOCKING) {
- n_min = ((n_min / 2 + GEMM_UNROLLING_16 - 1) / GEMM_UNROLLING_16) *
- GEMM_UNROLLING_16;
- } else {
- stride_l1 = 0;
- }
- packing_B16(k_min, n_min, B + ks * ldb, ldb, sB);
-
- for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) {
- m2_min = (ms + m_min) - ms2;
- if (m2_min >= 3 * GEMM_UNROLLING_8) {
- m2_min = 3 * GEMM_UNROLLING_8;
- } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
- m2_min = 2 * GEMM_UNROLLING_8;
- } else if (m2_min > GEMM_UNROLLING_8) {
- m2_min = GEMM_UNROLLING_8;
- }
-
- packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda,
- sA + k_min * (ms2 - ms) * stride_l1);
-
- HGEMM_KERNEL_8x16(m2_min, n_min, k_min,
- sA + stride_l1 * k_min * (ms2 - ms), sB,
- C + ms2 * ldc, ldc);
- }
-
- for (ns = n_min; ns < N; ns += n_min) {
- n_min = N - ns;
- if (n_min >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (n_min > N_BLOCKING) {
- n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
- }
-
- packing_B16(k_min, n_min, B + ns + ldb * ks, ldb, sB);
- HGEMM_KERNEL_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
- }
- }
- }
-
- free(sA);
- free(sB);
-}
-
-void hgemm_noTrans_8x16(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, float *C, unsigned int ldc,
- float alpha, float beta) {
- __fp16 *sA = alignedMalloc(M * K);
- __fp16 *sB = alignedMalloc(K * N);
-
- unsigned int ms, ms2, ns, ks;
- unsigned int m_min, m2_min, n_min, k_min;
- unsigned int stride_l1 = 1;
-
- for (ms = 0; ms < M; ms += M_BLOCKING) {
- m_min = M - ms;
- if (m_min > M_BLOCKING) {
- m_min = M_BLOCKING;
- }
-
- for (ks = 0; ks < K; ks += k_min) {
- k_min = K - ks;
- if (k_min >= (K_BLOCKING << 1)) {
- k_min = K_BLOCKING;
- } else if (k_min > K_BLOCKING) {
- k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
- }
-
- n_min = N;
- if (N >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (N > N_BLOCKING) {
- n_min = ((n_min / 2 + GEMM_UNROLLING_16 - 1) / GEMM_UNROLLING_16) *
- GEMM_UNROLLING_16;
- } else {
- stride_l1 = 0;
- }
- packing_B16(k_min, n_min, B + ks * ldb, ldb, sB);
-
- for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) {
- m2_min = (ms + m_min) - ms2;
- if (m2_min >= 3 * GEMM_UNROLLING_8) {
- m2_min = 3 * GEMM_UNROLLING_8;
- } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
- m2_min = 2 * GEMM_UNROLLING_8;
- } else if (m2_min > GEMM_UNROLLING_8) {
- m2_min = GEMM_UNROLLING_8;
- }
-
- packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda,
- sA + k_min * (ms2 - ms));
-
- HGEMM_KERNEL_8x16(m2_min, n_min, k_min, sA + k_min * (ms2 - ms), sB,
- C + ms2 * ldc, ldc);
- }
-
- for (ns = n_min; ns < N; ns += n_min) {
- n_min = N - ns;
- if (n_min >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (n_min > N_BLOCKING) {
- n_min =
- (n_min / 2 + GEMM_UNROLLING_16 - 1) & ~(GEMM_UNROLLING_16 - 1);
-
- }
-
- packing_B16(k_min, n_min, B + ns + ldb * ks, ldb, sB);
-
- HGEMM_KERNEL_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
- }
- }
- }
-
- free(sA);
- free(sB);
-}
-
-void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, float *C, unsigned int ldc,
- float alpha, float beta) {
-
- unsigned int k = 0;
- unsigned int N8 = (N >> 3) << 3;
- __fp16 a[16];
- for (; (K - k) >= 16; k += 16) {
- for (unsigned int m = 0; m < M; m++) {
- vst1q_f16(&a[0], vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha));
- vst1q_f16(&a[8], vmulq_n_f16(vld1q_f16(&A[m * K + k + 8]), alpha));
- for (unsigned int n = 0; n < N8; n += 8) {
- float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 4) * N + n]), a[4]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 8) * N + n]), a[8]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 9) * N + n]), a[9]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 10) * N + n]), a[10]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 11) * N + n]), a[11]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 12) * N + n]), a[12]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 13) * N + n]), a[13]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 14) * N + n]), a[14]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 15) * N + n]), a[15]);
-
- float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
- vcvt_f32_f16(vget_low_f16(b0_7_0)));
- float32x4_t c0_7_high_32 = vaddq_f32(
- vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));
-
- vst1q_f32(&C[m * N + n], c0_7_low_32);
- vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
- }
- if (N != N8) {
- unsigned int n = N8;
- __fp16 valsB_0[8];
- __fp16 valsB_1[8];
- __fp16 valsB_2[8];
- __fp16 valsB_3[8];
- __fp16 valsB_4[8];
- __fp16 valsB_5[8];
- __fp16 valsB_6[8];
- __fp16 valsB_7[8];
- __fp16 valsB_8[8];
- __fp16 valsB_9[8];
- __fp16 valsB_10[8];
- __fp16 valsB_11[8];
- __fp16 valsB_12[8];
- __fp16 valsB_13[8];
- __fp16 valsB_14[8];
- __fp16 valsB_15[8];
- float valsC[8];
- for (unsigned int idx = n; idx < N; idx++) {
- valsB_0[idx - n] = B[k * N + idx];
- valsB_1[idx - n] = B[(k + 1) * N + idx];
- valsB_2[idx - n] = B[(k + 2) * N + idx];
- valsB_3[idx - n] = B[(k + 3) * N + idx];
- valsB_4[idx - n] = B[(k + 4) * N + idx];
- valsB_5[idx - n] = B[(k + 5) * N + idx];
- valsB_6[idx - n] = B[(k + 6) * N + idx];
- valsB_7[idx - n] = B[(k + 7) * N + idx];
- valsB_8[idx - n] = B[(k + 8) * N + idx];
- valsB_9[idx - n] = B[(k + 9) * N + idx];
- valsB_10[idx - n] = B[(k + 10) * N + idx];
- valsB_11[idx - n] = B[(k + 11) * N + idx];
- valsB_12[idx - n] = B[(k + 12) * N + idx];
- valsB_13[idx - n] = B[(k + 13) * N + idx];
- valsB_14[idx - n] = B[(k + 14) * N + idx];
- valsB_15[idx - n] = B[(k + 15) * N + idx];
- valsC[idx - n] = C[m * N + idx];
- }
-
- float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_4), a[4]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_5), a[5]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_6), a[6]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_7), a[7]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_8), a[8]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_9), a[9]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_10), a[10]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_11), a[11]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_12), a[12]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_13), a[13]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_14), a[14]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_15), a[15]);
-
- float32x4_t c0_7_low_32 =
- vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
-
- float32x4_t c0_7_high_32 =
- vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
-
- vst1q_f32(valsC, c0_7_low_32);
- vst1q_f32(valsC + 4, c0_7_high_32);
-
- for (unsigned int idx = n; idx < N; idx++) {
- C[m * N + idx] = valsC[idx - n];
- }
- }
- }
- }
-
- for (; (K - k) >= 8; k += 8) {
- for (unsigned int m = 0; m < M; m++) {
- vst1q_f16(a, vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha));
-
- for (unsigned int n = 0; n < N8; n += 8) {
- float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 4) * N + n]), a[4]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]);
-
- float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
- vcvt_f32_f16(vget_low_f16(b0_7_0)));
- float32x4_t c0_7_high_32 = vaddq_f32(
- vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));
-
- vst1q_f32(&C[m * N + n], c0_7_low_32);
- vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
- }
- if (N != N8) {
- unsigned int n = N8;
- __fp16 valsB_0[8];
- __fp16 valsB_1[8];
- __fp16 valsB_2[8];
- __fp16 valsB_3[8];
- __fp16 valsB_4[8];
- __fp16 valsB_5[8];
- __fp16 valsB_6[8];
- __fp16 valsB_7[8];
- float valsC[8];
- for (unsigned int idx = n; idx < N; idx++) {
- valsB_0[idx - n] = B[k * N + idx];
- valsB_1[idx - n] = B[(k + 1) * N + idx];
- valsB_2[idx - n] = B[(k + 2) * N + idx];
- valsB_3[idx - n] = B[(k + 3) * N + idx];
- valsB_4[idx - n] = B[(k + 4) * N + idx];
- valsB_5[idx - n] = B[(k + 5) * N + idx];
- valsB_6[idx - n] = B[(k + 6) * N + idx];
- valsB_7[idx - n] = B[(k + 7) * N + idx];
- valsC[idx - n] = C[m * N + idx];
- }
-
- float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_4), a[4]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_5), a[5]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_6), a[6]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_7), a[7]);
-
- float32x4_t c0_7_low_32 =
- vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
-
- float32x4_t c0_7_high_32 =
- vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
-
- vst1q_f32(valsC, c0_7_low_32);
- vst1q_f32(valsC + 4, c0_7_high_32);
-
- for (unsigned int idx = n; idx < N; idx++) {
- C[m * N + idx] = valsC[idx - n];
- }
- }
- }
- }
-
- for (; (K - k) >= 4; k += 4) {
- for (unsigned int m = 0; m < M; m++) {
- vst1_f16(a, vmul_n_f16(vld1_f16(&A[m * K + k]), alpha));
-
- for (unsigned int n = 0; n < N8; n += 8) {
-
- float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
- b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
- float16x8_t b0_7_2 = vmulq_n_f16(vld1q_f16(&B[(k + 2) * N + n]), a[2]);
- b0_7_2 = vfmaq_n_f16(b0_7_2, vld1q_f16(&B[(k + 3) * N + n]), a[3]);
-
- float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
- vcvt_f32_f16(vget_low_f16(b0_7_0)));
- float32x4_t c0_7_high_32 = vaddq_f32(
- vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));
-
- c0_7_low_32 =
- vaddq_f32(c0_7_low_32, vcvt_f32_f16(vget_low_f16(b0_7_2)));
- c0_7_high_32 =
- vaddq_f32(c0_7_high_32, vcvt_f32_f16(vget_high_f16(b0_7_2)));
-
- vst1q_f32(&C[m * N + n], c0_7_low_32);
- vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
- }
- if (N != N8) {
- unsigned int n = N8;
- __fp16 valsB_0[8];
- __fp16 valsB_1[8];
- __fp16 valsB_2[8];
- __fp16 valsB_3[8];
- float valsC[8];
- for (unsigned int idx = n; idx < N; idx++) {
- valsB_0[idx - n] = B[k * N + idx];
- valsB_1[idx - n] = B[(k + 1) * N + idx];
- valsB_2[idx - n] = B[(k + 2) * N + idx];
- valsB_3[idx - n] = B[(k + 3) * N + idx];
- valsC[idx - n] = C[m * N + idx];
- }
-
- float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]);
- b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]);
-
- float32x4_t c0_7_low_32 =
- vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
-
- float32x4_t c0_7_high_32 =
- vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
-
- vst1q_f32(valsC, c0_7_low_32);
- vst1q_f32(valsC + 4, c0_7_high_32);
-
- for (unsigned int idx = n; idx < N; idx++) {
- C[m * N + idx] = valsC[idx - n];
- }
- }
- }
- }
-
- for (; k < K; k++) {
- for (unsigned int m = 0; m < M; m++) {
- __fp16 a0 = alpha * A[m * K + k];
-
- for (unsigned int n = 0; n < N8; n += 8) {
- float16x8_t b0_7 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a0);
-
- float32x4_t c0_7_low_32 =
- vaddq_f32(vld1q_f32(&C[m * N + n]), vcvt_f32_f16(vget_low_f16(b0_7)));
-
- float32x4_t c0_7_high_32 = vaddq_f32(vld1q_f32(&C[m * N + n + 4]),
- vcvt_f32_f16(vget_high_f16(b0_7)));
-
- vst1q_f32(&C[m * N + n], c0_7_low_32);
- vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
- }
- if (N != N8) {
- unsigned int n = N8;
- __fp16 valsB[8];
- float valsC[8];
- for (unsigned int idx = n; idx < N; idx++) {
- valsB[idx - n] = B[k * N + idx];
- valsC[idx - n] = C[m * N + idx];
- }
-
- float16x8_t b = vmulq_n_f16(vld1q_f16(valsB), a0);
-
- float32x4_t c0_7_low_32 =
- vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
-
- float32x4_t c0_7_high_32 =
- vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
-
- vst1q_f32(valsC, c0_7_low_32);
- vst1q_f32(valsC + 4, c0_7_high_32);
-
- for (unsigned int idx = n; idx < N; idx++) {
- C[m * N + idx] = valsC[idx - n];
- }
- }
- }
- }
-}
-
-void hgemm_transB_8x16(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, float *C, unsigned int ldc,
- float alpha, float beta) {
-
- __fp16 *sA = alignedMalloc(M * K);
- __fp16 *sB = alignedMalloc(K * N);
-
- unsigned int ms, ms2, ns, ks;
- unsigned int m_min, m2_min, n_min, k_min;
- unsigned int stride_l1 = 1;
-
- for (ms = 0; ms < M; ms += M_BLOCKING) {
- m_min = M - ms;
- if (m_min > M_BLOCKING) {
- m_min = M_BLOCKING;
- }
-
- for (ks = 0; ks < K; ks += k_min) {
- k_min = K - ks;
- if (k_min >= (K_BLOCKING << 1)) {
- k_min = K_BLOCKING;
- } else if (k_min > K_BLOCKING) {
- k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
- }
-
- n_min = N;
- if (N >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (N > N_BLOCKING) {
- n_min = ((n_min / 2 + GEMM_UNROLLING_16 - 1) / GEMM_UNROLLING_16) *
- GEMM_UNROLLING_16;
- } else {
- stride_l1 = 0;
- }
- packing_transB16(k_min, n_min, B + (ks), ldb, sB);
-
- for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) {
- m2_min = (ms + m_min) - ms2;
- if (m2_min >= 3 * GEMM_UNROLLING_8) {
- m2_min = 3 * GEMM_UNROLLING_8;
- } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
- m2_min = 2 * GEMM_UNROLLING_8;
- } else if (m2_min > GEMM_UNROLLING_8) {
- m2_min = GEMM_UNROLLING_8;
- }
- packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda,
- sA + k_min * (ms2 - ms) * stride_l1);
- HGEMM_KERNEL_8x16(m2_min, n_min, k_min,
- sA + k_min * (ms2 - ms) * stride_l1, sB,
- C + ms2 * ldc, ldc);
- }
-
- for (ns = n_min; ns < N; ns += n_min) {
- n_min = N - ns;
- if (n_min >= N_BLOCKING * 2) {
- n_min = N_BLOCKING;
- } else if (n_min > N_BLOCKING) {
- n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
- }
- packing_transB16(k_min, n_min, B + ks + ldb * ns, ldb, sB);
- HGEMM_KERNEL_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
- }
- }
- }
-
- free(sA);
- free(sB);
-}
-
-void hgemm_transB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
- unsigned int N, unsigned int K, float alpha, float beta) {
- if (((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0)) {
- return hgemm_transB_8x16(M, N, K, A, K, B, K, C, N, alpha, beta);
- } else {
- return hgemm_transB_fallback(A, B, C, M, N, K, alpha, beta);
- }
-}
-
-void hgemm_transB_fallback(const __fp16 *A, const __fp16 *B, float *C,
- unsigned int M, unsigned int N, unsigned int K,
- float alpha, float beta) {
- __fp16 *B_T = alignedMalloc(K * N);
-
- transpose_neon<__fp16>(N, K, B, K, B_T, N);
-
- hgemm_noTrans(A, B_T, C, M, N, K, alpha, beta);
-
- free(B_T);
-}
-
-void hgemm_transA(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
- unsigned int N, unsigned int K, float alpha, float beta) {
- __fp16 *A_T = alignedMalloc(M * K);
-
- transpose_neon<__fp16>(K, M, A, M, A_T, K);
-
- hgemm_noTrans(A_T, B, C, M, N, K, alpha, beta);
-
- free(A_T);
-}
-
-void hgemm_transAB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
- unsigned int N, unsigned int K, float alpha, float beta) {
- __fp16 *A_T = alignedMalloc(M * K);
- __fp16 *B_T = alignedMalloc(K * N);
-
- transpose_neon<__fp16>(K, M, A, M, A_T, K);
- transpose_neon<__fp16>(N, K, B, K, B_T, N);
-
- hgemm_noTrans(A_T, B_T, C, M, N, K, alpha, beta);
-
- free(A_T);
- free(B_T);
-}
-
-void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
- uint32_t N, uint32_t K, float alpha, float beta, bool TransA,
+void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
+ unsigned int N, unsigned int K, float alpha, float beta, bool TransA,
bool TransB) {
unsigned int lda = (TransA) ? M : K;
unsigned int ldb = (TransB) ? K : N;
+
+ return hgemm_K1_noTrans(M, N, K, A, lda, B, ldb, C, N, alpha, beta);
+
if (!TransA && TransB) {
hgemm_K1_transB(M, N, K, A, lda, B, ldb, C, N, alpha, beta);
} else if (TransA && !TransB) {
/**
* @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
- * @param[in] A __fp16 * for Matrix A
- * @param[in] B __fp16 * for Matrix B
- * @param[in] C float * for Matrix C
- * @param[in] M number of op(A)'s and C's row
- * @param[in] N number of op(B)'s and C's columns
- * @param[in] K number of op(A)'s and columns and op(B)'s rows
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
- unsigned int N, unsigned int K, float alpha = 1.F,
- float beta = 0.F);
-
-/**
- * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
- * @param[in] A __fp16 * for Matrix A
- * @param[in] B __fp16 * for Matrix B
- * @param[in] C __fp16 * for Matrix C
- * @param[in] M number of op(A)'s and C's row
- * @param[in] N number of op(B)'s and C's columns
- * @param[in] K number of op(A)'s and columns and op(B)'s rows
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, __fp16 *C,
- unsigned int M, unsigned int N, unsigned int K,
- float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * where op(X) is one of X or X**T
* @param[in] A __fp16 * for Matrix A
* @param[in] B __fp16 * for Matrix B
* @param[in] C __fp16 * for Matrix C
* @param[in] alpha float number
* @param[in] beta float number
*/
-void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, float *C,
- unsigned int M, unsigned int N, unsigned int K,
- float alpha = 1.F, float beta = 0.F);
+void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M, unsigned int N,
+ unsigned int K, float alpha, float beta, bool TransA, bool TransB);
/**
* @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
void hgemm_classify(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
unsigned int N, unsigned int K, float alpha = 1.F, float beta = 0.F,
bool TransA = false, bool TransB = false);
-
-/**
- * @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, float *C, unsigned int ldc,
- float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_K1_noTrans(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, __fp16 *C, unsigned int ldc,
- float alpha = 1.F, float beta = 0.F);
-/**
- * @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_K1_transA(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, __fp16 *C, unsigned int ldc,
- float alpha = 1.F, float beta = 0.F);
-/**
- * @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_K1_transB(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, __fp16 *C, unsigned int ldc,
- float alpha = 1.F, float beta = 0.F);
-/**
- * @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_K1_transAB(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, __fp16 *C, unsigned int ldc,
- float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 1x4 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, __fp16 *C, unsigned int ldc,
- float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 1x4 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, float *C, unsigned int ldc,
- float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 4x4 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, __fp16 *C, unsigned int ldc,
- float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 1x8 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_1x8(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, __fp16 *C, unsigned int ldc,
- float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 1x8 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_1x8(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, float *C, unsigned int ldc,
- float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 8x8 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, __fp16 *C, unsigned int ldc,
- float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 4x4 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, float *C, unsigned int ldc,
- float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 8x8 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, float *C, unsigned int ldc,
- float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 4x8 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, __fp16 *C, unsigned int ldc,
- float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 4x8 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, float *C, unsigned int ldc,
- float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 8x16 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_8x16(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, __fp16 *C, unsigned int ldc,
- float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm noTrans computation with 8x16 kernel : C = A*B,
- *
- * @param M length of the row of matrix A
- * @param N length of the col of matrix B
- * @param K length of the col of matrix A
- * @param A input matrix A
- * @param lda length of the col of matrix A
- * @param B input matrix B
- * @param ldb length of the col of matrix B
- * @param C output matrix C
- * @param ldc length of the col of matrix C
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_noTrans_8x16(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, float *C, unsigned int ldc,
- float alpha = 1.F, float beta = 0.F);
-
-/**
- * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
- * where op(X) is one of X or X**T
- * @param[in] A __fp16 * for Matrix A
- * @param[in] B __fp16 * for Matrix B
- * @param[in] C __fp16 * for Matrix C
- * @param[in] M number of op(A)'s and C's row
- * @param[in] N number of op(B)'s and C's columns
- * @param[in] K number of op(A)'s and columns and op(B)'s rows
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_transA(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
- unsigned int N, unsigned int K, float alpha, float beta);
-/**
- * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
- * where op(X) is one of X or X**T
- * @param[in] A __fp16 * for Matrix A
- * @param[in] B __fp16 * for Matrix B
- * @param[in] C __fp16 * for Matrix C
- * @param[in] M number of op(A)'s and C's row
- * @param[in] N number of op(B)'s and C's columns
- * @param[in] K number of op(A)'s and columns and op(B)'s rows
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_transB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
- unsigned int N, unsigned int K, float alpha, float beta);
-
-void hgemm_transB_fallback(const __fp16 *A, const __fp16 *B, float *C,
- unsigned int M, unsigned int N, unsigned int K,
- float alpha, float beta);
-
-/**
- * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
- * where op(X) is one of X or X**T
- * @param[in] A __fp16 * for Matrix A
- * @param[in] B __fp16 * for Matrix B
- * @param[in] C __fp16 * for Matrix C
- * @param[in] M number of op(A)'s and C's row
- * @param[in] N number of op(B)'s and C's columns
- * @param[in] K number of op(A)'s and columns and op(B)'s rows
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_transB_8x16(unsigned int M, unsigned int N, unsigned int K,
- const __fp16 *A, unsigned int lda, const __fp16 *B,
- unsigned int ldb, float *C, unsigned int ldc,
- float alpha = 1.F, float beta = 0.F);
-/**
- * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
- * where op(X) is one of X or X**T
- * @param[in] A __fp16 * for Matrix A
- * @param[in] B __fp16 * for Matrix B
- * @param[in] C __fp16 * for Matrix C
- * @param[in] M number of op(A)'s and C's row
- * @param[in] N number of op(B)'s and C's columns
- * @param[in] K number of op(A)'s and columns and op(B)'s rows
- * @param[in] alpha float number
- * @param[in] beta float number
- */
-void hgemm_transAB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
- unsigned int N, unsigned int K, float alpha, float beta);
/**
* @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
* where op(X) is one of X or X**T
* @param[in] alpha float number
* @param[in] beta float number
*/
-void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, uint32_t M,
- uint32_t N, uint32_t K, float alpha, float beta, bool TransA,
+void hgemm_K1(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
+ unsigned int N, unsigned int K, float alpha, float beta, bool TransA,
bool TransB);
#include <arm_neon.h>
#include <assert.h>
+
#define A(i, j) a[(i) * lda + (j)]
#define B(i, j) b[(i) * ldb + (j)]
#define C(i, j) c[(i) * ldc + (j)]
#define VL_FP16 (8)
#define VL_FP16_HALF (4)
+
+
/**
* @todo Add macro for instructions in other CPU architectures
*/
--- /dev/null
+// #include <hgemm_kernel_1x4.h>
+// #include <hgemm_kernel_1x8.h>
+// #include <hgemm_kernel_4x4.h>
+// #include <hgemm_kernel_4x8.h>
+// #include <hgemm_kernel_8x16.h>
+// #include <hgemm_kernel_8x8.h>
+
+// #define HGEMM_KERNEL_1x4 hgemm_kernel_1x4
+// #define HGEMM_KERNEL_4x4 hgemm_kernel_4x4
+// #define HGEMM_KERNEL_1x8 hgemm_kernel_1x8
+// #define HGEMM_KERNEL_4x8 hgemm_kernel_4x8
+// #define HGEMM_KERNEL_8x8 hgemm_kernel_8x8
+// #define HGEMM_KERNEL_8x16 hgemm_kernel_8x16
\ No newline at end of file
*
*/
-#include <hgemm_common.h>
#include <stdlib.h>
+#include <arm_neon.h>
+#include <assert.h>
/**
* @brief hgemm 1x4 kernel sc = sa * sb
__fp16 *a = sa, *b = sb, *c = sc;
unsigned int i, j, l;
for (i = 0; i < M; i++) {
- for (j = 0; j < N; j += VL_FP16_HALF) {
+ for (j = 0; j < N; j += 4) {
__builtin_prefetch(b, 0, 3);
__builtin_prefetch(a, 0, 3);
- for (l = 0; l < K; l += VL_FP16_HALF) {
+ for (l = 0; l < K; l += 4) {
float16x4_t v24 = {0.F};
float16x4_t v0 = vld1_f16(b);
float16_t v16 = *a;
float *c = sc;
unsigned int i, j, l;
for (i = 0; i < M; i++) {
- for (j = 0; j < N; j += VL_FP16_HALF) {
+ for (j = 0; j < N; j += 4) {
__builtin_prefetch(b, 0, 3);
__builtin_prefetch(a, 0, 3);
- for (l = 0; l < K; l += VL_FP16_HALF) {
+ for (l = 0; l < K; l += 4) {
float16x4_t v24 = {0.F};
float16x4_t v0 = vld1_f16(b);
float16_t v16 = *a;
*
*/
-#include <hgemm_common.h>
+#include <arm_neon.h>
+#include <assert.h>
#include <stdlib.h>
// 1. Partial sum 64 digits : worst accuracy, best latency
*
*/
-#include <hgemm_common.h>
+#include <arm_neon.h>
+#include <assert.h>
#include <stdlib.h>
#define INIT_KERNEL_4x4() \
__fp16 *a = sa, *b = sb, *c = sc;
unsigned int i, j, l;
- for (i = 0; i < M; i += VL_FP16_HALF) {
- for (j = 0; j < N; j += VL_FP16_HALF) {
+ for (i = 0; i < M; i += 4) {
+ for (j = 0; j < N; j += 4) {
__builtin_prefetch(b, 0, 3);
__builtin_prefetch(a, 0, 3);
float16x4_t v27;
INIT_KERNEL_4x4();
- for (l = 0; l < K; l += VL_FP16_HALF) {
+ for (l = 0; l < K; l += 4) {
float16x4_t v0 = vld1_f16(b);
float16x4_t v16 = vld1_f16(a);
unsigned int i, j, l;
unsigned int K16 = (K >> 4) << 4;
unsigned int K8 = (K >> 3) << 3;
- for (i = 0; i < M; i += VL_FP16_HALF) {
- for (j = 0; j < N; j += VL_FP16_HALF) {
+ for (i = 0; i < M; i += 4) {
+ for (j = 0; j < N; j += 4) {
__builtin_prefetch(b, 0, 3);
__builtin_prefetch(a, 0, 3);
*
*/
-#include <hgemm_common.h>
+#include <arm_neon.h>
+#include <assert.h>
#include <stdlib.h>
#define INIT_KERNEL_4X8() \
*
*/
-#include <hgemm_common.h>
+#include <arm_neon.h>
+#include <assert.h>
#include <iostream>
#include <stdlib.h>
assert(M > 0 && N > 0 && K > 0);
assert(M % 8 == 0 && N % 16 == 0 && K % 4 == 0);
- // std::cout << " m : " << M << " , n : " << N << " , k : " << K << std::endl;
-
__fp16 *a = sa, *b = sb;
float *c = sc;
unsigned int i, j, l;
*
*/
-#include <hgemm_common.h>
+#include <arm_neon.h>
+#include <assert.h>
#include <stdlib.h>
#define INIT_KERNEL_8x8() \
__fp16 *a = sa, *b = sb, *c = sc;
unsigned int i, j, l;
- for (i = 0; i < M; i += VL_FP16) {
- for (j = 0; j < N; j += VL_FP16) {
+ for (i = 0; i < M; i += 8) {
+ for (j = 0; j < N; j += 8) {
__builtin_prefetch(b, 0, 3);
__builtin_prefetch(a, 0, 3);
unsigned int K4 = (K >> 2) << 2;
unsigned int K8 = (K >> 3) << 3;
unsigned int K16 = (K >> 4) << 4;
- for (i = 0; i < M; i += VL_FP16) {
- for (j = 0; j < N; j += VL_FP16) {
+ for (i = 0; i < M; i += 8) {
+ for (j = 0; j < N; j += 8) {
__builtin_prefetch(b, 0, 3);
__builtin_prefetch(a, 0, 3);
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file hgemm_noTrans.cpp
+ * @date 10 July 2024
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug No known bugs except for NYI items
+ * @brief This is half-precision GEMM interface of non-transposed case
+ *
+ */
+
+#include <cmath>
+
+#include <hgemm_kernel_pack.h>
+#include <hgemm_noTrans.h>
+#include <hgemm_util.h>
+#include <limits>
+// #include <hgemm_kernel.h>
+
+#include <matrix_transpose_neon.h>
+#include <hgemm_common.h>
+
+#include <hgemm_kernel_1x4.h>
+#include <hgemm_kernel_1x8.h>
+#include <hgemm_kernel_4x4.h>
+#include <hgemm_kernel_4x8.h>
+#include <hgemm_kernel_8x16.h>
+#include <hgemm_kernel_8x8.h>
+
+
+
+void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
+ unsigned int N, unsigned int K, float alpha, float beta) {
+ const float eps = std::numeric_limits<float>::epsilon();
+ if (std::abs(alpha - 1.F) < eps) {
+ hgemm_noTrans_strict(A, B, C32, M, N, K, alpha, beta);
+ } else {
+ hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ }
+}
+
+void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, float *C32,
+ unsigned int M, unsigned int N, unsigned int K,
+ float alpha, float beta) {
+ // used bitwise operator instead of modulo for performance
+ // e.g (M % 8) is same as (M & 0x7) which will extract last 3 bits of M
+ if ((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0) {
+ hgemm_noTrans_8x16(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ } else if ((M & 0x7) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) {
+ hgemm_noTrans_8x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ } else if ((M & 0x3) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) {
+ hgemm_noTrans_4x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ } else if ((K & 0x7) == 0 && (N & 0x7) == 0) {
+ hgemm_noTrans_1x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ } else if ((K & 0x7) == 0 && (N & 0x3) == 0) {
+ hgemm_noTrans_1x4(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ } else {
+ hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ }
+}
+
+void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, __fp16 *C,
+ unsigned int M, unsigned int N, unsigned int K,
+ float alpha, float beta) {
+ if (alpha == 1.F) {
+ // used bitwise operator instead of modulo for performance
+ // e.g (M % 8) is same as (M & 0x7) which will extract last 3 bits of M
+ if ((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0) {
+ hgemm_noTrans_8x16(M, N, K, A, K, B, N, C, N, alpha, beta);
+ } else if ((M & 0x7) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) {
+ hgemm_noTrans_8x8(M, N, K, A, K, B, N, C, N, alpha, beta);
+ } else if ((M & 0x3) == 0 && (N & 0x7) == 0 && (K & 0x3) == 0) {
+ hgemm_noTrans_4x8(M, N, K, A, K, B, N, C, N, alpha, beta);
+ } else if ((M & 0x3) == 0 && (N & 0x3) == 0 && (K & 0x3) == 0) {
+ hgemm_noTrans_4x4(M, N, K, A, K, B, N, C, N, alpha, beta);
+ } else if ((N & 0x7) == 0 && (K & 0x7) == 0) {
+ hgemm_noTrans_1x8(M, N, K, A, K, B, N, C, N, alpha, beta);
+ } else if ((N & 0x3) == 0 && (K & 0x7) == 0) {
+ hgemm_noTrans_1x4(M, N, K, A, K, B, N, C, N, alpha, beta);
+ }
+ }
+}
+
+void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha, float beta) {
+ __fp16 *sa = alignedMalloc(M * K);
+ __fp16 *sb = alignedMalloc(K * N);
+
+ unsigned int ms, mms, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+ packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
+
+ for (mms = ms; mms < ms + m_min; mms += m2_min) {
+ m2_min = (ms + m_min) - mms;
+ if (m2_min >= 3 * GEMM_UNROLLING_1) {
+ m2_min = 3 * GEMM_UNROLLING_1;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_1) {
+ m2_min = 2 * GEMM_UNROLLING_1;
+ } else if (m2_min > GEMM_UNROLLING_1) {
+ m2_min = GEMM_UNROLLING_1;
+ }
+
+ packing_A1(m2_min, k_min, A + mms * lda + ks, lda,
+ sa + k_min * (mms - ms));
+
+ hgemm_kernel_1x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
+ C + mms * ldc, ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+ }
+
+ packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+ hgemm_kernel_1x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sa);
+ free(sb);
+}
+
+void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha, float beta) {
+ __fp16 *sa = alignedMalloc(M * K);
+ __fp16 *sb = alignedMalloc(K * N);
+
+ unsigned int ms, mms, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+ packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
+
+ for (mms = ms; mms < ms + m_min; mms += m2_min) {
+ m2_min = (ms + m_min) - mms;
+ if (m2_min >= 3 * GEMM_UNROLLING_1) {
+ m2_min = 3 * GEMM_UNROLLING_1;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_1) {
+ m2_min = 2 * GEMM_UNROLLING_1;
+ } else if (m2_min > GEMM_UNROLLING_1) {
+ m2_min = GEMM_UNROLLING_1;
+ }
+
+ packing_A1(m2_min, k_min, A + mms * lda + ks, lda,
+ sa + k_min * (mms - ms));
+
+ hgemm_kernel_1x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
+ C + mms * ldc, ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+ }
+
+ packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+ hgemm_kernel_1x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sa);
+ free(sb);
+}
+
+void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha, float beta) {
+ __fp16 *sa = alignedMalloc(M * K);
+ __fp16 *sb = alignedMalloc(K * N);
+
+ unsigned int ms, mms, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+ packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
+
+ for (mms = ms; mms < ms + m_min; mms += m2_min) {
+ m2_min = (ms + m_min) - mms;
+ if (m2_min >= 3 * GEMM_UNROLLING_4) {
+ m2_min = 3 * GEMM_UNROLLING_4;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
+ m2_min = 2 * GEMM_UNROLLING_4;
+ } else if (m2_min > GEMM_UNROLLING_4) {
+ m2_min = GEMM_UNROLLING_4;
+ }
+
+ packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
+ sa + k_min * (mms - ms));
+
+ hgemm_kernel_4x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
+ C + mms * ldc, ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+
+ packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+ hgemm_kernel_4x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sa);
+ free(sb);
+}
+
+void hgemm_noTrans_1x8(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha, float beta) {
+ __fp16 *sa = alignedMalloc(M * K);
+ __fp16 *sb = alignedMalloc(K * N);
+
+ unsigned int ms, mms, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ unsigned int l1stride = 1;
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) *
+ GEMM_UNROLLING_8;
+ } else {
+ l1stride = 0;
+ }
+ packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
+
+ for (mms = ms; mms < ms + m_min; mms += m2_min) {
+ m2_min = (ms + m_min) - mms;
+ if (m2_min >= 3 * GEMM_UNROLLING_1) {
+ m2_min = 3 * GEMM_UNROLLING_1;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_1) {
+ m2_min = 2 * GEMM_UNROLLING_1;
+ } else if (m2_min > GEMM_UNROLLING_1) {
+ m2_min = GEMM_UNROLLING_1;
+ }
+
+ packing_A1(m2_min, k_min, A + mms * lda + ks, lda,
+ sa + k_min * (mms - ms) * l1stride);
+
+ hgemm_kernel_1x8(m2_min, n_min, k_min,
+ sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc,
+ ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+ }
+
+ packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+ hgemm_kernel_1x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sa);
+ free(sb);
+}
+
+void hgemm_noTrans_1x8(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha, float beta) {
+ __fp16 *sa = alignedMalloc(M * K);
+ __fp16 *sb = alignedMalloc(K * N);
+
+ unsigned int ms, mms, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ unsigned int l1stride = 1;
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) *
+ GEMM_UNROLLING_8;
+ } else {
+ l1stride = 0;
+ }
+ packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
+
+ for (mms = ms; mms < ms + m_min; mms += m2_min) {
+ m2_min = (ms + m_min) - mms;
+ if (m2_min >= 3 * GEMM_UNROLLING_1) {
+ m2_min = 3 * GEMM_UNROLLING_1;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_1) {
+ m2_min = 2 * GEMM_UNROLLING_1;
+ } else if (m2_min > GEMM_UNROLLING_1) {
+ m2_min = GEMM_UNROLLING_1;
+ }
+
+ packing_A1(m2_min, k_min, A + mms * lda + ks, lda,
+ sa + k_min * (mms - ms) * l1stride);
+
+ hgemm_kernel_1x8(m2_min, n_min, k_min,
+ sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc,
+ ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_1 - 1) & ~(GEMM_UNROLLING_1 - 1);
+ }
+
+ packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+ hgemm_kernel_1x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sa);
+ free(sb);
+}
+
+void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha, float beta) {
+ __fp16 *sa = alignedMalloc(M * K);
+ __fp16 *sb = alignedMalloc(K * N);
+
+ unsigned int ms, mms, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+ packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
+
+ for (mms = ms; mms < ms + m_min; mms += m2_min) {
+ m2_min = (ms + m_min) - mms;
+ if (m2_min >= 3 * GEMM_UNROLLING_4) {
+ m2_min = 3 * GEMM_UNROLLING_4;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
+ m2_min = 2 * GEMM_UNROLLING_4;
+ } else if (m2_min > GEMM_UNROLLING_4) {
+ m2_min = GEMM_UNROLLING_4;
+ }
+
+ packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
+ sa + k_min * (mms - ms));
+
+ hgemm_kernel_4x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
+ C + mms * ldc, ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+
+ packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+ hgemm_kernel_4x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sa);
+ free(sb);
+}
+
+void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha, float beta) {
+ __fp16 *sa = alignedMalloc(M * K);
+ __fp16 *sb = alignedMalloc(K * N);
+
+ unsigned int ms, mms, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ unsigned int l1stride = 1;
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) *
+ GEMM_UNROLLING_8;
+ } else {
+ l1stride = 0;
+ }
+ packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
+
+ for (mms = ms; mms < ms + m_min; mms += m2_min) {
+ m2_min = (ms + m_min) - mms;
+ if (m2_min >= 3 * GEMM_UNROLLING_4) {
+ m2_min = 3 * GEMM_UNROLLING_4;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
+ m2_min = 2 * GEMM_UNROLLING_4;
+ } else if (m2_min > GEMM_UNROLLING_4) {
+ m2_min = GEMM_UNROLLING_4;
+ }
+
+ packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
+ sa + k_min * (mms - ms) * l1stride);
+
+ hgemm_kernel_4x8(m2_min, n_min, k_min,
+ sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc,
+ ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+
+ packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+ hgemm_kernel_4x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sa);
+ free(sb);
+}
+
+void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha, float beta) {
+ __fp16 *sa = alignedMalloc(M * K);
+ __fp16 *sb = alignedMalloc(K * N);
+
+ unsigned int ms, mms, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ unsigned int l1stride = 1;
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = ((n_min / 2 + GEMM_UNROLLING_8 - 1) / GEMM_UNROLLING_8) *
+ GEMM_UNROLLING_8;
+ } else {
+ l1stride = 0;
+ }
+ packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
+
+ for (mms = ms; mms < ms + m_min; mms += m2_min) {
+ m2_min = (ms + m_min) - mms;
+ if (m2_min >= 3 * GEMM_UNROLLING_4) {
+ m2_min = 3 * GEMM_UNROLLING_4;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
+ m2_min = 2 * GEMM_UNROLLING_4;
+ } else if (m2_min > GEMM_UNROLLING_4) {
+ m2_min = GEMM_UNROLLING_4;
+ }
+
+ packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
+ sa + k_min * (mms - ms) * l1stride);
+
+ hgemm_kernel_4x8(m2_min, n_min, k_min,
+ sa + l1stride * k_min * (mms - ms), sb, C + mms * ldc,
+ ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+
+ packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+ hgemm_kernel_4x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sa);
+ free(sb);
+}
+
+void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha, float beta) {
+
+ __fp16 *sa = alignedMalloc(M * K);
+ __fp16 *sb = alignedMalloc(K * N);
+
+ unsigned int ms, mms, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+ }
+ packing_B8(k_min, n_min, B + ks * ldb, ldb, sb);
+
+ for (mms = ms; mms < ms + m_min; mms += m2_min) {
+ m2_min = (ms + m_min) - mms;
+ if (m2_min >= 3 * GEMM_UNROLLING_8) {
+ m2_min = 3 * GEMM_UNROLLING_8;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
+ m2_min = 2 * GEMM_UNROLLING_8;
+ } else if (m2_min > GEMM_UNROLLING_8) {
+ m2_min = GEMM_UNROLLING_8;
+ }
+
+ packing_A8(m2_min, k_min, A + mms * lda + ks, lda,
+ sa + k_min * (mms - ms));
+
+ hgemm_kernel_8x8(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
+ C + mms * ldc, ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+ }
+
+ packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+ hgemm_kernel_8x8(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sa);
+ free(sb);
+}
+
+void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha, float beta) {
+
+ __fp16 *sA = alignedMalloc(M * K);
+ __fp16 *sB = alignedMalloc(K * N);
+
+ unsigned int ms, ms2, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+ }
+ packing_B8(k_min, n_min, B + ks * ldb, ldb, sB);
+
+ for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) {
+ m2_min = (ms + m_min) - ms2;
+ if (m2_min >= 3 * GEMM_UNROLLING_8) {
+ m2_min = 3 * GEMM_UNROLLING_8;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
+ m2_min = 2 * GEMM_UNROLLING_8;
+ } else if (m2_min > GEMM_UNROLLING_8) {
+ m2_min = GEMM_UNROLLING_8;
+ }
+
+ packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda,
+ sA + k_min * (ms2 - ms));
+
+ hgemm_kernel_8x8(m2_min, n_min, k_min, sA + k_min * (ms2 - ms), sB,
+ C + ms2 * ldc, ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+ }
+
+ packing_B8(k_min, n_min, B + ns + ldb * ks, ldb, sB);
+ hgemm_kernel_8x8(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sA);
+ free(sB);
+}
+
+void hgemm_noTrans_8x16(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha, float beta) {
+ __fp16 *sA = alignedMalloc(M * K);
+ __fp16 *sB = alignedMalloc(K * N);
+
+ unsigned int ms, ms2, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ unsigned int stride_l1 = 1;
+
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = ((n_min / 2 + GEMM_UNROLLING_16 - 1) / GEMM_UNROLLING_16) *
+ GEMM_UNROLLING_16;
+ } else {
+ stride_l1 = 0;
+ }
+ packing_B16(k_min, n_min, B + ks * ldb, ldb, sB);
+
+ for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) {
+ m2_min = (ms + m_min) - ms2;
+ if (m2_min >= 3 * GEMM_UNROLLING_8) {
+ m2_min = 3 * GEMM_UNROLLING_8;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
+ m2_min = 2 * GEMM_UNROLLING_8;
+ } else if (m2_min > GEMM_UNROLLING_8) {
+ m2_min = GEMM_UNROLLING_8;
+ }
+
+ packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda,
+ sA + k_min * (ms2 - ms) * stride_l1);
+
+ hgemm_kernel_8x16(m2_min, n_min, k_min,
+ sA + stride_l1 * k_min * (ms2 - ms), sB,
+ C + ms2 * ldc, ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+ }
+
+ packing_B16(k_min, n_min, B + ns + ldb * ks, ldb, sB);
+ hgemm_kernel_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sA);
+ free(sB);
+}
+
+void hgemm_noTrans_8x16(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha, float beta) {
+ __fp16 *sA = alignedMalloc(M * K);
+ __fp16 *sB = alignedMalloc(K * N);
+
+ unsigned int ms, ms2, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ unsigned int stride_l1 = 1;
+
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = ((n_min / 2 + GEMM_UNROLLING_16 - 1) / GEMM_UNROLLING_16) *
+ GEMM_UNROLLING_16;
+ } else {
+ stride_l1 = 0;
+ }
+ packing_B16(k_min, n_min, B + ks * ldb, ldb, sB);
+
+ for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) {
+ m2_min = (ms + m_min) - ms2;
+ if (m2_min >= 3 * GEMM_UNROLLING_8) {
+ m2_min = 3 * GEMM_UNROLLING_8;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
+ m2_min = 2 * GEMM_UNROLLING_8;
+ } else if (m2_min > GEMM_UNROLLING_8) {
+ m2_min = GEMM_UNROLLING_8;
+ }
+
+ packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda,
+ sA + k_min * (ms2 - ms));
+
+ hgemm_kernel_8x16(m2_min, n_min, k_min, sA + k_min * (ms2 - ms), sB,
+ C + ms2 * ldc, ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min =
+ (n_min / 2 + GEMM_UNROLLING_16 - 1) & ~(GEMM_UNROLLING_16 - 1);
+ }
+
+ packing_B16(k_min, n_min, B + ns + ldb * ks, ldb, sB);
+
+ hgemm_kernel_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sA);
+ free(sB);
+}
+
+void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha, float beta) {
+
+ unsigned int k = 0;
+ unsigned int N8 = (N >> 3) << 3;
+ __fp16 a[16];
+ for (; (K - k) >= 16; k += 16) {
+ for (unsigned int m = 0; m < M; m++) {
+ vst1q_f16(&a[0], vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha));
+ vst1q_f16(&a[8], vmulq_n_f16(vld1q_f16(&A[m * K + k + 8]), alpha));
+ for (unsigned int n = 0; n < N8; n += 8) {
+ float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 4) * N + n]), a[4]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 8) * N + n]), a[8]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 9) * N + n]), a[9]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 10) * N + n]), a[10]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 11) * N + n]), a[11]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 12) * N + n]), a[12]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 13) * N + n]), a[13]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 14) * N + n]), a[14]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 15) * N + n]), a[15]);
+
+ float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
+ vcvt_f32_f16(vget_low_f16(b0_7_0)));
+ float32x4_t c0_7_high_32 = vaddq_f32(
+ vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));
+
+ vst1q_f32(&C[m * N + n], c0_7_low_32);
+ vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
+ }
+ if (N != N8) {
+ unsigned int n = N8;
+ __fp16 valsB_0[8];
+ __fp16 valsB_1[8];
+ __fp16 valsB_2[8];
+ __fp16 valsB_3[8];
+ __fp16 valsB_4[8];
+ __fp16 valsB_5[8];
+ __fp16 valsB_6[8];
+ __fp16 valsB_7[8];
+ __fp16 valsB_8[8];
+ __fp16 valsB_9[8];
+ __fp16 valsB_10[8];
+ __fp16 valsB_11[8];
+ __fp16 valsB_12[8];
+ __fp16 valsB_13[8];
+ __fp16 valsB_14[8];
+ __fp16 valsB_15[8];
+ float valsC[8];
+ for (unsigned int idx = n; idx < N; idx++) {
+ valsB_0[idx - n] = B[k * N + idx];
+ valsB_1[idx - n] = B[(k + 1) * N + idx];
+ valsB_2[idx - n] = B[(k + 2) * N + idx];
+ valsB_3[idx - n] = B[(k + 3) * N + idx];
+ valsB_4[idx - n] = B[(k + 4) * N + idx];
+ valsB_5[idx - n] = B[(k + 5) * N + idx];
+ valsB_6[idx - n] = B[(k + 6) * N + idx];
+ valsB_7[idx - n] = B[(k + 7) * N + idx];
+ valsB_8[idx - n] = B[(k + 8) * N + idx];
+ valsB_9[idx - n] = B[(k + 9) * N + idx];
+ valsB_10[idx - n] = B[(k + 10) * N + idx];
+ valsB_11[idx - n] = B[(k + 11) * N + idx];
+ valsB_12[idx - n] = B[(k + 12) * N + idx];
+ valsB_13[idx - n] = B[(k + 13) * N + idx];
+ valsB_14[idx - n] = B[(k + 14) * N + idx];
+ valsB_15[idx - n] = B[(k + 15) * N + idx];
+ valsC[idx - n] = C[m * N + idx];
+ }
+
+ float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_4), a[4]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_5), a[5]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_6), a[6]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_7), a[7]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_8), a[8]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_9), a[9]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_10), a[10]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_11), a[11]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_12), a[12]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_13), a[13]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_14), a[14]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_15), a[15]);
+
+ float32x4_t c0_7_low_32 =
+ vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
+
+ float32x4_t c0_7_high_32 =
+ vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
+
+ vst1q_f32(valsC, c0_7_low_32);
+ vst1q_f32(valsC + 4, c0_7_high_32);
+
+ for (unsigned int idx = n; idx < N; idx++) {
+ C[m * N + idx] = valsC[idx - n];
+ }
+ }
+ }
+ }
+
+ for (; (K - k) >= 8; k += 8) {
+ for (unsigned int m = 0; m < M; m++) {
+ vst1q_f16(a, vmulq_n_f16(vld1q_f16(&A[m * K + k]), alpha));
+
+ for (unsigned int n = 0; n < N8; n += 8) {
+ float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 2) * N + n]), a[2]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 3) * N + n]), a[3]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 4) * N + n]), a[4]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 5) * N + n]), a[5]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 6) * N + n]), a[6]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 7) * N + n]), a[7]);
+
+ float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
+ vcvt_f32_f16(vget_low_f16(b0_7_0)));
+ float32x4_t c0_7_high_32 = vaddq_f32(
+ vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));
+
+ vst1q_f32(&C[m * N + n], c0_7_low_32);
+ vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
+ }
+ if (N != N8) {
+ unsigned int n = N8;
+ __fp16 valsB_0[8];
+ __fp16 valsB_1[8];
+ __fp16 valsB_2[8];
+ __fp16 valsB_3[8];
+ __fp16 valsB_4[8];
+ __fp16 valsB_5[8];
+ __fp16 valsB_6[8];
+ __fp16 valsB_7[8];
+ float valsC[8];
+ for (unsigned int idx = n; idx < N; idx++) {
+ valsB_0[idx - n] = B[k * N + idx];
+ valsB_1[idx - n] = B[(k + 1) * N + idx];
+ valsB_2[idx - n] = B[(k + 2) * N + idx];
+ valsB_3[idx - n] = B[(k + 3) * N + idx];
+ valsB_4[idx - n] = B[(k + 4) * N + idx];
+ valsB_5[idx - n] = B[(k + 5) * N + idx];
+ valsB_6[idx - n] = B[(k + 6) * N + idx];
+ valsB_7[idx - n] = B[(k + 7) * N + idx];
+ valsC[idx - n] = C[m * N + idx];
+ }
+
+ float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_4), a[4]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_5), a[5]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_6), a[6]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_7), a[7]);
+
+ float32x4_t c0_7_low_32 =
+ vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
+
+ float32x4_t c0_7_high_32 =
+ vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
+
+ vst1q_f32(valsC, c0_7_low_32);
+ vst1q_f32(valsC + 4, c0_7_high_32);
+
+ for (unsigned int idx = n; idx < N; idx++) {
+ C[m * N + idx] = valsC[idx - n];
+ }
+ }
+ }
+ }
+
+ for (; (K - k) >= 4; k += 4) {
+ for (unsigned int m = 0; m < M; m++) {
+ vst1_f16(a, vmul_n_f16(vld1_f16(&A[m * K + k]), alpha));
+
+ for (unsigned int n = 0; n < N8; n += 8) {
+
+ float16x8_t b0_7_0 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a[0]);
+ b0_7_0 = vfmaq_n_f16(b0_7_0, vld1q_f16(&B[(k + 1) * N + n]), a[1]);
+ float16x8_t b0_7_2 = vmulq_n_f16(vld1q_f16(&B[(k + 2) * N + n]), a[2]);
+ b0_7_2 = vfmaq_n_f16(b0_7_2, vld1q_f16(&B[(k + 3) * N + n]), a[3]);
+
+ float32x4_t c0_7_low_32 = vaddq_f32(vld1q_f32(&C[m * N + n]),
+ vcvt_f32_f16(vget_low_f16(b0_7_0)));
+ float32x4_t c0_7_high_32 = vaddq_f32(
+ vld1q_f32(&C[m * N + n + 4]), vcvt_f32_f16(vget_high_f16(b0_7_0)));
+
+ c0_7_low_32 =
+ vaddq_f32(c0_7_low_32, vcvt_f32_f16(vget_low_f16(b0_7_2)));
+ c0_7_high_32 =
+ vaddq_f32(c0_7_high_32, vcvt_f32_f16(vget_high_f16(b0_7_2)));
+
+ vst1q_f32(&C[m * N + n], c0_7_low_32);
+ vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
+ }
+ if (N != N8) {
+ unsigned int n = N8;
+ __fp16 valsB_0[8];
+ __fp16 valsB_1[8];
+ __fp16 valsB_2[8];
+ __fp16 valsB_3[8];
+ float valsC[8];
+ for (unsigned int idx = n; idx < N; idx++) {
+ valsB_0[idx - n] = B[k * N + idx];
+ valsB_1[idx - n] = B[(k + 1) * N + idx];
+ valsB_2[idx - n] = B[(k + 2) * N + idx];
+ valsB_3[idx - n] = B[(k + 3) * N + idx];
+ valsC[idx - n] = C[m * N + idx];
+ }
+
+ float16x8_t b = vmulq_n_f16(vld1q_f16(valsB_0), a[0]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_1), a[1]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_2), a[2]);
+ b = vfmaq_n_f16(b, vld1q_f16(valsB_3), a[3]);
+
+ float32x4_t c0_7_low_32 =
+ vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
+
+ float32x4_t c0_7_high_32 =
+ vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
+
+ vst1q_f32(valsC, c0_7_low_32);
+ vst1q_f32(valsC + 4, c0_7_high_32);
+
+ for (unsigned int idx = n; idx < N; idx++) {
+ C[m * N + idx] = valsC[idx - n];
+ }
+ }
+ }
+ }
+
+ for (; k < K; k++) {
+ for (unsigned int m = 0; m < M; m++) {
+ __fp16 a0 = alpha * A[m * K + k];
+
+ for (unsigned int n = 0; n < N8; n += 8) {
+ float16x8_t b0_7 = vmulq_n_f16(vld1q_f16(&B[k * N + n]), a0);
+
+ float32x4_t c0_7_low_32 =
+ vaddq_f32(vld1q_f32(&C[m * N + n]), vcvt_f32_f16(vget_low_f16(b0_7)));
+
+ float32x4_t c0_7_high_32 = vaddq_f32(vld1q_f32(&C[m * N + n + 4]),
+ vcvt_f32_f16(vget_high_f16(b0_7)));
+
+ vst1q_f32(&C[m * N + n], c0_7_low_32);
+ vst1q_f32(&C[m * N + n + 4], c0_7_high_32);
+ }
+ if (N != N8) {
+ unsigned int n = N8;
+ __fp16 valsB[8];
+ float valsC[8];
+ for (unsigned int idx = n; idx < N; idx++) {
+ valsB[idx - n] = B[k * N + idx];
+ valsC[idx - n] = C[m * N + idx];
+ }
+
+ float16x8_t b = vmulq_n_f16(vld1q_f16(valsB), a0);
+
+ float32x4_t c0_7_low_32 =
+ vaddq_f32(vld1q_f32(valsC), vcvt_f32_f16(vget_low_f16(b)));
+
+ float32x4_t c0_7_high_32 =
+ vaddq_f32(vld1q_f32(valsC + 4), vcvt_f32_f16(vget_high_f16(b)));
+
+ vst1q_f32(valsC, c0_7_low_32);
+ vst1q_f32(valsC + 4, c0_7_high_32);
+
+ for (unsigned int idx = n; idx < N; idx++) {
+ C[m * N + idx] = valsC[idx - n];
+ }
+ }
+ }
+ }
+}
+
+void hgemm_K1_noTrans(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha, float beta) {
+ const float eps = std::numeric_limits<float>::epsilon();
+ float16x8_t a_vec;
+ unsigned int N8 = (N >> 3) << 3;
+ for (unsigned int m = 0; m < M; ++m) {
+ a_vec = vmovq_n_f16(alpha * A[m]);
+ if (std::fpclassify(beta) != FP_ZERO) {
+ for (unsigned int n = 0; n < N8; n += 8) {
+ vst1q_f16(&C[m * ldc + n],
+ vaddq_f16(vmulq_f16(a_vec, vld1q_f16(&B[n])),
+ vmulq_n_f16(vld1q_f16(&C[m * ldc + n]), beta)));
+ }
+ } else {
+ for (unsigned int n = 0; n < N8; n += 8) {
+ vst1q_f16(&C[m * ldc + n], vmulq_f16(a_vec, vld1q_f16(&B[n])));
+ }
+ }
+ for (unsigned int n = N8; n < N; ++n) {
+ C[m * ldc + n] = alpha * A[m] * B[n] + beta * C[m * ldc + n];
+ }
+ }
+}
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file hgemm_noTrans.h
+ * @date 10 July 2024
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug No known bugs except for NYI items
+ * @brief This is half-precision GEMM interface of non-transposed case
+ *
+ */
+
+/**
+ * @brief hgemm noTrans computation with 1x4 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 1x4 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 4x4 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 1x8 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_1x8(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 1x8 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_1x8(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 8x8 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 4x4 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 8x8 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_8x8(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 4x8 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 4x8 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 8x16 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_8x16(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm noTrans computation with 8x16 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_8x16(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_K1_noTrans(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+/**
+ * @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_fallback(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C float * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
+ unsigned int N, unsigned int K, float alpha = 1.F,
+ float beta = 0.F);
+
+/**
+ * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, __fp16 *C,
+ unsigned int M, unsigned int N, unsigned int K,
+ float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, float *C,
+ unsigned int M, unsigned int N, unsigned int K,
+ float alpha = 1.F, float beta = 0.F);
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file hgemm_transA.cpp
+ * @date 10 July 2024
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug No known bugs except for NYI items
+ * @brief This is half-precision GEMM interface of transposed A case
+ *
+ */
+
+#include <hgemm_noTrans.h>
+#include <hgemm_transA.h>
+#include <hgemm_util.h>
+#include <matrix_transpose_neon.h>
+
+void hgemm_transA(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
+ unsigned int N, unsigned int K, float alpha, float beta) {
+ __fp16 *A_T = alignedMalloc(M * K);
+
+ transpose_neon<__fp16>(K, M, A, M, A_T, K);
+
+ hgemm_noTrans(A_T, B, C, M, N, K, alpha, beta);
+
+ free(A_T);
+}
+
+void hgemm_K1_transA(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc, float alpha,
+ float beta) {
+ __fp16 *A_T = alignedMalloc(M * K);
+
+ transpose_neon<__fp16>(K, M, A, M, A_T, K);
+
+ hgemm_K1_noTrans(M, N, K, A_T, lda, B, ldb, C, ldc, alpha, beta);
+
+ free(A_T);
+}
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file hgemm_transA.h
+ * @date 10 July 2024
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug No known bugs except for NYI items
+ * @brief This is half-precision GEMM interface of transposed A case
+ *
+ */
+
+/**
+ * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * where op(X) is one of X or X**T
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_transA(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
+ unsigned int N, unsigned int K, float alpha, float beta);
+/**
+ * @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_K1_transA(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file hgemm_transAB.cpp
+ * @date 10 July 2024
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug No known bugs except for NYI items
+ * @brief This is half-precision GEMM interface of transposed AB case
+ *
+ */
+
+#include <hgemm_noTrans.h>
+#include <hgemm_transAB.h>
+#include <hgemm_util.h>
+#include <matrix_transpose_neon.h>
+
+void hgemm_transAB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
+ unsigned int N, unsigned int K, float alpha, float beta) {
+ __fp16 *A_T = alignedMalloc(M * K);
+ __fp16 *B_T = alignedMalloc(K * N);
+
+ transpose_neon<__fp16>(K, M, A, M, A_T, K);
+ transpose_neon<__fp16>(N, K, B, K, B_T, N);
+
+ hgemm_noTrans(A_T, B_T, C, M, N, K, alpha, beta);
+
+ free(A_T);
+ free(B_T);
+}
+
+void hgemm_K1_transAB(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha, float beta) {
+ __fp16 *A_T = alignedMalloc(M * K);
+ __fp16 *B_T = alignedMalloc(K * N);
+
+ transpose_neon<__fp16>(K, M, A, M, A_T, K);
+ transpose_neon<__fp16>(N, K, B, K, B_T, N);
+
+ hgemm_K1_noTrans(M, N, K, A_T, lda, B_T, ldb, C, ldc, alpha, beta);
+
+ free(A_T);
+ free(B_T);
+}
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file hgemm_transAB.h
+ * @date 10 July 2024
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug No known bugs except for NYI items
+ * @brief This is half-precision GEMM interface of transposed AB case
+ *
+ */
+
+/**
+ * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * where op(X) is one of X or X**T
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_transAB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
+ unsigned int N, unsigned int K, float alpha, float beta);
+/**
+ * @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_K1_transAB(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
\ No newline at end of file
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file hgemm_transB.cpp
+ * @date 10 July 2024
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug No known bugs except for NYI items
+ * @brief This is half-precision GEMM interface of transposed B case
+ *
+ */
+
+#include <cmath>
+#include <hgemm_kernel_8x16.h>
+#include <hgemm_common.h>
+// #include <hgemm_kernel.h>
+#include <hgemm_kernel_pack.h>
+#include <hgemm_noTrans.h>
+#include <hgemm_transB.h>
+#include <hgemm_util.h>
+#include <limits>
+#include <matrix_transpose_neon.h>
+
+// #define HGEMM_KERNEL_8x16 hgemm_kernel_8x16 /// @todo change to macro kernel
+// #if !defined(HGEMM_KERNEL_8x16) hgemm_kernel_8x16
+// #endif
+
+void hgemm_transB_8x16(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha, float beta) {
+
+ __fp16 *sA = alignedMalloc(M * K);
+ __fp16 *sB = alignedMalloc(K * N);
+
+ unsigned int ms, ms2, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ unsigned int stride_l1 = 1;
+
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = ((n_min / 2 + GEMM_UNROLLING_16 - 1) / GEMM_UNROLLING_16) *
+ GEMM_UNROLLING_16;
+ } else {
+ stride_l1 = 0;
+ }
+ packing_transB16(k_min, n_min, B + (ks), ldb, sB);
+
+ for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) {
+ m2_min = (ms + m_min) - ms2;
+ if (m2_min >= 3 * GEMM_UNROLLING_8) {
+ m2_min = 3 * GEMM_UNROLLING_8;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
+ m2_min = 2 * GEMM_UNROLLING_8;
+ } else if (m2_min > GEMM_UNROLLING_8) {
+ m2_min = GEMM_UNROLLING_8;
+ }
+ packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda,
+ sA + k_min * (ms2 - ms) * stride_l1);
+ hgemm_kernel_8x16(m2_min, n_min, k_min,
+ sA + k_min * (ms2 - ms) * stride_l1, sB,
+ C + ms2 * ldc, ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+ }
+ packing_transB16(k_min, n_min, B + ks + ldb * ns, ldb, sB);
+ hgemm_kernel_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns,
+ ldc);
+ }
+ }
+ }
+
+ free(sA);
+ free(sB);
+}
+
+void hgemm_transB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
+ unsigned int N, unsigned int K, float alpha, float beta) {
+ const float eps = std::numeric_limits<float>::epsilon();
+ if (((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0 &&
+ (std::abs(alpha - 1.F) < eps))) {
+ return hgemm_transB_8x16(M, N, K, A, K, B, K, C, N, alpha, beta);
+ } else {
+ return hgemm_transB_fallback(A, B, C, M, N, K, alpha, beta);
+ }
+}
+
+void hgemm_transB_fallback(const __fp16 *A, const __fp16 *B, float *C,
+ unsigned int M, unsigned int N, unsigned int K,
+ float alpha, float beta) {
+ __fp16 *B_T = alignedMalloc(K * N);
+
+ transpose_neon<__fp16>(N, K, B, K, B_T, N);
+
+ hgemm_noTrans(A, B_T, C, M, N, K, alpha, beta);
+
+ free(B_T);
+}
+
+void hgemm_K1_transB(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc, float alpha,
+ float beta) {
+ __fp16 *B_T = alignedMalloc(K * N);
+
+ transpose_neon<__fp16>(N, K, B, K, B_T, N);
+
+ hgemm_K1_noTrans(M, N, K, A, lda, B_T, ldb, C, ldc, alpha, beta);
+
+ free(B_T);
+}
--- /dev/null
+// SPDX-License-Identifier: Apache-2.0
+/**
+ * Copyright (C) 2024 Sungsik Kong <ss.kong@samsung.com>
+ *
+ * @file hgemm_transB.h
+ * @date 10 July 2024
+ * @see https://github.com/nnstreamer/nntrainer
+ * @author Sungsik Kong <ss.kong@samsung.com>
+ * @bug No known bugs except for NYI items
+ * @brief This is half-precision GEMM interface of transposed B case
+ *
+ */
+
+/**
+ * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * where op(X) is one of X or X**T
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_transB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
+ unsigned int N, unsigned int K, float alpha, float beta);
+
+void hgemm_transB_fallback(const __fp16 *A, const __fp16 *B, float *C,
+ unsigned int M, unsigned int N, unsigned int K,
+ float alpha, float beta);
+
+/**
+ * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * where op(X) is one of X or X**T
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_transB_8x16(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+/**
+ * @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix A
+ * @param B input matrix B
+ * @param ldb length of the col of matrix B
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_K1_transB(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, __fp16 *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
'hgemm_padding_a.cpp',
'hgemm_padding_b.cpp',
'hgemm_kernel_pack.cpp',
+ 'hgemm_noTrans.cpp',
+ 'hgemm_transA.cpp',
+ 'hgemm_transB.cpp',
+ 'hgemm_transAB.cpp',
]
foreach s : hgemm_sources
EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1);
}
+TEST(nntrainer_Tensor, dot_gemm_K1) {
+ /// @note GEMM : A X B = C
+ int batch = 1;
+ int channel = 1;
+ int height = 56;
+ int width = 1;
+
+ int height_b = 1;
+ int width_b = 516;
+
+ bool transA = false;
+ bool transB = false;
+
+ nntrainer::TensorDim::TensorType t_type_nchw_fp16 = {
+ nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP16};
+
+ nntrainer::TensorDim::TensorType t_type_nchw_fp32 = {
+ nntrainer::Tformat::NCHW, nntrainer::Tdatatype::FP32};
+
+ nntrainer::Tensor A(batch, channel, height, width, t_type_nchw_fp16);
+ nntrainer::Tensor B(batch, channel, height_b, width_b, t_type_nchw_fp16);
+
+ nntrainer::Tensor A_fp32(batch, channel, height, width, t_type_nchw_fp32);
+ nntrainer::Tensor B_fp32(batch, channel, height_b, width_b, t_type_nchw_fp32);
+
+ const float alpha = 1e-1;
+ const int MOD = 10;
+
+ GEN_TEST_INPUT(A, ((i * (batch * height * channel) + j * (batch * height) +
+ k * (width) + l + 1) %
+ MOD) *
+ alpha);
+ GEN_TEST_INPUT_B(B, ((i * (batch * height_b * channel) +
+ j * (batch * height_b) + k * (width_b) + l + 1) %
+ MOD) *
+ alpha);
+ GEN_TEST_INPUT(A_fp32, ((i * (batch * height * channel) +
+ j * (batch * height) + k * (width) + l + 1) %
+ MOD) *
+ alpha);
+ GEN_TEST_INPUT_B(B_fp32, ((i * (batch * height_b * channel) +
+ j * (batch * height_b) + k * (width_b) + l + 1) %
+ MOD) *
+ alpha);
+
+ nntrainer::Tensor C = A.dot(B, transA, transB);
+
+ nntrainer::Tensor C_fp32 = A_fp32.dot(B_fp32, transA, transB);
+
+ float mseErrorNeon =
+ mse<__fp16>(C.getData<__fp16>(), C_fp32.getData<float>(), C.size());
+
+ double cosSimNeon = cosine_similarity<__fp16>(
+ C.getData<__fp16>(), C_fp32.getData<float>(), C.size());
+
+ const float epsilon = 1e-3 * width;
+
+ EXPECT_IN_RANGE(mseErrorNeon, 0, epsilon);
+ EXPECT_IN_RANGE((float)cosSimNeon, 0.99, 1);
+}
+
TEST(nntrainer_Tensor, dot_gemv_768_96000) {
/// @note GEMV : A X B = C
int batch = 1;