implemented padding aware im2col and col2im functions
authorlinmin <mavenlin@gmail.com>
Wed, 12 Feb 2014 04:13:51 +0000 (12:13 +0800)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Wed, 26 Feb 2014 23:41:57 +0000 (15:41 -0800)
include/caffe/util/im2col.hpp
include/caffe/vision_layers.hpp
src/caffe/layers/conv_layer.cpp
src/caffe/layers/im2col_layer.cpp
src/caffe/util/im2col.cpp
src/caffe/util/im2col.cu

index 83c01dd..dd01bdc 100644 (file)
@@ -25,6 +25,26 @@ void col2im_gpu(const Dtype* data_col, const int channels,
     const int height, const int width, const int psize, const int stride,
     Dtype* data_im);
 
+template <typename Dtype>
+void padded_im2col_cpu(const Dtype* data_im, const int channels,
+    const int height, const int width, const int ksize, const int pad, const int stride,
+    Dtype* data_col);
+
+template <typename Dtype>
+void padded_col2im_cpu(const Dtype* data_col, const int channels,
+    const int height, const int width, const int psize, const int pad, const int stride,
+    Dtype* data_im);
+
+template <typename Dtype>
+void padded_im2col_gpu(const Dtype* data_im, const int channels,
+    const int height, const int width, const int ksize, const int pad, const int stride,
+    Dtype* data_col);
+
+template <typename Dtype>
+void padded_col2im_gpu(const Dtype* data_col, const int channels,
+    const int height, const int width, const int psize, const int pad, const int stride,
+    Dtype* data_im);
+
 }  // namespace caffe
 
 #endif  // CAFFE_UTIL_IM2COL_HPP_
index 1861535..ca54f2b 100644 (file)
@@ -272,9 +272,9 @@ class Im2colLayer : public Layer<Dtype> {
   int CHANNELS_;
   int HEIGHT_;
   int WIDTH_;
+  int PAD_;
 };
 
-
 template <typename Dtype>
 class PoolingLayer : public Layer<Dtype> {
  public:
@@ -326,6 +326,7 @@ class ConvolutionLayer : public Layer<Dtype> {
   int STRIDE_;
   int NUM_;
   int CHANNELS_;
+  int PAD_;
   int HEIGHT_;
   int WIDTH_;
   int NUM_OUTPUT_;
index f2608be..8f06524 100644 (file)
@@ -18,6 +18,7 @@ void ConvolutionLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   KSIZE_ = this->layer_param_.kernelsize();
   STRIDE_ = this->layer_param_.stride();
   GROUP_ = this->layer_param_.group();
+  PAD_ = this->layer_param_.pad();
   NUM_ = bottom[0]->num();
   CHANNELS_ = bottom[0]->channels();
   HEIGHT_ = bottom[0]->height();
@@ -27,8 +28,8 @@ void ConvolutionLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   CHECK_EQ(CHANNELS_ % GROUP_, 0);
   // The im2col result buffer would only hold one image at a time to avoid
   // overly large memory usage.
-  int height_out = (HEIGHT_ - KSIZE_) / STRIDE_ + 1;
-  int width_out = (WIDTH_ - KSIZE_) / STRIDE_ + 1;
+  int height_out = (HEIGHT_ + 2 * PAD_ - KSIZE_) / STRIDE_ + 1;
+  int width_out = (WIDTH_ + 2 * PAD_ - KSIZE_) / STRIDE_ + 1;
   col_buffer_.Reshape(1, CHANNELS_ * KSIZE_ * KSIZE_, height_out, width_out);
   // Set the parameters
   CHECK_EQ(NUM_OUTPUT_ % GROUP_, 0)
@@ -87,8 +88,13 @@ void ConvolutionLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   int top_offset = M_ * N_;
   for (int n = 0; n < NUM_; ++n) {
     // First, im2col
-    im2col_cpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
+    if (PAD_ == 0) {
+      im2col_cpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
         WIDTH_, KSIZE_, STRIDE_, col_data);
+    } else {
+      padded_im2col_cpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
+       WIDTH_, KSIZE_, PAD_, STRIDE_, col_data);
+    }
     // Second, innerproduct with groups
     for (int g = 0; g < GROUP_; ++g) {
       caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_,
@@ -117,8 +123,13 @@ void ConvolutionLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   int top_offset = M_ * N_;
   for (int n = 0; n < NUM_; ++n) {
     // First, im2col
-    im2col_gpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
+    if (PAD_ == 0) {
+      im2col_gpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
         WIDTH_, KSIZE_, STRIDE_, col_data);
+    } else {
+      padded_im2col_gpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
+       WIDTH_, KSIZE_, PAD_, STRIDE_, col_data);
+    }
     // Second, innerproduct with groups
     for (int g = 0; g < GROUP_; ++g) {
       caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_,
@@ -166,8 +177,13 @@ Dtype ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
   for (int n = 0; n < NUM_; ++n) {
     // since we saved memory in the forward pass by not storing all col data,
     // we will need to recompute them.
-    im2col_cpu(bottom_data + (*bottom)[0]->offset(n), CHANNELS_, HEIGHT_,
+    if (PAD_ == 0) {
+      im2col_cpu(bottom_data + (*bottom)[0]->offset(n), CHANNELS_, HEIGHT_,
         WIDTH_, KSIZE_, STRIDE_, col_data);
+    } else {
+      padded_im2col_cpu(bottom_data + (*bottom)[0]->offset(n), CHANNELS_, HEIGHT_,
+       WIDTH_, KSIZE_, PAD_, STRIDE_, col_data);
+    }
     // gradient w.r.t. weight. Note that we will accumulate diffs.
     for (int g = 0; g < GROUP_; ++g) {
       caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_,
@@ -184,8 +200,13 @@ Dtype ConvolutionLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
           (Dtype)0., col_diff + col_offset * g);
       }
       // col2im back to the data
-      col2im_cpu(col_diff, CHANNELS_, HEIGHT_,
-        WIDTH_, KSIZE_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n));
+      if (PAD_ == 0) {
+       col2im_cpu(col_diff, CHANNELS_, HEIGHT_,
+         WIDTH_, KSIZE_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n));
+      } else {
+       padded_col2im_cpu(col_diff, CHANNELS_, HEIGHT_,
+         WIDTH_, KSIZE_, PAD_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n));
+      }
     }
   }
   return Dtype(0.);
@@ -224,8 +245,13 @@ Dtype ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   for (int n = 0; n < NUM_; ++n) {
     // since we saved memory in the forward pass by not storing all col data,
     // we will need to recompute them.
-    im2col_gpu(bottom_data + (*bottom)[0]->offset(n), CHANNELS_, HEIGHT_,
+    if (PAD_ == 0) {
+      im2col_gpu(bottom_data + (*bottom)[0]->offset(n), CHANNELS_, HEIGHT_,
         WIDTH_, KSIZE_, STRIDE_, col_data);
+    } else {
+      padded_im2col_gpu(bottom_data + (*bottom)[0]->offset(n), CHANNELS_, HEIGHT_,
+       WIDTH_, KSIZE_, PAD_, STRIDE_, col_data);
+    }
     // gradient w.r.t. weight. Note that we will accumulate diffs.
     for (int g = 0; g < GROUP_; ++g) {
       caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_,
@@ -242,8 +268,13 @@ Dtype ConvolutionLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
           (Dtype)0., col_diff + col_offset * g);
       }
       // col2im back to the data
-      col2im_gpu(col_diff, CHANNELS_, HEIGHT_,
+      if (PAD_ == 0) {
+       col2im_gpu(col_diff, CHANNELS_, HEIGHT_,
           WIDTH_, KSIZE_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n));
+      } else {
+       padded_col2im_gpu(col_diff, CHANNELS_, HEIGHT_,
+         WIDTH_, KSIZE_, PAD_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n));
+      }
     }
   }
   return Dtype(0.);
index 976c844..7aced1d 100644 (file)
@@ -16,11 +16,12 @@ void Im2colLayer<Dtype>::SetUp(const vector<Blob<Dtype>*>& bottom,
   CHECK_EQ(top->size(), 1) << "Im2col Layer takes a single blob as output.";
   KSIZE_ = this->layer_param_.kernelsize();
   STRIDE_ = this->layer_param_.stride();
+  PAD_ = this->layer_param_.pad();
   CHANNELS_ = bottom[0]->channels();
   HEIGHT_ = bottom[0]->height();
   WIDTH_ = bottom[0]->width();
   (*top)[0]->Reshape(bottom[0]->num(), CHANNELS_ * KSIZE_ * KSIZE_,
-      (HEIGHT_ - KSIZE_) / STRIDE_ + 1, (WIDTH_ - KSIZE_) / STRIDE_ + 1);
+      (HEIGHT_ + 2 * PAD_ - KSIZE_) / STRIDE_ + 1, (WIDTH_ + 2 * PAD_ - KSIZE_) / STRIDE_ + 1);
 };
 
 template <typename Dtype>
@@ -29,8 +30,13 @@ void Im2colLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
   const Dtype* bottom_data = bottom[0]->cpu_data();
   Dtype* top_data = (*top)[0]->mutable_cpu_data();
   for (int n = 0; n < bottom[0]->num(); ++n) {
-    im2col_cpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
+    if (PAD_ == 0) {
+      im2col_cpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
         WIDTH_, KSIZE_, STRIDE_, top_data + (*top)[0]->offset(n));
+    } else {
+      padded_im2col_cpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
+       WIDTH_, KSIZE_, PAD_, STRIDE_, top_data + (*top)[0]->offset(n));
+    }
   }
 }
 
@@ -40,8 +46,13 @@ void Im2colLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
   const Dtype* bottom_data = bottom[0]->gpu_data();
   Dtype* top_data = (*top)[0]->mutable_gpu_data();
   for (int n = 0; n < bottom[0]->num(); ++n) {
-    im2col_gpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
+    if (PAD_ == 0) {
+      im2col_gpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
         WIDTH_, KSIZE_, STRIDE_, top_data + (*top)[0]->offset(n));
+    } else {
+      padded_im2col_gpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
+       WIDTH_, KSIZE_, PAD_, STRIDE_, top_data + (*top)[0]->offset(n));
+    }
   }
 }
 
@@ -51,8 +62,13 @@ Dtype Im2colLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
   const Dtype* top_diff = top[0]->cpu_diff();
   Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
   for (int n = 0; n < top[0]->num(); ++n) {
-    col2im_cpu(top_diff + top[0]->offset(n), CHANNELS_, HEIGHT_,
+    if (PAD_ == 0) {
+      col2im_cpu(top_diff + top[0]->offset(n), CHANNELS_, HEIGHT_,
         WIDTH_, KSIZE_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n));
+    } else {
+      padded_col2im_cpu(top_diff + top[0]->offset(n), CHANNELS_, HEIGHT_,
+       WIDTH_, KSIZE_, PAD_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n));
+    }
   }
   return Dtype(0.);
 }
@@ -64,8 +80,13 @@ Dtype Im2colLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   const Dtype* top_diff = top[0]->gpu_diff();
   Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
   for (int n = 0; n < top[0]->num(); ++n) {
-    col2im_gpu(top_diff + top[0]->offset(n), CHANNELS_, HEIGHT_,
+    if (PAD_ == 0) {
+      col2im_gpu(top_diff + top[0]->offset(n), CHANNELS_, HEIGHT_,
         WIDTH_, KSIZE_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n));
+    } else {
+      padded_col2im_gpu(top_diff + top[0]->offset(n), CHANNELS_, HEIGHT_,
+       WIDTH_, KSIZE_, PAD_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n));
+    }
   }
   return Dtype(0.);
 }
index db79bb2..bf5db05 100644 (file)
@@ -29,6 +29,31 @@ void im2col_cpu(const Dtype* data_im, const int channels,
   }
 }
 
+template <typename Dtype>
+void padded_im2col_cpu(const Dtype* data_im, const int channels,
+    const int height, const int width, const int ksize, const int pad, const int stride,
+    Dtype* data_col) {
+  int height_col = (height + 2 * pad - ksize) / stride + 1;
+  int width_col = (width + 2 * pad - ksize) / stride + 1;
+  int channels_col = channels * ksize * ksize;
+  for (int c = 0; c < channels_col; ++c) {
+    int w_offset = c % ksize;
+    int h_offset = (c / ksize) % ksize;
+    int c_im = c / ksize / ksize;
+    for (int h = 0; h < height_col; ++h) {
+      for (int w = 0; w < width_col; ++w) {
+       int h_pad = h * stride - pad + h_offset;
+       int w_pad = w * stride - pad + w_offset;
+       if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width)
+         data_col[(c * height_col + h) * width_col + w] =
+           data_im[(c_im * height + h_pad) * width + w_pad];
+       else
+         data_col[(c * height_col + h) * width_col + w] = 0;
+      }
+    }
+  }
+}
+
 // Explicit instantiation
 template void im2col_cpu<float>(const float* data_im, const int channels,
     const int height, const int width, const int ksize, const int stride,
@@ -36,6 +61,12 @@ template void im2col_cpu<float>(const float* data_im, const int channels,
 template void im2col_cpu<double>(const double* data_im, const int channels,
     const int height, const int width, const int ksize, const int stride,
     double* data_col);
+template void padded_im2col_cpu<float>(const float* data_im, const int channels,
+    const int height, const int width, const int ksize, const int pad, const int stride,
+    float* data_col);
+template void padded_im2col_cpu<double>(const double* data_im, const int channels,
+    const int height, const int width, const int ksize, const int pad, const int stride,
+    double* data_col);
 
 template <typename Dtype>
 void col2im_cpu(const Dtype* data_col, const int channels,
@@ -58,6 +89,29 @@ void col2im_cpu(const Dtype* data_col, const int channels,
   }
 }
 
+template <typename Dtype>
+void padded_col2im_cpu(const Dtype* data_col, const int channels,
+    const int height, const int width, const int ksize, const int pad, const int stride,
+    Dtype* data_im) {
+  memset(data_im, 0, sizeof(Dtype) * height * width * channels);
+  int height_col = (height + 2 * pad - ksize) / stride + 1;
+  int width_col = (width + 2 * pad - ksize) / stride + 1;
+  int channels_col = channels * ksize * ksize;
+  for (int c = 0; c < channels_col; ++c) {
+    int w_offset = c % ksize;
+    int h_offset = (c / ksize) % ksize;
+    int c_im = c / ksize / ksize;
+    for (int h = 0; h < height_col; ++h) {
+      for (int w = 0; w < width_col; ++w) {
+       int h_pad = h * stride - pad + h_offset;
+       int w_pad = w * stride - pad + w_offset;
+       if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width)
+         data_im[(c_im * height + h_pad) * width + w_pad] += data_col[(c * height_col + h) * width_col + w];
+      }
+    }
+  }
+}
+
 // Explicit instantiation
 template void col2im_cpu<float>(const float* data_col, const int channels,
     const int height, const int width, const int psize, const int stride,
@@ -65,5 +119,11 @@ template void col2im_cpu<float>(const float* data_col, const int channels,
 template void col2im_cpu<double>(const double* data_col, const int channels,
     const int height, const int width, const int psize, const int stride,
     double* data_im);
+template void padded_col2im_cpu<float>(const float* data_col, const int channels,
+    const int height, const int width, const int psize, const int pad, const int stride,
+    float* data_im);
+template void padded_col2im_cpu<double>(const double* data_col, const int channels,
+    const int height, const int width, const int psize, const int pad, const int stride,
+    double* data_im);
 
 }  // namespace caffe
index 0b0c8b8..0903964 100644 (file)
@@ -35,6 +35,32 @@ __global__ void im2col_gpu_kernel(const int n, const Dtype* data_im,
 }
 
 template <typename Dtype>
+__global__ void padded_im2col_gpu_kernel(const int n, const Dtype* data_im,
+  const int height, const int width, const int ksize, const int pad,
+  const int stride, const int height_col, const int width_col, Dtype* data_col) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index < n) {
+    int w_out = index % width_col;
+    index /= width_col;
+    int h_out = index % height_col;
+    int channel_in = index / height_col;
+    int channel_out = channel_in * ksize * ksize;
+    int h_in = h_out * stride - pad;
+    int w_in = w_out * stride - pad;
+    data_col += (channel_out * height_col + h_out) * width_col + w_out;
+    data_im += (channel_in * height + h_in) * width + w_in;
+    for (int i = 0; i < ksize; ++i) {
+      for (int j = 0; j < ksize; ++j) {
+       int h = h_in + i;
+       int w = w_in + j;
+       *data_col = (h >= 0 && w >= 0 && h < width && w < height) ? data_im[i * width + j] : 0;
+       data_col += height_col * width_col;
+      }
+    }
+  }
+}
+
+template <typename Dtype>
 void im2col_gpu(const Dtype* data_im, const int channels,
     const int height, const int width, const int ksize, const int stride,
     Dtype* data_col) {
@@ -49,6 +75,21 @@ void im2col_gpu(const Dtype* data_im, const int channels,
   CUDA_POST_KERNEL_CHECK;
 }
 
+template <typename Dtype>
+void padded_im2col_gpu(const Dtype* data_im, const int channels,
+    const int height, const int width, const int ksize, const int pad, const int stride,
+    Dtype* data_col) {
+  // We are going to launch channels * height_col * width_col kernels, each
+  // kernel responsible for copying a single-channel grid.
+  int height_col = (height + 2 * pad - ksize) / stride + 1;
+  int width_col = (width + 2 * pad - ksize) / stride + 1;
+  int num_kernels = channels * height_col * width_col;
+  padded_im2col_gpu_kernel<Dtype><<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
+    num_kernels, data_im, height, width, ksize, pad, stride, height_col, width_col,
+    data_col);
+  CUDA_POST_KERNEL_CHECK;
+}
+
 
 // Explicit instantiation
 template void im2col_gpu<float>(const float* data_im, const int channels,
@@ -57,6 +98,12 @@ template void im2col_gpu<float>(const float* data_im, const int channels,
 template void im2col_gpu<double>(const double* data_im, const int channels,
     const int height, const int width, const int ksize, const int stride,
     double* data_col);
+template void padded_im2col_gpu<float>(const float* data_im, const int channels,
+  const int height, const int width, const int ksize, const int pad, const int stride,
+    float* data_col);
+template void padded_im2col_gpu<double>(const double* data_im, const int channels,
+  const int height, const int width, const int ksize, const int pad, const int stride,
+    double* data_col);
 
 template <typename Dtype>
 __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col,
@@ -96,6 +143,43 @@ __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col,
 }
 
 template <typename Dtype>
+__global__ void padded_col2im_gpu_kernel(const int n, const Dtype* data_col,
+  const int height, const int width, const int channels, const int ksize, const int pad,
+  const int stride, const int height_col, const int width_col, Dtype* data_im) {
+  int index = threadIdx.x + blockIdx.x * blockDim.x;
+  if (index < n) {
+    Dtype val = 0;
+    int w = index % width + pad;
+    int h = (index / width) % height + pad;
+    int c = index / (width * height);
+    // compute the start and end of the output
+    int w_col_start = (w < ksize) ? 0 : (w - ksize) / stride + 1;
+    int w_col_end = min(w / stride + 1, width_col);
+    int h_col_start = (h < ksize) ? 0 : (h - ksize) / stride + 1;
+    int h_col_end = min(h / stride + 1, height_col);
+    /*
+    for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
+      for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
+        // the col location: [c * width * height + h_out, w_out]
+        int c_col = c * ksize * ksize + (h - h_col * stride) * ksize + (w - w_col * stride);
+        val += data_col[(c_col * height_col + h_col) * width_col + w_col];
+      }
+    }
+    */
+    // equivalent implementation
+    int offset = (c * ksize * ksize + h * ksize + w) * height_col * width_col;
+    int coeff_h_col = (1 - stride * ksize * height_col) * width_col;
+    int coeff_w_col = (1 - stride * height_col * width_col);
+    for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
+      for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
+        val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col];
+      }
+    }
+    data_im[index] = val;
+  }
+}
+
+template <typename Dtype>
 void col2im_gpu(const Dtype* data_col, const int channels,
     const int height, const int width, const int ksize, const int stride,
     Dtype* data_im) {
@@ -111,6 +195,22 @@ void col2im_gpu(const Dtype* data_col, const int channels,
   CUDA_POST_KERNEL_CHECK;
 }
 
+template <typename Dtype>
+void padded_col2im_gpu(const Dtype* data_col, const int channels,
+    const int height, const int width, const int ksize, const int pad, const int stride,
+    Dtype* data_im) {
+  //CUDA_CHECK(cudaMemset(data_im, 0, sizeof(Dtype) * height * width * channels));
+  int height_col = (height + 2 * pad - ksize) / stride + 1;
+  int width_col = (width + 2 * pad - ksize) / stride + 1;
+  int num_kernels = channels * height * width;
+  // To avoid involving atomic operations, we will launch one kernel per
+  // bottom dimension, and then in the kernel add up the top dimensions.
+  padded_col2im_gpu_kernel<Dtype><<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
+      num_kernels, data_col, height, width, channels, ksize, pad, stride,
+      height_col, width_col, data_im);
+  CUDA_POST_KERNEL_CHECK;
+}
+
 
 // Explicit instantiation
 template void col2im_gpu<float>(const float* data_col, const int channels,
@@ -119,6 +219,12 @@ template void col2im_gpu<float>(const float* data_col, const int channels,
 template void col2im_gpu<double>(const double* data_col, const int channels,
     const int height, const int width, const int psize, const int stride,
     double* data_im);
+template void padded_col2im_gpu<float>(const float* data_col, const int channels,
+    const int height, const int width, const int psize, const int pad, const int stride,
+    float* data_im);
+template void padded_col2im_gpu<double>(const double* data_col, const int channels,
+    const int height, const int width, const int psize, const int pad, const int stride,
+    double* data_im);
 
 
 }  // namespace caffe