Fix race conditions in multithreaded GEMM3M
authorMartin Kroeker <martin@ruby.chemie.uni-freiburg.de>
Sat, 23 Nov 2019 18:54:56 +0000 (19:54 +0100)
committerGitHub <noreply@github.com>
Sat, 23 Nov 2019 18:54:56 +0000 (19:54 +0100)
by adding barriers (and a mutex lock for the non-OpenMP case) like it was already done for GEMM in level3_thread.c some time ago

driver/level3/level3_gemm3m_thread.c

index 4903aa5..21d431b 100644 (file)
@@ -408,7 +408,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
 
       /* Make sure if no one is using another buffer */
       for (i = 0; i < args -> nthreads; i++)
-       while (job[mypos].working[i][CACHE_LINE_SIZE * bufferside]) {YIELDING;};
+       while (job[mypos].working[i][CACHE_LINE_SIZE * bufferside]) {YIELDING;MB;};
 
       STOP_RPCC(waiting1);
 
@@ -441,7 +441,8 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
 
       for (i = 0; i < args -> nthreads; i++)
        job[mypos].working[i][CACHE_LINE_SIZE * bufferside] = (BLASLONG)buffer[bufferside];
-      }
+      WMB;
+       }
 
     current = mypos;
 
@@ -458,7 +459,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
          START_RPCC();
 
          /* thread has to wait */
-         while(job[current].working[mypos][CACHE_LINE_SIZE * bufferside] == 0) {YIELDING;};
+         while(job[current].working[mypos][CACHE_LINE_SIZE * bufferside] == 0) {YIELDING;MB;};
 
          STOP_RPCC(waiting2);
 
@@ -477,6 +478,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
 
        if (m_to - m_from == min_i) {
          job[current].working[mypos][CACHE_LINE_SIZE * bufferside] = 0;
+       WMB;
        }
       }
     } while (current != mypos);
@@ -517,6 +519,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
        if (is + min_i >= m_to) {
          /* Thread doesn't need this buffer any more */
          job[current].working[mypos][CACHE_LINE_SIZE * bufferside] = 0;
+       WMB;
        }
        }
 
@@ -541,7 +544,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
 
       /* Make sure if no one is using another buffer */
       for (i = 0; i < args -> nthreads; i++)
-       while (job[mypos].working[i][CACHE_LINE_SIZE * bufferside]) {YIELDING;};
+       while (job[mypos].working[i][CACHE_LINE_SIZE * bufferside]) {YIELDING;MB;};
 
       STOP_RPCC(waiting1);
 
@@ -595,7 +598,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
          START_RPCC();
 
          /* thread has to wait */
-         while(job[current].working[mypos][CACHE_LINE_SIZE * bufferside] == 0) {YIELDING;};
+         while(job[current].working[mypos][CACHE_LINE_SIZE * bufferside] == 0) {YIELDING;MB;};
 
          STOP_RPCC(waiting2);
 
@@ -613,6 +616,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
 
        if (m_to - m_from == min_i) {
          job[current].working[mypos][CACHE_LINE_SIZE * bufferside] = 0;
+       WMB;
        }
       }
     } while (current != mypos);
@@ -677,7 +681,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
 
       /* Make sure if no one is using another buffer */
       for (i = 0; i < args -> nthreads; i++)
-       while (job[mypos].working[i][CACHE_LINE_SIZE * bufferside]) {YIELDING;};
+       while (job[mypos].working[i][CACHE_LINE_SIZE * bufferside]) {YIELDING;MB;};
 
       STOP_RPCC(waiting1);
 
@@ -731,7 +735,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
          START_RPCC();
 
          /* thread has to wait */
-         while(job[current].working[mypos][CACHE_LINE_SIZE * bufferside] == 0) {YIELDING;};
+         while(job[current].working[mypos][CACHE_LINE_SIZE * bufferside] == 0) {YIELDING;MB;};
 
          STOP_RPCC(waiting2);
 
@@ -748,8 +752,9 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
        }
 
        if (m_to - m_from == min_i) {
-         job[current].working[mypos][CACHE_LINE_SIZE * bufferside] = 0;
-       }
+         job[current].working[mypos][CACHE_LINE_SIZE * bufferside] &= 0;
+       WMB;
+}
       }
     } while (current != mypos);
 
@@ -787,7 +792,8 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
 #endif
        if (is + min_i >= m_to) {
          /* Thread doesn't need this buffer any more */
-         job[current].working[mypos][CACHE_LINE_SIZE * bufferside] = 0;
+         job[current].working[mypos][CACHE_LINE_SIZE * bufferside] &= 0;
+         WMB;
        }
        }
 
@@ -804,7 +810,7 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
 
   for (i = 0; i < args -> nthreads; i++) {
     for (xxx = 0; xxx < DIVIDE_RATE; xxx++) {
-      while (job[mypos].working[i][CACHE_LINE_SIZE * xxx] ) {YIELDING;};
+      while (job[mypos].working[i][CACHE_LINE_SIZE * xxx] ) {YIELDING;MB;};
     }
   }
 
@@ -840,6 +846,15 @@ static int inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
 static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
                       *range_n, FLOAT *sa, FLOAT *sb, BLASLONG mypos){
 
+#ifndef USE_OPENMP
+#ifndef OS_WINDOWS
+static pthread_mutex_t  level3_lock    = PTHREAD_MUTEX_INITIALIZER;
+#else
+CRITICAL_SECTION level3_lock;
+InitializeCriticalSection((PCRITICAL_SECTION)&level3_lock);
+#endif
+#endif
+
   blas_arg_t newarg;
 
   blas_queue_t queue[MAX_CPU_NUMBER];
@@ -869,6 +884,14 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
   mode  =  BLAS_SINGLE  | BLAS_REAL | BLAS_NODE;
 #endif
 
+#ifndef USE_OPENMP
+#ifndef OS_WINDOWS
+pthread_mutex_lock(&level3_lock);
+#else
+EnterCriticalSection((PCRITICAL_SECTION)&level3_lock);
+#endif
+#endif
+
   newarg.m        = args -> m;
   newarg.n        = args -> n;
   newarg.k        = args -> k;
@@ -973,6 +996,14 @@ static int gemm_driver(blas_arg_t *args, BLASLONG *range_m, BLASLONG
   free(job);
 #endif
 
+#ifndef USE_OPENMP
+#ifndef OS_WINDOWS
+  pthread_mutex_unlock(&level3_lock);
+#else
+  LeaveCriticalSection((PCRITICAL_SECTION)&level3_lock);
+#endif
+#endif
+
   return 0;
 }