float Y32[rows];
unsigned int idx = 0;
+
for (; rows - idx >= 8; idx += 8) {
float32x4_t y0_3 = vcvt_f32_f16(vld1_f16(&Y[idx]));
float32x4_t y4_7 = vcvt_f32_f16(vld1_f16(&Y[idx + 4]));
vst1q_f32(&Y32[idx], y0_3_32);
}
- while (idx < rows) {
- Y32[idx] = Y[idx] * beta;
- ++idx;
+ for (; idx < rows; ++idx) {
+ Y32[idx] = beta * Y[idx];
}
idx = 0;
+ for (; cols - idx >= 64; idx += 64) {
+ float16x8_t x0_7 = vld1q_f16(&X[idx]);
+ float16x8_t x8_15 = vld1q_f16(&X[idx + 8]);
+ float16x8_t x16_23 = vld1q_f16(&X[idx + 16]);
+ float16x8_t x24_31 = vld1q_f16(&X[idx + 24]);
+
+ float16x8_t x32_39 = vld1q_f16(&X[idx + 32]);
+ float16x8_t x40_47 = vld1q_f16(&X[idx + 40]);
+ float16x8_t x48_55 = vld1q_f16(&X[idx + 48]);
+ float16x8_t x56_63 = vld1q_f16(&X[idx + 56]);
+
+ if (alpha != 1.0) {
+ x0_7 = vmulq_n_f16(x0_7, alpha);
+ x8_15 = vmulq_n_f16(x8_15, alpha);
+ x16_23 = vmulq_n_f16(x16_23, alpha);
+ x24_31 = vmulq_n_f16(x24_31, alpha);
+ x32_39 = vmulq_n_f16(x32_39, alpha);
+ x40_47 = vmulq_n_f16(x40_47, alpha);
+ x48_55 = vmulq_n_f16(x48_55, alpha);
+ x56_63 = vmulq_n_f16(x56_63, alpha);
+ }
+
+ const __fp16 *__restrict w;
+
+ float yVal_low;
+ float yVal_high;
+
+ for (unsigned int j = 0; j < rows; ++j) {
+ w = &A[j * cols + idx];
+ float16x8_t wvec0_7 = vld1q_f16(&w[0]);
+ float16x8_t wvec8_15 = vld1q_f16(&w[8]);
+ float16x8_t wvec16_23 = vld1q_f16(&w[16]);
+ float16x8_t wvec24_31 = vld1q_f16(&w[24]);
+
+ float16x8_t wvec32_39 = vld1q_f16(&w[32]);
+ float16x8_t wvec40_47 = vld1q_f16(&w[40]);
+ float16x8_t wvec48_55 = vld1q_f16(&w[48]);
+ float16x8_t wvec56_63 = vld1q_f16(&w[56]);
+
+ float16x8_t y = vmulq_f16(wvec0_7, x0_7);
+ y = vfmaq_f16(y, wvec8_15, x8_15);
+ y = vfmaq_f16(y, wvec16_23, x16_23);
+ y = vfmaq_f16(y, wvec24_31, x24_31);
+
+ y = vfmaq_f16(y, wvec32_39, x32_39);
+ y = vfmaq_f16(y, wvec40_47, x40_47);
+ y = vfmaq_f16(y, wvec48_55, x48_55);
+ y = vfmaq_f16(y, wvec56_63, x56_63);
+
+ yVal_low = vaddvq_f32(vcvt_f32_f16(vget_low_f16(y)));
+ yVal_high = vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
+
+ Y32[j] += yVal_low + yVal_high;
+ }
+ }
for (; cols - idx >= 32; idx += 32) {
- float32x4_t x0_3_f32 = vcvt_f32_f16(vld1_f16(&X[idx]));
- float32x4_t x4_7_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 4]));
- float32x4_t x8_11_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 8]));
- float32x4_t x12_15_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 12]));
- float32x4_t x16_19_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 16]));
- float32x4_t x20_23_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 20]));
- float32x4_t x24_27_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 24]));
- float32x4_t x28_31_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 28]));
+ float16x8_t x0_7 = vld1q_f16(&X[idx]);
+ float16x8_t x8_15 = vld1q_f16(&X[idx + 8]);
+ float16x8_t x16_23 = vld1q_f16(&X[idx + 16]);
+ float16x8_t x24_31 = vld1q_f16(&X[idx + 24]);
if (alpha != 1.0) {
- x0_3_f32 = vmulq_n_f32(x0_3_f32, alpha);
- x4_7_f32 = vmulq_n_f32(x4_7_f32, alpha);
- x8_11_f32 = vmulq_n_f32(x8_11_f32, alpha);
- x12_15_f32 = vmulq_n_f32(x12_15_f32, alpha);
- x16_19_f32 = vmulq_n_f32(x16_19_f32, alpha);
- x20_23_f32 = vmulq_n_f32(x20_23_f32, alpha);
- x24_27_f32 = vmulq_n_f32(x24_27_f32, alpha);
- x28_31_f32 = vmulq_n_f32(x28_31_f32, alpha);
+ x0_7 = vmulq_n_f16(x0_7, alpha);
+ x8_15 = vmulq_n_f16(x8_15, alpha);
+ x16_23 = vmulq_n_f16(x16_23, alpha);
+ x24_31 = vmulq_n_f16(x24_31, alpha);
}
const __fp16 *__restrict w;
+ float yVal_low;
+ float yVal_high;
+
for (unsigned int j = 0; j < rows; ++j) {
w = &A[j * cols + idx];
- float32x4_t wvec0_3_f32 = vcvt_f32_f16(vld1_f16(&w[0]));
- float32x4_t wvec4_7_f32 = vcvt_f32_f16(vld1_f16(&w[4]));
- float32x4_t wvec8_11_f32 = vcvt_f32_f16(vld1_f16(&w[8]));
- float32x4_t wvec12_15_f32 = vcvt_f32_f16(vld1_f16(&w[12]));
- float32x4_t wvec16_19_f32 = vcvt_f32_f16(vld1_f16(&w[16]));
- float32x4_t wvec20_23_f32 = vcvt_f32_f16(vld1_f16(&w[20]));
- float32x4_t wvec24_27_f32 = vcvt_f32_f16(vld1_f16(&w[24]));
- float32x4_t wvec28_31_f32 = vcvt_f32_f16(vld1_f16(&w[28]));
+ float16x8_t wvec0_7 = vld1q_f16(&w[0]);
+ float16x8_t wvec8_15 = vld1q_f16(&w[8]);
+ float16x8_t wvec16_23 = vld1q_f16(&w[16]);
+ float16x8_t wvec24_31 = vld1q_f16(&w[24]);
- float32x4_t y0 = vmulq_f32(wvec0_3_f32, x0_3_f32);
- y0 = vfmaq_f32(y0, wvec4_7_f32, x4_7_f32);
- y0 = vfmaq_f32(y0, wvec8_11_f32, x8_11_f32);
- y0 = vfmaq_f32(y0, wvec12_15_f32, x12_15_f32);
- y0 = vfmaq_f32(y0, wvec16_19_f32, x16_19_f32);
- y0 = vfmaq_f32(y0, wvec20_23_f32, x20_23_f32);
- y0 = vfmaq_f32(y0, wvec24_27_f32, x24_27_f32);
- y0 = vfmaq_f32(y0, wvec28_31_f32, x28_31_f32);
+ float16x8_t y = vmulq_f16(wvec0_7, x0_7);
+ y = vfmaq_f16(y, wvec8_15, x8_15);
+ y = vfmaq_f16(y, wvec16_23, x16_23);
+ y = vfmaq_f16(y, wvec24_31, x24_31);
- Y32[j] += vaddvq_f32(y0);
+ yVal_low = vaddvq_f32(vcvt_f32_f16(vget_low_f16(y)));
+ yVal_high = vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
+
+ Y32[j] += yVal_low + yVal_high;
}
}
for (; cols - idx >= 16; idx += 16) {
- float32x4_t x0_3_f32 = vcvt_f32_f16(vld1_f16(&X[idx]));
- float32x4_t x4_7_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 4]));
- float32x4_t x8_11_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 8]));
- float32x4_t x12_15_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 12]));
+ float16x8_t x0_7 = vld1q_f16(&X[idx]);
+ float16x8_t x8_15 = vld1q_f16(&X[idx + 8]);
if (alpha != 1.0) {
- x0_3_f32 = vmulq_n_f32(x0_3_f32, alpha);
- x4_7_f32 = vmulq_n_f32(x4_7_f32, alpha);
- x8_11_f32 = vmulq_n_f32(x8_11_f32, alpha);
- x12_15_f32 = vmulq_n_f32(x12_15_f32, alpha);
+ x0_7 = vmulq_n_f16(x0_7, alpha);
+ x8_15 = vmulq_n_f16(x8_15, alpha);
}
const __fp16 *__restrict w;
-
+ float yVal_low;
+ float yVal_high;
for (unsigned int j = 0; j < rows; ++j) {
w = &A[j * cols + idx];
- float32x4_t wvec0_3_f32 = vcvt_f32_f16(vld1_f16(&w[0]));
- float32x4_t wvec4_7_f32 = vcvt_f32_f16(vld1_f16(&w[4]));
- float32x4_t wvec8_11_f32 = vcvt_f32_f16(vld1_f16(&w[8]));
- float32x4_t wvec12_15_f32 = vcvt_f32_f16(vld1_f16(&w[12]));
+ float16x8_t wvec0_7 = vld1q_f16(&w[0]);
+ float16x8_t wvec8_15 = vld1q_f16(&w[8]);
+
+ float16x8_t y = vmulq_f16(wvec0_7, x0_7);
+ y = vfmaq_f16(y, wvec8_15, x8_15);
- float32x4_t y0 = vmulq_f32(wvec0_3_f32, x0_3_f32);
- y0 = vfmaq_f32(y0, wvec4_7_f32, x4_7_f32);
- y0 = vfmaq_f32(y0, wvec8_11_f32, x8_11_f32);
- y0 = vfmaq_f32(y0, wvec12_15_f32, x12_15_f32);
+ yVal_low = vaddvq_f32(vcvt_f32_f16(vget_low_f16(y)));
+ yVal_high = vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
- Y32[j] += vaddvq_f32(y0);
+ Y32[j] += yVal_low + yVal_high;
}
}
for (; cols - idx >= 8; idx += 8) {
- float32x4_t x0_3_f32 = vcvt_f32_f16(vld1_f16(&X[idx]));
- float32x4_t x4_7_f32 = vcvt_f32_f16(vld1_f16(&X[idx + 4]));
+ float16x8_t x0_7 = vld1q_f16(&X[idx]);
if (alpha != 1.0) {
- x0_3_f32 = vmulq_n_f32(x0_3_f32, alpha);
- x4_7_f32 = vmulq_n_f32(x4_7_f32, alpha);
+ x0_7 = vmulq_n_f16(x0_7, alpha);
}
const __fp16 *__restrict w;
+ float yVal_low;
+ float yVal_high;
+
for (unsigned int j = 0; j < rows; ++j) {
w = &A[j * cols + idx];
- float32x4_t wvec0_3_f32 = vcvt_f32_f16(vld1_f16(&w[0]));
- float32x4_t wvec4_7_f32 = vcvt_f32_f16(vld1_f16(&w[4]));
+ float16x8_t wvec0_7 = vld1q_f16(&w[0]);
+ float16x8_t y = vmulq_f16(wvec0_7, x0_7);
- float32x4_t y0 = vmulq_f32(wvec0_3_f32, x0_3_f32);
- y0 = vfmaq_f32(y0, wvec4_7_f32, x4_7_f32);
+ yVal_low = vaddvq_f32(vcvt_f32_f16(vget_low_f16(y)));
+ yVal_high = vaddvq_f32(vcvt_f32_f16(vget_high_f16(y)));
- Y32[j] += vaddvq_f32(y0);
+ Y32[j] += yVal_low + yVal_high;
}
}
for (; cols - idx >= 4; idx += 4) {
- float32x4_t x0_3_f32 = vcvt_f32_f16(vld1_f16(&X[idx]));
+ float16x4_t x0_3 = vld1_f16(&X[idx]);
if (alpha != 1.0) {
- x0_3_f32 = vmulq_n_f32(x0_3_f32, alpha);
+ x0_3 = vmul_n_f16(x0_3, alpha);
}
const __fp16 *__restrict w;
for (unsigned int j = 0; j < rows; ++j) {
w = &A[j * cols + idx];
- float32x4_t wvec0_3_f32 = vcvt_f32_f16(vld1_f16(&w[0]));
- float32x4_t y0 = vmulq_f32(wvec0_3_f32, x0_3_f32);
+ float16x4_t wvec0_3 = (vld1_f16(&w[0]));
+ float16x4_t y0 = vmul_f16(wvec0_3, x0_3);
- Y32[j] += vaddvq_f32(y0);
+ Y32[j] += vaddvq_f32(vcvt_f32_f16(y0));
}
}
x0_3[j] = 0;
}
- float32x4_t x0_3_f32 = vcvt_f32_f16(x0_3);
-
if (alpha != 1.0) {
- x0_3_f32 = vmulq_n_f32(x0_3_f32, alpha);
+ x0_3 = vmul_n_f16(x0_3, alpha);
}
const __fp16 *__restrict w;
+ __fp16 yVal;
+
for (unsigned int j = 0; j < rows; ++j) {
w = &A[j * cols + idx];
float16x4_t wvec0_3 = vld1_f16(&w[0]);
wvec0_3[k] = 0;
}
- float32x4_t wvec0_3_f32 = vcvt_f32_f16(wvec0_3);
+ float16x4_t y0 = vmul_f16(wvec0_3, x0_3);
- float32x4_t y0 = vmulq_f32(wvec0_3_f32, x0_3_f32);
-
- Y32[j] += vaddvq_f32(y0);
+ for (int k = 0; k < cols - idx; ++k) {
+ Y32[j] += y0[k];
+ }
}
}