Add curandGetErrorString and use it to redefine CURAND_CHECK
authorKai Li <kaili_kloud@163.com>
Mon, 31 Mar 2014 14:29:36 +0000 (22:29 +0800)
committerKai Li <kaili_kloud@163.com>
Tue, 1 Apr 2014 02:27:13 +0000 (10:27 +0800)
include/caffe/common.hpp
src/caffe/common.cpp

index 7c33620..d395bd7 100644 (file)
@@ -38,7 +38,12 @@ private:\
     cublasStatus_t status = condition; \
     CHECK_EQ(status, CUBLAS_STATUS_SUCCESS) << cublasGetErrorString(status); \
   } while(0)
-#define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
+
+#define CURAND_CHECK(condition) \
+  do { \
+    curandStatus_t status = condition; \
+    CHECK_EQ(status, CURAND_STATUS_SUCCESS) << curandGetErrorString(status); \
+  } while(0)
 
 // CUDA: grid stride looping
 #define CUDA_KERNEL_LOOP(i, n) \
@@ -135,6 +140,7 @@ class Caffe {
 
 // NVIDIA_CUDA-5.5_Samples/common/inc/helper_cuda.h
 const char* cublasGetErrorString(cublasStatus_t& error);
+const char* curandGetErrorString(curandStatus_t error);
 
 // CUDA: thread number configuration.
 // Use 1024 threads per block, which requires cuda sm_2x or above,
index 70c3eef..c2d5058 100644 (file)
@@ -169,4 +169,36 @@ const char* cublasGetErrorString(cublasStatus_t& error) {
   return "Unknown cublas status";
 }
 
+const char* curandGetErrorString(curandStatus_t error) {
+  switch (error) {
+  case CURAND_STATUS_SUCCESS:
+    return "CURAND_STATUS_SUCCESS";
+  case CURAND_STATUS_VERSION_MISMATCH:
+    return "CURAND_STATUS_VERSION_MISMATCH";
+  case CURAND_STATUS_NOT_INITIALIZED:
+    return "CURAND_STATUS_NOT_INITIALIZED";
+  case CURAND_STATUS_ALLOCATION_FAILED:
+    return "CURAND_STATUS_ALLOCATION_FAILED";
+  case CURAND_STATUS_TYPE_ERROR:
+    return "CURAND_STATUS_TYPE_ERROR";
+  case CURAND_STATUS_OUT_OF_RANGE:
+    return "CURAND_STATUS_OUT_OF_RANGE";
+  case CURAND_STATUS_LENGTH_NOT_MULTIPLE:
+    return "CURAND_STATUS_LENGTH_NOT_MULTIPLE";
+  case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED:
+    return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED";
+  case CURAND_STATUS_LAUNCH_FAILURE:
+    return "CURAND_STATUS_LAUNCH_FAILURE";
+  case CURAND_STATUS_PREEXISTING_FAILURE:
+    return "CURAND_STATUS_PREEXISTING_FAILURE";
+  case CURAND_STATUS_INITIALIZATION_FAILED:
+    return "CURAND_STATUS_INITIALIZATION_FAILED";
+  case CURAND_STATUS_ARCH_MISMATCH:
+    return "CURAND_STATUS_ARCH_MISMATCH";
+  case CURAND_STATUS_INTERNAL_ERROR:
+    return "CURAND_STATUS_INTERNAL_ERROR";
+  }
+  return "Unknown curand status";
+}
+
 }  // namespace caffe