shared_ptr<Caffe> Caffe::singleton_;
Caffe::Caffe()
- : mode_(Caffe::CPU), phase_(Caffe::TRAIN) {
- CUBLAS_CHECK(cublasCreate(&cublas_handle_));
- CURAND_CHECK(curandCreateGenerator(&curand_generator_,
- CURAND_RNG_PSEUDO_DEFAULT));
- CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(curand_generator_,
- 1701ULL));
- VSL_CHECK(vslNewStream(&vsl_stream_, VSL_BRNG_MT19937, 1701));
+ : mode_(Caffe::CPU), phase_(Caffe::TRAIN), cublas_handle_(NULL),
+ curand_generator_(NULL), vsl_stream_(NULL) {
+ if (cublasCreate(&cublas_handle_) != CUBLAS_STATUS_SUCCESS) {
+ LOG(ERROR) << "Cannot create Cublas handle. Cublas won't be available.";
+ }
+ if (curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT)
+ != CURAND_STATUS_SUCCESS ||
+ curandSetPseudoRandomGeneratorSeed(curand_generator_, 1701ULL)
+ != CURAND_STATUS_SUCCESS) {
+ LOG(ERROR) << "Cannot create Curand generator. Curand won't be available.";
+ }
+ if (vslNewStream(&vsl_stream_, VSL_BRNG_MT19937, 1701) != VSL_STATUS_OK) {
+ LOG(ERROR) << "Cannot create vsl stream. VSL random number generator "
+ << "won't be available.";
+ }
}
Caffe::~Caffe() {
if (!vsl_stream_) VSL_CHECK(vslDeleteStream(&vsl_stream_));
};
-Caffe& Caffe::Get() {
- if (!singleton_) {
- singleton_.reset(new Caffe());
- }
- return *singleton_;
-};
-
-VSLStreamStatePtr Caffe::vsl_stream() {
- return Get().vsl_stream_;
-}
-
-cublasHandle_t Caffe::cublas_handle() {
- return Get().cublas_handle_;
-};
-
-curandGenerator_t Caffe::curand_generator() {
- return Get().curand_generator_;
-};
-
-Caffe::Brew Caffe::mode() {
- return Get().mode_;
-}
-
-void Caffe::set_mode(Caffe::Brew mode) {
- Get().mode_ = mode;
-}
-
-Caffe::Phase Caffe::phase() {
- return Get().phase_;
-}
-
-void Caffe::set_phase(Caffe::Phase phase) {
- Get().phase_ = phase;
-}
-
void Caffe::set_random_seed(const unsigned int seed) {
// Curand seed
// Yangqing's note: simply setting the generator seed does not seem to
// work on the tesla K20s, so I wrote the ugly reset thing below. It is not
// tested yet and I'll wait til Jeff finishes training.
- CURAND_CHECK(curandDestroyGenerator(curand_generator()));
- CURAND_CHECK(curandCreateGenerator(&Get().curand_generator_,
- CURAND_RNG_PSEUDO_DEFAULT));
- CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(curand_generator(),
- seed));
+ if (Get().curand_generator_) {
+ CURAND_CHECK(curandDestroyGenerator(curand_generator()));
+ CURAND_CHECK(curandCreateGenerator(&Get().curand_generator_,
+ CURAND_RNG_PSEUDO_DEFAULT));
+ CURAND_CHECK(curandSetPseudoRandomGeneratorSeed(curand_generator(),
+ seed));
+ } else {
+ LOG(ERROR) << "Curand not available. Skipping setting the curand seed.";
+ }
// VSL seed
VSL_CHECK(vslDeleteStream(&(Get().vsl_stream_)));
VSL_CHECK(vslNewStream(&(Get().vsl_stream_), VSL_BRNG_MT19937, seed));
}
} // namespace caffe
-
return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS;
}
+
// A singleton class to hold common caffe stuff, such as the handler that
// caffe is going to use for cublas.
class Caffe {
+ private:
+ // The private constructor to avoid duplicate instantiation.
+ Caffe();
+
+ protected:
+ static shared_ptr<Caffe> singleton_;
+
public:
~Caffe();
- static Caffe& Get();
+ inline static Caffe& Get() {
+ if (!singleton_.get()) {
+ singleton_.reset(new Caffe());
+ }
+ return *singleton_;
+ }
enum Brew { CPU, GPU };
enum Phase { TRAIN, TEST };
// The getters for the variables.
// Returns the cublas handle.
- static cublasHandle_t cublas_handle();
+ inline static cublasHandle_t cublas_handle() { return Get().cublas_handle_; }
// Returns the curand generator.
- static curandGenerator_t curand_generator();
+ inline static curandGenerator_t curand_generator() {
+ return Get().curand_generator_;
+ }
// Returns the MKL random stream.
- static VSLStreamStatePtr vsl_stream();
+ inline static VSLStreamStatePtr vsl_stream() { return Get().vsl_stream_; }
// Returns the mode: running on CPU or GPU.
- static Brew mode();
+ inline static Brew mode() { return Get().mode_; }
// Returns the phase: TRAIN or TEST.
- static Phase phase();
+ inline static Phase phase() { return Get().phase_; }
// The setters for the variables
// Sets the mode.
- static void set_mode(Brew mode);
+ inline static void set_mode(Brew mode) { Get().mode_ = mode; }
// Sets the phase.
- static void set_phase(Phase phase);
+ inline static void set_phase(Phase phase) { Get().phase_ = phase; }
// Sets the random seed of both MKL and curand
static void set_random_seed(const unsigned int seed);
- private:
- // The private constructor to avoid duplicate instantiation.
- Caffe();
-
protected:
- static shared_ptr<Caffe> singleton_;
cublasHandle_t cublas_handle_;
curandGenerator_t curand_generator_;
VSLStreamStatePtr vsl_stream_;
Brew mode_;
Phase phase_;
-};
+ DISABLE_COPY_AND_ASSIGN(Caffe);
+};
} // namespace caffe