Restore device in cpp API (#14711)
authorLu Fang <lufang@fb.com>
Tue, 4 Dec 2018 08:44:43 +0000 (00:44 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 4 Dec 2018 08:46:41 +0000 (00:46 -0800)
Summary:
This is a stack PR based on https://github.com/pytorch/pytorch/pull/14454.

It enables the restoring the storage to appropriate device.

~~[TODO]: add/modify appropriate tests~~ Done
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14711

Reviewed By: dzhulgakov

Differential Revision: D13315746

Pulled By: houseroad

fbshipit-source-id: fe6f24a45c35e88fd1a2eebc09950d4430fac185

test/cpp/api/serialize.cpp
torch/csrc/api/include/torch/serialize/input-archive.h
torch/csrc/api/src/serialize/input-archive.cpp
torch/csrc/jit/import.cpp
torch/csrc/jit/import.h

index 973fbbf..f6e2b03 100644 (file)
@@ -215,9 +215,13 @@ TEST(SerializeTest, Optim) {
 TEST(SerializeTest, XOR_CUDA) {
   torch::manual_seed(0);
   // We better be able to save and load a XOR model!
-  auto getLoss = [](Sequential model, uint32_t batch_size) {
+  auto getLoss = [](Sequential model, uint32_t batch_size, bool is_cuda=false) {
     auto inputs = torch::empty({batch_size, 2});
     auto labels = torch::empty({batch_size});
+    if (is_cuda) {
+      inputs = inputs.cuda();
+      labels = labels.cuda();
+    }
     for (size_t i = 0; i < batch_size; i++) {
       inputs[i] = torch::randint(2, {2}, torch::kInt64);
       labels[i] = inputs[i][0].item<int64_t>() ^ inputs[i][1].item<int64_t>();
@@ -255,10 +259,13 @@ TEST(SerializeTest, XOR_CUDA) {
   ASSERT_LT(loss.item<float>(), 0.1);
 
   model2->to(torch::kCUDA);
+  loss = getLoss(model2, 100, true);
+  ASSERT_LT(loss.item<float>(), 0.1);
+
   auto tempfile2 = torch::utils::make_tempfile();
   torch::save(model2, tempfile2.name);
   torch::load(model3, tempfile2.name);
 
-  loss = getLoss(model3, 100);
+  loss = getLoss(model3, 100, true);
   ASSERT_LT(loss.item<float>(), 0.1);
 }
index dcfeb9e..c7c5bcc 100644 (file)
@@ -1,6 +1,9 @@
 #pragma once
 
+#include <c10/util/Optional.h>
+#include <c10/Device.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
+#include <torch/types.h>
 
 #include <iosfwd>
 #include <memory>
@@ -52,12 +55,16 @@ class TORCH_API InputArchive final {
   void read(const std::string& key, InputArchive& archive);
 
   /// Loads the `InputArchive` from a serialized representation stored in the
-  /// file at `filename`.
-  void load_from(const std::string& filename);
+  /// file at `filename`. Storage are remapped using device option. If device
+  /// is not specified, the module is loaded to the original device.
+  void load_from(const std::string& filename,
+      c10::optional<torch::Device> device = c10::nullopt);
 
   /// Loads the `InputArchive` from a serialized representation stored in the
-  /// given `stream`.
-  void load_from(std::istream& stream);
+  /// given `stream`. Storage are remapped using device option. If device
+  /// is not specified, the module is loaded to the original device.
+  void load_from(std::istream& stream,
+      c10::optional<torch::Device> device = c10::nullopt);
 
   /// Forwards all arguments to `read()`.
   /// Useful for generic code that can be re-used for both `InputArchive` and
index 1e55182..444fe79 100644 (file)
@@ -33,7 +33,11 @@ void InputArchive::read(
   // clang-format on
   if (tensor.defined()) {
     torch::NoGradGuard guard;
-    tensor.set_(*read_tensor->slot());
+    if (tensor.device() != read_tensor->slot()->device()) {
+      tensor.set_data(autograd::Variable(*read_tensor->slot()).data());
+    } else {
+      tensor.set_(*read_tensor->slot());
+    }
   } else {
     tensor = std::move(*read_tensor->slot());
   }
@@ -48,12 +52,14 @@ void InputArchive::read(const std::string& key, InputArchive& archive) {
   }
 }
 
-void InputArchive::load_from(const std::string& filename) {
-  module_ = torch::jit::load(filename);
+void InputArchive::load_from(const std::string& filename,
+    c10::optional<torch::Device> device /*= c10::nullopt*/) {
+  module_ = torch::jit::load(filename, device);
 }
 
-void InputArchive::load_from(std::istream& stream) {
-  module_ = torch::jit::load(stream);
+void InputArchive::load_from(std::istream& stream,
+    c10::optional<torch::Device> device /*= c10::nullopt*/) {
+  module_ = torch::jit::load(stream, device);
 }
 } // namespace serialize
 } // namespace torch
index 5cc139f..1dde3c1 100644 (file)
@@ -222,7 +222,8 @@ void import_ir_module(
   deserializer.deserialize(module_lookup, device);
 }
 
-std::shared_ptr<script::Module> load(std::istream& in) {
+std::shared_ptr<script::Module> load(std::istream& in,
+    c10::optional<at::Device> device) {
   auto module = std::make_shared<script::Module>();
 
   auto module_lookup = [&](const std::vector<std::string>& qualified_name) {
@@ -237,18 +238,18 @@ std::shared_ptr<script::Module> load(std::istream& in) {
   };
 
   ScriptModuleDeserializer deserializer(&in);
-  // TODO: add device support in C++ API
-  deserializer.deserialize(module_lookup, c10::optional<at::Device>(at::Device("cpu")));
+  deserializer.deserialize(module_lookup, device);
 
   return module;
 }
 
-std::shared_ptr<script::Module> load(const std::string& filename) {
+std::shared_ptr<script::Module> load(const std::string& filename,
+    c10::optional<at::Device> device) {
   std::ifstream in(filename, std::ios_base::binary);
 
   AT_CHECK(! in.fail(), "load: could not open file ", filename);
 
-  auto module = load(in);
+  auto module = load(in, device);
 
   return module;
 }
index 4b2a6e3..eaab9a1 100644 (file)
@@ -14,25 +14,27 @@ using ModuleLookup = std::function<std::shared_ptr<script::Module>(
 TORCH_API void import_ir_module(
     ModuleLookup module_lookup,
     const std::string& filename,
-    c10::optional<c10::Device> device);
+    c10::optional<c10::Device> device = c10::nullopt);
 
 TORCH_API void import_ir_module(
     ModuleLookup module_lookup,
     std::istream& in,
-    c10::optional<c10::Device> device);
+    c10::optional<c10::Device> device = c10::nullopt);
 
 /// Loads a serialized `script::Module` from the given `istream`.
 ///
 /// The istream must contain a serialized `script::Module`, exported via
 /// `torch::jit::ExportModule` in C++.
-TORCH_API std::shared_ptr<script::Module> load(std::istream& in);
+TORCH_API std::shared_ptr<script::Module> load(std::istream& in,
+    c10::optional<c10::Device> device = c10::nullopt);
 
 /// Loads a serialized `script::Module` from the given `filename`.
 ///
 /// The file stored at the location given in `filename` must contain a
 /// serialized `script::Module`, exported either via `ScriptModule.save()` in
 /// Python or `torch::jit::ExportModule` in C++.
-TORCH_API std::shared_ptr<script::Module> load(const std::string& filename);
+TORCH_API std::shared_ptr<script::Module> load(const std::string& filename,
+    c10::optional<c10::Device> device = c10::nullopt);
 
 } // namespace jit
 } // namespace torch