From 3ea4dadd3033b60397b485499bfac1f0e486d04b Mon Sep 17 00:00:00 2001 From: wernsaar Date: Fri, 25 Jul 2014 11:59:17 +0200 Subject: [PATCH] optimizations for trsm --- driver/level3/trsm_L.c | 8 ++++++-- driver/level3/trsm_R.c | 16 ++++++++++++---- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/driver/level3/trsm_L.c b/driver/level3/trsm_L.c index fa3b0d5..78da0eb 100644 --- a/driver/level3/trsm_L.c +++ b/driver/level3/trsm_L.c @@ -128,7 +128,9 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = js; jjs < js + min_j; jjs += min_jj){ min_jj = min_j + js - jjs; - if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; + if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + else + if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; GEMM_ONCOPY(min_l, min_jj, b + (ls + jjs * ldb) * COMPSIZE, ldb, sb + min_l * (jjs - js) * COMPSIZE); @@ -194,7 +196,9 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = js; jjs < js + min_j; jjs += min_jj){ min_jj = min_j + js - jjs; - if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; + if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + else + if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; GEMM_ONCOPY(min_l, min_jj, b + (ls - min_l + jjs * ldb) * COMPSIZE, ldb, sb + min_l * (jjs - js) * COMPSIZE); diff --git a/driver/level3/trsm_R.c b/driver/level3/trsm_R.c index b6ee956..169441d 100644 --- a/driver/level3/trsm_R.c +++ b/driver/level3/trsm_R.c @@ -123,7 +123,9 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = js; jjs < js + min_j; jjs += min_jj){ min_jj = min_j + js - jjs; - if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; + if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + else + if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; #ifndef TRANSA GEMM_ONCOPY(min_l, min_jj, a + (ls + jjs * lda) * COMPSIZE, lda, sb + min_l * (jjs - js) * COMPSIZE); @@ -177,7 +179,9 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = 0; jjs < min_j - min_l - ls + js; jjs += min_jj){ min_jj = min_j - min_l - ls + js - jjs; - if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; + if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + else + if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; #ifndef TRANSA GEMM_ONCOPY (min_l, min_jj, a + (ls + (ls + min_l + jjs) * lda) * COMPSIZE, lda, @@ -238,7 +242,9 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = js; jjs < js + min_j; jjs += min_jj){ min_jj = min_j + js - jjs; - if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; + if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + else + if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; #ifndef TRANSA GEMM_ONCOPY(min_l, min_jj, a + (ls + (jjs - min_j) * lda) * COMPSIZE, lda, sb + min_l * (jjs - js) * COMPSIZE); @@ -297,7 +303,9 @@ int CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa, FLO for(jjs = 0; jjs < min_j - js + ls; jjs += min_jj){ min_jj = min_j - js + ls - jjs; - if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; + if (min_jj > GEMM_UNROLL_N*3) min_jj = GEMM_UNROLL_N*3; + else + if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N; #ifndef TRANSA GEMM_ONCOPY (min_l, min_jj, a + (ls + (js - min_j + jjs) * lda) * COMPSIZE, lda, -- 2.7.4