free(sb);
}
+void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha, float beta) {
+ __fp16 *sa = alignedMalloc(M * K);
+ __fp16 *sb = alignedMalloc(K * N);
+
+ unsigned int ms, mms, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+ packing_B4(k_min, n_min, B + ks * ldb, ldb, sb);
+
+ for (mms = ms; mms < ms + m_min; mms += m2_min) {
+ m2_min = (ms + m_min) - mms;
+ if (m2_min >= 3 * GEMM_UNROLLING_4) {
+ m2_min = 3 * GEMM_UNROLLING_4;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_4) {
+ m2_min = 2 * GEMM_UNROLLING_4;
+ } else if (m2_min > GEMM_UNROLLING_4) {
+ m2_min = GEMM_UNROLLING_4;
+ }
+
+ packing_A4(m2_min, k_min, A + mms * lda + ks, lda,
+ sa + k_min * (mms - ms));
+
+ HGEMM_KERNEL_4x4(m2_min, n_min, k_min, sa + k_min * (mms - ms), sb,
+ C + mms * ldc, ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+
+ packing_B4(k_min, n_min, B + ns + ldb * ks, ldb, sb);
+ HGEMM_KERNEL_4x4(m_min, n_min, k_min, sa, sb, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sa);
+ free(sb);
+}
+
+
void hgemm_noTrans_4x8(unsigned int M, unsigned int N, unsigned int K,
const __fp16 *A, unsigned int lda, const __fp16 *B,
unsigned int ldb, __fp16 *C, unsigned int ldc,
unsigned int ldb, __fp16 *C, unsigned int ldc,
float alpha = 1.F, float beta = 0.F);
+/**
+ * @brief hgemm noTrans computation with 4x4 kernel : C = A*B,
+ *
+ * @param M length of the row of matrix A
+ * @param N length of the col of matrix B
+ * @param K length of the col of matrix A
+ * @param A input matrix A
+ * @param lda length of the col of matrix C
+ * @param B input matrix B
+ * @param ldb length of the col of matrix C
+ * @param C output matrix C
+ * @param ldc length of the col of matrix C
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_4x4(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
+
/**
* @brief hgemm noTrans computation with 8x8 kernel : C = A*B,
*