Add LogLayer
authorJeff Donahue <jeff.donahue@gmail.com>
Fri, 2 Jan 2015 07:07:44 +0000 (23:07 -0800)
committerJeff Donahue <jeff.donahue@gmail.com>
Wed, 3 Jun 2015 01:41:01 +0000 (18:41 -0700)
include/caffe/neuron_layers.hpp
include/caffe/util/math_functions.hpp
include/caffe/util/mkl_alternate.hpp
src/caffe/layers/log_layer.cpp [new file with mode: 0644]
src/caffe/proto/caffe.proto
src/caffe/test/test_neuron_layer.cpp
src/caffe/util/math_functions.cpp
src/caffe/util/math_functions.cu

index 9cf233f..c2e0774 100644 (file)
@@ -268,6 +268,72 @@ class ExpLayer : public NeuronLayer<Dtype> {
 };
 
 /**
+ * @brief Computes @f$ y = log_{\gamma}(\alpha x + \beta) @f$,
+ *        as specified by the scale @f$ \alpha @f$, shift @f$ \beta @f$,
+ *        and base @f$ \gamma @f$.
+ */
+template <typename Dtype>
+class LogLayer : public NeuronLayer<Dtype> {
+ public:
+  /**
+   * @param param provides LogParameter log_param,
+   *     with LogLayer options:
+   *   - scale (\b optional, default 1) the scale @f$ \alpha @f$
+   *   - shift (\b optional, default 0) the shift @f$ \beta @f$
+   *   - base (\b optional, default -1 for a value of @f$ e \approx 2.718 @f$)
+   *         the base @f$ \gamma @f$
+   */
+  explicit LogLayer(const LayerParameter& param)
+      : NeuronLayer<Dtype>(param) {}
+  virtual void LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+
+  virtual inline const char* type() const { return "Log"; }
+
+ protected:
+  /**
+   * @param bottom input Blob vector (length 1)
+   *   -# @f$ (N \times C \times H \times W) @f$
+   *      the inputs @f$ x @f$
+   * @param top output Blob vector (length 1)
+   *   -# @f$ (N \times C \times H \times W) @f$
+   *      the computed outputs @f$
+   *        y = log_{\gamma}(\alpha x + \beta)
+   *      @f$
+   */
+  virtual void Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+  virtual void Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top);
+
+  /**
+   * @brief Computes the error gradient w.r.t. the exp inputs.
+   *
+   * @param top output Blob vector (length 1), providing the error gradient with
+   *      respect to the outputs
+   *   -# @f$ (N \times C \times H \times W) @f$
+   *      containing error gradients @f$ \frac{\partial E}{\partial y} @f$
+   *      with respect to computed outputs @f$ y @f$
+   * @param propagate_down see Layer::Backward.
+   * @param bottom input Blob vector (length 1)
+   *   -# @f$ (N \times C \times H \times W) @f$
+   *      the inputs @f$ x @f$; Backward fills their diff with
+   *      gradients @f$
+   *        \frac{\partial E}{\partial x} =
+   *            \frac{\partial E}{\partial y} y \alpha \log_e(gamma)
+   *      @f$ if propagate_down[0]
+   */
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom);
+
+  Dtype base_scale_;
+  Dtype input_scale_, input_shift_;
+  Dtype backward_num_scale_;
+};
+
+/**
  * @brief Computes @f$ y = (\alpha x + \beta) ^ \gamma @f$,
  *        as specified by the scale @f$ \alpha @f$, shift @f$ \beta @f$,
  *        and power @f$ \gamma @f$.
index f43036f..2cacd8e 100644 (file)
@@ -89,6 +89,9 @@ template <typename Dtype>
 void caffe_exp(const int n, const Dtype* a, Dtype* y);
 
 template <typename Dtype>
+void caffe_log(const int n, const Dtype* a, Dtype* y);
+
+template <typename Dtype>
 void caffe_abs(const int n, const Dtype* a, Dtype* y);
 
 template <typename Dtype>
@@ -204,6 +207,9 @@ template <typename Dtype>
 void caffe_gpu_exp(const int n, const Dtype* a, Dtype* y);
 
 template <typename Dtype>
+void caffe_gpu_log(const int n, const Dtype* a, Dtype* y);
+
+template <typename Dtype>
 void caffe_gpu_powx(const int n, const Dtype* a, const Dtype b, Dtype* y);
 
 // caffe_gpu_rng_uniform with two arguments generates integers in the range
index 32fdbf7..3355b66 100644 (file)
@@ -33,6 +33,7 @@ extern "C" {
 
 DEFINE_VSL_UNARY_FUNC(Sqr, y[i] = a[i] * a[i]);
 DEFINE_VSL_UNARY_FUNC(Exp, y[i] = exp(a[i]));
+DEFINE_VSL_UNARY_FUNC(Ln, y[i] = log(a[i]));
 DEFINE_VSL_UNARY_FUNC(Abs, y[i] = fabs(a[i]));
 
 // A simple way to define the vsl unary functions with singular parameter b.
diff --git a/src/caffe/layers/log_layer.cpp b/src/caffe/layers/log_layer.cpp
new file mode 100644 (file)
index 0000000..45f7395
--- /dev/null
@@ -0,0 +1,136 @@
+#include <algorithm>
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/neuron_layers.hpp"
+#include "caffe/util/math_functions.hpp"
+
+namespace caffe {
+
+template <typename Dtype>
+void LogLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
+      const vector<Blob<Dtype>*>& top) {
+  NeuronLayer<Dtype>::LayerSetUp(bottom, top);
+  const Dtype base = this->layer_param_.log_param().base();
+  if (base != Dtype(-1)) {
+    CHECK_GT(base, 0) << "base must be strictly positive.";
+  }
+  // If base == -1, interpret the base as e and set log_base = 1 exactly.
+  // Otherwise, calculate its log explicitly.
+  const Dtype log_base = (base == Dtype(-1)) ? Dtype(1) : log(base);
+  CHECK(!isnan(log_base))
+      << "NaN result: log(base) = log(" << base << ") = " << log_base;
+  CHECK(!isinf(log_base))
+      << "Inf result: log(base) = log(" << base << ") = " << log_base;
+  base_scale_ = Dtype(1) / log_base;
+  CHECK(!isnan(base_scale_))
+      << "NaN result: 1/log(base) = 1/log(" << base << ") = " << base_scale_;
+  CHECK(!isinf(base_scale_))
+      << "Inf result: 1/log(base) = 1/log(" << base << ") = " << base_scale_;
+  input_scale_ = this->layer_param_.log_param().scale();
+  input_shift_ = this->layer_param_.log_param().shift();
+  backward_num_scale_ = input_scale_ / log_base;
+}
+
+template <typename Dtype>
+void LogLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  const int count = bottom[0]->count();
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  Dtype* top_data = top[0]->mutable_cpu_data();
+  if (input_scale_ == Dtype(1) && input_shift_ == Dtype(0)) {
+    caffe_log(count, bottom_data, top_data);
+  } else {
+    caffe_copy(count, bottom_data, top_data);
+    if (input_scale_ != Dtype(1)) {
+      caffe_scal(count, input_scale_, top_data);
+    }
+    if (input_shift_ != Dtype(0)) {
+      caffe_add_scalar(count, input_shift_, top_data);
+    }
+    caffe_log(count, top_data, top_data);
+  }
+  if (base_scale_ != Dtype(1)) {
+    caffe_scal(count, base_scale_, top_data);
+  }
+}
+
+template <typename Dtype>
+void LogLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
+  if (!propagate_down[0]) { return; }
+  const int count = bottom[0]->count();
+  const Dtype* bottom_data = bottom[0]->cpu_data();
+  const Dtype* top_diff = top[0]->cpu_diff();
+  Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
+  caffe_copy(count, bottom_data, bottom_diff);
+  if (input_scale_ != Dtype(1)) {
+    caffe_scal(count, input_scale_, bottom_diff);
+  }
+  if (input_shift_ != Dtype(0)) {
+    caffe_add_scalar(count, input_shift_, bottom_diff);
+  }
+  caffe_powx(count, bottom_diff, Dtype(-1), bottom_diff);
+  if (backward_num_scale_ != Dtype(1)) {
+    caffe_scal(count, backward_num_scale_, bottom_diff);
+  }
+  caffe_mul(count, top_diff, bottom_diff, bottom_diff);
+}
+
+#ifdef CPU_ONLY
+STUB_GPU(LogLayer);
+#else
+
+template <typename Dtype>
+void LogLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+    const vector<Blob<Dtype>*>& top) {
+  const int count = bottom[0]->count();
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  Dtype* top_data = top[0]->mutable_gpu_data();
+  if (input_scale_ == Dtype(1) && input_shift_ == Dtype(0)) {
+    caffe_gpu_log(count, bottom_data, top_data);
+  } else {
+    caffe_copy(count, bottom_data, top_data);
+    if (input_scale_ != Dtype(1)) {
+      caffe_gpu_scal(count, input_scale_, top_data);
+    }
+    if (input_shift_ != Dtype(0)) {
+      caffe_gpu_add_scalar(count, input_shift_, top_data);
+    }
+    caffe_gpu_log(count, top_data, top_data);
+  }
+  if (base_scale_ != Dtype(1)) {
+    caffe_gpu_scal(count, base_scale_, top_data);
+  }
+}
+
+template <typename Dtype>
+void LogLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
+    const vector<bool>& propagate_down, const vector<Blob<Dtype>*>& bottom) {
+  if (!propagate_down[0]) { return; }
+  const int count = bottom[0]->count();
+  const Dtype* bottom_data = bottom[0]->gpu_data();
+  const Dtype* top_diff = top[0]->gpu_diff();
+  Dtype* bottom_diff = bottom[0]->mutable_gpu_diff();
+  caffe_copy(count, bottom_data, bottom_diff);
+  if (input_scale_ != Dtype(1)) {
+    caffe_gpu_scal(count, input_scale_, bottom_diff);
+  }
+  if (input_shift_ != Dtype(0)) {
+    caffe_gpu_add_scalar(count, input_shift_, bottom_diff);
+  }
+  caffe_gpu_powx(count, bottom_diff, Dtype(-1), bottom_diff);
+  if (backward_num_scale_ != Dtype(1)) {
+    caffe_gpu_scal(count, backward_num_scale_, bottom_diff);
+  }
+  caffe_gpu_mul(count, top_diff, bottom_diff, bottom_diff);
+}
+
+INSTANTIATE_LAYER_GPU_FUNCS(LogLayer);
+
+#endif
+
+INSTANTIATE_CLASS(LogLayer);
+REGISTER_LAYER_CLASS(Log);
+
+}  // namespace caffe
index 94f713e..619642f 100644 (file)
@@ -269,7 +269,7 @@ message ParamSpec {
 // NOTE
 // Update the next available ID when you add a new LayerParameter field.
 //
-// LayerParameter next available layer-specific ID: 134 (last added: reshape_param)
+// LayerParameter next available layer-specific ID: 135 (last added: log_param)
 message LayerParameter {
   optional string name = 1; // the layer name
   optional string type = 2; // the layer type
@@ -332,6 +332,7 @@ message LayerParameter {
   optional ImageDataParameter image_data_param = 115;
   optional InfogainLossParameter infogain_loss_param = 116;
   optional InnerProductParameter inner_product_param = 117;
+  optional LogParameter log_param = 134;
   optional LRNParameter lrn_param = 118;
   optional MemoryDataParameter memory_data_param = 119;
   optional MVNParameter mvn_param = 120;
@@ -607,6 +608,17 @@ message InnerProductParameter {
   optional int32 axis = 5 [default = 1];
 }
 
+// Message that stores parameters used by LogLayer
+message LogParameter {
+  // LogLayer computes outputs y = log_base(shift + scale * x), for base > 0.
+  // Or if base is set to the default (-1), base is set to e,
+  // so y = ln(shift + scale * x) = log_e(shift + scale * x)
+  optional float base = 1 [default = -1.0];
+  optional float scale = 2 [default = 1.0];
+  optional float shift = 3 [default = 0.0];
+}
+
+// Message that stores parameters used by LRNLayer
 message LRNParameter {
   optional uint32 local_size = 1 [default = 5];
   optional float alpha = 2 [default = 1.];
index 37b5471..c6e4d27 100644 (file)
@@ -117,6 +117,49 @@ class NeuronLayerTest : public MultiDeviceTest<TypeParam> {
           + slope_data[c] * std::min(bottom_data[i], (Dtype)(0)));
     }
   }
+
+  void LogBottomInit() {
+    FillerParameter filler_param;
+    GaussianFiller<Dtype> filler(filler_param);
+    filler.Fill(this->blob_bottom_);
+    Dtype* bottom_data = this->blob_bottom_->mutable_cpu_data();
+    caffe_exp(this->blob_bottom_->count(), bottom_data, bottom_data);
+  }
+
+  void TestLogForward(const float base, const float scale, const float shift) {
+    LogBottomInit();
+    LayerParameter layer_param;
+    layer_param.mutable_log_param()->set_base(base);
+    layer_param.mutable_log_param()->set_scale(scale);
+    layer_param.mutable_log_param()->set_shift(shift);
+    LogLayer<Dtype> layer(layer_param);
+    layer.SetUp(blob_bottom_vec_, blob_top_vec_);
+    layer.Forward(blob_bottom_vec_, blob_top_vec_);
+    const Dtype kDelta = 2e-4;
+    const Dtype* bottom_data = blob_bottom_->cpu_data();
+    const Dtype* top_data = blob_top_->cpu_data();
+    for (int i = 0; i < blob_bottom_->count(); ++i) {
+      const Dtype bottom_val = bottom_data[i];
+      const Dtype top_val = top_data[i];
+      if (base == -1) {
+        EXPECT_NEAR(top_val, log(shift + scale * bottom_val), kDelta);
+      } else {
+        EXPECT_NEAR(top_val, log(shift + scale * bottom_val) / log(base),
+                    kDelta);
+      }
+    }
+  }
+
+  void TestLogGradient(const float base, const float scale, const float shift) {
+    LogBottomInit();
+    LayerParameter layer_param;
+    layer_param.mutable_log_param()->set_base(base);
+    layer_param.mutable_log_param()->set_scale(scale);
+    layer_param.mutable_log_param()->set_shift(shift);
+    LogLayer<Dtype> layer(layer_param);
+    GradientChecker<Dtype> checker(1e-2, 1e-2);
+    checker.CheckGradientEltwise(&layer, blob_bottom_vec_, blob_top_vec_);
+  }
 };
 
 TYPED_TEST_CASE(NeuronLayerTest, TestDtypesAndDevices);
@@ -339,6 +382,88 @@ TYPED_TEST(NeuronLayerTest, TestExpGradientBase2Shift1Scale3) {
   this->TestExpGradient(kBase, kScale, kShift);
 }
 
+TYPED_TEST(NeuronLayerTest, TestLogLayer) {
+  typedef typename TypeParam::Dtype Dtype;
+  // Test default base of "-1" -- should actually set base := e.
+  const Dtype kBase = -1;
+  const Dtype kScale = 1;
+  const Dtype kShift = 0;
+  this->TestLogForward(kBase, kScale, kShift);
+}
+
+TYPED_TEST(NeuronLayerTest, TestLogGradient) {
+  typedef typename TypeParam::Dtype Dtype;
+  // Test default base of "-1" -- should actually set base := e.
+  const Dtype kBase = -1;
+  const Dtype kScale = 1;
+  const Dtype kShift = 0;
+  this->TestLogGradient(kBase, kScale, kShift);
+}
+
+TYPED_TEST(NeuronLayerTest, TestLogLayerBase2) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kBase = 2;
+  const Dtype kScale = 1;
+  const Dtype kShift = 0;
+  this->TestLogForward(kBase, kScale, kShift);
+}
+
+TYPED_TEST(NeuronLayerTest, TestLogGradientBase2) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kBase = 2;
+  const Dtype kScale = 1;
+  const Dtype kShift = 0;
+  this->TestLogGradient(kBase, kScale, kShift);
+}
+
+TYPED_TEST(NeuronLayerTest, TestLogLayerBase2Shift1) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kBase = 2;
+  const Dtype kScale = 1;
+  const Dtype kShift = 1;
+  this->TestLogForward(kBase, kScale, kShift);
+}
+
+TYPED_TEST(NeuronLayerTest, TestLogGradientBase2Shift1) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kBase = 2;
+  const Dtype kScale = 1;
+  const Dtype kShift = 1;
+  this->TestLogGradient(kBase, kScale, kShift);
+}
+
+TYPED_TEST(NeuronLayerTest, TestLogLayerBase2Scale3) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kBase = 2;
+  const Dtype kScale = 3;
+  const Dtype kShift = 0;
+  this->TestLogForward(kBase, kScale, kShift);
+}
+
+TYPED_TEST(NeuronLayerTest, TestLogGradientBase2Scale3) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kBase = 2;
+  const Dtype kScale = 3;
+  const Dtype kShift = 0;
+  this->TestLogGradient(kBase, kScale, kShift);
+}
+
+TYPED_TEST(NeuronLayerTest, TestLogLayerBase2Shift1Scale3) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kBase = 2;
+  const Dtype kScale = 3;
+  const Dtype kShift = 1;
+  this->TestLogForward(kBase, kScale, kShift);
+}
+
+TYPED_TEST(NeuronLayerTest, TestLogGradientBase2Shift1Scale3) {
+  typedef typename TypeParam::Dtype Dtype;
+  const Dtype kBase = 2;
+  const Dtype kScale = 3;
+  const Dtype kShift = 1;
+  this->TestLogGradient(kBase, kScale, kShift);
+}
+
 TYPED_TEST(NeuronLayerTest, TestDropoutHalf) {
   const float kDropoutRatio = 0.5;
   this->TestDropoutForward(kDropoutRatio);
index 13e17be..0aab6b1 100644 (file)
@@ -207,6 +207,16 @@ void caffe_exp<double>(const int n, const double* a, double* y) {
 }
 
 template <>
+void caffe_log<float>(const int n, const float* a, float* y) {
+  vsLn(n, a, y);
+}
+
+template <>
+void caffe_log<double>(const int n, const double* a, double* y) {
+  vdLn(n, a, y);
+}
+
+template <>
 void caffe_abs<float>(const int n, const float* a, float* y) {
     vsAbs(n, a, y);
 }
index 43e65eb..2631a07 100644 (file)
@@ -325,6 +325,27 @@ void caffe_gpu_exp<double>(const int N, const double* a, double* y) {
 }
 
 template <typename Dtype>
+__global__ void log_kernel(const int n, const Dtype* a, Dtype* y) {
+  CUDA_KERNEL_LOOP(index, n) {
+    y[index] = log(a[index]);
+  }
+}
+
+template <>
+void caffe_gpu_log<float>(const int N, const float* a, float* y) {
+  // NOLINT_NEXT_LINE(whitespace/operators)
+  log_kernel<float><<<CAFFE_GET_BLOCKS(N), CAFFE_CUDA_NUM_THREADS>>>(
+      N, a, y);
+}
+
+template <>
+void caffe_gpu_log<double>(const int N, const double* a, double* y) {
+  // NOLINT_NEXT_LINE(whitespace/operators)
+  log_kernel<double><<<CAFFE_GET_BLOCKS(N), CAFFE_CUDA_NUM_THREADS>>>(
+      N, a, y);
+}
+
+template <typename Dtype>
 __global__ void powx_kernel(const int n, const Dtype* a,
     const Dtype alpha, Dtype* y) {
   CUDA_KERNEL_LOOP(index, n) {