From 197d11a0e1be7ad35714eb38d9b391e1cd39af39 Mon Sep 17 00:00:00 2001 From: Evan Shelhamer Date: Thu, 27 Oct 2016 00:41:03 -0700 Subject: [PATCH] sigmoid cross-entropy loss: add GPU forward for full GPU mode close #3004 --- .../layers/sigmoid_cross_entropy_loss_layer.hpp | 2 ++ .../layers/sigmoid_cross_entropy_loss_layer.cpp | 2 +- .../layers/sigmoid_cross_entropy_loss_layer.cu | 36 ++++++++++++++++++++-- 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/include/caffe/layers/sigmoid_cross_entropy_loss_layer.hpp b/include/caffe/layers/sigmoid_cross_entropy_loss_layer.hpp index 598dca5..6452ea5 100644 --- a/include/caffe/layers/sigmoid_cross_entropy_loss_layer.hpp +++ b/include/caffe/layers/sigmoid_cross_entropy_loss_layer.hpp @@ -59,6 +59,8 @@ class SigmoidCrossEntropyLossLayer : public LossLayer { /// @copydoc SigmoidCrossEntropyLossLayer virtual void Forward_cpu(const vector*>& bottom, const vector*>& top); + virtual void Forward_gpu(const vector*>& bottom, + const vector*>& top); /** * @brief Computes the sigmoid cross-entropy loss error gradient w.r.t. the diff --git a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp index 10ac947..eb77a9c 100644 --- a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp +++ b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp @@ -68,7 +68,7 @@ void SigmoidCrossEntropyLossLayer::Backward_cpu( } #ifdef CPU_ONLY -STUB_GPU_BACKWARD(SigmoidCrossEntropyLossLayer, Backward); +STUB_GPU(SigmoidCrossEntropyLossLayer); #endif INSTANTIATE_CLASS(SigmoidCrossEntropyLossLayer); diff --git a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu index 046cb9d..7cb982d 100644 --- a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu +++ b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu @@ -6,6 +6,39 @@ namespace caffe { template +__global__ void SigmoidCrossEntropyLossForwardGPU(const int nthreads, + const Dtype* input_data, const Dtype* target, Dtype* loss) { + CUDA_KERNEL_LOOP(i, nthreads) { + loss[i] = input_data[i] * (target[i] - (input_data[i] >= 0)) - + log(1 + exp(input_data[i] - 2 * input_data[i] * (input_data[i] >= 0))); + } +} + +template +void SigmoidCrossEntropyLossLayer::Forward_gpu( + const vector*>& bottom, const vector*>& top) { + // The forward pass computes the sigmoid outputs. + sigmoid_bottom_vec_[0] = bottom[0]; + sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_); + // Compute the loss (negative log likelihood) + const int count = bottom[0]->count(); + const int num = bottom[0]->num(); + // Stable version of loss computation from input data + const Dtype* input_data = bottom[0]->gpu_data(); + const Dtype* target = bottom[1]->gpu_data(); + // Since this memory is not used for anything until it is overwritten + // on the backward pass, we use it here to avoid having to allocate new GPU + // memory to accumulate intermediate results in the kernel. + Dtype* loss_data = bottom[0]->mutable_gpu_diff(); + // NOLINT_NEXT_LINE(whitespace/operators) + SigmoidCrossEntropyLossForwardGPU<<>>(count, input_data, target, loss_data); + Dtype loss; + caffe_gpu_asum(count, loss_data, &loss); + top[0]->mutable_cpu_data()[0] = loss / num; +} + +template void SigmoidCrossEntropyLossLayer::Backward_gpu( const vector*>& top, const vector& propagate_down, const vector*>& bottom) { @@ -28,7 +61,6 @@ void SigmoidCrossEntropyLossLayer::Backward_gpu( } } -INSTANTIATE_LAYER_GPU_BACKWARD(SigmoidCrossEntropyLossLayer); - +INSTANTIATE_LAYER_GPU_FUNCS(SigmoidCrossEntropyLossLayer); } // namespace caffe -- 2.7.4