implement GPU version of Softmax
authorRonghang Hu <huronghang@hotmail.com>
Sat, 16 Aug 2014 21:52:19 +0000 (14:52 -0700)
committerJonathan L Long <jonlong@cs.berkeley.edu>
Tue, 19 Aug 2014 08:05:22 +0000 (01:05 -0700)
src/caffe/layers/softmax_layer.cpp
src/caffe/layers/softmax_layer.cu

index ab74dc7..29767ac 100644 (file)
@@ -18,7 +18,7 @@ void SoftmaxLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
   for (int i = 0; i < sum_multiplier_.count(); ++i) {
     multiplier_data[i] = 1.;
   }
-  scale_.Reshape(1, 1, bottom[0]->height(), bottom[0]->width());
+  scale_.Reshape(bottom[0]->num(), 1, bottom[0]->height(), bottom[0]->width());
 }
 
 template <typename Dtype>
index 6b85309..f97eafc 100644 (file)
 namespace caffe {
 
 template <typename Dtype>
-__global__ void kernel_get_max(const int num, const int dim,
-    const Dtype* data, Dtype* out) {
-  CUDA_KERNEL_LOOP(index, num) {
+__global__ void kernel_channel_max(const int num, const int channels,
+    const int spatial_dim, const Dtype* data, Dtype* out) {
+  CUDA_KERNEL_LOOP(index, num * spatial_dim) {
+    int n = index / spatial_dim;
+    int s = index % spatial_dim;
     Dtype maxval = -FLT_MAX;
-    for (int i = 0; i < dim; ++i) {
-      maxval = max(data[index * dim + i], maxval);
+    for (int c = 0; c < channels; ++c) {
+      maxval = max(data[(n * channels + c) * spatial_dim + s], maxval);
     }
     out[index] = maxval;
   }
 }
 
 template <typename Dtype>
-__global__ void kernel_softmax_div(const int num, const int dim,
-    const Dtype* scale, Dtype* data) {
-  CUDA_KERNEL_LOOP(index, num * dim) {
-    int n = index / dim;
-    data[index] /= scale[n];
+__global__ void kernel_channel_subtract(const int num, const int channels,
+    const int spatial_dim, Dtype* data, const Dtype* channel_max) {
+  CUDA_KERNEL_LOOP(index, num * spatial_dim) {
+    int n = index / spatial_dim;
+    int s = index % spatial_dim;
+    for (int c = 0; c < channels; ++c) {
+      data[(n * channels + c) * spatial_dim + s] -= channel_max[index];
+    }
   }
 }
 
 template <typename Dtype>
-__global__ void kernel_exp(const int num, const Dtype* data, Dtype* out) {
-  CUDA_KERNEL_LOOP(index, num) {
+__global__ void kernel_exp(const int count, const Dtype* data, Dtype* out) {
+  CUDA_KERNEL_LOOP(index, count) {
     out[index] = exp(data[index]);
   }
 }
 
 template <typename Dtype>
+__global__ void kernel_channel_sum(const int num, const int channels,
+    const int spatial_dim, const Dtype* data, Dtype* channel_sum) {
+  CUDA_KERNEL_LOOP(index, num * spatial_dim) {
+    int n = index / spatial_dim;
+    int s = index % spatial_dim;
+    Dtype sum = 0;
+    for (int c = 0; c < channels; ++c) {
+      sum += data[(n * channels + c) * spatial_dim + s];
+    }
+    channel_sum[index] = sum;
+  }
+}
+
+template <typename Dtype>
+__global__ void kernel_channel_div(const int num, const int channels,
+    const int spatial_dim, Dtype* data, const Dtype* channel_sum) {
+  CUDA_KERNEL_LOOP(index, num * spatial_dim) {
+    int n = index / spatial_dim;
+    int s = index % spatial_dim;
+    for (int c = 0; c < channels; ++c) {
+      data[(n * channels + c) * spatial_dim + s] /= channel_sum[index];
+    }
+  }
+}
+
+template <typename Dtype>
+__global__ void kernel_channel_dot(const int num, const int channels,
+    const int spatial_dim, const Dtype* data_1, const Dtype* data_2,
+    Dtype* channel_dot) {
+  CUDA_KERNEL_LOOP(index, num * spatial_dim) {
+    int n = index / spatial_dim;
+    int s = index % spatial_dim;
+    Dtype dot = 0;
+    for (int c = 0; c < channels; ++c) {
+      dot += (data_1[(n * channels + c) * spatial_dim + s]
+          * data_2[(n * channels + c) * spatial_dim + s]);
+    }
+    channel_dot[index] = dot;
+  }
+}
+
+template <typename Dtype>
 void SoftmaxLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
     vector<Blob<Dtype>*>* top) {
   const Dtype* bottom_data = bottom[0]->gpu_data();
   Dtype* top_data = (*top)[0]->mutable_gpu_data();
   Dtype* scale_data = scale_.mutable_gpu_data();
   int num = bottom[0]->num();
-  int dim = bottom[0]->count() / bottom[0]->num();
+  int channels = bottom[0]->channels();
+  int spatial_dim = bottom[0]->height() * bottom[0]->width();
   caffe_copy(bottom[0]->count(), bottom_data, top_data);
-  // we need to subtract the max to avoid numerical issues, compute the exp,
+  // We need to subtract the max to avoid numerical issues, compute the exp,
   // and then normalize.
-  // Compute max
+  // compute max
+  // NOLINT_NEXT_LINE(whitespace/operators)
+  kernel_channel_max<Dtype><<<CAFFE_GET_BLOCKS(num * spatial_dim),
+      CAFFE_CUDA_NUM_THREADS>>>(num, channels, spatial_dim, top_data,
+      scale_data);
+  // subtract
   // NOLINT_NEXT_LINE(whitespace/operators)
-  kernel_get_max<Dtype><<<CAFFE_GET_BLOCKS(num), CAFFE_CUDA_NUM_THREADS>>>(
-      num, dim, bottom_data, scale_data);
-  // subtraction
-  caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
-      scale_data, sum_multiplier_.gpu_data(), 1., top_data);
-  // Perform exponentiation
+  kernel_channel_subtract<Dtype><<<CAFFE_GET_BLOCKS(num * spatial_dim),
+      CAFFE_CUDA_NUM_THREADS>>>(num, channels, spatial_dim, top_data,
+      scale_data);
+  // exponentiate
   // NOLINT_NEXT_LINE(whitespace/operators)
-  kernel_exp<Dtype><<<CAFFE_GET_BLOCKS(num * dim), CAFFE_CUDA_NUM_THREADS>>>(
-      num * dim, top_data, top_data);
+  kernel_exp<Dtype><<<CAFFE_GET_BLOCKS(num * channels * spatial_dim),
+      CAFFE_CUDA_NUM_THREADS>>>(num * channels * spatial_dim, top_data,
+      top_data);
   // sum after exp
-  caffe_gpu_gemv<Dtype>(CblasNoTrans, num, dim, 1., top_data,
-      sum_multiplier_.gpu_data(), 0., scale_data);
-  // Do division
   // NOLINT_NEXT_LINE(whitespace/operators)
-  kernel_softmax_div<Dtype><<<CAFFE_GET_BLOCKS(num * dim),
-                              CAFFE_CUDA_NUM_THREADS>>>(
-      num, dim, scale_data, top_data);
+  kernel_channel_sum<Dtype><<<CAFFE_GET_BLOCKS(num * spatial_dim),
+      CAFFE_CUDA_NUM_THREADS>>>(num, channels, spatial_dim, top_data,
+      scale_data);
+  // divide
+  // NOLINT_NEXT_LINE(whitespace/operators)
+  kernel_channel_div<Dtype><<<CAFFE_GET_BLOCKS(num * spatial_dim),
+      CAFFE_CUDA_NUM_THREADS>>>(num, channels, spatial_dim, top_data,
+      scale_data);
 }
 
-// TODO(Yangqing): implement the GPU version of softmax.
 template <typename Dtype>
 void SoftmaxLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
     const vector<bool>& propagate_down, vector<Blob<Dtype>*>* bottom) {
   const Dtype* top_diff = top[0]->gpu_diff();
   const Dtype* top_data = top[0]->gpu_data();
   Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
+  Dtype* scale_data = scale_.mutable_gpu_data();
   int num = top[0]->num();
-  int dim = top[0]->count() / top[0]->num();
+  int channels = top[0]->channels();
+  int spatial_dim = top[0]->height() * top[0]->width();
   caffe_copy(top[0]->count(), top_diff, bottom_diff);
-  // Compute inner1d(top_diff, top_data) and subtract them from the bottom diff
-  // cuda dot returns the result to cpu, so we temporarily change the pointer
-  // mode
-  CUBLAS_CHECK(cublasSetPointerMode(Caffe::cublas_handle(),
-      CUBLAS_POINTER_MODE_DEVICE));
-  Dtype* scale_data = scale_.mutable_gpu_data();
-  for (int i = 0; i < num; ++i) {
-    caffe_gpu_dot<Dtype>(dim, top_diff + i * dim,
-        top_data + i * dim, scale_data + i);
-  }
-  CUBLAS_CHECK(cublasSetPointerMode(Caffe::cublas_handle(),
-      CUBLAS_POINTER_MODE_HOST));
-  // subtraction
-  caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num, dim, 1, -1.,
-      scale_.gpu_data(), sum_multiplier_.gpu_data(), 1., bottom_diff);
+  // Compute inner1d(top_diff, top_data) and subtract them from the bottom diff.
+  // NOLINT_NEXT_LINE(whitespace/operators)
+  kernel_channel_dot<Dtype><<<CAFFE_GET_BLOCKS(num * spatial_dim),
+      CAFFE_CUDA_NUM_THREADS>>>(num, channels, spatial_dim, top_diff, top_data,
+      scale_data);
+  // NOLINT_NEXT_LINE(whitespace/operators)
+  kernel_channel_subtract<Dtype><<<CAFFE_GET_BLOCKS(num * spatial_dim),
+      CAFFE_CUDA_NUM_THREADS>>>(num, channels, spatial_dim, bottom_diff,
+      scale_data);
   // elementwise multiplication
   caffe_gpu_mul<Dtype>(top[0]->count(), bottom_diff, top_data, bottom_diff);
 }