Fix bug in torch::load and unpack torch::optim::detail namespace (#15926)
authorPeter Goldsborough <psag@fb.com>
Thu, 10 Jan 2019 21:50:41 +0000 (13:50 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 10 Jan 2019 21:55:50 +0000 (13:55 -0800)
Summary:
Wasn't clearing optimizer buffers before adding new entries to it during deserialization. Successive calls to `torch::load` with the same optimizer would just append to the buffer container. Also moved `serialize()` function from `torch::optim::detail` into `torch::optim` so users can use it for custom optimizers.

Fixes #15792

ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15926

Differential Revision: D13623615

Pulled By: goldsborough

fbshipit-source-id: e193091f25f56a95f2a9648af312cb7caa45f300

torch/csrc/api/include/torch/optim/adagrad.h
torch/csrc/api/include/torch/optim/adam.h
torch/csrc/api/include/torch/optim/lbfgs.h
torch/csrc/api/include/torch/optim/optimizer.h
torch/csrc/api/include/torch/optim/rmsprop.h
torch/csrc/api/include/torch/optim/serialize.h
torch/csrc/api/src/optim/serialize.cpp
torch/csrc/api/src/optim/sgd.cpp

index 265710b..1174dec 100644 (file)
@@ -49,8 +49,8 @@ class TORCH_API Adagrad : public Optimizer {
 
   template <typename Self, typename Archive>
   static void serialize(Self& self, Archive& archive) {
-    TORCH_OPTIM_SERIALIZE(sum_buffers);
-    TORCH_OPTIM_SERIALIZE(step_buffers);
+    _TORCH_OPTIM_SERIALIZE(sum_buffers);
+    _TORCH_OPTIM_SERIALIZE(step_buffers);
   }
 };
 } // namespace optim
index 50872f9..f21f87d 100644 (file)
@@ -52,10 +52,10 @@ class TORCH_API Adam : public Optimizer {
 
   template <typename Self, typename Archive>
   static void serialize(Self& self, Archive& archive) {
-    TORCH_OPTIM_SERIALIZE(step_buffers);
-    TORCH_OPTIM_SERIALIZE(exp_average_buffers);
-    TORCH_OPTIM_SERIALIZE(exp_average_sq_buffers);
-    TORCH_OPTIM_SERIALIZE(max_exp_average_sq_buffers);
+    _TORCH_OPTIM_SERIALIZE(step_buffers);
+    _TORCH_OPTIM_SERIALIZE(exp_average_buffers);
+    _TORCH_OPTIM_SERIALIZE(exp_average_sq_buffers);
+    _TORCH_OPTIM_SERIALIZE(max_exp_average_sq_buffers);
   }
 };
 } // namespace optim
index c5c5dee..33b877b 100644 (file)
@@ -65,8 +65,8 @@ class TORCH_API LBFGS : public LossClosureOptimizer {
     archive("H_diag", self.H_diag, /*is_buffer=*/true);
     archive("prev_flat_grad", self.prev_flat_grad, /*is_buffer=*/true);
     archive("prev_loss", self.prev_loss, /*is_buffer=*/true);
-    detail::serialize(archive, "old_dirs", self.old_dirs);
-    detail::serialize(archive, "old_stps", self.old_stps);
+    optim::serialize(archive, "old_dirs", self.old_dirs);
+    optim::serialize(archive, "old_stps", self.old_stps);
   }
 };
 } // namespace optim
index eb835b4..570a040 100644 (file)
@@ -53,7 +53,10 @@ class TORCH_API OptimizerBase {
   /// Returns the number of parameters referenced by the optimizer.
   size_t size() const noexcept;
 
+  /// Serializes the optimizer state into the given `archive`.
   virtual void save(serialize::OutputArchive& archive) const;
+
+  /// Deserializes the optimizer state from the given `archive`.
   virtual void load(serialize::InputArchive& archive);
 
  protected:
index a2ad58f..7f80c71 100644 (file)
@@ -56,9 +56,9 @@ class TORCH_API RMSprop : public Optimizer {
 
   template <typename Self, typename Archive>
   static void serialize(Self& self, Archive& archive) {
-    TORCH_OPTIM_SERIALIZE(square_average_buffers);
-    TORCH_OPTIM_SERIALIZE(momentum_buffers);
-    TORCH_OPTIM_SERIALIZE(grad_average_buffers);
+    _TORCH_OPTIM_SERIALIZE(square_average_buffers);
+    _TORCH_OPTIM_SERIALIZE(momentum_buffers);
+    _TORCH_OPTIM_SERIALIZE(grad_average_buffers);
   }
 };
 } // namespace optim
index 0aec75c..0bd1ee3 100644 (file)
@@ -11,7 +11,6 @@
 
 namespace torch {
 namespace optim {
-namespace detail {
 
 // Note: These functions are all called `serialize()` so they can be called
 // inside a template where the archive type is a template type and can thus be
@@ -49,6 +48,7 @@ void serialize(
     serialize::InputArchive& archive,
     const std::string& key,
     BufferContainer& buffers) {
+  buffers.clear();
   torch::Tensor size_tensor;
   archive.read(key + "/size", size_tensor);
   const size_t size = size_tensor.item<int64_t>();
@@ -59,9 +59,8 @@ void serialize(
   }
 }
 
-#define TORCH_OPTIM_SERIALIZE(name) \
-  torch::optim::detail::serialize(archive, #name, self.name)
+#define _TORCH_OPTIM_SERIALIZE(name) \
+  torch::optim::serialize(archive, #name, self.name)
 
-} // namespace detail
 } // namespace optim
 } // namespace torch
index 02d1000..c6797e7 100644 (file)
@@ -11,7 +11,6 @@
 
 namespace torch {
 namespace optim {
-namespace detail {
 void serialize(
     serialize::OutputArchive& archive,
     const std::string& key,
@@ -28,13 +27,12 @@ void serialize(
     serialize::InputArchive& archive,
     const std::string& key,
     std::vector<int64_t>& steps) {
+  steps.clear();
   std::vector<torch::Tensor> tensors;
   serialize(archive, key, tensors);
-  steps.clear();
   for (const auto& step : tensors) {
     steps.push_back(step.item<int64_t>());
   }
 }
-} // namespace detail
 } // namespace optim
 } // namespace torch
index d7f43e8..828acd3 100644 (file)
@@ -49,11 +49,11 @@ void SGD::step() {
 }
 
 void SGD::save(serialize::OutputArchive& archive) const {
-  detail::serialize(archive, "momentum_buffers", momentum_buffers);
+  optim::serialize(archive, "momentum_buffers", momentum_buffers);
 }
 
 void SGD::load(serialize::InputArchive& archive) {
-  detail::serialize(archive, "momentum_buffers", momentum_buffers);
+  optim::serialize(archive, "momentum_buffers", momentum_buffers);
 }
 } // namespace optim
 } // namespace torch