From aa4139b34e1365fd5ef7278e53615721ff4506a2 Mon Sep 17 00:00:00 2001 From: Kai Li Date: Mon, 31 Mar 2014 22:23:22 +0800 Subject: [PATCH] Add caffe::cublasGetErrorString and redefine CUBLAS_CHECK with it --- include/caffe/common.hpp | 9 ++++++++- src/caffe/common.cpp | 22 ++++++++++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/include/caffe/common.hpp b/include/caffe/common.hpp index e436756..7c33620 100644 --- a/include/caffe/common.hpp +++ b/include/caffe/common.hpp @@ -32,7 +32,12 @@ private:\ cudaError_t error = condition; \ CHECK_EQ(error, cudaSuccess) << cudaGetErrorString(error); \ } while(0) -#define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS) + +#define CUBLAS_CHECK(condition) \ + do { \ + cublasStatus_t status = condition; \ + CHECK_EQ(status, CUBLAS_STATUS_SUCCESS) << cublasGetErrorString(status); \ + } while(0) #define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS) // CUDA: grid stride looping @@ -128,6 +133,8 @@ class Caffe { DISABLE_COPY_AND_ASSIGN(Caffe); }; +// NVIDIA_CUDA-5.5_Samples/common/inc/helper_cuda.h +const char* cublasGetErrorString(cublasStatus_t& error); // CUDA: thread number configuration. // Use 1024 threads per block, which requires cuda sm_2x or above, diff --git a/src/caffe/common.cpp b/src/caffe/common.cpp index 59cbc56..70c3eef 100644 --- a/src/caffe/common.cpp +++ b/src/caffe/common.cpp @@ -147,4 +147,26 @@ const void* Caffe::RNG::generator() const { return &generator_->rng; } +const char* cublasGetErrorString(cublasStatus_t& error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + } + return "Unknown cublas status"; +} + } // namespace caffe -- 2.7.4