sigmoid cross-entropy loss: add GPU forward for full GPU mode
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Thu, 27 Oct 2016 07:41:03 +0000 (00:41 -0700)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Thu, 27 Oct 2016 08:37:36 +0000 (01:37 -0700)
close #3004

include/caffe/layers/sigmoid_cross_entropy_loss_layer.hpp
src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp
src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu

index 598dca5..6452ea5 100644 (file)
@@ -59,6 +59,8 @@ class SigmoidCrossEntropyLossLayer : public LossLayer<Dtype> {
   /// @copydoc SigmoidCrossEntropyLossLayer
   virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
       const vector<Blob<Dtype>*>& top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
 
   /**
    * @brief Computes the sigmoid cross-entropy loss error gradient w.r.t. the
index 10ac947..eb77a9c 100644 (file)
@@ -68,7 +68,7 @@ void SigmoidCrossEntropyLossLayer<Dtype>::Backward_cpu(
 }
 
 #ifdef CPU_ONLY
-STUB_GPU_BACKWARD(SigmoidCrossEntropyLossLayer, Backward);
+STUB_GPU(SigmoidCrossEntropyLossLayer);
 #endif
 
 INSTANTIATE_CLASS(SigmoidCrossEntropyLossLayer);
index 046cb9d..7cb982d 100644 (file)
@@ -6,6 +6,39 @@
 namespace caffe {
 
 template <typename Dtype>
+__global__ void SigmoidCrossEntropyLossForwardGPU(const int nthreads,
+          const Dtype* input_data, const Dtype* target, Dtype* loss) {
+  CUDA_KERNEL_LOOP(i, nthreads) {
+    loss[i] = input_data[i] * (target[i] - (input_data[i] >= 0)) -
+        log(1 + exp(input_data[i] - 2 * input_data[i] * (input_data[i] >= 0)));
+  }
+}
+
+template <typename Dtype>
+void SigmoidCrossEntropyLossLayer<Dtype>::Forward_gpu(
+    const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
+  // The forward pass computes the sigmoid outputs.
+  sigmoid_bottom_vec_[0] = bottom[0];
+  sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_);
+  // Compute the loss (negative log likelihood)
+  const int count = bottom[0]->count();
+  const int num = bottom[0]->num();
+  // Stable version of loss computation from input data
+  const Dtype* input_data = bottom[0]->gpu_data();
+  const Dtype* target = bottom[1]->gpu_data();
+  // Since this memory is not used for anything until it is overwritten
+  // on the backward pass, we use it here to avoid having to allocate new GPU
+  // memory to accumulate intermediate results in the kernel.
+  Dtype* loss_data = bottom[0]->mutable_gpu_diff();
+  // NOLINT_NEXT_LINE(whitespace/operators)
+  SigmoidCrossEntropyLossForwardGPU<Dtype><<<CAFFE_GET_BLOCKS(count),
+      CAFFE_CUDA_NUM_THREADS>>>(count, input_data, target, loss_data);
+  Dtype loss;
+  caffe_gpu_asum(count, loss_data, &loss);
+  top[0]->mutable_cpu_data()[0] = loss / num;
+}
+
+template <typename Dtype>
 void SigmoidCrossEntropyLossLayer<Dtype>::Backward_gpu(
     const vector<Blob<Dtype>*>& top, const vector<bool>& propagate_down,
     const vector<Blob<Dtype>*>& bottom) {
@@ -28,7 +61,6 @@ void SigmoidCrossEntropyLossLayer<Dtype>::Backward_gpu(
   }
 }
 
-INSTANTIATE_LAYER_GPU_BACKWARD(SigmoidCrossEntropyLossLayer);
-
+INSTANTIATE_LAYER_GPU_FUNCS(SigmoidCrossEntropyLossLayer);
 
 }  // namespace caffe