Implement CuDNN-based deconvolution layer and test
authorBo Wang <david.b.wang@gmail.com>
Sun, 10 Sep 2017 04:57:20 +0000 (21:57 -0700)
committerBo Wang <david.b.wang@gmail.com>
Sun, 10 Sep 2017 04:57:20 +0000 (21:57 -0700)
include/caffe/layers/cudnn_deconv_layer.hpp [new file with mode: 0644]
src/caffe/layer_factory.cpp
src/caffe/layers/cudnn_deconv_layer.cpp [new file with mode: 0644]
src/caffe/layers/cudnn_deconv_layer.cu [new file with mode: 0644]
src/caffe/layers/deconv_layer.cpp
src/caffe/test/test_deconvolution_layer.cpp

diff --git a/include/caffe/layers/cudnn_deconv_layer.hpp b/include/caffe/layers/cudnn_deconv_layer.hpp
new file mode 100644 (file)
index 0000000..34095e5
--- /dev/null
@@ -0,0 +1,68 @@
+#ifndef CAFFE_CUDNN_DECONV_LAYER_HPP_
+#define CAFFE_CUDNN_DECONV_LAYER_HPP_
+
+#include <vector>
+
+#include "caffe/blob.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/proto/caffe.pb.h"
+
+#include "caffe/layers/deconv_layer.hpp"
+
+namespace caffe {
+
+#ifdef USE_CUDNN
+/*
+ * @brief cuDNN implementation of DeConvolutionLayer.
+ *        Fallback to DeConvolutionLayer for CPU mode.
+ *
+ * cuDNN accelerates deconvolution through forward kernels for filtering and
+ * bias plus backward kernels for the gradient w.r.t. the filters, biases, and
+ * inputs. Caffe + cuDNN further speeds up the computation through forward
+ * parallelism across groups and backward parallelism across gradients.
+*/
+template <typename Dtype>
+class CuDNNDeconvolutionLayer : public DeconvolutionLayer<Dtype> {
+public:
+  explicit CuDNNDeconvolutionLayer(const LayerParameter& param)
+    : DeconvolutionLayer<Dtype>(param), handles_setup_(false) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+                          const vector<Blob<Dtype>*>& top);
+  virtual void Reshape(const vector<Blob<Dtype>*>& bottom,
+                       const vector<Blob<Dtype>*>& top);
+  virtual ~CuDNNDeconvolutionLayer();
+
+protected:
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+                           const vector<Blob<Dtype>*>& top);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+                            const vector<bool>& propagate_down,
+                            const vector<Blob<Dtype>*>& bottom);
+
+  bool handles_setup_;
+  cudnnHandle_t* handle_;
+  cudaStream_t*  stream_;
+
+  // algorithms for forward and backwards convolutions
+  cudnnConvolutionFwdAlgo_t *fwd_algo_;
+  cudnnConvolutionBwdFilterAlgo_t *bwd_filter_algo_;
+  cudnnConvolutionBwdDataAlgo_t *bwd_data_algo_;
+
+  vector<cudnnTensorDescriptor_t> bottom_descs_, top_descs_;
+  cudnnTensorDescriptor_t bias_desc_;
+  cudnnFilterDescriptor_t filter_desc_;
+  vector<cudnnConvolutionDescriptor_t> conv_descs_;
+  int bottom_offset_, top_offset_, bias_offset_;
+
+  size_t *workspace_fwd_sizes_;
+  size_t *workspace_bwd_data_sizes_;
+  size_t *workspace_bwd_filter_sizes_;
+  size_t workspaceSizeInBytes;  // size of underlying storage
+  void *workspaceData;  // underlying storage
+  void **workspace;  // aliases into workspaceData
+};
+#endif
+
+}  // namespace caffe
+
+#endif // CAFFE_CUDNN_DECONV_LAYER_HPP_
index f14253a..9f9026b 100644 (file)
@@ -8,6 +8,7 @@
 #include "caffe/layer.hpp"
 #include "caffe/layer_factory.hpp"
 #include "caffe/layers/conv_layer.hpp"
+#include "caffe/layers/deconv_layer.hpp"
 #include "caffe/layers/lrn_layer.hpp"
 #include "caffe/layers/pooling_layer.hpp"
 #include "caffe/layers/relu_layer.hpp"
@@ -18,6 +19,7 @@
 
 #ifdef USE_CUDNN
 #include "caffe/layers/cudnn_conv_layer.hpp"
+#include "caffe/layers/cudnn_deconv_layer.hpp"
 #include "caffe/layers/cudnn_lcn_layer.hpp"
 #include "caffe/layers/cudnn_lrn_layer.hpp"
 #include "caffe/layers/cudnn_pooling_layer.hpp"
@@ -73,6 +75,45 @@ shared_ptr<Layer<Dtype> > GetConvolutionLayer(
 
 REGISTER_LAYER_CREATOR(Convolution, GetConvolutionLayer);
 
+// Get deconvolution layer according to engine.
+template <typename Dtype>
+shared_ptr<Layer<Dtype> > GetDeconvolutionLayer(const LayerParameter& param) {
+  ConvolutionParameter conv_param = param.convolution_param();
+  ConvolutionParameter_Engine engine = conv_param.engine();
+#ifdef USE_CUDNN
+  bool use_dilation = false;
+  for (int i = 0; i < conv_param.dilation_size(); ++i) {
+    if (conv_param.dilation(i) > 1) {
+      use_dilation = true;
+    }
+  }
+#endif
+  if (engine == ConvolutionParameter_Engine_DEFAULT) {
+    engine = ConvolutionParameter_Engine_CAFFE;
+#ifdef USE_CUDNN
+    if (!use_dilation) {
+      engine = ConvolutionParameter_Engine_CUDNN;
+    }
+#endif
+  }
+  if (engine == ConvolutionParameter_Engine_CAFFE) {
+    return shared_ptr<Layer<Dtype> >(new DeconvolutionLayer<Dtype>(param));
+#ifdef USE_CUDNN
+  } else if (engine == ConvolutionParameter_Engine_CUDNN) {
+    if (use_dilation) {
+      LOG(FATAL) << "CuDNN doesn't support the dilated deconvolution at Layer "
+                 << param.name();
+    }
+    return shared_ptr<Layer<Dtype> >(new CuDNNDeconvolutionLayer<Dtype>(param));
+#endif
+  } else {
+    LOG(FATAL) << "Layer " << param.name() << " has unknown engine.";
+    throw;  // Avoids missing return warning
+  }
+}
+
+REGISTER_LAYER_CREATOR(Deconvolution, GetDeconvolutionLayer);
+
 // Get pooling layer according to engine.
 template <typename Dtype>
 shared_ptr<Layer<Dtype> > GetPoolingLayer(const LayerParameter& param) {
diff --git a/src/caffe/layers/cudnn_deconv_layer.cpp b/src/caffe/layers/cudnn_deconv_layer.cpp
new file mode 100644 (file)
index 0000000..260da5c
--- /dev/null
@@ -0,0 +1,327 @@
+#ifdef USE_CUDNN
+#include <algorithm>
+#include <vector>
+
+#include "caffe/layers/cudnn_deconv_layer.hpp"
+
+namespace caffe {
+
+// Set to three for the benefit of the backward pass, which
+// can use separate streams for calculating the gradient w.r.t.
+// bias, filter weights, and bottom data for each group independently
+#define CUDNN_STREAMS_PER_GROUP 3
+
+/**
+ * TODO(dox) explain cuDNN interface
+ */
+template <typename Dtype>
+void CuDNNDeconvolutionLayer<Dtype>::LayerSetUp(
+    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+  DeconvolutionLayer<Dtype>::LayerSetUp(bottom, top);
+  // Initialize CUDA streams and cuDNN.
+  stream_         = new cudaStream_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
+  handle_         = new cudnnHandle_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
+
+  // Initialize algorithm arrays
+  fwd_algo_       = new cudnnConvolutionFwdAlgo_t[bottom.size()];
+  bwd_filter_algo_= new cudnnConvolutionBwdFilterAlgo_t[bottom.size()];
+  bwd_data_algo_  = new cudnnConvolutionBwdDataAlgo_t[bottom.size()];
+
+  // initialize size arrays
+  workspace_fwd_sizes_ = new size_t[bottom.size()];
+  workspace_bwd_filter_sizes_ = new size_t[bottom.size()];
+  workspace_bwd_data_sizes_ = new size_t[bottom.size()];
+
+  // workspace data
+  workspaceSizeInBytes = 0;
+  workspaceData = NULL;
+  workspace = new void*[this->group_ * CUDNN_STREAMS_PER_GROUP];
+
+  for (size_t i = 0; i < bottom.size(); ++i) {
+    // initialize all to default algorithms
+    fwd_algo_[i] = (cudnnConvolutionFwdAlgo_t)0;
+    bwd_filter_algo_[i] = (cudnnConvolutionBwdFilterAlgo_t)0;
+    bwd_data_algo_[i] = (cudnnConvolutionBwdDataAlgo_t)0;
+    // default algorithms don't require workspace
+    workspace_fwd_sizes_[i] = 0;
+    workspace_bwd_data_sizes_[i] = 0;
+    workspace_bwd_filter_sizes_[i] = 0;
+  }
+
+  for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) {
+    CUDA_CHECK(cudaStreamCreate(&stream_[g]));
+    CUDNN_CHECK(cudnnCreate(&handle_[g]));
+    CUDNN_CHECK(cudnnSetStream(handle_[g], stream_[g]));
+    workspace[g] = NULL;
+  }
+
+  // Set the indexing parameters.
+  bias_offset_ = (this->num_output_ / this->group_);
+
+  // Create filter descriptor.
+  const int* kernel_shape_data = this->kernel_shape_.cpu_data();
+  const int kernel_h = kernel_shape_data[0];
+  const int kernel_w = kernel_shape_data[1];
+  cudnn::createFilterDesc<Dtype>(&filter_desc_,
+                                 this->channels_ / this->group_,
+                                 this->num_output_ / this->group_,
+                                 kernel_h,
+                                 kernel_w);
+
+  // Create tensor descriptor(s) for data and corresponding convolution(s).
+  for (int i = 0; i < bottom.size(); i++) {
+    cudnnTensorDescriptor_t bottom_desc;
+    cudnn::createTensor4dDesc<Dtype>(&bottom_desc);
+    bottom_descs_.push_back(bottom_desc);
+    cudnnTensorDescriptor_t top_desc;
+    cudnn::createTensor4dDesc<Dtype>(&top_desc);
+    top_descs_.push_back(top_desc);
+    cudnnConvolutionDescriptor_t conv_desc;
+    cudnn::createConvolutionDesc<Dtype>(&conv_desc);
+    conv_descs_.push_back(conv_desc);
+  }
+
+  // Tensor descriptor for bias.
+  if (this->bias_term_) {
+    cudnn::createTensor4dDesc<Dtype>(&bias_desc_);
+  }
+
+  handles_setup_ = true;
+}
+
+template <typename Dtype>
+void CuDNNDeconvolutionLayer<Dtype>::Reshape(
+    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+  DeconvolutionLayer<Dtype>::Reshape(bottom, top);
+  CHECK_EQ(2, this->num_spatial_axes_)
+      << "CuDNNDeconvolutionLayer input must have 2 spatial axes "
+      << "(e.g., height and width). "
+      << "Use 'engine: CAFFE' for general ND convolution.";
+  bottom_offset_ = this->bottom_dim_ / this->group_;
+  top_offset_ = this->top_dim_ / this->group_;
+  const int height = bottom[0]->shape(this->channel_axis_ + 1);
+  const int width = bottom[0]->shape(this->channel_axis_ + 2);
+  const int height_out = top[0]->shape(this->channel_axis_ + 1);
+  const int width_out = top[0]->shape(this->channel_axis_ + 2);
+  const int* pad_data = this->pad_.cpu_data();
+  const int pad_h = pad_data[0];
+  const int pad_w = pad_data[1];
+  const int* stride_data = this->stride_.cpu_data();
+  const int stride_h = stride_data[0];
+  const int stride_w = stride_data[1];
+
+  // Specify workspace limit for kernels directly until we have a
+  // planning strategy and a rewrite of Caffe's GPU memory mangagement
+  size_t workspace_limit_bytes = 8*1024*1024;
+
+  for (int i = 0; i < bottom.size(); i++) {
+    cudnn::setTensor4dDesc<Dtype>(&bottom_descs_[i],
+                                  this->num_,
+                                  this->channels_ / this->group_,
+                                  height,
+                                  width,
+                                  this->channels_ * height * width,
+                                  height * width,
+                                  width,
+                                  1);
+    cudnn::setTensor4dDesc<Dtype>(&top_descs_[i],
+                                  this->num_,
+                                  this->num_output_ / this->group_,
+                                  height_out,
+                                  width_out,
+                                  this->num_output_ * height_out * width_out,
+                                  height_out * width_out,
+                                  width_out,
+                                  1);
+    cudnn::setConvolutionDesc<Dtype>(&conv_descs_[i],
+                                     top_descs_[i],
+                                     filter_desc_,
+                                     pad_h,
+                                     pad_w,
+                                     stride_h,
+                                     stride_w);
+
+    // choose forward and backward algorithms + workspace(s)
+    CUDNN_CHECK(cudnnGetConvolutionForwardAlgorithm(
+        handle_[0],
+        top_descs_[i],
+        filter_desc_,
+        conv_descs_[i],
+        bottom_descs_[i],
+        CUDNN_CONVOLUTION_FWD_SPECIFY_WORKSPACE_LIMIT,
+        workspace_limit_bytes,
+        &fwd_algo_[i]));
+
+    // We have found that CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM is
+    // buggy. Thus, if this algo was chosen, choose winograd instead. If
+    // winograd is not supported or workspace is larger than threshold, choose
+    // implicit_gemm instead.
+    if (fwd_algo_[i] == CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM) {
+      size_t winograd_workspace_size;
+      cudnnStatus_t status = cudnnGetConvolutionForwardWorkspaceSize(
+          handle_[0],
+          top_descs_[i],
+          filter_desc_,
+          conv_descs_[i],
+          bottom_descs_[i],
+          CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD,
+          &winograd_workspace_size);
+      if (status != CUDNN_STATUS_SUCCESS ||
+          winograd_workspace_size >= workspace_limit_bytes) {
+        fwd_algo_[i] = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
+      } else {
+        fwd_algo_[i] = CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD;
+      }
+    }
+
+    CUDNN_CHECK(cudnnGetConvolutionForwardWorkspaceSize(
+        handle_[0],
+        top_descs_[i],
+        filter_desc_,
+        conv_descs_[i],
+        bottom_descs_[i],
+        fwd_algo_[i],
+        &(workspace_fwd_sizes_[i])));
+
+    // choose backward algorithm for filter
+    CUDNN_CHECK(cudnnGetConvolutionBackwardFilterAlgorithm(
+        handle_[0],
+        top_descs_[i],
+        bottom_descs_[i],
+        conv_descs_[i],
+        filter_desc_,
+        CUDNN_CONVOLUTION_BWD_FILTER_SPECIFY_WORKSPACE_LIMIT,
+        workspace_limit_bytes,
+        &bwd_filter_algo_[i]));
+
+    // get workspace for backwards filter algorithm
+    CUDNN_CHECK(cudnnGetConvolutionBackwardFilterWorkspaceSize(
+        handle_[0],
+        top_descs_[i],
+        bottom_descs_[i],
+        conv_descs_[i],
+        filter_desc_,
+        bwd_filter_algo_[i],
+        &workspace_bwd_filter_sizes_[i]));
+
+    // choose backward algo for data
+    CUDNN_CHECK(cudnnGetConvolutionBackwardDataAlgorithm(
+        handle_[0],
+        filter_desc_,
+        bottom_descs_[i],
+        conv_descs_[i],
+        top_descs_[i],
+        CUDNN_CONVOLUTION_BWD_DATA_SPECIFY_WORKSPACE_LIMIT,
+        workspace_limit_bytes,
+        &bwd_data_algo_[i]));
+
+    // get workspace size
+    CUDNN_CHECK(cudnnGetConvolutionBackwardDataWorkspaceSize(
+        handle_[0],
+        filter_desc_,
+        bottom_descs_[i],
+        conv_descs_[i],
+        top_descs_[i],
+        bwd_data_algo_[i],
+        &workspace_bwd_data_sizes_[i]));
+  }
+
+  // reduce over all workspace sizes to get a maximum to allocate / reallocate
+  size_t total_workspace_fwd = 0;
+  size_t total_workspace_bwd_data = 0;
+  size_t total_workspace_bwd_filter = 0;
+
+  for (size_t i = 0; i < bottom.size(); i++) {
+    total_workspace_fwd        = std::max(total_workspace_fwd,
+                                     workspace_fwd_sizes_[i]);
+    total_workspace_bwd_data   = std::max(total_workspace_bwd_data,
+                                     workspace_bwd_data_sizes_[i]);
+    total_workspace_bwd_filter = std::max(total_workspace_bwd_filter,
+                                     workspace_bwd_filter_sizes_[i]);
+  }
+  // get max over all operations
+  size_t max_workspace = std::max(total_workspace_fwd,
+                             total_workspace_bwd_data);
+  max_workspace = std::max(max_workspace, total_workspace_bwd_filter);
+  // ensure all groups have enough workspace
+  size_t total_max_workspace = max_workspace *
+                               (this->group_ * CUDNN_STREAMS_PER_GROUP);
+
+  // this is the total amount of storage needed over all groups + streams
+  if (total_max_workspace > workspaceSizeInBytes) {
+    DLOG(INFO) << "Reallocating workspace storage: " << total_max_workspace;
+    workspaceSizeInBytes = total_max_workspace;
+
+    // free the existing workspace and allocate a new (larger) one
+    cudaFree(this->workspaceData);
+
+    cudaError_t err = cudaMalloc(&(this->workspaceData), workspaceSizeInBytes);
+    if (err != cudaSuccess) {
+      // force zero memory path
+      for (int i = 0; i < bottom.size(); i++) {
+        workspace_fwd_sizes_[i] = 0;
+        workspace_bwd_filter_sizes_[i] = 0;
+        workspace_bwd_data_sizes_[i] = 0;
+        fwd_algo_[i] = CUDNN_CONVOLUTION_FWD_ALGO_FFT_TILING;
+        bwd_filter_algo_[i] = CUDNN_CONVOLUTION_BWD_FILTER_ALGO_0;
+        bwd_data_algo_[i] = CUDNN_CONVOLUTION_BWD_DATA_ALGO_0;
+      }
+
+      // NULL out all workspace pointers
+      for (int g = 0; g < (this->group_ * CUDNN_STREAMS_PER_GROUP); g++) {
+        workspace[g] = NULL;
+      }
+      // NULL out underlying data
+      workspaceData = NULL;
+      workspaceSizeInBytes = 0;
+    }
+
+    // if we succeed in the allocation, set pointer aliases for workspaces
+    for (int g = 0; g < (this->group_ * CUDNN_STREAMS_PER_GROUP); g++) {
+      workspace[g] = reinterpret_cast<char *>(workspaceData) + g*max_workspace;
+    }
+  }
+
+  // Tensor descriptor for bias.
+  if (this->bias_term_) {
+    cudnn::setTensor4dDesc<Dtype>(
+        &bias_desc_, 1, this->num_output_ / this->group_, 1, 1);
+  }
+}
+
+template <typename Dtype>
+CuDNNDeconvolutionLayer<Dtype>::~CuDNNDeconvolutionLayer() {
+  // Check that handles have been setup before destroying.
+  if (!handles_setup_) { return; }
+
+  for (int i = 0; i < bottom_descs_.size(); i++) {
+    cudnnDestroyTensorDescriptor(bottom_descs_[i]);
+    cudnnDestroyTensorDescriptor(top_descs_[i]);
+    cudnnDestroyConvolutionDescriptor(conv_descs_[i]);
+  }
+  if (this->bias_term_) {
+    cudnnDestroyTensorDescriptor(bias_desc_);
+  }
+  cudnnDestroyFilterDescriptor(filter_desc_);
+
+  for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) {
+    cudaStreamDestroy(stream_[g]);
+    cudnnDestroy(handle_[g]);
+  }
+
+  cudaFree(workspaceData);
+  delete [] workspace;
+  delete [] stream_;
+  delete [] handle_;
+  delete [] fwd_algo_;
+  delete [] bwd_filter_algo_;
+  delete [] bwd_data_algo_;
+  delete [] workspace_fwd_sizes_;
+  delete [] workspace_bwd_data_sizes_;
+  delete [] workspace_bwd_filter_sizes_;
+}
+
+INSTANTIATE_CLASS(CuDNNDeconvolutionLayer);
+
+}   // namespace caffe
+#endif
diff --git a/src/caffe/layers/cudnn_deconv_layer.cu b/src/caffe/layers/cudnn_deconv_layer.cu
new file mode 100644 (file)
index 0000000..eb1df32
--- /dev/null
@@ -0,0 +1,138 @@
+#ifdef USE_CUDNN
+#include <vector>
+
+#include "caffe/layers/cudnn_deconv_layer.hpp"
+
+namespace caffe {
+
+__global__ void sync_deconv_groups() {}
+
+template <typename Dtype>
+void CuDNNDeconvolutionLayer<Dtype>::Forward_gpu(
+    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+  const Dtype* weight = this->blobs_[0]->gpu_data();
+  for (int i = 0; i < bottom.size(); ++i) {
+    const Dtype* bottom_data = bottom[i]->gpu_data();
+    Dtype* top_data = top[i]->mutable_gpu_data();
+
+    // Forward through cuDNN in parallel over groups.
+    for (int g = 0; g < this->group_; g++) {
+      // Filters.
+      CUDNN_CHECK(cudnnConvolutionBackwardData(
+          handle_[g],
+          cudnn::dataType<Dtype>::one,
+          filter_desc_,
+          weight + this->weight_offset_ * g,
+          bottom_descs_[i],
+          bottom_data + bottom_offset_ * g,
+          conv_descs_[i],
+          bwd_data_algo_[i],
+          workspace[g],
+          workspace_bwd_data_sizes_[i],
+          cudnn::dataType<Dtype>::zero,
+          top_descs_[i],
+          top_data + top_offset_ * g));
+
+      // Bias.
+      if (this->bias_term_) {
+        const Dtype* bias_data = this->blobs_[1]->gpu_data();
+        CUDNN_CHECK(cudnnAddTensor(handle_[g],
+                                   cudnn::dataType<Dtype>::one,
+                                   bias_desc_,
+                                   bias_data + bias_offset_ * g,
+                                   cudnn::dataType<Dtype>::one,
+                                   top_descs_[i],
+                                   top_data + top_offset_ * g));
+      }
+    }
+
+    // Synchronize the work across groups, each of which went into its own
+    // stream, by launching an empty kernel into the default (null) stream.
+    // NOLINT_NEXT_LINE(whitespace/operators)
+    sync_deconv_groups<<<1, 1>>>();
+  }
+}
+
+template <typename Dtype>
+void CuDNNDeconvolutionLayer<Dtype>::Backward_gpu(
+    const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down,
+    const vector<Blob<Dtype>*>& bottom) {
+  const Dtype* weight = NULL;
+  Dtype* weight_diff = NULL;
+  if (this->param_propagate_down_[0]) {
+    weight = this->blobs_[0]->gpu_data();
+    weight_diff = this->blobs_[0]->mutable_gpu_diff();
+  }
+  Dtype* bias_diff = NULL;
+  if (this->bias_term_ && this->param_propagate_down_[1]) {
+    bias_diff = this->blobs_[1]->mutable_gpu_diff();
+  }
+  for (int i = 0; i < top.size(); ++i) {
+    const Dtype* top_diff = top[i]->gpu_diff();
+    // Backward through cuDNN in parallel over groups and gradients.
+    for (int g = 0; g < this->group_; g++) {
+      // Gradient w.r.t. bias.
+      if (this->bias_term_ && this->param_propagate_down_[1]) {
+        CUDNN_CHECK(cudnnConvolutionBackwardBias(handle_[0 * this->group_ + g],
+                                                 cudnn::dataType<Dtype>::one,
+                                                 top_descs_[i],
+                                                 top_diff + top_offset_ * g,
+                                                 cudnn::dataType<Dtype>::one,
+                                                 bias_desc_,
+                                                 bias_diff + bias_offset_ * g));
+      }
+
+      // Gradient w.r.t. weights.
+      if (this->param_propagate_down_[0]) {
+        const Dtype* bottom_data = bottom[i]->gpu_data();
+        CUDNN_CHECK(cudnnConvolutionBackwardFilter(
+            handle_[1 * this->group_ + g],
+            cudnn::dataType<Dtype>::one,
+            top_descs_[i],
+            top_diff + top_offset_ * g,
+            bottom_descs_[i],
+            bottom_data + bottom_offset_ * g,
+            conv_descs_[i],
+            bwd_filter_algo_[i],
+            workspace[1 * this->group_ + g],
+            workspace_bwd_filter_sizes_[i],
+            cudnn::dataType<Dtype>::one,
+            filter_desc_,
+            weight_diff + this->weight_offset_ * g));
+      }
+
+      // Gradient w.r.t. bottom data.
+      if (propagate_down[i]) {
+        if (weight == NULL) {
+          weight = this->blobs_[0]->gpu_data();
+        }
+        Dtype* bottom_diff = bottom[i]->mutable_gpu_diff();
+        CUDNN_CHECK(
+            cudnnConvolutionForward(handle_[2 * this->group_ + g],
+                                    cudnn::dataType<Dtype>::one,
+                                    top_descs_[i],
+                                    top_diff + top_offset_ * g,
+                                    filter_desc_,
+                                    weight + this->weight_offset_ * g,
+                                    conv_descs_[i],
+                                    fwd_algo_[i],
+                                    workspace[2 * this->group_ + g],
+                                    workspace_fwd_sizes_[i],
+                                    cudnn::dataType<Dtype>::zero,
+                                    bottom_descs_[i],
+                                    bottom_diff + bottom_offset_ * g));
+      }
+    }
+
+    // Synchronize the work across groups, each of which went into its own
+    // stream, by launching an empty kernel into the default (null) stream.
+    // NOLINT_NEXT_LINE(whitespace/operators)
+    sync_deconv_groups<<<1, 1>>>();
+  }
+}
+
+INSTANTIATE_LAYER_GPU_FUNCS(CuDNNDeconvolutionLayer);
+
+}  // namespace caffe
+#endif
index 20a460f..b86472b 100644 (file)
@@ -79,6 +79,5 @@ STUB_GPU(DeconvolutionLayer);
 #endif
 
 INSTANTIATE_CLASS(DeconvolutionLayer);
-REGISTER_LAYER_CLASS(Deconvolution);
 
 }  // namespace caffe
index c4b09ad..0067907 100644 (file)
@@ -6,6 +6,7 @@
 #include "caffe/common.hpp"
 #include "caffe/filler.hpp"
 #include "caffe/layers/deconv_layer.hpp"
+#include "caffe/layers/cudnn_deconv_layer.hpp"
 
 #include "caffe/test/test_caffe_main.hpp"
 #include "caffe/test/test_gradient_check_util.hpp"
@@ -301,4 +302,268 @@ TYPED_TEST(DeconvolutionLayerTest, TestGradient3D) {
       this->blob_top_vec_);
 }
 
+#ifdef USE_CUDNN
+
+// Since ConvolutionLayerTest checks the shared conv/deconv code in detail,
+// we'll just do a simple forward test and a gradient check.
+template <typename TypeParam>
+class CuDNNDeconvolutionLayerTest : public MultiDeviceTest<TypeParam> {
+  typedef typename TypeParam::Dtype Dtype;
+
+ protected:
+  CuDNNDeconvolutionLayerTest()
+      : blob_bottom_(new Blob<Dtype>(2, 3, 6, 4)),
+        blob_bottom_2_(new Blob<Dtype>(2, 3, 6, 4)),
+        blob_top_(new Blob<Dtype>()),
+        blob_top_2_(new Blob<Dtype>()) {}
+  virtual void SetUp() {
+    // fill the values
+    FillerParameter filler_param;
+    filler_param.set_value(1.);
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    filler.Fill(this->blob_bottom_2_);
+    blob_bottom_vec_.push_back(blob_bottom_);
+    blob_top_vec_.push_back(blob_top_);
+  }
+
+  virtual ~CuDNNDeconvolutionLayerTest() {
+    delete blob_bottom_;
+    delete blob_bottom_2_;
+    delete blob_top_;
+    delete blob_top_2_;
+  }
+
+  Blob<Dtype>* const blob_bottom_;
+  Blob<Dtype>* const blob_bottom_2_;
+  Blob<Dtype>* const blob_top_;
+  Blob<Dtype>* const blob_top_2_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+TYPED_TEST_CASE(CuDNNDeconvolutionLayerTest, TestDtypesAndDevices);
+
+TYPED_TEST(CuDNNDeconvolutionLayerTest, TestSetup) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
+  convolution_param->set_num_output(4);
+  this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
+  this->blob_top_vec_.push_back(this->blob_top_2_);
+  shared_ptr<Layer<Dtype> > layer(
+      new CuDNNDeconvolutionLayer<Dtype>(layer_param));
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  EXPECT_EQ(this->blob_top_->num(), 2);
+  EXPECT_EQ(this->blob_top_->channels(), 4);
+  EXPECT_EQ(this->blob_top_->height(), 13);
+  EXPECT_EQ(this->blob_top_->width(), 9);
+  EXPECT_EQ(this->blob_top_2_->num(), 2);
+  EXPECT_EQ(this->blob_top_2_->channels(), 4);
+  EXPECT_EQ(this->blob_top_2_->height(), 13);
+  EXPECT_EQ(this->blob_top_2_->width(), 9);
+  // setting group should not change the shape
+  convolution_param->set_num_output(3);
+  convolution_param->set_group(3);
+  layer.reset(new CuDNNDeconvolutionLayer<Dtype>(layer_param));
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  EXPECT_EQ(this->blob_top_->num(), 2);
+  EXPECT_EQ(this->blob_top_->channels(), 3);
+  EXPECT_EQ(this->blob_top_->height(), 13);
+  EXPECT_EQ(this->blob_top_->width(), 9);
+  EXPECT_EQ(this->blob_top_2_->num(), 2);
+  EXPECT_EQ(this->blob_top_2_->channels(), 3);
+  EXPECT_EQ(this->blob_top_2_->height(), 13);
+  EXPECT_EQ(this->blob_top_2_->width(), 9);
+}
+
+TYPED_TEST(CuDNNDeconvolutionLayerTest, TestSimpleCuDNNDeconvolution) {
+  typedef typename TypeParam::Dtype Dtype;
+  this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
+  this->blob_top_vec_.push_back(this->blob_top_2_);
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->add_kernel_size(3);
+  convolution_param->add_stride(2);
+  convolution_param->set_num_output(4);
+  convolution_param->mutable_weight_filler()->set_type("constant");
+  convolution_param->mutable_weight_filler()->set_value(1);
+  convolution_param->mutable_bias_filler()->set_type("constant");
+  convolution_param->mutable_bias_filler()->set_value(0.1);
+  shared_ptr<Layer<Dtype> > layer(
+      new CuDNNDeconvolutionLayer<Dtype>(layer_param));
+  layer->SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+  // constant-fill the bottom blobs
+  FillerParameter filler_param;
+  filler_param.set_value(1.);
+  ConstantFiller<Dtype> filler(filler_param);
+  filler.Fill(this->blob_bottom_);
+  filler.Fill(this->blob_bottom_2_);
+  layer->Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+  // simply check that accumulation works with overlapping filters
+  const Dtype* top_data = this->blob_top_->cpu_data();
+  for (int n = 0; n < this->blob_top_->num(); ++n) {
+    for (int c = 0; c < this->blob_top_->channels(); ++c) {
+      for (int h = 0; h < this->blob_top_->height(); ++h) {
+        for (int w = 0; w < this->blob_top_->width(); ++w) {
+          Dtype expected = 3.1;
+          bool h_overlap = h % 2 == 0 && h > 0
+            && h < this->blob_top_->height() - 1;
+          bool w_overlap = w % 2 == 0 && w > 0
+            && w < this->blob_top_->width() - 1;
+          if (h_overlap && w_overlap) {
+            expected += 9;
+          } else if (h_overlap || w_overlap) {
+            expected += 3;
+          }
+          EXPECT_NEAR(top_data[this->blob_top_->offset(n, c, h, w)],
+              expected, 1e-4);
+        }
+      }
+    }
+  }
+}
+
+TYPED_TEST(CuDNNDeconvolutionLayerTest, TestGradient) {
+  typedef typename TypeParam::Dtype Dtype;
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
+  this->blob_top_vec_.push_back(this->blob_top_2_);
+  convolution_param->add_kernel_size(2);
+  convolution_param->add_stride(1);
+  convolution_param->set_num_output(1);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("gaussian");
+  CuDNNDeconvolutionLayer<Dtype> layer(layer_param);
+  GradientChecker<Dtype> checker(1e-2, 1e-3);
+  checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
+      this->blob_top_vec_);
+}
+
+TYPED_TEST(CuDNNDeconvolutionLayerTest, TestNDAgainst2D) {
+  typedef typename TypeParam::Dtype Dtype;
+  const int kernel_h = 11;
+  const int kernel_w = 13;
+  vector<int> bottom_shape(4);
+  bottom_shape[0] = 15;
+  bottom_shape[1] = 12;
+  bottom_shape[2] = kernel_h * 2;
+  bottom_shape[3] = kernel_w * 2;
+  FillerParameter filler_param;
+  GaussianFiller<Dtype> filler(filler_param);
+  for (int i = 0; i < this->blob_bottom_vec_.size(); ++i) {
+    this->blob_bottom_vec_[i]->Reshape(bottom_shape);
+    filler.Fill(this->blob_bottom_vec_[i]);
+  }
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->set_num_output(18);
+  convolution_param->set_bias_term(false);
+  convolution_param->set_group(6);
+  convolution_param->set_kernel_h(kernel_h);
+  convolution_param->set_kernel_w(kernel_w);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  Blob<Dtype> weights;
+  Blob<Dtype> top_diff;
+  // Shape and fill weights and top_diff.
+  bool copy_diff;
+  bool reshape;
+  {
+    CuDNNDeconvolutionLayer<Dtype> layer(layer_param);
+    layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+    top_diff.ReshapeLike(*this->blob_top_);
+    filler.Fill(&top_diff);
+    ASSERT_EQ(1, layer.blobs().size());
+    copy_diff = false; reshape = true;
+    weights.CopyFrom(*layer.blobs()[0], copy_diff, reshape);
+  }
+  vector<bool> propagate_down(1, true);
+  Blob<Dtype> result_2d;
+  Blob<Dtype> backward_result_2d;
+  Blob<Dtype> backward_weight_result_2d;
+  // Test with 2D im2col
+  {
+    caffe_set(this->blob_top_->count(), Dtype(0),
+              this->blob_top_->mutable_cpu_data());
+    caffe_set(this->blob_bottom_->count(), Dtype(0),
+              this->blob_bottom_->mutable_cpu_diff());
+    caffe_set(weights.count(), Dtype(0), weights.mutable_cpu_diff());
+    // Do SetUp and Forward; save Forward result in result_2d.
+    convolution_param->set_force_nd_im2col(false);
+    CuDNNDeconvolutionLayer<Dtype> layer_2d(layer_param);
+    layer_2d.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+    ASSERT_EQ(1, layer_2d.blobs().size());
+    copy_diff = false; reshape = false;
+    layer_2d.blobs()[0]->CopyFrom(weights, copy_diff, reshape);
+    layer_2d.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+    copy_diff = false; reshape = true;
+    result_2d.CopyFrom(*this->blob_top_, copy_diff, reshape);
+    // Copy pre-generated top diff into actual top diff;
+    // do Backward and save result in backward_result_2d.
+    ASSERT_EQ(this->blob_top_->shape(), top_diff.shape());
+    caffe_copy(top_diff.count(), top_diff.cpu_data(),
+               this->blob_top_->mutable_cpu_diff());
+    layer_2d.Backward(this->blob_top_vec_, propagate_down,
+                      this->blob_bottom_vec_);
+    copy_diff = true; reshape = true;
+    backward_result_2d.CopyFrom(*this->blob_bottom_, copy_diff, reshape);
+    backward_weight_result_2d.CopyFrom(weights, copy_diff, reshape);
+  }
+  Blob<Dtype> result_nd;
+  Blob<Dtype> backward_result_nd;
+  Blob<Dtype> backward_weight_result_nd;
+  // Test with ND im2col
+  {
+    caffe_set(this->blob_top_->count(), Dtype(0),
+              this->blob_top_->mutable_cpu_data());
+    caffe_set(this->blob_bottom_->count(), Dtype(0),
+              this->blob_bottom_->mutable_cpu_diff());
+    caffe_set(weights.count(), Dtype(0), weights.mutable_cpu_diff());
+    // Do SetUp and Forward; save Forward result in result_nd.
+    convolution_param->set_force_nd_im2col(true);
+    CuDNNDeconvolutionLayer<Dtype> layer_nd(layer_param);
+    layer_nd.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
+    ASSERT_EQ(1, layer_nd.blobs().size());
+    copy_diff = false; reshape = false;
+    layer_nd.blobs()[0]->CopyFrom(weights, copy_diff, reshape);
+    layer_nd.Forward(this->blob_bottom_vec_, this->blob_top_vec_);
+    copy_diff = false; reshape = true;
+    result_nd.CopyFrom(*this->blob_top_, copy_diff, reshape);
+    // Copy pre-generated top diff into actual top diff;
+    // do Backward and save result in backward_result_nd.
+    ASSERT_EQ(this->blob_top_->shape(), top_diff.shape());
+    caffe_copy(top_diff.count(), top_diff.cpu_data(),
+               this->blob_top_->mutable_cpu_diff());
+    layer_nd.Backward(this->blob_top_vec_, propagate_down,
+                      this->blob_bottom_vec_);
+    copy_diff = true; reshape = true;
+    backward_result_nd.CopyFrom(*this->blob_bottom_, copy_diff, reshape);
+    backward_weight_result_nd.CopyFrom(weights, copy_diff, reshape);
+  }
+  ASSERT_EQ(result_nd.count(), result_2d.count());
+  for (int i = 0; i < result_2d.count(); ++i)  {
+    EXPECT_NEAR(result_2d.cpu_data()[i], result_nd.cpu_data()[i], 1e-4);
+  }
+  ASSERT_EQ(backward_result_nd.count(), backward_result_2d.count());
+  for (int i = 0; i < backward_result_2d.count(); ++i) {
+    EXPECT_EQ(backward_result_2d.cpu_diff()[i],
+              backward_result_nd.cpu_diff()[i]);
+  }
+  ASSERT_EQ(backward_weight_result_nd.count(),
+            backward_weight_result_2d.count());
+  for (int i = 0; i < backward_weight_result_2d.count(); ++i) {
+    EXPECT_EQ(backward_weight_result_2d.cpu_diff()[i],
+              backward_weight_result_nd.cpu_diff()[i]);
+  }
+}
+
+#endif
+
 }  // namespace caffe