fix sve dtrsm kernels
authorBine Brank <binebrank@gmail.com>
Sat, 15 Jan 2022 20:02:14 +0000 (21:02 +0100)
committerBine Brank <binebrank@gmail.com>
Sat, 15 Jan 2022 20:02:14 +0000 (21:02 +0100)
kernel/arm64/trsm_kernel_LN_sve.c
kernel/arm64/trsm_kernel_LT_sve.c
kernel/arm64/trsm_kernel_RT_sve.c
kernel/arm64/trsm_lncopy_sve.c
kernel/arm64/trsm_ltcopy_sve.c
kernel/arm64/trsm_uncopy_sve.c
kernel/arm64/trsm_utcopy_sve.c

index c29c3b5..57f79ac 100644 (file)
@@ -182,8 +182,8 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k,  FLOAT dummy1,
 
     i = m % sve_size;
     if (i) {
-      aa = a + ((m & ~(i - 1)) - i) * k * COMPSIZE;
-      cc = c + ((m & ~(i - 1)) - i)     * COMPSIZE;
+      aa = a + (m - i) * k * COMPSIZE;
+      cc = c + (m - i)     * COMPSIZE;
 
       if (k - kk > 0) {
         GEMM_KERNEL(i, GEMM_UNROLL_N, k - kk, dm1,
@@ -205,10 +205,11 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k,  FLOAT dummy1,
 
     }
 
+    int mod = i;
     i = sve_size;
     if (i <= m) {
-      aa = a + ((m & ~(sve_size - 1)) - sve_size) * k * COMPSIZE;
-      cc = c + ((m & ~(sve_size - 1)) - sve_size)     * COMPSIZE;
+      aa = a + (m - mod - sve_size) * k * COMPSIZE;
+      cc = c + (m - mod - sve_size)     * COMPSIZE;
 
       do {
         if (k - kk > 0) {
@@ -217,7 +218,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k,  FLOAT dummy1,
               ZERO,
 #endif
               aa + sve_size * kk * COMPSIZE,
-              b +  sve_size * kk * COMPSIZE,
+              b +  GEMM_UNROLL_N * kk * COMPSIZE,
               cc,
               ldc);
         }
@@ -251,8 +252,8 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k,  FLOAT dummy1,
 
         i = m % sve_size;
         if (i) {
-          aa = a + ((m & ~(i - 1)) - i) * k * COMPSIZE;
-          cc = c + ((m & ~(i - 1)) - i)     * COMPSIZE;
+          aa = a + (m - i) * k * COMPSIZE;
+          cc = c + (m - i)     * COMPSIZE;
 
           if (k - kk > 0) {
             GEMM_KERNEL(i, j, k - kk, dm1,
@@ -273,10 +274,11 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k,  FLOAT dummy1,
 
         }
 
+        int mod = i;
         i = sve_size;
         if (i <= m) {
-          aa = a + ((m & ~(sve_size - 1)) - sve_size) * k * COMPSIZE;
-          cc = c + ((m & ~(sve_size - 1)) - sve_size)     * COMPSIZE;
+          aa = a + (m - mod - sve_size) * k * COMPSIZE;
+          cc = c + (m - mod - sve_size)     * COMPSIZE;
 
           do {
             if (k - kk > 0) {
index 7f54597..8c6a57a 100644 (file)
@@ -257,7 +257,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1,
           i += sve_size;
         }
 
-        i = sve_size % m;
+        i = m % sve_size;
         if (i) {
           if (kk > 0) {
             GEMM_KERNEL(i, j, kk, dm1,
index d93ebe7..efafc9d 100644 (file)
@@ -258,23 +258,23 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k,  FLOAT dummy1,
       if (i <= m) {
        do {
          if (k - kk > 0) {
-           GEMM_KERNEL(GEMM_UNROLL_M, GEMM_UNROLL_N, k - kk, dm1,
+           GEMM_KERNEL(sve_size, GEMM_UNROLL_N, k - kk, dm1,
 #ifdef COMPLEX
                        ZERO,
 #endif
-                       aa + GEMM_UNROLL_M * kk * COMPSIZE,
+                       aa + sve_size * kk * COMPSIZE,
                        b  + GEMM_UNROLL_N * kk * COMPSIZE,
                        cc,
                        ldc);
          }
 
-         solve(GEMM_UNROLL_M, GEMM_UNROLL_N,
-               aa + (kk - GEMM_UNROLL_N) * GEMM_UNROLL_M * COMPSIZE,
+         solve(sve_size, GEMM_UNROLL_N,
+               aa + (kk - GEMM_UNROLL_N) * sve_size * COMPSIZE,
                b  + (kk - GEMM_UNROLL_N) * GEMM_UNROLL_N * COMPSIZE,
                cc, ldc);
 
-         aa += GEMM_UNROLL_M * k * COMPSIZE;
-         cc += GEMM_UNROLL_M     * COMPSIZE;
+         aa += sve_size * k * COMPSIZE;
+         cc += sve_size     * COMPSIZE;
          i += sve_size;
        } while (i <= m);
       }
index d96a1f3..7f480dc 100644 (file)
 
 int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){
 
-  BLASLONG i, ii, j, jj;
+  BLASLONG i, ii, jj;
 
   FLOAT *ao;
 
   jj = offset;
-  int js = 0;
 #ifdef DOUBLE
+  int64_t js = 0;
   svint64_t index = svindex_s64(0LL, lda);
   svbool_t pn = svwhilelt_b64(js, n);
   int n_active = svcntp_b64(svptrue_b64(), pn);
 #else
+  int32_t js = 0;
   svint32_t index = svindex_s32(0, lda);
   svbool_t pn = svwhilelt_b32(js, n);
   int n_active = svcntp_b32(svptrue_b32(), pn);
@@ -74,25 +75,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT
       if (ii == jj) {
         for (int j = 0; j < n_active; j++) {
           for (int k = 0; k < j; k++) {
-            *(b + j * n_active + k) = *(a + k * lda + j);
+            *(b + j * n_active + k) = *(ao + k * lda + j);
           }
-          *(b + j * n_active + j) = INV(*(a + j * lda + j));
+          *(b + j * n_active + j) = INV(*(ao + j * lda + j));
         }
-      }
-
-      if (ii > jj) {
-        for (int j = 0; j < n_active; j++) {
+        ao += n_active;
+        b += n_active * n_active;
+        i += n_active;
+        ii += n_active;
+      } else {
+        if (ii > jj) {
           svfloat64_t aj_vec = svld1_gather_index(pn, ao, index);
           svst1(pn, b, aj_vec);
-          ao++;
         }
-
+        ao++;
+        b += n_active;
+        i++;
+        ii++;
       }
-
-      b += n_active * n_active;
-
-      i += n_active;
-      ii += n_active;
     } while (i < m);
 
 
index 9012f7f..d7b2a4e 100644 (file)
 
 int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){
 
-  BLASLONG i, ii, j, jj;
+  BLASLONG i, ii, jj;
 
   FLOAT *ao;
 
   jj = offset;
-  int js = 0;
 #ifdef DOUBLE
-  svint64_t index = svindex_s64(0LL, lda);
+  int64_t js = 0;
   svbool_t pn = svwhilelt_b64(js, n);
   int n_active = svcntp_b64(svptrue_b64(), pn);
 #else
-  svint32_t index = svindex_s32(0, lda);
+  int32_t js = 0;
   svbool_t pn = svwhilelt_b32(js, n);
   int n_active = svcntp_b32(svptrue_b32(), pn);
 #endif
@@ -73,26 +72,25 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT
 
       if (ii == jj) {
         for (int j = 0; j < n_active; j++) {
-          *(b + j * n_active + j) = INV(*(a + j * lda + j));
+          *(b + j * n_active + j) = INV(*(ao + j * lda + j));
           for (int k = j+1; k < n_active; k++) {
-            *(b + j * n_active + k) = *(a + j * lda + k);
+            *(b + j * n_active + k) = *(ao + j * lda + k);
           }
         }
-      }
-
-      if (ii < jj) {
-        for (int j = 0; j < n_active; j++) {
+        b += n_active * n_active;
+        ao += lda * n_active;
+        i += n_active;
+        ii += n_active;
+      } else {
+        if (ii < jj) {
           svfloat64_t aj_vec = svld1(pn, ao);
           svst1(pn, b, aj_vec);
-          ao += lda;
         }
-
+        ao += lda;
+        b += n_active;
+        i ++;
+        ii ++;
       }
-
-      b += n_active * n_active;
-
-      i += n_active;
-      ii += n_active;
     } while (i < m);
 
 
index 242e99f..b285145 100644 (file)
 
 int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){
 
-  BLASLONG i, ii, j, jj;
+  BLASLONG i, ii, jj;
 
   FLOAT *ao;
 
   jj = offset;
-  int js = 0;
 #ifdef DOUBLE
+  int64_t js = 0;
   svint64_t index = svindex_s64(0LL, lda);
   svbool_t pn = svwhilelt_b64(js, n);
   int n_active = svcntp_b64(svptrue_b64(), pn);
 #else
+  int32_t js = 0;
   svint32_t index = svindex_s32(0, lda);
   svbool_t pn = svwhilelt_b32(js, n);
   int n_active = svcntp_b32(svptrue_b32(), pn);
@@ -73,25 +74,25 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT
 
       if (ii == jj) {
         for (int j = 0; j < n_active; j++) {
-          *(b + j * n_active + j) = INV(*(a + j * lda + j));
+          *(b + j * n_active + j) = INV(*(ao + j * lda + j));
           for (int k = j+1; k < n_active; k++) {
-            *(b + j * n_active + k) = *(a + k * lda + j);
+            *(b + j * n_active + k) = *(ao + k * lda + j);
           }
         }
-      }
-
-      if (ii < jj) {
-        for (int j = 0; j < n_active; j++) {
+        ao += n_active;
+        b += n_active * n_active;
+        i += n_active;
+        ii += n_active;
+      } else {
+        if (ii < jj) {
           svfloat64_t aj_vec = svld1_gather_index(pn, ao, index);
           svst1(pn, b, aj_vec);
-          ao++;
         }
+        ao++;
+        b += n_active;
+        i++;
+        ii++;
       }
-
-      b += n_active * n_active;
-
-      i += n_active;
-      ii += n_active;
     } while (i < m);
 
 
index 9eefb8c..5589558 100644 (file)
 
 int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){
 
-  BLASLONG i, ii, j, jj;
+  BLASLONG i, ii, jj;
 
   FLOAT *ao;
 
   jj = offset;
-  int js = 0;
 #ifdef DOUBLE
-  svint64_t index = svindex_s64(0LL, lda);
+  int64_t js = 0;
   svbool_t pn = svwhilelt_b64(js, n);
   int n_active = svcntp_b64(svptrue_b64(), pn);
 #else
-  svint32_t index = svindex_s32(0, lda);
+  int32_t js = 0;
   svbool_t pn = svwhilelt_b32(js, n);
   int n_active = svcntp_b32(svptrue_b32(), pn);
 #endif
@@ -74,25 +73,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT
       if (ii == jj) {
         for (int j = 0; j < n_active; j++) {
           for (int k = 0; k < j; k++) {
-            *(b + j * n_active + k) = *(a + j * lda + k);
+            *(b + j * n_active + k) = *(ao + j * lda + k);
           }
-          *(b + j * n_active + j) = INV(*(a + j * lda + j));
+          *(b + j * n_active + j) = INV(*(ao + j * lda + j));
         }
-      }
-
-      if (ii > jj) {
-        for (int j = 0; j < n_active; j++) {
+        ao += lda * n_active;
+        b += n_active * n_active;
+        i += n_active;
+        ii += n_active;
+      } else {
+        if (ii > jj) {
           svfloat64_t aj_vec = svld1(pn, ao);
           svst1(pn, b, aj_vec);
-          ao += lda;
         }
-
-      }
-
-      b += n_active * n_active;
-
-      i += n_active;
-      ii += n_active;
+        ao += lda;
+        b += n_active;
+        i ++;
+        ii ++;
+      } 
     } while (i < m);