Add Elmar Peise's ReLAPACK
[platform/upstream/openblas.git] / relapack / src / zsytrf.c
1 #include "relapack.h"
2 #if XSYTRF_ALLOW_MALLOC
3 #include <stdlib.h>
4 #endif
5
6 static void RELAPACK_zsytrf_rec(const char *, const int *, const int *, int *,
7     double *, const int *, int *, double *, const int *, int *);
8
9
10 /** ZSYTRF computes the factorization of a complex symmetric matrix A using the Bunch-Kaufman diagonal pivoting method.
11  *
12  * This routine is functionally equivalent to LAPACK's zsytrf.
13  * For details on its interface, see
14  * http://www.netlib.org/lapack/explore-html/da/d94/zsytrf_8f.html
15  * */
16 void RELAPACK_zsytrf(
17     const char *uplo, const int *n,
18     double *A, const int *ldA, int *ipiv,
19     double *Work, const int *lWork, int *info
20 ) {
21
22     // Required work size
23     const int cleanlWork = *n * (*n / 2);
24     int minlWork = cleanlWork;
25 #if XSYTRF_ALLOW_MALLOC
26     minlWork = 1;
27 #endif
28
29     // Check arguments
30     const int lower = LAPACK(lsame)(uplo, "L");
31     const int upper = LAPACK(lsame)(uplo, "U");
32     *info = 0;
33     if (!lower && !upper)
34         *info = -1;
35     else if (*n < 0)
36         *info = -2;
37     else if (*ldA < MAX(1, *n))
38         *info = -4;
39     else if (*lWork < minlWork && *lWork != -1)
40         *info = -7;
41     else if (*lWork == -1) {
42         // Work size query
43         *Work = cleanlWork;
44         return;
45     }
46
47     // Ensure Work size
48     double *cleanWork = Work;
49 #if XSYTRF_ALLOW_MALLOC
50     if (!*info && *lWork < cleanlWork) {
51         cleanWork = malloc(cleanlWork * 2 * sizeof(double));
52         if (!cleanWork)
53             *info = -7;
54     }
55 #endif
56
57     if (*info) {
58         const int minfo = -*info;
59         LAPACK(xerbla)("ZSYTRF", &minfo);
60         return;
61     }
62
63     // Clean char * arguments
64     const char cleanuplo = lower ? 'L' : 'U';
65
66     // Dummy arguments
67     int nout;
68
69     // Recursive kernel
70     RELAPACK_zsytrf_rec(&cleanuplo, n, n, &nout, A, ldA, ipiv, cleanWork, n, info);
71
72 #if XSYTRF_ALLOW_MALLOC
73     if (cleanWork != Work)
74         free(cleanWork);
75 #endif
76 }
77
78
79 /** zsytrf's recursive compute kernel */
80 static void RELAPACK_zsytrf_rec(
81     const char *uplo, const int *n_full, const int *n, int *n_out,
82     double *A, const int *ldA, int *ipiv,
83     double *Work, const int *ldWork, int *info
84 ) {
85
86     // top recursion level?
87     const int top = *n_full == *n;
88
89     if (*n <= MAX(CROSSOVER_ZSYTRF, 3)) {
90         // Unblocked
91         if (top) {
92             LAPACK(zsytf2)(uplo, n, A, ldA, ipiv, info);
93             *n_out = *n;
94         } else
95             RELAPACK_zsytrf_rec2(uplo, n_full, n, n_out, A, ldA, ipiv, Work, ldWork, info);
96         return;
97     }
98
99     int info1, info2;
100
101     // Constants
102     const double ONE[]  = { 1., 0. };
103     const double MONE[] = { -1., 0. };
104     const int    iONE[] = { 1 };
105
106     // Loop iterator
107     int i;
108
109     const int n_rest = *n_full - *n;
110
111     if (*uplo == 'L') {
112         // Splitting (setup)
113         int n1 = ZREC_SPLIT(*n);
114         int n2 = *n - n1;
115
116         // Work_L *
117         double *const Work_L = Work;
118
119         // recursion(A_L)
120         int n1_out;
121         RELAPACK_zsytrf_rec(uplo, n_full, &n1, &n1_out, A, ldA, ipiv, Work_L, ldWork, &info1);
122         n1 = n1_out;
123
124         // Splitting (continued)
125         n2 = *n - n1;
126         const int n_full2 = *n_full - n1;
127
128         // *      *
129         // A_BL   A_BR
130         // A_BL_B A_BR_B
131         double *const A_BL   = A                 + 2 * n1;
132         double *const A_BR   = A + 2 * *ldA * n1 + 2 * n1;
133         double *const A_BL_B = A                 + 2 * *n;
134         double *const A_BR_B = A + 2 * *ldA * n1 + 2 * *n;
135
136         // *        *
137         // Work_BL Work_BR
138         // *       *
139         // (top recursion level: use Work as Work_BR)
140         double *const Work_BL =              Work                    + 2 * n1;
141         double *const Work_BR = top ? Work : Work + 2 * *ldWork * n1 + 2 * n1;
142         const int ldWork_BR = top ? n2 : *ldWork;
143
144         // ipiv_T
145         // ipiv_B
146         int *const ipiv_B = ipiv + n1;
147
148         // A_BR = A_BR - A_BL Work_BL'
149         RELAPACK_zgemmt(uplo, "N", "T", &n2, &n1, MONE, A_BL, ldA, Work_BL, ldWork, ONE, A_BR, ldA);
150         BLAS(zgemm)("N", "T", &n_rest, &n2, &n1, MONE, A_BL_B, ldA, Work_BL, ldWork, ONE, A_BR_B, ldA);
151
152         // recursion(A_BR)
153         int n2_out;
154         RELAPACK_zsytrf_rec(uplo, &n_full2, &n2, &n2_out, A_BR, ldA, ipiv_B, Work_BR, &ldWork_BR, &info2);
155
156         if (n2_out != n2) {
157             // undo 1 column of updates
158             const int n_restp1 = n_rest + 1;
159
160             // last column of A_BR
161             double *const A_BR_r = A_BR + 2 * *ldA * n2_out + 2 * n2_out;
162
163             // last row of A_BL
164             double *const A_BL_b = A_BL + 2 * n2_out;
165
166             // last row of Work_BL
167             double *const Work_BL_b = Work_BL + 2 * n2_out;
168
169             // A_BR_r = A_BR_r + A_BL_b Work_BL_b'
170             BLAS(zgemv)("N", &n_restp1, &n1, ONE, A_BL_b, ldA, Work_BL_b, ldWork, ONE, A_BR_r, iONE);
171         }
172         n2 = n2_out;
173
174         // shift pivots
175         for (i = 0; i < n2; i++)
176             if (ipiv_B[i] > 0)
177                 ipiv_B[i] += n1;
178             else
179                 ipiv_B[i] -= n1;
180
181         *info  = info1 || info2;
182         *n_out = n1 + n2;
183     } else {
184         // Splitting (setup)
185         int n2 = ZREC_SPLIT(*n);
186         int n1 = *n - n2;
187
188         // * Work_R
189         // (top recursion level: use Work as Work_R)
190         double *const Work_R = top ? Work : Work + 2 * *ldWork * n1;
191
192         // recursion(A_R)
193         int n2_out;
194         RELAPACK_zsytrf_rec(uplo, n_full, &n2, &n2_out, A, ldA, ipiv, Work_R, ldWork, &info2);
195         const int n2_diff = n2 - n2_out;
196         n2 = n2_out;
197
198         // Splitting (continued)
199         n1 = *n - n2;
200         const int n_full1  = *n_full - n2;
201
202         // * A_TL_T A_TR_T
203         // * A_TL   A_TR
204         // * *      *
205         double *const A_TL_T = A + 2 * *ldA * n_rest;
206         double *const A_TR_T = A + 2 * *ldA * (n_rest + n1);
207         double *const A_TL   = A + 2 * *ldA * n_rest        + 2 * n_rest;
208         double *const A_TR   = A + 2 * *ldA * (n_rest + n1) + 2 * n_rest;
209
210         // Work_L *
211         // *      Work_TR
212         // *      *
213         // (top recursion level: Work_R was Work)
214         double *const Work_L  = Work;
215         double *const Work_TR = Work + 2 * *ldWork * (top ? n2_diff : n1) + 2 * n_rest;
216         const int ldWork_L = top ? n1 : *ldWork;
217
218         // A_TL = A_TL - A_TR Work_TR'
219         RELAPACK_zgemmt(uplo, "N", "T", &n1, &n2, MONE, A_TR, ldA, Work_TR, ldWork, ONE, A_TL, ldA);
220         BLAS(zgemm)("N", "T", &n_rest, &n1, &n2, MONE, A_TR_T, ldA, Work_TR, ldWork, ONE, A_TL_T, ldA);
221
222         // recursion(A_TL)
223         int n1_out;
224         RELAPACK_zsytrf_rec(uplo, &n_full1, &n1, &n1_out, A, ldA, ipiv, Work_L, &ldWork_L, &info1);
225
226         if (n1_out != n1) {
227             // undo 1 column of updates
228             const int n_restp1 = n_rest + 1;
229
230             // A_TL_T_l = A_TL_T_l + A_TR_T Work_TR_t'
231             BLAS(zgemv)("N", &n_restp1, &n2, ONE, A_TR_T, ldA, Work_TR, ldWork, ONE, A_TL_T, iONE);
232         }
233         n1 = n1_out;
234
235         *info  = info2 || info1;
236         *n_out = n1 + n2;
237     }
238 }