From 18c795ebe8401cb82c9f8350664de665f1ec8733 Mon Sep 17 00:00:00 2001 From: Fisher Yu Date: Sun, 27 Dec 2015 20:48:30 -0800 Subject: [PATCH] add support for N-D dilated convolution --- include/caffe/layers/base_conv_layer.hpp | 8 +- include/caffe/util/im2col.hpp | 8 +- src/caffe/layer_factory.cpp | 2 + src/caffe/layers/im2col_layer.cpp | 4 +- src/caffe/layers/im2col_layer.cu | 4 +- src/caffe/test/test_im2col_kernel.cu | 9 +- src/caffe/test/test_im2col_layer.cpp | 8 +- src/caffe/util/im2col.cpp | 21 ++-- src/caffe/util/im2col.cu | 166 +++++++++++++++++++++---------- 9 files changed, 148 insertions(+), 82 deletions(-) diff --git a/include/caffe/layers/base_conv_layer.hpp b/include/caffe/layers/base_conv_layer.hpp index db471b5..0160a83 100644 --- a/include/caffe/layers/base_conv_layer.hpp +++ b/include/caffe/layers/base_conv_layer.hpp @@ -106,7 +106,7 @@ class BaseConvolutionLayer : public Layer { } else { im2col_nd_cpu(data, num_spatial_axes_, conv_input_shape_.cpu_data(), col_buffer_shape_.data(), kernel_shape_.cpu_data(), - pad_.cpu_data(), stride_.cpu_data(), col_buff); + pad_.cpu_data(), stride_.cpu_data(), dilation_.cpu_data(), col_buff); } } inline void conv_col2im_cpu(const Dtype* col_buff, Dtype* data) { @@ -120,7 +120,7 @@ class BaseConvolutionLayer : public Layer { } else { col2im_nd_cpu(col_buff, num_spatial_axes_, conv_input_shape_.cpu_data(), col_buffer_shape_.data(), kernel_shape_.cpu_data(), - pad_.cpu_data(), stride_.cpu_data(), data); + pad_.cpu_data(), stride_.cpu_data(), dilation_.cpu_data(), data); } } #ifndef CPU_ONLY @@ -136,7 +136,7 @@ class BaseConvolutionLayer : public Layer { im2col_nd_gpu(data, num_spatial_axes_, num_kernels_im2col_, conv_input_shape_.gpu_data(), col_buffer_.gpu_shape(), kernel_shape_.gpu_data(), pad_.gpu_data(), - stride_.gpu_data(), col_buff); + stride_.gpu_data(), dilation_.gpu_data(), col_buff); } } inline void conv_col2im_gpu(const Dtype* col_buff, Dtype* data) { @@ -151,7 +151,7 @@ class BaseConvolutionLayer : public Layer { col2im_nd_gpu(col_buff, num_spatial_axes_, num_kernels_col2im_, conv_input_shape_.gpu_data(), col_buffer_.gpu_shape(), kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(), - data); + dilation_.gpu_data(), data); } } #endif diff --git a/include/caffe/util/im2col.hpp b/include/caffe/util/im2col.hpp index 748b65c..a35bc6e 100644 --- a/include/caffe/util/im2col.hpp +++ b/include/caffe/util/im2col.hpp @@ -7,7 +7,7 @@ template void im2col_nd_cpu(const Dtype* data_im, const int num_spatial_axes, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_col); + const int* dilation, Dtype* data_col); template void im2col_cpu(const Dtype* data_im, const int channels, @@ -20,7 +20,7 @@ template void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_im); + const int* dilation, Dtype* data_im); template void col2im_cpu(const Dtype* data_col, const int channels, @@ -33,7 +33,7 @@ template void im2col_nd_gpu(const Dtype* data_im, const int num_spatial_axes, const int col_size, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_col); + const int* dilation, Dtype* data_col); template void im2col_gpu(const Dtype* data_im, const int channels, @@ -46,7 +46,7 @@ template void col2im_nd_gpu(const Dtype* data_col, const int num_spatial_axes, const int im_size, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_im); + const int* dilation, Dtype* data_im); template void col2im_gpu(const Dtype* data_col, const int channels, diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp index 6b1d1c1..4d912d2 100644 --- a/src/caffe/layer_factory.cpp +++ b/src/caffe/layer_factory.cpp @@ -39,12 +39,14 @@ shared_ptr > GetConvolutionLayer( const LayerParameter& param) { ConvolutionParameter conv_param = param.convolution_param(); ConvolutionParameter_Engine engine = conv_param.engine(); +#ifdef USE_CUDNN bool use_dilation = false; for (int i = 0; i < conv_param.dilation_size(); ++i) { if (conv_param.dilation(i) > 1) { use_dilation = true; } } +#endif if (engine == ConvolutionParameter_Engine_DEFAULT) { engine = ConvolutionParameter_Engine_CAFFE; #ifdef USE_CUDNN diff --git a/src/caffe/layers/im2col_layer.cpp b/src/caffe/layers/im2col_layer.cpp index 19ae301..2fb9b3c 100644 --- a/src/caffe/layers/im2col_layer.cpp +++ b/src/caffe/layers/im2col_layer.cpp @@ -153,7 +153,7 @@ void Im2colLayer::Forward_cpu(const vector*>& bottom, bottom[0]->shape().data() + channel_axis_, top[0]->shape().data() + channel_axis_, kernel_shape_.cpu_data(), pad_.cpu_data(), stride_.cpu_data(), - top_data + n * top_dim_); + dilation_.cpu_data(), top_data + n * top_dim_); } } } @@ -178,7 +178,7 @@ void Im2colLayer::Backward_cpu(const vector*>& top, bottom[0]->shape().data() + channel_axis_, top[0]->shape().data() + channel_axis_, kernel_shape_.cpu_data(), pad_.cpu_data(), stride_.cpu_data(), - bottom_diff + n * bottom_dim_); + dilation_.cpu_data(), bottom_diff + n * bottom_dim_); } } } diff --git a/src/caffe/layers/im2col_layer.cu b/src/caffe/layers/im2col_layer.cu index d90075d..792c97f 100644 --- a/src/caffe/layers/im2col_layer.cu +++ b/src/caffe/layers/im2col_layer.cu @@ -26,7 +26,7 @@ void Im2colLayer::Forward_gpu(const vector*>& bottom, num_kernels, bottom[0]->gpu_shape() + channel_axis_, top[0]->gpu_shape() + channel_axis_, kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(), - top_data + n * top_dim_); + dilation_.gpu_data(), top_data + n * top_dim_); } } } @@ -51,7 +51,7 @@ void Im2colLayer::Backward_gpu(const vector*>& top, bottom[0]->gpu_shape() + channel_axis_, top[0]->gpu_shape() + channel_axis_, kernel_shape_.gpu_data(), pad_.gpu_data(), stride_.gpu_data(), - bottom_diff + n * bottom_dim_); + dilation_.gpu_data(), bottom_diff + n * bottom_dim_); } } } diff --git a/src/caffe/test/test_im2col_kernel.cu b/src/caffe/test/test_im2col_kernel.cu index 15e06aa..5d8f01f 100644 --- a/src/caffe/test/test_im2col_kernel.cu +++ b/src/caffe/test/test_im2col_kernel.cu @@ -26,7 +26,7 @@ template __global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_col); + const int* dilation, Dtype* data_col); extern cudaDeviceProp CAFFE_TEST_CUDA_PROP; @@ -35,7 +35,7 @@ class Im2colKernelTest : public GPUDeviceTest { protected: Im2colKernelTest() // big so launches > 1024 threads - : blob_bottom_(new Blob(5, 500, 10, 10)), + : blob_bottom_(new Blob(5, 500, 15, 15)), blob_kernel_shape_(new Blob()), blob_stride_(new Blob()), blob_pad_(new Blob()), @@ -56,7 +56,7 @@ class Im2colKernelTest : public GPUDeviceTest { channels_ = blob_bottom_->channels(); pad_ = 0; stride_ = 2; - dilation_ = 1; + dilation_ = 3; kernel_size_ = 3; height_col_ = (height_ + 2 * pad_ - (dilation_ * (kernel_size_ - 1) + 1)) / stride_ + 1; @@ -176,6 +176,7 @@ TYPED_TEST(Im2colKernelTest, TestND) { this->blob_top_cpu_->shape().data() + 1, this->blob_kernel_shape_->cpu_data(), this->blob_pad_->cpu_data(), this->blob_stride_->cpu_data(), + this->blob_dilation_->cpu_data(), top_data_cpu + this->blob_top_cpu_->offset(n)); } @@ -194,7 +195,7 @@ TYPED_TEST(Im2colKernelTest, TestND) { num_kernels, bottom_data_gpu + this->blob_bottom_->offset(n), this->blob_bottom_->gpu_shape() + 1, this->blob_top_->gpu_shape() + 1, this->blob_kernel_shape_->gpu_data(), this->blob_pad_->gpu_data(), - this->blob_stride_->gpu_data(), + this->blob_stride_->gpu_data(), this->blob_dilation_->gpu_data(), top_data_gpu + this->blob_top_->offset(n)); CUDA_POST_KERNEL_CHECK; } diff --git a/src/caffe/test/test_im2col_layer.cpp b/src/caffe/test/test_im2col_layer.cpp index 932d3f2..24885e6 100644 --- a/src/caffe/test/test_im2col_layer.cpp +++ b/src/caffe/test/test_im2col_layer.cpp @@ -17,7 +17,7 @@ class Im2colLayerTest : public MultiDeviceTest { typedef typename TypeParam::Dtype Dtype; protected: Im2colLayerTest() - : blob_bottom_(new Blob(2, 3, 10, 9)), + : blob_bottom_(new Blob(2, 3, 10, 11)), blob_top_(new Blob()) { // fill the values Caffe::set_random_seed(1701); @@ -43,12 +43,13 @@ TYPED_TEST(Im2colLayerTest, TestSetup) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); + convolution_param->add_dilation(3); Im2colLayer layer(layer_param); layer.SetUp(this->blob_bottom_vec_, this->blob_top_vec_); EXPECT_EQ(this->blob_top_->num(), 2); EXPECT_EQ(this->blob_top_->channels(), 27); EXPECT_EQ(this->blob_top_->height(), 2); - EXPECT_EQ(this->blob_top_->width(), 2); + EXPECT_EQ(this->blob_top_->width(), 3); } TYPED_TEST(Im2colLayerTest, TestForward) { @@ -89,6 +90,7 @@ TYPED_TEST(Im2colLayerTest, TestGradientForceND) { layer_param.mutable_convolution_param(); convolution_param->add_kernel_size(3); convolution_param->add_stride(2); + convolution_param->add_dilation(3); convolution_param->set_force_nd_im2col(true); Im2colLayer layer(layer_param); GradientChecker checker(1e-2, 1e-2); @@ -123,6 +125,8 @@ TYPED_TEST(Im2colLayerTest, TestRectGradient) { convolution_param->set_kernel_h(5); convolution_param->set_kernel_w(3); convolution_param->add_stride(2); + convolution_param->add_dilation(1); + convolution_param->add_dilation(3); Im2colLayer layer(layer_param); GradientChecker checker(1e-2, 1e-2); checker.CheckGradientExhaustive(&layer, this->blob_bottom_vec_, diff --git a/src/caffe/util/im2col.cpp b/src/caffe/util/im2col.cpp index 1e578e7..6e5ea87 100644 --- a/src/caffe/util/im2col.cpp +++ b/src/caffe/util/im2col.cpp @@ -49,7 +49,7 @@ template inline void im2col_nd_core_cpu(const Dtype* data_input, const bool im2col, const int num_spatial_axes, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_output) { + const int* dilation, Dtype* data_output) { if (!im2col) { int im_size = im_shape[0]; for (int i = 0; i < num_spatial_axes; ++i) { @@ -81,7 +81,8 @@ inline void im2col_nd_core_cpu(const Dtype* data_input, const bool im2col, bool is_padding = false; for (int d_i = 0; d_i < num_spatial_axes; ++d_i) { const int d = d_iter[d_i]; - const int d_im = d * stride[d_i] - pad[d_i] + d_offset[d_i]; + const int d_im = d * stride[d_i] - pad[d_i] + + d_offset[d_i] * dilation[d_i]; is_padding |= d_im < 0 || d_im >= im_shape[d_i + 1]; index_col *= col_shape[d_i + 1]; index_col += d; @@ -119,10 +120,10 @@ template void im2col_nd_cpu(const Dtype* data_im, const int num_spatial_axes, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_col) { + const int* dilation, Dtype* data_col) { const bool kIm2Col = true; im2col_nd_core_cpu(data_im, kIm2Col, num_spatial_axes, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); } // Explicit instantiation @@ -130,12 +131,12 @@ template void im2col_nd_cpu(const float* data_im, const int num_spatial_axes, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - float* data_col); + const int* dilation, float* data_col); template void im2col_nd_cpu(const double* data_im, const int num_spatial_axes, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - double* data_col); + const int* dilation, double* data_col); template void col2im_cpu(const Dtype* data_col, const int channels, @@ -182,10 +183,10 @@ template void col2im_nd_cpu(const Dtype* data_col, const int num_spatial_axes, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_im) { + const int* dilation, Dtype* data_im) { const bool kIm2Col = false; im2col_nd_core_cpu(data_col, kIm2Col, num_spatial_axes, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); } // Explicit instantiation @@ -193,12 +194,12 @@ template void col2im_nd_cpu(const float* data_col, const int num_spatial_axes, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - float* data_im); + const int* dilation, float* data_im); template void col2im_nd_cpu(const double* data_col, const int num_spatial_axes, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - double* data_im); + const int* dilation, double* data_im); } // namespace caffe diff --git a/src/caffe/util/im2col.cu b/src/caffe/util/im2col.cu index cdcaac5..a8f30a0 100644 --- a/src/caffe/util/im2col.cu +++ b/src/caffe/util/im2col.cu @@ -75,9 +75,29 @@ template __global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_col) { + const int* dilation, Dtype* data_col) { int d_temp[num_axes]; // NOLINT(runtime/arrays) int d_iter[num_axes]; // NOLINT(runtime/arrays) + + __shared__ int shared_dilation[num_axes]; + __shared__ int shared_kernel_shape[num_axes]; + __shared__ int shared_pad[num_axes]; + __shared__ int shared_stride[num_axes]; + __shared__ int shared_col_shape[num_axes + 1]; + __shared__ int shared_im_shape[num_axes + 1]; + + if (threadIdx.x < num_axes) { + shared_dilation[threadIdx.x] = dilation[threadIdx.x]; + shared_kernel_shape[threadIdx.x] = kernel_shape[threadIdx.x]; + shared_pad[threadIdx.x] = pad[threadIdx.x]; + shared_stride[threadIdx.x] = stride[threadIdx.x]; + } + if (threadIdx.x < num_axes + 1) { + shared_col_shape[threadIdx.x] = col_shape[threadIdx.x]; + shared_im_shape[threadIdx.x] = im_shape[threadIdx.x]; + } + __syncthreads(); + int i; CUDA_KERNEL_LOOP(index, n) { // Initialize channel_in, computed in the loop below, with intermediate @@ -85,19 +105,19 @@ __global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im, int channel_in = index; int channel_out = 1; for (i = num_axes - 1; i >= 0; --i) { - d_temp[i] = channel_in % col_shape[i + 1]; - channel_in /= col_shape[i + 1]; - channel_out *= kernel_shape[i]; + d_temp[i] = channel_in % shared_col_shape[i + 1]; + channel_in /= shared_col_shape[i + 1]; + channel_out *= shared_kernel_shape[i]; } channel_out *= channel_in; int data_col_inc = 1; for (i = 0; i < num_axes; ++i) { - channel_out *= col_shape[i + 1]; + channel_out *= shared_col_shape[i + 1]; channel_out += d_temp[i]; - d_temp[i] = d_temp[i] * stride[i] - pad[i]; - channel_in *= im_shape[i + 1]; + d_temp[i] = d_temp[i] * shared_stride[i] - shared_pad[i]; + channel_in *= shared_im_shape[i + 1]; channel_in += d_temp[i]; - data_col_inc *= col_shape[i + 1]; + data_col_inc *= shared_col_shape[i + 1]; d_iter[i] = 0; } Dtype* data_col_ptr = data_col + channel_out; @@ -106,15 +126,15 @@ __global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im, do { bool in_range = true; for (i = 0; i < num_axes; ++i) { - const int d_iter_im = d_iter[i] + d_temp[i]; - in_range &= d_iter_im >= 0 && d_iter_im < im_shape[i + 1]; + const int d_iter_im = d_iter[i] * shared_dilation[i] + d_temp[i]; + in_range &= d_iter_im >= 0 && d_iter_im < shared_im_shape[i + 1]; if (!in_range) { break; } } if (in_range) { - int data_im_offset = d_iter[0]; + int data_im_offset = d_iter[0] * shared_dilation[0]; for (i = 1; i < num_axes; ++i) { - data_im_offset *= im_shape[i + 1]; - data_im_offset += d_iter[i]; + data_im_offset *= shared_im_shape[i + 1]; + data_im_offset += d_iter[i] * shared_dilation[i]; } *data_col_ptr = data_im_ptr[data_im_offset]; } else { @@ -123,7 +143,7 @@ __global__ void im2col_nd_gpu_kernel(const int n, const Dtype* data_im, data_col_ptr += data_col_inc; incremented = false; for (i = num_axes - 1; i >= 0; --i) { - const int d_max = kernel_shape[i]; + const int d_max = shared_kernel_shape[i]; if (d_iter[i] == d_max - 1) { d_iter[i] = 0; } else { // d_iter[i] < d_max - 1 @@ -140,67 +160,69 @@ template void im2col_nd_gpu(const Dtype* data_im, const int num_spatial_axes, const int num_kernels, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_col) { + const int* dilation, Dtype* data_col) { + // num_axes should be smaller than block size + DCHECK_LT(num_spatial_axes, CAFFE_CUDA_NUM_THREADS); switch (num_spatial_axes) { case 1: im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( num_kernels, data_im, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); break; case 2: im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( num_kernels, data_im, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); break; case 3: im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( num_kernels, data_im, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); break; case 4: im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( num_kernels, data_im, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); break; case 5: im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( num_kernels, data_im, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); break; case 6: im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( num_kernels, data_im, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); break; case 7: im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( num_kernels, data_im, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); break; case 8: im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( num_kernels, data_im, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); break; case 9: im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( num_kernels, data_im, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); break; case 10: im2col_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( num_kernels, data_im, im_shape, col_shape, - kernel_shape, pad, stride, data_col); + kernel_shape, pad, stride, dilation, data_col); break; default: LOG(FATAL) << "im2col_nd_gpu does not support computation with " @@ -214,12 +236,12 @@ template void im2col_nd_gpu(const float* data_im, const int num_spatial_axes, const int col_size, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - float* data_col); + const int* dilation, float* data_col); template void im2col_nd_gpu(const double* data_im, const int num_spatial_axes, const int col_size, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - double* data_col); + const int* dilation, double* data_col); template __global__ void col2im_gpu_kernel(const int n, const Dtype* data_col, @@ -300,27 +322,50 @@ template __global__ void col2im_nd_gpu_kernel(const int n, const Dtype* data_col, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_im) { + const int* dilation, Dtype* data_im) { int d_im[num_axes]; // NOLINT(runtime/arrays) int d_col_iter[num_axes]; // NOLINT(runtime/arrays) int d_col_start[num_axes]; // NOLINT(runtime/arrays) int d_col_end[num_axes]; // NOLINT(runtime/arrays) + + __shared__ int shared_dilation[num_axes]; + __shared__ int shared_kernel_shape[num_axes]; + __shared__ int shared_pad[num_axes]; + __shared__ int shared_stride[num_axes]; + __shared__ int shared_col_shape[num_axes + 1]; + __shared__ int shared_im_shape[num_axes + 1]; + + if (threadIdx.x < num_axes) { + shared_dilation[threadIdx.x] = dilation[threadIdx.x]; + shared_kernel_shape[threadIdx.x] = kernel_shape[threadIdx.x]; + shared_pad[threadIdx.x] = pad[threadIdx.x]; + shared_stride[threadIdx.x] = stride[threadIdx.x]; + } + if (threadIdx.x < num_axes + 1) { + shared_col_shape[threadIdx.x] = col_shape[threadIdx.x]; + shared_im_shape[threadIdx.x] = im_shape[threadIdx.x]; + } + __syncthreads(); + CUDA_KERNEL_LOOP(index, n) { // Initialize channel_in, computed in the loop below, with intermediate // computations used to compute the spatial indices. int c_im = index; // Calculate d_im (image dimensions). for (int i = num_axes - 1; i >= 0; --i) { - d_im[i] = c_im % im_shape[i + 1] + pad[i]; - c_im /= im_shape[i + 1]; + d_im[i] = c_im % shared_im_shape[i + 1] + shared_pad[i]; + c_im /= shared_im_shape[i + 1]; } // Calculate col start/end indices. bool done = false; for (int i = 0; i < num_axes; ++i) { + const int kernel_extent = + shared_dilation[i] * (shared_kernel_shape[i] - 1) + 1; d_col_start[i] = d_col_iter[i] = - (d_im[i] < kernel_shape[i]) ? - 0 : (d_im[i] - kernel_shape[i]) / stride[i] + 1; - d_col_end[i] = min(d_im[i] / stride[i] + 1, col_shape[i + 1]); + (d_im[i] < kernel_extent) ? 0 : + (d_im[i] - kernel_extent) / shared_stride[i] + 1; + d_col_end[i] = + min(d_im[i] / shared_stride[i] + 1, shared_col_shape[i + 1]); if (d_col_start[i] >= d_col_end[i]) { // Skip computation if the dimension is 0 at any spatial axis -- // final val will be 0. @@ -335,21 +380,32 @@ __global__ void col2im_nd_gpu_kernel(const int n, const Dtype* data_col, // Loop over the col to compute the output val. Dtype val = 0; bool incremented = true; + bool skip = false; do { // Compute the final offset. int final_offset = 0; int kernel_shape_prod = 1; + int kernel_index; for (int i = num_axes - 1; i >= 0; --i) { - final_offset += - (d_im[i] - d_col_iter[i] * stride[i]) * kernel_shape_prod; - kernel_shape_prod *= kernel_shape[i]; + kernel_index = d_im[i] - d_col_iter[i] * shared_stride[i]; + if (kernel_index % shared_dilation[i]) { + skip = true; + break; + } else { + kernel_index /= shared_dilation[i]; + final_offset += kernel_index * kernel_shape_prod; + kernel_shape_prod *= shared_kernel_shape[i]; + } } - final_offset += kernel_shape_prod * c_im; - for (int i = 0; i < num_axes; ++i) { - final_offset *= col_shape[i + 1]; - final_offset += d_col_iter[i]; + if (!skip) { + final_offset += kernel_shape_prod * c_im; + for (int i = 0; i < num_axes; ++i) { + final_offset *= shared_col_shape[i + 1]; + final_offset += d_col_iter[i]; + } + val += data_col[final_offset]; } - val += data_col[final_offset]; + skip = false; incremented = false; for (int i = num_axes - 1; i >= 0; --i) { const int d_max = d_col_end[i]; @@ -370,67 +426,69 @@ template void col2im_nd_gpu(const Dtype* data_col, const int num_spatial_axes, const int im_size, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - Dtype* data_im) { + const int* dilation, Dtype* data_im) { + // num_axes should be smaller than block size + DCHECK_LT(num_spatial_axes, CAFFE_CUDA_NUM_THREADS); switch (num_spatial_axes) { case 1: col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( im_size, data_col, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); break; case 2: col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( im_size, data_col, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); break; case 3: col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( im_size, data_col, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); break; case 4: col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( im_size, data_col, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); break; case 5: col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( im_size, data_col, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); break; case 6: col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( im_size, data_col, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); break; case 7: col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( im_size, data_col, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); break; case 8: col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( im_size, data_col, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); break; case 9: col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( im_size, data_col, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); break; case 10: col2im_nd_gpu_kernel // NOLINT_NEXT_LINE(whitespace/operators) <<>>( im_size, data_col, im_shape, col_shape, - kernel_shape, pad, stride, data_im); + kernel_shape, pad, stride, dilation, data_im); break; default: LOG(FATAL) << "col2im_nd_gpu does not support computation with " @@ -444,11 +502,11 @@ template void col2im_nd_gpu(const float* data_col, const int num_spatial_axes, const int im_size, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - float* data_im); + const int* dilation, float* data_im); template void col2im_nd_gpu(const double* data_col, const int num_spatial_axes, const int im_size, const int* im_shape, const int* col_shape, const int* kernel_shape, const int* pad, const int* stride, - double* data_im); + const int* dilation, double* data_im); } // namespace caffe -- 2.7.4