Fix serialization (#15033)
authorPeter Goldsborough <psag@fb.com>
Wed, 12 Dec 2018 06:38:14 +0000 (22:38 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 12 Dec 2018 06:43:36 +0000 (22:43 -0800)
Summary:
Fixes a bug where (de-)/serializing a hierarchy of submodules where one submodule doesn't have any parameters, but its submodules do, doesn't get properly loaded. This had to do with the fact that the old protobuf format couldn't store empty parameters.

Fixes https://github.com/pytorch/pytorch/issues/14891

soumith ezyang ebetica
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15033

Differential Revision: D13411322

Pulled By: goldsborough

fbshipit-source-id: 2ef73b2aa93fa9e46b1cbe1fd47d9f134d6016d5

test/cpp/api/serialize.cpp
torch/csrc/api/include/torch/nn/module.h
torch/csrc/api/include/torch/nn/pimpl.h
torch/csrc/api/src/nn/module.cpp

index f6e2b03..dab0dc4 100644 (file)
@@ -215,7 +215,9 @@ TEST(SerializeTest, Optim) {
 TEST(SerializeTest, XOR_CUDA) {
   torch::manual_seed(0);
   // We better be able to save and load a XOR model!
-  auto getLoss = [](Sequential model, uint32_t batch_size, bool is_cuda=false) {
+  auto getLoss = [](Sequential model,
+                    uint32_t batch_size,
+                    bool is_cuda = false) {
     auto inputs = torch::empty({batch_size, 2});
     auto labels = torch::empty({batch_size});
     if (is_cuda) {
@@ -269,3 +271,34 @@ TEST(SerializeTest, XOR_CUDA) {
   loss = getLoss(model3, 100, true);
   ASSERT_LT(loss.item<float>(), 0.1);
 }
+
+TEST(
+    SerializeTest,
+    CanSerializeModulesWithIntermediateModulesWithoutParametersOrBuffers) {
+  struct C : torch::nn::Module {
+    C() {
+      register_buffer("foo", torch::ones(5, torch::kInt32));
+    }
+  };
+  struct B : torch::nn::Module {};
+  struct A : torch::nn::Module {
+    A() {
+      register_module("b", std::make_shared<B>());
+      register_module("c", std::make_shared<C>());
+    }
+  };
+  struct M : torch::nn::Module {
+    M() {
+      register_module("a", std::make_shared<A>());
+    }
+  };
+
+  auto out = std::make_shared<M>();
+  std::stringstream ss;
+  torch::save(out, ss);
+  auto in = std::make_shared<M>();
+  torch::load(in, ss);
+
+  const int output = in->named_buffers()["a.c.foo"].sum().item<int>();
+  ASSERT_EQ(output, 5);
+}
index c298419..d027040 100644 (file)
@@ -495,6 +495,16 @@ class TORCH_API Module : public std::enable_shared_from_this<Module> {
   bool is_training_{true};
 };
 
+/// Serialize a `Module` pointer into an `OutputArchive`.
+TORCH_API serialize::OutputArchive& operator<<(
+    serialize::OutputArchive& archive,
+    const std::shared_ptr<nn::Module>& module);
+
+/// Deserializes a `Module` from an `InputArchive`.
+TORCH_API serialize::InputArchive& operator>>(
+    serialize::InputArchive& archive,
+    const std::shared_ptr<nn::Module>& module);
+
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ nn::Module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 template <typename ModuleType>
index 58998be..790a63d 100644 (file)
@@ -57,12 +57,8 @@ class ModuleHolder : torch::detail::ModuleHolderIndicator {
       typename Head,
       typename... Tail,
       typename = typename std::enable_if<
-      !(
-          torch::detail::is_module_holder_of<Head, ContainedType>::value
-          && (sizeof...(Tail) == 0)
-       )
-      >::type
-          >
+          !(torch::detail::is_module_holder_of<Head, ContainedType>::value &&
+            (sizeof...(Tail) == 0))>::type>
   explicit ModuleHolder(Head&& head, Tail&&... tail)
       : impl_(new Contained(
             std::forward<Head>(head),
@@ -160,22 +156,20 @@ class ModuleHolder : torch::detail::ModuleHolderIndicator {
   }
 };
 
-/// Serializes an `OptimizerBase` into an `OutputArchive`.
+/// Serializes a `ModuleHolder` into an `OutputArchive`.
 template <typename ModuleType>
 serialize::OutputArchive& operator<<(
     serialize::OutputArchive& archive,
     const nn::ModuleHolder<ModuleType>& module) {
-  module->save(archive);
-  return archive;
+  return archive << module.ptr();
 }
 
-/// Deserializes a `Tensor` from an `InputArchive`.
+/// Deserializes a `ModuleHolder` from an `InputArchive`.
 template <typename ModuleType>
 serialize::InputArchive& operator>>(
     serialize::InputArchive& archive,
     nn::ModuleHolder<ModuleType>& module) {
-  module->load(archive);
-  return archive;
+  return archive >> module.ptr();
 }
 
 } // namespace nn
index 38c4789..cc17506 100644 (file)
@@ -292,14 +292,9 @@ void Module::load(serialize::InputArchive& archive) {
     archive.read(buffer.key(), buffer.value(), /*is_buffer=*/true);
   }
   for (const auto& child : children_) {
-    // Modules that have no state at all (parameters or buffers) are currently
-    // not stored in Protobuf at all, so we can just skip them.
-    if (!child.value()->parameters_.is_empty() ||
-        !child.value()->buffers_.is_empty()) {
-      serialize::InputArchive child_archive;
-      archive.read(child.key(), child_archive);
-      child.value()->load(child_archive);
-    }
+    serialize::InputArchive child_archive;
+    archive.read(child.key(), child_archive);
+    child.value()->load(child_archive);
   }
 }
 
@@ -356,5 +351,21 @@ std::shared_ptr<Module> Module::shared_from_this_checked() const {
   }
   return std::const_pointer_cast<Module>(ptr);
 }
+
+serialize::OutputArchive& operator<<(
+    serialize::OutputArchive& archive,
+    const std::shared_ptr<nn::Module>& module) {
+  AT_CHECK(module != nullptr, "Cannot serialize empty module");
+  module->save(archive);
+  return archive;
+}
+
+serialize::InputArchive& operator>>(
+    serialize::InputArchive& archive,
+    const std::shared_ptr<nn::Module>& module) {
+  AT_CHECK(module != nullptr, "Cannot deserialize empty module");
+  module->load(archive);
+  return archive;
+}
 } // namespace nn
 } // namespace torch