From 0dade9862c62a427e5d48d74b1316a079c107726 Mon Sep 17 00:00:00 2001 From: Peter Goldsborough Date: Tue, 11 Dec 2018 22:38:14 -0800 Subject: [PATCH] Fix serialization (#15033) Summary: Fixes a bug where (de-)/serializing a hierarchy of submodules where one submodule doesn't have any parameters, but its submodules do, doesn't get properly loaded. This had to do with the fact that the old protobuf format couldn't store empty parameters. Fixes https://github.com/pytorch/pytorch/issues/14891 soumith ezyang ebetica Pull Request resolved: https://github.com/pytorch/pytorch/pull/15033 Differential Revision: D13411322 Pulled By: goldsborough fbshipit-source-id: 2ef73b2aa93fa9e46b1cbe1fd47d9f134d6016d5 --- test/cpp/api/serialize.cpp | 35 +++++++++++++++++++++++++++++++- torch/csrc/api/include/torch/nn/module.h | 10 +++++++++ torch/csrc/api/include/torch/nn/pimpl.h | 18 ++++++---------- torch/csrc/api/src/nn/module.cpp | 27 ++++++++++++++++-------- 4 files changed, 69 insertions(+), 21 deletions(-) diff --git a/test/cpp/api/serialize.cpp b/test/cpp/api/serialize.cpp index f6e2b03..dab0dc4 100644 --- a/test/cpp/api/serialize.cpp +++ b/test/cpp/api/serialize.cpp @@ -215,7 +215,9 @@ TEST(SerializeTest, Optim) { TEST(SerializeTest, XOR_CUDA) { torch::manual_seed(0); // We better be able to save and load a XOR model! - auto getLoss = [](Sequential model, uint32_t batch_size, bool is_cuda=false) { + auto getLoss = [](Sequential model, + uint32_t batch_size, + bool is_cuda = false) { auto inputs = torch::empty({batch_size, 2}); auto labels = torch::empty({batch_size}); if (is_cuda) { @@ -269,3 +271,34 @@ TEST(SerializeTest, XOR_CUDA) { loss = getLoss(model3, 100, true); ASSERT_LT(loss.item(), 0.1); } + +TEST( + SerializeTest, + CanSerializeModulesWithIntermediateModulesWithoutParametersOrBuffers) { + struct C : torch::nn::Module { + C() { + register_buffer("foo", torch::ones(5, torch::kInt32)); + } + }; + struct B : torch::nn::Module {}; + struct A : torch::nn::Module { + A() { + register_module("b", std::make_shared()); + register_module("c", std::make_shared()); + } + }; + struct M : torch::nn::Module { + M() { + register_module("a", std::make_shared()); + } + }; + + auto out = std::make_shared(); + std::stringstream ss; + torch::save(out, ss); + auto in = std::make_shared(); + torch::load(in, ss); + + const int output = in->named_buffers()["a.c.foo"].sum().item(); + ASSERT_EQ(output, 5); +} diff --git a/torch/csrc/api/include/torch/nn/module.h b/torch/csrc/api/include/torch/nn/module.h index c298419..d027040 100644 --- a/torch/csrc/api/include/torch/nn/module.h +++ b/torch/csrc/api/include/torch/nn/module.h @@ -495,6 +495,16 @@ class TORCH_API Module : public std::enable_shared_from_this { bool is_training_{true}; }; +/// Serialize a `Module` pointer into an `OutputArchive`. +TORCH_API serialize::OutputArchive& operator<<( + serialize::OutputArchive& archive, + const std::shared_ptr& module); + +/// Deserializes a `Module` from an `InputArchive`. +TORCH_API serialize::InputArchive& operator>>( + serialize::InputArchive& archive, + const std::shared_ptr& module); + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ nn::Module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ template diff --git a/torch/csrc/api/include/torch/nn/pimpl.h b/torch/csrc/api/include/torch/nn/pimpl.h index 58998be..790a63d 100644 --- a/torch/csrc/api/include/torch/nn/pimpl.h +++ b/torch/csrc/api/include/torch/nn/pimpl.h @@ -57,12 +57,8 @@ class ModuleHolder : torch::detail::ModuleHolderIndicator { typename Head, typename... Tail, typename = typename std::enable_if< - !( - torch::detail::is_module_holder_of::value - && (sizeof...(Tail) == 0) - ) - >::type - > + !(torch::detail::is_module_holder_of::value && + (sizeof...(Tail) == 0))>::type> explicit ModuleHolder(Head&& head, Tail&&... tail) : impl_(new Contained( std::forward(head), @@ -160,22 +156,20 @@ class ModuleHolder : torch::detail::ModuleHolderIndicator { } }; -/// Serializes an `OptimizerBase` into an `OutputArchive`. +/// Serializes a `ModuleHolder` into an `OutputArchive`. template serialize::OutputArchive& operator<<( serialize::OutputArchive& archive, const nn::ModuleHolder& module) { - module->save(archive); - return archive; + return archive << module.ptr(); } -/// Deserializes a `Tensor` from an `InputArchive`. +/// Deserializes a `ModuleHolder` from an `InputArchive`. template serialize::InputArchive& operator>>( serialize::InputArchive& archive, nn::ModuleHolder& module) { - module->load(archive); - return archive; + return archive >> module.ptr(); } } // namespace nn diff --git a/torch/csrc/api/src/nn/module.cpp b/torch/csrc/api/src/nn/module.cpp index 38c4789..cc17506 100644 --- a/torch/csrc/api/src/nn/module.cpp +++ b/torch/csrc/api/src/nn/module.cpp @@ -292,14 +292,9 @@ void Module::load(serialize::InputArchive& archive) { archive.read(buffer.key(), buffer.value(), /*is_buffer=*/true); } for (const auto& child : children_) { - // Modules that have no state at all (parameters or buffers) are currently - // not stored in Protobuf at all, so we can just skip them. - if (!child.value()->parameters_.is_empty() || - !child.value()->buffers_.is_empty()) { - serialize::InputArchive child_archive; - archive.read(child.key(), child_archive); - child.value()->load(child_archive); - } + serialize::InputArchive child_archive; + archive.read(child.key(), child_archive); + child.value()->load(child_archive); } } @@ -356,5 +351,21 @@ std::shared_ptr Module::shared_from_this_checked() const { } return std::const_pointer_cast(ptr); } + +serialize::OutputArchive& operator<<( + serialize::OutputArchive& archive, + const std::shared_ptr& module) { + AT_CHECK(module != nullptr, "Cannot serialize empty module"); + module->save(archive); + return archive; +} + +serialize::InputArchive& operator>>( + serialize::InputArchive& archive, + const std::shared_ptr& module) { + AT_CHECK(module != nullptr, "Cannot deserialize empty module"); + module->load(archive); + return archive; +} } // namespace nn } // namespace torch -- 2.7.4