strategize cuDNN convolution
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Tue, 2 Sep 2014 05:05:43 +0000 (22:05 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Sun, 7 Sep 2014 17:25:23 +0000 (19:25 +0200)
include/caffe/vision_layers.hpp
src/caffe/layer_factory.cpp
src/caffe/layers/conv_layer.cpp
src/caffe/layers/cudnn_conv_layer.cpp [new file with mode: 0644]
src/caffe/layers/cudnn_conv_layer.cu [new file with mode: 0644]
src/caffe/test/test_convolution_layer.cpp

index 2dca2bf..4269163 100644 (file)
@@ -63,6 +63,36 @@ class ConvolutionLayer : public Layer<Dtype> {
   Blob<Dtype> bias_multiplier_;
 };
 
+#ifdef USE_CUDNN
+/*
+ * @brief cuDNN implementation of ConvolutionLayer.
+ *        Fallback to ConvolutionLayer for CPU mode.
+*/
+template <typename Dtype>
+class CuDNNConvolutionLayer : public ConvolutionLayer<Dtype> {
+ public:
+  explicit CuDNNConvolutionLayer(const LayerParameter& param)
+      : ConvolutionLayer<Dtype>(param) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual ~CuDNNConvolutionLayer();
+
+ protected:
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom);
+
+  cudnnHandle_t* handle_;
+  cudaStream_t*  stream_;
+  vector<cudnnTensor4dDescriptor_t> bottom_descs_, top_descs_;
+  cudnnTensor4dDescriptor_t    bias_desc_;
+  cudnnFilterDescriptor_t      filter_desc_;
+  vector<cudnnConvolutionDescriptor_t> conv_descs_;
+  int bottom_offset_, top_offset_, weight_offset_, bias_offset_;
+};
+#endif
+
 /**
  * @brief A helper for image operations that rearranges image regions into
  *        column vectors.  Used by ConvolutionLayer to perform convolution
index b7b1098..ef1b756 100644 (file)
@@ -22,6 +22,10 @@ ConvolutionLayer<Dtype>* GetConvolutionLayer(const string& name,
   }
   if (engine == ConvolutionParameter_Engine_CAFFE) {
     return new ConvolutionLayer<Dtype>(param);
+#ifdef USE_CUDNN
+  } else if (engine == ConvolutionParameter_Engine_CUDNN) {
+    return new CuDNNConvolutionLayer<Dtype>(param);
+#endif
   } else {
     LOG(FATAL) << "Layer " << name << " has unknown engine.";
   }
index 1a1248f..81ad4f9 100644 (file)
@@ -66,11 +66,11 @@ void ConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
   CHECK_EQ(channels_ % group_, 0);
   // The im2col result buffer would only hold one image at a time to avoid
   // overly large memory usage.
-  int height_out =
+  height_out_ =
       (height_ + 2 * pad_h_ - kernel_h_) / stride_h_ + 1;
-  int width_out = (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1;
+  width_out_ = (width_ + 2 * pad_w_ - kernel_w_) / stride_w_ + 1;
   col_buffer_.Reshape(
-      1, channels_ * kernel_h_ * kernel_w_, height_out, width_out);
+      1, channels_ * kernel_h_ * kernel_w_, height_out_, width_out_);
   // Set the parameters
   CHECK_EQ(num_output_ % group_, 0)
       << "Number of output should be multiples of group.";
@@ -78,9 +78,9 @@ void ConvolutionLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
   // Figure out the dimensions for individual gemms.
   M_ = num_output_ / group_;
   K_ = channels_ * kernel_h_ * kernel_w_ / group_;
-  N_ = height_out * width_out;
+  N_ = height_out_ * width_out_;
   for (int top_id = 0; top_id < top->size(); ++top_id) {
-    (*top)[top_id]->Reshape(num_, num_output_, height_out, width_out);
+    (*top)[top_id]->Reshape(num_, num_output_, height_out_, width_out_);
   }
   // Check if we need to set up the weights
   if (this->blobs_.size() > 0) {
diff --git a/src/caffe/layers/cudnn_conv_layer.cpp b/src/caffe/layers/cudnn_conv_layer.cpp
new file mode 100644 (file)
index 0000000..a27d19c
--- /dev/null
@@ -0,0 +1,113 @@
+#ifdef USE_CUDNN
+#include <vector>
+
+#include "caffe/filler.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/util/im2col.hpp"
+#include "caffe/util/math_functions.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+// Set to three for the benefit of the backward pass, which
+// can use separate streams for calculating the gradient w.r.t.
+// bias, filter weights, and bottom data for each group independently
+#define CUDNN_STREAMS_PER_GROUP 3
+
+/**
+ * TODO(dox) explain cuDNN interface
+ */
+template <typename Dtype>
+void CuDNNConvolutionLayer<Dtype>::LayerSetUp(
+    const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+  ConvolutionLayer<Dtype>::LayerSetUp(bottom, top);
+  // Initialize CUDA streams and cuNN.
+  stream_         = new cudaStream_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
+  handle_         = new cudnnHandle_t[this->group_ * CUDNN_STREAMS_PER_GROUP];
+
+  for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) {
+    // TODO(cudnn) check
+    cudaError_t err = cudaStreamCreate(&stream_[g]);
+    CHECK_EQ(err, cudaSuccess) << "Error creating a CUDA stream.";
+
+    // TODO(cudnn) check
+    cudnnStatus_t stat;
+    stat = cudnnCreate(&handle_[g]);
+    CHECK_EQ(stat, CUDNN_STATUS_SUCCESS) << "Could not create a CUDNN handle.";
+    stat = cudnnSetStream(handle_[g], stream_[g]);
+    CHECK_EQ(stat, CUDNN_STATUS_SUCCESS) << "Could not set CUDNN stream.";
+  }
+
+  // Set the indexing parameters.
+  bottom_offset_ = (this->channels_ / this->group_)
+      * this->height_ * this->width_;
+  top_offset_ = (this->num_output_ / this->group_)
+      * this->height_out_ * this->width_out_;
+  weight_offset_ = (this->num_output_ / this->group_)
+      * (this->channels_ / this->group_) * this->kernel_h_ * this->kernel_w_;
+  bias_offset_ = (this->num_output_ / this->group_);
+
+  // Create filter descriptor.
+  cudnn::createFilterDesc<Dtype>(&filter_desc_,
+      this->num_output_ / this->group_, this->channels_ / this->group_,
+      this->kernel_h_, this->kernel_w_);
+
+  // Create tensor descriptor(s) for data and corresponding convolution(s).
+  for (int i = 0; i < bottom.size(); i++) {
+    cudnnTensor4dDescriptor_t bottom_desc;
+    cudnn::createTensor4dDesc<Dtype>(&bottom_desc,
+        this->num_,
+        this->channels_ / this->group_,
+        this->height_, this->width_,
+        this->channels_ * this->height_ * this->width_,
+        this->height_ * this->width_,
+        this->width_, 1);
+    bottom_descs_.push_back(bottom_desc);
+    cudnnTensor4dDescriptor_t top_desc;
+    cudnn::createTensor4dDesc<Dtype>(&top_desc,
+        this->num_,
+        this->num_output_ / this->group_,
+        this->height_out_, this->width_out_,
+        this->num_output_ * this->height_out_ * this->width_out_,
+        this->height_out_ * this->width_out_,
+        this->width_out_, 1);
+    top_descs_.push_back(top_desc);
+    cudnnConvolutionDescriptor_t conv_desc;
+    cudnn::createConvolutionDesc<Dtype>(&conv_desc, bottom_desc,
+        filter_desc_, this->pad_h_, this->pad_w_,
+        this->stride_h_, this->stride_w_);
+    conv_descs_.push_back(conv_desc);
+  }
+
+  // Tensor descriptor for bias.
+  if (this->bias_term_) {
+    cudnn::createTensor4dDesc<Dtype>(&bias_desc_,
+        1, this->num_output_ / this->group_, 1, 1);
+  }
+}
+
+template <typename Dtype>
+CuDNNConvolutionLayer<Dtype>::~CuDNNConvolutionLayer() {
+  for (int i = 0; i < bottom_descs_.size(); i++) {
+    cudnnDestroyTensor4dDescriptor(bottom_descs_[i]);
+    cudnnDestroyTensor4dDescriptor(top_descs_[i]);
+    cudnnDestroyConvolutionDescriptor(conv_descs_[i]);
+  }
+  if (this->bias_term_) {
+    cudnnDestroyTensor4dDescriptor(bias_desc_);
+  }
+  cudnnDestroyFilterDescriptor(filter_desc_);
+
+  for (int g = 0; g < this->group_ * CUDNN_STREAMS_PER_GROUP; g++) {
+    cudaStreamDestroy(stream_[g]);
+    cudnnDestroy(handle_[g]);
+  }
+
+  delete [] stream_;
+  delete [] handle_;
+}
+
+INSTANTIATE_CLASS(CuDNNConvolutionLayer);
+
+}   // namespace caffe
+#endif
diff --git a/src/caffe/layers/cudnn_conv_layer.cu b/src/caffe/layers/cudnn_conv_layer.cu
new file mode 100644 (file)
index 0000000..9a65455
--- /dev/null
@@ -0,0 +1,121 @@
+#ifdef USE_CUDNN
+#include <vector>
+
+#include "caffe/filler.hpp"
+#include "caffe/layer.hpp"
+#include "caffe/util/im2col.hpp"
+#include "caffe/util/math_functions.hpp"
+#include "caffe/vision_layers.hpp"
+
+namespace caffe {
+
+__global__ void sync_conv_groups() { }
+
+template <typename Dtype>
+void CuDNNConvolutionLayer<Dtype>::Forward_gpu(
+    const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+  for (int i = 0; i < bottom.size(); ++i) {
+    const Dtype* bottom_data = bottom[i]->gpu_data();
+    Dtype* top_data = (*top)[i]->mutable_gpu_data();
+    const Dtype* weight = this->blobs_[0]->gpu_data();
+
+    // Forward through cuDNN in parallel over groups.
+    for (int g = 0; g < this->group_; g++) {
+      cudnnStatus_t stat;
+
+      // Filters.
+      stat = cudnnConvolutionForward(handle_[g],
+          bottom_descs_[i], bottom_data + bottom_offset_ * g,
+          filter_desc_, weight + weight_offset_ * g,
+          conv_descs_[i],
+          top_descs_[i], top_data + top_offset_ * g,
+          CUDNN_RESULT_NO_ACCUMULATE);
+      CHECK_EQ(stat,CUDNN_STATUS_SUCCESS) << "Error in cudnnConvolutionForward";
+
+      // Bias.
+      if (this->bias_term_) {
+        const Dtype* bias_data = this->blobs_[1]->gpu_data();
+        Dtype alpha = 1.;
+        stat = cudnnAddTensor4d(handle_[g], CUDNN_ADD_SAME_C, &alpha,
+                                bias_desc_, bias_data + bias_offset_ * g,
+                                top_descs_[i], top_data + top_offset_ * g);
+        CHECK_EQ(stat,CUDNN_STATUS_SUCCESS) << "Error in cudnnAddTensor4d";
+      }
+    }
+
+    // Synchronize the work across groups, each of which went into its own
+    // stream, by launching an empty kernel into the default (null) stream.
+    // NOLINT_NEXT_LINE(whitespace/operators)
+    sync_conv_groups<<<1, 1>>>();
+  }
+}
+
+template <typename Dtype>
+void CuDNNConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
+  const Dtype* weight = NULL;
+  Dtype* weight_diff = NULL;
+  if (this->param_propagate_down_[0]) {
+    weight = this->blobs_[0]->gpu_data();
+    weight_diff = this->blobs_[0]->mutable_gpu_diff();
+    caffe_gpu_set(this->blobs_[0]->count(), Dtype(0), weight_diff);
+  }
+  Dtype* bias_diff = NULL;
+  if (this->bias_term_ && this->param_propagate_down_[1]) {
+    bias_diff = this->blobs_[1]->mutable_gpu_diff();
+    caffe_gpu_set(this->blobs_[1]->count(), Dtype(0), bias_diff);
+  }
+  for (int i = 0; i < top.size(); ++i) {
+    const Dtype* top_diff = top[i]->gpu_diff();
+    // Backward through cuDNN in parallel over groups and gradients.
+    for (int g = 0; g < this->group_; g++) {
+      cudnnStatus_t stat;
+
+      // Gradient w.r.t. bias.
+      if (this->bias_term_ && this->param_propagate_down_[1]) {
+        stat = cudnnConvolutionBackwardBias(handle_[0*this->group_ + g],
+            top_descs_[i],  top_diff + top_offset_ * g,
+            bias_desc_, bias_diff + bias_offset_ * g,
+            CUDNN_RESULT_ACCUMULATE);
+        CHECK_EQ(stat,CUDNN_STATUS_SUCCESS)
+            << "Error in cudnnConvolutionBackwardBias.";
+      }
+
+      // Gradient w.r.t. weights.
+      if (this->param_propagate_down_[0]) {
+        const Dtype* bottom_data = (*bottom)[i]->gpu_data();
+        stat = cudnnConvolutionBackwardFilter(handle_[1*this->group_ + g],
+            bottom_descs_[i], bottom_data + bottom_offset_ * g,
+            top_descs_[i],    top_diff + top_offset_ * g,
+            conv_descs_[i],
+            filter_desc_, weight_diff + weight_offset_ * g,
+            CUDNN_RESULT_ACCUMULATE);
+      CHECK_EQ(stat,CUDNN_STATUS_SUCCESS)
+          << "Error in cudnnConvolutionBackwardFilter.";
+      }
+
+      // Gradient w.r.t. bottom data.
+      if (propagate_down[i]) {
+        Dtype* bottom_diff = (*bottom)[i]->mutable_gpu_diff();
+        stat = cudnnConvolutionBackwardData(handle_[2*this->group_ + g],
+            filter_desc_, weight + weight_offset_ * g,
+            top_descs_[i],    top_diff + top_offset_ * g,
+            conv_descs_[i],
+            bottom_descs_[i], bottom_diff + bottom_offset_ * g,
+            CUDNN_RESULT_NO_ACCUMULATE);
+        CHECK_EQ(stat,CUDNN_STATUS_SUCCESS)
+            << "Error in cudnnConvolutionBackwardData.";
+      }
+    }
+
+    // Synchronize the work across groups, each of which went into its own
+    // stream, by launching an empty kernel into the default (null) stream.
+    // NOLINT_NEXT_LINE(whitespace/operators)
+    sync_conv_groups<<<1, 1>>>();
+  }
+}
+
+INSTANTIATE_CLASS(CuDNNConvolutionLayer);
+
+}  // namespace caffe
+#endif
index 5a7ea80..0e7a8da 100644 (file)
@@ -302,4 +302,295 @@ TYPED_TEST(ConvolutionLayerTest, TestGradientGroup) {
       &(this->blob_top_vec_));
 }
 
+#ifdef USE_CUDNN
+
+template <typename Dtype>
+class CuDNNConvolutionLayerTest : public ::testing::Test {
+ protected:
+  CuDNNConvolutionLayerTest()
+      : blob_bottom_(new Blob<Dtype>(2, 3, 6, 4)),
+        blob_bottom_2_(new Blob<Dtype>(2, 3, 6, 4)),
+        blob_top_(new Blob<Dtype>()),
+        blob_top_2_(new Blob<Dtype>()) {}
+  virtual void SetUp() {
+    // fill the values
+    FillerParameter filler_param;
+    filler_param.set_value(1.);
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    filler.Fill(this->blob_bottom_2_);
+    blob_bottom_vec_.push_back(blob_bottom_);
+    blob_top_vec_.push_back(blob_top_);
+  }
+
+  virtual ~CuDNNConvolutionLayerTest() {
+    delete blob_bottom_;
+    delete blob_bottom_2_;
+    delete blob_top_;
+    delete blob_top_2_;
+  }
+
+  Blob<Dtype>* const blob_bottom_;
+  Blob<Dtype>* const blob_bottom_2_;
+  Blob<Dtype>* const blob_top_;
+  Blob<Dtype>* const blob_top_2_;
+  vector<Blob<Dtype>*> blob_bottom_vec_;
+  vector<Blob<Dtype>*> blob_top_vec_;
+};
+
+TYPED_TEST_CASE(CuDNNConvolutionLayerTest, TestDtypes);
+
+TYPED_TEST(CuDNNConvolutionLayerTest, TestSetupCuDNN) {
+  Caffe::set_mode(Caffe::GPU);
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->set_kernel_size(3);
+  convolution_param->set_stride(2);
+  convolution_param->set_num_output(4);
+  this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
+  this->blob_top_vec_.push_back(this->blob_top_2_);
+  shared_ptr<Layer<TypeParam> > layer(
+      new CuDNNConvolutionLayer<TypeParam>(layer_param));
+  layer->SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  EXPECT_EQ(this->blob_top_->num(), 2);
+  EXPECT_EQ(this->blob_top_->channels(), 4);
+  EXPECT_EQ(this->blob_top_->height(), 2);
+  EXPECT_EQ(this->blob_top_->width(), 1);
+  EXPECT_EQ(this->blob_top_2_->num(), 2);
+  EXPECT_EQ(this->blob_top_2_->channels(), 4);
+  EXPECT_EQ(this->blob_top_2_->height(), 2);
+  EXPECT_EQ(this->blob_top_2_->width(), 1);
+  // setting group should not change the shape
+  convolution_param->set_num_output(3);
+  convolution_param->set_group(3);
+  layer.reset(new CuDNNConvolutionLayer<TypeParam>(layer_param));
+  layer->SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  EXPECT_EQ(this->blob_top_->num(), 2);
+  EXPECT_EQ(this->blob_top_->channels(), 3);
+  EXPECT_EQ(this->blob_top_->height(), 2);
+  EXPECT_EQ(this->blob_top_->width(), 1);
+  EXPECT_EQ(this->blob_top_2_->num(), 2);
+  EXPECT_EQ(this->blob_top_2_->channels(), 3);
+  EXPECT_EQ(this->blob_top_2_->height(), 2);
+  EXPECT_EQ(this->blob_top_2_->width(), 1);
+}
+
+TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionCuDNN) {
+  // We will simply see if the convolution layer carries out averaging well.
+  Caffe::set_mode(Caffe::GPU);
+  shared_ptr<ConstantFiller<TypeParam> > filler;
+  FillerParameter filler_param;
+  filler_param.set_value(1.);
+  filler.reset(new ConstantFiller<TypeParam>(filler_param));
+  filler->Fill(this->blob_bottom_);
+  filler_param.set_value(2.);
+  filler.reset(new ConstantFiller<TypeParam>(filler_param));
+  filler->Fill(this->blob_bottom_2_);
+  this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
+  this->blob_top_vec_.push_back(this->blob_top_2_);
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->set_kernel_size(3);
+  convolution_param->set_stride(2);
+  convolution_param->set_num_output(4);
+  convolution_param->mutable_weight_filler()->set_type("constant");
+  convolution_param->mutable_weight_filler()->set_value(1);
+  convolution_param->mutable_bias_filler()->set_type("constant");
+  convolution_param->mutable_bias_filler()->set_value(0.1);
+  shared_ptr<Layer<TypeParam> > layer(
+      new CuDNNConvolutionLayer<TypeParam>(layer_param));
+  layer->SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer->Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  // After the convolution, the output should all have output values 27.1
+  const TypeParam* top_data = this->blob_top_->cpu_data();
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    EXPECT_NEAR(top_data[i], 27.1, 1e-4);
+  }
+  top_data = this->blob_top_2_->cpu_data();
+  for (int i = 0; i < this->blob_top_2_->count(); ++i) {
+    EXPECT_NEAR(top_data[i], 54.1, 1e-4);
+  }
+}
+
+TYPED_TEST(CuDNNConvolutionLayerTest, TestSimpleConvolutionGroupCuDNN) {
+  // We will simply see if the convolution layer carries out averaging well.
+  Caffe::set_mode(Caffe::GPU);
+  FillerParameter filler_param;
+  filler_param.set_value(1.);
+  ConstantFiller<TypeParam> filler(filler_param);
+  filler.Fill(this->blob_bottom_);
+  TypeParam* bottom_data = this->blob_bottom_->mutable_cpu_data();
+  for (int n = 0; n < this->blob_bottom_->num(); ++n) {
+    for (int c = 0; c < this->blob_bottom_->channels(); ++c) {
+      for (int h = 0; h < this->blob_bottom_->height(); ++h) {
+        for (int w = 0; w < this->blob_bottom_->width(); ++w) {
+          bottom_data[this->blob_bottom_->offset(n, c, h, w)] = c;
+        }
+      }
+    }
+  }
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->set_kernel_size(3);
+  convolution_param->set_stride(2);
+  convolution_param->set_num_output(3);
+  convolution_param->set_group(3);
+  convolution_param->mutable_weight_filler()->set_type("constant");
+  convolution_param->mutable_weight_filler()->set_value(1);
+  convolution_param->mutable_bias_filler()->set_type("constant");
+  convolution_param->mutable_bias_filler()->set_value(0.1);
+  shared_ptr<Layer<TypeParam> > layer(
+      new CuDNNConvolutionLayer<TypeParam>(layer_param));
+  layer->SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer->Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  // After the convolution, the output should all have output values 9.1
+  const TypeParam* top_data = this->blob_top_->cpu_data();
+  for (int n = 0; n < this->blob_top_->num(); ++n) {
+    for (int c = 0; c < this->blob_top_->channels(); ++c) {
+      for (int h = 0; h < this->blob_top_->height(); ++h) {
+        for (int w = 0; w < this->blob_top_->width(); ++w) {
+          TypeParam data = top_data[this->blob_top_->offset(n, c, h, w)];
+          EXPECT_NEAR(data, c * 9 + 0.1, 1e-4);
+        }
+      }
+    }
+  }
+}
+
+TYPED_TEST(CuDNNConvolutionLayerTest, TestSobelConvolutionCuDNN) {
+  // Test separable convolution by computing the Sobel operator
+  // as a single filter then comparing the result
+  // as the convolution of two rectangular filters.
+  Caffe::set_mode(Caffe::GPU);
+  // Fill bottoms with identical Gaussian noise.
+  shared_ptr<GaussianFiller<TypeParam> > filler;
+  FillerParameter filler_param;
+  filler_param.set_value(1.);
+  filler.reset(new GaussianFiller<TypeParam>(filler_param));
+  filler->Fill(this->blob_bottom_);
+  this->blob_bottom_2_->CopyFrom(*this->blob_bottom_);
+  // Compute Sobel G_x operator as 3 x 3 convolution.
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->set_kernel_size(3);
+  convolution_param->set_stride(2);
+  convolution_param->set_num_output(1);
+  convolution_param->set_bias_term(false);
+  shared_ptr<Layer<TypeParam> > layer(
+      new CuDNNConvolutionLayer<TypeParam>(layer_param));
+  layer->blobs().resize(1);
+  layer->blobs()[0].reset(new Blob<TypeParam>(1, 3, 3, 3));
+  TypeParam* weights = layer->blobs()[0]->mutable_cpu_data();
+  for (int c = 0; c < 3; ++c) {
+    int i = c * 9;  // 3 x 3 filter
+    weights[i +  0] = -1;
+    weights[i +  1] =  0;
+    weights[i +  2] =  1;
+    weights[i +  3] = -2;
+    weights[i +  4] =  0;
+    weights[i +  5] =  2;
+    weights[i +  6] = -1;
+    weights[i +  7] =  0;
+    weights[i +  8] =  1;
+  }
+  layer->SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  layer->Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
+  // Compute Sobel G_x operator as separable 3 x 1 and 1 x 3 convolutions.
+  // (1) the [1 2 1] column filter
+  vector<Blob<TypeParam>*> sep_blob_bottom_vec;
+  vector<Blob<TypeParam>*> sep_blob_top_vec;
+  shared_ptr<Blob<TypeParam> > blob_sep(new Blob<TypeParam>());
+  sep_blob_bottom_vec.push_back(this->blob_bottom_2_);
+  sep_blob_top_vec.push_back(this->blob_top_2_);
+  convolution_param->clear_kernel_size();
+  convolution_param->clear_stride();
+  convolution_param->set_kernel_h(3);
+  convolution_param->set_kernel_w(1);
+  convolution_param->set_stride_h(2);
+  convolution_param->set_stride_w(1);
+  convolution_param->set_num_output(1);
+  convolution_param->set_bias_term(false);
+  layer.reset(new CuDNNConvolutionLayer<TypeParam>(layer_param));
+  layer->blobs().resize(1);
+  layer->blobs()[0].reset(new Blob<TypeParam>(1, 3, 3, 1));
+  TypeParam* weights_1 = layer->blobs()[0]->mutable_cpu_data();
+  for (int c = 0; c < 3; ++c) {
+    int i = c * 3;  // 3 x 1 filter
+    weights_1[i +  0] = 1;
+    weights_1[i +  1] = 2;
+    weights_1[i +  2] = 1;
+  }
+  layer->SetUp(sep_blob_bottom_vec, &(sep_blob_top_vec));
+  layer->Forward(sep_blob_bottom_vec, &(sep_blob_top_vec));
+  // (2) the [-1 0 1] row filter
+  blob_sep->CopyFrom(*this->blob_top_2_, false, true);
+  sep_blob_bottom_vec.clear();
+  sep_blob_bottom_vec.push_back(blob_sep.get());
+  convolution_param->set_kernel_h(1);
+  convolution_param->set_kernel_w(3);
+  convolution_param->set_stride_h(1);
+  convolution_param->set_stride_w(2);
+  convolution_param->set_num_output(1);
+  convolution_param->set_bias_term(false);
+  layer.reset(new CuDNNConvolutionLayer<TypeParam>(layer_param));
+  layer->blobs().resize(1);
+  layer->blobs()[0].reset(new Blob<TypeParam>(1, 3, 1, 3));
+  TypeParam* weights_2 = layer->blobs()[0]->mutable_cpu_data();
+  for (int c = 0; c < 3; ++c) {
+    int i = c * 3;  // 1 x 3 filter
+    weights_2[i +  0] = -1;
+    weights_2[i +  1] =  0;
+    weights_2[i +  2] =  1;
+  }
+  layer->SetUp(sep_blob_bottom_vec, &(sep_blob_top_vec));
+  layer->Forward(sep_blob_bottom_vec, &(sep_blob_top_vec));
+  // Test equivalence of full and separable filters.
+  const TypeParam* top_data = this->blob_top_->cpu_data();
+  const TypeParam* sep_top_data = this->blob_top_2_->cpu_data();
+  for (int i = 0; i < this->blob_top_->count(); ++i) {
+    EXPECT_NEAR(top_data[i], sep_top_data[i], 1e-4);
+  }
+}
+
+TYPED_TEST(CuDNNConvolutionLayerTest, TestGradientCuDNN) {
+  Caffe::set_mode(Caffe::GPU);
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  this->blob_bottom_vec_.push_back(this->blob_bottom_2_);
+  this->blob_top_vec_.push_back(this->blob_top_2_);
+  convolution_param->set_kernel_size(3);
+  convolution_param->set_stride(2);
+  convolution_param->set_num_output(2);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("gaussian");
+  CuDNNConvolutionLayer<TypeParam> layer(layer_param);
+  GradientChecker<TypeParam> checker(1e-2, 1e-3);
+  checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
+      &(this->blob_top_vec_));
+}
+
+TYPED_TEST(CuDNNConvolutionLayerTest, TestGradientGroupCuDNN) {
+  Caffe::set_mode(Caffe::GPU);
+  LayerParameter layer_param;
+  ConvolutionParameter* convolution_param =
+      layer_param.mutable_convolution_param();
+  convolution_param->set_kernel_size(3);
+  convolution_param->set_stride(2);
+  convolution_param->set_num_output(3);
+  convolution_param->set_group(3);
+  convolution_param->mutable_weight_filler()->set_type("gaussian");
+  convolution_param->mutable_bias_filler()->set_type("gaussian");
+  CuDNNConvolutionLayer<TypeParam> layer(layer_param);
+  GradientChecker<TypeParam> checker(1e-2, 1e-3);
+  checker.CheckGradientExhaustive(&layer, &(this->blob_bottom_vec_),
+      &(this->blob_top_vec_));
+}
+
+#endif
+
 }  // namespace caffe