Add caffe::cublasGetErrorString and redefine CUBLAS_CHECK with it
authorKai Li <kaili_kloud@163.com>
Mon, 31 Mar 2014 14:23:22 +0000 (22:23 +0800)
committerKai Li <kaili_kloud@163.com>
Tue, 1 Apr 2014 02:25:40 +0000 (10:25 +0800)
include/caffe/common.hpp
src/caffe/common.cpp

index e436756..7c33620 100644 (file)
@@ -32,7 +32,12 @@ private:\
     cudaError_t error = condition; \
     CHECK_EQ(error, cudaSuccess) << cudaGetErrorString(error); \
   } while(0)
-#define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
+
+#define CUBLAS_CHECK(condition) \
+  do { \
+    cublasStatus_t status = condition; \
+    CHECK_EQ(status, CUBLAS_STATUS_SUCCESS) << cublasGetErrorString(status); \
+  } while(0)
 #define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
 
 // CUDA: grid stride looping
@@ -128,6 +133,8 @@ class Caffe {
   DISABLE_COPY_AND_ASSIGN(Caffe);
 };
 
+// NVIDIA_CUDA-5.5_Samples/common/inc/helper_cuda.h
+const char* cublasGetErrorString(cublasStatus_t& error);
 
 // CUDA: thread number configuration.
 // Use 1024 threads per block, which requires cuda sm_2x or above,
index 59cbc56..70c3eef 100644 (file)
@@ -147,4 +147,26 @@ const void* Caffe::RNG::generator() const {
   return &generator_->rng;
 }
 
+const char* cublasGetErrorString(cublasStatus_t& error) {
+  switch (error) {
+  case CUBLAS_STATUS_SUCCESS:
+    return "CUBLAS_STATUS_SUCCESS";
+  case CUBLAS_STATUS_NOT_INITIALIZED:
+    return "CUBLAS_STATUS_NOT_INITIALIZED";
+  case CUBLAS_STATUS_ALLOC_FAILED:
+    return "CUBLAS_STATUS_ALLOC_FAILED";
+  case CUBLAS_STATUS_INVALID_VALUE:
+    return "CUBLAS_STATUS_INVALID_VALUE";
+  case CUBLAS_STATUS_ARCH_MISMATCH:
+    return "CUBLAS_STATUS_ARCH_MISMATCH";
+  case CUBLAS_STATUS_MAPPING_ERROR:
+    return "CUBLAS_STATUS_MAPPING_ERROR";
+  case CUBLAS_STATUS_EXECUTION_FAILED:
+    return "CUBLAS_STATUS_EXECUTION_FAILED";
+  case CUBLAS_STATUS_INTERNAL_ERROR:
+    return "CUBLAS_STATUS_INTERNAL_ERROR";
+  }
+  return "Unknown cublas status";
+}
+
 }  // namespace caffe