1 // Copyright 2013 Yangqing Jia
3 #include "caffe/layer.hpp"
4 #include "caffe/vision_layers.hpp"
5 #include "caffe/util/math_functions.hpp"
9 template <typename Dtype>
10 __global__ void LRNFillScale(const int nthreads, const Dtype* in,
11 const int num, const int channels, const int height,
12 const int width, const int size, const Dtype alpha_over_size,
14 int index = threadIdx.x + blockIdx.x * blockDim.x;
15 if (index < nthreads) {
16 // find out the local offset
17 int w = index % width;
18 int h = (index / width) % height;
19 int n = index / width / height;
20 int offset = (n * channels * height + h) * width + w;
21 int step = height * width;
25 int pre_pad = (size - 1) / 2;
26 int post_pad = size - pre_pad - 1;
27 Dtype accum_scale = 0;
28 // fill the scale at [n, :, h, w]
30 while (head < post_pad) {
31 accum_scale += in[head * step] * in[head * step];
34 // until we reach size, nothing needs to be subtracted
36 accum_scale += in[head * step] * in[head * step];
37 scale[(head - post_pad) * step] = 1. + accum_scale * alpha_over_size;
40 // both add and subtract
41 while (head < channels) {
42 accum_scale += in[head * step] * in[head * step];
43 accum_scale -= in[(head - size) * step] * in[(head - size) * step];
44 scale[(head - post_pad) * step] = 1. + accum_scale * alpha_over_size;
48 while (head < channels + post_pad) {
49 accum_scale -= in[(head - size) * step] * in[(head - size) * step];
50 scale[(head - post_pad) * step] = 1. + accum_scale * alpha_over_size;
57 // TODO: check if it would be faster to just put it into the previous kernel.
58 template <typename Dtype>
59 __global__ void LRNComputeOutput(const int nthreads, const Dtype* in,
60 const Dtype* scale, const Dtype negative_beta, Dtype* out) {
61 int index = threadIdx.x + blockIdx.x * blockDim.x;
62 if (index < nthreads) {
63 out[index] = in[index] * pow(scale[index], negative_beta);
67 template <typename Dtype>
68 void LRNLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
69 vector<Blob<Dtype>*>* top) {
70 // First, compute scale
71 const Dtype* bottom_data = bottom[0]->gpu_data();
72 Dtype* top_data = (*top)[0]->mutable_gpu_data();
73 Dtype* scale_data = scale_.mutable_gpu_data();
74 // We will launch one kernel for each pixel location, and have the kernel
75 // go through all the channels.
76 int n_threads = num_ * height_ * width_;
77 LRNFillScale<<<CAFFE_GET_BLOCKS(n_threads), CAFFE_CUDA_NUM_THREADS>>>(
78 n_threads, bottom_data, num_, channels_, height_, width_, size_,
79 alpha_ / size_, scale_data);
80 CUDA_POST_KERNEL_CHECK;
81 n_threads = bottom[0]->count();
82 LRNComputeOutput<<<CAFFE_GET_BLOCKS(n_threads), CAFFE_CUDA_NUM_THREADS>>>(
83 n_threads, bottom_data, scale_data, -beta_, top_data);
84 CUDA_POST_KERNEL_CHECK;
88 template <typename Dtype>
89 __global__ void LRNComputeDiff(const int nthreads, const Dtype* bottom_data,
90 const Dtype* top_data, const Dtype* scale, const Dtype* top_diff,
91 const int num, const int channels, const int height,
92 const int width, const int size, const Dtype negative_beta,
93 const Dtype cache_ratio,
95 int index = threadIdx.x + blockIdx.x * blockDim.x;
96 if (index < nthreads) {
97 // find out the local offset
98 int w = index % width;
99 int h = (index / width) % height;
100 int n = index / width / height;
101 int offset = (n * channels * height + h) * width + w;
102 int step = height * width;
103 bottom_data += offset;
107 bottom_diff += offset;
109 int pre_pad = size - (size + 1) / 2;
110 int post_pad = size - pre_pad - 1;
111 Dtype accum_ratio = 0;
113 while (head < post_pad) {
114 accum_ratio += top_diff[head * step] * top_data[head * step] /
118 // until we reach size, nothing needs to be subtracted
119 while (head < size) {
120 accum_ratio += top_diff[head * step] * top_data[head * step] /
122 bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
123 * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
124 bottom_data[(head - post_pad) * step] * accum_ratio;
127 // both add and subtract
128 while (head < channels) {
129 accum_ratio += top_diff[head * step] * top_data[head * step] /
131 accum_ratio -= top_diff[(head - size) * step] *
132 top_data[(head - size) * step] / scale[(head - size) * step];
133 bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
134 * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
135 bottom_data[(head - post_pad) * step] * accum_ratio;
139 while (head < channels + post_pad) {
140 accum_ratio -= top_diff[(head - size) * step] *
141 top_data[(head - size) * step] / scale[(head - size) * step];
142 bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step]
143 * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio *
144 bottom_data[(head - post_pad) * step] * accum_ratio;
150 template <typename Dtype>
151 Dtype LRNLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
152 const bool propagate_down, vector<Blob<Dtype>*>* bottom) {
153 int n_threads = num_ * height_ * width_;
154 LRNComputeDiff<<<CAFFE_GET_BLOCKS(n_threads), CAFFE_CUDA_NUM_THREADS>>>(
155 n_threads, (*bottom)[0]->gpu_data(), top[0]->gpu_data(),
156 scale_.gpu_data(), top[0]->gpu_diff(), num_, channels_, height_, width_,
157 size_, -beta_, Dtype(2. * alpha_ * beta_ / size_),
158 (*bottom)[0]->mutable_gpu_diff());
163 INSTANTIATE_CLASS(LRNLayer);