// performing beta*C
unsigned int idx = 0;
unsigned int size = M * N;
- for (; idx < (size - idx) && (size - idx) >= 8; idx += 8) {
- float16x8_t c = vmulq_n_f16(vld1q_f16(&C[idx]), static_cast<__fp16>(beta));
+ if (beta != 0.F) {
+ for (; idx < (size - idx) && (size - idx) >= 8; idx += 8) {
+ float16x8_t c =
+ vmulq_n_f16(vld1q_f16(&C[idx]), static_cast<__fp16>(beta));
- vst1q_f32(&C32[idx], vcvt_f32_f16(vget_low_f16(c)));
- vst1q_f32(&C32[idx + 4], vcvt_f32_f16(vget_high_f16(c)));
- }
- // remaining 4
- for (; idx < (size - idx) && (size - idx) >= 4; idx += 4) {
- float16x4_t c = vmul_n_f16(vld1_f16(&C[idx]), static_cast<__fp16>(beta));
+ vst1q_f32(&C32[idx], vcvt_f32_f16(vget_low_f16(c)));
+ vst1q_f32(&C32[idx + 4], vcvt_f32_f16(vget_high_f16(c)));
+ }
+ // remaining 4
+ for (; idx < (size - idx) && (size - idx) >= 4; idx += 4) {
+ float16x4_t c = vmul_n_f16(vld1_f16(&C[idx]), static_cast<__fp16>(beta));
- vst1q_f32(&C32[idx], vcvt_f32_f16(c));
- }
+ vst1q_f32(&C32[idx], vcvt_f32_f16(c));
+ }
- // remaining values if dimensions not a multiple of 8
- for (; idx < size; idx++) {
- C32[idx] = C[idx] * beta;
+ // remaining values if dimensions not a multiple of 8
+ for (; idx < size; idx++) {
+ C32[idx] = C[idx] * beta;
+ }
+ } else {
+ float32x4_t zeros = vmovq_n_f32(0.F);
+ for (; idx < (size - idx) && (size - idx) >= 4; idx += 4) {
+ vst1q_f32(&C32[idx], zeros);
+ }
+ for (; idx < size; idx++) {
+ C32[idx] = 0.F;
+ }
}
if (!TransA && TransB) {
void hgemm_noTrans(const __fp16 *A, const __fp16 *B, float *C32, unsigned int M,
unsigned int N, unsigned int K, float alpha, float beta) {
- if (alpha == 1.F && beta == 0.F && N > 4) {
+ if (alpha == 1.F) {
// used bitwise operator instead of modulo for performance
// e.g (M % 8) is same as (M & 0x7) which will extract last 3 bits of M
if ((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0) {
void hgemm_noTrans(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
unsigned int N, unsigned int K, float alpha, float beta) {
- if (alpha == 1.F && beta == 0.F) {
+ if (alpha == 1.F) {
// used bitwise operator instead of modulo for performance
// e.g (M % 8) is same as (M & 0x7) which will extract last 3 bits of M
if ((M & 0x7) == 0 && (N & 0xF) == 0 && (K & 0x7) == 0) {