#include <hgemm_kernel_8x16.h>
#include <hgemm_kernel_8x8.h>
#include <hgemm_kernel_pack.h>
+#include <hgemm_padding.h>
#include <hgemm_util.h>
#include <limits>
#include <matrix_transpose_neon.h>
unsigned int N, unsigned int K, float alpha, float beta) {
const float eps = std::numeric_limits<float>::epsilon();
if (std::abs(alpha - 1.F) < eps) {
- if ((K & 0x7) != 0) {
- hgemm_noTrans_padding_wrt_K(A, B, C32, M, N, K, alpha, beta);
- } else {
- hgemm_noTrans_strict(A, B, C32, M, N, K, alpha, beta);
- }
+ hgemm_noTrans_strict(A, B, C32, M, N, K, alpha, beta);
} else {
hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
}
}
}
-void hgemm_noTrans_padding_wrt_K(const __fp16 *A, const __fp16 *B, float *C,
- unsigned int M, unsigned int N, unsigned int K,
- float alpha, float beta) {
- const unsigned int K8_high = ((K - 1) / 8 + 1) * 8;
- const unsigned int K8_low = (K >> 3) << 3;
-
- const unsigned int lda = K;
- const unsigned int ldb = N;
+void hgemm_ensure_divisibility(const __fp16 *A, const __fp16 *B, float *C32,
+ unsigned int M, unsigned int N, unsigned int K,
+ float alpha, float beta, bool TransA,
+ bool TransB) {
+ /// @note Padding standard : 8x16 is the only KERNEL that outperforms single
+ /// precision GEMM 'so far'. Padding will forcibly make every GEMM cases to
+ /// use it. Note that padding is not the optimal way here, but just an option
+ /// that is easier to implement. Fine-grained packing should be supported on
+ /// the future for optimal performance.
- __fp16 *A8 = alignedMalloc(M * K8_high);
- __fp16 *B8 = alignedMalloc(K8_high * N);
+ __fp16 *A_ = (__fp16 *)A, *B_ = (__fp16 *)B;
+ unsigned int M_ = M, N_ = N, K_ = K;
+ bool pad_A = false, pad_B = false;
- float16x8_t ZEROS = vmovq_n_f16(0.F);
+ // Case 2 : smaller than 8, 16 | padding would be redundant?
+ if (M < 8 && K < 16 && N < 16)
+ return hgemm_classify(A_, B_, C32, M_, N_, K_, alpha, beta, TransA, TransB);
- // Make zero-padded A matrix
- for (unsigned int m = 0; m < M; ++m) {
- unsigned int k = 0;
- for (; k < K8_low; k += 8) {
- vst1q_f16(&A8[m * K8_high + k], vld1q_f16(&A[m * K + k]));
- }
- for (; k < K; ++k) {
- A8[m * K8_high + k] = A[m * K + k];
- }
- for (; k < K8_high; ++k) {
- A8[m * K8_high + k] = 0.F;
- }
- }
+ __fp16 *Ap;
+ __fp16 *Bp;
- // Make zero-padded B matrix
- unsigned int k = 0;
- unsigned int N8 = (N >> 3) << 3;
- for (; k < K; ++k) {
- unsigned int n = 0;
- for (; n < N8; n += 8) {
- vst1q_f16(&B8[k * N + n], vld1q_f16(&B[k * N + n]));
- }
- for (; n < N; ++n) {
- B8[k * N + n] = B[k * N + n];
- }
+ const unsigned int M8_high = ((M - 1) / 8 + 1) * 8;
+ const unsigned int K8_high = ((K - 1) / 8 + 1) * 8;
+ const unsigned int N16_high = ((N - 1) / 16 + 1) * 16;
+
+ if ((M8_high != M) || (K8_high != K)) {
+ pad_A = true;
+ Ap = alignedMalloc(M8_high * K8_high);
+ hgemm_padding_A(A, Ap, M, K, M8_high, K8_high, TransA);
+ A_ = Ap;
+ M_ = M8_high;
+ K_ = K8_high;
}
- for (; k < K8_high; ++k) {
- unsigned int n = 0;
- for (; n < N8; n += 8) {
- vst1q_f16(&B8[k * N + n], ZEROS);
- }
- for (; n < N; ++n) {
- B8[k * N + n] = 0.F;
- }
+ if ((K8_high != K) || (N16_high != N)) {
+ pad_B = true;
+ Bp = alignedMalloc(K8_high * N16_high);
+ hgemm_padding_B(B, Bp, K, N, K8_high, N16_high, TransB);
+ B_ = Bp;
+ K_ = K8_high;
+ N_ = N16_high;
}
- hgemm_noTrans_strict(A8, B8, C, M, N, K8_high, alpha, beta);
+ hgemm_classify(A_, B_, C32, M_, N_, K_, alpha, beta, TransA, TransB);
+
+ if (pad_A)
+ free(Ap);
+ if (pad_B)
+ free(Bp);
+}
- free(A8);
- free(B8);
+void hgemm_classify(const __fp16 *A, const __fp16 *B, float *C32,
+ unsigned int M, unsigned int N, unsigned int K, float alpha,
+ float beta, bool TransA, bool TransB) {
+ if (!TransA && !TransB) {
+ hgemm_noTrans(A, B, C32, M, N, K, alpha, beta);
+ } else if (TransA && !TransB) {
+ hgemm_transA(A, B, C32, M, N, K, alpha, beta);
+ } else if (!TransA && TransB) {
+ hgemm_transB(A, B, C32, M, N, K, alpha, beta);
+ } else { // TransA && TransB
+ hgemm_transAB(A, B, C32, M, N, K, alpha, beta);
+ }
}
void hgemm_K1_noTrans(unsigned int M, unsigned int N, unsigned int K,
const __fp16 *A, unsigned int lda, const __fp16 *B,
unsigned int ldb, __fp16 *C, unsigned int ldc,
float alpha, float beta) {
- // M, N, K is full M, N, K here
-
__fp16 *sA = alignedMalloc(M * K);
__fp16 *sB = alignedMalloc(K * N);
const __fp16 *A, unsigned int lda, const __fp16 *B,
unsigned int ldb, float *C, unsigned int ldc,
float alpha, float beta) {
-
__fp16 *sA = alignedMalloc(M * K);
__fp16 *sB = alignedMalloc(K * N);
if (n_min >= N_BLOCKING * 2) {
n_min = N_BLOCKING;
} else if (n_min > N_BLOCKING) {
- n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+ n_min =
+ (n_min / 2 + GEMM_UNROLLING_16 - 1) & ~(GEMM_UNROLLING_16 - 1);
+
}
packing_B16(k_min, n_min, B + ns + ldb * ks, ldb, sB);
+
HGEMM_KERNEL_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
}
}
}
}
+void hgemm_transB_8x16(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha, float beta) {
+
+ __fp16 *sA = alignedMalloc(M * K);
+ __fp16 *sB = alignedMalloc(K * N);
+
+ unsigned int ms, ms2, ns, ks;
+ unsigned int m_min, m2_min, n_min, k_min;
+ unsigned int stride_l1 = 1;
+
+ for (ms = 0; ms < M; ms += M_BLOCKING) {
+ m_min = M - ms;
+ if (m_min > M_BLOCKING) {
+ m_min = M_BLOCKING;
+ }
+
+ for (ks = 0; ks < K; ks += k_min) {
+ k_min = K - ks;
+ if (k_min >= (K_BLOCKING << 1)) {
+ k_min = K_BLOCKING;
+ } else if (k_min > K_BLOCKING) {
+ k_min = (k_min / 2 + GEMM_UNROLLING_4 - 1) & ~(GEMM_UNROLLING_4 - 1);
+ }
+
+ n_min = N;
+ if (N >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (N > N_BLOCKING) {
+ n_min = ((n_min / 2 + GEMM_UNROLLING_16 - 1) / GEMM_UNROLLING_16) *
+ GEMM_UNROLLING_16;
+ } else {
+ stride_l1 = 0;
+ }
+ packing_transB16(k_min, n_min, B + (ks), ldb, sB);
+
+ for (ms2 = ms; ms2 < ms + m_min; ms2 += m2_min) {
+ m2_min = (ms + m_min) - ms2;
+ if (m2_min >= 3 * GEMM_UNROLLING_8) {
+ m2_min = 3 * GEMM_UNROLLING_8;
+ } else if (m2_min >= 2 * GEMM_UNROLLING_8) {
+ m2_min = 2 * GEMM_UNROLLING_8;
+ } else if (m2_min > GEMM_UNROLLING_8) {
+ m2_min = GEMM_UNROLLING_8;
+ }
+ packing_A8(m2_min, k_min, A + ms2 * lda + ks, lda,
+ sA + k_min * (ms2 - ms) * stride_l1);
+ HGEMM_KERNEL_8x16(m2_min, n_min, k_min,
+ sA + k_min * (ms2 - ms) * stride_l1, sB,
+ C + ms2 * ldc, ldc);
+ }
+
+ for (ns = n_min; ns < N; ns += n_min) {
+ n_min = N - ns;
+ if (n_min >= N_BLOCKING * 2) {
+ n_min = N_BLOCKING;
+ } else if (n_min > N_BLOCKING) {
+ n_min = (n_min / 2 + GEMM_UNROLLING_8 - 1) & ~(GEMM_UNROLLING_8 - 1);
+ }
+ packing_transB16(k_min, n_min, B + ks + ldb * ns, ldb, sB);
+ HGEMM_KERNEL_8x16(m_min, n_min, k_min, sA, sB, C + ms * ldc + ns, ldc);
+ }
+ }
+ }
+
+ free(sA);
+ free(sB);
+}
+
void hgemm_transB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
unsigned int N, unsigned int K, float alpha, float beta) {
+ if (((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0)) {
+ return hgemm_transB_8x16(M, N, K, A, K, B, K, C, N, alpha, beta);
+ } else {
+ return hgemm_transB_fallback(A, B, C, M, N, K, alpha, beta);
+ }
+}
+
+void hgemm_transB_fallback(const __fp16 *A, const __fp16 *B, float *C,
+ unsigned int M, unsigned int N, unsigned int K,
+ float alpha, float beta) {
__fp16 *B_T = alignedMalloc(K * N);
transpose_neon<__fp16>(N, K, B, K, B_T, N);
* @param[in] alpha float number
* @param[in] beta float number
*/
-void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
- unsigned int N, unsigned int K, float alpha = 1.F,
- float beta = 0.F);
+void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, __fp16 *C,
+ unsigned int M, unsigned int N, unsigned int K,
+ float alpha = 1.F, float beta = 0.F);
/**
* @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
* @param[in] alpha float number
* @param[in] beta float number
*/
-void hgemm_noTrans_padding_wrt_K(const __fp16 *A, const __fp16 *B, float *C,
- unsigned int M, unsigned int N, unsigned int K,
- float alpha = 1.F, float beta = 0.F);
+void hgemm_ensure_divisibility(const __fp16 *A, const __fp16 *B, float *C32,
+ unsigned int M, unsigned int N, unsigned int K,
+ float alpha = 1.F, float beta = 0.F,
+ bool TransA = false, bool TransB = false);
- /**
+/**
* @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
* @param[in] A __fp16 * for Matrix A
* @param[in] B __fp16 * for Matrix B
* @param[in] alpha float number
* @param[in] beta float number
*/
-void hgemm_noTrans_padding_wrt_K4(const __fp16 *A, const __fp16 *B, float *C,
- unsigned int M, unsigned int N, unsigned int K,
- float alpha = 1.F, float beta = 0.F);
+void hgemm_classify(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
+ unsigned int N, unsigned int K, float alpha = 1.F, float beta = 0.F,
+ bool TransA = false, bool TransB = false);
/**
* @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
*/
void hgemm_transB(const __fp16 *A, const __fp16 *B, float *C, unsigned int M,
unsigned int N, unsigned int K, float alpha, float beta);
+
+void hgemm_transB_fallback(const __fp16 *A, const __fp16 *B, float *C,
+ unsigned int M, unsigned int N, unsigned int K,
+ float alpha, float beta);
+
+/**
+ * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * where op(X) is one of X or X**T
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_transB_8x16(unsigned int M, unsigned int N, unsigned int K,
+ const __fp16 *A, unsigned int lda, const __fp16 *B,
+ unsigned int ldb, float *C, unsigned int ldc,
+ float alpha = 1.F, float beta = 0.F);
/**
* @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
* where op(X) is one of X or X**T