mask should be const in backward pass
authorJeff Donahue <jeff.donahue@gmail.com>
Sat, 24 May 2014 23:14:04 +0000 (16:14 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Sun, 25 May 2014 01:24:50 +0000 (18:24 -0700)
src/caffe/layers/pooling_layer.cu

index 00c7846..a91b137 100644 (file)
@@ -197,9 +197,9 @@ Dtype PoolingLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
 
 template <typename Dtype>
 __global__ void MaxPoolBackward(const int nthreads, const Dtype* top_diff,
-    const int num, const int channels, const int height,
+    const int* mask, const int num, const int channels, const int height,
     const int width, const int pooled_height, const int pooled_width,
-    const int kernel_size, const int stride, Dtype* bottom_diff, int* mask) {
+    const int kernel_size, const int stride, Dtype* bottom_diff) {
   CUDA_KERNEL_LOOP(index, nthreads) {
     // find out the local index
     // find out the local offset
@@ -300,18 +300,18 @@ void PoolingLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
   Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
   int count = (*bottom)[0]->count();
   CUDA_CHECK(cudaMemset(bottom_diff, 0, sizeof(Dtype) * count));
-  int* mask;
+  const int* mask;
   switch (this->layer_param_.pooling_param().pool()) {
   case PoolingParameter_PoolMethod_MAX:
-    mask = static_cast<int*>(max_idx_->mutable_gpu_data());
+    mask = static_cast<const int*>(max_idx_->gpu_data());
     // Since we have the mask we only need count top_diff
     count = top[0]->count();
     caffe_gpu_set(count, Dtype(0.), bottom_diff);
     // NOLINT_NEXT_LINE(whitespace/operators)
     MaxPoolBackward<Dtype><<<CAFFE_GET_BLOCKS(count), CAFFE_CUDA_NUM_THREADS>>>(
-        count, top_diff, top[0]->num(), channels_,
+        count, top_diff, mask, top[0]->num(), channels_,
         height_, width_, pooled_height_, pooled_width_,
-        kernel_size_, stride_, bottom_diff, mask);
+        kernel_size_, stride_, bottom_diff);
     break;
   case PoolingParameter_PoolMethod_AVE:
     // NOLINT_NEXT_LINE(whitespace/operators)