TEST(SerializeTest, XOR_CUDA) {
torch::manual_seed(0);
// We better be able to save and load a XOR model!
- auto getLoss = [](Sequential model, uint32_t batch_size, bool is_cuda=false) {
+ auto getLoss = [](Sequential model,
+ uint32_t batch_size,
+ bool is_cuda = false) {
auto inputs = torch::empty({batch_size, 2});
auto labels = torch::empty({batch_size});
if (is_cuda) {
loss = getLoss(model3, 100, true);
ASSERT_LT(loss.item<float>(), 0.1);
}
+
+TEST(
+ SerializeTest,
+ CanSerializeModulesWithIntermediateModulesWithoutParametersOrBuffers) {
+ struct C : torch::nn::Module {
+ C() {
+ register_buffer("foo", torch::ones(5, torch::kInt32));
+ }
+ };
+ struct B : torch::nn::Module {};
+ struct A : torch::nn::Module {
+ A() {
+ register_module("b", std::make_shared<B>());
+ register_module("c", std::make_shared<C>());
+ }
+ };
+ struct M : torch::nn::Module {
+ M() {
+ register_module("a", std::make_shared<A>());
+ }
+ };
+
+ auto out = std::make_shared<M>();
+ std::stringstream ss;
+ torch::save(out, ss);
+ auto in = std::make_shared<M>();
+ torch::load(in, ss);
+
+ const int output = in->named_buffers()["a.c.foo"].sum().item<int>();
+ ASSERT_EQ(output, 5);
+}
bool is_training_{true};
};
+/// Serialize a `Module` pointer into an `OutputArchive`.
+TORCH_API serialize::OutputArchive& operator<<(
+ serialize::OutputArchive& archive,
+ const std::shared_ptr<nn::Module>& module);
+
+/// Deserializes a `Module` from an `InputArchive`.
+TORCH_API serialize::InputArchive& operator>>(
+ serialize::InputArchive& archive,
+ const std::shared_ptr<nn::Module>& module);
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ nn::Module ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <typename ModuleType>
typename Head,
typename... Tail,
typename = typename std::enable_if<
- !(
- torch::detail::is_module_holder_of<Head, ContainedType>::value
- && (sizeof...(Tail) == 0)
- )
- >::type
- >
+ !(torch::detail::is_module_holder_of<Head, ContainedType>::value &&
+ (sizeof...(Tail) == 0))>::type>
explicit ModuleHolder(Head&& head, Tail&&... tail)
: impl_(new Contained(
std::forward<Head>(head),
}
};
-/// Serializes an `OptimizerBase` into an `OutputArchive`.
+/// Serializes a `ModuleHolder` into an `OutputArchive`.
template <typename ModuleType>
serialize::OutputArchive& operator<<(
serialize::OutputArchive& archive,
const nn::ModuleHolder<ModuleType>& module) {
- module->save(archive);
- return archive;
+ return archive << module.ptr();
}
-/// Deserializes a `Tensor` from an `InputArchive`.
+/// Deserializes a `ModuleHolder` from an `InputArchive`.
template <typename ModuleType>
serialize::InputArchive& operator>>(
serialize::InputArchive& archive,
nn::ModuleHolder<ModuleType>& module) {
- module->load(archive);
- return archive;
+ return archive >> module.ptr();
}
} // namespace nn
archive.read(buffer.key(), buffer.value(), /*is_buffer=*/true);
}
for (const auto& child : children_) {
- // Modules that have no state at all (parameters or buffers) are currently
- // not stored in Protobuf at all, so we can just skip them.
- if (!child.value()->parameters_.is_empty() ||
- !child.value()->buffers_.is_empty()) {
- serialize::InputArchive child_archive;
- archive.read(child.key(), child_archive);
- child.value()->load(child_archive);
- }
+ serialize::InputArchive child_archive;
+ archive.read(child.key(), child_archive);
+ child.value()->load(child_archive);
}
}
}
return std::const_pointer_cast<Module>(ptr);
}
+
+serialize::OutputArchive& operator<<(
+ serialize::OutputArchive& archive,
+ const std::shared_ptr<nn::Module>& module) {
+ AT_CHECK(module != nullptr, "Cannot serialize empty module");
+ module->save(archive);
+ return archive;
+}
+
+serialize::InputArchive& operator>>(
+ serialize::InputArchive& archive,
+ const std::shared_ptr<nn::Module>& module) {
+ AT_CHECK(module != nullptr, "Cannot deserialize empty module");
+ module->load(archive);
+ return archive;
+}
} // namespace nn
} // namespace torch