add mnist autoencoder example necessities (sigmoid cross entropy loss
authorJeff Donahue <jeff.donahue@gmail.com>
Tue, 15 Apr 2014 21:52:41 +0000 (14:52 -0700)
committerJeff Donahue <jeff.donahue@gmail.com>
Tue, 15 Apr 2014 21:52:41 +0000 (14:52 -0700)
layer, sparse gaussian filler)

examples/mnist/mnist_autoencoder_solver.prototxt [new file with mode: 0644]
examples/mnist/mnist_autoencoder_test.prototxt [new file with mode: 0644]
examples/mnist/mnist_autoencoder_train.prototxt [new file with mode: 0644]
examples/mnist/train_mnist_autoencoder.sh [new file with mode: 0755]
include/caffe/filler.hpp
include/caffe/vision_layers.hpp
src/caffe/layer_factory.cpp
src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp [new file with mode: 0644]
src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu [new file with mode: 0644]
src/caffe/proto/caffe.proto

diff --git a/examples/mnist/mnist_autoencoder_solver.prototxt b/examples/mnist/mnist_autoencoder_solver.prototxt
new file mode 100644 (file)
index 0000000..b11b2c4
--- /dev/null
@@ -0,0 +1,14 @@
+train_net: "mnist_autoencoder_train.prototxt"
+test_net: "mnist_autoencoder_test.prototxt"
+test_iter: 50
+test_interval: 100
+base_lr: 0.0001
+lr_policy: "fixed"
+display: 20
+max_iter: 4500000
+weight_decay: 0.0005
+snapshot: 10000
+snapshot_prefix: "alexnet_train"
+momentum: 0.9
+solver_mode: 1
+device_id: 1
diff --git a/examples/mnist/mnist_autoencoder_test.prototxt b/examples/mnist/mnist_autoencoder_test.prototxt
new file mode 100644 (file)
index 0000000..bec7a3c
--- /dev/null
@@ -0,0 +1,164 @@
+name: "MNISTAutoencoder"
+layers {
+  top: "data"
+  top: "label"
+  name: "data"
+  type: DATA
+  data_param {
+    source: "mnist-test-leveldb"
+    scale: 0.0039215684
+    batch_size: 100
+  }
+}
+layers {
+  bottom: "data"
+  top: "flatdata"
+  name: "flatdata"
+  type: FLATTEN
+}
+layers {
+  bottom: "data"
+  top: "encode1"
+  name: "encode1"
+  type: INNER_PRODUCT
+  blobs_lr: 1
+  blobs_lr: 1
+  weight_decay: 1
+  weight_decay: 0
+  inner_product_param {
+    num_output: 1000
+  }
+}
+layers {
+  bottom: "encode1"
+  top: "encode1neuron"
+  name: "encode1neuron"
+  type: SIGMOID
+}
+layers {
+  bottom: "encode1neuron"
+  top: "encode2"
+  name: "encode2"
+  type: INNER_PRODUCT
+  blobs_lr: 1
+  blobs_lr: 1
+  weight_decay: 1
+  weight_decay: 0
+  inner_product_param {
+    num_output: 500
+  }
+}
+layers {
+  bottom: "encode2"
+  top: "encode2neuron"
+  name: "encode2neuron"
+  type: SIGMOID
+}
+layers {
+  bottom: "encode2neuron"
+  top: "encode3"
+  name: "encode3"
+  type: INNER_PRODUCT
+  blobs_lr: 1
+  blobs_lr: 1
+  weight_decay: 1
+  weight_decay: 0
+  inner_product_param {
+    num_output: 250
+  }
+}
+layers {
+  bottom: "encode3"
+  top: "encode3neuron"
+  name: "encode3neuron"
+  type: SIGMOID
+}
+layers {
+  bottom: "encode3neuron"
+  top: "encode4"
+  name: "encode4"
+  type: INNER_PRODUCT
+  blobs_lr: 1
+  blobs_lr: 1
+  weight_decay: 1
+  weight_decay: 0
+  inner_product_param {
+    num_output: 30
+  }
+}
+layers {
+  bottom: "encode4"
+  top: "decode4"
+  name: "decode4"
+  type: INNER_PRODUCT
+  blobs_lr: 1
+  blobs_lr: 1
+  weight_decay: 1
+  weight_decay: 0
+  inner_product_param {
+    num_output: 250
+  }
+}
+layers {
+  bottom: "decode4"
+  top: "decode4neuron"
+  name: "decode4neuron"
+  type: SIGMOID
+}
+layers {
+  bottom: "decode4neuron"
+  top: "decode3"
+  name: "decode3"
+  type: INNER_PRODUCT
+  blobs_lr: 1
+  blobs_lr: 1
+  weight_decay: 1
+  weight_decay: 0
+  inner_product_param {
+    num_output: 500
+  }
+}
+layers {
+  bottom: "decode3"
+  top: "decode3neuron"
+  name: "decode3neuron"
+  type: SIGMOID
+}
+layers {
+  bottom: "decode3neuron"
+  top: "decode2"
+  name: "decode2"
+  type: INNER_PRODUCT
+  blobs_lr: 1
+  blobs_lr: 1
+  weight_decay: 1
+  weight_decay: 0
+  inner_product_param {
+    num_output: 1000
+  }
+}
+layers {
+  bottom: "decode2"
+  top: "decode2neuron"
+  name: "decode2neuron"
+  type: SIGMOID
+}
+layers {
+  bottom: "decode2neuron"
+  top: "decode1"
+  name: "decode1"
+  type: INNER_PRODUCT
+  blobs_lr: 1
+  blobs_lr: 1
+  weight_decay: 1
+  weight_decay: 0
+  inner_product_param {
+    num_output: 784
+  }
+}
+layers {
+  bottom: "decode1"
+  bottom: "flatdata"
+  name: "loss"
+  type: EUCLIDEAN_LOSS
+}
diff --git a/examples/mnist/mnist_autoencoder_train.prototxt b/examples/mnist/mnist_autoencoder_train.prototxt
new file mode 100644 (file)
index 0000000..d5201eb
--- /dev/null
@@ -0,0 +1,236 @@
+name: "MNISTAutoencoder"
+layers {
+  top: "data"
+  top: "label"
+  name: "data"
+  type: DATA
+  data_param {
+    source: "mnist-train-leveldb"
+    scale: 0.0039215684
+    batch_size: 100
+  }
+}
+layers {
+  bottom: "data"
+  top: "flatdata"
+  name: "flatdata"
+  type: FLATTEN
+}
+layers {
+  bottom: "data"
+  top: "encode1"
+  name: "encode1"
+  type: INNER_PRODUCT
+  blobs_lr: 1
+  blobs_lr: 1
+  weight_decay: 1
+  weight_decay: 0
+  inner_product_param {
+    num_output: 1000
+    weight_filler {
+      type: "gaussian"
+      std: 1
+      sparse: 15
+    }
+    bias_filler {
+      type: "constant"
+      value: 0
+    }
+  }
+}
+layers {
+  bottom: "encode1"
+  top: "encode1neuron"
+  name: "encode1neuron"
+  type: SIGMOID
+}
+layers {
+  bottom: "encode1neuron"
+  top: "encode2"
+  name: "encode2"
+  type: INNER_PRODUCT
+  blobs_lr: 1
+  blobs_lr: 1
+  weight_decay: 1
+  weight_decay: 0
+  inner_product_param {
+    num_output: 500
+    weight_filler {
+      type: "gaussian"
+      std: 1
+      sparse: 15
+    }
+    bias_filler {
+      type: "constant"
+      value: 0
+    }
+  }
+}
+layers {
+  bottom: "encode2"
+  top: "encode2neuron"
+  name: "encode2neuron"
+  type: SIGMOID
+}
+layers {
+  bottom: "encode2neuron"
+  top: "encode3"
+  name: "encode3"
+  type: INNER_PRODUCT
+  blobs_lr: 1
+  blobs_lr: 1
+  weight_decay: 1
+  weight_decay: 0
+  inner_product_param {
+    num_output: 250
+    weight_filler {
+      type: "gaussian"
+      std: 1
+      sparse: 15
+    }
+    bias_filler {
+      type: "constant"
+      value: 0
+    }
+  }
+}
+layers {
+  bottom: "encode3"
+  top: "encode3neuron"
+  name: "encode3neuron"
+  type: SIGMOID
+}
+layers {
+  bottom: "encode3neuron"
+  top: "encode4"
+  name: "encode4"
+  type: INNER_PRODUCT
+  blobs_lr: 1
+  blobs_lr: 1
+  weight_decay: 1
+  weight_decay: 0
+  inner_product_param {
+    num_output: 30
+    weight_filler {
+      type: "gaussian"
+      std: 1
+      sparse: 15
+    }
+    bias_filler {
+      type: "constant"
+      value: 0
+    }
+  }
+}
+layers {
+  bottom: "encode4"
+  top: "decode4"
+  name: "decode4"
+  type: INNER_PRODUCT
+  blobs_lr: 1
+  blobs_lr: 1
+  weight_decay: 1
+  weight_decay: 0
+  inner_product_param {
+    num_output: 250
+    weight_filler {
+      type: "gaussian"
+      std: 1
+      sparse: 15
+    }
+    bias_filler {
+      type: "constant"
+      value: 0
+    }
+  }
+}
+layers {
+  bottom: "decode4"
+  top: "decode4neuron"
+  name: "decode4neuron"
+  type: SIGMOID
+}
+layers {
+  bottom: "decode4neuron"
+  top: "decode3"
+  name: "decode3"
+  type: INNER_PRODUCT
+  blobs_lr: 1
+  blobs_lr: 1
+  weight_decay: 1
+  weight_decay: 0
+  inner_product_param {
+    num_output: 500
+    weight_filler {
+      type: "gaussian"
+      std: 1
+      sparse: 15
+    }
+    bias_filler {
+      type: "constant"
+      value: 0
+    }
+  }
+}
+layers {
+  bottom: "decode3"
+  top: "decode3neuron"
+  name: "decode3neuron"
+  type: SIGMOID
+}
+layers {
+  bottom: "decode3neuron"
+  top: "decode2"
+  name: "decode2"
+  type: INNER_PRODUCT
+  blobs_lr: 1
+  blobs_lr: 1
+  weight_decay: 1
+  weight_decay: 0
+  inner_product_param {
+    num_output: 1000
+    weight_filler {
+      type: "gaussian"
+      std: 1
+      sparse: 15
+    }
+    bias_filler {
+      type: "constant"
+      value: 0
+    }
+  }
+}
+layers {
+  bottom: "decode2"
+  top: "decode2neuron"
+  name: "decode2neuron"
+  type: SIGMOID
+}
+layers {
+  bottom: "decode2neuron"
+  top: "decode1"
+  name: "decode1"
+  type: INNER_PRODUCT
+  blobs_lr: 1
+  blobs_lr: 1
+  weight_decay: 1
+  weight_decay: 0
+  inner_product_param {
+    num_output: 784
+    weight_filler {
+      type: "gaussian"
+      std: 1
+      sparse: 15
+    }
+    bias_filler {
+      type: "constant"
+      value: 0
+    }
+  }
+}
+layers {
+  bottom: "decode1"
+  bottom: "flatdata"
+  name: "loss"
+  type: SIGMOID_CROSS_ENTROPY_LOSS
+}
diff --git a/examples/mnist/train_mnist_autoencoder.sh b/examples/mnist/train_mnist_autoencoder.sh
new file mode 100755 (executable)
index 0000000..af2245e
--- /dev/null
@@ -0,0 +1,4 @@
+#!/bin/bash
+TOOLS=../../build/tools
+
+GLOG_logtostderr=1 $TOOLS/train_net.bin mnist_autoencoder_solver.prototxt
index 50a397e..242f11a 100644 (file)
@@ -41,6 +41,8 @@ class ConstantFiller : public Filler<Dtype> {
     for (int i = 0; i < count; ++i) {
       data[i] = value;
     }
+    CHECK_EQ(this->filler_param_.sparse(), -1)
+         << "Sparsity not supported by this Filler.";
   }
 };
 
@@ -53,6 +55,8 @@ class UniformFiller : public Filler<Dtype> {
     CHECK(blob->count());
     caffe_rng_uniform<Dtype>(blob->count(), Dtype(this->filler_param_.min()),
         Dtype(this->filler_param_.max()), blob->mutable_cpu_data());
+    CHECK_EQ(this->filler_param_.sparse(), -1)
+         << "Sparsity not supported by this Filler.";
   }
 };
 
@@ -66,7 +70,28 @@ class GaussianFiller : public Filler<Dtype> {
     CHECK(blob->count());
     caffe_rng_gaussian<Dtype>(blob->count(), Dtype(this->filler_param_.mean()),
         Dtype(this->filler_param_.std()), blob->mutable_cpu_data());
+    int sparse = this->filler_param_.sparse();
+    CHECK_GE(sparse, -1);
+    if (sparse >= 0) {
+      // Sparse initialization is implemented for "weight" blobs; i.e. matrices.
+      // These have num == channels == 1; height is number of inputs; width is
+      // number of outputs.  The 'sparse' variable specifies the mean number
+      // of non-zero input weights for a given output.
+      CHECK_EQ(blob->num(), 1);
+      CHECK_EQ(blob->channels(), 1);
+      int num_inputs = blob->height();
+      Dtype non_zero_probability = Dtype(sparse) / Dtype(num_inputs);
+      rand_vec_.reset(new SyncedMemory(blob->count() * sizeof(int)));
+      int* mask = reinterpret_cast<int*>(rand_vec_->mutable_cpu_data());
+      caffe_rng_bernoulli(blob->count(), non_zero_probability, mask);
+      for (int i = 0; i < blob->count(); ++i) {
+        data[i] *= mask[i];
+      }
+    }
   }
+
+ protected:
+  shared_ptr<SyncedMemory> rand_vec_;
 };
 
 template <typename Dtype>
@@ -91,6 +116,8 @@ class PositiveUnitballFiller : public Filler<Dtype> {
         data[i * dim + j] /= sum;
       }
     }
+    CHECK_EQ(this->filler_param_.sparse(), -1)
+         << "Sparsity not supported by this Filler.";
   }
 };
 
@@ -113,6 +140,8 @@ class XavierFiller : public Filler<Dtype> {
     Dtype scale = sqrt(Dtype(3) / fan_in);
     caffe_rng_uniform<Dtype>(blob->count(), -scale, scale,
         blob->mutable_cpu_data());
+    CHECK_EQ(this->filler_param_.sparse(), -1)
+         << "Sparsity not supported by this Filler.";
   }
 };
 
index 5af7b28..7cd5159 100644 (file)
@@ -135,6 +135,34 @@ class SigmoidLayer : public NeuronLayer<Dtype> {
 };
 
 template <typename Dtype>
+class SigmoidCrossEntropyLossLayer : public Layer<Dtype> {
+ public:
+  explicit SigmoidCrossEntropyLossLayer(const LayerParameter& param)
+      : Layer<Dtype>(param),
+          sigmoid_layer_(new SigmoidLayer<Dtype>(param)),
+          sigmoid_output_(new Blob<Dtype>()) {}
+  virtual void SetUp(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+
+ protected:
+  virtual Dtype Forward_cpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual Dtype Forward_gpu(const vector<Blob<Dtype>*>& bottom,
+      vector<Blob<Dtype>*>* top);
+  virtual void Backward_cpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+  virtual void Backward_gpu(const vector<Blob<Dtype>*>& top,
+      const bool propagate_down, vector<Blob<Dtype>*>* bottom);
+
+  shared_ptr<SigmoidLayer<Dtype> > sigmoid_layer_;
+  // sigmoid_output stores the output of the sigmoid layer.
+  shared_ptr<Blob<Dtype> > sigmoid_output_;
+  // Vector holders to call the underlying softmax layer forward and backward.
+  vector<Blob<Dtype>*> sigmoid_bottom_vec_;
+  vector<Blob<Dtype>*> sigmoid_top_vec_;
+};
+
+template <typename Dtype>
 class TanHLayer : public NeuronLayer<Dtype> {
  public:
   explicit TanHLayer(const LayerParameter& param)
index d30ffee..cb45751 100644 (file)
@@ -64,6 +64,8 @@ Layer<Dtype>* GetLayer(const LayerParameter& param) {
     return new ReLULayer<Dtype>(param);
   case LayerParameter_LayerType_SIGMOID:
     return new SigmoidLayer<Dtype>(param);
+  case LayerParameter_LayerType_SIGMOID_CROSS_ENTROPY_LOSS:
+    return new SigmoidCrossEntropyLossLayer<Dtype>(param);
   case LayerParameter_LayerType_SOFTMAX:
     return new SoftmaxLayer<Dtype>(param);
   case LayerParameter_LayerType_SOFTMAX_LOSS:
diff --git a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp
new file mode 100644 (file)
index 0000000..7e3af42
--- /dev/null
@@ -0,0 +1,66 @@
+// Copyright 2014 BVLC and contributors.
+
+#include <algorithm>
+#include <cfloat>
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/util/math_functions.hpp"
+
+using std::max;
+
+namespace caffe {
+
+template <typename Dtype>
+void SigmoidCrossEntropyLossLayer<Dtype>::SetUp(
+    const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+  CHECK_EQ(bottom.size(), 2) <<
+      "SigmoidCrossEntropyLoss Layer takes two blobs as input.";
+  CHECK_EQ(top->size(), 0) <<
+      "SigmoidCrossEntropyLoss Layer takes no blob as output.";
+  sigmoid_bottom_vec_.clear();
+  sigmoid_bottom_vec_.push_back(bottom[0]);
+  sigmoid_top_vec_.push_back(sigmoid_output_.get());
+  sigmoid_layer_->SetUp(sigmoid_bottom_vec_, &sigmoid_top_vec_);
+}
+
+template <typename Dtype>
+Dtype SigmoidCrossEntropyLossLayer<Dtype>::Forward_cpu(
+    const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+  // The forward pass computes the sigmoid outputs.
+  sigmoid_bottom_vec_[0] = bottom[0];
+  sigmoid_layer_->Forward(sigmoid_bottom_vec_, &sigmoid_top_vec_);
+  // Compute the loss (negative log likelihood)
+  int count = bottom[0]->count();
+  int num = bottom[0]->num();
+  // Stable version of loss computation from input data
+  const Dtype* input_data = bottom[0]->cpu_data();
+  const Dtype* ground_truth = bottom[1]->cpu_data();
+  Dtype loss = 0;
+  for (int i = 0; i < count; ++i) {
+    loss -= input_data[i] * (ground_truth[i] - (input_data[i] >= 0)) -
+        log(1 + exp(input_data[i] - 2 * input_data[i] * (input_data[i] >= 0)));
+  }
+  return loss / num;
+}
+
+template <typename Dtype>
+void SigmoidCrossEntropyLossLayer<Dtype>::Backward_cpu(
+    const vector<Blob<Dtype>*>& top, const bool propagate_down,
+    vector<Blob<Dtype>*>* bottom) {
+  // First, compute the diff
+  int count = (*bottom)[0]->count();
+  int num = (*bottom)[0]->num();
+  const Dtype* sigmoid_output_data = sigmoid_output_->cpu_data();
+  const Dtype* ground_truth = (*bottom)[1]->cpu_data();
+  Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff();
+  caffe_sub(count, sigmoid_output_data, ground_truth, bottom_diff);
+  // Scale down gradient
+  caffe_scal(count, Dtype(1) / num, bottom_diff);
+}
+
+INSTANTIATE_CLASS(SigmoidCrossEntropyLossLayer);
+
+
+}  // namespace caffe
diff --git a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu
new file mode 100644 (file)
index 0000000..64bc476
--- /dev/null
@@ -0,0 +1,54 @@
+// Copyright 2014 BVLC and contributors.
+
+#include <algorithm>
+#include <cfloat>
+#include <vector>
+
+#include "caffe/layer.hpp"
+#include "caffe/vision_layers.hpp"
+#include "caffe/util/math_functions.hpp"
+
+using std::max;
+
+namespace caffe {
+
+template <typename Dtype>
+Dtype SigmoidCrossEntropyLossLayer<Dtype>::Forward_gpu(
+    const vector<Blob<Dtype>*>& bottom, vector<Blob<Dtype>*>* top) {
+  // The forward pass computes the sigmoid outputs.
+  sigmoid_bottom_vec_[0] = bottom[0];
+  sigmoid_layer_->Forward(sigmoid_bottom_vec_, &sigmoid_top_vec_);
+  // Compute the loss (negative log likelihood)
+  int count = bottom[0]->count();
+  int num = bottom[0]->num();
+  // Stable version of loss computation from input data
+  const Dtype* input_data = bottom[0]->cpu_data();
+  const Dtype* ground_truth = bottom[1]->cpu_data();
+  Dtype loss = 0;
+  for (int i = 0; i < count; ++i) {
+    loss -= input_data[i] * (ground_truth[i] - (input_data[i] >= 0)) -
+        log(1 + exp(input_data[i] - 2 * input_data[i] * (input_data[i] >= 0)));
+  }
+  return loss / num;
+}
+
+template <typename Dtype>
+void SigmoidCrossEntropyLossLayer<Dtype>::Backward_gpu(
+    const vector<Blob<Dtype>*>& top, const bool propagate_down,
+    vector<Blob<Dtype>*>* bottom) {
+  // First, compute the diff
+  int count = (*bottom)[0]->count();
+  int num = (*bottom)[0]->num();
+  const Dtype* sigmoid_output_data = sigmoid_output_->gpu_data();
+  const Dtype* ground_truth = (*bottom)[1]->gpu_data();
+  Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff();
+  caffe_gpu_copy(count, sigmoid_output_data, bottom_diff);
+  caffe_gpu_axpy(count, Dtype(-1), ground_truth, bottom_diff);
+  // Scale down gradient
+  caffe_gpu_scal(count, Dtype(1) / num, bottom_diff);
+}
+
+INSTANTIATE_CLASS(SigmoidCrossEntropyLossLayer);
+
+
+}  // namespace caffe
index da7824c..19953bf 100644 (file)
@@ -34,8 +34,11 @@ message FillerParameter {
   optional float value = 2 [default = 0]; // the value in constant filler
   optional float min = 3 [default = 0]; // the min value in uniform filler
   optional float max = 4 [default = 1]; // the max value in uniform filler
-  optional float mean = 5 [default = 0]; // the mean value in gaussian filler
-  optional float std = 6 [default = 1]; // the std value in gaussian filler
+  optional float mean = 5 [default = 0]; // the mean value in Gaussian filler
+  optional float std = 6 [default = 1]; // the std value in Gaussian filler
+  // The expected number of non-zero input weights for a given output in
+  // Gaussian filler -- the default -1 means don't perform sparsification.
+  optional int32 sparse = 7 [default = -1];
 }
 
 message NetParameter {
@@ -129,6 +132,7 @@ message LayerParameter {
     POWER = 26;
     RELU = 18;
     SIGMOID = 19;
+    SIGMOID_CROSS_ENTROPY_LOSS = 1000;
     SOFTMAX = 20;
     SOFTMAX_LOSS = 21;
     SPLIT = 22;