- Missing implementations might trigger unittest fails on Android.
- This patch will now support padding function for all combinations of following conditions : matrix A / B, trans/noTrans, M/K/N direction
**Self evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test: [X]Passed [ ]Failed [ ]Skipped
Signed-off-by: skykongkong8 <ss.kong@samsung.com>
#include <hgemm_transAB.h>
#include <hgemm_transB.h>
#include <hgemm_util.h>
+#include <iostream>
#include <limits>
void hgemm(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
return hgemm_K1(A, B, C, M, N, K, alpha, beta, TransA, TransB);
}
// dynamic creation to avoid reaching stack limit(causes segmentation fault)
- float *C32 = (float *)malloc(M * N * sizeof(float));
+ const unsigned int M8_high = ((M - 1) / 8 + 1) * 8;
+ const unsigned int K8_high = ((K - 1) / 8 + 1) * 8;
+ const unsigned int N16_high = ((N - 1) / 16 + 1) * 16;
+ const unsigned int N8_low = (N >> 3) << 3;
+ float32x4_t ZEROS = vmovq_n_f32(0.F);
- // performing beta*C
- unsigned int idx = 0;
- unsigned int size = M * N;
+ // void* C_ptr = 0;
+ // int iRet = posix_memalign(&C_ptr, 64, M8_high * N16_high * sizeof(float));
+ // float* C32 = (float*) C_ptr;
+ float *C32 = (float *)malloc(M8_high * N16_high * sizeof(float));
+
+ unsigned int size = M8_high * N16_high;
unsigned int size8 = (size >> 3) << 3;
unsigned int size4 = (size >> 2) << 2;
if (std::fpclassify(beta) != FP_ZERO) {
- for (; idx < size8; idx += 8) {
- float16x8_t c =
- vmulq_n_f16(vld1q_f16(&C[idx]), static_cast<__fp16>(beta));
-
- vst1q_f32(&C32[idx], vcvt_f32_f16(vget_low_f16(c)));
- vst1q_f32(&C32[idx + 4], vcvt_f32_f16(vget_high_f16(c)));
- }
- // remaining 4
- for (; idx < size4; idx += 4) {
- float16x4_t c = vmul_n_f16(vld1_f16(&C[idx]), static_cast<__fp16>(beta));
-
- vst1q_f32(&C32[idx], vcvt_f32_f16(c));
+ for (unsigned int m = 0; m < M; ++m) {
+ for (unsigned int n = 0; n < N8_low; n += 8) {
+ float16x8_t c =
+ vmulq_n_f16(vld1q_f16(&C[m * N + n]), static_cast<__fp16>(beta));
+ vst1q_f32(&C32[m * N16_high + n], vcvt_f32_f16(vget_low_f16(c)));
+ vst1q_f32(&C32[m * N16_high + n + 4], vcvt_f32_f16(vget_high_f16(c)));
+ }
+ for (unsigned int n = N8_low; n < N; ++n) {
+ C32[m * N16_high + n] = beta * C[m * N + n];
+ }
+ for (unsigned int n = N; n < N16_high; ++n) {
+ C32[m * N16_high + n] = 0.F;
+ }
}
-
- // remaining values if dimensions not a multiple of 8
- for (; idx < size; idx++) {
- C32[idx] = C[idx] * beta;
+ for (unsigned m = M; m < M8_high; ++m) {
+ for (unsigned int n = 0; n < N16_high; n += 4) {
+ vst1q_f32(&C32[m * N16_high + n], ZEROS);
+ }
}
} else {
- float32x4_t zeros = vmovq_n_f32(0.F);
- for (; idx < size4; idx += 4) {
- vst1q_f32(&C32[idx], zeros);
+ for (unsigned int idx = 0; idx < size4; idx += 4) {
+ vst1q_f32(&C32[idx], ZEROS);
}
- for (; idx < size; idx++) {
+ for (unsigned int idx = size4; idx < size; idx++) {
C32[idx] = 0.F;
}
}
hgemm_ensure_divisibility(A, B, C32, M, N, K, alpha, beta, TransA, TransB);
- for (unsigned int idx = 0; idx < size8; idx += 8) {
- float32x4_t x1 = vld1q_f32(&C32[idx]);
- float32x4_t x2 = vld1q_f32(&C32[idx + 4]);
-
- float16x8_t y1 = vcombine_f16(vcvt_f16_f32(x1), vcvt_f16_f32(x2));
-
- vst1q_f16(&C[idx], y1);
- }
- for (unsigned int idx = size8; idx < size; ++idx) {
- C[idx] = static_cast<__fp16>(C32[idx]);
+ for (unsigned int m = 0; m < M; ++m) {
+ for (unsigned int n = 0; n < N8_low; n += 8) {
+ float32x4_t x1 = vld1q_f32(&C32[m * N16_high + n]);
+ float32x4_t x2 = vld1q_f32(&C32[m * N16_high + n + 4]);
+ vst1q_f16(&C[m * N + n],
+ vcombine_f16(vcvt_f16_f32(x1), vcvt_f16_f32(x2)));
+ }
+ for (unsigned int n = N8_low; n < N; ++n) {
+ C[m * N + n] = C32[m * N16_high + n];
+ }
}
free(C32);
bool TransB) {
/// @note Padding standard : 8x16 is the only KERNEL that outperforms single
/// precision GEMM 'so far'. Padding will forcibly make every GEMM cases to
- /// use it. Note that padding is not the optimal way here, but just an option
+ /// use it. Note that padding is not an optimal way here, but just an option
/// that is easier to implement. Fine-grained packing, blocking, and
- /// corresponding kernels should be supported on the future for optimal
- /// performance.
+ /// corresponding kernels should be supported in the future for optimal
+ /// performance in terms of both latency and memory.
__fp16 *A_ = (__fp16 *)A, *B_ = (__fp16 *)B;
unsigned int M_ = M, N_ = N, K_ = K;
bool pad_A = false, pad_B = false;
- // Case 2 : smaller than 8, 16 | padding would be redundant
+ // Smaller than 8, 16 -> padding would be redundant
if (M < 8 && K < 16 && N < 16)
return hgemm_classify(A_, B_, C32, M_, N_, K_, alpha, beta, TransA, TransB);
__fp16 *Bp;
const unsigned int M8_high = ((M - 1) / 8 + 1) * 8;
- const unsigned int K8_high = ((K - 1) / 8 + 1) * 8;
+ // const unsigned int K8_high = ((K - 1) / 8 + 1) * 8;
+ const unsigned int K8_high = ((K - 1) / 16 + 1) * 16;
const unsigned int N16_high = ((N - 1) / 16 + 1) * 16;
if ((M8_high != M) || (K8_high != K)) {
N_ = N16_high;
}
+ // std::cout << "A matrix\n";
+ // for (unsigned int m = 0; m < M; m += 1) {
+ // for (unsigned int k = 0; k < K; ++k) {
+ // std::cout << A[m * K + k] << "\t";
+ // }
+ // std::cout << std::endl;
+ // }
+ // std::cout << std::endl;
+ // if (pad_A) {
+ // std::cout << "B padding\n";
+ // for (unsigned int m = 0; m < M; m += 1) {
+ // for (unsigned int k = 0; k < K8_high; ++k) {
+ // std::cout << A_[m * K8_high + k] << "\t";
+ // }
+ // std::cout << std::endl;
+ // }
+ // std::cout << std::endl;
+ // }
+ // std::cout << "B matrix\n";
+ // for (unsigned int k = 0; k < K; ++k) {
+ // for (unsigned int n = 0; n < N; n += 1) {
+ // std::cout << B[k * N + n] << "\t";
+ // }
+ // std::cout << std::endl;
+ // }
+ // std::cout << std::endl;
+ // if (pad_B) {
+ // std::cout << "B padding\n";
+ // for (unsigned int k = 0; k < K; ++k) {
+ // for (unsigned int n = 0; n < N16_high; n += 1) {
+ // std::cout << B_[k * N16_high + n] << "\t";
+ // }
+ // std::cout << std::endl;
+ // }
+ // std::cout << std::endl;
+ // }
+
+ // std::cout << "A matrix\n";
+ // matrix_printer<__fp16>(A_, M_, K_);
+ // std::cout << "B matrix\n";
+ // matrix_printer<__fp16>(B_, K_, N_);
+
hgemm_classify(A_, B_, C32, M_, N_, K_, alpha, beta, TransA, TransB);
if (pad_A)
float16x8_t ZEROS = vmovq_n_f16(0.F);
for (unsigned int m = 0; m < M; ++m) {
- for (unsigned int k = 0; k < K8_low; ++k) {
+ for (unsigned int k = 0; k < K8_low; k += 8) {
vst1q_f16(&Ap[m * K8 + k], vld1q_f16(&A[m * K + k]));
}
for (unsigned int k = K8_low; k < K; ++k) {
}
}
for (unsigned int m = M; m < M8; ++m) {
- for (unsigned int k = K; k < K8; ++k) {
- Ap[m * K8 + k] = ZEROS;
+ for (unsigned int k = 0; k < K8; k += 8) {
+ vst1q_f16(&Ap[m * K8 + k], ZEROS);
}
}
}
unsigned int K, unsigned int M8,
unsigned int K8) {
const unsigned int M8_low = (M >> 3) << 3;
-
for (unsigned int k = 0; k < K; ++k) {
for (unsigned int m = 0; m < M8_low; m += 8) {
- vst1q_f16(&Ap[k * M + m], vld1q_f16(&A[k * M + m]));
+ vst1q_f16(&Ap[k * M8 + m], vld1q_f16(&A[k * M + m]));
}
for (unsigned int m = M8_low; m < M; ++m) {
- Ap[k * M + m] = A[k * M + m];
+ Ap[k * M8 + m] = A[k * M + m];
}
for (unsigned int m = M; m < M8; ++m) {
- Ap[k * M + m] = 0.F;
+ Ap[k * M8 + m] = 0.F;
}
}
}
void hgemm_padding_A_Trans_wrt_K(const __fp16 *A, __fp16 *Ap, unsigned int M,
unsigned int K, unsigned int M8,
unsigned int K8) {
- std::cerr << "Error : hgemm_padding_A_Trans_wrt_K NYI!\n";
+ float16x8_t ZEROS = vmovq_n_f16(0.F);
+ for (unsigned int k = 0; k < K; ++k) {
+ for (unsigned int m = 0; m < M; m += 8) {
+ vst1q_f16(&Ap[k * M8 + m], vld1q_f16(&A[k * M + m]));
+ }
+ }
+ for (unsigned int k = K; k < K8; ++k) {
+ for (unsigned int m = 0; m < M; m += 8) {
+ vst1q_f16(&Ap[k * M8 + m], ZEROS);
+ }
+ }
}
void hgemm_padding_A_Trans_wrt_MK(const __fp16 *A, __fp16 *Ap, unsigned int M,
unsigned int K, unsigned int M8,
unsigned int K8) {
- std::cerr << "Error : hgemm_padding_A_Trans_wrt_MK NYI!\n";
+ float16x8_t ZEROS = vmovq_n_f16(0.F);
+ const unsigned int M8_low = (M >> 3) << 3;
+ for (unsigned int k = 0; k < K; ++k) {
+ for (unsigned int m = 0; m < M8_low; m += 8) {
+ vst1q_f16(&Ap[k * M8 + m], vld1q_f16(&A[k * M + m]));
+ }
+ for (unsigned int m = M8_low; m < M; ++m) {
+ Ap[k * M8 + m] = A[k * M + m];
+ }
+ for (unsigned int m = M; m < M8; ++m) {
+ Ap[k * M8 + m] = 0.F;
+ }
+ }
+ for (unsigned int k = K; k < K8; ++k) {
+ for (unsigned int m = 0; m < M8; m += 8) {
+ vst1q_f16(&Ap[k * M8 + m], ZEROS);
+ }
+ }
}
unsigned int N, unsigned int K8, unsigned int N16,
bool transB) {
if (transB) {
- hgemm_padding_B_Trans(B, Bp, K, N, K8, N16);
+ return hgemm_padding_B_Trans(B, Bp, K, N, K8, N16);
} else {
- hgemm_padding_B_noTrans(B, Bp, K, N, K8, N16);
+ return hgemm_padding_B_noTrans(B, Bp, K, N, K8, N16);
}
}
unsigned int N, unsigned int K8,
unsigned int N16) {
if (K != K8 && N != N16) {
- hgemm_padding_B_noTrans_wrt_KN(B, Bp, K, N, K8, N16);
+ return hgemm_padding_B_noTrans_wrt_KN(B, Bp, K, N, K8, N16);
} else if (K != K8) {
- hgemm_padding_B_noTrans_wrt_K(B, Bp, K, N, K8, N16);
+ return hgemm_padding_B_noTrans_wrt_K(B, Bp, K, N, K8, N16);
} else if (N != N16) {
- hgemm_padding_B_noTrans_wrt_N(B, Bp, K, N, K8, N16);
+ return hgemm_padding_B_noTrans_wrt_N(B, Bp, K, N, K8, N16);
} else {
std::cerr << "Error : No room for matrix B padding\n";
}
void hgemm_padding_B_Trans(const __fp16 *B, __fp16 *Bp, unsigned int K,
unsigned int N, unsigned int K8, unsigned int N16) {
if (K != K8 && N != N16) {
- hgemm_padding_B_Trans_wrt_KN(B, Bp, K, N, K8, N16);
+ return hgemm_padding_B_Trans_wrt_KN(B, Bp, K, N, K8, N16);
} else if (K != K8) {
- hgemm_padding_B_Trans_wrt_K(B, Bp, K, N, K8, N16);
+ return hgemm_padding_B_Trans_wrt_K(B, Bp, K, N, K8, N16);
} else if (N != N16) {
- hgemm_padding_B_Trans_wrt_N(B, Bp, K, N, K8, N16);
+ return hgemm_padding_B_Trans_wrt_N(B, Bp, K, N, K8, N16);
} else {
std::cerr << "Error : No room for matrix B padding\n";
}
void hgemm_padding_B_noTrans_wrt_N(const __fp16 *B, __fp16 *Bp, unsigned int K,
unsigned int N, unsigned int K8,
unsigned int N16) {
- std::cerr << "Error : hgemm_padding_B_noTrans_wrt_N NYI!\n";
+ const unsigned int N8_low = (N >> 3) << 3;
+ for (unsigned int k = 0; k < K; ++k) {
+ for (unsigned int n = 0; n < N8_low; n += 8) {
+ vst1q_f16(&Bp[k * N16 + n], vld1q_f16(&B[k * N + n]));
+ }
+ for (unsigned int n = N8_low; n < N; ++n) {
+ Bp[k * N16 + n] = B[k * N + n];
+ }
+ for (unsigned int n = N; n < N16; ++n) {
+ Bp[k * N16 + n] = 0.F;
+ }
+ }
}
void hgemm_padding_B_noTrans_wrt_K(const __fp16 *B, __fp16 *Bp, unsigned int K,
void hgemm_padding_B_noTrans_wrt_KN(const __fp16 *B, __fp16 *Bp, unsigned int K,
unsigned int N, unsigned int K8,
unsigned int N16) {
- std::cerr << "Error : hgemm_padding_B_noTrans_wrt_KN NYI!\n";
+ unsigned int N8_low = (N >> 3) << 3;
+ float16x8_t ZEROS = vmovq_n_f16(0.F);
+ for (unsigned int k = 0; k < K; ++k) {
+ for (unsigned int n = 0; n < N8_low; n += 8) {
+ vst1q_f16(&Bp[k * N16 + n], vld1q_f16(&B[k * N + n]));
+ }
+ for (unsigned int n = N8_low; n < N; ++n) {
+ Bp[k * N16 + n] = B[k * N + n];
+ }
+ for (unsigned int n = N; n < N16; ++n) {
+ Bp[k * N16 + n] = 0.F;
+ }
+ }
+ for (unsigned int k = K; k < K8; ++k) {
+ for (unsigned int n = 0; n < N16; n += 8) {
+ vst1q_f16(&Bp[k * N16 + n], ZEROS);
+ }
+ }
}
void hgemm_padding_B_Trans_wrt_N(const __fp16 *B, __fp16 *Bp, unsigned int K,
unsigned int N, unsigned int K8,
unsigned int N16) {
- std::cerr << "Error : hgemm_padding_B_Trans_wrt_N NYI!\n";
+ float16x8_t ZEROS = vmovq_n_f16(0.F);
+
+ for (unsigned int n = 0; n < N; ++n) {
+ for (unsigned int k = 0; k < K; k += 8) {
+ vst1q_f16(&Bp[n * K8 + k], vld1q_f16(&B[n * K + k]));
+ }
+ }
+ for (unsigned int n = N; n < N16; ++n) {
+ for (unsigned int k = 0; k < K; k += 8) {
+ vst1q_f16(&Bp[n * K8 + k], ZEROS);
+ }
+ }
}
void hgemm_padding_B_Trans_wrt_K(const __fp16 *B, __fp16 *Bp, unsigned int K,
void hgemm_padding_B_Trans_wrt_KN(const __fp16 *B, __fp16 *Bp, unsigned int K,
unsigned int N, unsigned int K8,
unsigned int N16) {
- std::cerr << "Error : hgemm_padding_B_Trans_wrt_KN NYI!\n";
+ unsigned int K8_low = (K >> 3) << 3;
+ float16x8_t ZEROS = vmovq_n_f16(0.F);
+
+ for (unsigned int n = 0; n < N; ++n) {
+ for (unsigned int k = 0; k < K8_low; k += 8) {
+ vst1q_f16(&Bp[n * K8 + k], vld1q_f16(&B[n * K + k]));
+ }
+ for (unsigned int k = K8_low; k < K; ++k) {
+ Bp[n * K8 + k] = B[n * K + k];
+ }
+ for (unsigned int k = K; k < K8; ++k) {
+ Bp[n * K8 + k] = 0.F;
+ }
+ }
+ for (unsigned int n = N; n < N16; ++n) {
+ for (unsigned int k = 0; k < K8; k += 8) {
+ vst1q_f16(&Bp[n * K8 + k], ZEROS);
+ }
+ }
}