// NOTE
// Update the next available ID when you add a new SolverParameter field.
//
-// SolverParameter next available ID: 35 (last added: stepvalue)
+// SolverParameter next available ID: 36 (last added: clip_gradients)
message SolverParameter {
//////////////////////////////////////////////////////////////////////////////
// Specifying the train and test networks
optional int32 stepsize = 13;
// the stepsize for learning rate policy "multistep"
repeated int32 stepvalue = 34;
+
+ // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm,
+ // whenever their actual L2 norm is larger.
+ optional float clip_gradients = 35 [default = -1];
+
optional int32 snapshot = 14 [default = 0]; // The snapshot interval
optional string snapshot_prefix = 15; // The prefix for the snapshot.
// whether to snapshot diff in the results or not. Snapshotting diff will help
}
}
+template <typename Dtype>
+void SGDSolver<Dtype>::ClipGradients() {
+ const Dtype clip_gradients = this->param_.clip_gradients();
+ if (clip_gradients < 0) { return; }
+ const vector<shared_ptr<Blob<Dtype> > >& net_params = this->net_->params();
+ Dtype sumsq_diff = 0;
+ for (int i = 0; i < net_params.size(); ++i) {
+ if (this->net_->param_owners()[i] < 0) {
+ sumsq_diff += net_params[i]->sumsq_diff();
+ }
+ }
+ const Dtype l2norm_diff = std::sqrt(sumsq_diff);
+ if (l2norm_diff > clip_gradients) {
+ Dtype scale_factor = clip_gradients / l2norm_diff;
+ LOG(INFO) << "Gradient clipping: scaling down gradients (L2 norm "
+ << l2norm_diff << " > " << clip_gradients << ") "
+ << "by scale factor " << scale_factor;
+ for (int i = 0; i < net_params.size(); ++i) {
+ if (this->net_->param_owners()[i] < 0) {
+ net_params[i]->scale_diff(scale_factor);
+ }
+ }
+ }
+}
template <typename Dtype>
void SGDSolver<Dtype>::ComputeUpdateValue() {
if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
}
+ ClipGradients();
Dtype momentum = this->param_.momentum();
Dtype weight_decay = this->param_.weight_decay();
string regularization_type = this->param_.regularization_type();
if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
}
+ SGDSolver<Dtype>::ClipGradients();
Dtype momentum = this->param_.momentum();
Dtype weight_decay = this->param_.weight_decay();
string regularization_type = this->param_.regularization_type();
if (this->param_.display() && this->iter_ % this->param_.display() == 0) {
LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate;
}
+ SGDSolver<Dtype>::ClipGradients();
Dtype weight_decay = this->param_.weight_decay();
string regularization_type = this->param_.regularization_type();
switch (Caffe::mode()) {