add support for N-D dilated convolution
authorFisher Yu <i@yf.io>
Mon, 28 Dec 2015 04:48:30 +0000 (20:48 -0800)
committerJonathan L Long <jonlong@cs.berkeley.edu>
Mon, 28 Dec 2015 23:07:39 +0000 (15:07 -0800)
include/caffe/layers/base_conv_layer.hpp
include/caffe/util/im2col.hpp
src/caffe/layer_factory.cpp
src/caffe/layers/im2col_layer.cpp
src/caffe/layers/im2col_layer.cu
src/caffe/test/test_im2col_kernel.cu
src/caffe/test/test_im2col_layer.cpp
src/caffe/util/im2col.cpp
src/caffe/util/im2col.cu

index db471b5..0160a83 100644 (file)
@@ -106,7 +106,7 @@ class BaseConvolutionLayer : public Layer<Dtype> {
     } else {
       im2col_nd_cpu(data, num_spatial_axes_, conv_input_shape_.cpu_data(),
           col_buffer_shape_.data(), kernel_shape_.cpu_data(),
-          pad_.cpu_data(), stride_.cpu_data(), col_buff);
+          pad_.cpu_data(), stride_.cpu_data(), dilation_.cpu_data(), col_buff);
     }
   }
   inline void conv_col2im_cpu(const Dtype* col_buff, Dtype* data) {
@@ -120,7 +120,7 @@ class BaseConvolutionLayer : public Layer<Dtype> {
     } else {
       col2im_nd_cpu(col_buff, num_spatial_axes_, conv_input_shape_.cpu_data(),
           col_buffer_shape_.data(), kernel_shape_.cpu_data(),
-          pad_.cpu_data(), stride_.cpu_data(), data);
+          pad_.cpu_data(), stride_.cpu_data(), dilation_.cpu_data(), data);
     }
   }
 #ifndef CPU_ONLY
@@ -136,7 +136,7 @@ class BaseConvolutionLayer : public Layer<Dtype> {
       im2col_nd_gpu(data, num_spatial_axes_, num_kernels_im2col_,
           conv_input_shape_.gpu_data(), col_buffer_.gpu_shape(),
           kernel_shape_.gpu_data(), pad_.gpu_data(),
-          stride_.gpu_data(), col_buff);
+          stride_.gpu_data(), dilation_.gpu_data(), col_buff);
     }
   }
   inline void conv_col2im_gpu(const Dtype* col_buff, Dtype* data) {
@@ -151,7 +151,7 @@ class BaseConvolutionLayer : public Layer<Dtype> {
       col2im_nd_gpu(col_buff, num_spatial_axes_, num_kernels_col2im_,
           conv_input_shape_.gpu_data(), col_buffer_.gpu_shape(),
           kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(),
-          data);
+          dilation_.gpu_data(), data);
     }
   }
 #endif
index 748b65c..a35bc6e 100644 (file)
@@ -7,7 +7,7 @@ template <typename Dtype>
 void im2col_nd_cpu(const Dtype* data_im, const int num_spatial_axes,
     const int* im_shape, const int* col_shape,
     const int* kernel_shape, const int* pad, const int* stride,
-    Dtype* data_col);
+    const int* dilation, Dtype* data_col);
 
 template <typename Dtype>
 void im2col_cpu(const Dtype* data_im, const int channels,
@@ -20,7 +20,7 @@ template <typename Dtype>
 void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes,
     const int* im_shape, const int* col_shape,
     const int* kernel_shape, const int* pad, const int* stride,
-    Dtype* data_im);
+    const int* dilation, Dtype* data_im);
 
 template <typename Dtype>
 void col2im_cpu(const Dtype* data_col, const int channels,
@@ -33,7 +33,7 @@ template <typename Dtype>
 void im2col_nd_gpu(const Dtype* data_im, const int num_spatial_axes,
     const int col_size, const int* im_shape, const int* col_shape,
     const int* kernel_shape, const int* pad, const int* stride,
-    Dtype* data_col);
+    const int* dilation, Dtype* data_col);
 
 template <typename Dtype>
 void im2col_gpu(const Dtype* data_im, const int channels,
@@ -46,7 +46,7 @@ template <typename Dtype>
 void col2im_nd_gpu(const Dtype* data_col, const int num_spatial_axes,
     const int im_size, const int* im_shape, const int* col_shape,
     const int* kernel_shape, const int* pad, const int* stride,
-    Dtype* data_im);
+    const int* dilation, Dtype* data_im);
 
 template <typename Dtype>
 void col2im_gpu(const Dtype* data_col, const int channels,
index 6b1d1c1..4d912d2 100644 (file)
@@ -39,12 +39,14 @@ shared_ptr<Layer<Dtype> > GetConvolutionLayer(
     const LayerParameter& param) {
   ConvolutionParameter conv_param = param.convolution_param();
   ConvolutionParameter_Engine engine = conv_param.engine();
+#ifdef USE_CUDNN
   bool use_dilation = false;
   for (int i = 0; i < conv_param.dilation_size(); ++i) {
     if (conv_param.dilation(i) > 1) {
       use_dilation = true;
     }
   }
+#endif
   if (engine == ConvolutionParameter_Engine_DEFAULT) {
     engine = ConvolutionParameter_Engine_CAFFE;
 #ifdef USE_CUDNN
index 19ae301..2fb9b3c 100644 (file)
@@ -153,7 +153,7 @@ void Im2colLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
           bottom[0]->shape().data() + channel_axis_,
           top[0]->shape().data() + channel_axis_,
           kernel_shape_.cpu_data(), pad_.cpu_data(), stride_.cpu_data(),
-          top_data + n * top_dim_);
+          dilation_.cpu_data(), top_data + n * top_dim_);
     }
   }
 }
@@ -178,7 +178,7 @@ void Im2colLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
           bottom[0]->shape().data() + channel_axis_,
           top[0]->shape().data() + channel_axis_,
           kernel_shape_.cpu_data(), pad_.cpu_data(), stride_.cpu_data(),
-          bottom_diff + n * bottom_dim_);
+          dilation_.cpu_data(), bottom_diff + n * bottom_dim_);
     }
   }
 }
index d90075d..792c97f 100644 (file)
@@ -26,7 +26,7 @@ void Im2colLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
           num_kernels, bottom[0]->gpu_shape() + channel_axis_,
           top[0]->gpu_shape() + channel_axis_,
           kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(),
-          top_data + n * top_dim_);
+          dilation_.gpu_data(), top_data + n * top_dim_);
     }
   }
 }
@@ -51,7 +51,7 @@ void Im2colLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
           bottom[0]->gpu_shape() + channel_axis_,
           top[0]->gpu_shape() + channel_axis_,
           kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(),
-          bottom_diff + n * bottom_dim_);
+          dilation_.gpu_data(), bottom_diff + n * bottom_dim_);
     }
   }
 }
index 15e06aa..5d8f01f 100644 (file)
@@ -26,7 +26,7 @@ template <typename Dtype, int num_axes>
 __global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im,
     const int* im_shape, const int* col_shape,
     const int* kernel_shape, const int* pad, const int* stride,
-    Dtype* data_col);
+    const int* dilation, Dtype* data_col);
 
 extern cudaDeviceProp CAFFE_TEST_CUDA_PROP;
 
@@ -35,7 +35,7 @@ class Im2colKernelTest : public GPUDeviceTest<Dtype> {
  protected:
   Im2colKernelTest()
         // big so launches > 1024 threads
-      : blob_bottom_(new Blob<Dtype>(5, 500, 10, 10)),
+      : blob_bottom_(new Blob<Dtype>(5, 500, 15, 15)),
         blob_kernel_shape_(new Blob<int>()),
         blob_stride_(new Blob<int>()),
         blob_pad_(new Blob<int>()),
@@ -56,7 +56,7 @@ class Im2colKernelTest : public GPUDeviceTest<Dtype> {
     channels_ = blob_bottom_->channels();
     pad_ = 0;
     stride_ = 2;
-    dilation_ = 1;
+    dilation_ = 3;
     kernel_size_ = 3;
     height_col_ = (height_ + 2 * pad_ -
         (dilation_ * (kernel_size_ - 1) + 1)) / stride_ + 1;
@@ -176,6 +176,7 @@ TYPED_TEST(Im2colKernelTest, TestND) {
         this->blob_top_cpu_->shape().data() + 1,
         this->blob_kernel_shape_->cpu_data(),
         this->blob_pad_->cpu_data(), this->blob_stride_->cpu_data(),
+        this->blob_dilation_->cpu_data(),
         top_data_cpu + this->blob_top_cpu_->offset(n));
   }
 
@@ -194,7 +195,7 @@ TYPED_TEST(Im2colKernelTest, TestND) {
           num_kernels, bottom_data_gpu + this->blob_bottom_->offset(n),
           this->blob_bottom_->gpu_shape() + 1, this->blob_top_->gpu_shape() + 1,
           this->blob_kernel_shape_->gpu_data(), this->blob_pad_->gpu_data(),
-          this->blob_stride_->gpu_data(),
+          this->blob_stride_->gpu_data(), this->blob_dilation_->gpu_data(),
           top_data_gpu + this->blob_top_->offset(n));
       CUDA_POST_KERNEL_CHECK;
     }
index 932d3f2..24885e6 100644 (file)
@@ -17,7 +17,7 @@ class Im2colLayerTest : public MultiDeviceTest<TypeParam> {
   typedef typename TypeParam::Dtype Dtype;
  protected:
   Im2colLayerTest()
-      : blob_bottom_(new Blob<Dtype>(2, 3, 10, 9)),
+      : blob_bottom_(new Blob<Dtype>(2, 3, 10, 11)),
         blob_top_(new Blob<Dtype>()) {
     // fill the values
     Caffe::set_random_seed(1701);
@@ -43,12 +43,13 @@ TYPED_TEST(Im2colLayerTest, TestSetup) {
       layer_param.mutable_convolution_param();
   convolution_param->add_kernel_size(3);
   convolution_param->add_stride(2);
+  convolution_param->add_dilation(3);
   Im2colLayer<Dtype> layer(layer_param);
   layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_);
   EXPECT_EQ(this->blob_top_->num(), 2);
   EXPECT_EQ(this->blob_top_->channels(), 27);
   EXPECT_EQ(this->blob_top_->height(), 2);
-  EXPECT_EQ(this->blob_top_->width(), 2);
+  EXPECT_EQ(this->blob_top_->width(), 3);
 }
 
 TYPED_TEST(Im2colLayerTest, TestForward) {
@@ -89,6 +90,7 @@ TYPED_TEST(Im2colLayerTest, TestGradientForceND) {
       layer_param.mutable_convolution_param();
   convolution_param->add_kernel_size(3);
   convolution_param->add_stride(2);
+  convolution_param->add_dilation(3);
   convolution_param->set_force_nd_im2col(true);
   Im2colLayer<Dtype> layer(layer_param);
   GradientChecker<Dtype> checker(1e-2, 1e-2);
@@ -123,6 +125,8 @@ TYPED_TEST(Im2colLayerTest, TestRectGradient) {
   convolution_param->set_kernel_h(5);
   convolution_param->set_kernel_w(3);
   convolution_param->add_stride(2);
+  convolution_param->add_dilation(1);
+  convolution_param->add_dilation(3);
   Im2colLayer<Dtype> layer(layer_param);
   GradientChecker<Dtype> checker(1e-2, 1e-2);
   checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_,
index 1e578e7..6e5ea87 100644 (file)
@@ -49,7 +49,7 @@ template <typename Dtype>
 inline void im2col_nd_core_cpu(const Dtype* data_input, const bool im2col,
     const int num_spatial_axes, const int* im_shape, const int* col_shape,
     const int* kernel_shape, const int* pad, const int* stride,
-    Dtype* data_output) {
+    const int* dilation, Dtype* data_output) {
   if (!im2col) {
     int im_size = im_shape[0];
     for (int i = 0; i < num_spatial_axes; ++i) {
@@ -81,7 +81,8 @@ inline void im2col_nd_core_cpu(const Dtype* data_input, const bool im2col,
       bool is_padding = false;
       for (int d_i = 0; d_i < num_spatial_axes; ++d_i) {
         const int d = d_iter[d_i];
-        const int d_im = d * stride[d_i] - pad[d_i] + d_offset[d_i];
+        const int d_im = d * stride[d_i] - pad[d_i] +
+            d_offset[d_i] * dilation[d_i];
         is_padding |= d_im < 0 || d_im >= im_shape[d_i + 1];
         index_col *= col_shape[d_i + 1];
         index_col += d;
@@ -119,10 +120,10 @@ template <typename Dtype>
 void im2col_nd_cpu(const Dtype* data_im, const int num_spatial_axes,
     const int* im_shape, const int* col_shape,
     const int* kernel_shape, const int* pad, const int* stride,
-    Dtype* data_col) {
+    const int* dilation, Dtype* data_col) {
   const bool kIm2Col = true;
   im2col_nd_core_cpu(data_im, kIm2Col, num_spatial_axes, im_shape, col_shape,
-                  kernel_shape, pad, stride, data_col);
+                  kernel_shape, pad, stride, dilation, data_col);
 }
 
 // Explicit instantiation
@@ -130,12 +131,12 @@ template void im2col_nd_cpu<float>(const float* data_im,
     const int num_spatial_axes,
     const int* im_shape, const int* col_shape,
     const int* kernel_shape, const int* pad, const int* stride,
-    float* data_col);
+    const int* dilation, float* data_col);
 template void im2col_nd_cpu<double>(const double* data_im,
     const int num_spatial_axes,
     const int* im_shape, const int* col_shape,
     const int* kernel_shape, const int* pad, const int* stride,
-    double* data_col);
+    const int* dilation, double* data_col);
 
 template <typename Dtype>
 void col2im_cpu(const Dtype* data_col, const int channels,
@@ -182,10 +183,10 @@ template <typename Dtype>
 void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes,
     const int* im_shape, const int* col_shape,
     const int* kernel_shape, const int* pad, const int* stride,
-    Dtype* data_im) {
+    const int* dilation, Dtype* data_im) {
   const bool kIm2Col = false;
   im2col_nd_core_cpu(data_col, kIm2Col, num_spatial_axes, im_shape, col_shape,
-                     kernel_shape, pad, stride, data_im);
+                     kernel_shape, pad, stride, dilation, data_im);
 }
 
 // Explicit instantiation
@@ -193,12 +194,12 @@ template void col2im_nd_cpu<float>(const float* data_col,
     const int num_spatial_axes,
     const int* im_shape, const int* col_shape,
     const int* kernel_shape, const int* pad, const int* stride,
-    float* data_im);
+    const int* dilation, float* data_im);
 template void col2im_nd_cpu<double>(const double* data_col,
     const int num_spatial_axes,
     const int* im_shape, const int* col_shape,
     const int* kernel_shape, const int* pad, const int* stride,
-    double* data_im);
+    const int* dilation, double* data_im);
 
 
 }  // namespace caffe
index cdcaac5..a8f30a0 100644 (file)
@@ -75,9 +75,29 @@ template <typename Dtype, int num_axes>
 __global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im,
     const int* im_shape, const int* col_shape,
     const int* kernel_shape, const int* pad, const int* stride,
-    Dtype* data_col) {
+    const int* dilation, Dtype* data_col) {
   int d_temp[num_axes];  // NOLINT(runtime/arrays)
   int d_iter[num_axes];  // NOLINT(runtime/arrays)
+
+  __shared__ int shared_dilation[num_axes];
+  __shared__ int shared_kernel_shape[num_axes];
+  __shared__ int shared_pad[num_axes];
+  __shared__ int shared_stride[num_axes];
+  __shared__ int shared_col_shape[num_axes + 1];
+  __shared__ int shared_im_shape[num_axes + 1];
+
+  if (threadIdx.x < num_axes) {
+    shared_dilation[threadIdx.x] = dilation[threadIdx.x];
+    shared_kernel_shape[threadIdx.x] = kernel_shape[threadIdx.x];
+    shared_pad[threadIdx.x] = pad[threadIdx.x];
+    shared_stride[threadIdx.x] = stride[threadIdx.x];
+  }
+  if (threadIdx.x < num_axes + 1) {
+    shared_col_shape[threadIdx.x] = col_shape[threadIdx.x];
+    shared_im_shape[threadIdx.x] = im_shape[threadIdx.x];
+  }
+  __syncthreads();
+
   int i;
   CUDA_KERNEL_LOOP(index, n) {
     // Initialize channel_in, computed in the loop below, with intermediate
@@ -85,19 +105,19 @@ __global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im,
     int channel_in = index;
     int channel_out = 1;
     for (i = num_axes - 1; i >= 0; --i) {
-      d_temp[i] = channel_in % col_shape[i + 1];
-      channel_in /= col_shape[i + 1];
-      channel_out *= kernel_shape[i];
+      d_temp[i] = channel_in % shared_col_shape[i + 1];
+      channel_in /= shared_col_shape[i + 1];
+      channel_out *= shared_kernel_shape[i];
     }
     channel_out *= channel_in;
     int data_col_inc = 1;
     for (i = 0; i < num_axes; ++i) {
-      channel_out *= col_shape[i + 1];
+      channel_out *= shared_col_shape[i + 1];
       channel_out += d_temp[i];
-      d_temp[i] = d_temp[i] * stride[i] - pad[i];
-      channel_in *= im_shape[i + 1];
+      d_temp[i] = d_temp[i] * shared_stride[i] - shared_pad[i];
+      channel_in *= shared_im_shape[i + 1];
       channel_in += d_temp[i];
-      data_col_inc *= col_shape[i + 1];
+      data_col_inc *= shared_col_shape[i + 1];
       d_iter[i] = 0;
     }
     Dtype* data_col_ptr = data_col + channel_out;
@@ -106,15 +126,15 @@ __global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im,
     do {
       bool in_range = true;
       for (i = 0; i < num_axes; ++i) {
-        const int d_iter_im = d_iter[i] + d_temp[i];
-        in_range &= d_iter_im >= 0 && d_iter_im < im_shape[i + 1];
+        const int d_iter_im = d_iter[i] * shared_dilation[i] + d_temp[i];
+        in_range &= d_iter_im >= 0 && d_iter_im < shared_im_shape[i + 1];
         if (!in_range) { break; }
       }
       if (in_range) {
-        int data_im_offset = d_iter[0];
+        int data_im_offset = d_iter[0] * shared_dilation[0];
         for (i = 1; i < num_axes; ++i) {
-          data_im_offset *= im_shape[i + 1];
-          data_im_offset += d_iter[i];
+          data_im_offset *= shared_im_shape[i + 1];
+          data_im_offset += d_iter[i] * shared_dilation[i];
         }
         *data_col_ptr = data_im_ptr[data_im_offset];
       } else {
@@ -123,7 +143,7 @@ __global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im,
       data_col_ptr += data_col_inc;
       incremented = false;
       for (i = num_axes - 1; i >= 0; --i) {
-        const int d_max = kernel_shape[i];
+        const int d_max = shared_kernel_shape[i];
         if (d_iter[i] == d_max - 1) {
           d_iter[i] = 0;
         } else {  // d_iter[i] < d_max - 1
@@ -140,67 +160,69 @@ template <typename Dtype>
 void im2col_nd_gpu(const Dtype* data_im, const int num_spatial_axes,
     const int num_kernels, const int* im_shape, const int* col_shape,
     const int* kernel_shape, const int* pad, const int* stride,
-    Dtype* data_col) {
+    const int* dilation, Dtype* data_col) {
+  // num_axes should be smaller than block size
+  DCHECK_LT(num_spatial_axes, CAFFE_CUDA_NUM_THREADS);
   switch (num_spatial_axes) {
   case 1:
     im2col_nd_gpu_kernel<Dtype, 1>  // NOLINT_NEXT_LINE(whitespace/operators)
         <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
         num_kernels, data_im, im_shape, col_shape,
-        kernel_shape, pad, stride, data_col);
+        kernel_shape, pad, stride, dilation, data_col);
     break;
   case 2:
     im2col_nd_gpu_kernel<Dtype, 2>  // NOLINT_NEXT_LINE(whitespace/operators)
         <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
         num_kernels, data_im, im_shape, col_shape,
-        kernel_shape, pad, stride, data_col);
+        kernel_shape, pad, stride, dilation, data_col);
     break;
   case 3:
     im2col_nd_gpu_kernel<Dtype, 3>  // NOLINT_NEXT_LINE(whitespace/operators)
         <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
         num_kernels, data_im, im_shape, col_shape,
-        kernel_shape, pad, stride, data_col);
+        kernel_shape, pad, stride, dilation, data_col);
     break;
   case 4:
     im2col_nd_gpu_kernel<Dtype, 4>  // NOLINT_NEXT_LINE(whitespace/operators)
         <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
         num_kernels, data_im, im_shape, col_shape,
-        kernel_shape, pad, stride, data_col);
+        kernel_shape, pad, stride, dilation, data_col);
     break;
   case 5:
     im2col_nd_gpu_kernel<Dtype, 5>  // NOLINT_NEXT_LINE(whitespace/operators)
         <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
         num_kernels, data_im, im_shape, col_shape,
-        kernel_shape, pad, stride, data_col);
+        kernel_shape, pad, stride, dilation, data_col);
     break;
   case 6:
     im2col_nd_gpu_kernel<Dtype, 6>  // NOLINT_NEXT_LINE(whitespace/operators)
         <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
         num_kernels, data_im, im_shape, col_shape,
-        kernel_shape, pad, stride, data_col);
+        kernel_shape, pad, stride, dilation, data_col);
     break;
   case 7:
     im2col_nd_gpu_kernel<Dtype, 7>  // NOLINT_NEXT_LINE(whitespace/operators)
         <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
         num_kernels, data_im, im_shape, col_shape,
-        kernel_shape, pad, stride, data_col);
+        kernel_shape, pad, stride, dilation, data_col);
     break;
   case 8:
     im2col_nd_gpu_kernel<Dtype, 8>  // NOLINT_NEXT_LINE(whitespace/operators)
         <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
         num_kernels, data_im, im_shape, col_shape,
-        kernel_shape, pad, stride, data_col);
+        kernel_shape, pad, stride, dilation, data_col);
     break;
   case 9:
     im2col_nd_gpu_kernel<Dtype, 9>  // NOLINT_NEXT_LINE(whitespace/operators)
         <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
         num_kernels, data_im, im_shape, col_shape,
-        kernel_shape, pad, stride, data_col);
+        kernel_shape, pad, stride, dilation, data_col);
     break;
   case 10:
     im2col_nd_gpu_kernel<Dtype, 10>  // NOLINT_NEXT_LINE(whitespace/operators)
         <<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
         num_kernels, data_im, im_shape, col_shape,
-        kernel_shape, pad, stride, data_col);
+        kernel_shape, pad, stride, dilation, data_col);
     break;
   default:
     LOG(FATAL) << "im2col_nd_gpu does not support computation with "
@@ -214,12 +236,12 @@ template void im2col_nd_gpu<float>(const float* data_im,
     const int num_spatial_axes, const int col_size,
     const int* im_shape, const int* col_shape,
     const int* kernel_shape, const int* pad, const int* stride,
-    float* data_col);
+    const int* dilation, float* data_col);
 template void im2col_nd_gpu<double>(const double* data_im,
     const int num_spatial_axes, const int col_size,
     const int* im_shape, const int* col_shape,
     const int* kernel_shape, const int* pad, const int* stride,
-    double* data_col);
+    const int* dilation, double* data_col);
 
 template <typename Dtype>
 __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col,
@@ -300,27 +322,50 @@ template <typename Dtype, int num_axes>
 __global__ void col2im_nd_gpu_kernel(const int n, const Dtype* data_col,
     const int* im_shape, const int* col_shape,
     const int* kernel_shape, const int* pad, const int* stride,
-    Dtype* data_im) {
+    const int* dilation, Dtype* data_im) {
   int d_im[num_axes];  // NOLINT(runtime/arrays)
   int d_col_iter[num_axes];  // NOLINT(runtime/arrays)
   int d_col_start[num_axes];  // NOLINT(runtime/arrays)
   int d_col_end[num_axes];  // NOLINT(runtime/arrays)
+
+  __shared__ int shared_dilation[num_axes];
+  __shared__ int shared_kernel_shape[num_axes];
+  __shared__ int shared_pad[num_axes];
+  __shared__ int shared_stride[num_axes];
+  __shared__ int shared_col_shape[num_axes + 1];
+  __shared__ int shared_im_shape[num_axes + 1];
+
+  if (threadIdx.x < num_axes) {
+    shared_dilation[threadIdx.x] = dilation[threadIdx.x];
+    shared_kernel_shape[threadIdx.x] = kernel_shape[threadIdx.x];
+    shared_pad[threadIdx.x] = pad[threadIdx.x];
+    shared_stride[threadIdx.x] = stride[threadIdx.x];
+  }
+  if (threadIdx.x < num_axes + 1) {
+    shared_col_shape[threadIdx.x] = col_shape[threadIdx.x];
+    shared_im_shape[threadIdx.x] = im_shape[threadIdx.x];
+  }
+  __syncthreads();
+
   CUDA_KERNEL_LOOP(index, n) {
     // Initialize channel_in, computed in the loop below, with intermediate
     // computations used to compute the spatial indices.
     int c_im = index;
     // Calculate d_im (image dimensions).
     for (int i = num_axes - 1; i >= 0; --i) {
-      d_im[i] = c_im % im_shape[i + 1] + pad[i];
-      c_im /= im_shape[i + 1];
+      d_im[i] = c_im % shared_im_shape[i + 1] + shared_pad[i];
+      c_im /= shared_im_shape[i + 1];
     }
     // Calculate col start/end indices.
     bool done = false;
     for (int i = 0; i < num_axes; ++i) {
+      const int kernel_extent =
+          shared_dilation[i] * (shared_kernel_shape[i] - 1) + 1;
       d_col_start[i] = d_col_iter[i] =
-          (d_im[i] < kernel_shape[i]) ?
-          0 : (d_im[i] - kernel_shape[i]) / stride[i] + 1;
-      d_col_end[i] = min(d_im[i] / stride[i] + 1, col_shape[i + 1]);
+          (d_im[i] < kernel_extent) ? 0 :
+          (d_im[i] - kernel_extent) / shared_stride[i] + 1;
+      d_col_end[i] =
+          min(d_im[i] / shared_stride[i] + 1, shared_col_shape[i + 1]);
       if (d_col_start[i] >= d_col_end[i]) {
         // Skip computation if the dimension is 0 at any spatial axis --
         // final val will be 0.
@@ -335,21 +380,32 @@ __global__ void col2im_nd_gpu_kernel(const int n, const Dtype* data_col,
     // Loop over the col to compute the output val.
     Dtype val = 0;
     bool incremented = true;
+    bool skip = false;
     do {
       // Compute the final offset.
       int final_offset = 0;
       int kernel_shape_prod = 1;
+      int kernel_index;
       for (int i = num_axes - 1; i >= 0; --i) {
-        final_offset +=
-            (d_im[i] - d_col_iter[i] * stride[i]) * kernel_shape_prod;
-        kernel_shape_prod *= kernel_shape[i];
+        kernel_index = d_im[i] - d_col_iter[i] * shared_stride[i];
+        if (kernel_index % shared_dilation[i]) {
+          skip = true;
+          break;
+        } else {
+          kernel_index /= shared_dilation[i];
+          final_offset += kernel_index * kernel_shape_prod;
+          kernel_shape_prod *= shared_kernel_shape[i];
+        }
       }
-      final_offset += kernel_shape_prod * c_im;
-      for (int i = 0; i < num_axes; ++i) {
-        final_offset *= col_shape[i + 1];
-        final_offset += d_col_iter[i];
+      if (!skip) {
+        final_offset += kernel_shape_prod * c_im;
+        for (int i = 0; i < num_axes; ++i) {
+          final_offset *= shared_col_shape[i + 1];
+          final_offset += d_col_iter[i];
+        }
+        val += data_col[final_offset];
       }
-      val += data_col[final_offset];
+      skip = false;
       incremented = false;
       for (int i = num_axes - 1; i >= 0; --i) {
         const int d_max = d_col_end[i];
@@ -370,67 +426,69 @@ template <typename Dtype>
 void col2im_nd_gpu(const Dtype* data_col, const int num_spatial_axes,
     const int im_size, const int* im_shape, const int* col_shape,
     const int* kernel_shape, const int* pad, const int* stride,
-    Dtype* data_im) {
+    const int* dilation, Dtype* data_im) {
+  // num_axes should be smaller than block size
+  DCHECK_LT(num_spatial_axes, CAFFE_CUDA_NUM_THREADS);
   switch (num_spatial_axes) {
   case 1:
     col2im_nd_gpu_kernel<Dtype, 1>  // NOLINT_NEXT_LINE(whitespace/operators)
           <<<CAFFE_GET_BLOCKS(im_size), CAFFE_CUDA_NUM_THREADS>>>(
           im_size, data_col, im_shape, col_shape,
-          kernel_shape, pad, stride, data_im);
+          kernel_shape, pad, stride, dilation, data_im);
     break;
   case 2:
     col2im_nd_gpu_kernel<Dtype, 2>  // NOLINT_NEXT_LINE(whitespace/operators)
           <<<CAFFE_GET_BLOCKS(im_size), CAFFE_CUDA_NUM_THREADS>>>(
           im_size, data_col, im_shape, col_shape,
-          kernel_shape, pad, stride, data_im);
+          kernel_shape, pad, stride, dilation, data_im);
     break;
   case 3:
     col2im_nd_gpu_kernel<Dtype, 3>  // NOLINT_NEXT_LINE(whitespace/operators)
           <<<CAFFE_GET_BLOCKS(im_size), CAFFE_CUDA_NUM_THREADS>>>(
           im_size, data_col, im_shape, col_shape,
-          kernel_shape, pad, stride, data_im);
+          kernel_shape, pad, stride, dilation, data_im);
     break;
   case 4:
     col2im_nd_gpu_kernel<Dtype, 4>  // NOLINT_NEXT_LINE(whitespace/operators)
           <<<CAFFE_GET_BLOCKS(im_size), CAFFE_CUDA_NUM_THREADS>>>(
           im_size, data_col, im_shape, col_shape,
-          kernel_shape, pad, stride, data_im);
+          kernel_shape, pad, stride, dilation, data_im);
     break;
   case 5:
     col2im_nd_gpu_kernel<Dtype, 5>  // NOLINT_NEXT_LINE(whitespace/operators)
           <<<CAFFE_GET_BLOCKS(im_size), CAFFE_CUDA_NUM_THREADS>>>(
           im_size, data_col, im_shape, col_shape,
-          kernel_shape, pad, stride, data_im);
+          kernel_shape, pad, stride, dilation, data_im);
     break;
   case 6:
     col2im_nd_gpu_kernel<Dtype, 6>  // NOLINT_NEXT_LINE(whitespace/operators)
           <<<CAFFE_GET_BLOCKS(im_size), CAFFE_CUDA_NUM_THREADS>>>(
           im_size, data_col, im_shape, col_shape,
-          kernel_shape, pad, stride, data_im);
+          kernel_shape, pad, stride, dilation, data_im);
     break;
   case 7:
     col2im_nd_gpu_kernel<Dtype, 7>  // NOLINT_NEXT_LINE(whitespace/operators)
           <<<CAFFE_GET_BLOCKS(im_size), CAFFE_CUDA_NUM_THREADS>>>(
           im_size, data_col, im_shape, col_shape,
-          kernel_shape, pad, stride, data_im);
+          kernel_shape, pad, stride, dilation, data_im);
     break;
   case 8:
     col2im_nd_gpu_kernel<Dtype, 8>  // NOLINT_NEXT_LINE(whitespace/operators)
           <<<CAFFE_GET_BLOCKS(im_size), CAFFE_CUDA_NUM_THREADS>>>(
           im_size, data_col, im_shape, col_shape,
-          kernel_shape, pad, stride, data_im);
+          kernel_shape, pad, stride, dilation, data_im);
     break;
   case 9:
     col2im_nd_gpu_kernel<Dtype, 9>  // NOLINT_NEXT_LINE(whitespace/operators)
           <<<CAFFE_GET_BLOCKS(im_size), CAFFE_CUDA_NUM_THREADS>>>(
           im_size, data_col, im_shape, col_shape,
-          kernel_shape, pad, stride, data_im);
+          kernel_shape, pad, stride, dilation, data_im);
     break;
   case 10:
     col2im_nd_gpu_kernel<Dtype, 10>  // NOLINT_NEXT_LINE(whitespace/operators)
           <<<CAFFE_GET_BLOCKS(im_size), CAFFE_CUDA_NUM_THREADS>>>(
           im_size, data_col, im_shape, col_shape,
-          kernel_shape, pad, stride, data_im);
+          kernel_shape, pad, stride, dilation, data_im);
     break;
   default:
     LOG(FATAL) << "col2im_nd_gpu does not support computation with "
@@ -444,11 +502,11 @@ template void col2im_nd_gpu<float>(const float* data_col,
     const int num_spatial_axes, const int im_size,
     const int* im_shape, const int* col_shape,
     const int* kernel_shape, const int* pad, const int* stride,
-    float* data_im);
+    const int* dilation, float* data_im);
 template void col2im_nd_gpu<double>(const double* data_col,
     const int num_spatial_axes, const int im_size,
     const int* im_shape, const int* col_shape,
     const int* kernel_shape, const int* pad, const int* stride,
-    double* data_im);
+    const int* dilation, double* data_im);
 
 }  // namespace caffe