From 6e0c5a8a4e03e489b60f8d72692f27ea6f377a67 Mon Sep 17 00:00:00 2001 From: Lu Fang Date: Tue, 4 Dec 2018 00:44:43 -0800 Subject: [PATCH] Restore device in cpp API (#14711) Summary: This is a stack PR based on https://github.com/pytorch/pytorch/pull/14454. It enables the restoring the storage to appropriate device. ~~[TODO]: add/modify appropriate tests~~ Done Pull Request resolved: https://github.com/pytorch/pytorch/pull/14711 Reviewed By: dzhulgakov Differential Revision: D13315746 Pulled By: houseroad fbshipit-source-id: fe6f24a45c35e88fd1a2eebc09950d4430fac185 --- test/cpp/api/serialize.cpp | 11 +++++++++-- torch/csrc/api/include/torch/serialize/input-archive.h | 15 +++++++++++---- torch/csrc/api/src/serialize/input-archive.cpp | 16 +++++++++++----- torch/csrc/jit/import.cpp | 11 ++++++----- torch/csrc/jit/import.h | 10 ++++++---- 5 files changed, 43 insertions(+), 20 deletions(-) diff --git a/test/cpp/api/serialize.cpp b/test/cpp/api/serialize.cpp index 973fbbf..f6e2b03 100644 --- a/test/cpp/api/serialize.cpp +++ b/test/cpp/api/serialize.cpp @@ -215,9 +215,13 @@ TEST(SerializeTest, Optim) { TEST(SerializeTest, XOR_CUDA) { torch::manual_seed(0); // We better be able to save and load a XOR model! - auto getLoss = [](Sequential model, uint32_t batch_size) { + auto getLoss = [](Sequential model, uint32_t batch_size, bool is_cuda=false) { auto inputs = torch::empty({batch_size, 2}); auto labels = torch::empty({batch_size}); + if (is_cuda) { + inputs = inputs.cuda(); + labels = labels.cuda(); + } for (size_t i = 0; i < batch_size; i++) { inputs[i] = torch::randint(2, {2}, torch::kInt64); labels[i] = inputs[i][0].item() ^ inputs[i][1].item(); @@ -255,10 +259,13 @@ TEST(SerializeTest, XOR_CUDA) { ASSERT_LT(loss.item(), 0.1); model2->to(torch::kCUDA); + loss = getLoss(model2, 100, true); + ASSERT_LT(loss.item(), 0.1); + auto tempfile2 = torch::utils::make_tempfile(); torch::save(model2, tempfile2.name); torch::load(model3, tempfile2.name); - loss = getLoss(model3, 100); + loss = getLoss(model3, 100, true); ASSERT_LT(loss.item(), 0.1); } diff --git a/torch/csrc/api/include/torch/serialize/input-archive.h b/torch/csrc/api/include/torch/serialize/input-archive.h index dcfeb9e..c7c5bcc 100644 --- a/torch/csrc/api/include/torch/serialize/input-archive.h +++ b/torch/csrc/api/include/torch/serialize/input-archive.h @@ -1,6 +1,9 @@ #pragma once +#include +#include #include +#include #include #include @@ -52,12 +55,16 @@ class TORCH_API InputArchive final { void read(const std::string& key, InputArchive& archive); /// Loads the `InputArchive` from a serialized representation stored in the - /// file at `filename`. - void load_from(const std::string& filename); + /// file at `filename`. Storage are remapped using device option. If device + /// is not specified, the module is loaded to the original device. + void load_from(const std::string& filename, + c10::optional device = c10::nullopt); /// Loads the `InputArchive` from a serialized representation stored in the - /// given `stream`. - void load_from(std::istream& stream); + /// given `stream`. Storage are remapped using device option. If device + /// is not specified, the module is loaded to the original device. + void load_from(std::istream& stream, + c10::optional device = c10::nullopt); /// Forwards all arguments to `read()`. /// Useful for generic code that can be re-used for both `InputArchive` and diff --git a/torch/csrc/api/src/serialize/input-archive.cpp b/torch/csrc/api/src/serialize/input-archive.cpp index 1e55182..444fe79 100644 --- a/torch/csrc/api/src/serialize/input-archive.cpp +++ b/torch/csrc/api/src/serialize/input-archive.cpp @@ -33,7 +33,11 @@ void InputArchive::read( // clang-format on if (tensor.defined()) { torch::NoGradGuard guard; - tensor.set_(*read_tensor->slot()); + if (tensor.device() != read_tensor->slot()->device()) { + tensor.set_data(autograd::Variable(*read_tensor->slot()).data()); + } else { + tensor.set_(*read_tensor->slot()); + } } else { tensor = std::move(*read_tensor->slot()); } @@ -48,12 +52,14 @@ void InputArchive::read(const std::string& key, InputArchive& archive) { } } -void InputArchive::load_from(const std::string& filename) { - module_ = torch::jit::load(filename); +void InputArchive::load_from(const std::string& filename, + c10::optional device /*= c10::nullopt*/) { + module_ = torch::jit::load(filename, device); } -void InputArchive::load_from(std::istream& stream) { - module_ = torch::jit::load(stream); +void InputArchive::load_from(std::istream& stream, + c10::optional device /*= c10::nullopt*/) { + module_ = torch::jit::load(stream, device); } } // namespace serialize } // namespace torch diff --git a/torch/csrc/jit/import.cpp b/torch/csrc/jit/import.cpp index 5cc139f..1dde3c1 100644 --- a/torch/csrc/jit/import.cpp +++ b/torch/csrc/jit/import.cpp @@ -222,7 +222,8 @@ void import_ir_module( deserializer.deserialize(module_lookup, device); } -std::shared_ptr load(std::istream& in) { +std::shared_ptr load(std::istream& in, + c10::optional device) { auto module = std::make_shared(); auto module_lookup = [&](const std::vector& qualified_name) { @@ -237,18 +238,18 @@ std::shared_ptr load(std::istream& in) { }; ScriptModuleDeserializer deserializer(&in); - // TODO: add device support in C++ API - deserializer.deserialize(module_lookup, c10::optional(at::Device("cpu"))); + deserializer.deserialize(module_lookup, device); return module; } -std::shared_ptr load(const std::string& filename) { +std::shared_ptr load(const std::string& filename, + c10::optional device) { std::ifstream in(filename, std::ios_base::binary); AT_CHECK(! in.fail(), "load: could not open file ", filename); - auto module = load(in); + auto module = load(in, device); return module; } diff --git a/torch/csrc/jit/import.h b/torch/csrc/jit/import.h index 4b2a6e3..eaab9a1 100644 --- a/torch/csrc/jit/import.h +++ b/torch/csrc/jit/import.h @@ -14,25 +14,27 @@ using ModuleLookup = std::function( TORCH_API void import_ir_module( ModuleLookup module_lookup, const std::string& filename, - c10::optional device); + c10::optional device = c10::nullopt); TORCH_API void import_ir_module( ModuleLookup module_lookup, std::istream& in, - c10::optional device); + c10::optional device = c10::nullopt); /// Loads a serialized `script::Module` from the given `istream`. /// /// The istream must contain a serialized `script::Module`, exported via /// `torch::jit::ExportModule` in C++. -TORCH_API std::shared_ptr load(std::istream& in); +TORCH_API std::shared_ptr load(std::istream& in, + c10::optional device = c10::nullopt); /// Loads a serialized `script::Module` from the given `filename`. /// /// The file stored at the location given in `filename` must contain a /// serialized `script::Module`, exported either via `ScriptModule.save()` in /// Python or `torch::jit::ExportModule` in C++. -TORCH_API std::shared_ptr load(const std::string& filename); +TORCH_API std::shared_ptr load(const std::string& filename, + c10::optional device = c10::nullopt); } // namespace jit } // namespace torch -- 2.7.4