sigmoid cross-entropy loss: normalize loss by different schemes
authorEvan Shelhamer <shelhamer@imaginarynumber.net>
Thu, 17 Nov 2016 04:39:42 +0000 (20:39 -0800)
committerEvan Shelhamer <shelhamer@imaginarynumber.net>
Thu, 17 Nov 2016 06:21:41 +0000 (22:21 -0800)
sig-ce loss handles all the same normalizations as the softmax loss;
refer to #3296 for more detail.

this preserves the default normalization for sig-ce loss: batch size.

include/caffe/layers/sigmoid_cross_entropy_loss_layer.hpp
src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp
src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu
src/caffe/proto/caffe.proto

index a9fe33c..3d92524 100644 (file)
@@ -97,6 +97,13 @@ class SigmoidCrossEntropyLossLayer : public LossLayer<Dtype> {
   virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
       const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
 
+  /// Read the normalization mode parameter and compute the normalizer based
+  /// on the blob size.  If normalization_mode is VALID, the count of valid
+  /// outputs will be read from valid_count, unless it is -1 in which case
+  /// all outputs are assumed to be valid.
+  virtual Dtype get_normalizer(
+      LossParameter_NormalizationMode normalization_mode, int valid_count);
+
   /// The internal SigmoidLayer used to map predictions to probabilities.
   shared_ptr<SigmoidLayer<Dtype> > sigmoid_layer_;
   /// sigmoid_output stores the output of the SigmoidLayer.
@@ -110,6 +117,10 @@ class SigmoidCrossEntropyLossLayer : public LossLayer<Dtype> {
   bool has_ignore_label_;
   /// The label indicating that an instance should be ignored.
   int ignore_label_;
+  /// How to normalize the loss.
+  LossParameter_NormalizationMode normalization_;
+  Dtype normalizer_;
+  int outer_num_, inner_num_;
 };
 
 }  // namespace caffe
index 21b64c2..99fa3eb 100644 (file)
@@ -1,3 +1,4 @@
+#include <algorithm>
 #include <vector>
 
 #include "caffe/layers/sigmoid_cross_entropy_loss_layer.hpp"
@@ -20,17 +21,60 @@ void SigmoidCrossEntropyLossLayer<Dtype>::LayerSetUp(
   if (has_ignore_label_) {
     ignore_label_ = this->layer_param_.loss_param().ignore_label();
   }
+  if (this->layer_param_.loss_param().has_normalization()) {
+    normalization_ = this->layer_param_.loss_param().normalization();
+  } else if (this->layer_param_.loss_param().has_normalize()) {
+    normalization_ = this->layer_param_.loss_param().normalize() ?
+                     LossParameter_NormalizationMode_VALID :
+                     LossParameter_NormalizationMode_BATCH_SIZE;
+  } else {
+    normalization_ = LossParameter_NormalizationMode_BATCH_SIZE;
+  }
 }
 
 template <typename Dtype>
 void SigmoidCrossEntropyLossLayer<Dtype>::Reshape(
     const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
   LossLayer<Dtype>::Reshape(bottom, top);
+  outer_num_ = bottom[0]->shape(0);  // batch size
+  inner_num_ = bottom[0]->count(1);  // instance size: |output| == |target|
   CHECK_EQ(bottom[0]->count(), bottom[1]->count()) <<
       "SIGMOID_CROSS_ENTROPY_LOSS layer inputs must have the same count.";
   sigmoid_layer_->Reshape(sigmoid_bottom_vec_, sigmoid_top_vec_);
 }
 
+// TODO(shelhamer) loss normalization should be pulled up into LossLayer,
+// instead of duplicated here and in SoftMaxWithLossLayer
+template <typename Dtype>
+Dtype SigmoidCrossEntropyLossLayer<Dtype>::get_normalizer(
+    LossParameter_NormalizationMode normalization_mode, int valid_count) {
+  Dtype normalizer;
+  switch (normalization_mode) {
+    case LossParameter_NormalizationMode_FULL:
+      normalizer = Dtype(outer_num_ * inner_num_);
+      break;
+    case LossParameter_NormalizationMode_VALID:
+      if (valid_count == -1) {
+        normalizer = Dtype(outer_num_ * inner_num_);
+      } else {
+        normalizer = Dtype(valid_count);
+      }
+      break;
+    case LossParameter_NormalizationMode_BATCH_SIZE:
+      normalizer = Dtype(outer_num_);
+      break;
+    case LossParameter_NormalizationMode_NONE:
+      normalizer = Dtype(1);
+      break;
+    default:
+      LOG(FATAL) << "Unknown normalization mode: "
+          << LossParameter_NormalizationMode_Name(normalization_mode);
+  }
+  // Some users will have no labels for some examples in order to 'turn off' a
+  // particular loss in a multi-task setup. The max prevents NaNs in that case.
+  return std::max(Dtype(1.0), normalizer);
+}
+
 template <typename Dtype>
 void SigmoidCrossEntropyLossLayer<Dtype>::Forward_cpu(
     const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
@@ -38,21 +82,22 @@ void SigmoidCrossEntropyLossLayer<Dtype>::Forward_cpu(
   sigmoid_bottom_vec_[0] = bottom[0];
   sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_);
   // Compute the loss (negative log likelihood)
-  const int count = bottom[0]->count();
-  const int num = bottom[0]->num();
   // Stable version of loss computation from input data
   const Dtype* input_data = bottom[0]->cpu_data();
   const Dtype* target = bottom[1]->cpu_data();
+  int valid_count = 0;
   Dtype loss = 0;
-  for (int i = 0; i < count; ++i) {
+  for (int i = 0; i < bottom[0]->count(); ++i) {
     const int target_value = static_cast<int>(target[i]);
     if (has_ignore_label_ && target_value == ignore_label_) {
       continue;
     }
     loss -= input_data[i] * (target[i] - (input_data[i] >= 0)) -
         log(1 + exp(input_data[i] - 2 * input_data[i] * (input_data[i] >= 0)));
+    ++valid_count;
   }
-  top[0]->mutable_cpu_data()[0] = loss / num;
+  normalizer_ = get_normalizer(normalization_, valid_count);
+  top[0]->mutable_cpu_data()[0] = loss / normalizer_;
 }
 
 template <typename Dtype>
@@ -66,14 +111,10 @@ void SigmoidCrossEntropyLossLayer<Dtype>::Backward_cpu(
   if (propagate_down[0]) {
     // First, compute the diff
     const int count = bottom[0]->count();
-    const int num = bottom[0]->num();
     const Dtype* sigmoid_output_data = sigmoid_output_->cpu_data();
     const Dtype* target = bottom[1]->cpu_data();
     Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
     caffe_sub(count, sigmoid_output_data, target, bottom_diff);
-    // Scale down gradient
-    const Dtype loss_weight = top[0]->cpu_diff()[0];
-    caffe_scal(count, loss_weight / num, bottom_diff);
     // Zero out gradient of ignored targets.
     if (has_ignore_label_) {
       for (int i = 0; i < count; ++i) {
@@ -83,6 +124,9 @@ void SigmoidCrossEntropyLossLayer<Dtype>::Backward_cpu(
         }
       }
     }
+    // Scale down gradient
+    Dtype loss_weight = top[0]->cpu_diff()[0] / normalizer_;
+    caffe_scal(count, loss_weight, bottom_diff);
   }
 }
 
index 39eb050..b9877e6 100644 (file)
@@ -5,26 +5,38 @@
 
 namespace caffe {
 
+
 template <typename Dtype>
 __global__ void SigmoidCrossEntropyLossForwardGPU(const int nthreads,
-          const Dtype* input_data, const Dtype* target, Dtype* loss) {
+          const Dtype* input_data, const Dtype* target, Dtype* loss,
+          const bool has_ignore_label_, const int ignore_label_,
+          Dtype* counts) {
   CUDA_KERNEL_LOOP(i, nthreads) {
-    loss[i] = input_data[i] * (target[i] - (input_data[i] >= 0)) -
-        log(1 + exp(input_data[i] - 2 * input_data[i] * (input_data[i] >= 0)));
+    const int target_value = static_cast<int>(target[i]);
+    if (has_ignore_label_ && target_value == ignore_label_) {
+      loss[i] = 0;
+      counts[i] = 0;
+    } else {
+      loss[i] = input_data[i] * (target[i] - (input_data[i] >= 0)) -
+          log(1 + exp(input_data[i] - 2 * input_data[i] *
+          (input_data[i] >= 0)));
+      counts[i] = 1;
+    }
   }
 }
 
 template <typename Dtype>
-__global__ void SigmoidCrossEntropyLossIgnoreGPU(const int count,
-    const int ignore_label, const Dtype* target, Dtype* reference) {
-  CUDA_KERNEL_LOOP(index, count) {
-    const int target_value = static_cast<int>(target[index]);
+__global__ void SigmoidCrossEntropyLossIgnoreDiffGPU(const int count,
+    const int ignore_label, const Dtype* target, Dtype* diff) {
+  CUDA_KERNEL_LOOP(i, count) {
+    const int target_value = static_cast<int>(target[i]);
     if (target_value == ignore_label) {
-      reference[index] = 0;
+      diff[i] = 0;
     }
   }
 }
 
+
 template <typename Dtype>
 void SigmoidCrossEntropyLossLayer<Dtype>::Forward_gpu(
     const vector<Blob<Dtype>*>& bottom, const vector<Blob<Dtype>*>& top) {
@@ -33,7 +45,6 @@ void SigmoidCrossEntropyLossLayer<Dtype>::Forward_gpu(
   sigmoid_layer_->Forward(sigmoid_bottom_vec_, sigmoid_top_vec_);
   // Compute the loss (negative log likelihood)
   const int count = bottom[0]->count();
-  const int num = bottom[0]->num();
   // Stable version of loss computation from input data
   const Dtype* input_data = bottom[0]->gpu_data();
   const Dtype* target = bottom[1]->gpu_data();
@@ -41,18 +52,23 @@ void SigmoidCrossEntropyLossLayer<Dtype>::Forward_gpu(
   // on the backward pass, we use it here to avoid having to allocate new GPU
   // memory to accumulate intermediate results in the kernel.
   Dtype* loss_data = bottom[0]->mutable_gpu_diff();
+  Dtype* count_data = bottom[1]->mutable_gpu_diff();
+  Dtype valid_count;
   // NOLINT_NEXT_LINE(whitespace/operators)
   SigmoidCrossEntropyLossForwardGPU<Dtype><<<CAFFE_GET_BLOCKS(count),
-      CAFFE_CUDA_NUM_THREADS>>>(count, input_data, target, loss_data);
-  // Zero out loss of ignored targets.
-  if (has_ignore_label_) {
-    // NOLINT_NEXT_LINE(whitespace/operators)
-    SigmoidCrossEntropyLossIgnoreGPU<Dtype><<<CAFFE_GET_BLOCKS(count),
-      CAFFE_CUDA_NUM_THREADS>>>(count, ignore_label_, target, loss_data);
+      CAFFE_CUDA_NUM_THREADS>>>(count, input_data, target, loss_data,
+      has_ignore_label_, ignore_label_, count_data);
+  // Only launch another CUDA kernel if we actually need the valid count.
+  if (normalization_ == LossParameter_NormalizationMode_VALID &&
+      has_ignore_label_) {
+    caffe_gpu_asum(count, count_data, &valid_count);
+  } else {
+    valid_count = count;
   }
   Dtype loss;
   caffe_gpu_asum(count, loss_data, &loss);
-  top[0]->mutable_cpu_data()[0] = loss / num;
+  normalizer_ = get_normalizer(normalization_, valid_count);
+  top[0]->mutable_cpu_data()[0] = loss / normalizer_;
 }
 
 template <typename Dtype>
@@ -66,21 +82,20 @@ void SigmoidCrossEntropyLossLayer<Dtype>::Backward_gpu(
   if (propagate_down[0]) {
     // First, compute the diff
     const int count = bottom[0]->count();
-    const int num = bottom[0]->num();
     const Dtype* sigmoid_output_data = sigmoid_output_->gpu_data();
     const Dtype* target = bottom[1]->gpu_data();
     Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
     caffe_copy(count, sigmoid_output_data, bottom_diff);
     caffe_gpu_axpy(count, Dtype(-1), target, bottom_diff);
-    // Scale down gradient
-    const Dtype loss_weight = top[0]->cpu_diff()[0];
-    caffe_gpu_scal(count, loss_weight / num, bottom_diff);
     // Zero out gradient of ignored targets.
     if (has_ignore_label_) {
       // NOLINT_NEXT_LINE(whitespace/operators)
-      SigmoidCrossEntropyLossIgnoreGPU<Dtype><<<CAFFE_GET_BLOCKS(count),
+      SigmoidCrossEntropyLossIgnoreDiffGPU<Dtype><<<CAFFE_GET_BLOCKS(count),
         CAFFE_CUDA_NUM_THREADS>>>(count, ignore_label_, target, bottom_diff);
     }
+    // Scale down gradient
+    Dtype loss_weight = top[0]->cpu_diff()[0] / normalizer_;
+    caffe_gpu_scal(count, loss_weight, bottom_diff);
   }
 }
 
index 6940a70..0b2768b 100644 (file)
@@ -434,7 +434,7 @@ message LossParameter {
   optional int32 ignore_label = 1;
   // How to normalize the loss for loss layers that aggregate across batches,
   // spatial dimensions, or other dimensions.  Currently only implemented in
-  // SoftmaxWithLoss layer.
+  // SoftmaxWithLoss and SigmoidCrossEntropyLoss layers.
   enum NormalizationMode {
     // Divide by the number of examples in the batch times spatial dimensions.
     // Outputs that receive the ignore label will NOT be ignored in computing
@@ -448,6 +448,8 @@ message LossParameter {
     // Do not normalize the loss.
     NONE = 3;
   }
+  // For historical reasons, the default normalization for
+  // SigmoidCrossEntropyLoss is BATCH_SIZE and *not* VALID.
   optional NormalizationMode normalization = 3 [default = VALID];
   // Deprecated.  Ignored if normalization is specified.  If normalization
   // is not specified, then setting this to false will be equivalent to