unsigned int N, unsigned int K, float alpha, float beta) {
const float eps = std::numeric_limits<float>::epsilon();
if (std::abs(alpha - 1.F) < eps) {
- // used bitwise operator instead of modulo for performance
- // e.g (M % 8) is same as (M & 0x7) which will extract last 3 bits of M
- if ((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0) {
- hgemm_noTrans_8x16(M, N, K, A, K, B, N, C32, N, alpha, beta);
- } else if ((M & 0x7) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) {
- hgemm_noTrans_8x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
- } else if ((M & 0x3) == 0 && (N & 0x7) == 0 && (K & 0x3) == 0) {
- hgemm_noTrans_4x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
- } else if ((K & 0x7) == 0 && (N & 0x7) == 0) {
- hgemm_noTrans_1x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
- } else if ((K & 0x7) == 0 && (N & 0x3) == 0) {
- hgemm_noTrans_1x4(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ if ((K & 0x7) != 0) {
+ hgemm_noTrans_padding_wrt_K(A, B, C32, M, N, K, alpha, beta);
} else {
- hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ hgemm_noTrans_strict(A, B, C32, M, N, K, alpha, beta);
}
- } else
+ } else {
+ hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ }
+}
+
+void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, float *C32,
+ unsigned int M, unsigned int N, unsigned int K,
+ float alpha, float beta) {
+ // used bitwise operator instead of modulo for performance
+ // e.g (M % 8) is same as (M & 0x7) which will extract last 3 bits of M
+ if ((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0) {
+ hgemm_noTrans_8x16(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ } else if ((M & 0x7) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) {
+ hgemm_noTrans_8x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ } else if ((M & 0x3) == 0 && (N & 0x7) == 0 && (K & 0x7) == 0) {
+ hgemm_noTrans_4x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ } else if ((K & 0x7) == 0 && (N & 0x7) == 0) {
+ hgemm_noTrans_1x8(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ } else if ((K & 0x7) == 0 && (N & 0x3) == 0) {
+ hgemm_noTrans_1x4(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ } else {
hgemm_noTrans_fallback(M, N, K, A, K, B, N, C32, N, alpha, beta);
+ }
}
-void hgemm_noTrans(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
+void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
unsigned int N, unsigned int K, float alpha, float beta) {
if (alpha == 1.F) {
// used bitwise operator instead of modulo for performance
}
}
+void hgemm_noTrans_padding_wrt_K(const __fp16 *A, const __fp16 *B, float *C,
+ unsigned int M, unsigned int N, unsigned int K,
+ float alpha, float beta) {
+ const unsigned int K8_high = ((K - 1) / 8 + 1) * 8;
+ const unsigned int K8_low = (K >> 3) << 3;
+
+ const unsigned int lda = K;
+ const unsigned int ldb = N;
+
+ __fp16 *A8 = new __fp16[M * K8_high];
+ __fp16 *B8 = new __fp16[K8_high * N];
+
+ float16x8_t ZEROS = vmovq_n_f16(0.F);
+
+ // Make zero-padded A matrix
+ for (unsigned int m = 0; m < M; ++m) {
+ unsigned int k = 0;
+ for (; k < K8_low; k += 8) {
+ vst1q_f16(&A8[m * K8_high + k], vld1q_f16(&A[m * K + k]));
+ }
+ for (; k < K; ++k) {
+ A8[m * K8_high + k] = A[m * K + k];
+ }
+ for (; k < K8_high; ++k) {
+ A8[m * K8_high + k] = 0.F;
+ }
+ }
+
+ // Make zero-padded B matrix
+ unsigned int k = 0;
+ unsigned int N8 = (N >> 3) << 3;
+ for (; k < K; ++k) {
+ unsigned int n = 0;
+ for (; n < N8; n += 8) {
+ vst1q_f16(&B8[k * N + n], vld1q_f16(&B[k * N + n]));
+ }
+ for (; n < N; ++n) {
+ B8[k * N + n] = B[k * N + n];
+ }
+ }
+ for (; k < K8_high; ++k) {
+ unsigned int n = 0;
+ for (; n < N8; n += 8) {
+ vst1q_f16(&B8[k * N + n], ZEROS);
+ }
+ for (; n < N; ++n) {
+ B8[k * N + n] = 0.F;
+ }
+ }
+
+ hgemm_noTrans_strict(A8, B8, C, M, N, K8_high, alpha, beta);
+
+ free(A8);
+ free(B8);
+}
+
void hgemm_noTrans_1x4(unsigned int M, unsigned int N, unsigned int K,
const __fp16 *A, unsigned int lda, const __fp16 *B,
unsigned int ldb, __fp16 *C, unsigned int ldc,
const __fp16 *A, unsigned int lda, const __fp16 *B,
unsigned int ldb, __fp16 *C, unsigned int ldc,
float alpha, float beta) {
+// M, N, K is full M, N, K here
__fp16 *sA = alignedMalloc(M * K);
__fp16 *sB = alignedMalloc(K * N);
* @param[in] alpha float number
* @param[in] beta float number
*/
-void hgemm_noTrans(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
+void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
unsigned int N, unsigned int K, float alpha = 1.F,
float beta = 0.F);
+/**
+ * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_strict(const __fp16 *A, const __fp16 *B, float *C,
+ unsigned int M, unsigned int N, unsigned int K,
+ float alpha = 1.F, float beta = 0.F);
+
+/**
+ * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_padding_wrt_K(const __fp16 *A, const __fp16 *B, float *C,
+ unsigned int M, unsigned int N, unsigned int K,
+ float alpha = 1.F, float beta = 0.F);
+
+ /**
+ * @brief hgemm computation with neon : Y = alpha*op(A)*op(B) + beta*C,
+ * @param[in] A __fp16 * for Matrix A
+ * @param[in] B __fp16 * for Matrix B
+ * @param[in] C __fp16 * for Matrix C
+ * @param[in] M number of op(A)'s and C's row
+ * @param[in] N number of op(B)'s and C's columns
+ * @param[in] K number of op(A)'s and columns and op(B)'s rows
+ * @param[in] alpha float number
+ * @param[in] beta float number
+ */
+void hgemm_noTrans_padding_wrt_K4(const __fp16 *A, const __fp16 *B, float *C,
+ unsigned int M, unsigned int N, unsigned int K,
+ float alpha = 1.F, float beta = 0.F);
+
/**
* @brief hgemm fallback with neon : Y = alpha*op(A)*op(B) + beta*C,
* @param M length of the row of matrix A