Ref #103: enhancement for small matrix dimensions
authorwernsaar <wernsaar@googlemail.com>
Wed, 18 Jun 2014 13:04:11 +0000 (15:04 +0200)
committerwernsaar <wernsaar@googlemail.com>
Wed, 18 Jun 2014 13:04:11 +0000 (15:04 +0200)
interface/gemm.c

index 587175e..9ce7fe5 100644 (file)
@@ -72,7 +72,7 @@
 #endif
 
 #ifndef GEMM_MULTITHREAD_THRESHOLD
-# define GEMM_MULTITHREAD_THRESHOLD 4
+#define GEMM_MULTITHREAD_THRESHOLD 4
 #endif
 
 static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, FLOAT *, FLOAT *, BLASLONG) = {
@@ -400,14 +400,63 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
   mode |= (transa << BLAS_TRANSA_SHIFT);
   mode |= (transb << BLAS_TRANSB_SHIFT);
 
-  args.common = NULL;
+  int nthreads_max = num_cpu_avail(3);
+  int nthreads_avail = nthreads_max;
 
-  if(args.m <= GEMM_MULTITHREAD_THRESHOLD || args.n <= GEMM_MULTITHREAD_THRESHOLD 
-     || args.k <=GEMM_MULTITHREAD_THRESHOLD){
-    args.nthreads = 1;
-  }else{
-    args.nthreads = num_cpu_avail(3);
+#ifndef COMPLEX
+  double MNK = (double) args.m * (double) args.n * (double) args.k;
+  if ( MNK <= (1024.0  * (double) GEMM_MULTITHREAD_THRESHOLD)  )
+       nthreads_max = 1; 
+  else
+  {
+       if ( MNK <= (65536.0 * (double) GEMM_MULTITHREAD_THRESHOLD) )
+       {
+               nthreads_max = 4; 
+               if ( args.m < 16 * GEMM_MULTITHREAD_THRESHOLD )
+               {
+                       nthreads_max = 2; 
+                       if ( args.m <     3 * GEMM_MULTITHREAD_THRESHOLD ) nthreads_max = 1;
+                       if ( args.n <     1 * GEMM_MULTITHREAD_THRESHOLD ) nthreads_max = 1;
+                       if ( args.k <     3 * GEMM_MULTITHREAD_THRESHOLD ) nthreads_max = 1;
+               }
+               else
+               {
+                       if ( args.n <=    1 * GEMM_MULTITHREAD_THRESHOLD ) nthreads_max = 2;
+               }
+       }
   }
+#else
+  double MNK = (double) args.m * (double) args.n * (double) args.k;
+  if ( MNK <= (256.0  * (double) GEMM_MULTITHREAD_THRESHOLD)  )
+       nthreads_max = 1; 
+  else
+  {
+       if ( MNK <= (16384.0 * (double) GEMM_MULTITHREAD_THRESHOLD) )
+       {
+               nthreads_max = 4; 
+               if ( args.m < 3 * GEMM_MULTITHREAD_THRESHOLD )
+               {
+                       nthreads_max = 2; 
+                       if ( args.m <=    1 * GEMM_MULTITHREAD_THRESHOLD ) nthreads_max = 1;
+                       if ( args.n <     1 * GEMM_MULTITHREAD_THRESHOLD ) nthreads_max = 1;
+                       if ( args.k <     1 * GEMM_MULTITHREAD_THRESHOLD ) nthreads_max = 1;
+               }
+               else
+               {
+                               if ( args.n <     2 * GEMM_MULTITHREAD_THRESHOLD ) nthreads_max = 2;
+               }
+       }
+  }
+
+#endif
+  args.common = NULL;
+
+  if ( nthreads_max > nthreads_avail )
+       args.nthreads = nthreads_avail;
+  else
+       args.nthreads = nthreads_max;
+
+
  if (args.nthreads == 1) {
 #endif