From: Jeff Donahue Date: Tue, 15 Apr 2014 21:52:41 +0000 (-0700) Subject: add mnist autoencoder example necessities (sigmoid cross entropy loss X-Git-Tag: submit/tizen/20180823.020014~692^2~36^2~8 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=a767caf273b3332cc33092da061fb498a3008443;p=platform%2Fupstream%2Fcaffeonacl.git add mnist autoencoder example necessities (sigmoid cross entropy loss layer, sparse gaussian filler) --- diff --git a/examples/mnist/mnist_autoencoder_solver.prototxt b/examples/mnist/mnist_autoencoder_solver.prototxt new file mode 100644 index 0000000..b11b2c4 --- /dev/null +++ b/examples/mnist/mnist_autoencoder_solver.prototxt @@ -0,0 +1,14 @@ +train_net: "mnist_autoencoder_train.prototxt" +test_net: "mnist_autoencoder_test.prototxt" +test_iter: 50 +test_interval: 100 +base_lr: 0.0001 +lr_policy: "fixed" +display: 20 +max_iter: 4500000 +weight_decay: 0.0005 +snapshot: 10000 +snapshot_prefix: "alexnet_train" +momentum: 0.9 +solver_mode: 1 +device_id: 1 diff --git a/examples/mnist/mnist_autoencoder_test.prototxt b/examples/mnist/mnist_autoencoder_test.prototxt new file mode 100644 index 0000000..bec7a3c --- /dev/null +++ b/examples/mnist/mnist_autoencoder_test.prototxt @@ -0,0 +1,164 @@ +name: "MNISTAutoencoder" +layers { + top: "data" + top: "label" + name: "data" + type: DATA + data_param { + source: "mnist-test-leveldb" + scale: 0.0039215684 + batch_size: 100 + } +} +layers { + bottom: "data" + top: "flatdata" + name: "flatdata" + type: FLATTEN +} +layers { + bottom: "data" + top: "encode1" + name: "encode1" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 1 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 1000 + } +} +layers { + bottom: "encode1" + top: "encode1neuron" + name: "encode1neuron" + type: SIGMOID +} +layers { + bottom: "encode1neuron" + top: "encode2" + name: "encode2" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 1 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 500 + } +} +layers { + bottom: "encode2" + top: "encode2neuron" + name: "encode2neuron" + type: SIGMOID +} +layers { + bottom: "encode2neuron" + top: "encode3" + name: "encode3" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 1 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 250 + } +} +layers { + bottom: "encode3" + top: "encode3neuron" + name: "encode3neuron" + type: SIGMOID +} +layers { + bottom: "encode3neuron" + top: "encode4" + name: "encode4" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 1 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 30 + } +} +layers { + bottom: "encode4" + top: "decode4" + name: "decode4" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 1 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 250 + } +} +layers { + bottom: "decode4" + top: "decode4neuron" + name: "decode4neuron" + type: SIGMOID +} +layers { + bottom: "decode4neuron" + top: "decode3" + name: "decode3" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 1 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 500 + } +} +layers { + bottom: "decode3" + top: "decode3neuron" + name: "decode3neuron" + type: SIGMOID +} +layers { + bottom: "decode3neuron" + top: "decode2" + name: "decode2" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 1 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 1000 + } +} +layers { + bottom: "decode2" + top: "decode2neuron" + name: "decode2neuron" + type: SIGMOID +} +layers { + bottom: "decode2neuron" + top: "decode1" + name: "decode1" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 1 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 784 + } +} +layers { + bottom: "decode1" + bottom: "flatdata" + name: "loss" + type: EUCLIDEAN_LOSS +} diff --git a/examples/mnist/mnist_autoencoder_train.prototxt b/examples/mnist/mnist_autoencoder_train.prototxt new file mode 100644 index 0000000..d5201eb --- /dev/null +++ b/examples/mnist/mnist_autoencoder_train.prototxt @@ -0,0 +1,236 @@ +name: "MNISTAutoencoder" +layers { + top: "data" + top: "label" + name: "data" + type: DATA + data_param { + source: "mnist-train-leveldb" + scale: 0.0039215684 + batch_size: 100 + } +} +layers { + bottom: "data" + top: "flatdata" + name: "flatdata" + type: FLATTEN +} +layers { + bottom: "data" + top: "encode1" + name: "encode1" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 1 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 1000 + weight_filler { + type: "gaussian" + std: 1 + sparse: 15 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + bottom: "encode1" + top: "encode1neuron" + name: "encode1neuron" + type: SIGMOID +} +layers { + bottom: "encode1neuron" + top: "encode2" + name: "encode2" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 1 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 500 + weight_filler { + type: "gaussian" + std: 1 + sparse: 15 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + bottom: "encode2" + top: "encode2neuron" + name: "encode2neuron" + type: SIGMOID +} +layers { + bottom: "encode2neuron" + top: "encode3" + name: "encode3" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 1 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 250 + weight_filler { + type: "gaussian" + std: 1 + sparse: 15 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + bottom: "encode3" + top: "encode3neuron" + name: "encode3neuron" + type: SIGMOID +} +layers { + bottom: "encode3neuron" + top: "encode4" + name: "encode4" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 1 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 30 + weight_filler { + type: "gaussian" + std: 1 + sparse: 15 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + bottom: "encode4" + top: "decode4" + name: "decode4" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 1 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 250 + weight_filler { + type: "gaussian" + std: 1 + sparse: 15 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + bottom: "decode4" + top: "decode4neuron" + name: "decode4neuron" + type: SIGMOID +} +layers { + bottom: "decode4neuron" + top: "decode3" + name: "decode3" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 1 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 500 + weight_filler { + type: "gaussian" + std: 1 + sparse: 15 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + bottom: "decode3" + top: "decode3neuron" + name: "decode3neuron" + type: SIGMOID +} +layers { + bottom: "decode3neuron" + top: "decode2" + name: "decode2" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 1 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 1000 + weight_filler { + type: "gaussian" + std: 1 + sparse: 15 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + bottom: "decode2" + top: "decode2neuron" + name: "decode2neuron" + type: SIGMOID +} +layers { + bottom: "decode2neuron" + top: "decode1" + name: "decode1" + type: INNER_PRODUCT + blobs_lr: 1 + blobs_lr: 1 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 784 + weight_filler { + type: "gaussian" + std: 1 + sparse: 15 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + bottom: "decode1" + bottom: "flatdata" + name: "loss" + type: SIGMOID_CROSS_ENTROPY_LOSS +} diff --git a/examples/mnist/train_mnist_autoencoder.sh b/examples/mnist/train_mnist_autoencoder.sh new file mode 100755 index 0000000..af2245e --- /dev/null +++ b/examples/mnist/train_mnist_autoencoder.sh @@ -0,0 +1,4 @@ +#!/bin/bash +TOOLS=../../build/tools + +GLOG_logtostderr=1 $TOOLS/train_net.bin mnist_autoencoder_solver.prototxt diff --git a/include/caffe/filler.hpp b/include/caffe/filler.hpp index 50a397e..242f11a 100644 --- a/include/caffe/filler.hpp +++ b/include/caffe/filler.hpp @@ -41,6 +41,8 @@ class ConstantFiller : public Filler { for (int i = 0; i < count; ++i) { data[i] = value; } + CHECK_EQ(this->filler_param_.sparse(), -1) + << "Sparsity not supported by this Filler."; } }; @@ -53,6 +55,8 @@ class UniformFiller : public Filler { CHECK(blob->count()); caffe_rng_uniform(blob->count(), Dtype(this->filler_param_.min()), Dtype(this->filler_param_.max()), blob->mutable_cpu_data()); + CHECK_EQ(this->filler_param_.sparse(), -1) + << "Sparsity not supported by this Filler."; } }; @@ -66,7 +70,28 @@ class GaussianFiller : public Filler { CHECK(blob->count()); caffe_rng_gaussian(blob->count(), Dtype(this->filler_param_.mean()), Dtype(this->filler_param_.std()), blob->mutable_cpu_data()); + int sparse = this->filler_param_.sparse(); + CHECK_GE(sparse, -1); + if (sparse >= 0) { + // Sparse initialization is implemented for "weight" blobs; i.e. matrices. + // These have num == channels == 1; height is number of inputs; width is + // number of outputs. The 'sparse' variable specifies the mean number + // of non-zero input weights for a given output. + CHECK_EQ(blob->num(), 1); + CHECK_EQ(blob->channels(), 1); + int num_inputs = blob->height(); + Dtype non_zero_probability = Dtype(sparse) / Dtype(num_inputs); + rand_vec_.reset(new SyncedMemory(blob->count() * sizeof(int))); + int* mask = reinterpret_cast(rand_vec_->mutable_cpu_data()); + caffe_rng_bernoulli(blob->count(), non_zero_probability, mask); + for (int i = 0; i < blob->count(); ++i) { + data[i] *= mask[i]; + } + } } + + protected: + shared_ptr rand_vec_; }; template @@ -91,6 +116,8 @@ class PositiveUnitballFiller : public Filler { data[i * dim + j] /= sum; } } + CHECK_EQ(this->filler_param_.sparse(), -1) + << "Sparsity not supported by this Filler."; } }; @@ -113,6 +140,8 @@ class XavierFiller : public Filler { Dtype scale = sqrt(Dtype(3) / fan_in); caffe_rng_uniform(blob->count(), -scale, scale, blob->mutable_cpu_data()); + CHECK_EQ(this->filler_param_.sparse(), -1) + << "Sparsity not supported by this Filler."; } }; diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index 5af7b28..7cd5159 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -135,6 +135,34 @@ class SigmoidLayer : public NeuronLayer { }; template +class SigmoidCrossEntropyLossLayer : public Layer { + public: + explicit SigmoidCrossEntropyLossLayer(const LayerParameter& param) + : Layer(param), + sigmoid_layer_(new SigmoidLayer(param)), + sigmoid_output_(new Blob()) {} + virtual void SetUp(const vector*>& bottom, + vector*>* top); + + protected: + virtual Dtype Forward_cpu(const vector*>& bottom, + vector*>* top); + virtual Dtype Forward_gpu(const vector*>& bottom, + vector*>* top); + virtual void Backward_cpu(const vector*>& top, + const bool propagate_down, vector*>* bottom); + virtual void Backward_gpu(const vector*>& top, + const bool propagate_down, vector*>* bottom); + + shared_ptr > sigmoid_layer_; + // sigmoid_output stores the output of the sigmoid layer. + shared_ptr > sigmoid_output_; + // Vector holders to call the underlying softmax layer forward and backward. + vector*> sigmoid_bottom_vec_; + vector*> sigmoid_top_vec_; +}; + +template class TanHLayer : public NeuronLayer { public: explicit TanHLayer(const LayerParameter& param) diff --git a/src/caffe/layer_factory.cpp b/src/caffe/layer_factory.cpp index d30ffee..cb45751 100644 --- a/src/caffe/layer_factory.cpp +++ b/src/caffe/layer_factory.cpp @@ -64,6 +64,8 @@ Layer* GetLayer(const LayerParameter& param) { return new ReLULayer(param); case LayerParameter_LayerType_SIGMOID: return new SigmoidLayer(param); + case LayerParameter_LayerType_SIGMOID_CROSS_ENTROPY_LOSS: + return new SigmoidCrossEntropyLossLayer(param); case LayerParameter_LayerType_SOFTMAX: return new SoftmaxLayer(param); case LayerParameter_LayerType_SOFTMAX_LOSS: diff --git a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp new file mode 100644 index 0000000..7e3af42 --- /dev/null +++ b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cpp @@ -0,0 +1,66 @@ +// Copyright 2014 BVLC and contributors. + +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" +#include "caffe/util/math_functions.hpp" + +using std::max; + +namespace caffe { + +template +void SigmoidCrossEntropyLossLayer::SetUp( + const vector*>& bottom, vector*>* top) { + CHECK_EQ(bottom.size(), 2) << + "SigmoidCrossEntropyLoss Layer takes two blobs as input."; + CHECK_EQ(top->size(), 0) << + "SigmoidCrossEntropyLoss Layer takes no blob as output."; + sigmoid_bottom_vec_.clear(); + sigmoid_bottom_vec_.push_back(bottom[0]); + sigmoid_top_vec_.push_back(sigmoid_output_.get()); + sigmoid_layer_->SetUp(sigmoid_bottom_vec_, &sigmoid_top_vec_); +} + +template +Dtype SigmoidCrossEntropyLossLayer::Forward_cpu( + const vector*>& bottom, vector*>* top) { + // The forward pass computes the sigmoid outputs. + sigmoid_bottom_vec_[0] = bottom[0]; + sigmoid_layer_->Forward(sigmoid_bottom_vec_, &sigmoid_top_vec_); + // Compute the loss (negative log likelihood) + int count = bottom[0]->count(); + int num = bottom[0]->num(); + // Stable version of loss computation from input data + const Dtype* input_data = bottom[0]->cpu_data(); + const Dtype* ground_truth = bottom[1]->cpu_data(); + Dtype loss = 0; + for (int i = 0; i < count; ++i) { + loss -= input_data[i] * (ground_truth[i] - (input_data[i] >= 0)) - + log(1 + exp(input_data[i] - 2 * input_data[i] * (input_data[i] >= 0))); + } + return loss / num; +} + +template +void SigmoidCrossEntropyLossLayer::Backward_cpu( + const vector*>& top, const bool propagate_down, + vector*>* bottom) { + // First, compute the diff + int count = (*bottom)[0]->count(); + int num = (*bottom)[0]->num(); + const Dtype* sigmoid_output_data = sigmoid_output_->cpu_data(); + const Dtype* ground_truth = (*bottom)[1]->cpu_data(); + Dtype* bottom_diff = (*bottom)[0]->mutable_cpu_diff(); + caffe_sub(count, sigmoid_output_data, ground_truth, bottom_diff); + // Scale down gradient + caffe_scal(count, Dtype(1) / num, bottom_diff); +} + +INSTANTIATE_CLASS(SigmoidCrossEntropyLossLayer); + + +} // namespace caffe diff --git a/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu new file mode 100644 index 0000000..64bc476 --- /dev/null +++ b/src/caffe/layers/sigmoid_cross_entropy_loss_layer.cu @@ -0,0 +1,54 @@ +// Copyright 2014 BVLC and contributors. + +#include +#include +#include + +#include "caffe/layer.hpp" +#include "caffe/vision_layers.hpp" +#include "caffe/util/math_functions.hpp" + +using std::max; + +namespace caffe { + +template +Dtype SigmoidCrossEntropyLossLayer::Forward_gpu( + const vector*>& bottom, vector*>* top) { + // The forward pass computes the sigmoid outputs. + sigmoid_bottom_vec_[0] = bottom[0]; + sigmoid_layer_->Forward(sigmoid_bottom_vec_, &sigmoid_top_vec_); + // Compute the loss (negative log likelihood) + int count = bottom[0]->count(); + int num = bottom[0]->num(); + // Stable version of loss computation from input data + const Dtype* input_data = bottom[0]->cpu_data(); + const Dtype* ground_truth = bottom[1]->cpu_data(); + Dtype loss = 0; + for (int i = 0; i < count; ++i) { + loss -= input_data[i] * (ground_truth[i] - (input_data[i] >= 0)) - + log(1 + exp(input_data[i] - 2 * input_data[i] * (input_data[i] >= 0))); + } + return loss / num; +} + +template +void SigmoidCrossEntropyLossLayer::Backward_gpu( + const vector*>& top, const bool propagate_down, + vector*>* bottom) { + // First, compute the diff + int count = (*bottom)[0]->count(); + int num = (*bottom)[0]->num(); + const Dtype* sigmoid_output_data = sigmoid_output_->gpu_data(); + const Dtype* ground_truth = (*bottom)[1]->gpu_data(); + Dtype* bottom_diff = (*bottom)[0]->mutable_gpu_diff(); + caffe_gpu_copy(count, sigmoid_output_data, bottom_diff); + caffe_gpu_axpy(count, Dtype(-1), ground_truth, bottom_diff); + // Scale down gradient + caffe_gpu_scal(count, Dtype(1) / num, bottom_diff); +} + +INSTANTIATE_CLASS(SigmoidCrossEntropyLossLayer); + + +} // namespace caffe diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index da7824c..19953bf 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -34,8 +34,11 @@ message FillerParameter { optional float value = 2 [default = 0]; // the value in constant filler optional float min = 3 [default = 0]; // the min value in uniform filler optional float max = 4 [default = 1]; // the max value in uniform filler - optional float mean = 5 [default = 0]; // the mean value in gaussian filler - optional float std = 6 [default = 1]; // the std value in gaussian filler + optional float mean = 5 [default = 0]; // the mean value in Gaussian filler + optional float std = 6 [default = 1]; // the std value in Gaussian filler + // The expected number of non-zero input weights for a given output in + // Gaussian filler -- the default -1 means don't perform sparsification. + optional int32 sparse = 7 [default = -1]; } message NetParameter { @@ -129,6 +132,7 @@ message LayerParameter { POWER = 26; RELU = 18; SIGMOID = 19; + SIGMOID_CROSS_ENTROPY_LOSS = 1000; SOFTMAX = 20; SOFTMAX_LOSS = 21; SPLIT = 22;