From 4c93b3dc555891ae0ad75092b6c0f77508740ecf Mon Sep 17 00:00:00 2001 From: Mausoom Sarkar Date: Tue, 13 Oct 2015 18:35:32 +0530 Subject: [PATCH] Moved the loop inside PReLUParamBackward to do the reduction inside the kernel Now PReLU backward is taking the same time as forward Code cleanup Removed unnecessary code Fixed indent merge if(channed_shared_) --- src/caffe/layers/prelu_layer.cu | 44 ++++++++++++++++++++--------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/src/caffe/layers/prelu_layer.cu b/src/caffe/layers/prelu_layer.cu index e1f2004..1225334 100644 --- a/src/caffe/layers/prelu_layer.cu +++ b/src/caffe/layers/prelu_layer.cu @@ -31,10 +31,15 @@ __global__ void PReLUBackward(const int n, const int channels, const int dim, // CUDA kernel for element-wise parameter backward template -__global__ void PReLUParamBackward(const int n, const Dtype* in_diff, +__global__ void PReLUParamBackward(const int n, + const int rows, const int rowPitch, const Dtype* in_diff, const Dtype* in_data, Dtype* out_diff) { CUDA_KERNEL_LOOP(index, n) { out_diff[index] = in_diff[index] * in_data[index] * (in_data[index] <= 0); + for ( int k = 1; k < rows; k++ ) { + out_diff[index] += in_diff[index + k*rowPitch] + * in_data[index + k*rowPitch] * (in_data[index + k*rowPitch] <= 0); + } } } @@ -82,29 +87,24 @@ void PReLULayer::Backward_gpu(const vector*>& top, if (this->param_propagate_down_[0]) { Dtype* slope_diff = this->blobs_[0]->mutable_gpu_diff(); int cdim = channels * dim; - Dtype dsum = 0.; - for (int n = 0; n < bottom[0]->num(); ++n) { - // compute element-wise diff - // NOLINT_NEXT_LINE(whitespace/operators) - PReLUParamBackward<<>>( - cdim, top_diff + top[0]->offset(n), - bottom_data + bottom[0]->offset(n), - backward_buff_.mutable_gpu_diff()); - CUDA_POST_KERNEL_CHECK; - if (channel_shared_) { - Dtype d; - caffe_gpu_dot(channels * dim, backward_buff_.gpu_diff(), - multiplier_.gpu_data(), &d); - dsum += d; - } else { - caffe_gpu_gemv(CblasNoTrans, channels, dim, 1., - backward_buff_.gpu_diff(), multiplier_.gpu_data(), 1., - slope_diff); - } - } + + // compute element-wise diff + // NOLINT_NEXT_LINE(whitespace/operators) + PReLUParamBackward<<>>( + cdim, bottom[0]->num(), top[0]->offset(1), top_diff , + bottom_data , + backward_buff_.mutable_gpu_diff()); + CUDA_POST_KERNEL_CHECK; if (channel_shared_) { + Dtype dsum; + caffe_gpu_dot(channels * dim, backward_buff_.gpu_diff(), + multiplier_.gpu_data(), &dsum); caffe_gpu_add_scalar(this->blobs_[0]->count(), Dtype(dsum), slope_diff); + } else { + caffe_gpu_gemv(CblasNoTrans, channels, dim, 1., + backward_buff_.gpu_diff(), multiplier_.gpu_data(), 1., + slope_diff); } } // Propagate to bottom -- 2.7.4