95fd73ab5922795f50695ad753d12567a481de47
[platform/upstream/caffeonacl.git] / src / caffe / common.hpp
1 #ifndef CAFFE_COMMON_HPP_
2 #define CAFFE_COMMON_HPP_
3
4 #include <boost/shared_ptr.hpp>
5 #include <cublas_v2.h>
6 #include <cuda.h>
7 #include <curand.h>
8 #include <glog/logging.h>
9 #include <mkl_vsl.h>
10
11 #include "driver_types.h"
12
13 #define CUDA_CHECK(condition) CHECK_EQ((condition), cudaSuccess)
14 #define CUBLAS_CHECK(condition) CHECK_EQ((condition), CUBLAS_STATUS_SUCCESS)
15 #define CURAND_CHECK(condition) CHECK_EQ((condition), CURAND_STATUS_SUCCESS)
16 #define VSL_CHECK(condition) CHECK_EQ((condition), VSL_STATUS_OK)
17
18 #define CUDA_POST_KERNEL_CHECK \
19   if (cudaSuccess != cudaPeekAtLastError()) {\
20     LOG(FATAL) << "Cuda kernel failed. Error: " << cudaGetLastError(); \
21   }
22
23 #define INSTANTIATE_CLASS(classname) \
24   template class classname<float>; \
25   template class classname<double>
26
27 #define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented Yet"
28
29
30 namespace caffe {
31
32 // We will use the boost shared_ptr instead of the new C++11 one mainly
33 // because cuda does not work (at least now) well with C++11 features.
34 using boost::shared_ptr;
35
36 // For backward compatibility we will just use 512 threads per block
37 const int CAFFE_CUDA_NUM_THREADS = 512;
38
39 inline int CAFFE_GET_BLOCKS(const int N) {
40   return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
41 }
42
43 // A singleton class to hold common caffe stuff, such as the handler that
44 // caffe is going to use for cublas.
45 class Caffe {
46  public:
47   ~Caffe();
48   static Caffe& Get();
49   enum Brew { CPU, GPU };
50   enum Phase { TRAIN, TEST };
51
52   // The getters for the variables.
53   // Returns the cublas handle.
54   static cublasHandle_t cublas_handle();
55   // Returns the curand generator.
56   static curandGenerator_t curand_generator();
57   // Returns the MKL random stream.
58   static VSLStreamStatePtr vsl_stream();
59   // Returns the mode: running on CPU or GPU.
60   static Brew mode();
61   // Returns the phase: TRAIN or TEST.
62   static Phase phase();
63   // The setters for the variables
64   // Sets the mode.
65   static void set_mode(Brew mode);
66   // Sets the phase.
67   static void set_phase(Phase phase);
68   // Sets the random seed of both MKL and curand
69   static void set_random_seed(const unsigned int seed);
70
71  private:
72   // The private constructor to avoid duplicate instantiation.
73   Caffe();
74
75  protected:
76   static shared_ptr<Caffe> singleton_;
77   cublasHandle_t cublas_handle_;
78   curandGenerator_t curand_generator_;
79   VSLStreamStatePtr vsl_stream_;
80   Brew mode_;
81   Phase phase_;
82 };
83
84
85 }  // namespace caffe
86
87 #endif  // CAFFE_COMMON_HPP_