From 1b2508362b9033468eb98ea4146e31ab50d14fa3 Mon Sep 17 00:00:00 2001 From: Ashwin Sekhar T K Date: Fri, 1 Jan 2021 02:09:40 -0800 Subject: [PATCH] arm64: Fix nrm2 for input vectors with Inf Fix double precision nrm2 kernels returning NaN when the input vectors contain Inf/-Inf. --- kernel/arm64/KERNEL.NEOVERSEN1 | 8 ++++---- kernel/arm64/KERNEL.THUNDERX2T99 | 8 ++++---- kernel/arm64/KERNEL.THUNDERX3T110 | 17 +++++++---------- kernel/arm64/dznrm2_thunderx2t99.c | 28 +++++++++++++++++++++++++++- 4 files changed, 42 insertions(+), 19 deletions(-) diff --git a/kernel/arm64/KERNEL.NEOVERSEN1 b/kernel/arm64/KERNEL.NEOVERSEN1 index 074d7215..ea010db4 100644 --- a/kernel/arm64/KERNEL.NEOVERSEN1 +++ b/kernel/arm64/KERNEL.NEOVERSEN1 @@ -91,10 +91,10 @@ IDAMAXKERNEL = iamax_thunderx2t99.c ICAMAXKERNEL = izamax_thunderx2t99.c IZAMAXKERNEL = izamax_thunderx2t99.c -SNRM2KERNEL = nrm2.S -DNRM2KERNEL = nrm2.S -CNRM2KERNEL = znrm2.S -ZNRM2KERNEL = znrm2.S +SNRM2KERNEL = scnrm2_thunderx2t99.c +DNRM2KERNEL = dznrm2_thunderx2t99.c +CNRM2KERNEL = scnrm2_thunderx2t99.c +ZNRM2KERNEL = dznrm2_thunderx2t99.c DDOTKERNEL = dot_thunderx2t99.c SDOTKERNEL = dot_thunderx2t99.c diff --git a/kernel/arm64/KERNEL.THUNDERX2T99 b/kernel/arm64/KERNEL.THUNDERX2T99 index 8333f60e..a20d0d4a 100644 --- a/kernel/arm64/KERNEL.THUNDERX2T99 +++ b/kernel/arm64/KERNEL.THUNDERX2T99 @@ -153,12 +153,12 @@ IDAMAXKERNEL = iamax_thunderx2t99.c ICAMAXKERNEL = izamax_thunderx2t99.c IZAMAXKERNEL = izamax_thunderx2t99.c -SNRM2KERNEL = nrm2.S -CNRM2KERNEL = nrm2.S +SNRM2KERNEL = scnrm2_thunderx2t99.c +CNRM2KERNEL = scnrm2_thunderx2t99.c #DNRM2KERNEL = dznrm2_thunderx2t99_fast.c #ZNRM2KERNEL = dznrm2_thunderx2t99_fast.c -DNRM2KERNEL = znrm2.S -ZNRM2KERNEL = znrm2.S +DNRM2KERNEL = dznrm2_thunderx2t99.c +ZNRM2KERNEL = dznrm2_thunderx2t99.c DDOTKERNEL = dot_thunderx2t99.c diff --git a/kernel/arm64/KERNEL.THUNDERX3T110 b/kernel/arm64/KERNEL.THUNDERX3T110 index 4cdd8769..a20d0d4a 100644 --- a/kernel/arm64/KERNEL.THUNDERX3T110 +++ b/kernel/arm64/KERNEL.THUNDERX3T110 @@ -153,16 +153,13 @@ IDAMAXKERNEL = iamax_thunderx2t99.c ICAMAXKERNEL = izamax_thunderx2t99.c IZAMAXKERNEL = izamax_thunderx2t99.c -#SNRM2KERNEL = scnrm2_thunderx2t99.c -#CNRM2KERNEL = scnrm2_thunderx2t99.c -##DNRM2KERNEL = dznrm2_thunderx2t99_fast.c -##ZNRM2KERNEL = dznrm2_thunderx2t99_fast.c -#DNRM2KERNEL = dznrm2_thunderx2t99.c -#ZNRM2KERNEL = dznrm2_thunderx2t99.c -SNRM2KERNEL = nrm2.S -DNRM2KERNEL = nrm2.S -CNRM2KERNEL = znrm2.S -ZNRM2KERNEL = znrm2.S +SNRM2KERNEL = scnrm2_thunderx2t99.c +CNRM2KERNEL = scnrm2_thunderx2t99.c +#DNRM2KERNEL = dznrm2_thunderx2t99_fast.c +#ZNRM2KERNEL = dznrm2_thunderx2t99_fast.c +DNRM2KERNEL = dznrm2_thunderx2t99.c +ZNRM2KERNEL = dznrm2_thunderx2t99.c + DDOTKERNEL = dot_thunderx2t99.c SDOTKERNEL = dot_thunderx2t99.c diff --git a/kernel/arm64/dznrm2_thunderx2t99.c b/kernel/arm64/dznrm2_thunderx2t99.c index b94f0cff..b021a283 100644 --- a/kernel/arm64/dznrm2_thunderx2t99.c +++ b/kernel/arm64/dznrm2_thunderx2t99.c @@ -58,6 +58,7 @@ extern int blas_level1_thread_with_return_value(int mode, BLASLONG m, BLASLONG n #define CUR_MAXINV "d8" #define CUR_MAXINV_V "v8.2d" #define CUR_MAX_V "v8.2d" +#define REGINF "d9" static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x, double *ssq, double *scale) @@ -79,8 +80,10 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x, " ble 9f //nrm2_kernel_L999 \n" "1: //nrm2_kernel_F_BEGIN: \n" + " mov x6, #0x7FF0000000000000 //+Infinity \n" " fmov "REGZERO", xzr \n" " fmov "REGONE", #1.0 \n" + " fmov "REGINF", x6 \n" " lsl "INC_X", "INC_X", #"INC_SHIFT" \n" " mov "J", "N" \n" " cmp "J", xzr \n" @@ -104,6 +107,8 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x, " ldr d4, ["X"] \n" " fabs d4, d4 \n" " fmax "CUR_MAX", "SCALE", d4 \n" + " fcmp "CUR_MAX", "REGINF" \n" + " beq 10f \n" " fdiv "SCALE", "SCALE", "CUR_MAX" \n" " fmul "SCALE", "SCALE", "SCALE" \n" " fmul "SSQ", "SSQ", "SCALE" \n" @@ -116,6 +121,8 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x, " ldr d3, ["X", #8] \n" " fabs d3, d3 \n" " fmax "CUR_MAX", "SCALE", d3 \n" + " fcmp "CUR_MAX", "REGINF" \n" + " beq 10f \n" " fdiv "SCALE", "SCALE", "CUR_MAX" \n" " fmul "SCALE", "SCALE", "SCALE" \n" " fmul "SSQ", "SSQ", "SCALE" \n" @@ -158,6 +165,8 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x, " fmaxp v24.2d, v24.2d, v26.2d \n" " fmaxp v24.2d, v24.2d, v24.2d \n" " fmax "CUR_MAX", "SCALE", d24 \n" + " fcmp "CUR_MAX", "REGINF" \n" + " beq 10f \n" " fdiv "CUR_MAXINV", "REGONE", "CUR_MAX" \n" " //dup "CUR_MAX_V", v7.d[0] \n" " fdiv "SCALE", "SCALE", "CUR_MAX" \n" @@ -217,6 +226,8 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x, " fmaxp v24.2d, v24.2d, v26.2d \n" " fmaxp v24.2d, v24.2d, v24.2d \n" " fmax "CUR_MAX", "SCALE", d24 \n" + " fcmp "CUR_MAX", "REGINF" \n" + " beq 10f \n" " fdiv "CUR_MAXINV", "REGONE", "CUR_MAX" \n" " //dup "CUR_MAX_V", v7.d[0] \n" " fdiv "SCALE", "SCALE", "CUR_MAX" \n" @@ -265,6 +276,8 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x, " ldr d4, ["X"] \n" " fabs d4, d4 \n" " fmax "CUR_MAX", "SCALE", d4 \n" + " fcmp "CUR_MAX", "REGINF" \n" + " beq 10f \n" " fdiv "SCALE", "SCALE", "CUR_MAX" \n" " fmul "SCALE", "SCALE", "SCALE" \n" " fmul "SSQ", "SSQ", "SCALE" \n" @@ -276,6 +289,8 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x, " ldr d3, ["X", #8] \n" " fabs d3, d3 \n" " fmax "CUR_MAX", "SCALE", d3 \n" + " fcmp "CUR_MAX", "REGINF" \n" + " beq 10f \n" " fdiv "SCALE", "SCALE", "CUR_MAX" \n" " fmul "SCALE", "SCALE", "SCALE" \n" " fmul "SSQ", "SSQ", "SCALE" \n" @@ -291,6 +306,11 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x, "9: //nrm2_kernel_L999: \n" " str "SSQ", [%[SSQ_]] \n" " str "SCALE", [%[SCALE_]] \n" + " b 11f \n" + "10: \n" + " str "REGINF", [%[SSQ_]] \n" + " str "REGINF", [%[SCALE_]] \n" + "11: \n" : : [SSQ_] "r" (ssq), //%0 @@ -300,7 +320,7 @@ static void nrm2_compute(BLASLONG n, FLOAT *x, BLASLONG inc_x, [INCX_] "r" (inc_x) //%4 : "cc", "memory", - "x0", "x1", "x2", "x3", "x4", "x5", + "x0", "x1", "x2", "x3", "x4", "x5", "x6", "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8" ); @@ -359,6 +379,12 @@ FLOAT CNAME(BLASLONG n, FLOAT *x, BLASLONG inc_x) cur_ssq = *ptr; cur_scale = *(ptr + 1); + if (cur_ssq == INFINITY) { + ssq = INFINITY; + scale = INFINITY; + break; + } + if (cur_scale != 0) { if (cur_scale > scale) { scale = (scale / cur_scale); -- 2.34.1