From 14fe30cc098bf2d06e4a6baee07be17f9cb8c00e Mon Sep 17 00:00:00 2001 From: Kai Li Date: Mon, 31 Mar 2014 22:29:36 +0800 Subject: [PATCH] Add curandGetErrorString and use it to redefine CURAND_CHECK --- include/caffe/common.hpp | 8 +++++++- src/caffe/common.cpp | 32 ++++++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/include/caffe/common.hpp b/include/caffe/common.hpp index 7c33620..d395bd7 100644 --- a/include/caffe/common.hpp +++ b/include/caffe/common.hpp @@ -38,7 +38,12 @@ private:\ cublasStatus_t status = condition; \ CHECK_EQ(status, CUBLAS_STATUS_SUCCESS) << cublasGetErrorString(status); \ } while(0) -#define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS) + +#define CURAND_CHECK(condition) \ + do { \ + curandStatus_t status = condition; \ + CHECK_EQ(status, CURAND_STATUS_SUCCESS) << curandGetErrorString(status); \ + } while(0) // CUDA: grid stride looping #define CUDA_KERNEL_LOOP(i, n) \ @@ -135,6 +140,7 @@ class Caffe { // NVIDIA_CUDA-5.5_Samples/common/inc/helper_cuda.h const char* cublasGetErrorString(cublasStatus_t& error); +const char* curandGetErrorString(curandStatus_t error); // CUDA: thread number configuration. // Use 1024 threads per block, which requires cuda sm_2x or above, diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp index 70c3eef..c2d5058 100644 --- a/src/caffe/common.cpp +++ b/src/caffe/common.cpp @@ -169,4 +169,36 @@ const char* cublasGetErrorString(cublasStatus_t& error) { return "Unknown cublas status"; } +const char* curandGetErrorString(curandStatus_t error) { + switch (error) { + case CURAND_STATUS_SUCCESS: + return "CURAND_STATUS_SUCCESS"; + case CURAND_STATUS_VERSION_MISMATCH: + return "CURAND_STATUS_VERSION_MISMATCH"; + case CURAND_STATUS_NOT_INITIALIZED: + return "CURAND_STATUS_NOT_INITIALIZED"; + case CURAND_STATUS_ALLOCATION_FAILED: + return "CURAND_STATUS_ALLOCATION_FAILED"; + case CURAND_STATUS_TYPE_ERROR: + return "CURAND_STATUS_TYPE_ERROR"; + case CURAND_STATUS_OUT_OF_RANGE: + return "CURAND_STATUS_OUT_OF_RANGE"; + case CURAND_STATUS_LENGTH_NOT_MULTIPLE: + return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; + case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: + return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; + case CURAND_STATUS_LAUNCH_FAILURE: + return "CURAND_STATUS_LAUNCH_FAILURE"; + case CURAND_STATUS_PREEXISTING_FAILURE: + return "CURAND_STATUS_PREEXISTING_FAILURE"; + case CURAND_STATUS_INITIALIZATION_FAILED: + return "CURAND_STATUS_INITIALIZATION_FAILED"; + case CURAND_STATUS_ARCH_MISMATCH: + return "CURAND_STATUS_ARCH_MISMATCH"; + case CURAND_STATUS_INTERNAL_ERROR: + return "CURAND_STATUS_INTERNAL_ERROR"; + } + return "Unknown curand status"; +} + } // namespace caffe -- 2.7.4