From b1fa19961e266f900fa5283a913b57c054509b5d Mon Sep 17 00:00:00 2001 From: Peter Goldsborough Date: Thu, 10 Jan 2019 13:50:41 -0800 Subject: [PATCH] Fix bug in torch::load and unpack torch::optim::detail namespace (#15926) Summary: Wasn't clearing optimizer buffers before adding new entries to it during deserialization. Successive calls to `torch::load` with the same optimizer would just append to the buffer container. Also moved `serialize()` function from `torch::optim::detail` into `torch::optim` so users can use it for custom optimizers. Fixes #15792 ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/15926 Differential Revision: D13623615 Pulled By: goldsborough fbshipit-source-id: e193091f25f56a95f2a9648af312cb7caa45f300 --- torch/csrc/api/include/torch/optim/adagrad.h | 4 ++-- torch/csrc/api/include/torch/optim/adam.h | 8 ++++---- torch/csrc/api/include/torch/optim/lbfgs.h | 4 ++-- torch/csrc/api/include/torch/optim/optimizer.h | 3 +++ torch/csrc/api/include/torch/optim/rmsprop.h | 6 +++--- torch/csrc/api/include/torch/optim/serialize.h | 7 +++---- torch/csrc/api/src/optim/serialize.cpp | 4 +--- torch/csrc/api/src/optim/sgd.cpp | 4 ++-- 8 files changed, 20 insertions(+), 20 deletions(-) diff --git a/torch/csrc/api/include/torch/optim/adagrad.h b/torch/csrc/api/include/torch/optim/adagrad.h index 265710b..1174dec 100644 --- a/torch/csrc/api/include/torch/optim/adagrad.h +++ b/torch/csrc/api/include/torch/optim/adagrad.h @@ -49,8 +49,8 @@ class TORCH_API Adagrad : public Optimizer { template static void serialize(Self& self, Archive& archive) { - TORCH_OPTIM_SERIALIZE(sum_buffers); - TORCH_OPTIM_SERIALIZE(step_buffers); + _TORCH_OPTIM_SERIALIZE(sum_buffers); + _TORCH_OPTIM_SERIALIZE(step_buffers); } }; } // namespace optim diff --git a/torch/csrc/api/include/torch/optim/adam.h b/torch/csrc/api/include/torch/optim/adam.h index 50872f9..f21f87d 100644 --- a/torch/csrc/api/include/torch/optim/adam.h +++ b/torch/csrc/api/include/torch/optim/adam.h @@ -52,10 +52,10 @@ class TORCH_API Adam : public Optimizer { template static void serialize(Self& self, Archive& archive) { - TORCH_OPTIM_SERIALIZE(step_buffers); - TORCH_OPTIM_SERIALIZE(exp_average_buffers); - TORCH_OPTIM_SERIALIZE(exp_average_sq_buffers); - TORCH_OPTIM_SERIALIZE(max_exp_average_sq_buffers); + _TORCH_OPTIM_SERIALIZE(step_buffers); + _TORCH_OPTIM_SERIALIZE(exp_average_buffers); + _TORCH_OPTIM_SERIALIZE(exp_average_sq_buffers); + _TORCH_OPTIM_SERIALIZE(max_exp_average_sq_buffers); } }; } // namespace optim diff --git a/torch/csrc/api/include/torch/optim/lbfgs.h b/torch/csrc/api/include/torch/optim/lbfgs.h index c5c5dee..33b877b 100644 --- a/torch/csrc/api/include/torch/optim/lbfgs.h +++ b/torch/csrc/api/include/torch/optim/lbfgs.h @@ -65,8 +65,8 @@ class TORCH_API LBFGS : public LossClosureOptimizer { archive("H_diag", self.H_diag, /*is_buffer=*/true); archive("prev_flat_grad", self.prev_flat_grad, /*is_buffer=*/true); archive("prev_loss", self.prev_loss, /*is_buffer=*/true); - detail::serialize(archive, "old_dirs", self.old_dirs); - detail::serialize(archive, "old_stps", self.old_stps); + optim::serialize(archive, "old_dirs", self.old_dirs); + optim::serialize(archive, "old_stps", self.old_stps); } }; } // namespace optim diff --git a/torch/csrc/api/include/torch/optim/optimizer.h b/torch/csrc/api/include/torch/optim/optimizer.h index eb835b4..570a040 100644 --- a/torch/csrc/api/include/torch/optim/optimizer.h +++ b/torch/csrc/api/include/torch/optim/optimizer.h @@ -53,7 +53,10 @@ class TORCH_API OptimizerBase { /// Returns the number of parameters referenced by the optimizer. size_t size() const noexcept; + /// Serializes the optimizer state into the given `archive`. virtual void save(serialize::OutputArchive& archive) const; + + /// Deserializes the optimizer state from the given `archive`. virtual void load(serialize::InputArchive& archive); protected: diff --git a/torch/csrc/api/include/torch/optim/rmsprop.h b/torch/csrc/api/include/torch/optim/rmsprop.h index a2ad58f..7f80c71 100644 --- a/torch/csrc/api/include/torch/optim/rmsprop.h +++ b/torch/csrc/api/include/torch/optim/rmsprop.h @@ -56,9 +56,9 @@ class TORCH_API RMSprop : public Optimizer { template static void serialize(Self& self, Archive& archive) { - TORCH_OPTIM_SERIALIZE(square_average_buffers); - TORCH_OPTIM_SERIALIZE(momentum_buffers); - TORCH_OPTIM_SERIALIZE(grad_average_buffers); + _TORCH_OPTIM_SERIALIZE(square_average_buffers); + _TORCH_OPTIM_SERIALIZE(momentum_buffers); + _TORCH_OPTIM_SERIALIZE(grad_average_buffers); } }; } // namespace optim diff --git a/torch/csrc/api/include/torch/optim/serialize.h b/torch/csrc/api/include/torch/optim/serialize.h index 0aec75c..0bd1ee3 100644 --- a/torch/csrc/api/include/torch/optim/serialize.h +++ b/torch/csrc/api/include/torch/optim/serialize.h @@ -11,7 +11,6 @@ namespace torch { namespace optim { -namespace detail { // Note: These functions are all called `serialize()` so they can be called // inside a template where the archive type is a template type and can thus be @@ -49,6 +48,7 @@ void serialize( serialize::InputArchive& archive, const std::string& key, BufferContainer& buffers) { + buffers.clear(); torch::Tensor size_tensor; archive.read(key + "/size", size_tensor); const size_t size = size_tensor.item(); @@ -59,9 +59,8 @@ void serialize( } } -#define TORCH_OPTIM_SERIALIZE(name) \ - torch::optim::detail::serialize(archive, #name, self.name) +#define _TORCH_OPTIM_SERIALIZE(name) \ + torch::optim::serialize(archive, #name, self.name) -} // namespace detail } // namespace optim } // namespace torch diff --git a/torch/csrc/api/src/optim/serialize.cpp b/torch/csrc/api/src/optim/serialize.cpp index 02d1000..c6797e7 100644 --- a/torch/csrc/api/src/optim/serialize.cpp +++ b/torch/csrc/api/src/optim/serialize.cpp @@ -11,7 +11,6 @@ namespace torch { namespace optim { -namespace detail { void serialize( serialize::OutputArchive& archive, const std::string& key, @@ -28,13 +27,12 @@ void serialize( serialize::InputArchive& archive, const std::string& key, std::vector& steps) { + steps.clear(); std::vector tensors; serialize(archive, key, tensors); - steps.clear(); for (const auto& step : tensors) { steps.push_back(step.item()); } } -} // namespace detail } // namespace optim } // namespace torch diff --git a/torch/csrc/api/src/optim/sgd.cpp b/torch/csrc/api/src/optim/sgd.cpp index d7f43e8..828acd3 100644 --- a/torch/csrc/api/src/optim/sgd.cpp +++ b/torch/csrc/api/src/optim/sgd.cpp @@ -49,11 +49,11 @@ void SGD::step() { } void SGD::save(serialize::OutputArchive& archive) const { - detail::serialize(archive, "momentum_buffers", momentum_buffers); + optim::serialize(archive, "momentum_buffers", momentum_buffers); } void SGD::load(serialize::InputArchive& archive) { - detail::serialize(archive, "momentum_buffers", momentum_buffers); + optim::serialize(archive, "momentum_buffers", momentum_buffers); } } // namespace optim } // namespace torch -- 2.7.4