Make SGD match python (#15840)
authoran-kumar <an-kumar@users.noreply.github.com>
Thu, 10 Jan 2019 06:17:45 +0000 (22:17 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 10 Jan 2019 06:21:14 +0000 (22:21 -0800)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/15530
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15840

Differential Revision: D13608503

Pulled By: goldsborough

fbshipit-source-id: aad17c110d64cbe2c126bccd36d228e4108ffa9a

torch/csrc/api/src/optim/sgd.cpp

index 2a407a8..d7f43e8 100644 (file)
@@ -23,10 +23,10 @@ void SGD::step() {
       continue;
     }
 
-    auto update = options.learning_rate_ * p.grad();
+    auto update = p.grad();
 
     if (options.weight_decay_ > 0) {
-      update += options.learning_rate_ * options.weight_decay_ * p;
+      update += options.weight_decay_ * p;
     }
 
     if (options.momentum_ != 0) {
@@ -43,7 +43,7 @@ void SGD::step() {
     }
 
     NoGradGuard guard;
-    p.add_(-update);
+    p.add_(-options.learning_rate_ * update);
   }
   iteration_ += 1;
 }