--- /dev/null
+#ifndef CAFFE_CUDNN_DECONV_LAYER_HPP_
+#define CAFFE_CUDNN_DECONV_LAYER_HPP_
+
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+#include "caffe/layers/deconv_layer.hpp"
+
+namespace caffe {
+
+#ifdef USE_CUDNN
+/*
+ * @brief cuDNN implementation of DeConvolutionLayer.
+ * Fallback to DeConvolutionLayer for CPU mode.
+ *
+ * cuDNN accelerates deconvolution through forward kernels for filtering and
+ * bias plus backward kernels for the gradient w.r.t. the filters, biases, and
+ * inputs. Caffe + cuDNN further speeds up the computation through forward
+ * parallelism across groups and backward parallelism across gradients.
+*/
+template <typename Dtype>
+class CuDNNDeconvolutionLayer : public DeconvolutionLayer<Dtype> {
+public:
+ explicit CuDNNDeconvolutionLayer(const LayerParameter& param)
+ : DeconvolutionLayer<Dtype>(param), handles_setup_(false) {}
+ virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual ~CuDNNDeconvolutionLayer();
+
+protected:
+ virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+ const vector<Blob<Dtype>*>& top);
+ virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down,
+ const vector<Blob<Dtype>*>& bottom);
+
+ bool handles_setup_;
+ cudnnHandle_t* handle_;
+ cudaStream_t* stream_;
+
+ // algorithms for forward and backwards convolutions
+ cudnnConvolutionFwdAlgo_t *fwd_algo_;
+ cudnnConvolutionBwdFilterAlgo_t *bwd_filter_algo_;
+ cudnnConvolutionBwdDataAlgo_t *bwd_data_algo_;
+
+ vector<cudnnTensorDescriptor_t> bottom_descs_, top_descs_;
+ cudnnTensorDescriptor_t bias_desc_;
+ cudnnFilterDescriptor_t filter_desc_;
+ vector<cudnnConvolutionDescriptor_t> conv_descs_;
+ int bottom_offset_, top_offset_, bias_offset_;
+
+ size_t *workspace_fwd_sizes_;
+ size_t *workspace_bwd_data_sizes_;
+ size_t *workspace_bwd_filter_sizes_;
+ size_t workspaceSizeInBytes; // size of underlying storage
+ void *workspaceData; // underlying storage
+ void **workspace; // aliases into workspaceData
+};
+#endif
+
+} // namespace caffe
+
+#endif // CAFFE_CUDNN_DECONV_LAYER_HPP_
#include "caffe/layer.hpp"
#include "caffe/layer_factory.hpp"
#include "caffe/layers/conv_layer.hpp"
+#include "caffe/layers/deconv_layer.hpp"
#include "caffe/layers/lrn_layer.hpp"
#include "caffe/layers/pooling_layer.hpp"
#include "caffe/layers/relu_layer.hpp"
#ifdef USE_CUDNN
#include "caffe/layers/cudnn_conv_layer.hpp"
+#include "caffe/layers/cudnn_deconv_layer.hpp"
#include "caffe/layers/cudnn_lcn_layer.hpp"
#include "caffe/layers/cudnn_lrn_layer.hpp"
#include "caffe/layers/cudnn_pooling_layer.hpp"
REGISTER_LAYER_CREATOR(Convolution, GetConvolutionLayer);
+// Get deconvolution layer according to engine.
+template <typename Dtype>
+shared_ptr<Layer<Dtype> > GetDeconvolutionLayer(const LayerParameter& param) {
+ ConvolutionParameter conv_param = param.convolution_param();
+ ConvolutionParameter_Engine engine = conv_param.engine();
+#ifdef USE_CUDNN
+ bool use_dilation = false;
+ for (int i = 0; i < conv_param.dilation_size(); ++i) {
+ if (conv_param.dilation(i) > 1) {
+ use_dilation = true;
+ }
+ }
+#endif
+ if (engine == ConvolutionParameter_Engine_DEFAULT) {
+ engine = ConvolutionParameter_Engine_CAFFE;
+#ifdef USE_CUDNN
+ if (!use_dilation) {
+ engine = ConvolutionParameter_Engine_CUDNN;
+ }
+#endif
+ }
+ if (engine == ConvolutionParameter_Engine_CAFFE) {
+ return shared_ptr<Layer<Dtype> >(new DeconvolutionLayer<Dtype>(param));
+#ifdef USE_CUDNN
+ } else if (engine == ConvolutionParameter_Engine_CUDNN) {
+ if (use_dilation) {
+ LOG(FATAL) << "CuDNN doesn't support the dilated deconvolution at Layer "
+ << param.name();
+ }
+ return shared_ptr<Layer<Dtype> >(new CuDNNDeconvolutionLayer<Dtype>(param));
+#endif
+ } else {
+ LOG(FATAL) << "Layer " << param.name() << " has unknown engine.";
+ throw; // Avoids missing return warning
+ }
+}
+
+REGISTER_LAYER_CREATOR(Deconvolution, GetDeconvolutionLayer);
+
// Get pooling layer according to engine.
template <typename Dtype>
shared_ptr<Layer<Dtype> > GetPoolingLayer(const LayerParameter& param) {
--- /dev/null
+#ifdef USE_CUDNN
+#include <algorithm>
+#include <vector>
+
+#include "caffe/layers/cudnn_deconv_layer.hpp"
+
+namespace caffe {
+
+// Set to three for the benefit of the backward pass, which
+// can use separate streams for calculating the gradient w.r.t.
+// bias, filter weights, and bottom data for each group independently
+#define CUDNN_STREAMS_PER_GROUP 3
+
+/**
+ * TODO(dox) explain cuDNN interface
+ */
+template <typename Dtype>
+void CuDNNDeconvolutionLayer<Dtype>::LayerSetUp(
+ const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+ DeconvolutionLayer<Dtype>::LayerSetUp(bottom, top);
+ // Initialize CUDA streams and cuDNN.
+ stream_ = new cudaStream_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
+ handle_ = new cudnnHandle_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
+
+ // Initialize algorithm arrays
+ fwd_algo_ = new cudnnConvolutionFwdAlgo_t[bottom.size()];
+ bwd_filter_algo_= new cudnnConvolutionBwdFilterAlgo_t[bottom.size()];
+ bwd_data_algo_ = new cudnnConvolutionBwdDataAlgo_t[bottom.size()];
+
+ // initialize size arrays
+ workspace_fwd_sizes_ = new size_t[bottom.size()];
+ workspace_bwd_filter_sizes_ = new size_t[bottom.size()];
+ workspace_bwd_data_sizes_ = new size_t[bottom.size()];
+
+ // workspace data
+ workspaceSizeInBytes = 0;
+ workspaceData = NULL;
+ workspace = new void*[this->group_ * CUDNN_STREAMS_PER_GROUP];
+
+ for (size_t i = 0; i < bottom.size(); ++i) {
+ // initialize all to default algorithms
+ fwd_algo_[i] = (cudnnConvolutionFwdAlgo_t)0;
+ bwd_filter_algo_[i] = (cudnnConvolutionBwdFilterAlgo_t)0;
+ bwd_data_algo_[i] = (cudnnConvolutionBwdDataAlgo_t)0;
+ // default algorithms don't require workspace
+ workspace_fwd_sizes_[i] = 0;
+ workspace_bwd_data_sizes_[i] = 0;
+ workspace_bwd_filter_sizes_[i] = 0;
+ }
+
+ for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) {
+ CUDA_CHECK(cudaStreamCreate(&stream_[g]));
+ CUDNN_CHECK(cudnnCreate(&handle_[g]));
+ CUDNN_CHECK(cudnnSetStream(handle_[g], stream_[g]));
+ workspace[g] = NULL;
+ }
+
+ // Set the indexing parameters.
+ bias_offset_ = (this->num_output_ / this->group_);
+
+ // Create filter descriptor.
+ const int* kernel_shape_data = this->kernel_shape_.cpu_data();
+ const int kernel_h = kernel_shape_data[0];
+ const int kernel_w = kernel_shape_data[1];
+ cudnn::createFilterDesc<Dtype>(&filter_desc_,
+ this->channels_ / this->group_,
+ this->num_output_ / this->group_,
+ kernel_h,
+ kernel_w);
+
+ // Create tensor descriptor(s) for data and corresponding convolution(s).
+ for (int i = 0; i < bottom.size(); i++) {
+ cudnnTensorDescriptor_t bottom_desc;
+ cudnn::createTensor4dDesc<Dtype>(&bottom_desc);
+ bottom_descs_.push_back(bottom_desc);
+ cudnnTensorDescriptor_t top_desc;
+ cudnn::createTensor4dDesc<Dtype>(&top_desc);
+ top_descs_.push_back(top_desc);
+ cudnnConvolutionDescriptor_t conv_desc;
+ cudnn::createConvolutionDesc<Dtype>(&conv_desc);
+ conv_descs_.push_back(conv_desc);
+ }
+
+ // Tensor descriptor for bias.
+ if (this->bias_term_) {
+ cudnn::createTensor4dDesc<Dtype>(&bias_desc_);
+ }
+
+ handles_setup_ = true;
+}
+
+template <typename Dtype>
+void CuDNNDeconvolutionLayer<Dtype>::Reshape(
+ const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+ DeconvolutionLayer<Dtype>::Reshape(bottom, top);
+ CHECK_EQ(2, this->num_spatial_axes_)
+ << "CuDNNDeconvolutionLayer input must have 2 spatial axes "
+ << "(e.g., height and width). "
+ << "Use 'engine: CAFFE' for general ND convolution.";
+ bottom_offset_ = this->bottom_dim_ / this->group_;
+ top_offset_ = this->top_dim_ / this->group_;
+ const int height = bottom[0]->shape(this->channel_axis_ + 1);
+ const int width = bottom[0]->shape(this->channel_axis_ + 2);
+ const int height_out = top[0]->shape(this->channel_axis_ + 1);
+ const int width_out = top[0]->shape(this->channel_axis_ + 2);
+ const int* pad_data = this->pad_.cpu_data();
+ const int pad_h = pad_data[0];
+ const int pad_w = pad_data[1];
+ const int* stride_data = this->stride_.cpu_data();
+ const int stride_h = stride_data[0];
+ const int stride_w = stride_data[1];
+
+ // Specify workspace limit for kernels directly until we have a
+ // planning strategy and a rewrite of Caffe's GPU memory mangagement
+ size_t workspace_limit_bytes = 8*1024*1024;
+
+ for (int i = 0; i < bottom.size(); i++) {
+ cudnn::setTensor4dDesc<Dtype>(&bottom_descs_[i],
+ this->num_,
+ this->channels_ / this->group_,
+ height,
+ width,
+ this->channels_ * height * width,
+ height * width,
+ width,
+ 1);
+ cudnn::setTensor4dDesc<Dtype>(&top_descs_[i],
+ this->num_,
+ this->num_output_ / this->group_,
+ height_out,
+ width_out,
+ this->num_output_ * height_out * width_out,
+ height_out * width_out,
+ width_out,
+ 1);
+ cudnn::setConvolutionDesc<Dtype>(&conv_descs_[i],
+ top_descs_[i],
+ filter_desc_,
+ pad_h,
+ pad_w,
+ stride_h,
+ stride_w);
+
+ // choose forward and backward algorithms + workspace(s)
+ CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
+ handle_[0],
+ top_descs_[i],
+ filter_desc_,
+ conv_descs_[i],
+ bottom_descs_[i],
+ CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
+ workspace_limit_bytes,
+ &fwd_algo_[i]));
+
+ // We have found that CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM is
+ // buggy. Thus, if this algo was chosen, choose winograd instead. If
+ // winograd is not supported or workspace is larger than threshold, choose
+ // implicit_gemm instead.
+ if (fwd_algo_[i] == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) {
+ size_t winograd_workspace_size;
+ cudnnStatus_t status = cudnnGetConvolutionForwardWorkspaceSize(
+ handle_[0],
+ top_descs_[i],
+ filter_desc_,
+ conv_descs_[i],
+ bottom_descs_[i],
+ CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
+ &winograd_workspace_size);
+ if (status != CUDNN_STATUS_SUCCESS ||
+ winograd_workspace_size >= workspace_limit_bytes) {
+ fwd_algo_[i] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
+ } else {
+ fwd_algo_[i] = CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD;
+ }
+ }
+
+ CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
+ handle_[0],
+ top_descs_[i],
+ filter_desc_,
+ conv_descs_[i],
+ bottom_descs_[i],
+ fwd_algo_[i],
+ &(workspace_fwd_sizes_[i])));
+
+ // choose backward algorithm for filter
+ CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
+ handle_[0],
+ top_descs_[i],
+ bottom_descs_[i],
+ conv_descs_[i],
+ filter_desc_,
+ CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
+ workspace_limit_bytes,
+ &bwd_filter_algo_[i]));
+
+ // get workspace for backwards filter algorithm
+ CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
+ handle_[0],
+ top_descs_[i],
+ bottom_descs_[i],
+ conv_descs_[i],
+ filter_desc_,
+ bwd_filter_algo_[i],
+ &workspace_bwd_filter_sizes_[i]));
+
+ // choose backward algo for data
+ CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
+ handle_[0],
+ filter_desc_,
+ bottom_descs_[i],
+ conv_descs_[i],
+ top_descs_[i],
+ CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
+ workspace_limit_bytes,
+ &bwd_data_algo_[i]));
+
+ // get workspace size
+ CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(
+ handle_[0],
+ filter_desc_,
+ bottom_descs_[i],
+ conv_descs_[i],
+ top_descs_[i],
+ bwd_data_algo_[i],
+ &workspace_bwd_data_sizes_[i]));
+ }
+
+ // reduce over all workspace sizes to get a maximum to allocate / reallocate
+ size_t total_workspace_fwd = 0;
+ size_t total_workspace_bwd_data = 0;
+ size_t total_workspace_bwd_filter = 0;
+
+ for (size_t i = 0; i < bottom.size(); i++) {
+ total_workspace_fwd = std::max(total_workspace_fwd,
+ workspace_fwd_sizes_[i]);
+ total_workspace_bwd_data = std::max(total_workspace_bwd_data,
+ workspace_bwd_data_sizes_[i]);
+ total_workspace_bwd_filter = std::max(total_workspace_bwd_filter,
+ workspace_bwd_filter_sizes_[i]);
+ }
+ // get max over all operations
+ size_t max_workspace = std::max(total_workspace_fwd,
+ total_workspace_bwd_data);
+ max_workspace = std::max(max_workspace, total_workspace_bwd_filter);
+ // ensure all groups have enough workspace
+ size_t total_max_workspace = max_workspace *
+ (this->group_ * CUDNN_STREAMS_PER_GROUP);
+
+ // this is the total amount of storage needed over all groups + streams
+ if (total_max_workspace > workspaceSizeInBytes) {
+ DLOG(INFO) << "Reallocating workspace storage: " << total_max_workspace;
+ workspaceSizeInBytes = total_max_workspace;
+
+ // free the existing workspace and allocate a new (larger) one
+ cudaFree(this->workspaceData);
+
+ cudaError_t err = cudaMalloc(&(this->workspaceData), workspaceSizeInBytes);
+ if (err != cudaSuccess) {
+ // force zero memory path
+ for (int i = 0; i < bottom.size(); i++) {
+ workspace_fwd_sizes_[i] = 0;
+ workspace_bwd_filter_sizes_[i] = 0;
+ workspace_bwd_data_sizes_[i] = 0;
+ fwd_algo_[i] = CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING;
+ bwd_filter_algo_[i] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
+ bwd_data_algo_[i] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
+ }
+
+ // NULL out all workspace pointers
+ for (int g = 0; g < (this->group_ * CUDNN_STREAMS_PER_GROUP); g++) {
+ workspace[g] = NULL;
+ }
+ // NULL out underlying data
+ workspaceData = NULL;
+ workspaceSizeInBytes = 0;
+ }
+
+ // if we succeed in the allocation, set pointer aliases for workspaces
+ for (int g = 0; g < (this->group_ * CUDNN_STREAMS_PER_GROUP); g++) {
+ workspace[g] = reinterpret_cast<char *>(workspaceData) + g*max_workspace;
+ }
+ }
+
+ // Tensor descriptor for bias.
+ if (this->bias_term_) {
+ cudnn::setTensor4dDesc<Dtype>(
+ &bias_desc_, 1, this->num_output_ / this->group_, 1, 1);
+ }
+}
+
+template <typename Dtype>
+CuDNNDeconvolutionLayer<Dtype>::~CuDNNDeconvolutionLayer() {
+ // Check that handles have been setup before destroying.
+ if (!handles_setup_) { return; }
+
+ for (int i = 0; i < bottom_descs_.size(); i++) {
+ cudnnDestroyTensorDescriptor(bottom_descs_[i]);
+ cudnnDestroyTensorDescriptor(top_descs_[i]);
+ cudnnDestroyConvolutionDescriptor(conv_descs_[i]);
+ }
+ if (this->bias_term_) {
+ cudnnDestroyTensorDescriptor(bias_desc_);
+ }
+ cudnnDestroyFilterDescriptor(filter_desc_);
+
+ for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) {
+ cudaStreamDestroy(stream_[g]);
+ cudnnDestroy(handle_[g]);
+ }
+
+ cudaFree(workspaceData);
+ delete [] workspace;
+ delete [] stream_;
+ delete [] handle_;
+ delete [] fwd_algo_;
+ delete [] bwd_filter_algo_;
+ delete [] bwd_data_algo_;
+ delete [] workspace_fwd_sizes_;
+ delete [] workspace_bwd_data_sizes_;
+ delete [] workspace_bwd_filter_sizes_;
+}
+
+INSTANTIATE_CLASS(CuDNNDeconvolutionLayer);
+
+} // namespace caffe
+#endif
--- /dev/null
+#ifdef USE_CUDNN
+#include <vector>
+
+#include "caffe/layers/cudnn_deconv_layer.hpp"
+
+namespace caffe {
+
+__global__ void sync_deconv_groups() {}
+
+template <typename Dtype>
+void CuDNNDeconvolutionLayer<Dtype>::Forward_gpu(
+ const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+ const Dtype* weight = this->blobs_[0]->gpu_data();
+ for (int i = 0; i < bottom.size(); ++i) {
+ const Dtype* bottom_data = bottom[i]->gpu_data();
+ Dtype* top_data = top[i]->mutable_gpu_data();
+
+ // Forward through cuDNN in parallel over groups.
+ for (int g = 0; g < this->group_; g++) {
+ // Filters.
+ CUDNN_CHECK(cudnnConvolutionBackwardData(
+ handle_[g],
+ cudnn::dataType<Dtype>::one,
+ filter_desc_,
+ weight + this->weight_offset_ * g,
+ bottom_descs_[i],
+ bottom_data + bottom_offset_ * g,
+ conv_descs_[i],
+ bwd_data_algo_[i],
+ workspace[g],
+ workspace_bwd_data_sizes_[i],
+ cudnn::dataType<Dtype>::zero,
+ top_descs_[i],
+ top_data + top_offset_ * g));
+
+ // Bias.
+ if (this->bias_term_) {
+ const Dtype* bias_data = this->blobs_[1]->gpu_data();
+ CUDNN_CHECK(cudnnAddTensor(handle_[g],
+ cudnn::dataType<Dtype>::one,
+ bias_desc_,
+ bias_data + bias_offset_ * g,
+ cudnn::dataType<Dtype>::one,
+ top_descs_[i],
+ top_data + top_offset_ * g));
+ }
+ }
+
+ // Synchronize the work across groups, each of which went into its own
+ // stream, by launching an empty kernel into the default (null) stream.
+ // NOLINT_NEXT_LINE(whitespace/operators)
+ sync_deconv_groups<<<1, 1>>>();
+ }
+}
+
+template <typename Dtype>
+void CuDNNDeconvolutionLayer<Dtype>::Backward_gpu(
+ const vector<Blob<Dtype>*>& top,
+ const vector<bool>& propagate_down,
+ const vector<Blob<Dtype>*>& bottom) {
+ const Dtype* weight = NULL;
+ Dtype* weight_diff = NULL;
+ if (this->param_propagate_down_[0]) {
+ weight = this->blobs_[0]->gpu_data();
+ weight_diff = this->blobs_[0]->mutable_gpu_diff();
+ }
+ Dtype* bias_diff = NULL;
+ if (this->bias_term_ && this->param_propagate_down_[1]) {
+ bias_diff = this->blobs_[1]->mutable_gpu_diff();
+ }
+ for (int i = 0; i < top.size(); ++i) {
+ const Dtype* top_diff = top[i]->gpu_diff();
+ // Backward through cuDNN in parallel over groups and gradients.
+ for (int g = 0; g < this->group_; g++) {
+ // Gradient w.r.t. bias.
+ if (this->bias_term_ && this->param_propagate_down_[1]) {
+ CUDNN_CHECK(cudnnConvolutionBackwardBias(handle_[0 * this->group_ + g],
+ cudnn::dataType<Dtype>::one,
+ top_descs_[i],
+ top_diff + top_offset_ * g,
+ cudnn::dataType<Dtype>::one,
+ bias_desc_,
+ bias_diff + bias_offset_ * g));
+ }
+
+ // Gradient w.r.t. weights.
+ if (this->param_propagate_down_[0]) {
+ const Dtype* bottom_data = bottom[i]->gpu_data();
+ CUDNN_CHECK(cudnnConvolutionBackwardFilter(
+ handle_[1 * this->group_ + g],
+ cudnn::dataType<Dtype>::one,
+ top_descs_[i],
+ top_diff + top_offset_ * g,
+ bottom_descs_[i],
+ bottom_data + bottom_offset_ * g,
+ conv_descs_[i],
+ bwd_filter_algo_[i],
+ workspace[1 * this->group_ + g],
+ workspace_bwd_filter_sizes_[i],
+ cudnn::dataType<Dtype>::one,
+ filter_desc_,
+ weight_diff + this->weight_offset_ * g));
+ }
+
+ // Gradient w.r.t. bottom data.
+ if (propagate_down[i]) {
+ if (weight == NULL) {
+ weight = this->blobs_[0]->gpu_data();
+ }
+ Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
+ CUDNN_CHECK(
+ cudnnConvolutionForward(handle_[2 * this->group_ + g],
+ cudnn::dataType<Dtype>::one,
+ top_descs_[i],
+ top_diff + top_offset_ * g,
+ filter_desc_,
+ weight + this->weight_offset_ * g,
+ conv_descs_[i],
+ fwd_algo_[i],
+ workspace[2 * this->group_ + g],
+ workspace_fwd_sizes_[i],
+ cudnn::dataType<Dtype>::zero,
+ bottom_descs_[i],
+ bottom_diff + bottom_offset_ * g));
+ }
+ }
+
+ // Synchronize the work across groups, each of which went into its own
+ // stream, by launching an empty kernel into the default (null) stream.
+ // NOLINT_NEXT_LINE(whitespace/operators)
+ sync_deconv_groups<<<1, 1>>>();
+ }
+}
+
+INSTANTIATE_LAYER_GPU_FUNCS(CuDNNDeconvolutionLayer);
+
+} // namespace caffe
+#endif
#endif
INSTANTIATE_CLASS(DeconvolutionLayer);
-REGISTER_LAYER_CLASS(Deconvolution);
} // namespace caffe
#include "caffe/common.hpp"
#include "caffe/filler.hpp"
#include "caffe/layers/deconv_layer.hpp"
+#include "caffe/layers/cudnn_deconv_layer.hpp"
#include "caffe/test/test_caffe_main.hpp"
#include "caffe/test/test_gradient_check_util.hpp"
this->blob_top_vec_);
}
+#ifdef USE_CUDNN
+
+// Since ConvolutionLayerTest checks the shared conv/deconv code in detail,
+// we'll just do a simple forward test and a gradient check.
+template <typename TypeParam>
+class CuDNNDeconvolutionLayerTest : public MultiDeviceTest<TypeParam> {
+ typedef typename TypeParam::Dtype Dtype;
+
+ protected:
+ CuDNNDeconvolutionLayerTest()
+ : blob_bottom_(new Blob<Dtype>(2, 3, 6, 4)),
+ blob_bottom_2_(new Blob<Dtype>(2, 3, 6, 4)),
+ blob_top_(new Blob<Dtype>()),
+ blob_top_2_(new Blob<Dtype>()) {}
+ virtual void SetUp() {
+ // fill the values
+ FillerParameter filler_param;
+ filler_param.set_value(1.);
+ GaussianFiller<Dtype> filler(filler_param);
+ filler.Fill(this->blob_bottom_);
+ filler.Fill(this->blob_bottom_2_);
+ blob_bottom_vec_.push_back(blob_bottom_);
+ blob_top_vec_.push_back(blob_top_);
+ }
+
+ virtual ~CuDNNDeconvolutionLayerTest() {
+ delete blob_bottom_;
+ delete blob_bottom_2_;
+ delete blob_top_;
+ delete blob_top_2_;
+ }
+
+ Blob<Dtype>* const blob_bottom_;
+ Blob<Dtype>* const blob_bottom_2_;
+ Blob<Dtype>* const blob_top_;
+ Blob<Dtype>* const blob_top_2_;
+ vector<Blob<Dtype>*> blob_bottom_vec_;
+ vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+TYPED_TEST_CASE(CuDNNDeconvolutionLayerTest, TestDtypesAndDevices);
+
+TYPED_TEST(CuDNNDeconvolutionLayerTest, TestSetup) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ ConvolutionParameter* convolution_param =
+ layer_param.mutable_convolution_param();
+ convolution_param->add_kernel_size(3);
+ convolution_param->add_stride(2);
+ convolution_param->set_num_output(4);
+ this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
+ this->blob_top_vec_.push_back(this->blob_top_2_);
+ shared_ptr<Layer<Dtype> > layer(
+ new CuDNNDeconvolutionLayer<Dtype>(layer_param));
+ layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ EXPECT_EQ(this->blob_top_->num(), 2);
+ EXPECT_EQ(this->blob_top_->channels(), 4);
+ EXPECT_EQ(this->blob_top_->height(), 13);
+ EXPECT_EQ(this->blob_top_->width(), 9);
+ EXPECT_EQ(this->blob_top_2_->num(), 2);
+ EXPECT_EQ(this->blob_top_2_->channels(), 4);
+ EXPECT_EQ(this->blob_top_2_->height(), 13);
+ EXPECT_EQ(this->blob_top_2_->width(), 9);
+ // setting group should not change the shape
+ convolution_param->set_num_output(3);
+ convolution_param->set_group(3);
+ layer.reset(new CuDNNDeconvolutionLayer<Dtype>(layer_param));
+ layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ EXPECT_EQ(this->blob_top_->num(), 2);
+ EXPECT_EQ(this->blob_top_->channels(), 3);
+ EXPECT_EQ(this->blob_top_->height(), 13);
+ EXPECT_EQ(this->blob_top_->width(), 9);
+ EXPECT_EQ(this->blob_top_2_->num(), 2);
+ EXPECT_EQ(this->blob_top_2_->channels(), 3);
+ EXPECT_EQ(this->blob_top_2_->height(), 13);
+ EXPECT_EQ(this->blob_top_2_->width(), 9);
+}
+
+TYPED_TEST(CuDNNDeconvolutionLayerTest, TestSimpleCuDNNDeconvolution) {
+ typedef typename TypeParam::Dtype Dtype;
+ this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
+ this->blob_top_vec_.push_back(this->blob_top_2_);
+ LayerParameter layer_param;
+ ConvolutionParameter* convolution_param =
+ layer_param.mutable_convolution_param();
+ convolution_param->add_kernel_size(3);
+ convolution_param->add_stride(2);
+ convolution_param->set_num_output(4);
+ convolution_param->mutable_weight_filler()->set_type("constant");
+ convolution_param->mutable_weight_filler()->set_value(1);
+ convolution_param->mutable_bias_filler()->set_type("constant");
+ convolution_param->mutable_bias_filler()->set_value(0.1);
+ shared_ptr<Layer<Dtype> > layer(
+ new CuDNNDeconvolutionLayer<Dtype>(layer_param));
+ layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ // constant-fill the bottom blobs
+ FillerParameter filler_param;
+ filler_param.set_value(1.);
+ ConstantFiller<Dtype> filler(filler_param);
+ filler.Fill(this->blob_bottom_);
+ filler.Fill(this->blob_bottom_2_);
+ layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ // simply check that accumulation works with overlapping filters
+ const Dtype* top_data = this->blob_top_->cpu_data();
+ for (int n = 0; n < this->blob_top_->num(); ++n) {
+ for (int c = 0; c < this->blob_top_->channels(); ++c) {
+ for (int h = 0; h < this->blob_top_->height(); ++h) {
+ for (int w = 0; w < this->blob_top_->width(); ++w) {
+ Dtype expected = 3.1;
+ bool h_overlap = h % 2 == 0 && h > 0
+ && h < this->blob_top_->height() - 1;
+ bool w_overlap = w % 2 == 0 && w > 0
+ && w < this->blob_top_->width() - 1;
+ if (h_overlap && w_overlap) {
+ expected += 9;
+ } else if (h_overlap || w_overlap) {
+ expected += 3;
+ }
+ EXPECT_NEAR(top_data[this->blob_top_->offset(n, c, h, w)],
+ expected, 1e-4);
+ }
+ }
+ }
+ }
+}
+
+TYPED_TEST(CuDNNDeconvolutionLayerTest, TestGradient) {
+ typedef typename TypeParam::Dtype Dtype;
+ LayerParameter layer_param;
+ ConvolutionParameter* convolution_param =
+ layer_param.mutable_convolution_param();
+ this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
+ this->blob_top_vec_.push_back(this->blob_top_2_);
+ convolution_param->add_kernel_size(2);
+ convolution_param->add_stride(1);
+ convolution_param->set_num_output(1);
+ convolution_param->mutable_weight_filler()->set_type("gaussian");
+ convolution_param->mutable_bias_filler()->set_type("gaussian");
+ CuDNNDeconvolutionLayer<Dtype> layer(layer_param);
+ GradientChecker<Dtype> checker(1e-2, 1e-3);
+ checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+ this->blob_top_vec_);
+}
+
+TYPED_TEST(CuDNNDeconvolutionLayerTest, TestNDAgainst2D) {
+ typedef typename TypeParam::Dtype Dtype;
+ const int kernel_h = 11;
+ const int kernel_w = 13;
+ vector<int> bottom_shape(4);
+ bottom_shape[0] = 15;
+ bottom_shape[1] = 12;
+ bottom_shape[2] = kernel_h * 2;
+ bottom_shape[3] = kernel_w * 2;
+ FillerParameter filler_param;
+ GaussianFiller<Dtype> filler(filler_param);
+ for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) {
+ this->blob_bottom_vec_[i]->Reshape(bottom_shape);
+ filler.Fill(this->blob_bottom_vec_[i]);
+ }
+ LayerParameter layer_param;
+ ConvolutionParameter* convolution_param =
+ layer_param.mutable_convolution_param();
+ convolution_param->set_num_output(18);
+ convolution_param->set_bias_term(false);
+ convolution_param->set_group(6);
+ convolution_param->set_kernel_h(kernel_h);
+ convolution_param->set_kernel_w(kernel_w);
+ convolution_param->mutable_weight_filler()->set_type("gaussian");
+ Blob<Dtype> weights;
+ Blob<Dtype> top_diff;
+ // Shape and fill weights and top_diff.
+ bool copy_diff;
+ bool reshape;
+ {
+ CuDNNDeconvolutionLayer<Dtype> layer(layer_param);
+ layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ top_diff.ReshapeLike(*this->blob_top_);
+ filler.Fill(&top_diff);
+ ASSERT_EQ(1, layer.blobs().size());
+ copy_diff = false; reshape = true;
+ weights.CopyFrom(*layer.blobs()[0], copy_diff, reshape);
+ }
+ vector<bool> propagate_down(1, true);
+ Blob<Dtype> result_2d;
+ Blob<Dtype> backward_result_2d;
+ Blob<Dtype> backward_weight_result_2d;
+ // Test with 2D im2col
+ {
+ caffe_set(this->blob_top_->count(), Dtype(0),
+ this->blob_top_->mutable_cpu_data());
+ caffe_set(this->blob_bottom_->count(), Dtype(0),
+ this->blob_bottom_->mutable_cpu_diff());
+ caffe_set(weights.count(), Dtype(0), weights.mutable_cpu_diff());
+ // Do SetUp and Forward; save Forward result in result_2d.
+ convolution_param->set_force_nd_im2col(false);
+ CuDNNDeconvolutionLayer<Dtype> layer_2d(layer_param);
+ layer_2d.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ ASSERT_EQ(1, layer_2d.blobs().size());
+ copy_diff = false; reshape = false;
+ layer_2d.blobs()[0]->CopyFrom(weights, copy_diff, reshape);
+ layer_2d.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ copy_diff = false; reshape = true;
+ result_2d.CopyFrom(*this->blob_top_, copy_diff, reshape);
+ // Copy pre-generated top diff into actual top diff;
+ // do Backward and save result in backward_result_2d.
+ ASSERT_EQ(this->blob_top_->shape(), top_diff.shape());
+ caffe_copy(top_diff.count(), top_diff.cpu_data(),
+ this->blob_top_->mutable_cpu_diff());
+ layer_2d.Backward(this->blob_top_vec_, propagate_down,
+ this->blob_bottom_vec_);
+ copy_diff = true; reshape = true;
+ backward_result_2d.CopyFrom(*this->blob_bottom_, copy_diff, reshape);
+ backward_weight_result_2d.CopyFrom(weights, copy_diff, reshape);
+ }
+ Blob<Dtype> result_nd;
+ Blob<Dtype> backward_result_nd;
+ Blob<Dtype> backward_weight_result_nd;
+ // Test with ND im2col
+ {
+ caffe_set(this->blob_top_->count(), Dtype(0),
+ this->blob_top_->mutable_cpu_data());
+ caffe_set(this->blob_bottom_->count(), Dtype(0),
+ this->blob_bottom_->mutable_cpu_diff());
+ caffe_set(weights.count(), Dtype(0), weights.mutable_cpu_diff());
+ // Do SetUp and Forward; save Forward result in result_nd.
+ convolution_param->set_force_nd_im2col(true);
+ CuDNNDeconvolutionLayer<Dtype> layer_nd(layer_param);
+ layer_nd.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+ ASSERT_EQ(1, layer_nd.blobs().size());
+ copy_diff = false; reshape = false;
+ layer_nd.blobs()[0]->CopyFrom(weights, copy_diff, reshape);
+ layer_nd.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+ copy_diff = false; reshape = true;
+ result_nd.CopyFrom(*this->blob_top_, copy_diff, reshape);
+ // Copy pre-generated top diff into actual top diff;
+ // do Backward and save result in backward_result_nd.
+ ASSERT_EQ(this->blob_top_->shape(), top_diff.shape());
+ caffe_copy(top_diff.count(), top_diff.cpu_data(),
+ this->blob_top_->mutable_cpu_diff());
+ layer_nd.Backward(this->blob_top_vec_, propagate_down,
+ this->blob_bottom_vec_);
+ copy_diff = true; reshape = true;
+ backward_result_nd.CopyFrom(*this->blob_bottom_, copy_diff, reshape);
+ backward_weight_result_nd.CopyFrom(weights, copy_diff, reshape);
+ }
+ ASSERT_EQ(result_nd.count(), result_2d.count());
+ for (int i = 0; i < result_2d.count(); ++i) {
+ EXPECT_NEAR(result_2d.cpu_data()[i], result_nd.cpu_data()[i], 1e-4);
+ }
+ ASSERT_EQ(backward_result_nd.count(), backward_result_2d.count());
+ for (int i = 0; i < backward_result_2d.count(); ++i) {
+ EXPECT_EQ(backward_result_2d.cpu_diff()[i],
+ backward_result_nd.cpu_diff()[i]);
+ }
+ ASSERT_EQ(backward_weight_result_nd.count(),
+ backward_weight_result_2d.count());
+ for (int i = 0; i < backward_weight_result_2d.count(); ++i) {
+ EXPECT_EQ(backward_weight_result_2d.cpu_diff()[i],
+ backward_weight_result_nd.cpu_diff()[i]);
+ }
+}
+
+#endif
+
} // namespace caffe