Initial cuDNN v3 support
authorSimon Layton <slayton58@gmail.com>
Wed, 8 Jul 2015 19:35:55 +0000 (15:35 -0400)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Fri, 16 Oct 2015 00:50:09 +0000 (17:50 -0700)
include/caffe/vision_layers.hpp
src/caffe/layer_factory.cpp
src/caffe/layers/cudnn_conv_layer.cpp
src/caffe/layers/cudnn_conv_layer.cu
src/caffe/layers/cudnn_lcn_layer.cpp [new file with mode: 0644]
src/caffe/layers/cudnn_lcn_layer.cu [new file with mode: 0644]
src/caffe/layers/cudnn_lrn_layer.cpp [new file with mode: 0644]
src/caffe/layers/cudnn_lrn_layer.cu [new file with mode: 0644]
src/caffe/layers/lrn_layer.cpp
src/caffe/proto/caffe.proto
src/caffe/test/test_lrn_layer.cpp

index 06bc045..237b05d 100644 (file)
@@ -304,13 +304,24 @@ class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
   bool handles_setup_;
   cudnnHandle_t* handle_;
   cudaStream_t*  stream_;
+
+  // algorithms for forward and backwards convolutions
+  cudnnConvolutionFwdAlgo_t *fwd_algo_;
+  cudnnConvolutionBwdFilterAlgo_t *bwd_filter_algo_;
+  cudnnConvolutionBwdDataAlgo_t *bwd_data_algo_;
+
   vector<cudnnTensorDescriptor_t> bottom_descs_, top_descs_;
   cudnnTensorDescriptor_t    bias_desc_;
   cudnnFilterDescriptor_t      filter_desc_;
   vector<cudnnConvolutionDescriptor_t> conv_descs_;
   int bottom_offset_, top_offset_, bias_offset_;
-  size_t workspaceSizeInBytes;
-  void *workspace;
+
+  size_t *workspace_fwd_sizes_;
+  size_t *workspace_bwd_data_sizes_;
+  size_t *workspace_bwd_filter_sizes_;
+  size_t workspaceSizeInBytes;  // size of underlying storage
+  void *workspaceData;  // underlying storage
+  void **workspace;  // aliases into workspaceData
 };
 #endif
 
@@ -442,6 +453,65 @@ class LRNLayer : public Layer<Dtype> {
   vector<Blob<Dtype>*> product_bottom_vec_;
 };
 
+#ifdef USE_CUDNN
+
+template <typename Dtype>
+class CuDNNLRNLayer : public LRNLayer<Dtype> {
+ public:
+  explicit CuDNNLRNLayer(const LayerParameter& param)
+      : LRNLayer<Dtype>(param), handles_setup_(false) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual ~CuDNNLRNLayer();
+
+ protected:
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+
+  bool handles_setup_;
+  cudnnHandle_t             handle_;
+  cudnnLRNDescriptor_t norm_desc_;
+  cudnnTensorDescriptor_t bottom_desc_, top_desc_;
+
+  int size_;
+  Dtype alpha_, beta_, k_;
+};
+
+template <typename Dtype>
+class CuDNNLCNLayer : public LRNLayer<Dtype> {
+ public:
+  explicit CuDNNLCNLayer(const LayerParameter& param)
+      : LRNLayer<Dtype>(param), handles_setup_(false), tempDataSize(0),
+        tempData1(NULL), tempData2(NULL) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual ~CuDNNLCNLayer();
+
+ protected:
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+
+  bool handles_setup_;
+  cudnnHandle_t             handle_;
+  cudnnLRNDescriptor_t norm_desc_;
+  cudnnTensorDescriptor_t bottom_desc_, top_desc_;
+
+  int size_, pre_pad_;
+  Dtype alpha_, beta_, k_;
+
+  size_t tempDataSize;
+  void *tempData1, *tempData2;
+};
+
+#endif
 
 /**
  * @brief Pools the input image by taking the max, average, etc. within regions.
index 926c7d8..417ffe9 100644 (file)
@@ -54,10 +54,8 @@ shared_ptr<Layer<Dtype> > GetPoolingLayer(const LayerParameter& param) {
     return shared_ptr<Layer<Dtype> >(new PoolingLayer<Dtype>(param));
 #ifdef USE_CUDNN
   } else if (engine == PoolingParameter_Engine_CUDNN) {
-    PoolingParameter p_param = param.pooling_param();
-    if (p_param.pad() || p_param.pad_h() || p_param.pad_w() ||
-        param.top_size() > 1) {
-      LOG(INFO) << "CUDNN does not support padding or multiple tops. "
+    if (param.top_size() > 1) {
+      LOG(INFO) << "cuDNN does not support multiple tops. "
                 << "Using Caffe's own pooling layer.";
       return shared_ptr<Layer<Dtype> >(new PoolingLayer<Dtype>(param));
     }
@@ -70,6 +68,43 @@ shared_ptr<Layer<Dtype> > GetPoolingLayer(const LayerParameter& param) {
 
 REGISTER_LAYER_CREATOR(Pooling, GetPoolingLayer);
 
+// Get LRN layer according to engine
+template <typename Dtype>
+shared_ptr<Layer<Dtype> > GetLRNLayer(const LayerParameter& param) {
+  LRNParameter_Engine engine = param.lrn_param().engine();
+
+  if (engine == LRNParameter_Engine_DEFAULT) {
+#ifdef USE_CUDNN
+    engine = LRNParameter_Engine_CUDNN;
+#else
+    engine = LRNParameter_Engine_CAFFE;
+#endif
+  }
+
+  if (engine == LRNParameter_Engine_CAFFE) {
+    return shared_ptr<Layer<Dtype> >(new LRNLayer<Dtype>(param));
+#ifdef USE_CUDNN
+  } else if (engine == LRNParameter_Engine_CUDNN) {
+    LRNParameter lrn_param = param.lrn_param();
+
+    if (lrn_param.norm_region() ==LRNParameter_NormRegion_WITHIN_CHANNEL) {
+      return shared_ptr<Layer<Dtype> >(new CuDNNLCNLayer<Dtype>(param));
+    } else {
+      // local size is too big to be handled through cuDNN
+      if (param.lrn_param().local_size() > CUDNN_LRN_MAX_N) {
+        return shared_ptr<Layer<Dtype> >(new LRNLayer<Dtype>(param));
+      } else {
+        return shared_ptr<Layer<Dtype> >(new CuDNNLRNLayer<Dtype>(param));
+      }
+    }
+#endif
+  } else {
+    LOG(FATAL) << "Layer " << param.name() << " has unknown engine.";
+  }
+}
+
+REGISTER_LAYER_CREATOR(LRN, GetLRNLayer);
+
 // Get relu layer according to engine.
 template <typename Dtype>
 shared_ptr<Layer<Dtype> > GetReLULayer(const LayerParameter& param) {
index 3514fe2..d7b1e0d 100644 (file)
@@ -1,4 +1,5 @@
 #ifdef USE_CUDNN
+#include <algorithm>
 #include <vector>
 
 #include "caffe/filler.hpp"
@@ -24,13 +25,38 @@ void CuDNNConvolutionLayer<Dtype>::LayerSetUp(
   // Initialize CUDA streams and cuDNN.
   stream_         = new cudaStream_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
   handle_         = new cudnnHandle_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
+
+  // Initialize algorithm arrays
+  fwd_algo_       = new cudnnConvolutionFwdAlgo_t[bottom.size()];
+  bwd_filter_algo_= new cudnnConvolutionBwdFilterAlgo_t[bottom.size()];
+  bwd_data_algo_  = new cudnnConvolutionBwdDataAlgo_t[bottom.size()];
+
+  // initialize size arrays
+  workspace_fwd_sizes_ = new size_t[bottom.size()];
+  workspace_bwd_filter_sizes_ = new size_t[bottom.size()];
+  workspace_bwd_data_sizes_ = new size_t[bottom.size()];
+
+  // workspace data
   workspaceSizeInBytes = 0;
-  workspace = NULL;
+  workspaceData = NULL;
+  workspace = new void*[this->group_ * CUDNN_STREAMS_PER_GROUP];
+
+  for (size_t i = 0; i < bottom.size(); ++i) {
+    // initialize all to default algorithms
+    fwd_algo_[i] = (cudnnConvolutionFwdAlgo_t)0;
+    bwd_filter_algo_[i] = (cudnnConvolutionBwdFilterAlgo_t)0;
+    bwd_data_algo_[i] = (cudnnConvolutionBwdDataAlgo_t)0;
+    // default algorithms don't require workspace
+    workspace_fwd_sizes_[i] = 0;
+    workspace_bwd_data_sizes_[i] = 0;
+    workspace_bwd_filter_sizes_[i] = 0;
+  }
 
   for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) {
     CUDA_CHECK(cudaStreamCreate(&stream_[g]));
     CUDNN_CHECK(cudnnCreate(&handle_[g]));
     CUDNN_CHECK(cudnnSetStream(handle_[g], stream_[g]));
+    workspace[g] = NULL;
   }
 
   // Set the indexing parameters.
@@ -86,6 +112,10 @@ void CuDNNConvolutionLayer<Dtype>::Reshape(
   const int stride_h = stride_data[0];
   const int stride_w = stride_data[1];
 
+  // Specify workspace limit for kernels directly until we have a
+  // planning strategy and a rewrite of Caffe's GPU memory mangagement
+  size_t workspace_limit_bytes = 8*1024*1024;
+
   for (int i = 0; i < bottom.size(); i++) {
     cudnn::setTensor4dDesc<Dtype>(&bottom_descs_[i],
         this->num_,
@@ -98,7 +128,104 @@ void CuDNNConvolutionLayer<Dtype>::Reshape(
         this->num_output_ * this->out_spatial_dim_,
         this->out_spatial_dim_, width_out, 1);
     cudnn::setConvolutionDesc<Dtype>(&conv_descs_[i], bottom_descs_[i],
-        filter_desc_, pad_h, pad_w, stride_h, stride_w);
+        filter_desc_, pad_h, pad_w,
+        stride_h, stride_w);
+
+    // choose forward and backward algorithms + workspace(s)
+    CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(handle_[0],
+      bottom_descs_[i],
+      filter_desc_,
+      conv_descs_[i],
+      top_descs_[i],
+      CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
+      workspace_limit_bytes,
+      &fwd_algo_[i]));
+
+    CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle_[0],
+      bottom_descs_[i],
+      filter_desc_,
+      conv_descs_[i],
+      top_descs_[i],
+      fwd_algo_[i],
+      &(workspace_fwd_sizes_[i])));
+
+    // choose backward algorithm for filter
+    CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(handle_[0],
+          bottom_descs_[i], top_descs_[i], conv_descs_[i], filter_desc_,
+          CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
+          workspace_limit_bytes, &bwd_filter_algo_[i]) );
+
+    // get workspace for backwards filter algorithm
+    CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(handle_[0],
+          bottom_descs_[i], top_descs_[i], conv_descs_[i], filter_desc_,
+          bwd_filter_algo_[i], &workspace_bwd_filter_sizes_[i]));
+
+    // choose backward algo for data
+    CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(handle_[0],
+          filter_desc_, top_descs_[i], conv_descs_[i], bottom_descs_[i],
+          CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
+        workspace_limit_bytes, &bwd_data_algo_[i]));
+
+    // get workspace size
+    CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(handle_[0],
+          filter_desc_, top_descs_[i], conv_descs_[i], bottom_descs_[i],
+          bwd_data_algo_[i], &workspace_bwd_data_sizes_[i]) );
+  }
+
+  // reduce over all workspace sizes to get a maximum to allocate / reallocate
+  size_t total_workspace_fwd = 0;
+  size_t total_workspace_bwd_data = 0;
+  size_t total_workspace_bwd_filter = 0;
+
+  for (size_t i = 0; i < bottom.size(); i++) {
+    total_workspace_fwd        = std::max(total_workspace_fwd,
+                                     workspace_fwd_sizes_[i]);
+    total_workspace_bwd_data   = std::max(total_workspace_bwd_data,
+                                     workspace_bwd_data_sizes_[i]);
+    total_workspace_bwd_filter = std::max(total_workspace_bwd_filter,
+                                     workspace_bwd_filter_sizes_[i]);
+  }
+  // get max over all operations
+  size_t max_workspace = std::max(total_workspace_fwd,
+                             total_workspace_bwd_data);
+  max_workspace = std::max(max_workspace, total_workspace_bwd_filter);
+  // ensure all groups have enough workspace
+  size_t total_max_workspace = max_workspace *
+                               (this->group_ * CUDNN_STREAMS_PER_GROUP);
+
+  // this is the total amount of storage needed over all groups + streams
+  if (total_max_workspace > workspaceSizeInBytes) {
+    LOG(INFO) << "Reallocating workspace storage: " << total_max_workspace;
+    workspaceSizeInBytes = total_max_workspace;
+
+    // free the existing workspace and allocate a new (larger) one
+    cudaFree(this->workspaceData);
+
+    cudaError_t err = cudaMalloc(&(this->workspaceData), workspaceSizeInBytes);
+    if (err != cudaSuccess) {
+      // force zero memory path
+      for (int i = 0; i < bottom.size(); i++) {
+        workspace_fwd_sizes_[i] = 0;
+        workspace_bwd_filter_sizes_[i] = 0;
+        workspace_bwd_data_sizes_[i] = 0;
+        fwd_algo_[i] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
+        bwd_filter_algo_[i] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
+        bwd_data_algo_[i] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
+      }
+
+      // NULL out all workspace pointers
+      for (int g = 0; g < (this->group_ * CUDNN_STREAMS_PER_GROUP); g++) {
+        workspace[g] = NULL;
+      }
+      // NULL out underlying data
+      workspaceData = NULL;
+      workspaceSizeInBytes = 0;
+    }
+
+    // if we succeed in the allocation, set pointer aliases for workspaces
+    for (int g = 0; g < (this->group_ * CUDNN_STREAMS_PER_GROUP); g++) {
+      workspace[g] = reinterpret_cast<char *>(workspaceData) + g*max_workspace;
+    }
   }
 
   // Tensor descriptor for bias.
@@ -128,8 +255,15 @@ CuDNNConvolutionLayer<Dtype>::~CuDNNConvolutionLayer() {
     cudnnDestroy(handle_[g]);
   }
 
+  cudaFree(workspaceData);
   delete [] stream_;
   delete [] handle_;
+  delete [] fwd_algo_;
+  delete [] bwd_filter_algo_;
+  delete [] bwd_data_algo_;
+  delete [] workspace_fwd_sizes_;
+  delete [] workspace_bwd_data_sizes_;
+  delete [] workspace_bwd_filter_sizes_;
 }
 
 INSTANTIATE_CLASS(CuDNNConvolutionLayer);
index 6911520..e88e4dd 100644 (file)
@@ -14,11 +14,6 @@ __global__ void sync_conv_groups() { }
 template <typename Dtype>
 void CuDNNConvolutionLayer<Dtype>::Forward_gpu(
     const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
-  const int* kernel_shape_data = this->kernel_shape_.cpu_data();
-  const int kernel_h = kernel_shape_data[0];
-  const int kernel_w = kernel_shape_data[1];
-  const size_t workspace_limit_bytes =
-      kernel_h * kernel_w * this->channels_ * sizeof(int) + 1;
   const Dtype* weight = this->blobs_[0]->gpu_data();
   for (int i = 0; i < bottom.size(); ++i) {
     const Dtype* bottom_data = bottom[i]->gpu_data();
@@ -26,52 +21,13 @@ void CuDNNConvolutionLayer<Dtype>::Forward_gpu(
 
     // Forward through cuDNN in parallel over groups.
     for (int g = 0; g < this->group_; g++) {
-      cudnnConvolutionFwdAlgo_t algo;
-
-      // pick the convolution algorithm
-      // TODO(shelhamer) this should be done during reshape
-      // TODO(shelhamer) the choice of automatic or manual algorithm picking
-      // should be exposed in proto
-      CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(handle_[g],
-        bottom_descs_[i],
-        filter_desc_,
-        conv_descs_[i],
-        top_descs_[i],
-        CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
-        workspace_limit_bytes,  // memoryLimitInBytes,
-        &algo));
-
-      // get minimum size of the workspace needed for the desired algorithm
-      size_t workspaceSizeInBytes_temp = 0;
-
-      CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle_[g],
-        bottom_descs_[i],
-        filter_desc_,
-        conv_descs_[i],
-        top_descs_[i],
-        algo,
-        &workspaceSizeInBytes_temp));
-
-      if (workspaceSizeInBytes_temp > workspaceSizeInBytes) {
-        workspaceSizeInBytes = workspaceSizeInBytes_temp;
-        // free the existing workspace and allocate a new (larger) one
-        cudaFree(this->workspace);
-        cudaError_t err = cudaMalloc(&(this->workspace), workspaceSizeInBytes);
-        if (err != cudaSuccess) {
-          // force zero memory path
-          algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
-          workspace = NULL;
-          workspaceSizeInBytes = 0;
-        }
-      }
-
       // Filters.
       CUDNN_CHECK(cudnnConvolutionForward(handle_[g],
             cudnn::dataType<Dtype>::one,
             bottom_descs_[i], bottom_data + bottom_offset_ * g,
             filter_desc_, weight + this->weight_offset_ * g,
             conv_descs_[i],
-            algo, workspace, workspaceSizeInBytes,
+            fwd_algo_[i], workspace[g], workspace_fwd_sizes_[i],
             cudnn::dataType<Dtype>::zero,
             top_descs_[i], top_data + top_offset_ * g));
 
@@ -101,10 +57,12 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   if (this->param_propagate_down_[0]) {
     weight = this->blobs_[0]->gpu_data();
     weight_diff = this->blobs_[0]->mutable_gpu_diff();
+    caffe_gpu_set(this->blobs_[0]->count(), Dtype(0), weight_diff);
   }
   Dtype* bias_diff = NULL;
   if (this->bias_term_ && this->param_propagate_down_[1]) {
     bias_diff = this->blobs_[1]->mutable_gpu_diff();
+    caffe_gpu_set(this->blobs_[1]->count(), Dtype(0), bias_diff);
   }
   for (int i = 0; i < top.size(); ++i) {
     const Dtype* top_diff = top[i]->gpu_diff();
@@ -122,11 +80,14 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
       // Gradient w.r.t. weights.
       if (this->param_propagate_down_[0]) {
         const Dtype* bottom_data = bottom[i]->gpu_data();
-        CUDNN_CHECK(cudnnConvolutionBackwardFilter(handle_[1*this->group_ + g],
+        CUDNN_CHECK(cudnnConvolutionBackwardFilter_v3(
+              handle_[1*this->group_ + g],
               cudnn::dataType<Dtype>::one,
               bottom_descs_[i], bottom_data + bottom_offset_ * g,
               top_descs_[i],    top_diff + top_offset_ * g,
               conv_descs_[i],
+              bwd_filter_algo_[i], workspace[1*this->group_ + g],
+              workspace_bwd_filter_sizes_[i],
               cudnn::dataType<Dtype>::one,
               filter_desc_, weight_diff + this->weight_offset_ * g));
       }
@@ -137,11 +98,14 @@ void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
           weight = this->blobs_[0]->gpu_data();
         }
         Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
-        CUDNN_CHECK(cudnnConvolutionBackwardData(handle_[2*this->group_ + g],
+        CUDNN_CHECK(cudnnConvolutionBackwardData_v3(
+              handle_[2*this->group_ + g],
               cudnn::dataType<Dtype>::one,
               filter_desc_, weight + this->weight_offset_ * g,
               top_descs_[i], top_diff + top_offset_ * g,
               conv_descs_[i],
+              bwd_data_algo_[i], workspace[2*this->group_ + g],
+              workspace_bwd_data_sizes_[i],
               cudnn::dataType<Dtype>::zero,
               bottom_descs_[i], bottom_diff + bottom_offset_ * g));
       }
diff --git a/src/caffe/layers/cudnn_lcn_layer.cpp b/src/caffe/layers/cudnn_lcn_layer.cpp
new file mode 100644 (file)
index 0000000..866d810
--- /dev/null
@@ -0,0 +1,77 @@
+#ifdef USE_CUDNN
+#include <vector>
+
+#include "caffe/filler.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/util/im2col.hpp"
+#include "caffe/util/math_functions.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void CuDNNLCNLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  LRNLayer<Dtype>::LayerSetUp(bottom, top);
+
+  CUDNN_CHECK(cudnnCreate(&handle_));
+  CUDNN_CHECK(cudnnCreateLRNDescriptor(&norm_desc_));
+  cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
+  cudnn::createTensor4dDesc<Dtype>(&top_desc_);
+
+  // create a LRN handle
+  handles_setup_ = true;
+
+  size_ = this->layer_param().lrn_param().local_size();
+  pre_pad_ = (size_ - 1) / 2;
+  alpha_ = this->layer_param().lrn_param().alpha();
+  beta_ = this->layer_param().lrn_param().beta();
+  k_ = this->layer_param().lrn_param().k();
+}
+
+template <typename Dtype>
+void CuDNNLCNLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  LRNLayer<Dtype>::Reshape(bottom, top);
+  cudnn::setTensor4dDesc<Dtype>(&bottom_desc_, bottom[0]->num(),
+      this->channels_, this->height_, this->width_);
+  cudnn::setTensor4dDesc<Dtype>(&top_desc_, bottom[0]->num(),
+      this->channels_, this->height_, this->width_);
+  CUDNN_CHECK(cudnnSetLRNDescriptor(norm_desc_, size_, alpha_, beta_, k_));
+
+  // allocate / reallocate tempData buffers
+  size_t totalSizeInBytes = sizeof(Dtype)*bottom[0]->num()* \
+                            this->channels_*this->height_*this->width_;
+
+  if (totalSizeInBytes > tempDataSize) {
+    tempDataSize = totalSizeInBytes;
+
+    cudaFree(tempData1);
+    cudaFree(tempData2);
+
+    // allocate new buffers
+    CUDA_CHECK(cudaMalloc(&tempData1, totalSizeInBytes));
+    CUDA_CHECK(cudaMalloc(&tempData2, totalSizeInBytes));
+  }
+}
+
+template <typename Dtype>
+CuDNNLCNLayer<Dtype>::~CuDNNLCNLayer() {
+  // Check that handles have been setup before destroying.
+  if (!handles_setup_) { return; }
+
+  cudnnDestroyTensorDescriptor(bottom_desc_);
+  cudnnDestroyTensorDescriptor(top_desc_);
+
+  // destroy LRN handle
+  cudnnDestroy(handle_);
+
+  // free temp buffers
+  cudaFree(tempData1);
+  cudaFree(tempData2);
+}
+
+INSTANTIATE_CLASS(CuDNNLCNLayer);
+
+}   // namespace caffe
+#endif
diff --git a/src/caffe/layers/cudnn_lcn_layer.cu b/src/caffe/layers/cudnn_lcn_layer.cu
new file mode 100644 (file)
index 0000000..c07ade7
--- /dev/null
@@ -0,0 +1,50 @@
+#ifdef USE_CUDNN
+#include <vector>
+
+#include "caffe/filler.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/util/im2col.hpp"
+#include "caffe/util/math_functions.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void CuDNNLCNLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* top_data = top[0]->mutable_gpu_data();
+
+  CUDNN_CHECK(cudnnDivisiveNormalizationForward(
+        handle_, norm_desc_, CUDNN_DIVNORM_PRECOMPUTED_MEANS,
+        cudnn::dataType<Dtype>::one,
+        bottom_desc_, bottom_data,
+        NULL,  // srcMeansData
+        this->tempData1, this->tempData2,
+        cudnn::dataType<Dtype>::zero,
+        top_desc_, top_data) );
+}
+
+template <typename Dtype>
+void CuDNNLCNLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
+  const Dtype* top_diff = top[0]->gpu_diff();
+  const Dtype* top_data = top[0]->gpu_data();
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+
+  CUDNN_CHECK(cudnnDivisiveNormalizationBackward(
+        handle_, norm_desc_, CUDNN_DIVNORM_PRECOMPUTED_MEANS,
+        cudnn::dataType<Dtype>::one,
+        bottom_desc_, bottom_data,
+        NULL, top_diff,  // NULL - srcMeansData
+        this->tempData1, this->tempData2,
+        cudnn::dataType<Dtype>::zero,
+        bottom_desc_, bottom_diff,
+        NULL) );
+}
+
+INSTANTIATE_LAYER_GPU_FUNCS(CuDNNLCNLayer);
+
+}  // namespace caffe
+#endif
diff --git a/src/caffe/layers/cudnn_lrn_layer.cpp b/src/caffe/layers/cudnn_lrn_layer.cpp
new file mode 100644 (file)
index 0000000..6e99214
--- /dev/null
@@ -0,0 +1,57 @@
+#ifdef USE_CUDNN
+#include <vector>
+
+#include "caffe/filler.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/util/im2col.hpp"
+#include "caffe/util/math_functions.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void CuDNNLRNLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  LRNLayer<Dtype>::LayerSetUp(bottom, top);
+
+  CUDNN_CHECK(cudnnCreate(&handle_));
+  CUDNN_CHECK(cudnnCreateLRNDescriptor(&norm_desc_));
+  cudnn::createTensor4dDesc<Dtype>(&bottom_desc_);
+  cudnn::createTensor4dDesc<Dtype>(&top_desc_);
+
+  // create a LRN handle
+  handles_setup_ = true;
+
+  size_ = this->layer_param().lrn_param().local_size();
+  alpha_ = this->layer_param().lrn_param().alpha();
+  beta_ = this->layer_param().lrn_param().beta();
+  k_ = this->layer_param().lrn_param().k();
+}
+
+template <typename Dtype>
+void CuDNNLRNLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  LRNLayer<Dtype>::Reshape(bottom, top);
+  cudnn::setTensor4dDesc<Dtype>(&bottom_desc_, bottom[0]->num(),
+      this->channels_, this->height_, this->width_);
+  cudnn::setTensor4dDesc<Dtype>(&top_desc_, bottom[0]->num(),
+      this->channels_, this->height_, this->width_);
+  CUDNN_CHECK(cudnnSetLRNDescriptor(norm_desc_, size_, alpha_, beta_, k_));
+}
+
+template <typename Dtype>
+CuDNNLRNLayer<Dtype>::~CuDNNLRNLayer() {
+  // Check that handles have been setup before destroying.
+  if (!handles_setup_) { return; }
+
+  cudnnDestroyTensorDescriptor(bottom_desc_);
+  cudnnDestroyTensorDescriptor(top_desc_);
+
+  // destroy LRN handle
+  cudnnDestroy(handle_);
+}
+
+INSTANTIATE_CLASS(CuDNNLRNLayer);
+
+}   // namespace caffe
+#endif
diff --git a/src/caffe/layers/cudnn_lrn_layer.cu b/src/caffe/layers/cudnn_lrn_layer.cu
new file mode 100644 (file)
index 0000000..f992303
--- /dev/null
@@ -0,0 +1,48 @@
+#ifdef USE_CUDNN
+#include <vector>
+
+#include "caffe/filler.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/util/im2col.hpp"
+#include "caffe/util/math_functions.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void CuDNNLRNLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* top_data = top[0]->mutable_gpu_data();
+
+  CUDNN_CHECK(cudnnLRNCrossChannelForward(
+        handle_, norm_desc_, CUDNN_LRN_CROSS_CHANNEL_DIM1,
+        cudnn::dataType<Dtype>::one,
+        bottom_desc_, bottom_data,
+        cudnn::dataType<Dtype>::zero,
+        top_desc_, top_data) );
+}
+
+template <typename Dtype>
+void CuDNNLRNLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
+  const Dtype* top_diff = top[0]->gpu_diff();
+  const Dtype* top_data = top[0]->gpu_data();
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+
+  CUDNN_CHECK(cudnnLRNCrossChannelBackward(
+        handle_, norm_desc_, CUDNN_LRN_CROSS_CHANNEL_DIM1,
+        cudnn::dataType<Dtype>::one,
+        top_desc_, top_data,
+        top_desc_, top_diff,
+        bottom_desc_, bottom_data,
+        cudnn::dataType<Dtype>::zero,
+        bottom_desc_, bottom_diff) );
+}
+
+INSTANTIATE_LAYER_GPU_FUNCS(CuDNNLRNLayer);
+
+};  // namespace caffe
+
+#endif
index 36c1ace..d18a04e 100644 (file)
@@ -254,6 +254,5 @@ STUB_GPU_BACKWARD(LRNLayer, CrossChannelBackward);
 #endif
 
 INSTANTIATE_CLASS(LRNLayer);
-REGISTER_LAYER_CLASS(LRN);
 
 }  // namespace caffe
index f52c941..af01b47 100644 (file)
@@ -721,6 +721,12 @@ message LRNParameter {
   }
   optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS];
   optional float k = 5 [default = 1.];
+  enum Engine {
+    DEFAULT = 0;
+    CAFFE = 1;
+    CUDNN = 2;
+  }
+  optional Engine engine = 6 [default = DEFAULT];
 }
 
 message MemoryDataParameter {
index c4e2f8e..78cf2d9 100644 (file)
@@ -246,5 +246,201 @@ TYPED_TEST(LRNLayerTest, TestGradientWithinChannel) {
       this->blob_top_vec_);
 }
 
+#ifdef USE_CUDNN
+template <typename Dtype>
+class CuDNNLRNLayerTest : public GPUDeviceTest<Dtype> {
+ protected:
+  CuDNNLRNLayerTest()
+      : epsilon_(Dtype(1e-5)),
+        blob_bottom_(new Blob<Dtype>()),
+        blob_top_(new Blob<Dtype>()) {}
+  virtual void SetUp() {
+    Caffe::set_random_seed(1701);
+    blob_bottom_->Reshape(2, 7, 3, 3);
+    // fill the values
+    FillerParameter filler_param;
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    blob_bottom_vec_.push_back(blob_bottom_);
+    blob_top_vec_.push_back(blob_top_);
+  }
+  virtual ~CuDNNLRNLayerTest() { delete blob_bottom_; delete blob_top_; }
+  void ReferenceLRNForward(const Blob<Dtype>& blob_bottom,
+      const LayerParameter& layer_param, Blob<Dtype>* blob_top);
+
+  Dtype epsilon_;
+  Blob<Dtype>* const blob_bottom_;
+  Blob<Dtype>* const blob_top_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+template <typename TypeParam>
+void CuDNNLRNLayerTest<TypeParam>::ReferenceLRNForward(
+    const Blob<TypeParam>& blob_bottom, const LayerParameter& layer_param,
+    Blob<TypeParam>* blob_top) {
+  typedef TypeParam Dtype;
+  blob_top->Reshape(blob_bottom.num(), blob_bottom.channels(),
+      blob_bottom.height(), blob_bottom.width());
+  Dtype* top_data = blob_top->mutable_cpu_data();
+  LRNParameter lrn_param = layer_param.lrn_param();
+  Dtype alpha = lrn_param.alpha();
+  Dtype beta = lrn_param.beta();
+  int size = lrn_param.local_size();
+  switch (lrn_param.norm_region()) {
+  case LRNParameter_NormRegion_ACROSS_CHANNELS:
+    for (int n = 0; n < blob_bottom.num(); ++n) {
+      for (int c = 0; c < blob_bottom.channels(); ++c) {
+        for (int h = 0; h < blob_bottom.height(); ++h) {
+          for (int w = 0; w < blob_bottom.width(); ++w) {
+            int c_start = c - (size - 1) / 2;
+            int c_end = min(c_start + size, blob_bottom.channels());
+            c_start = max(c_start, 0);
+            Dtype scale = 1.;
+            for (int i = c_start; i < c_end; ++i) {
+              Dtype value = blob_bottom.data_at(n, i, h, w);
+              scale += value * value * alpha / size;
+            }
+            *(top_data + blob_top->offset(n, c, h, w)) =
+              blob_bottom.data_at(n, c, h, w) / pow(scale, beta);
+          }
+        }
+      }
+    }
+    break;
+  case LRNParameter_NormRegion_WITHIN_CHANNEL:
+    for (int n = 0; n < blob_bottom.num(); ++n) {
+      for (int c = 0; c < blob_bottom.channels(); ++c) {
+        for (int h = 0; h < blob_bottom.height(); ++h) {
+          int h_start = h - (size - 1) / 2;
+          int h_end = min(h_start + size, blob_bottom.height());
+          h_start = max(h_start, 0);
+          for (int w = 0; w < blob_bottom.width(); ++w) {
+            Dtype scale = 1.;
+            int w_start = w - (size - 1) / 2;
+            int w_end = min(w_start + size, blob_bottom.width());
+            w_start = max(w_start, 0);
+            for (int nh = h_start; nh < h_end; ++nh) {
+              for (int nw = w_start; nw < w_end; ++nw) {
+                Dtype value = blob_bottom.data_at(n, c, nh, nw);
+                scale += value * value * alpha / (size * size);
+              }
+            }
+            *(top_data + blob_top->offset(n, c, h, w)) =
+              blob_bottom.data_at(n, c, h, w) / pow(scale, beta);
+          }
+        }
+      }
+    }
+    break;
+  default:
+    LOG(FATAL) << "Unknown normalization region.";
+  }
+}
+
+TYPED_TEST_CASE(CuDNNLRNLayerTest, TestDtypes);
+
+TYPED_TEST(CuDNNLRNLayerTest, TestForwardAcrossChannelsCuDNN) {
+  // typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  CuDNNLRNLayer<TypeParam> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  Blob<TypeParam> top_reference;
+  this->ReferenceLRNForward(*(this->blob_bottom_), layer_param,
+      &top_reference);
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    EXPECT_NEAR(this->blob_top_->cpu_data()[i], top_reference.cpu_data()[i],
+                this->epsilon_);
+  }
+}
+
+TYPED_TEST(CuDNNLRNLayerTest, TestForwardAcrossChannelsLargeRegionCuDNN) {
+  typedef TypeParam Dtype;
+  LayerParameter layer_param;
+  layer_param.mutable_lrn_param()->set_local_size(15);
+  CuDNNLRNLayer<Dtype> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  Blob<Dtype> top_reference;
+  this->ReferenceLRNForward(*(this->blob_bottom_), layer_param,
+      &top_reference);
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    EXPECT_NEAR(this->blob_top_->cpu_data()[i], top_reference.cpu_data()[i],
+                this->epsilon_);
+  }
+}
+
+TYPED_TEST(CuDNNLRNLayerTest, TestGradientAcrossChannelsCuDNN) {
+  typedef TypeParam Dtype;
+  LayerParameter layer_param;
+  CuDNNLRNLayer<Dtype> layer(layer_param);
+  GradientChecker<Dtype> checker(1e-2, 1e-2);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    this->blob_top_->mutable_cpu_diff()[i] = 1.;
+  }
+  vector<bool> propagate_down(this->blob_bottom_vec_.size(), true);
+  layer.Backward(this->blob_top_vec_, propagate_down,
+                 this->blob_bottom_vec_);
+  checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_);
+}
+
+TYPED_TEST(CuDNNLRNLayerTest, TestForwardWithinChannel) {
+  typedef TypeParam Dtype;
+  LayerParameter layer_param;
+  layer_param.mutable_lrn_param()->set_norm_region(
+      LRNParameter_NormRegion_WITHIN_CHANNEL);
+  layer_param.mutable_lrn_param()->set_local_size(3);
+  CuDNNLCNLayer<Dtype> layer(layer_param);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  Blob<Dtype> top_reference;
+  this->ReferenceLRNForward(*(this->blob_bottom_), layer_param,
+      &top_reference);
+  for (int i = 0; i < this->blob_bottom_->count(); ++i) {
+    EXPECT_NEAR(this->blob_top_->cpu_data()[i], top_reference.cpu_data()[i],
+                this->epsilon_);
+  }
+}
+
+TYPED_TEST(CuDNNLRNLayerTest, TestGradientWithinChannel) {
+  typedef TypeParam Dtype;
+  LayerParameter layer_param;
+  layer_param.mutable_lrn_param()->set_norm_region(
+      LRNParameter_NormRegion_WITHIN_CHANNEL);
+  layer_param.mutable_lrn_param()->set_local_size(3);
+  CuDNNLCNLayer<Dtype> layer(layer_param);
+  GradientChecker<Dtype> checker(1e-2, 1e-2);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    this->blob_top_->mutable_cpu_diff()[i] = 1.;
+  }
+  checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_);
+}
+
+TYPED_TEST(CuDNNLRNLayerTest, TestGradientAcrossChannelsLargeRegionCuDNN) {
+  typedef TypeParam Dtype;
+  LayerParameter layer_param;
+  layer_param.mutable_lrn_param()->set_local_size(15);
+  CuDNNLRNLayer<Dtype> layer(layer_param);
+  GradientChecker<Dtype> checker(1e-2, 1e-2);
+  layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    this->blob_top_->mutable_cpu_diff()[i] = 1.;
+  }
+  vector<bool> propagate_down(this->blob_bottom_vec_.size(), true);
+  layer.Backward(this->blob_top_vec_, propagate_down,
+                 this->blob_bottom_vec_);
+  checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_);
+}
+
+#endif
 
 }  // namespace caffe