Add early returns and fix sign errors in workspace calculations
authorMartin Kroeker <martin@ruby.chemie.uni-freiburg.de>
Thu, 27 Aug 2020 09:25:18 +0000 (11:25 +0200)
committerGitHub <noreply@github.com>
Thu, 27 Aug 2020 09:25:18 +0000 (11:25 +0200)
relapack/src/cgbtrf.c
relapack/src/cpbtrf.c
relapack/src/dgbtrf.c
relapack/src/dpbtrf.c
relapack/src/sgbtrf.c
relapack/src/spbtrf.c
relapack/src/zgbtrf.c
relapack/src/zpbtrf.c

index 61332c6..e52f2e6 100644 (file)
@@ -36,6 +36,7 @@ void RELAPACK_cgbtrf(
         return;
     }
 
+    if (*m == 0 || *n == 0) return;
     // Constant
     const float ZERO[] = { 0., 0. };
 
@@ -56,10 +57,10 @@ void RELAPACK_cgbtrf(
 
     // Allocate work space
     const blasint n1 = CREC_SPLIT(*n);
-    const blasint mWorkl = (kv > n1) ? MAX(1, *m - *kl) : kv;
-    const blasint nWorkl = (kv > n1) ? n1 : kv;
-    const blasint mWorku = (*kl > n1) ? n1 : *kl;
-    const blasint nWorku = (*kl > n1) ? MAX(0, *n - *kl) : *kl;
+    const blasint mWorkl = abs ( (kv > n1) ? MAX(1, *m - *kl) : kv);
+    const blasint nWorkl = abs ( (kv > n1) ? n1 : kv);
+    const blasint mWorku = abs ((*kl > n1) ? n1 : *kl);
+    const blasint nWorku = abs ((*kl > n1) ? MAX(0, *n - *kl) : *kl);
     float *Workl = malloc(mWorkl * nWorkl * 2 * sizeof(float));
     float *Worku = malloc(mWorku * nWorku * 2 * sizeof(float));
     LAPACK(claset)("L", &mWorkl, &nWorkl, ZERO, ZERO, Workl, &mWorkl);
@@ -82,7 +83,7 @@ static void RELAPACK_cgbtrf_rec(
     blasint *info
 ) {
 
-    if (*n <= MAX(CROSSOVER_CGBTRF, 1)) {
+    if (*n <= MAX(CROSSOVER_CGBTRF, 1)|| *n > *kl || *ldAb == 1) {
         // Unblocked
         LAPACK(cgbtf2)(m, n, kl, ku, Ab, ldAb, ipiv, info);
         return;
index 971e547..a0fa138 100644 (file)
@@ -35,6 +35,8 @@ void RELAPACK_cpbtrf(
         return;
     }
 
+    if (*n == 0) return;
+
     // Clean char * arguments
     const char cleanuplo = lower ? 'L' : 'U';
 
@@ -43,8 +45,8 @@ void RELAPACK_cpbtrf(
 
     // Allocate work space
     const blasint n1 = CREC_SPLIT(*n);
-    const blasint mWork = (*kd > n1) ? (lower ? *n - *kd : n1) : *kd;
-    const blasint nWork = (*kd > n1) ? (lower ? n1 : *n - *kd) : *kd;
+    const blasint mWork = abs((*kd > n1) ? (lower ? *n - *kd : n1) : *kd);
+    const blasint nWork = abs((*kd > n1) ? (lower ? n1 : *n - *kd) : *kd);
     float *Work = malloc(mWork * nWork * 2 * sizeof(float));
     LAPACK(claset)(uplo, &mWork, &nWork, ZERO, ZERO, Work, &mWork);
 
@@ -64,7 +66,7 @@ static void RELAPACK_cpbtrf_rec(
     blasint *info
 ){
 
-    if (*n <= MAX(CROSSOVER_CPBTRF, 1)) {
+    if (*n <= MAX(CROSSOVER_CPBTRF, 1) || *ldAb==1) {
         // Unblocked
         LAPACK(cpbtf2)(uplo, n, kd, Ab, ldAb, info);
         return;
@@ -148,7 +150,7 @@ static void RELAPACK_cpbtrf_rec(
     }
 
     // recursion(A_BR)
-    if (*kd > n1)
+    if (*kd > n1 && ldA != 0)
         RELAPACK_cpotrf(uplo, &n2, A_BR, ldA, info);
     else
         RELAPACK_cpbtrf_rec(uplo, &n2, kd, Ab_BR, ldAb, Work, ldWork, info);
index cdf06ad..aac10f2 100644 (file)
@@ -36,6 +36,8 @@ void RELAPACK_dgbtrf(
         return;
     }
 
+    if (*m == 0 || *n == 0) return;
+
     // Constant
     const double ZERO[] = { 0. };
 
@@ -83,7 +85,7 @@ static void RELAPACK_dgbtrf_rec(
     blasint *info
 ) {
 
-    if (*n <= MAX(CROSSOVER_DGBTRF, 1)) {
+    if (*n <= MAX(CROSSOVER_DGBTRF, 1) || *n > *kl || *ldAb == 1) {
         // Unblocked
         LAPACK(dgbtf2)(m, n, kl, ku, Ab, ldAb, ipiv, info);
         return;
@@ -195,6 +197,7 @@ static void RELAPACK_dgbtrf_rec(
     // Worku = A_TRr
     LAPACK(dlacpy)("L", &m1, &n22, A_TRr, ldA, Worku, ldWorku);
     // Worku = A_TL \ Worku
+    if (ldWorku <= 0) return;
     BLAS(dtrsm)("L", "L", "N", "U", &m1, &n22, ONE, A_TL, ldA, Worku, ldWorku);
     // A_TRr = Worku
     LAPACK(dlacpy)("L", &m1, &n22, Worku, ldWorku, A_TRr, ldA);
index 9380b28..94e9b80 100644 (file)
@@ -35,6 +35,8 @@ void RELAPACK_dpbtrf(
         return;
     }
 
+    if (*n == 0) return;
+
     // Clean char * arguments
     const char cleanuplo = lower ? 'L' : 'U';
 
@@ -43,8 +45,8 @@ void RELAPACK_dpbtrf(
 
     // Allocate work space
     const blasint n1 = DREC_SPLIT(*n);
-    const blasint mWork = (*kd > n1) ? (lower ? *n - *kd : n1) : *kd;
-    const blasint nWork = (*kd > n1) ? (lower ? n1 : *n - *kd) : *kd;
+    const blasint mWork = abs((*kd > n1) ? (lower ? *n - *kd : n1) : *kd);
+    const blasint nWork = abs((*kd > n1) ? (lower ? n1 : *n - *kd) : *kd);
     double *Work = malloc(mWork * nWork * sizeof(double));
     LAPACK(dlaset)(uplo, &mWork, &nWork, ZERO, ZERO, Work, &mWork);
 
@@ -64,7 +66,7 @@ static void RELAPACK_dpbtrf_rec(
     blasint *info
 ){
 
-    if (*n <= MAX(CROSSOVER_DPBTRF, 1)) {
+    if (*n <= MAX(CROSSOVER_DPBTRF, 1) || *ldAb == 1) {
         // Unblocked
         LAPACK(dpbtf2)(uplo, n, kd, Ab, ldAb, info);
         return;
@@ -148,7 +150,7 @@ static void RELAPACK_dpbtrf_rec(
     }
 
     // recursion(A_BR)
-    if (*kd > n1)
+    if (*kd > n1 && ldA != 0)
         RELAPACK_dpotrf(uplo, &n2, A_BR, ldA, info);
     else
         RELAPACK_dpbtrf_rec(uplo, &n2, kd, Ab_BR, ldAb, Work, ldWork, info);
index 3e3fdf4..76e84e6 100644 (file)
@@ -35,6 +35,13 @@ void RELAPACK_sgbtrf(
         return;
     }
 
+    if (*m == 0 || *n == 0) return;
+
+    if (*ldAb == 1) {
+        LAPACK(sgbtf2)(m, n, kl, ku, Ab, ldAb, ipiv, info);
+       return;
+    }
+
     // Constant
     const float ZERO[] = { 0. };
 
@@ -82,8 +89,9 @@ static void RELAPACK_sgbtrf_rec(
     blasint *info
 ) {
 
+    if (*m == 0 || *n == 0) return;
 
-    if (*n <= MAX(CROSSOVER_SGBTRF, 1)) {
+    if ( *n <= MAX(CROSSOVER_SGBTRF, 1) || *n > *kl || *ldAb == 1) {
         // Unblocked
         LAPACK(sgbtf2)(m, n, kl, ku, Ab, ldAb, ipiv, info);
         return;
@@ -160,7 +168,7 @@ static void RELAPACK_sgbtrf_rec(
 
     // recursion(Ab_L, ipiv_T)
     RELAPACK_sgbtrf_rec(m, &n1, kl, ku, Ab_L, ldAb, ipiv_T, Workl, ldWorkl, Worku, ldWorku, info);
-
+    if (*info) return;
     // Workl = A_BLb
     LAPACK(slacpy)("U", &m22, &n1, A_BLb, ldA, Workl, ldWorkl);
 
@@ -222,8 +230,8 @@ static void RELAPACK_sgbtrf_rec(
 
     // recursion(Ab_BR, ipiv_B)
 //cause of infinite recursion here ?    
-//      RELAPACK_sgbtrf_rec(&m2, &n2, kl, ku, Ab_BR, ldAb, ipiv_B, Workl, ldWorkl, Worku, ldWorku, info);
-        LAPACK(sgbtf2)(&m2, &n2, kl, ku, Ab_BR, ldAb, ipiv_B, info);
+      RELAPACK_sgbtrf_rec(&m2, &n2, kl, ku, Ab_BR, ldAb, ipiv_B, Workl, ldWorkl, Worku, ldWorku, info);
+//        LAPACK(sgbtf2)(&m2, &n2, kl, ku, Ab_BR, ldAb, ipiv_B, info);
     if (*info)
         *info += n1;
     // shift pivots
index 26804dc..3302763 100644 (file)
@@ -35,6 +35,9 @@ void RELAPACK_spbtrf(
         return;
     }
 
+
+    if (*n == 0) return;
+
     // Clean char * arguments
     const char cleanuplo = lower ? 'L' : 'U';
 
@@ -43,8 +46,8 @@ void RELAPACK_spbtrf(
 
     // Allocate work space
     const blasint n1 = SREC_SPLIT(*n);
-    const blasint mWork = (*kd > n1) ? (lower ? *n - *kd : n1) : *kd;
-    const blasint nWork = (*kd > n1) ? (lower ? n1 : *n - *kd) : *kd;
+    const blasint mWork = abs( (*kd > n1) ? (lower ? *n - *kd : n1) : *kd);
+    const blasint nWork = abs((*kd > n1) ? (lower ? n1 : *n - *kd) : *kd);
     float *Work = malloc(mWork * nWork * sizeof(float));
     LAPACK(slaset)(uplo, &mWork, &nWork, ZERO, ZERO, Work, &mWork);
 
@@ -64,7 +67,9 @@ static void RELAPACK_spbtrf_rec(
     blasint *info
 ){
 
-    if (*n <= MAX(CROSSOVER_SPBTRF, 1)) {
+    if (*n == 0 ) return;
+
+    if ( *n <= MAX(CROSSOVER_SPBTRF, 1) || *ldAb == 1) {
         // Unblocked
         LAPACK(spbtf2)(uplo, n, kd, Ab, ldAb, info);
         return;
@@ -148,7 +153,7 @@ static void RELAPACK_spbtrf_rec(
     }
 
     // recursion(A_BR)
-    if (*kd > n1)
+    if (*kd > n1 && ldA != 0)
         RELAPACK_spotrf(uplo, &n2, A_BR, ldA, info);
     else
         RELAPACK_spbtrf_rec(uplo, &n2, kd, Ab_BR, ldAb, Work, ldWork, info);
index d4ba417..5d7dfd3 100644 (file)
@@ -36,6 +36,8 @@ void RELAPACK_zgbtrf(
         return;
     }
 
+    if (*m == 0 || *n == 0) return;
+
     // Constant
     const double ZERO[] = { 0., 0. };
 
@@ -82,7 +84,7 @@ static void RELAPACK_zgbtrf_rec(
     blasint *info
 ) {
 
-    if (*n <= MAX(CROSSOVER_ZGBTRF, 1)) {
+    if (*n <= MAX(CROSSOVER_ZGBTRF, 1) || *n > *kl || *ldAb == 1) {
         // Unblocked
         LAPACK(zgbtf2)(m, n, kl, ku, Ab, ldAb, ipiv, info);
         return;
@@ -92,6 +94,7 @@ static void RELAPACK_zgbtrf_rec(
     const double ONE[]  = { 1., 0. };
     const double MONE[] = { -1., 0. };
     const blasint    iONE[] = { 1 };
+    const blasint min11 = -11;
 
     // Loop iterators
     blasint i, j;
@@ -158,6 +161,7 @@ static void RELAPACK_zgbtrf_rec(
 
     // recursion(Ab_L, ipiv_T)
     RELAPACK_zgbtrf_rec(m, &n1, kl, ku, Ab_L, ldAb, ipiv_T, Workl, ldWorkl, Worku, ldWorku, info);
+if (*info) return;
 
     // Workl = A_BLb
     LAPACK(zlacpy)("U", &m22, &n1, A_BLb, ldA, Workl, ldWorkl);
@@ -193,11 +197,21 @@ static void RELAPACK_zgbtrf_rec(
     }
 
     // A_TRl = A_TL \ A_TRl
+    if (*ldA < MAX(1,m1)) {
+        LAPACK(xerbla)("ZGBTRF", &min11, strlen("ZGBTRF"));
+        return;
+    } else {
     BLAS(ztrsm)("L", "L", "N", "U", &m1, &n21, ONE, A_TL, ldA, A_TRl, ldA);
+    }
     // Worku = A_TRr
     LAPACK(zlacpy)("L", &m1, &n22, A_TRr, ldA, Worku, ldWorku);
     // Worku = A_TL \ Worku
+    if (*ldWorku < MAX(1,m1)) {
+        LAPACK(xerbla)("ZGBTRF", &min11, strlen("ZGBTRF"));
+        return;
+    } else {
     BLAS(ztrsm)("L", "L", "N", "U", &m1, &n22, ONE, A_TL, ldA, Worku, ldWorku);
+    }
     // A_TRr = Worku
     LAPACK(zlacpy)("L", &m1, &n22, Worku, ldWorku, A_TRr, ldA);
     // A_BRtl = A_BRtl - A_BLt * A_TRl
index fb0e1e9..8b09438 100644 (file)
@@ -35,6 +35,8 @@ void RELAPACK_zpbtrf(
         return;
     }
 
+    if (*n == 0) return;
+
     // Clean char * arguments
     const char cleanuplo = lower ? 'L' : 'U';
 
@@ -43,9 +45,10 @@ void RELAPACK_zpbtrf(
 
     // Allocate work space
     const blasint n1 = ZREC_SPLIT(*n);
-    const blasint mWork = (*kd > n1) ? (lower ? *n - *kd : n1) : *kd;
-    const blasint nWork = (*kd > n1) ? (lower ? n1 : *n - *kd) : *kd;
+    const blasint mWork = abs((*kd > n1) ? (lower ? *n - *kd : n1) : *kd);
+    const blasint nWork = abs((*kd > n1) ? (lower ? n1 : *n - *kd) : *kd);
     double *Work = malloc(mWork * nWork * 2 * sizeof(double));
+
     LAPACK(zlaset)(uplo, &mWork, &nWork, ZERO, ZERO, Work, &mWork);
 
     // Recursive kernel
@@ -64,7 +67,7 @@ static void RELAPACK_zpbtrf_rec(
     blasint *info
 ){
 
-    if (*n <= MAX(CROSSOVER_ZPBTRF, 1)) {
+    if (*n <= MAX(CROSSOVER_ZPBTRF, 1) || *ldAb == 1) {
         // Unblocked
         LAPACK(zpbtf2)(uplo, n, kd, Ab, ldAb, info);
         return;
@@ -148,7 +151,7 @@ static void RELAPACK_zpbtrf_rec(
     }
 
     // recursion(A_BR)
-    if (*kd > n1)
+    if (*kd > n1 && ldA != 0)
         RELAPACK_zpotrf(uplo, &n2, A_BR, ldA, info);
     else
         RELAPACK_zpbtrf_rec(uplo, &n2, kd, Ab_BR, ldAb, Work, ldWork, info);