From aaa2b1a861623eb012288c2b401fa923933da55c Mon Sep 17 00:00:00 2001 From: Bine Brank Date: Sat, 15 Jan 2022 21:02:14 +0100 Subject: [PATCH] fix sve dtrsm kernels --- kernel/arm64/trsm_kernel_LN_sve.c | 20 +++++++++++--------- kernel/arm64/trsm_kernel_LT_sve.c | 2 +- kernel/arm64/trsm_kernel_RT_sve.c | 12 ++++++------ kernel/arm64/trsm_lncopy_sve.c | 30 +++++++++++++++--------------- kernel/arm64/trsm_ltcopy_sve.c | 32 +++++++++++++++----------------- kernel/arm64/trsm_uncopy_sve.c | 29 +++++++++++++++-------------- kernel/arm64/trsm_utcopy_sve.c | 34 ++++++++++++++++------------------ 7 files changed, 79 insertions(+), 80 deletions(-) diff --git a/kernel/arm64/trsm_kernel_LN_sve.c b/kernel/arm64/trsm_kernel_LN_sve.c index c29c3b5..57f79ac 100644 --- a/kernel/arm64/trsm_kernel_LN_sve.c +++ b/kernel/arm64/trsm_kernel_LN_sve.c @@ -182,8 +182,8 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1, i = m % sve_size; if (i) { - aa = a + ((m & ~(i - 1)) - i) * k * COMPSIZE; - cc = c + ((m & ~(i - 1)) - i) * COMPSIZE; + aa = a + (m - i) * k * COMPSIZE; + cc = c + (m - i) * COMPSIZE; if (k - kk > 0) { GEMM_KERNEL(i, GEMM_UNROLL_N, k - kk, dm1, @@ -205,10 +205,11 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1, } + int mod = i; i = sve_size; if (i <= m) { - aa = a + ((m & ~(sve_size - 1)) - sve_size) * k * COMPSIZE; - cc = c + ((m & ~(sve_size - 1)) - sve_size) * COMPSIZE; + aa = a + (m - mod - sve_size) * k * COMPSIZE; + cc = c + (m - mod - sve_size) * COMPSIZE; do { if (k - kk > 0) { @@ -217,7 +218,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1, ZERO, #endif aa + sve_size * kk * COMPSIZE, - b + sve_size * kk * COMPSIZE, + b + GEMM_UNROLL_N * kk * COMPSIZE, cc, ldc); } @@ -251,8 +252,8 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1, i = m % sve_size; if (i) { - aa = a + ((m & ~(i - 1)) - i) * k * COMPSIZE; - cc = c + ((m & ~(i - 1)) - i) * COMPSIZE; + aa = a + (m - i) * k * COMPSIZE; + cc = c + (m - i) * COMPSIZE; if (k - kk > 0) { GEMM_KERNEL(i, j, k - kk, dm1, @@ -273,10 +274,11 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1, } + int mod = i; i = sve_size; if (i <= m) { - aa = a + ((m & ~(sve_size - 1)) - sve_size) * k * COMPSIZE; - cc = c + ((m & ~(sve_size - 1)) - sve_size) * COMPSIZE; + aa = a + (m - mod - sve_size) * k * COMPSIZE; + cc = c + (m - mod - sve_size) * COMPSIZE; do { if (k - kk > 0) { diff --git a/kernel/arm64/trsm_kernel_LT_sve.c b/kernel/arm64/trsm_kernel_LT_sve.c index 7f54597..8c6a57a 100644 --- a/kernel/arm64/trsm_kernel_LT_sve.c +++ b/kernel/arm64/trsm_kernel_LT_sve.c @@ -257,7 +257,7 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1, i += sve_size; } - i = sve_size % m; + i = m % sve_size; if (i) { if (kk > 0) { GEMM_KERNEL(i, j, kk, dm1, diff --git a/kernel/arm64/trsm_kernel_RT_sve.c b/kernel/arm64/trsm_kernel_RT_sve.c index d93ebe7..efafc9d 100644 --- a/kernel/arm64/trsm_kernel_RT_sve.c +++ b/kernel/arm64/trsm_kernel_RT_sve.c @@ -258,23 +258,23 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG k, FLOAT dummy1, if (i <= m) { do { if (k - kk > 0) { - GEMM_KERNEL(GEMM_UNROLL_M, GEMM_UNROLL_N, k - kk, dm1, + GEMM_KERNEL(sve_size, GEMM_UNROLL_N, k - kk, dm1, #ifdef COMPLEX ZERO, #endif - aa + GEMM_UNROLL_M * kk * COMPSIZE, + aa + sve_size * kk * COMPSIZE, b + GEMM_UNROLL_N * kk * COMPSIZE, cc, ldc); } - solve(GEMM_UNROLL_M, GEMM_UNROLL_N, - aa + (kk - GEMM_UNROLL_N) * GEMM_UNROLL_M * COMPSIZE, + solve(sve_size, GEMM_UNROLL_N, + aa + (kk - GEMM_UNROLL_N) * sve_size * COMPSIZE, b + (kk - GEMM_UNROLL_N) * GEMM_UNROLL_N * COMPSIZE, cc, ldc); - aa += GEMM_UNROLL_M * k * COMPSIZE; - cc += GEMM_UNROLL_M * COMPSIZE; + aa += sve_size * k * COMPSIZE; + cc += sve_size * COMPSIZE; i += sve_size; } while (i <= m); } diff --git a/kernel/arm64/trsm_lncopy_sve.c b/kernel/arm64/trsm_lncopy_sve.c index d96a1f3..7f480dc 100644 --- a/kernel/arm64/trsm_lncopy_sve.c +++ b/kernel/arm64/trsm_lncopy_sve.c @@ -48,17 +48,18 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){ - BLASLONG i, ii, j, jj; + BLASLONG i, ii, jj; FLOAT *ao; jj = offset; - int js = 0; #ifdef DOUBLE + int64_t js = 0; svint64_t index = svindex_s64(0LL, lda); svbool_t pn = svwhilelt_b64(js, n); int n_active = svcntp_b64(svptrue_b64(), pn); #else + int32_t js = 0; svint32_t index = svindex_s32(0, lda); svbool_t pn = svwhilelt_b32(js, n); int n_active = svcntp_b32(svptrue_b32(), pn); @@ -74,25 +75,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT if (ii == jj) { for (int j = 0; j < n_active; j++) { for (int k = 0; k < j; k++) { - *(b + j * n_active + k) = *(a + k * lda + j); + *(b + j * n_active + k) = *(ao + k * lda + j); } - *(b + j * n_active + j) = INV(*(a + j * lda + j)); + *(b + j * n_active + j) = INV(*(ao + j * lda + j)); } - } - - if (ii > jj) { - for (int j = 0; j < n_active; j++) { + ao += n_active; + b += n_active * n_active; + i += n_active; + ii += n_active; + } else { + if (ii > jj) { svfloat64_t aj_vec = svld1_gather_index(pn, ao, index); svst1(pn, b, aj_vec); - ao++; } - + ao++; + b += n_active; + i++; + ii++; } - - b += n_active * n_active; - - i += n_active; - ii += n_active; } while (i < m); diff --git a/kernel/arm64/trsm_ltcopy_sve.c b/kernel/arm64/trsm_ltcopy_sve.c index 9012f7f..d7b2a4e 100644 --- a/kernel/arm64/trsm_ltcopy_sve.c +++ b/kernel/arm64/trsm_ltcopy_sve.c @@ -48,18 +48,17 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){ - BLASLONG i, ii, j, jj; + BLASLONG i, ii, jj; FLOAT *ao; jj = offset; - int js = 0; #ifdef DOUBLE - svint64_t index = svindex_s64(0LL, lda); + int64_t js = 0; svbool_t pn = svwhilelt_b64(js, n); int n_active = svcntp_b64(svptrue_b64(), pn); #else - svint32_t index = svindex_s32(0, lda); + int32_t js = 0; svbool_t pn = svwhilelt_b32(js, n); int n_active = svcntp_b32(svptrue_b32(), pn); #endif @@ -73,26 +72,25 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT if (ii == jj) { for (int j = 0; j < n_active; j++) { - *(b + j * n_active + j) = INV(*(a + j * lda + j)); + *(b + j * n_active + j) = INV(*(ao + j * lda + j)); for (int k = j+1; k < n_active; k++) { - *(b + j * n_active + k) = *(a + j * lda + k); + *(b + j * n_active + k) = *(ao + j * lda + k); } } - } - - if (ii < jj) { - for (int j = 0; j < n_active; j++) { + b += n_active * n_active; + ao += lda * n_active; + i += n_active; + ii += n_active; + } else { + if (ii < jj) { svfloat64_t aj_vec = svld1(pn, ao); svst1(pn, b, aj_vec); - ao += lda; } - + ao += lda; + b += n_active; + i ++; + ii ++; } - - b += n_active * n_active; - - i += n_active; - ii += n_active; } while (i < m); diff --git a/kernel/arm64/trsm_uncopy_sve.c b/kernel/arm64/trsm_uncopy_sve.c index 242e99f..b285145 100644 --- a/kernel/arm64/trsm_uncopy_sve.c +++ b/kernel/arm64/trsm_uncopy_sve.c @@ -48,17 +48,18 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){ - BLASLONG i, ii, j, jj; + BLASLONG i, ii, jj; FLOAT *ao; jj = offset; - int js = 0; #ifdef DOUBLE + int64_t js = 0; svint64_t index = svindex_s64(0LL, lda); svbool_t pn = svwhilelt_b64(js, n); int n_active = svcntp_b64(svptrue_b64(), pn); #else + int32_t js = 0; svint32_t index = svindex_s32(0, lda); svbool_t pn = svwhilelt_b32(js, n); int n_active = svcntp_b32(svptrue_b32(), pn); @@ -73,25 +74,25 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT if (ii == jj) { for (int j = 0; j < n_active; j++) { - *(b + j * n_active + j) = INV(*(a + j * lda + j)); + *(b + j * n_active + j) = INV(*(ao + j * lda + j)); for (int k = j+1; k < n_active; k++) { - *(b + j * n_active + k) = *(a + k * lda + j); + *(b + j * n_active + k) = *(ao + k * lda + j); } } - } - - if (ii < jj) { - for (int j = 0; j < n_active; j++) { + ao += n_active; + b += n_active * n_active; + i += n_active; + ii += n_active; + } else { + if (ii < jj) { svfloat64_t aj_vec = svld1_gather_index(pn, ao, index); svst1(pn, b, aj_vec); - ao++; } + ao++; + b += n_active; + i++; + ii++; } - - b += n_active * n_active; - - i += n_active; - ii += n_active; } while (i < m); diff --git a/kernel/arm64/trsm_utcopy_sve.c b/kernel/arm64/trsm_utcopy_sve.c index 9eefb8c..5589558 100644 --- a/kernel/arm64/trsm_utcopy_sve.c +++ b/kernel/arm64/trsm_utcopy_sve.c @@ -48,18 +48,17 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT *b){ - BLASLONG i, ii, j, jj; + BLASLONG i, ii, jj; FLOAT *ao; jj = offset; - int js = 0; #ifdef DOUBLE - svint64_t index = svindex_s64(0LL, lda); + int64_t js = 0; svbool_t pn = svwhilelt_b64(js, n); int n_active = svcntp_b64(svptrue_b64(), pn); #else - svint32_t index = svindex_s32(0, lda); + int32_t js = 0; svbool_t pn = svwhilelt_b32(js, n); int n_active = svcntp_b32(svptrue_b32(), pn); #endif @@ -74,25 +73,24 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *a, BLASLONG lda, BLASLONG offset, FLOAT if (ii == jj) { for (int j = 0; j < n_active; j++) { for (int k = 0; k < j; k++) { - *(b + j * n_active + k) = *(a + j * lda + k); + *(b + j * n_active + k) = *(ao + j * lda + k); } - *(b + j * n_active + j) = INV(*(a + j * lda + j)); + *(b + j * n_active + j) = INV(*(ao + j * lda + j)); } - } - - if (ii > jj) { - for (int j = 0; j < n_active; j++) { + ao += lda * n_active; + b += n_active * n_active; + i += n_active; + ii += n_active; + } else { + if (ii > jj) { svfloat64_t aj_vec = svld1(pn, ao); svst1(pn, b, aj_vec); - ao += lda; } - - } - - b += n_active * n_active; - - i += n_active; - ii += n_active; + ao += lda; + b += n_active; + i ++; + ii ++; + } } while (i < m); -- 2.7.4