[ BLAS ] Move to use CUBLAS for gemm
authorjijoong.moon <jijoong.moon@samsung.com>
Fri, 4 Sep 2020 00:45:33 +0000 (09:45 +0900)
committerJijoong Moon <jijoong.moon@samsung.com>
Wed, 9 Sep 2020 04:46:32 +0000 (13:46 +0900)
Move CUBLAS gemm routine into blas_interface

**Self evaluation:**
1. Build test:  [X]Passed [ ]Failed [ ]Skipped
2. Run test:  [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <jijoong.moon@samsung.com>
meson.build
meson_options.txt
nntrainer/include/blas_interface.h
nntrainer/src/blas_interface.cpp
nntrainer/src/tensor.cpp

index 4ba390c..9024e42 100644 (file)
@@ -86,6 +86,10 @@ openmp_dep = dependency('openmp')
 
 blas_dep = dummy_dep
 # Dependencies
+if get_option('enable-cublas')
+   add_project_arguments('-DUSE_CUBLAS=1', language:['c','cpp'])
+endif
+
 if get_option('enable-blas')
   add_project_arguments('-DUSE_BLAS=1', language:['c','cpp'])
   if build_platform == 'tizen'
index 49e7711..5a4b3f0 100644 (file)
@@ -1,5 +1,6 @@
 option('enable-tizen', type: 'boolean', value: false)
 option('enable-blas', type: 'boolean', value: true)
+option('enable-cublas', type: 'boolean', value: false)
 option('enable-app', type: 'boolean', value: true)
 option('install-app', type: 'boolean', value: true)
 option('use_gym', type: 'boolean', value: false)
index 9819297..e3310fe 100644 (file)
@@ -34,6 +34,11 @@ enum CBLAS_TRANSPOSE {
 };
 #endif
 
+#ifdef USE_CUBLAS
+#include <helper_cuda.h>
+#include <helper_functions.h>
+#endif
+
 namespace nntrainer {
 
 /* TODO : need to scopy, sscal, snrm2 */
index 9bbc4f2..bb9cd77 100644 (file)
@@ -123,7 +123,33 @@ void sgemm(CBLAS_ORDER order, CBLAS_TRANSPOSE TransA, CBLAS_TRANSPOSE TransB,
            const float *B, const unsigned int ldb, const float beta, float *C,
            const unsigned int ldc) {
 
-#ifdef USE_BLAS
+#ifdef USE_CUBLAS
+  int devID = 0;
+  cudaDeviceProp deviceProp;
+  cudaGetDeviceProperties(&deviceProp, devID);
+  float *d_A, *d_B, *d_C;
+
+  unsigned int size_A = M * K * sizeof(float);
+  unsigned int size_B = K * N * sizeof(float);
+  unsigned int size_C = M * N * sizeof(float);
+
+  cudaMalloc((void **)&d_A, size_A);
+  cudaMalloc((void **)&d_B, size_B);
+  cudaMemcpy(d_A, A, size_A, cudaMemcpyHostToDevice);
+  cudaMemcpy(d_B, B, size_B, cudaMemcpyHostToDevice);
+  cudaMalloc((void **)&d_C, size_C);
+
+  cublasHandle_t handle;
+  cublasCreate(&handle);
+
+  cublasOperation_t transA = (TransA == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
+  cublasOperation_t transB = (TransB == CblasTrans) ? CUBLAS_OP_T : CUBLAS_OP_N;
+  cublasSgemm(handle, transA, transB, N, M, K, &alpha, d_B, N, d_A, K, &beta,
+              d_C, N);
+
+  cudaMemcpy(C, d_C, size_C, cudaMemcpyDeviceToHost);
+  cublasDestroy(handle);
+#elif defined USE_BLAS
   cblas_sgemm(order, TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C,
               ldc);
 #else
index 2a3eef8..ca3eff6 100644 (file)
 
 #include <lazy_tensor.h>
 
-#ifdef USE_CUBLAS
-#include <helper_cuda.h>
-#include <helper_functions.h>
-#endif
-
 #define transposeloop(cl, ci, cj, ck, sl, si, sj, sk)                 \
   do {                                                                \
     unsigned int i, j, k, l;                                          \
@@ -484,38 +479,10 @@ Tensor Tensor::dot(Tensor const &m, bool trans, bool trans_m) const {
   const float alpha = 1.0f;
   const float beta = 0.0f;
 
-#ifdef USE_CUBLAS
-  int devID = 0;
-  cudaDeviceProp deviceProp;
-  cudaGetDeviceProperties(&deviceProp, devID);
-  float *d_A, *d_B, *d_C;
-
-  unsigned int size_A = this->length() * sizeof(float);
-  unsigned int size_B = m.length() * sizeof(float);
-  unsigned int size_C = result.length() * sizeof(float);
-
-  cudaMalloc((void **)&d_A, size_A);
-  cudaMalloc((void **)&d_B, size_B);
-  cudaMemcpy(d_A, data, size_A, cudaMemcpyHostToDevice);
-  cudaMemcpy(d_B, mdata, size_B, cudaMemcpyHostToDevice);
-  cudaMalloc((void **)&d_C, size_C);
-
-  cublasHandle_t handle;
-  cublasCreate(&handle);
-
-  cublasOperation_t transA = trans ? CUBLAS_OP_T : CUBLAS_OP_N;
-  cublasOperation_t transB = trans_m ? CUBLAS_OP_T : CUBLAS_OP_N;
-  cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, N, M, K, &alpha, d_B, N, d_A, K,
-              &beta, d_C, N);
-
-  cudaMemcpy(rdata, d_C, size_C, cudaMemcpyDeviceToHost);
-  cublasDestroy(handle);
-#else
   enum CBLAS_TRANSPOSE transA = trans ? CblasTrans : CblasNoTrans;
   enum CBLAS_TRANSPOSE transB = trans_m ? CblasTrans : CblasNoTrans;
   sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, data, lda, mdata, ldb,
         beta, rdata, ldc);
-#endif
 
   return result;
 }