Fixes for CuDNN layers: only destroy handles if setup
authorJeff Donahue <jeff.donahue@gmail.com>
Tue, 10 Feb 2015 02:08:50 +0000 (18:08 -0800)
committerJeff Donahue <jeff.donahue@gmail.com>
Tue, 10 Feb 2015 02:15:51 +0000 (18:15 -0800)
include/caffe/common_layers.hpp
include/caffe/neuron_layers.hpp
include/caffe/vision_layers.hpp
src/caffe/layers/cudnn_conv_layer.cpp
src/caffe/layers/cudnn_pooling_layer.cpp
src/caffe/layers/cudnn_relu_layer.cpp
src/caffe/layers/cudnn_sigmoid_layer.cpp
src/caffe/layers/cudnn_softmax_layer.cpp
src/caffe/layers/cudnn_tanh_layer.cpp

index 3a5ccd2..c67822c 100644 (file)
@@ -377,7 +377,7 @@ template <typename Dtype>
 class CuDNNSoftmaxLayer : public SoftmaxLayer<Dtype> {
  public:
   explicit CuDNNSoftmaxLayer(const LayerParameter& param)
-      : SoftmaxLayer<Dtype>(param) {}
+      : SoftmaxLayer<Dtype>(param), handles_setup_(false) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top);
   virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
@@ -390,6 +390,7 @@ class CuDNNSoftmaxLayer : public SoftmaxLayer<Dtype> {
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
 
+  bool handles_setup_;
   cudnnHandle_t             handle_;
   cudnnTensor4dDescriptor_t bottom_desc_;
   cudnnTensor4dDescriptor_t top_desc_;
index b65d984..46c8417 100644 (file)
@@ -417,7 +417,7 @@ template <typename Dtype>
 class CuDNNReLULayer : public ReLULayer<Dtype> {
  public:
   explicit CuDNNReLULayer(const LayerParameter& param)
-      : ReLULayer<Dtype>(param) {}
+      : ReLULayer<Dtype>(param), handles_setup_(false) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top);
   virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
@@ -430,6 +430,7 @@ class CuDNNReLULayer : public ReLULayer<Dtype> {
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
 
+  bool handles_setup_;
   cudnnHandle_t             handle_;
   cudnnTensor4dDescriptor_t bottom_desc_;
   cudnnTensor4dDescriptor_t top_desc_;
@@ -499,7 +500,7 @@ template <typename Dtype>
 class CuDNNSigmoidLayer : public SigmoidLayer<Dtype> {
  public:
   explicit CuDNNSigmoidLayer(const LayerParameter& param)
-      : SigmoidLayer<Dtype>(param) {}
+      : SigmoidLayer<Dtype>(param), handles_setup_(false) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top);
   virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
@@ -512,6 +513,7 @@ class CuDNNSigmoidLayer : public SigmoidLayer<Dtype> {
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
 
+  bool handles_setup_;
   cudnnHandle_t             handle_;
   cudnnTensor4dDescriptor_t bottom_desc_;
   cudnnTensor4dDescriptor_t top_desc_;
@@ -583,7 +585,7 @@ template <typename Dtype>
 class CuDNNTanHLayer : public TanHLayer<Dtype> {
  public:
   explicit CuDNNTanHLayer(const LayerParameter& param)
-      : TanHLayer<Dtype>(param) {}
+      : TanHLayer<Dtype>(param), handles_setup_(false) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top);
   virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
@@ -596,6 +598,7 @@ class CuDNNTanHLayer : public TanHLayer<Dtype> {
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
 
+  bool handles_setup_;
   cudnnHandle_t             handle_;
   cudnnTensor4dDescriptor_t bottom_desc_;
   cudnnTensor4dDescriptor_t top_desc_;
index 38e369c..6cb507a 100644 (file)
@@ -230,7 +230,7 @@ template <typename Dtype>
 class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
  public:
   explicit CuDNNConvolutionLayer(const LayerParameter& param)
-      : ConvolutionLayer<Dtype>(param) {}
+      : ConvolutionLayer<Dtype>(param), handles_setup_(false) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top);
   virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
@@ -243,6 +243,7 @@ class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
 
+  bool handles_setup_;
   cudnnHandle_t* handle_;
   cudaStream_t*  stream_;
   vector<cudnnTensor4dDescriptor_t> bottom_descs_, top_descs_;
@@ -426,7 +427,7 @@ template <typename Dtype>
 class CuDNNPoolingLayer : public PoolingLayer<Dtype> {
  public:
   explicit CuDNNPoolingLayer(const LayerParameter& param)
-      : PoolingLayer<Dtype>(param) {}
+      : PoolingLayer<Dtype>(param), handles_setup_(false) {}
   virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top);
   virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
@@ -442,6 +443,7 @@ class CuDNNPoolingLayer : public PoolingLayer<Dtype> {
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
 
+  bool handles_setup_;
   cudnnHandle_t             handle_;
   cudnnTensor4dDescriptor_t bottom_desc_, top_desc_;
   cudnnPoolingDescriptor_t  pooling_desc_;
index f74a3db..4a69ca2 100644 (file)
@@ -58,6 +58,8 @@ void CuDNNConvolutionLayer<Dtype>::LayerSetUp(
   if (this->bias_term_) {
     cudnn::createTensor4dDesc<Dtype>(&bias_desc_);
   }
+
+  handles_setup_ = true;
 }
 
 template <typename Dtype>
@@ -98,6 +100,9 @@ void CuDNNConvolutionLayer<Dtype>::Reshape(
 
 template <typename Dtype>
 CuDNNConvolutionLayer<Dtype>::~CuDNNConvolutionLayer() {
+  // Check that handles have been setup before destroying.
+  if (!handles_setup_) { return; }
+
   for (int i = 0; i < bottom_descs_.size(); i++) {
     cudnnDestroyTensor4dDescriptor(bottom_descs_[i]);
     cudnnDestroyTensor4dDescriptor(top_descs_[i]);
index fc0222a..dd90195 100644 (file)
@@ -22,6 +22,7 @@ void CuDNNPoolingLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
   cudnn::createPoolingDesc<Dtype>(&pooling_desc_,
       this->layer_param_.pooling_param().pool(), &mode_,
       this->kernel_h_, this->kernel_w_, this->stride_h_, this->stride_w_);
+  handles_setup_ = true;
 }
 
 template <typename Dtype>
@@ -36,6 +37,9 @@ void CuDNNPoolingLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 CuDNNPoolingLayer<Dtype>::~CuDNNPoolingLayer() {
+  // Check that handles have been setup before destroying.
+  if (!handles_setup_) { return; }
+
   cudnnDestroyTensor4dDescriptor(bottom_desc_);
   cudnnDestroyTensor4dDescriptor(top_desc_);
   cudnnDestroyPoolingDescriptor(pooling_desc_);
index 20f486f..0b8a6bc 100644 (file)
@@ -15,6 +15,7 @@ void CuDNNReLULayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
   CUDNN_CHECK(cudnnCreate(&handle_));
   cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
   cudnn::createTensor4dDesc<Dtype>(&top_desc_);
+  handles_setup_ = true;
 }
 
 template <typename Dtype>
@@ -31,6 +32,9 @@ void CuDNNReLULayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 CuDNNReLULayer<Dtype>::~CuDNNReLULayer() {
+  // Check that handles have been setup before destroying.
+  if (!handles_setup_) { return; }
+
   cudnnDestroyTensor4dDescriptor(this->bottom_desc_);
   cudnnDestroyTensor4dDescriptor(this->top_desc_);
   cudnnDestroy(this->handle_);
index a94c004..67bd9c3 100644 (file)
@@ -15,6 +15,7 @@ void CuDNNSigmoidLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
   CUDNN_CHECK(cudnnCreate(&handle_));
   cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
   cudnn::createTensor4dDesc<Dtype>(&top_desc_);
+  handles_setup_ = true;
 }
 
 template <typename Dtype>
@@ -31,6 +32,9 @@ void CuDNNSigmoidLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 CuDNNSigmoidLayer<Dtype>::~CuDNNSigmoidLayer() {
+  // Check that handles have been setup before destroying.
+  if (!handles_setup_) { return; }
+
   cudnnDestroyTensor4dDescriptor(this->bottom_desc_);
   cudnnDestroyTensor4dDescriptor(this->top_desc_);
   cudnnDestroy(this->handle_);
index 1a0f406..83a5b69 100644 (file)
@@ -19,6 +19,7 @@ void CuDNNSoftmaxLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
   CUDNN_CHECK(cudnnCreate(&handle_));
   cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
   cudnn::createTensor4dDesc<Dtype>(&top_desc_);
+  handles_setup_ = true;
 }
 
 template <typename Dtype>
@@ -35,6 +36,9 @@ void CuDNNSoftmaxLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 CuDNNSoftmaxLayer<Dtype>::~CuDNNSoftmaxLayer() {
+  // Check that handles have been setup before destroying.
+  if (!handles_setup_) { return; }
+
   cudnnDestroyTensor4dDescriptor(bottom_desc_);
   cudnnDestroyTensor4dDescriptor(top_desc_);
   cudnnDestroy(handle_);
index 39a3e14..b1d2b86 100644 (file)
@@ -15,6 +15,7 @@ void CuDNNTanHLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
   CUDNN_CHECK(cudnnCreate(&handle_));
   cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
   cudnn::createTensor4dDesc<Dtype>(&top_desc_);
+  handles_setup_ = true;
 }
 
 template <typename Dtype>
@@ -31,6 +32,9 @@ void CuDNNTanHLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 CuDNNTanHLayer<Dtype>::~CuDNNTanHLayer() {
+  // Check that handles have been setup before destroying.
+  if (!handles_setup_) { return; }
+
   cudnnDestroyTensor4dDescriptor(this->bottom_desc_);
   cudnnDestroyTensor4dDescriptor(this->top_desc_);
   cudnnDestroy(this->handle_);