return;
}
const Dtype* top_diff = top[0]->cpu_diff();
- const Dtype* top_data = top[0]->cpu_data();
- const Dtype* bottom_data = (*bottom)[0]->cpu_data();
Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
// Different pooling methods. We explicitly do the switch outside the for
// loop to save time, although this results in more codes.
bottom_diff[mask[ph * pooled_width_ + pw]]+=top_diff[ph * pooled_width_ + pw];
}
}
- // offset
- bottom_data += (*bottom)[0]->offset(0, 1);
- top_data += top[0]->offset(0, 1);
bottom_diff += (*bottom)[0]->offset(0, 1);
top_diff += top[0]->offset(0, 1);
mask += top[0]->offset(0, 1);
}
}
// offset
- bottom_data += (*bottom)[0]->offset(0, 1);
- top_data += top[0]->offset(0, 1);
bottom_diff += (*bottom)[0]->offset(0, 1);
top_diff += top[0]->offset(0, 1);
}
return Dtype(0.);
}
+// template <typename Dtype>
+// __global__ void MaxPoolBackward(const int nthreads, const Dtype* bottom_data,
+// const Dtype* top_data, const Dtype* top_diff,
+// const int num, const int channels, const int height,
+// const int width, const int pooled_height, const int pooled_width,
+// const int ksize, const int stride, Dtype* bottom_diff, int* mask) {
+// int index = threadIdx.x + blockIdx.x * blockDim.x;
+// if (index < nthreads) {
+// // find out the local index
+// // find out the local offset
+// int w = index % width;
+// int h = (index / width) % height;
+// int c = (index / width / height) % channels;
+// int n = index / width / height / channels;
+// int phstart = (h < ksize) ? 0 : (h - ksize) / stride + 1;
+// int phend = min(h / stride + 1, pooled_height);
+// int pwstart = (w < ksize) ? 0 : (w - ksize) / stride + 1;
+// int pwend = min(w / stride + 1, pooled_width);
+// Dtype gradient = 0;
+// Dtype bottom_datum =
+// bottom_data[((n * channels + c) * height + h) * width + w];
+// top_data += (n * channels + c) * pooled_height * pooled_width;
+// top_diff += (n * channels + c) * pooled_height * pooled_width;
+// //bottom_diff[index] += top_diff[mask[index]];
+// for (int ph = phstart; ph < phend; ++ph) {
+// for (int pw = pwstart; pw < pwend; ++pw) {
+// gradient += top_diff[ph * pooled_width + pw] *
+// (bottom_datum == top_data[ph * pooled_width + pw]);
+// }
+// }
+// bottom_diff[index] = gradient;
+// } // (if index < nthreads)
+// }
+
template <typename Dtype>
-__global__ void MaxPoolBackward(const int nthreads, const Dtype* bottom_data,
- const Dtype* top_data, const Dtype* top_diff,
+__global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
const int num, const int channels, const int height,
const int width, const int pooled_height, const int pooled_width,
const int kernel_size, const int stride, Dtype* bottom_diff, int* mask) {
CUDA_KERNEL_LOOP(index, nthreads) {
// find out the local index
// find out the local offset
- int w = index % width;
- int h = (index / width) % height;
- int c = (index / width / height) % channels;
- int n = index / width / height / channels;
- int phstart = (h < kernel_size) ? 0 : (h - kernel_size) / stride + 1;
- int phend = min(h / stride + 1, pooled_height);
- int pwstart = (w < kernel_size) ? 0 : (w - kernel_size) / stride + 1;
- int pwend = min(w / stride + 1, pooled_width);
- Dtype gradient = 0;
- Dtype bottom_datum =
- bottom_data[((n * channels + c) * height + h) * width + w];
- top_data += (n * channels + c) * pooled_height * pooled_width;
- top_diff += (n * channels + c) * pooled_height * pooled_width;
- //bottom_diff[index] += top_diff[mask[index]];
- for (int ph = phstart; ph < phend; ++ph) {
- for (int pw = pwstart; pw < pwend; ++pw) {
- gradient += top_diff[ph * pooled_width + pw] *
- (bottom_datum == top_data[ph * pooled_width + pw]);
- }
- }
- bottom_diff[index] = gradient;
- }
+ int c = (index / pooled_width / pooled_height) % channels;
+ int n = index / pooled_width / pooled_height / channels;
+ bottom_diff += (n * channels + c) * height * width;
+ bottom_diff[mask[index]] += top_diff[index];
+ }
}
-
template <typename Dtype>
__global__ void AvePoolBackward(const int nthreads, const Dtype* top_diff,
const int num, const int channels, const int height,