const int height, const int width, const int psize, const int stride,
Dtype* data_im);
+template <typename Dtype>
+void padded_im2col_cpu(const Dtype* data_im, const int channels,
+ const int height, const int width, const int ksize, const int pad, const int stride,
+ Dtype* data_col);
+
+template <typename Dtype>
+void padded_col2im_cpu(const Dtype* data_col, const int channels,
+ const int height, const int width, const int psize, const int pad, const int stride,
+ Dtype* data_im);
+
+template <typename Dtype>
+void padded_im2col_gpu(const Dtype* data_im, const int channels,
+ const int height, const int width, const int ksize, const int pad, const int stride,
+ Dtype* data_col);
+
+template <typename Dtype>
+void padded_col2im_gpu(const Dtype* data_col, const int channels,
+ const int height, const int width, const int psize, const int pad, const int stride,
+ Dtype* data_im);
+
} // namespace caffe
#endif // CAFFE_UTIL_IM2COL_HPP_
int CHANNELS_;
int HEIGHT_;
int WIDTH_;
+ int PAD_;
};
-
template <typename Dtype>
class PoolingLayer : public Layer<Dtype> {
public:
int STRIDE_;
int NUM_;
int CHANNELS_;
+ int PAD_;
int HEIGHT_;
int WIDTH_;
int NUM_OUTPUT_;
KSIZE_ = this->layer_param_.kernelsize();
STRIDE_ = this->layer_param_.stride();
GROUP_ = this->layer_param_.group();
+ PAD_ = this->layer_param_.pad();
NUM_ = bottom[0]->num();
CHANNELS_ = bottom[0]->channels();
HEIGHT_ = bottom[0]->height();
CHECK_EQ(CHANNELS_ % GROUP_, 0);
// The im2col result buffer would only hold one image at a time to avoid
// overly large memory usage.
- int height_out = (HEIGHT_ - KSIZE_) / STRIDE_ + 1;
- int width_out = (WIDTH_ - KSIZE_) / STRIDE_ + 1;
+ int height_out = (HEIGHT_ + 2 * PAD_ - KSIZE_) / STRIDE_ + 1;
+ int width_out = (WIDTH_ + 2 * PAD_ - KSIZE_) / STRIDE_ + 1;
col_buffer_.Reshape(1, CHANNELS_ * KSIZE_ * KSIZE_, height_out, width_out);
// Set the parameters
CHECK_EQ(NUM_OUTPUT_ % GROUP_, 0)
int top_offset = M_ * N_;
for (int n = 0; n < NUM_; ++n) {
// First, im2col
- im2col_cpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
+ if (PAD_ == 0) {
+ im2col_cpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
WIDTH_, KSIZE_, STRIDE_, col_data);
+ } else {
+ padded_im2col_cpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
+ WIDTH_, KSIZE_, PAD_, STRIDE_, col_data);
+ }
// Second, innerproduct with groups
for (int g = 0; g < GROUP_; ++g) {
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_,
int top_offset = M_ * N_;
for (int n = 0; n < NUM_; ++n) {
// First, im2col
- im2col_gpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
+ if (PAD_ == 0) {
+ im2col_gpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
WIDTH_, KSIZE_, STRIDE_, col_data);
+ } else {
+ padded_im2col_gpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
+ WIDTH_, KSIZE_, PAD_, STRIDE_, col_data);
+ }
// Second, innerproduct with groups
for (int g = 0; g < GROUP_; ++g) {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, M_, N_, K_,
for (int n = 0; n < NUM_; ++n) {
// since we saved memory in the forward pass by not storing all col data,
// we will need to recompute them.
- im2col_cpu(bottom_data + (*bottom)[0]->offset(n), CHANNELS_, HEIGHT_,
+ if (PAD_ == 0) {
+ im2col_cpu(bottom_data + (*bottom)[0]->offset(n), CHANNELS_, HEIGHT_,
WIDTH_, KSIZE_, STRIDE_, col_data);
+ } else {
+ padded_im2col_cpu(bottom_data + (*bottom)[0]->offset(n), CHANNELS_, HEIGHT_,
+ WIDTH_, KSIZE_, PAD_, STRIDE_, col_data);
+ }
// gradient w.r.t. weight. Note that we will accumulate diffs.
for (int g = 0; g < GROUP_; ++g) {
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_,
(Dtype)0., col_diff + col_offset * g);
}
// col2im back to the data
- col2im_cpu(col_diff, CHANNELS_, HEIGHT_,
- WIDTH_, KSIZE_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n));
+ if (PAD_ == 0) {
+ col2im_cpu(col_diff, CHANNELS_, HEIGHT_,
+ WIDTH_, KSIZE_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n));
+ } else {
+ padded_col2im_cpu(col_diff, CHANNELS_, HEIGHT_,
+ WIDTH_, KSIZE_, PAD_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n));
+ }
}
}
return Dtype(0.);
for (int n = 0; n < NUM_; ++n) {
// since we saved memory in the forward pass by not storing all col data,
// we will need to recompute them.
- im2col_gpu(bottom_data + (*bottom)[0]->offset(n), CHANNELS_, HEIGHT_,
+ if (PAD_ == 0) {
+ im2col_gpu(bottom_data + (*bottom)[0]->offset(n), CHANNELS_, HEIGHT_,
WIDTH_, KSIZE_, STRIDE_, col_data);
+ } else {
+ padded_im2col_gpu(bottom_data + (*bottom)[0]->offset(n), CHANNELS_, HEIGHT_,
+ WIDTH_, KSIZE_, PAD_, STRIDE_, col_data);
+ }
// gradient w.r.t. weight. Note that we will accumulate diffs.
for (int g = 0; g < GROUP_; ++g) {
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasTrans, M_, K_, N_,
(Dtype)0., col_diff + col_offset * g);
}
// col2im back to the data
- col2im_gpu(col_diff, CHANNELS_, HEIGHT_,
+ if (PAD_ == 0) {
+ col2im_gpu(col_diff, CHANNELS_, HEIGHT_,
WIDTH_, KSIZE_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n));
+ } else {
+ padded_col2im_gpu(col_diff, CHANNELS_, HEIGHT_,
+ WIDTH_, KSIZE_, PAD_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n));
+ }
}
}
return Dtype(0.);
CHECK_EQ(top->size(), 1) << "Im2col Layer takes a single blob as output.";
KSIZE_ = this->layer_param_.kernelsize();
STRIDE_ = this->layer_param_.stride();
+ PAD_ = this->layer_param_.pad();
CHANNELS_ = bottom[0]->channels();
HEIGHT_ = bottom[0]->height();
WIDTH_ = bottom[0]->width();
(*top)[0]->Reshape(bottom[0]->num(), CHANNELS_ * KSIZE_ * KSIZE_,
- (HEIGHT_ - KSIZE_) / STRIDE_ + 1, (WIDTH_ - KSIZE_) / STRIDE_ + 1);
+ (HEIGHT_ + 2 * PAD_ - KSIZE_) / STRIDE_ + 1, (WIDTH_ + 2 * PAD_ - KSIZE_) / STRIDE_ + 1);
};
template <typename Dtype>
const Dtype* bottom_data = bottom[0]->cpu_data();
Dtype* top_data = (*top)[0]->mutable_cpu_data();
for (int n = 0; n < bottom[0]->num(); ++n) {
- im2col_cpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
+ if (PAD_ == 0) {
+ im2col_cpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
WIDTH_, KSIZE_, STRIDE_, top_data + (*top)[0]->offset(n));
+ } else {
+ padded_im2col_cpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
+ WIDTH_, KSIZE_, PAD_, STRIDE_, top_data + (*top)[0]->offset(n));
+ }
}
}
const Dtype* bottom_data = bottom[0]->gpu_data();
Dtype* top_data = (*top)[0]->mutable_gpu_data();
for (int n = 0; n < bottom[0]->num(); ++n) {
- im2col_gpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
+ if (PAD_ == 0) {
+ im2col_gpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
WIDTH_, KSIZE_, STRIDE_, top_data + (*top)[0]->offset(n));
+ } else {
+ padded_im2col_gpu(bottom_data + bottom[0]->offset(n), CHANNELS_, HEIGHT_,
+ WIDTH_, KSIZE_, PAD_, STRIDE_, top_data + (*top)[0]->offset(n));
+ }
}
}
const Dtype* top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
for (int n = 0; n < top[0]->num(); ++n) {
- col2im_cpu(top_diff + top[0]->offset(n), CHANNELS_, HEIGHT_,
+ if (PAD_ == 0) {
+ col2im_cpu(top_diff + top[0]->offset(n), CHANNELS_, HEIGHT_,
WIDTH_, KSIZE_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n));
+ } else {
+ padded_col2im_cpu(top_diff + top[0]->offset(n), CHANNELS_, HEIGHT_,
+ WIDTH_, KSIZE_, PAD_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n));
+ }
}
return Dtype(0.);
}
const Dtype* top_diff = top[0]->gpu_diff();
Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
for (int n = 0; n < top[0]->num(); ++n) {
- col2im_gpu(top_diff + top[0]->offset(n), CHANNELS_, HEIGHT_,
+ if (PAD_ == 0) {
+ col2im_gpu(top_diff + top[0]->offset(n), CHANNELS_, HEIGHT_,
WIDTH_, KSIZE_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n));
+ } else {
+ padded_col2im_gpu(top_diff + top[0]->offset(n), CHANNELS_, HEIGHT_,
+ WIDTH_, KSIZE_, PAD_, STRIDE_, bottom_diff + (*bottom)[0]->offset(n));
+ }
}
return Dtype(0.);
}
}
}
+template <typename Dtype>
+void padded_im2col_cpu(const Dtype* data_im, const int channels,
+ const int height, const int width, const int ksize, const int pad, const int stride,
+ Dtype* data_col) {
+ int height_col = (height + 2 * pad - ksize) / stride + 1;
+ int width_col = (width + 2 * pad - ksize) / stride + 1;
+ int channels_col = channels * ksize * ksize;
+ for (int c = 0; c < channels_col; ++c) {
+ int w_offset = c % ksize;
+ int h_offset = (c / ksize) % ksize;
+ int c_im = c / ksize / ksize;
+ for (int h = 0; h < height_col; ++h) {
+ for (int w = 0; w < width_col; ++w) {
+ int h_pad = h * stride - pad + h_offset;
+ int w_pad = w * stride - pad + w_offset;
+ if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width)
+ data_col[(c * height_col + h) * width_col + w] =
+ data_im[(c_im * height + h_pad) * width + w_pad];
+ else
+ data_col[(c * height_col + h) * width_col + w] = 0;
+ }
+ }
+ }
+}
+
// Explicit instantiation
template void im2col_cpu<float>(const float* data_im, const int channels,
const int height, const int width, const int ksize, const int stride,
template void im2col_cpu<double>(const double* data_im, const int channels,
const int height, const int width, const int ksize, const int stride,
double* data_col);
+template void padded_im2col_cpu<float>(const float* data_im, const int channels,
+ const int height, const int width, const int ksize, const int pad, const int stride,
+ float* data_col);
+template void padded_im2col_cpu<double>(const double* data_im, const int channels,
+ const int height, const int width, const int ksize, const int pad, const int stride,
+ double* data_col);
template <typename Dtype>
void col2im_cpu(const Dtype* data_col, const int channels,
}
}
+template <typename Dtype>
+void padded_col2im_cpu(const Dtype* data_col, const int channels,
+ const int height, const int width, const int ksize, const int pad, const int stride,
+ Dtype* data_im) {
+ memset(data_im, 0, sizeof(Dtype) * height * width * channels);
+ int height_col = (height + 2 * pad - ksize) / stride + 1;
+ int width_col = (width + 2 * pad - ksize) / stride + 1;
+ int channels_col = channels * ksize * ksize;
+ for (int c = 0; c < channels_col; ++c) {
+ int w_offset = c % ksize;
+ int h_offset = (c / ksize) % ksize;
+ int c_im = c / ksize / ksize;
+ for (int h = 0; h < height_col; ++h) {
+ for (int w = 0; w < width_col; ++w) {
+ int h_pad = h * stride - pad + h_offset;
+ int w_pad = w * stride - pad + w_offset;
+ if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width)
+ data_im[(c_im * height + h_pad) * width + w_pad] += data_col[(c * height_col + h) * width_col + w];
+ }
+ }
+ }
+}
+
// Explicit instantiation
template void col2im_cpu<float>(const float* data_col, const int channels,
const int height, const int width, const int psize, const int stride,
template void col2im_cpu<double>(const double* data_col, const int channels,
const int height, const int width, const int psize, const int stride,
double* data_im);
+template void padded_col2im_cpu<float>(const float* data_col, const int channels,
+ const int height, const int width, const int psize, const int pad, const int stride,
+ float* data_im);
+template void padded_col2im_cpu<double>(const double* data_col, const int channels,
+ const int height, const int width, const int psize, const int pad, const int stride,
+ double* data_im);
} // namespace caffe
}
template <typename Dtype>
+__global__ void padded_im2col_gpu_kernel(const int n, const Dtype* data_im,
+ const int height, const int width, const int ksize, const int pad,
+ const int stride, const int height_col, const int width_col, Dtype* data_col) {
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (index < n) {
+ int w_out = index % width_col;
+ index /= width_col;
+ int h_out = index % height_col;
+ int channel_in = index / height_col;
+ int channel_out = channel_in * ksize * ksize;
+ int h_in = h_out * stride - pad;
+ int w_in = w_out * stride - pad;
+ data_col += (channel_out * height_col + h_out) * width_col + w_out;
+ data_im += (channel_in * height + h_in) * width + w_in;
+ for (int i = 0; i < ksize; ++i) {
+ for (int j = 0; j < ksize; ++j) {
+ int h = h_in + i;
+ int w = w_in + j;
+ *data_col = (h >= 0 && w >= 0 && h < width && w < height) ? data_im[i * width + j] : 0;
+ data_col += height_col * width_col;
+ }
+ }
+ }
+}
+
+template <typename Dtype>
void im2col_gpu(const Dtype* data_im, const int channels,
const int height, const int width, const int ksize, const int stride,
Dtype* data_col) {
CUDA_POST_KERNEL_CHECK;
}
+template <typename Dtype>
+void padded_im2col_gpu(const Dtype* data_im, const int channels,
+ const int height, const int width, const int ksize, const int pad, const int stride,
+ Dtype* data_col) {
+ // We are going to launch channels * height_col * width_col kernels, each
+ // kernel responsible for copying a single-channel grid.
+ int height_col = (height + 2 * pad - ksize) / stride + 1;
+ int width_col = (width + 2 * pad - ksize) / stride + 1;
+ int num_kernels = channels * height_col * width_col;
+ padded_im2col_gpu_kernel<Dtype><<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
+ num_kernels, data_im, height, width, ksize, pad, stride, height_col, width_col,
+ data_col);
+ CUDA_POST_KERNEL_CHECK;
+}
+
// Explicit instantiation
template void im2col_gpu<float>(const float* data_im, const int channels,
template void im2col_gpu<double>(const double* data_im, const int channels,
const int height, const int width, const int ksize, const int stride,
double* data_col);
+template void padded_im2col_gpu<float>(const float* data_im, const int channels,
+ const int height, const int width, const int ksize, const int pad, const int stride,
+ float* data_col);
+template void padded_im2col_gpu<double>(const double* data_im, const int channels,
+ const int height, const int width, const int ksize, const int pad, const int stride,
+ double* data_col);
template <typename Dtype>
__global__ void col2im_gpu_kernel(const int n, const Dtype* data_col,
}
template <typename Dtype>
+__global__ void padded_col2im_gpu_kernel(const int n, const Dtype* data_col,
+ const int height, const int width, const int channels, const int ksize, const int pad,
+ const int stride, const int height_col, const int width_col, Dtype* data_im) {
+ int index = threadIdx.x + blockIdx.x * blockDim.x;
+ if (index < n) {
+ Dtype val = 0;
+ int w = index % width + pad;
+ int h = (index / width) % height + pad;
+ int c = index / (width * height);
+ // compute the start and end of the output
+ int w_col_start = (w < ksize) ? 0 : (w - ksize) / stride + 1;
+ int w_col_end = min(w / stride + 1, width_col);
+ int h_col_start = (h < ksize) ? 0 : (h - ksize) / stride + 1;
+ int h_col_end = min(h / stride + 1, height_col);
+ /*
+ for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
+ for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
+ // the col location: [c * width * height + h_out, w_out]
+ int c_col = c * ksize * ksize + (h - h_col * stride) * ksize + (w - w_col * stride);
+ val += data_col[(c_col * height_col + h_col) * width_col + w_col];
+ }
+ }
+ */
+ // equivalent implementation
+ int offset = (c * ksize * ksize + h * ksize + w) * height_col * width_col;
+ int coeff_h_col = (1 - stride * ksize * height_col) * width_col;
+ int coeff_w_col = (1 - stride * height_col * width_col);
+ for (int h_col = h_col_start; h_col < h_col_end; ++h_col) {
+ for (int w_col = w_col_start; w_col < w_col_end; ++w_col) {
+ val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col];
+ }
+ }
+ data_im[index] = val;
+ }
+}
+
+template <typename Dtype>
void col2im_gpu(const Dtype* data_col, const int channels,
const int height, const int width, const int ksize, const int stride,
Dtype* data_im) {
CUDA_POST_KERNEL_CHECK;
}
+template <typename Dtype>
+void padded_col2im_gpu(const Dtype* data_col, const int channels,
+ const int height, const int width, const int ksize, const int pad, const int stride,
+ Dtype* data_im) {
+ //CUDA_CHECK(cudaMemset(data_im, 0, sizeof(Dtype) * height * width * channels));
+ int height_col = (height + 2 * pad - ksize) / stride + 1;
+ int width_col = (width + 2 * pad - ksize) / stride + 1;
+ int num_kernels = channels * height * width;
+ // To avoid involving atomic operations, we will launch one kernel per
+ // bottom dimension, and then in the kernel add up the top dimensions.
+ padded_col2im_gpu_kernel<Dtype><<<CAFFE_GET_BLOCKS(num_kernels), CAFFE_CUDA_NUM_THREADS>>>(
+ num_kernels, data_col, height, width, channels, ksize, pad, stride,
+ height_col, width_col, data_im);
+ CUDA_POST_KERNEL_CHECK;
+}
+
// Explicit instantiation
template void col2im_gpu<float>(const float* data_col, const int channels,
template void col2im_gpu<double>(const double* data_col, const int channels,
const int height, const int width, const int psize, const int stride,
double* data_im);
+template void padded_col2im_gpu<float>(const float* data_col, const int channels,
+ const int height, const int width, const int psize, const int pad, const int stride,
+ float* data_im);
+template void padded_col2im_gpu<double>(const double* data_col, const int channels,
+ const int height, const int width, const int psize, const int pad, const int stride,
+ double* data_im);
} // namespace caffe