s390x: fix cscal and zscal implementations
authorMarius Hillenbrand <mhillen@linux.ibm.com>
Mon, 14 Sep 2020 16:36:31 +0000 (18:36 +0200)
committerMarius Hillenbrand <mhillen@linux.ibm.com>
Mon, 21 Sep 2020 11:10:05 +0000 (13:10 +0200)
The implementation of complex scalar * vector multiplication for Z14
makes some LAPACK tests fail because the numerical differences to the
reference implementation exceed the threshold (as can be seen by running
make lapack-test and replacing kernel/zarch/cscal.c with a generic
implementation for comparison).

The complex multiplication uses terms of the form a * b + c * d for both
real and imaginary parts. The assembly code (and compiler-emitted code
as well) uses fused multiply add operations for the second product and
sum. The results can be "surprising", for example when both terms in the
imaginary part nearly cancel each other out. In that case, the second
product contributes more digits to the sum than the first product that
has been rounded before.

One option is to use separate multiplications (which then round the same
way) and a distinct add. Change the code to pursue that path, by (1)
requesting the compiler not to contract the operations into FMAs and (2)
replacing the assembly kernel with corresponding vectorized C code
(where change 1 also applies).

Signed-off-by: Marius Hillenbrand <mhillen@linux.ibm.com>
kernel/zarch/cscal.c
kernel/zarch/zscal.c

index f9e89a4..57bb89c 100644 (file)
@@ -25,67 +25,35 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
 USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *****************************************************************************/
 
+/*
+ * Avoid contraction of floating point operations, specifically fused
+ * multiply-add, because they can cause unexpected results in complex
+ * multiplication.
+ */
+#if defined(__GNUC__) && !defined(__clang__)
+#pragma GCC optimize ("fp-contract=off")
+#endif
+
+#if defined(__clang__)
+#pragma clang fp contract(off)
+#endif
+
 #include "common.h"
+#include "vector-common.h"
 
-static void cscal_kernel_16(BLASLONG n, FLOAT *alpha, FLOAT *x) {
-  __asm__("vlrepf %%v0,0(%[alpha])\n\t"
-    "vlef   %%v1,4(%[alpha]),0\n\t"
-    "vlef   %%v1,4(%[alpha]),2\n\t"
-    "vflcsb %%v1,%%v1\n\t"
-    "vlef   %%v1,4(%[alpha]),1\n\t"
-    "vlef   %%v1,4(%[alpha]),3\n\t"
-    "srlg %[n],%[n],4\n\t"
-    "xgr   %%r1,%%r1\n\t"
-    "0:\n\t"
-    "pfd 2, 1024(%%r1,%[x])\n\t"
-    "vl   %%v16,0(%%r1,%[x])\n\t"
-    "vl   %%v17,16(%%r1,%[x])\n\t"
-    "vl   %%v18,32(%%r1,%[x])\n\t"
-    "vl   %%v19,48(%%r1,%[x])\n\t"
-    "vl   %%v20,64(%%r1,%[x])\n\t"
-    "vl   %%v21,80(%%r1,%[x])\n\t"
-    "vl   %%v22,96(%%r1,%[x])\n\t"
-    "vl   %%v23,112(%%r1,%[x])\n\t"
-    "verllg   %%v24,%%v16,32\n\t"
-    "verllg   %%v25,%%v17,32\n\t"
-    "verllg   %%v26,%%v18,32\n\t"
-    "verllg   %%v27,%%v19,32\n\t"
-    "verllg   %%v28,%%v20,32\n\t"
-    "verllg   %%v29,%%v21,32\n\t"
-    "verllg   %%v30,%%v22,32\n\t"
-    "verllg   %%v31,%%v23,32\n\t"
-    "vfmsb %%v16,%%v16,%%v0\n\t"
-    "vfmsb %%v17,%%v17,%%v0\n\t"
-    "vfmsb %%v18,%%v18,%%v0\n\t"
-    "vfmsb %%v19,%%v19,%%v0\n\t"
-    "vfmsb %%v20,%%v20,%%v0\n\t"
-    "vfmsb %%v21,%%v21,%%v0\n\t"
-    "vfmsb %%v22,%%v22,%%v0\n\t"
-    "vfmsb %%v23,%%v23,%%v0\n\t"
-    "vfmasb %%v16,%%v24,%%v1,%%v16\n\t"
-    "vfmasb %%v17,%%v25,%%v1,%%v17\n\t"
-    "vfmasb %%v18,%%v26,%%v1,%%v18\n\t"
-    "vfmasb %%v19,%%v27,%%v1,%%v19\n\t"
-    "vfmasb %%v20,%%v28,%%v1,%%v20\n\t"
-    "vfmasb %%v21,%%v29,%%v1,%%v21\n\t"
-    "vfmasb %%v22,%%v30,%%v1,%%v22\n\t"
-    "vfmasb %%v23,%%v31,%%v1,%%v23\n\t"
-    "vst %%v16,0(%%r1,%[x])\n\t"
-    "vst %%v17,16(%%r1,%[x])\n\t"
-    "vst %%v18,32(%%r1,%[x])\n\t"
-    "vst %%v19,48(%%r1,%[x])\n\t"
-    "vst %%v20,64(%%r1,%[x])\n\t"
-    "vst %%v21,80(%%r1,%[x])\n\t"
-    "vst %%v22,96(%%r1,%[x])\n\t"
-    "vst %%v23,112(%%r1,%[x])\n\t"
-    "agfi  %%r1,128\n\t"
-    "brctg %[n],0b"
-    : "+m"(*(FLOAT (*)[n * 2]) x),[n] "+&r"(n)
-    : [x] "a"(x), "m"(*(const FLOAT (*)[2]) alpha),
-       [alpha] "a"(alpha)
-    : "cc", "r1", "v0", "v1", "v16", "v17", "v18", "v19", "v20", "v21",
-       "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
-       "v31");
+static void cscal_kernel_16(BLASLONG n, FLOAT da_r, FLOAT da_i, FLOAT *x) {
+    vector_float da_r_vec = vec_splats(da_r);
+    vector_float da_i_vec = { -da_i, da_i, -da_i, da_i };
+
+    vector_float *x_vec_ptr = (vector_float *)x;
+
+#pragma GCC unroll 16
+    for (size_t i = 0; i < n/2; i++) {
+       vector_float x_vec = vec_load_hinted(x + i * VLEN_FLOATS);
+       vector_float x_swapped = {x_vec[1], x_vec[0], x_vec[3], x_vec[2]};
+
+       x_vec_ptr[i] =  x_vec * da_r_vec + x_swapped * da_i_vec;
+    }
 }
 
 static void cscal_kernel_16_zero_r(BLASLONG n, FLOAT *alpha, FLOAT *x) {
@@ -199,14 +167,12 @@ static void cscal_kernel_16_zero(BLASLONG n, FLOAT *x) {
     : "cc", "r1", "v0");
 }
 
-static void cscal_kernel_inc_8(BLASLONG n, FLOAT *alpha, FLOAT *x,
+static void cscal_kernel_inc_8(BLASLONG n, FLOAT da_r, FLOAT da_i, FLOAT *x,
                                BLASLONG inc_x) {
   BLASLONG i;
   BLASLONG inc_x2 = 2 * inc_x;
   BLASLONG inc_x3 = inc_x2 + inc_x;
   FLOAT t0, t1, t2, t3;
-  FLOAT da_r = alpha[0];
-  FLOAT da_i = alpha[1];
 
   for (i = 0; i < n; i += 4) {
     t0 = da_r * x[0] - da_i * x[1];
@@ -324,9 +290,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
 
         BLASLONG n1 = n & -8;
         if (n1 > 0) {
-          alpha[0] = da_r;
-          alpha[1] = da_i;
-          cscal_kernel_inc_8(n1, alpha, x, inc_x);
+          cscal_kernel_inc_8(n1, da_r, da_i, x, inc_x);
           j = n1;
           i = n1 * inc_x;
         }
@@ -362,7 +326,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
     else if (da_i == 0)
       cscal_kernel_16_zero_i(n1, alpha, x);
     else
-      cscal_kernel_16(n1, alpha, x);
+      cscal_kernel_16(n1, da_r, da_i, x);
 
     i = n1 << 1;
     j = n1;
index a5a8f69..d39b844 100644 (file)
@@ -25,65 +25,35 @@ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
 USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *****************************************************************************/
 
+/*
+ * Avoid contraction of floating point operations, specifically fused
+ * multiply-add, because they can cause unexpected results in complex
+ * multiplication.
+ */
+#if defined(__GNUC__) && !defined(__clang__)
+#pragma GCC optimize ("fp-contract=off")
+#endif
+
+#if defined(__clang__)
+#pragma clang fp contract(off)
+#endif
+
 #include "common.h"
+#include "vector-common.h"
 
-static void zscal_kernel_8(BLASLONG n, FLOAT *alpha, FLOAT *x) {
-  __asm__("vlrepg %%v0,0(%[alpha])\n\t"
-    "vleg   %%v1,8(%[alpha]),0\n\t"
-    "wflcdb %%v1,%%v1\n\t"
-    "vleg   %%v1,8(%[alpha]),1\n\t"
-    "srlg %[n],%[n],3\n\t"
-    "xgr   %%r1,%%r1\n\t"
-    "0:\n\t"
-    "pfd 2, 1024(%%r1,%[x])\n\t"
-    "vl   %%v16,0(%%r1,%[x])\n\t"
-    "vl   %%v17,16(%%r1,%[x])\n\t"
-    "vl   %%v18,32(%%r1,%[x])\n\t"
-    "vl   %%v19,48(%%r1,%[x])\n\t"
-    "vl   %%v20,64(%%r1,%[x])\n\t"
-    "vl   %%v21,80(%%r1,%[x])\n\t"
-    "vl   %%v22,96(%%r1,%[x])\n\t"
-    "vl   %%v23,112(%%r1,%[x])\n\t"
-    "vpdi %%v24,%%v16,%%v16,4\n\t"
-    "vpdi %%v25,%%v17,%%v17,4\n\t"
-    "vpdi %%v26,%%v18,%%v18,4\n\t"
-    "vpdi %%v27,%%v19,%%v19,4\n\t"
-    "vpdi %%v28,%%v20,%%v20,4\n\t"
-    "vpdi %%v29,%%v21,%%v21,4\n\t"
-    "vpdi %%v30,%%v22,%%v22,4\n\t"
-    "vpdi %%v31,%%v23,%%v23,4\n\t"
-    "vfmdb %%v16,%%v16,%%v0\n\t"
-    "vfmdb %%v17,%%v17,%%v0\n\t"
-    "vfmdb %%v18,%%v18,%%v0\n\t"
-    "vfmdb %%v19,%%v19,%%v0\n\t"
-    "vfmdb %%v20,%%v20,%%v0\n\t"
-    "vfmdb %%v21,%%v21,%%v0\n\t"
-    "vfmdb %%v22,%%v22,%%v0\n\t"
-    "vfmdb %%v23,%%v23,%%v0\n\t"
-    "vfmadb %%v16,%%v24,%%v1,%%v16\n\t"
-    "vfmadb %%v17,%%v25,%%v1,%%v17\n\t"
-    "vfmadb %%v18,%%v26,%%v1,%%v18\n\t"
-    "vfmadb %%v19,%%v27,%%v1,%%v19\n\t"
-    "vfmadb %%v20,%%v28,%%v1,%%v20\n\t"
-    "vfmadb %%v21,%%v29,%%v1,%%v21\n\t"
-    "vfmadb %%v22,%%v30,%%v1,%%v22\n\t"
-    "vfmadb %%v23,%%v31,%%v1,%%v23\n\t"
-    "vst %%v16,0(%%r1,%[x])\n\t"
-    "vst %%v17,16(%%r1,%[x])\n\t"
-    "vst %%v18,32(%%r1,%[x])\n\t"
-    "vst %%v19,48(%%r1,%[x])\n\t"
-    "vst %%v20,64(%%r1,%[x])\n\t"
-    "vst %%v21,80(%%r1,%[x])\n\t"
-    "vst %%v22,96(%%r1,%[x])\n\t"
-    "vst %%v23,112(%%r1,%[x])\n\t"
-    "agfi  %%r1,128\n\t"
-    "brctg %[n],0b"
-    : "+m"(*(FLOAT (*)[n * 2]) x),[n] "+&r"(n)
-    : [x] "a"(x), "m"(*(const FLOAT (*)[2]) alpha),
-       [alpha] "a"(alpha)
-    : "cc", "r1", "v0", "v1", "v16", "v17", "v18", "v19", "v20", "v21",
-       "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30",
-       "v31");
+static void zscal_kernel_8(BLASLONG n, FLOAT da_r, FLOAT da_i, FLOAT *x) {
+    vector_float da_r_vec = vec_splats(da_r);
+    vector_float da_i_vec = { -da_i, da_i };
+
+    vector_float * x_vec_ptr = (vector_float *)x;
+
+#pragma GCC unroll 16
+    for (size_t i = 0; i < n; i++) {
+       vector_float x_vec = vec_load_hinted(x + i * VLEN_FLOATS);
+       vector_float x_swapped = {x_vec[1], x_vec[0]};
+
+       x_vec_ptr[i] = x_vec * da_r_vec + x_swapped * da_i_vec;
+    }
 }
 
 static void zscal_kernel_8_zero_r(BLASLONG n, FLOAT *alpha, FLOAT *x) {
@@ -195,14 +165,12 @@ static void zscal_kernel_8_zero(BLASLONG n, FLOAT *x) {
     : "cc", "r1", "v0");
 }
 
-static void zscal_kernel_inc_8(BLASLONG n, FLOAT *alpha, FLOAT *x,
+static void zscal_kernel_inc_8(BLASLONG n, FLOAT da_r, FLOAT da_i, FLOAT *x,
                                BLASLONG inc_x) {
   BLASLONG i;
   BLASLONG inc_x2 = 2 * inc_x;
   BLASLONG inc_x3 = inc_x2 + inc_x;
   FLOAT t0, t1, t2, t3;
-  FLOAT da_r = alpha[0];
-  FLOAT da_i = alpha[1];
 
   for (i = 0; i < n; i += 4) {
     t0 = da_r * x[0] - da_i * x[1];
@@ -320,9 +288,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
 
         BLASLONG n1 = n & -8;
         if (n1 > 0) {
-          alpha[0] = da_r;
-          alpha[1] = da_i;
-          zscal_kernel_inc_8(n1, alpha, x, inc_x);
+          zscal_kernel_inc_8(n1, da_r, da_i, x, inc_x);
           j = n1;
           i = n1 * inc_x;
         }
@@ -358,7 +324,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
     else if (da_i == 0)
       zscal_kernel_8_zero_i(n1, alpha, x);
     else
-      zscal_kernel_8(n1, alpha, x);
+      zscal_kernel_8(n1, da_r, da_i, x);
 
     i = n1 << 1;
     j = n1;