Use acq/rel semantics to pass flags/pointers in getrf_parallel.
authorAli Saidi <alisaidi@amazon.com>
Mon, 24 Feb 2020 05:45:30 +0000 (05:45 +0000)
committerAli Saidi <alisaidi@amazon.com>
Fri, 6 Mar 2020 06:22:31 +0000 (06:22 +0000)
The current implementation has locks, but the locks each only
have a critical section of one variable so atomic reads/writes
with barriers can be used to achieve the same behavior.

Like the previous patch, pthread_mutex_lock isn't fair, so in a
tight loop the previous thread that has the lock can keep it
starving another thread, even if that thread is about to write
the data that will stop the current thread from spinning.

On a 64c Arm system this improves performance by 20x on sgesv.goto.

lapack/getrf/getrf_parallel.c

index c82defc..c602822 100644 (file)
@@ -68,25 +68,16 @@ double sqrt(double);
 #define GETRF_FACTOR 1.00
 
 
-#if   defined(USE_PTHREAD_LOCK)
-static pthread_mutex_t    getrf_lock = PTHREAD_MUTEX_INITIALIZER;
-#elif defined(USE_PTHREAD_SPINLOCK)
-static pthread_spinlock_t getrf_lock = 0;
+#if (__STDC_VERSION__ >= 201112L)
+#define        atomic_load_long(p)             __atomic_load_n(p, __ATOMIC_RELAXED)
+#define        atomic_store_long(p, v)         __atomic_store_n(p, v, __ATOMIC_RELAXED)
 #else
-static BLASULONG  getrf_lock = 0UL;
-#endif
-
-#if   defined(USE_PTHREAD_LOCK)
-static pthread_mutex_t    getrf_flag_lock = PTHREAD_MUTEX_INITIALIZER;
-#elif defined(USE_PTHREAD_SPINLOCK)
-static pthread_spinlock_t getrf_flag_lock = 0;
-#else
-static BLASULONG  getrf_flag_lock = 0UL;
+#define        atomic_load_long(p)             (BLASLONG)(*(volatile BLASLONG*)(p))
+#define        atomic_store_long(p, v)         (*(volatile BLASLONG *)(p)) = (v)
 #endif
 
 
 
-
 static __inline BLASLONG FORMULA1(BLASLONG M, BLASLONG N, BLASLONG IS, BLASLONG BK, BLASLONG T) {
 
   double m = (double)(M - IS - BK);
@@ -119,11 +110,7 @@ static void inner_basic_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *ra
   FLOAT *d = (FLOAT *)args -> b + (k + k * lda) * COMPSIZE;
   FLOAT *sbb = sb;
 
-#if __STDC_VERSION__ >= 201112L
-  _Atomic BLASLONG *flag = (_Atomic BLASLONG *)args -> d;
-#else
   volatile BLASLONG *flag = (volatile BLASLONG *)args -> d;
-#endif
 
   blasint *ipiv = (blasint *)args -> c;
 
@@ -180,7 +167,10 @@ static void inner_basic_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *ra
       }
     }
 
-    if ((js + REAL_GEMM_R >= n) && (mypos >= 0)) flag[mypos * CACHE_LINE_SIZE] = 0;
+    if ((js + REAL_GEMM_R >= n) && (mypos >= 0)) {
+           MB;
+           atomic_store_long(&flag[mypos * CACHE_LINE_SIZE], 0);
+    }
 
     for (is = 0; is < m; is += GEMM_P){
       min_i = m - is;
@@ -201,14 +191,10 @@ static void inner_basic_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *ra
 /* Non blocking implementation */
 
 typedef struct {
-#if __STDC_VERSION__ >= 201112L
-  _Atomic
-#else
-  volatile
-#endif
-   BLASLONG working[MAX_CPU_NUMBER][CACHE_LINE_SIZE * DIVIDE_RATE];
+  volatile BLASLONG working[MAX_CPU_NUMBER][CACHE_LINE_SIZE * DIVIDE_RATE];
 } job_t;
 
+
 #define ICOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ITCOPY(M, N, (FLOAT *)(A) + ((Y) + (X) * (LDA)) * COMPSIZE, LDA, BUFFER);
 #define OCOPY_OPERATION(M, N, A, LDA, X, Y, BUFFER) GEMM_ONCOPY(M, N, (FLOAT *)(A) + ((X) + (Y) * (LDA)) * COMPSIZE, LDA, BUFFER);
 
@@ -246,11 +232,8 @@ static int inner_advanced_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *
 
   blasint *ipiv = (blasint *)args -> c;
   BLASLONG jw;
-#if __STDC_VERSION__ >= 201112L
-  _Atomic BLASLONG *flag = (_Atomic BLASLONG *)args -> d;
-#else
   volatile BLASLONG *flag = (volatile BLASLONG *)args -> d;
-#endif
+
   if (args -> a == NULL) {
     TRSM_ILTCOPY(k, k, (FLOAT *)args -> b, lda, 0, sb);
     sbb = (FLOAT *)((((BLASULONG)(sb + k * k * COMPSIZE) + GEMM_ALIGN) & ~GEMM_ALIGN) + GEMM_OFFSET_B);
@@ -280,10 +263,9 @@ static int inner_advanced_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *
 #if 1
     {
        do {
-           LOCK_COMMAND(&getrf_lock);
-           jw = job[mypos].working[i][CACHE_LINE_SIZE * bufferside];
-           UNLOCK_COMMAND(&getrf_lock);
+          jw =  atomic_load_long(&job[mypos].working[i][CACHE_LINE_SIZE * bufferside]);
        } while (jw);
+       MB;
     }
 #else
       while (job[mypos].working[i][CACHE_LINE_SIZE * bufferside]) {};
@@ -326,21 +308,17 @@ static int inner_advanced_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *
     }
     MB;
     for (i = 0; i < args -> nthreads; i++) {
-      LOCK_COMMAND(&getrf_lock);
-      job[mypos].working[i][CACHE_LINE_SIZE * bufferside] = (BLASLONG)buffer[bufferside];
-      UNLOCK_COMMAND(&getrf_lock);
+      atomic_store_long(&job[mypos].working[i][CACHE_LINE_SIZE * bufferside], (BLASLONG)buffer[bufferside]);
     }
   }
 
-  LOCK_COMMAND(&getrf_flag_lock);
-  flag[mypos * CACHE_LINE_SIZE] = 0;
-  UNLOCK_COMMAND(&getrf_flag_lock);
+  MB;
+  atomic_store_long(&flag[mypos * CACHE_LINE_SIZE], 0);
 
   if (m == 0) {
+    MB;
     for (xxx = 0; xxx < DIVIDE_RATE; xxx++) {
-      LOCK_COMMAND(&getrf_lock);
-      job[mypos].working[mypos][CACHE_LINE_SIZE * xxx] = 0;
-      UNLOCK_COMMAND(&getrf_lock);
+      atomic_store_long(&job[mypos].working[mypos][CACHE_LINE_SIZE * xxx], 0);
     }
   }
 
@@ -366,10 +344,9 @@ static int inner_advanced_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *
          if ((current != mypos) && (!is)) {
 #if 1
                do {
-                   LOCK_COMMAND(&getrf_lock);
-                   jw = job[current].working[mypos][CACHE_LINE_SIZE * bufferside];
-                   UNLOCK_COMMAND(&getrf_lock);
-               } while (jw == 0);
+                  jw =  atomic_load_long(&job[current].working[mypos][CACHE_LINE_SIZE * bufferside]);
+               } while (jw == 0);
+               MB;
 #else
                    while(job[current].working[mypos][CACHE_LINE_SIZE * bufferside] == 0) {};
 #endif
@@ -381,9 +358,7 @@ static int inner_advanced_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *
 
          MB;
          if (is + min_i >= m) {
-            LOCK_COMMAND(&getrf_lock);
-           job[current].working[mypos][CACHE_LINE_SIZE * bufferside] = 0;
-            UNLOCK_COMMAND(&getrf_lock);
+           atomic_store_long(&job[current].working[mypos][CACHE_LINE_SIZE * bufferside], 0);
          }
        }
 
@@ -397,10 +372,9 @@ static int inner_advanced_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *
     for (xxx = 0; xxx < DIVIDE_RATE; xxx++) {
 #if 1
        do {
-           LOCK_COMMAND(&getrf_lock);
-           jw = job[mypos].working[i][CACHE_LINE_SIZE *xxx];
-           UNLOCK_COMMAND(&getrf_lock);
+           jw = atomic_load_long(&job[mypos].working[i][CACHE_LINE_SIZE *xxx]);
        } while(jw != 0);
+       MB;
 #else
       while (job[mypos].working[i][CACHE_LINE_SIZE * xxx] ) {};
 #endif
@@ -443,12 +417,7 @@ blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa,
 #ifdef _MSC_VER
   BLASLONG flag[MAX_CPU_NUMBER * CACHE_LINE_SIZE];
 #else
-#if __STDC_VERSION__ >= 201112L
-  _Atomic
-#else  
-  volatile
-#endif  
-   BLASLONG flag[MAX_CPU_NUMBER * CACHE_LINE_SIZE] __attribute__((aligned(128)));
+  volatile BLASLONG flag[MAX_CPU_NUMBER * CACHE_LINE_SIZE] __attribute__((aligned(128)));
 #endif
 
 #ifndef COMPLEX
@@ -543,7 +512,11 @@ blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa,
       if (width > mn - is - bk) width = mn - is - bk;
     }
 
-    if (num_cpu > 0) exec_blas_async_wait(num_cpu, &queue[0]);
+
+    if (num_cpu > 0) {
+           WMB;
+           exec_blas_async_wait(num_cpu, &queue[0]);
+    }
 
     mm = m - bk - is;
     nn = n - bk - is;
@@ -608,7 +581,7 @@ blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa,
       queue[num_cpu].sa      = NULL;
       queue[num_cpu].sb      = NULL;
       queue[num_cpu].next    = &queue[num_cpu + 1];
-      flag[num_cpu * CACHE_LINE_SIZE] = 1;
+      atomic_store_long(&flag[num_cpu * CACHE_LINE_SIZE], 1);
 
       num_cpu ++;
 
@@ -637,6 +610,8 @@ blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa,
     if (num_cpu > 0) {
       queue[num_cpu - 1].next = NULL;
 
+      WMB;
+
       exec_blas_async(0, &queue[0]);
 
       inner_basic_thread(&newarg, NULL, range_n_mine, sa, sbb, -1);
@@ -647,14 +622,10 @@ blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa,
 
       for (i = 0; i < num_cpu; i ++) {
 #if 1
-             LOCK_COMMAND(&getrf_flag_lock);
-             f=flag[i*CACHE_LINE_SIZE];
-             UNLOCK_COMMAND(&getrf_flag_lock);
-             while (f!=0) {
-             LOCK_COMMAND(&getrf_flag_lock);
-             f=flag[i*CACHE_LINE_SIZE];
-             UNLOCK_COMMAND(&getrf_flag_lock);
-             };
+             do {
+                  f =  atomic_load_long(&flag[i*CACHE_LINE_SIZE]);
+             } while (f != 0);
+             MB;
 #else
               while (flag[i*CACHE_LINE_SIZE]) {};
 #endif
@@ -719,12 +690,7 @@ blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa,
   BLASLONG range[MAX_CPU_NUMBER + 1];
 
   BLASLONG width, nn, num_cpu;
-#if __STDC_VERSION__ >= 201112L
-  _Atomic
-#else  
-  volatile
-#endif
-   BLASLONG flag[MAX_CPU_NUMBER * CACHE_LINE_SIZE] __attribute__((aligned(128)));
+  volatile BLASLONG flag[MAX_CPU_NUMBER * CACHE_LINE_SIZE] __attribute__((aligned(128)));
 
 #ifndef COMPLEX
 #ifdef XDOUBLE
@@ -833,6 +799,8 @@ blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa,
     nn = n - bk - is;
     if (width > nn) width = nn;
 
+    WMB;
+
     if (num_cpu > 1)  exec_blas_async_wait(num_cpu - 1, &queue[1]);
 
     range[0] = 0;
@@ -867,7 +835,7 @@ blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa,
       queue[num_cpu].sa      = NULL;
       queue[num_cpu].sb      = NULL;
       queue[num_cpu].next    = &queue[num_cpu + 1];
-      flag[num_cpu * CACHE_LINE_SIZE] = 1;
+      atomic_store_long(&flag[num_cpu * CACHE_LINE_SIZE], 1);
 
       num_cpu ++;
     }
@@ -882,6 +850,7 @@ blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa,
     range_n_new[0] = offset + is;
     range_n_new[1] = offset + is + bk;
 
+    WMB;
     if (num_cpu > 1) {
 
       exec_blas_async(1, &queue[1]);
@@ -917,7 +886,7 @@ blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa,
 
 #endif
 
-      for (i = 1; i < num_cpu; i ++) while (flag[i * CACHE_LINE_SIZE]) {};
+      for (i = 1; i < num_cpu; i ++) while (atomic_load_long(&flag[i * CACHE_LINE_SIZE])) {};
 
       TRSM_ILTCOPY(bk, bk, a + (is +  is * lda) * COMPSIZE, lda, 0, sb);