vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
history_.clear();
update_.clear();
+ temp_.clear();
for (int i = 0; i < net_params.size(); ++i) {
const Blob<Dtype>* net_param = net_params[i].get();
history_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
update_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
net_param->num(), net_param->channels(), net_param->height(),
net_param->width())));
+ temp_.push_back(shared_ptr<Blob<Dtype> >(new Blob<Dtype>(
+ net_param->num(), net_param->channels(), net_param->height(),
+ net_param->width())));
}
}
}
Dtype momentum = this->param_.momentum();
Dtype weight_decay = this->param_.weight_decay();
+ string regularization_type = this->param_.regularization_type();
switch (Caffe::mode()) {
case Caffe::CPU:
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());
if (local_decay) {
- // add weight decay
- caffe_axpy(net_params[param_id]->count(),
- local_decay * local_rate,
- net_params[param_id]->cpu_data(),
- history_[param_id]->mutable_cpu_data());
+ if (regularization_type == "L2") {
+ // add weight decay
+ caffe_axpy(net_params[param_id]->count(),
+ local_decay * local_rate,
+ net_params[param_id]->cpu_data(),
+ history_[param_id]->mutable_cpu_data());
+ } else if (regularization_type == "L1") {
+ caffe_cpu_sign(net_params[param_id]->count(),
+ net_params[param_id]->cpu_data(),
+ temp_[param_id]->mutable_cpu_data());
+ caffe_axpy(net_params[param_id]->count(),
+ local_decay * local_rate,
+ temp_[param_id]->cpu_data(),
+ history_[param_id]->mutable_cpu_data());
+ } else {
+ LOG(FATAL) << "Unknown regularization type: " << regularization_type;
+ }
}
// copy
caffe_copy(net_params[param_id]->count(),
net_params[param_id]->gpu_diff(), momentum,
history_[param_id]->mutable_gpu_data());
if (local_decay) {
- // add weight decay
- caffe_gpu_axpy(net_params[param_id]->count(),
- local_decay * local_rate,
- net_params[param_id]->gpu_data(),
- history_[param_id]->mutable_gpu_data());
+ if (regularization_type == "L2") {
+ // add weight decay
+ caffe_gpu_axpy(net_params[param_id]->count(),
+ local_decay * local_rate,
+ net_params[param_id]->gpu_data(),
+ history_[param_id]->mutable_gpu_data());
+ } else if (regularization_type == "L1") {
+ caffe_gpu_sign(net_params[param_id]->count(),
+ net_params[param_id]->gpu_data(),
+ temp_[param_id]->mutable_gpu_data());
+ caffe_gpu_axpy(net_params[param_id]->count(),
+ local_decay * local_rate,
+ temp_[param_id]->gpu_data(),
+ history_[param_id]->mutable_gpu_data());
+ } else {
+ LOG(FATAL) << "Unknown regularization type: " << regularization_type;
+ }
}
// copy
caffe_copy(net_params[param_id]->count(),
}
Dtype momentum = this->param_.momentum();
Dtype weight_decay = this->param_.weight_decay();
+ string regularization_type = this->param_.regularization_type();
switch (Caffe::mode()) {
case Caffe::CPU:
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
net_params[param_id]->cpu_diff(), momentum,
this->history_[param_id]->mutable_cpu_data());
if (local_decay) {
- // add weight decay
- caffe_axpy(net_params[param_id]->count(),
- local_decay * local_rate,
- net_params[param_id]->cpu_data(),
- this->history_[param_id]->mutable_cpu_data());
+ if (regularization_type == "L2") {
+ // add weight decay
+ caffe_axpy(net_params[param_id]->count(),
+ local_decay * local_rate,
+ net_params[param_id]->cpu_data(),
+ this->history_[param_id]->mutable_cpu_data());
+ } else if (regularization_type == "L1") {
+ caffe_cpu_sign(net_params[param_id]->count(),
+ net_params[param_id]->cpu_data(),
+ this->temp_[param_id]->mutable_cpu_data());
+ caffe_axpy(net_params[param_id]->count(),
+ local_decay * local_rate,
+ this->temp_[param_id]->cpu_data(),
+ this->history_[param_id]->mutable_cpu_data());
+ } else {
+ LOG(FATAL) << "Unknown regularization type: " << regularization_type;
+ }
}
// compute udpate: step back then over step
caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
net_params[param_id]->gpu_diff(), momentum,
this->history_[param_id]->mutable_gpu_data());
if (local_decay) {
- // add weight decay
- caffe_gpu_axpy(net_params[param_id]->count(),
- local_decay * local_rate,
- net_params[param_id]->gpu_data(),
- this->history_[param_id]->mutable_gpu_data());
+ if (regularization_type == "L2") {
+ // add weight decay
+ caffe_gpu_axpy(net_params[param_id]->count(),
+ local_decay * local_rate,
+ net_params[param_id]->gpu_data(),
+ this->history_[param_id]->mutable_gpu_data());
+ } else if (regularization_type == "L1") {
+ caffe_gpu_sign(net_params[param_id]->count(),
+ net_params[param_id]->gpu_data(),
+ this->temp_[param_id]->mutable_gpu_data());
+ caffe_gpu_axpy(net_params[param_id]->count(),
+ local_decay * local_rate,
+ this->temp_[param_id]->gpu_data(),
+ this->history_[param_id]->mutable_gpu_data());
+ } else {
+ LOG(FATAL) << "Unknown regularization type: " << regularization_type;
+ }
}
// compute udpate: step back then over step
caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum,
LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
}
Dtype weight_decay = this->param_.weight_decay();
+ string regularization_type = this->param_.regularization_type();
switch (Caffe::mode()) {
case Caffe::CPU:
for (int param_id = 0; param_id < net_params.size(); ++param_id) {
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
if (local_decay) {
- // add weight decay
- caffe_axpy(net_params[param_id]->count(),
- local_decay,
- net_params[param_id]->cpu_data(),
- net_params[param_id]->mutable_cpu_diff());
+ if (regularization_type == "L2") {
+ // add weight decay
+ caffe_axpy(net_params[param_id]->count(),
+ local_decay,
+ net_params[param_id]->cpu_data(),
+ this->history_[param_id]->mutable_cpu_data());
+ } else if (regularization_type == "L1") {
+ caffe_cpu_sign(net_params[param_id]->count(),
+ net_params[param_id]->cpu_data(),
+ this->temp_[param_id]->mutable_cpu_data());
+ caffe_axpy(net_params[param_id]->count(),
+ local_decay,
+ this->temp_[param_id]->cpu_data(),
+ this->history_[param_id]->mutable_cpu_data());
+ } else {
+ LOG(FATAL) << "Unknown regularization type: " << regularization_type;
+ }
}
// compute square of gradient in update
Dtype local_decay = weight_decay * net_params_weight_decay[param_id];
if (local_decay) {
- // add weight decay
- caffe_gpu_axpy(net_params[param_id]->count(),
- local_decay,
- net_params[param_id]->gpu_data(),
- net_params[param_id]->mutable_gpu_diff());
+ if (regularization_type == "L2") {
+ // add weight decay
+ caffe_gpu_axpy(net_params[param_id]->count(),
+ local_decay,
+ net_params[param_id]->gpu_data(),
+ this->history_[param_id]->mutable_gpu_data());
+ } else if (regularization_type == "L1") {
+ caffe_gpu_sign(net_params[param_id]->count(),
+ net_params[param_id]->gpu_data(),
+ this->temp_[param_id]->mutable_gpu_data());
+ caffe_gpu_axpy(net_params[param_id]->count(),
+ local_decay,
+ this->temp_[param_id]->gpu_data(),
+ this->history_[param_id]->mutable_gpu_data());
+ } else {
+ LOG(FATAL) << "Unknown regularization type: " << regularization_type;
+ }
}
// compute square of gradient in update