From ecac7ff6286642420eb5db723c382e74bf82c9d7 Mon Sep 17 00:00:00 2001 From: Simon Layton Date: Wed, 8 Jul 2015 15:35:55 -0400 Subject: [PATCH] Initial cuDNN v3 support --- include/caffe/vision_layers.hpp | 74 ++++++++++++- src/caffe/layer_factory.cpp | 43 +++++++- src/caffe/layers/cudnn_conv_layer.cpp | 138 +++++++++++++++++++++++- src/caffe/layers/cudnn_conv_layer.cu | 58 ++-------- src/caffe/layers/cudnn_lcn_layer.cpp | 77 +++++++++++++ src/caffe/layers/cudnn_lcn_layer.cu | 50 +++++++++ src/caffe/layers/cudnn_lrn_layer.cpp | 57 ++++++++++ src/caffe/layers/cudnn_lrn_layer.cu | 48 +++++++++ src/caffe/layers/lrn_layer.cpp | 1 - src/caffe/proto/caffe.proto | 6 ++ src/caffe/test/test_lrn_layer.cpp | 196 ++++++++++++++++++++++++++++++++++ 11 files changed, 692 insertions(+), 56 deletions(-) create mode 100644 src/caffe/layers/cudnn_lcn_layer.cpp create mode 100644 src/caffe/layers/cudnn_lcn_layer.cu create mode 100644 src/caffe/layers/cudnn_lrn_layer.cpp create mode 100644 src/caffe/layers/cudnn_lrn_layer.cu diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index 06bc045..237b05d 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -304,13 +304,24 @@ class CuDNNConvolutionLayer : public ConvolutionLayer { bool handles_setup_; cudnnHandle_t* handle_; cudaStream_t* stream_; + + // algorithms for forward and backwards convolutions + cudnnConvolutionFwdAlgo_t *fwd_algo_; + cudnnConvolutionBwdFilterAlgo_t *bwd_filter_algo_; + cudnnConvolutionBwdDataAlgo_t *bwd_data_algo_; + vector bottom_descs_, top_descs_; cudnnTensorDescriptor_t bias_desc_; cudnnFilterDescriptor_t filter_desc_; vector conv_descs_; int bottom_offset_, top_offset_, bias_offset_; - size_t workspaceSizeInBytes; - void *workspace; + + size_t *workspace_fwd_sizes_; + size_t *workspace_bwd_data_sizes_; + size_t *workspace_bwd_filter_sizes_; + size_t workspaceSizeInBytes; // size of underlying storage + void *workspaceData; // underlying storage + void **workspace; // aliases into workspaceData }; #endif @@ -442,6 +453,65 @@ class LRNLayer : public Layer { vector*> product_bottom_vec_; }; +#ifdef USE_CUDNN + +template +class CuDNNLRNLayer : public LRNLayer { + public: + explicit CuDNNLRNLayer(const LayerParameter& param) + : LRNLayer(param), handles_setup_(false) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + virtual ~CuDNNLRNLayer(); + + protected: + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + bool handles_setup_; + cudnnHandle_t handle_; + cudnnLRNDescriptor_t norm_desc_; + cudnnTensorDescriptor_t bottom_desc_, top_desc_; + + int size_; + Dtype alpha_, beta_, k_; +}; + +template +class CuDNNLCNLayer : public LRNLayer { + public: + explicit CuDNNLCNLayer(const LayerParameter& param) + : LRNLayer(param), handles_setup_(false), tempDataSize(0), + tempData1(NULL), tempData2(NULL) {} + virtual void LayerSetUp(const vector*>& bottom, + const vector*>& top); + virtual void Reshape(const vector*>& bottom, + const vector*>& top); + virtual ~CuDNNLCNLayer(); + + protected: + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); + virtual void Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom); + + bool handles_setup_; + cudnnHandle_t handle_; + cudnnLRNDescriptor_t norm_desc_; + cudnnTensorDescriptor_t bottom_desc_, top_desc_; + + int size_, pre_pad_; + Dtype alpha_, beta_, k_; + + size_t tempDataSize; + void *tempData1, *tempData2; +}; + +#endif /** * @brief Pools the input image by taking the max, average, etc. within regions. diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp index 926c7d8..417ffe9 100644 --- a/src/caffe/layer_factory.cpp +++ b/src/caffe/layer_factory.cpp @@ -54,10 +54,8 @@ shared_ptr > GetPoolingLayer(const LayerParameter& param) { return shared_ptr >(new PoolingLayer(param)); #ifdef USE_CUDNN } else if (engine == PoolingParameter_Engine_CUDNN) { - PoolingParameter p_param = param.pooling_param(); - if (p_param.pad() || p_param.pad_h() || p_param.pad_w() || - param.top_size() > 1) { - LOG(INFO) << "CUDNN does not support padding or multiple tops. " + if (param.top_size() > 1) { + LOG(INFO) << "cuDNN does not support multiple tops. " << "Using Caffe's own pooling layer."; return shared_ptr >(new PoolingLayer(param)); } @@ -70,6 +68,43 @@ shared_ptr > GetPoolingLayer(const LayerParameter& param) { REGISTER_LAYER_CREATOR(Pooling, GetPoolingLayer); +// Get LRN layer according to engine +template +shared_ptr > GetLRNLayer(const LayerParameter& param) { + LRNParameter_Engine engine = param.lrn_param().engine(); + + if (engine == LRNParameter_Engine_DEFAULT) { +#ifdef USE_CUDNN + engine = LRNParameter_Engine_CUDNN; +#else + engine = LRNParameter_Engine_CAFFE; +#endif + } + + if (engine == LRNParameter_Engine_CAFFE) { + return shared_ptr >(new LRNLayer(param)); +#ifdef USE_CUDNN + } else if (engine == LRNParameter_Engine_CUDNN) { + LRNParameter lrn_param = param.lrn_param(); + + if (lrn_param.norm_region() ==LRNParameter_NormRegion_WITHIN_CHANNEL) { + return shared_ptr >(new CuDNNLCNLayer(param)); + } else { + // local size is too big to be handled through cuDNN + if (param.lrn_param().local_size() > CUDNN_LRN_MAX_N) { + return shared_ptr >(new LRNLayer(param)); + } else { + return shared_ptr >(new CuDNNLRNLayer(param)); + } + } +#endif + } else { + LOG(FATAL) << "Layer " << param.name() << " has unknown engine."; + } +} + +REGISTER_LAYER_CREATOR(LRN, GetLRNLayer); + // Get relu layer according to engine. template shared_ptr > GetReLULayer(const LayerParameter& param) { diff --git a/src/caffe/layers/cudnn_conv_layer.cpp b/src/caffe/layers/cudnn_conv_layer.cpp index 3514fe2..d7b1e0d 100644 --- a/src/caffe/layers/cudnn_conv_layer.cpp +++ b/src/caffe/layers/cudnn_conv_layer.cpp @@ -1,4 +1,5 @@ #ifdef USE_CUDNN +#include #include #include "caffe/filler.hpp" @@ -24,13 +25,38 @@ void CuDNNConvolutionLayer::LayerSetUp( // Initialize CUDA streams and cuDNN. stream_ = new cudaStream_t[this->group_ * CUDNN_STREAMS_PER_GROUP]; handle_ = new cudnnHandle_t[this->group_ * CUDNN_STREAMS_PER_GROUP]; + + // Initialize algorithm arrays + fwd_algo_ = new cudnnConvolutionFwdAlgo_t[bottom.size()]; + bwd_filter_algo_= new cudnnConvolutionBwdFilterAlgo_t[bottom.size()]; + bwd_data_algo_ = new cudnnConvolutionBwdDataAlgo_t[bottom.size()]; + + // initialize size arrays + workspace_fwd_sizes_ = new size_t[bottom.size()]; + workspace_bwd_filter_sizes_ = new size_t[bottom.size()]; + workspace_bwd_data_sizes_ = new size_t[bottom.size()]; + + // workspace data workspaceSizeInBytes = 0; - workspace = NULL; + workspaceData = NULL; + workspace = new void*[this->group_ * CUDNN_STREAMS_PER_GROUP]; + + for (size_t i = 0; i < bottom.size(); ++i) { + // initialize all to default algorithms + fwd_algo_[i] = (cudnnConvolutionFwdAlgo_t)0; + bwd_filter_algo_[i] = (cudnnConvolutionBwdFilterAlgo_t)0; + bwd_data_algo_[i] = (cudnnConvolutionBwdDataAlgo_t)0; + // default algorithms don't require workspace + workspace_fwd_sizes_[i] = 0; + workspace_bwd_data_sizes_[i] = 0; + workspace_bwd_filter_sizes_[i] = 0; + } for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) { CUDA_CHECK(cudaStreamCreate(&stream_[g])); CUDNN_CHECK(cudnnCreate(&handle_[g])); CUDNN_CHECK(cudnnSetStream(handle_[g], stream_[g])); + workspace[g] = NULL; } // Set the indexing parameters. @@ -86,6 +112,10 @@ void CuDNNConvolutionLayer::Reshape( const int stride_h = stride_data[0]; const int stride_w = stride_data[1]; + // Specify workspace limit for kernels directly until we have a + // planning strategy and a rewrite of Caffe's GPU memory mangagement + size_t workspace_limit_bytes = 8*1024*1024; + for (int i = 0; i < bottom.size(); i++) { cudnn::setTensor4dDesc(&bottom_descs_[i], this->num_, @@ -98,7 +128,104 @@ void CuDNNConvolutionLayer::Reshape( this->num_output_ * this->out_spatial_dim_, this->out_spatial_dim_, width_out, 1); cudnn::setConvolutionDesc(&conv_descs_[i], bottom_descs_[i], - filter_desc_, pad_h, pad_w, stride_h, stride_w); + filter_desc_, pad_h, pad_w, + stride_h, stride_w); + + // choose forward and backward algorithms + workspace(s) + CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(handle_[0], + bottom_descs_[i], + filter_desc_, + conv_descs_[i], + top_descs_[i], + CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, + workspace_limit_bytes, + &fwd_algo_[i])); + + CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle_[0], + bottom_descs_[i], + filter_desc_, + conv_descs_[i], + top_descs_[i], + fwd_algo_[i], + &(workspace_fwd_sizes_[i]))); + + // choose backward algorithm for filter + CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(handle_[0], + bottom_descs_[i], top_descs_[i], conv_descs_[i], filter_desc_, + CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT, + workspace_limit_bytes, &bwd_filter_algo_[i]) ); + + // get workspace for backwards filter algorithm + CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(handle_[0], + bottom_descs_[i], top_descs_[i], conv_descs_[i], filter_desc_, + bwd_filter_algo_[i], &workspace_bwd_filter_sizes_[i])); + + // choose backward algo for data + CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(handle_[0], + filter_desc_, top_descs_[i], conv_descs_[i], bottom_descs_[i], + CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT, + workspace_limit_bytes, &bwd_data_algo_[i])); + + // get workspace size + CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(handle_[0], + filter_desc_, top_descs_[i], conv_descs_[i], bottom_descs_[i], + bwd_data_algo_[i], &workspace_bwd_data_sizes_[i]) ); + } + + // reduce over all workspace sizes to get a maximum to allocate / reallocate + size_t total_workspace_fwd = 0; + size_t total_workspace_bwd_data = 0; + size_t total_workspace_bwd_filter = 0; + + for (size_t i = 0; i < bottom.size(); i++) { + total_workspace_fwd = std::max(total_workspace_fwd, + workspace_fwd_sizes_[i]); + total_workspace_bwd_data = std::max(total_workspace_bwd_data, + workspace_bwd_data_sizes_[i]); + total_workspace_bwd_filter = std::max(total_workspace_bwd_filter, + workspace_bwd_filter_sizes_[i]); + } + // get max over all operations + size_t max_workspace = std::max(total_workspace_fwd, + total_workspace_bwd_data); + max_workspace = std::max(max_workspace, total_workspace_bwd_filter); + // ensure all groups have enough workspace + size_t total_max_workspace = max_workspace * + (this->group_ * CUDNN_STREAMS_PER_GROUP); + + // this is the total amount of storage needed over all groups + streams + if (total_max_workspace > workspaceSizeInBytes) { + LOG(INFO) << "Reallocating workspace storage: " << total_max_workspace; + workspaceSizeInBytes = total_max_workspace; + + // free the existing workspace and allocate a new (larger) one + cudaFree(this->workspaceData); + + cudaError_t err = cudaMalloc(&(this->workspaceData), workspaceSizeInBytes); + if (err != cudaSuccess) { + // force zero memory path + for (int i = 0; i < bottom.size(); i++) { + workspace_fwd_sizes_[i] = 0; + workspace_bwd_filter_sizes_[i] = 0; + workspace_bwd_data_sizes_[i] = 0; + fwd_algo_[i] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; + bwd_filter_algo_[i] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0; + bwd_data_algo_[i] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0; + } + + // NULL out all workspace pointers + for (int g = 0; g < (this->group_ * CUDNN_STREAMS_PER_GROUP); g++) { + workspace[g] = NULL; + } + // NULL out underlying data + workspaceData = NULL; + workspaceSizeInBytes = 0; + } + + // if we succeed in the allocation, set pointer aliases for workspaces + for (int g = 0; g < (this->group_ * CUDNN_STREAMS_PER_GROUP); g++) { + workspace[g] = reinterpret_cast(workspaceData) + g*max_workspace; + } } // Tensor descriptor for bias. @@ -128,8 +255,15 @@ CuDNNConvolutionLayer::~CuDNNConvolutionLayer() { cudnnDestroy(handle_[g]); } + cudaFree(workspaceData); delete [] stream_; delete [] handle_; + delete [] fwd_algo_; + delete [] bwd_filter_algo_; + delete [] bwd_data_algo_; + delete [] workspace_fwd_sizes_; + delete [] workspace_bwd_data_sizes_; + delete [] workspace_bwd_filter_sizes_; } INSTANTIATE_CLASS(CuDNNConvolutionLayer); diff --git a/src/caffe/layers/cudnn_conv_layer.cu b/src/caffe/layers/cudnn_conv_layer.cu index 6911520..e88e4dd 100644 --- a/src/caffe/layers/cudnn_conv_layer.cu +++ b/src/caffe/layers/cudnn_conv_layer.cu @@ -14,11 +14,6 @@ __global__ void sync_conv_groups() { } template void CuDNNConvolutionLayer::Forward_gpu( const vector*>& bottom, const vector*>& top) { - const int* kernel_shape_data = this->kernel_shape_.cpu_data(); - const int kernel_h = kernel_shape_data[0]; - const int kernel_w = kernel_shape_data[1]; - const size_t workspace_limit_bytes = - kernel_h * kernel_w * this->channels_ * sizeof(int) + 1; const Dtype* weight = this->blobs_[0]->gpu_data(); for (int i = 0; i < bottom.size(); ++i) { const Dtype* bottom_data = bottom[i]->gpu_data(); @@ -26,52 +21,13 @@ void CuDNNConvolutionLayer::Forward_gpu( // Forward through cuDNN in parallel over groups. for (int g = 0; g < this->group_; g++) { - cudnnConvolutionFwdAlgo_t algo; - - // pick the convolution algorithm - // TODO(shelhamer) this should be done during reshape - // TODO(shelhamer) the choice of automatic or manual algorithm picking - // should be exposed in proto - CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(handle_[g], - bottom_descs_[i], - filter_desc_, - conv_descs_[i], - top_descs_[i], - CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT, - workspace_limit_bytes, // memoryLimitInBytes, - &algo)); - - // get minimum size of the workspace needed for the desired algorithm - size_t workspaceSizeInBytes_temp = 0; - - CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(handle_[g], - bottom_descs_[i], - filter_desc_, - conv_descs_[i], - top_descs_[i], - algo, - &workspaceSizeInBytes_temp)); - - if (workspaceSizeInBytes_temp > workspaceSizeInBytes) { - workspaceSizeInBytes = workspaceSizeInBytes_temp; - // free the existing workspace and allocate a new (larger) one - cudaFree(this->workspace); - cudaError_t err = cudaMalloc(&(this->workspace), workspaceSizeInBytes); - if (err != cudaSuccess) { - // force zero memory path - algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM; - workspace = NULL; - workspaceSizeInBytes = 0; - } - } - // Filters. CUDNN_CHECK(cudnnConvolutionForward(handle_[g], cudnn::dataType::one, bottom_descs_[i], bottom_data + bottom_offset_ * g, filter_desc_, weight + this->weight_offset_ * g, conv_descs_[i], - algo, workspace, workspaceSizeInBytes, + fwd_algo_[i], workspace[g], workspace_fwd_sizes_[i], cudnn::dataType::zero, top_descs_[i], top_data + top_offset_ * g)); @@ -101,10 +57,12 @@ void CuDNNConvolutionLayer::Backward_gpu(const vector*>& top, if (this->param_propagate_down_[0]) { weight = this->blobs_[0]->gpu_data(); weight_diff = this->blobs_[0]->mutable_gpu_diff(); + caffe_gpu_set(this->blobs_[0]->count(), Dtype(0), weight_diff); } Dtype* bias_diff = NULL; if (this->bias_term_ && this->param_propagate_down_[1]) { bias_diff = this->blobs_[1]->mutable_gpu_diff(); + caffe_gpu_set(this->blobs_[1]->count(), Dtype(0), bias_diff); } for (int i = 0; i < top.size(); ++i) { const Dtype* top_diff = top[i]->gpu_diff(); @@ -122,11 +80,14 @@ void CuDNNConvolutionLayer::Backward_gpu(const vector*>& top, // Gradient w.r.t. weights. if (this->param_propagate_down_[0]) { const Dtype* bottom_data = bottom[i]->gpu_data(); - CUDNN_CHECK(cudnnConvolutionBackwardFilter(handle_[1*this->group_ + g], + CUDNN_CHECK(cudnnConvolutionBackwardFilter_v3( + handle_[1*this->group_ + g], cudnn::dataType::one, bottom_descs_[i], bottom_data + bottom_offset_ * g, top_descs_[i], top_diff + top_offset_ * g, conv_descs_[i], + bwd_filter_algo_[i], workspace[1*this->group_ + g], + workspace_bwd_filter_sizes_[i], cudnn::dataType::one, filter_desc_, weight_diff + this->weight_offset_ * g)); } @@ -137,11 +98,14 @@ void CuDNNConvolutionLayer::Backward_gpu(const vector*>& top, weight = this->blobs_[0]->gpu_data(); } Dtype* bottom_diff = bottom[i]->mutable_gpu_diff(); - CUDNN_CHECK(cudnnConvolutionBackwardData(handle_[2*this->group_ + g], + CUDNN_CHECK(cudnnConvolutionBackwardData_v3( + handle_[2*this->group_ + g], cudnn::dataType::one, filter_desc_, weight + this->weight_offset_ * g, top_descs_[i], top_diff + top_offset_ * g, conv_descs_[i], + bwd_data_algo_[i], workspace[2*this->group_ + g], + workspace_bwd_data_sizes_[i], cudnn::dataType::zero, bottom_descs_[i], bottom_diff + bottom_offset_ * g)); } diff --git a/src/caffe/layers/cudnn_lcn_layer.cpp b/src/caffe/layers/cudnn_lcn_layer.cpp new file mode 100644 index 0000000..866d810 --- /dev/null +++ b/src/caffe/layers/cudnn_lcn_layer.cpp @@ -0,0 +1,77 @@ +#ifdef USE_CUDNN +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void CuDNNLCNLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + LRNLayer::LayerSetUp(bottom, top); + + CUDNN_CHECK(cudnnCreate(&handle_)); + CUDNN_CHECK(cudnnCreateLRNDescriptor(&norm_desc_)); + cudnn::createTensor4dDesc(&bottom_desc_); + cudnn::createTensor4dDesc(&top_desc_); + + // create a LRN handle + handles_setup_ = true; + + size_ = this->layer_param().lrn_param().local_size(); + pre_pad_ = (size_ - 1) / 2; + alpha_ = this->layer_param().lrn_param().alpha(); + beta_ = this->layer_param().lrn_param().beta(); + k_ = this->layer_param().lrn_param().k(); +} + +template +void CuDNNLCNLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + LRNLayer::Reshape(bottom, top); + cudnn::setTensor4dDesc(&bottom_desc_, bottom[0]->num(), + this->channels_, this->height_, this->width_); + cudnn::setTensor4dDesc(&top_desc_, bottom[0]->num(), + this->channels_, this->height_, this->width_); + CUDNN_CHECK(cudnnSetLRNDescriptor(norm_desc_, size_, alpha_, beta_, k_)); + + // allocate / reallocate tempData buffers + size_t totalSizeInBytes = sizeof(Dtype)*bottom[0]->num()* \ + this->channels_*this->height_*this->width_; + + if (totalSizeInBytes > tempDataSize) { + tempDataSize = totalSizeInBytes; + + cudaFree(tempData1); + cudaFree(tempData2); + + // allocate new buffers + CUDA_CHECK(cudaMalloc(&tempData1, totalSizeInBytes)); + CUDA_CHECK(cudaMalloc(&tempData2, totalSizeInBytes)); + } +} + +template +CuDNNLCNLayer::~CuDNNLCNLayer() { + // Check that handles have been setup before destroying. + if (!handles_setup_) { return; } + + cudnnDestroyTensorDescriptor(bottom_desc_); + cudnnDestroyTensorDescriptor(top_desc_); + + // destroy LRN handle + cudnnDestroy(handle_); + + // free temp buffers + cudaFree(tempData1); + cudaFree(tempData2); +} + +INSTANTIATE_CLASS(CuDNNLCNLayer); + +} // namespace caffe +#endif diff --git a/src/caffe/layers/cudnn_lcn_layer.cu b/src/caffe/layers/cudnn_lcn_layer.cu new file mode 100644 index 0000000..c07ade7 --- /dev/null +++ b/src/caffe/layers/cudnn_lcn_layer.cu @@ -0,0 +1,50 @@ +#ifdef USE_CUDNN +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void CuDNNLCNLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + + CUDNN_CHECK(cudnnDivisiveNormalizationForward( + handle_, norm_desc_, CUDNN_DIVNORM_PRECOMPUTED_MEANS, + cudnn::dataType::one, + bottom_desc_, bottom_data, + NULL, // srcMeansData + this->tempData1, this->tempData2, + cudnn::dataType::zero, + top_desc_, top_data) ); +} + +template +void CuDNNLCNLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* top_data = top[0]->gpu_data(); + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + + CUDNN_CHECK(cudnnDivisiveNormalizationBackward( + handle_, norm_desc_, CUDNN_DIVNORM_PRECOMPUTED_MEANS, + cudnn::dataType::one, + bottom_desc_, bottom_data, + NULL, top_diff, // NULL - srcMeansData + this->tempData1, this->tempData2, + cudnn::dataType::zero, + bottom_desc_, bottom_diff, + NULL) ); +} + +INSTANTIATE_LAYER_GPU_FUNCS(CuDNNLCNLayer); + +} // namespace caffe +#endif diff --git a/src/caffe/layers/cudnn_lrn_layer.cpp b/src/caffe/layers/cudnn_lrn_layer.cpp new file mode 100644 index 0000000..6e99214 --- /dev/null +++ b/src/caffe/layers/cudnn_lrn_layer.cpp @@ -0,0 +1,57 @@ +#ifdef USE_CUDNN +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void CuDNNLRNLayer::LayerSetUp(const vector*>& bottom, + const vector*>& top) { + LRNLayer::LayerSetUp(bottom, top); + + CUDNN_CHECK(cudnnCreate(&handle_)); + CUDNN_CHECK(cudnnCreateLRNDescriptor(&norm_desc_)); + cudnn::createTensor4dDesc(&bottom_desc_); + cudnn::createTensor4dDesc(&top_desc_); + + // create a LRN handle + handles_setup_ = true; + + size_ = this->layer_param().lrn_param().local_size(); + alpha_ = this->layer_param().lrn_param().alpha(); + beta_ = this->layer_param().lrn_param().beta(); + k_ = this->layer_param().lrn_param().k(); +} + +template +void CuDNNLRNLayer::Reshape(const vector*>& bottom, + const vector*>& top) { + LRNLayer::Reshape(bottom, top); + cudnn::setTensor4dDesc(&bottom_desc_, bottom[0]->num(), + this->channels_, this->height_, this->width_); + cudnn::setTensor4dDesc(&top_desc_, bottom[0]->num(), + this->channels_, this->height_, this->width_); + CUDNN_CHECK(cudnnSetLRNDescriptor(norm_desc_, size_, alpha_, beta_, k_)); +} + +template +CuDNNLRNLayer::~CuDNNLRNLayer() { + // Check that handles have been setup before destroying. + if (!handles_setup_) { return; } + + cudnnDestroyTensorDescriptor(bottom_desc_); + cudnnDestroyTensorDescriptor(top_desc_); + + // destroy LRN handle + cudnnDestroy(handle_); +} + +INSTANTIATE_CLASS(CuDNNLRNLayer); + +} // namespace caffe +#endif diff --git a/src/caffe/layers/cudnn_lrn_layer.cu b/src/caffe/layers/cudnn_lrn_layer.cu new file mode 100644 index 0000000..f992303 --- /dev/null +++ b/src/caffe/layers/cudnn_lrn_layer.cu @@ -0,0 +1,48 @@ +#ifdef USE_CUDNN +#include + +#include "caffe/filler.hpp" +#include "caffe/layer.hpp" +#include "caffe/util/im2col.hpp" +#include "caffe/util/math_functions.hpp" +#include "caffe/vision_layers.hpp" + +namespace caffe { + +template +void CuDNNLRNLayer::Forward_gpu(const vector*>& bottom, + const vector*>& top) { + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* top_data = top[0]->mutable_gpu_data(); + + CUDNN_CHECK(cudnnLRNCrossChannelForward( + handle_, norm_desc_, CUDNN_LRN_CROSS_CHANNEL_DIM1, + cudnn::dataType::one, + bottom_desc_, bottom_data, + cudnn::dataType::zero, + top_desc_, top_data) ); +} + +template +void CuDNNLRNLayer::Backward_gpu(const vector*>& top, + const vector& propagate_down, const vector*>& bottom) { + const Dtype* top_diff = top[0]->gpu_diff(); + const Dtype* top_data = top[0]->gpu_data(); + const Dtype* bottom_data = bottom[0]->gpu_data(); + Dtype* bottom_diff = bottom[0]->mutable_gpu_diff(); + + CUDNN_CHECK(cudnnLRNCrossChannelBackward( + handle_, norm_desc_, CUDNN_LRN_CROSS_CHANNEL_DIM1, + cudnn::dataType::one, + top_desc_, top_data, + top_desc_, top_diff, + bottom_desc_, bottom_data, + cudnn::dataType::zero, + bottom_desc_, bottom_diff) ); +} + +INSTANTIATE_LAYER_GPU_FUNCS(CuDNNLRNLayer); + +}; // namespace caffe + +#endif diff --git a/src/caffe/layers/lrn_layer.cpp b/src/caffe/layers/lrn_layer.cpp index 36c1ace..d18a04e 100644 --- a/src/caffe/layers/lrn_layer.cpp +++ b/src/caffe/layers/lrn_layer.cpp @@ -254,6 +254,5 @@ STUB_GPU_BACKWARD(LRNLayer, CrossChannelBackward); #endif INSTANTIATE_CLASS(LRNLayer); -REGISTER_LAYER_CLASS(LRN); } // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index f52c941..af01b47 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -721,6 +721,12 @@ message LRNParameter { } optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS]; optional float k = 5 [default = 1.]; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 6 [default = DEFAULT]; } message MemoryDataParameter { diff --git a/src/caffe/test/test_lrn_layer.cpp b/src/caffe/test/test_lrn_layer.cpp index c4e2f8e..78cf2d9 100644 --- a/src/caffe/test/test_lrn_layer.cpp +++ b/src/caffe/test/test_lrn_layer.cpp @@ -246,5 +246,201 @@ TYPED_TEST(LRNLayerTest, TestGradientWithinChannel) { this->blob_top_vec_); } +#ifdef USE_CUDNN +template +class CuDNNLRNLayerTest : public GPUDeviceTest { + protected: + CuDNNLRNLayerTest() + : epsilon_(Dtype(1e-5)), + blob_bottom_(new Blob()), + blob_top_(new Blob()) {} + virtual void SetUp() { + Caffe::set_random_seed(1701); + blob_bottom_->Reshape(2, 7, 3, 3); + // fill the values + FillerParameter filler_param; + GaussianFiller filler(filler_param); + filler.Fill(this->blob_bottom_); + blob_bottom_vec_.push_back(blob_bottom_); + blob_top_vec_.push_back(blob_top_); + } + virtual ~CuDNNLRNLayerTest() { delete blob_bottom_; delete blob_top_; } + void ReferenceLRNForward(const Blob& blob_bottom, + const LayerParameter& layer_param, Blob* blob_top); + + Dtype epsilon_; + Blob* const blob_bottom_; + Blob* const blob_top_; + vector*> blob_bottom_vec_; + vector*> blob_top_vec_; +}; + +template +void CuDNNLRNLayerTest::ReferenceLRNForward( + const Blob& blob_bottom, const LayerParameter& layer_param, + Blob* blob_top) { + typedef TypeParam Dtype; + blob_top->Reshape(blob_bottom.num(), blob_bottom.channels(), + blob_bottom.height(), blob_bottom.width()); + Dtype* top_data = blob_top->mutable_cpu_data(); + LRNParameter lrn_param = layer_param.lrn_param(); + Dtype alpha = lrn_param.alpha(); + Dtype beta = lrn_param.beta(); + int size = lrn_param.local_size(); + switch (lrn_param.norm_region()) { + case LRNParameter_NormRegion_ACROSS_CHANNELS: + for (int n = 0; n < blob_bottom.num(); ++n) { + for (int c = 0; c < blob_bottom.channels(); ++c) { + for (int h = 0; h < blob_bottom.height(); ++h) { + for (int w = 0; w < blob_bottom.width(); ++w) { + int c_start = c - (size - 1) / 2; + int c_end = min(c_start + size, blob_bottom.channels()); + c_start = max(c_start, 0); + Dtype scale = 1.; + for (int i = c_start; i < c_end; ++i) { + Dtype value = blob_bottom.data_at(n, i, h, w); + scale += value * value * alpha / size; + } + *(top_data + blob_top->offset(n, c, h, w)) = + blob_bottom.data_at(n, c, h, w) / pow(scale, beta); + } + } + } + } + break; + case LRNParameter_NormRegion_WITHIN_CHANNEL: + for (int n = 0; n < blob_bottom.num(); ++n) { + for (int c = 0; c < blob_bottom.channels(); ++c) { + for (int h = 0; h < blob_bottom.height(); ++h) { + int h_start = h - (size - 1) / 2; + int h_end = min(h_start + size, blob_bottom.height()); + h_start = max(h_start, 0); + for (int w = 0; w < blob_bottom.width(); ++w) { + Dtype scale = 1.; + int w_start = w - (size - 1) / 2; + int w_end = min(w_start + size, blob_bottom.width()); + w_start = max(w_start, 0); + for (int nh = h_start; nh < h_end; ++nh) { + for (int nw = w_start; nw < w_end; ++nw) { + Dtype value = blob_bottom.data_at(n, c, nh, nw); + scale += value * value * alpha / (size * size); + } + } + *(top_data + blob_top->offset(n, c, h, w)) = + blob_bottom.data_at(n, c, h, w) / pow(scale, beta); + } + } + } + } + break; + default: + LOG(FATAL) << "Unknown normalization region."; + } +} + +TYPED_TEST_CASE(CuDNNLRNLayerTest, TestDtypes); + +TYPED_TEST(CuDNNLRNLayerTest, TestForwardAcrossChannelsCuDNN) { + // typedef typename TypeParam::Dtype Dtype; + LayerParameter layer_param; + CuDNNLRNLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + Blob top_reference; + this->ReferenceLRNForward(*(this->blob_bottom_), layer_param, + &top_reference); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + EXPECT_NEAR(this->blob_top_->cpu_data()[i], top_reference.cpu_data()[i], + this->epsilon_); + } +} + +TYPED_TEST(CuDNNLRNLayerTest, TestForwardAcrossChannelsLargeRegionCuDNN) { + typedef TypeParam Dtype; + LayerParameter layer_param; + layer_param.mutable_lrn_param()->set_local_size(15); + CuDNNLRNLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + Blob top_reference; + this->ReferenceLRNForward(*(this->blob_bottom_), layer_param, + &top_reference); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + EXPECT_NEAR(this->blob_top_->cpu_data()[i], top_reference.cpu_data()[i], + this->epsilon_); + } +} + +TYPED_TEST(CuDNNLRNLayerTest, TestGradientAcrossChannelsCuDNN) { + typedef TypeParam Dtype; + LayerParameter layer_param; + CuDNNLRNLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < this->blob_top_->count(); ++i) { + this->blob_top_->mutable_cpu_diff()[i] = 1.; + } + vector propagate_down(this->blob_bottom_vec_.size(), true); + layer.Backward(this->blob_top_vec_, propagate_down, + this->blob_bottom_vec_); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(CuDNNLRNLayerTest, TestForwardWithinChannel) { + typedef TypeParam Dtype; + LayerParameter layer_param; + layer_param.mutable_lrn_param()->set_norm_region( + LRNParameter_NormRegion_WITHIN_CHANNEL); + layer_param.mutable_lrn_param()->set_local_size(3); + CuDNNLCNLayer layer(layer_param); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + Blob top_reference; + this->ReferenceLRNForward(*(this->blob_bottom_), layer_param, + &top_reference); + for (int i = 0; i < this->blob_bottom_->count(); ++i) { + EXPECT_NEAR(this->blob_top_->cpu_data()[i], top_reference.cpu_data()[i], + this->epsilon_); + } +} + +TYPED_TEST(CuDNNLRNLayerTest, TestGradientWithinChannel) { + typedef TypeParam Dtype; + LayerParameter layer_param; + layer_param.mutable_lrn_param()->set_norm_region( + LRNParameter_NormRegion_WITHIN_CHANNEL); + layer_param.mutable_lrn_param()->set_local_size(3); + CuDNNLCNLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < this->blob_top_->count(); ++i) { + this->blob_top_->mutable_cpu_diff()[i] = 1.; + } + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +TYPED_TEST(CuDNNLRNLayerTest, TestGradientAcrossChannelsLargeRegionCuDNN) { + typedef TypeParam Dtype; + LayerParameter layer_param; + layer_param.mutable_lrn_param()->set_local_size(15); + CuDNNLRNLayer layer(layer_param); + GradientChecker checker(1e-2, 1e-2); + layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); + layer.Forward(this->blob_bottom_vec_, this->blob_top_vec_); + for (int i = 0; i < this->blob_top_->count(); ++i) { + this->blob_top_->mutable_cpu_diff()[i] = 1.; + } + vector propagate_down(this->blob_bottom_vec_.size(), true); + layer.Backward(this->blob_top_vec_, propagate_down, + this->blob_bottom_vec_); + checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, + this->blob_top_vec_); +} + +#endif } // namespace caffe -- 2.7.4