Restore device when import jit script module (#14454)
authorLu Fang <lufang@fb.com>
Mon, 3 Dec 2018 22:07:50 +0000 (14:07 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 3 Dec 2018 22:10:30 +0000 (14:10 -0800)
Summary:
We align the restore logic to `torch.load`, we try to restore to the right device, and if the device is not available, an exception is raised. We allow user to remap the device through a parameter `map_location`, it can be 1) a string like 'cuda:0`, `cpu`, 2) a device, torch.device('cpu'), 3) a dict, {'cuda:1', 'cuda:0'}, and a function, and its signature looks like string map_location(tensor, saved_device_string).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14454

Reviewed By: zrphercule

Differential Revision: D13271956

Pulled By: houseroad

fbshipit-source-id: dfd6b6049b0dc07549ddeddf2dea03ac53ba6d49

caffe2/proto/torch.proto
test/test_jit.py
torch/csrc/jit/export.cpp
torch/csrc/jit/import.cpp
torch/csrc/jit/import.h
torch/csrc/jit/script/init.cpp
torch/jit/__init__.py
torch/serialization.py

index 78e4c0d..0dbca7c 100644 (file)
@@ -18,7 +18,9 @@ message TensorDef {
 
   optional RecordRef data = 6;
 
-  // future: device options
+  // device field stores the canonical device string, and it follows the
+  // format below: `(cpu|cuda)[:<device-index>]`, e.g., 'cuda:0'
+  optional string device = 7;
 }
 
 message ParameterDef {
index b109ec6..6dca770 100644 (file)
@@ -266,11 +266,11 @@ class JitTestCase(TestCase):
             if pp != pp2:
                 self.assertMultiLineEqual(pp, pp2)
 
-    def getExportImportCopy(self, m, also_test_file=True):
+    def getExportImportCopy(self, m, also_test_file=True, map_location=None):
         buffer = io.BytesIO()
         torch.jit.save(m, buffer)
         buffer.seek(0)
-        imported = torch.jit.load(buffer)
+        imported = torch.jit.load(buffer, map_location=map_location)
 
         if not also_test_file:
             return imported
@@ -282,7 +282,7 @@ class JitTestCase(TestCase):
         try:
             f.close()
             imported.save(f.name)
-            result = torch.jit.load(f.name)
+            result = torch.jit.load(f.name, map_location=map_location)
         finally:
             os.unlink(f.name)
 
@@ -489,6 +489,50 @@ class TestJit(JitTestCase):
         self.assertExpectedGraph(trace)
         self.assertExportImport(trace, (x, y))
 
+    def test_restore_device(self):
+        # main purpose is checking map_location works
+        m = torch.jit.ScriptModule()
+        cpu_device_str = 'cpu'
+        m.p0 = nn.Parameter(torch.tensor([0.3], dtype=torch.float,
+                                         device=cpu_device_str))
+        m.register_buffer('b0', torch.tensor([0.9], dtype=torch.float,
+                                             device=cpu_device_str))
+        m2 = self.getExportImportCopy(m)
+        self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
+        self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
+        self.assertFalse(m2.p0.is_cuda)
+        self.assertFalse(m2.b0.is_cuda)
+
+    @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
+    def test_restore_device_cuda(self):
+        m = torch.jit.ScriptModule()
+        cuda_device_str = 'cuda:' + str(torch.cuda.device_count() - 1)
+        m.p0 = nn.Parameter(torch.tensor([0.3], dtype=torch.float,
+                                         device=cuda_device_str))
+        m.register_buffer('b0', torch.tensor([0.9], dtype=torch.float,
+                                             device=cuda_device_str))
+        self.assertTrue(m.p0.is_cuda)
+        self.assertTrue(m.b0.is_cuda)
+
+        # restore to the saved devices
+        m2 = self.getExportImportCopy(m)
+        self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
+        self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
+        self.assertEqual(str(m2.p0.device), cuda_device_str)
+        self.assertEqual(str(m2.b0.device), cuda_device_str)
+
+        # restore all to cpu using string
+        cpu_device_str = 'cpu'
+        m3 = self.getExportImportCopy(m, map_location=cpu_device_str)
+        self.assertEqual(str(m3.p0.device), cpu_device_str)
+        self.assertEqual(str(m3.b0.device), cpu_device_str)
+
+        # restore all to first gpu using device
+        m4 = self.getExportImportCopy(
+            m3, map_location=torch.device('cuda:0'))
+        self.assertEqual(str(m4.p0.device), 'cuda:0')
+        self.assertEqual(str(m4.b0.device), 'cuda:0')
+
     def test_typeas_trace_check(self):
         a = torch.tensor([0.4], requires_grad=True)
         b = torch.tensor([0.7], requires_grad=True)
@@ -6214,7 +6258,7 @@ a")
 
             def __init__(self):
                 super(M, self).__init__(False)
-                self.param = torch.nn.Parameter(torch.zeros((5, 5), device='cuda').random_())
+                self.param = torch.nn.Parameter(torch.zeros((5, 5), device='cuda:0').random_())
 
             @torch.jit.script_method
             def foo(self):
@@ -6224,7 +6268,7 @@ a")
         m_import = self.getExportImportCopy(m_orig)
         # check to make sure the storage wasn't resized
         self.assertTrue(m_orig.param.storage().size() == 25)
-        self.assertTrue(m_import.foo().device == torch.device('cpu'))
+        self.assertTrue(m_import.foo().device == torch.device('cuda:0'))
         self.assertEqual(m_orig.foo(), m_import.foo())
         self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
 
index 6c8876c..932efbb 100644 (file)
@@ -579,7 +579,10 @@ void ScriptModuleSerializer::convertAndWriteTensor(
   auto* data = tensor_proto->mutable_data();
   data->set_key(storage_it->second);
 
-  // TODO handle device case, set the device_detail and load to CUDA device
+  // handle device case, set the device_detail and load to CUDA device
+  std::stringstream ss;
+  ss << tensor.device();
+  tensor_proto->set_device(ss.str());
 }
 
 void ScriptModuleSerializer::writeTensorTable(torch::ModelDef* model_def) {
index 1540547..20f4312 100644 (file)
@@ -37,7 +37,8 @@ class ScriptModuleDeserializer final {
 
   ScriptModuleDeserializer(std::istream* is);
 
-  void deserialize(ModuleLookup module_lookup);
+  void deserialize(ModuleLookup module_lookup,
+      c10::optional<at::Device> device);
 
 private:
  at::Tensor loadTensor(
@@ -53,6 +54,7 @@ private:
  // this is a hack to make sure the script module created in C++ is the
  // same as created in Python
  ModuleLookup moduleLookup_;
+ c10::optional<at::Device> device_;
  std::vector<std::string> moduleStack_;
 
  std::vector<at::Tensor> tensor_table_;
@@ -66,7 +68,8 @@ ScriptModuleDeserializer::ScriptModuleDeserializer(const std::string& filename)
 ScriptModuleDeserializer::ScriptModuleDeserializer(std::istream* is)
     : ifs_(), reader_(is) {}
 
-void ScriptModuleDeserializer::deserialize(ModuleLookup module_lookup) {
+void ScriptModuleDeserializer::deserialize(ModuleLookup module_lookup,
+    c10::optional<at::Device> device) {
   torch::ModelDef model_def;
   at::DataPtr data_ptr;
   size_t data_size;
@@ -95,6 +98,7 @@ void ScriptModuleDeserializer::deserialize(ModuleLookup module_lookup) {
       model_def.ParseFromString(binary_string),
       "JSON transcoder produced invalid protobuf output.");
   moduleLookup_ = module_lookup;
+  device_ = device;
 
   const auto& module_def = model_def.main_module();
   loadTensorTable(&model_def);
@@ -116,23 +120,54 @@ at::Tensor ScriptModuleDeserializer::loadTensor(const torch::TensorDef& tensor_p
   std::vector<int64_t> strides(tensor_proto.strides().begin(), tensor_proto.strides().end());
   auto type = at::typeMetaToScalarType(
       caffe2::DataTypeToTypeMeta(tensor_proto.data_type()));
-
   const std::string& record_key = tensor_proto.data().key();
+  AT_ASSERT(tensor_proto.has_device() && !tensor_proto.device().empty());
+  at::Device device(tensor_proto.device());
+  if (device_.has_value()) {
+    // override the device, if user provides map_location
+    device = device_.value();
+  }
+
   auto storage_it = storageMap.find(record_key);
   if (storage_it == storageMap.end()) {
     at::DataPtr storage_ptr;
     uint64_t record_size;
     std::tie(storage_ptr, record_size) = reader_.getRecord(record_key);
-    auto storage = at::Storage(
+    auto cpu_storage = at::Storage(
         at::CPU(type).typeMeta(),
         std::move(storage_ptr),
         record_size / at::CPU(type).typeMeta().itemsize(),
         nullptr); // NB: we didn't set any allocator for the tensor
-    storage_it = storageMap.insert(std::make_pair(record_key, storage)).first;
+    if (device.type() == at::DeviceType::CPU) {
+      storage_it = storageMap.insert(std::make_pair(
+            record_key, cpu_storage)).first;
+    } else if (device.type() == at::DeviceType::CUDA) {
+      at::Tensor cpu_tensor = at::CPU(type)._th_tensor(
+          cpu_storage, tensor_proto.offset(), dims, strides);
+      at::Storage cuda_storage = cpu_tensor.to(device,
+          cpu_tensor.scalar_type()).storage();
+      storage_it = storageMap.insert(std::make_pair(
+            record_key, cuda_storage)).first;
+    } else {
+      AT_ERROR("supported devices include CPU and CUDA, however got ",
+          at::DeviceTypeName(device.type(), false));
+    }
   }
-  auto t = at::CPU(type)._th_tensor(
-      storage_it->second, tensor_proto.offset(), dims, strides);
-  return autograd::make_variable(t, tensor_proto.requires_grad());
+  AT_ASSERT(storage_it->second.device() == device);
+
+  at::Tensor result;
+  if (device.type() == at::DeviceType::CPU) {
+    result = at::CPU(type)._th_tensor(
+        storage_it->second, tensor_proto.offset(), dims, strides);
+  } else if (device.type() == at::DeviceType::CUDA) {
+    result = at::CUDA(type)._th_tensor(
+        storage_it->second, tensor_proto.offset(), dims, strides);
+  }
+  AT_ASSERT(result.defined());
+
+  result = autograd::make_variable(result, tensor_proto.requires_grad());
+
+  return result;
 }
 
 void ScriptModuleDeserializer::convertModule(
@@ -164,16 +199,18 @@ void ScriptModuleDeserializer::convertModule(
 
 void import_ir_module(
     ModuleLookup module_lookup,
-    std::istream& in) {
+    std::istream& in,
+    c10::optional<at::Device> device) {
   ScriptModuleDeserializer deserializer(&in);
-  deserializer.deserialize(module_lookup);
+  deserializer.deserialize(module_lookup, device);
 }
 
 void import_ir_module(
     ModuleLookup module_lookup,
-    const std::string& filename) {
+    const std::string& filename,
+    c10::optional<at::Device> device) {
   ScriptModuleDeserializer deserializer(filename);
-  deserializer.deserialize(module_lookup);
+  deserializer.deserialize(module_lookup, device);
 }
 
 std::shared_ptr<script::Module> load(std::istream& in) {
@@ -191,7 +228,8 @@ std::shared_ptr<script::Module> load(std::istream& in) {
   };
 
   ScriptModuleDeserializer deserializer(&in);
-  deserializer.deserialize(module_lookup);
+  // TODO: add device support in C++ API
+  deserializer.deserialize(module_lookup, c10::optional<at::Device>(at::Device("cpu")));
 
   return module;
 }
index 084b6a8..4b2a6e3 100644 (file)
@@ -13,9 +13,13 @@ using ModuleLookup = std::function<std::shared_ptr<script::Module>(
 
 TORCH_API void import_ir_module(
     ModuleLookup module_lookup,
-    const std::string& filename);
+    const std::string& filename,
+    c10::optional<c10::Device> device);
 
-TORCH_API void import_ir_module(ModuleLookup module_lookup, std::istream& in);
+TORCH_API void import_ir_module(
+    ModuleLookup module_lookup,
+    std::istream& in,
+    c10::optional<c10::Device> device);
 
 /// Loads a serialized `script::Module` from the given `istream`.
 ///
index 68058f9..2659db6 100644 (file)
@@ -712,12 +712,24 @@ void initJitScriptBindings(PyObject* module) {
   });
 
   m.def("merge_type_from_type_comment", &mergeTypesFromTypeComment);
-  m.def("import_ir_module", [](ModuleLookup module_lookup, const std::string& filename) {
-    import_ir_module(module_lookup, filename);
+  m.def("import_ir_module", [](ModuleLookup module_lookup, const std::string& filename,
+        py::object map_location) {
+    c10::optional<at::Device> optional_device;
+    if (!map_location.is(py::none())) {
+      AT_ASSERT(THPDevice_Check(map_location.ptr()));
+      optional_device = reinterpret_cast<THPDevice*>(map_location.ptr())->device;
+    }
+    import_ir_module(module_lookup, filename, optional_device);
   });
-  m.def("import_ir_module_from_buffer", [](ModuleLookup module_lookup, const std::string& buffer) {
+  m.def("import_ir_module_from_buffer", [](ModuleLookup module_lookup,
+        const std::string& buffer, py::object map_location) {
     std::istringstream in(buffer);
-    import_ir_module(module_lookup, in);
+    c10::optional<at::Device> optional_device;
+    if (!map_location.is(py::none())) {
+      AT_ASSERT(THPDevice_Check(map_location.ptr()));
+      optional_device = reinterpret_cast<THPDevice*>(map_location.ptr())->device;
+    }
+    import_ir_module(module_lookup, in, optional_device);
   });
   m.def("_jit_import_methods", import_methods);
   m.def("_jit_set_emit_module_hook", setEmitModuleHook);
index b930928..7b5a98d 100644 (file)
@@ -1,11 +1,13 @@
 import torch._C
 from torch import Tensor
 from torch.autograd import Variable, function
+from torch.serialization import validate_cuda_device
 from torch.nn import Module, ModuleList, ParameterList, Parameter, Sequential
 from torch.jit.frontend import get_jit_ast, get_default_args
 import torch.backends.cudnn as cudnn
 import torch.jit.annotations
-from torch._six import raise_from, with_metaclass, get_function_from_type
+from torch._six import raise_from, with_metaclass, get_function_from_type, \
+    string_classes
 from .._jit_internal import createResolutionCallback, _compiled_weak_fns, \
     _weak_script_methods, _weak_modules, _weak_types, COMPILED, \
     COMPILATION_PENDING, _boolean_dispatched
@@ -70,17 +72,23 @@ def scope(scope_name):
             tracing_state.pop_scope()
 
 
-def load(f):
+def load(f, map_location=None):
     r"""
         Load a ``ScriptModule`` previously saved with :func:`save <torch.jit.save>`
 
-        .. DANGER::
-           All previously saved modules, no matter their device, are always loaded onto the CPU.
-           This is different from :func:`torch.load`'s semantics and may change in the future.
+        All previously saved modules, no matter their device, are first loaded onto CPU,
+        and then are moved to the devices they were saved from. If this fails (e.g. because
+        the run time system doesn't have certain devices), an exception is raised.
+        However, storages can be dynamically remapped to an alternative set of devices
+        using the `map_location` argument. Comparing to :func:`torch.load`, `map_location`
+        in this function is simplified, which only accepts a string (e.g., 'cpu', 'cuda:0'),
+        or torch.device (e.g., torch.device('cpu'))
 
         Arguments:
             f: a file-like object (has to implement read, readline, tell, and seek),
                 or a string containing a file name
+            map_location: can a string (e.g., 'cpu', 'cuda:0'), a device (e.g.,
+                torch.device('cpu'))
 
         Returns:
             A ``ScriptModule`` object.
@@ -90,7 +98,12 @@ def load(f):
             # Load ScriptModule from io.BytesIO object
             >>> with open('scriptmodule.pt', 'rb') as f:
                     buffer = io.BytesIO(f.read())
+            # Load all tensors to the original device
             >>> torch.jit.load(buffer)
+            # Load all tensors onto CPU, using a device
+            >>> torch.jit.load(buffer, map_location=torch.device('cpu'))
+            # Load all tensors onto CPU, using a string
+            >>> torch.jit.load(buffer, map_location='cpu')
     """
     m = ScriptModule()
 
@@ -102,12 +115,21 @@ def load(f):
             curr = getattr(curr, name)
         return curr
 
+    if isinstance(map_location, string_classes):
+        map_location = torch.device(map_location)
+    elif not (map_location is None or
+              isinstance(map_location, torch.device)):
+        raise ValueError("map_location should be either None, string or torch.device, "
+                         "but got type: " + str(type(map_location)))
+    if (str(map_location).startswith('cuda')):
+        validate_cuda_device(map_location)
+
     if isinstance(f, str) or \
             (sys.version_info[0] == 2 and isinstance(f, unicode)) or \
             (sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
-        torch._C.import_ir_module(module_lookup, f)
+        torch._C.import_ir_module(module_lookup, f, map_location)
     else:
-        torch._C.import_ir_module_from_buffer(module_lookup, f.read())
+        torch._C.import_ir_module_from_buffer(module_lookup, f.read(), map_location)
     return m
 
 
index 98e88c9..ba73ab3 100644 (file)
@@ -64,25 +64,34 @@ def _cpu_deserialize(obj, location):
         return obj
 
 
+def validate_cuda_device(location):
+    if isinstance(location, torch.device):
+        location = str(location)
+    if not isinstance(location, _string_classes):
+        raise ValueError("location should be a string or torch.device")
+    if location[5:] == '':
+        device = 0
+    else:
+        device = max(int(location[5:]), 0)
+
+    if not torch.cuda.is_available():
+        raise RuntimeError('Attempting to deserialize object on a CUDA '
+                           'device but torch.cuda.is_available() is False. '
+                           'If you are running on a CPU-only machine, '
+                           'please use torch.load with map_location=\'cpu\' '
+                           'to map your storages to the CPU.')
+    if device >= torch.cuda.device_count():
+        raise RuntimeError('Attempting to deserialize object on CUDA device '
+                           '{} but torch.cuda.device_count() is {}. Please use '
+                           'torch.load with map_location to map your storages '
+                           'to an existing device.'.format(
+                               device, torch.cuda.device_count()))
+    return device
+
+
 def _cuda_deserialize(obj, location):
     if location.startswith('cuda'):
-        if location[5:] == '':
-            device = 0
-        else:
-            device = max(int(location[5:]), 0)
-
-        if not torch.cuda.is_available():
-            raise RuntimeError('Attempting to deserialize object on a CUDA '
-                               'device but torch.cuda.is_available() is False. '
-                               'If you are running on a CPU-only machine, '
-                               'please use torch.load with map_location=\'cpu\' '
-                               'to map your storages to the CPU.')
-        if device >= torch.cuda.device_count():
-            raise RuntimeError('Attempting to deserialize object on CUDA device '
-                               '{} but torch.cuda.device_count() is {}. Please use '
-                               'torch.load with map_location to map your storages '
-                               'to an existing device.'.format(
-                                   device, torch.cuda.device_count()))
+        device = validate_cuda_device(location)
         return obj.cuda(device)