template <typename Self, typename Archive>
static void serialize(Self& self, Archive& archive) {
- TORCH_OPTIM_SERIALIZE(sum_buffers);
- TORCH_OPTIM_SERIALIZE(step_buffers);
+ _TORCH_OPTIM_SERIALIZE(sum_buffers);
+ _TORCH_OPTIM_SERIALIZE(step_buffers);
}
};
} // namespace optim
template <typename Self, typename Archive>
static void serialize(Self& self, Archive& archive) {
- TORCH_OPTIM_SERIALIZE(step_buffers);
- TORCH_OPTIM_SERIALIZE(exp_average_buffers);
- TORCH_OPTIM_SERIALIZE(exp_average_sq_buffers);
- TORCH_OPTIM_SERIALIZE(max_exp_average_sq_buffers);
+ _TORCH_OPTIM_SERIALIZE(step_buffers);
+ _TORCH_OPTIM_SERIALIZE(exp_average_buffers);
+ _TORCH_OPTIM_SERIALIZE(exp_average_sq_buffers);
+ _TORCH_OPTIM_SERIALIZE(max_exp_average_sq_buffers);
}
};
} // namespace optim
archive("H_diag", self.H_diag, /*is_buffer=*/true);
archive("prev_flat_grad", self.prev_flat_grad, /*is_buffer=*/true);
archive("prev_loss", self.prev_loss, /*is_buffer=*/true);
- detail::serialize(archive, "old_dirs", self.old_dirs);
- detail::serialize(archive, "old_stps", self.old_stps);
+ optim::serialize(archive, "old_dirs", self.old_dirs);
+ optim::serialize(archive, "old_stps", self.old_stps);
}
};
} // namespace optim
/// Returns the number of parameters referenced by the optimizer.
size_t size() const noexcept;
+ /// Serializes the optimizer state into the given `archive`.
virtual void save(serialize::OutputArchive& archive) const;
+
+ /// Deserializes the optimizer state from the given `archive`.
virtual void load(serialize::InputArchive& archive);
protected:
template <typename Self, typename Archive>
static void serialize(Self& self, Archive& archive) {
- TORCH_OPTIM_SERIALIZE(square_average_buffers);
- TORCH_OPTIM_SERIALIZE(momentum_buffers);
- TORCH_OPTIM_SERIALIZE(grad_average_buffers);
+ _TORCH_OPTIM_SERIALIZE(square_average_buffers);
+ _TORCH_OPTIM_SERIALIZE(momentum_buffers);
+ _TORCH_OPTIM_SERIALIZE(grad_average_buffers);
}
};
} // namespace optim
namespace torch {
namespace optim {
-namespace detail {
// Note: These functions are all called `serialize()` so they can be called
// inside a template where the archive type is a template type and can thus be
serialize::InputArchive& archive,
const std::string& key,
BufferContainer& buffers) {
+ buffers.clear();
torch::Tensor size_tensor;
archive.read(key + "/size", size_tensor);
const size_t size = size_tensor.item<int64_t>();
}
}
-#define TORCH_OPTIM_SERIALIZE(name) \
- torch::optim::detail::serialize(archive, #name, self.name)
+#define _TORCH_OPTIM_SERIALIZE(name) \
+ torch::optim::serialize(archive, #name, self.name)
-} // namespace detail
} // namespace optim
} // namespace torch
namespace torch {
namespace optim {
-namespace detail {
void serialize(
serialize::OutputArchive& archive,
const std::string& key,
serialize::InputArchive& archive,
const std::string& key,
std::vector<int64_t>& steps) {
+ steps.clear();
std::vector<torch::Tensor> tensors;
serialize(archive, key, tensors);
- steps.clear();
for (const auto& step : tensors) {
steps.push_back(step.item<int64_t>());
}
}
-} // namespace detail
} // namespace optim
} // namespace torch
}
void SGD::save(serialize::OutputArchive& archive) const {
- detail::serialize(archive, "momentum_buffers", momentum_buffers);
+ optim::serialize(archive, "momentum_buffers", momentum_buffers);
}
void SGD::load(serialize::InputArchive& archive) {
- detail::serialize(archive, "momentum_buffers", momentum_buffers);
+ optim::serialize(archive, "momentum_buffers", momentum_buffers);
}
} // namespace optim
} // namespace torch