optional RecordRef data = 6;
- // future: device options
+ // device field stores the canonical device string, and it follows the
+ // format below: `(cpu|cuda)[:<device-index>]`, e.g., 'cuda:0'
+ optional string device = 7;
}
message ParameterDef {
if pp != pp2:
self.assertMultiLineEqual(pp, pp2)
- def getExportImportCopy(self, m, also_test_file=True):
+ def getExportImportCopy(self, m, also_test_file=True, map_location=None):
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
- imported = torch.jit.load(buffer)
+ imported = torch.jit.load(buffer, map_location=map_location)
if not also_test_file:
return imported
try:
f.close()
imported.save(f.name)
- result = torch.jit.load(f.name)
+ result = torch.jit.load(f.name, map_location=map_location)
finally:
os.unlink(f.name)
self.assertExpectedGraph(trace)
self.assertExportImport(trace, (x, y))
+ def test_restore_device(self):
+ # main purpose is checking map_location works
+ m = torch.jit.ScriptModule()
+ cpu_device_str = 'cpu'
+ m.p0 = nn.Parameter(torch.tensor([0.3], dtype=torch.float,
+ device=cpu_device_str))
+ m.register_buffer('b0', torch.tensor([0.9], dtype=torch.float,
+ device=cpu_device_str))
+ m2 = self.getExportImportCopy(m)
+ self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
+ self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
+ self.assertFalse(m2.p0.is_cuda)
+ self.assertFalse(m2.b0.is_cuda)
+
+ @unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
+ def test_restore_device_cuda(self):
+ m = torch.jit.ScriptModule()
+ cuda_device_str = 'cuda:' + str(torch.cuda.device_count() - 1)
+ m.p0 = nn.Parameter(torch.tensor([0.3], dtype=torch.float,
+ device=cuda_device_str))
+ m.register_buffer('b0', torch.tensor([0.9], dtype=torch.float,
+ device=cuda_device_str))
+ self.assertTrue(m.p0.is_cuda)
+ self.assertTrue(m.b0.is_cuda)
+
+ # restore to the saved devices
+ m2 = self.getExportImportCopy(m)
+ self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
+ self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
+ self.assertEqual(str(m2.p0.device), cuda_device_str)
+ self.assertEqual(str(m2.b0.device), cuda_device_str)
+
+ # restore all to cpu using string
+ cpu_device_str = 'cpu'
+ m3 = self.getExportImportCopy(m, map_location=cpu_device_str)
+ self.assertEqual(str(m3.p0.device), cpu_device_str)
+ self.assertEqual(str(m3.b0.device), cpu_device_str)
+
+ # restore all to first gpu using device
+ m4 = self.getExportImportCopy(
+ m3, map_location=torch.device('cuda:0'))
+ self.assertEqual(str(m4.p0.device), 'cuda:0')
+ self.assertEqual(str(m4.b0.device), 'cuda:0')
+
def test_typeas_trace_check(self):
a = torch.tensor([0.4], requires_grad=True)
b = torch.tensor([0.7], requires_grad=True)
def __init__(self):
super(M, self).__init__(False)
- self.param = torch.nn.Parameter(torch.zeros((5, 5), device='cuda').random_())
+ self.param = torch.nn.Parameter(torch.zeros((5, 5), device='cuda:0').random_())
@torch.jit.script_method
def foo(self):
m_import = self.getExportImportCopy(m_orig)
# check to make sure the storage wasn't resized
self.assertTrue(m_orig.param.storage().size() == 25)
- self.assertTrue(m_import.foo().device == torch.device('cpu'))
+ self.assertTrue(m_import.foo().device == torch.device('cuda:0'))
self.assertEqual(m_orig.foo(), m_import.foo())
self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
auto* data = tensor_proto->mutable_data();
data->set_key(storage_it->second);
- // TODO handle device case, set the device_detail and load to CUDA device
+ // handle device case, set the device_detail and load to CUDA device
+ std::stringstream ss;
+ ss << tensor.device();
+ tensor_proto->set_device(ss.str());
}
void ScriptModuleSerializer::writeTensorTable(torch::ModelDef* model_def) {
ScriptModuleDeserializer(std::istream* is);
- void deserialize(ModuleLookup module_lookup);
+ void deserialize(ModuleLookup module_lookup,
+ c10::optional<at::Device> device);
private:
at::Tensor loadTensor(
// this is a hack to make sure the script module created in C++ is the
// same as created in Python
ModuleLookup moduleLookup_;
+ c10::optional<at::Device> device_;
std::vector<std::string> moduleStack_;
std::vector<at::Tensor> tensor_table_;
ScriptModuleDeserializer::ScriptModuleDeserializer(std::istream* is)
: ifs_(), reader_(is) {}
-void ScriptModuleDeserializer::deserialize(ModuleLookup module_lookup) {
+void ScriptModuleDeserializer::deserialize(ModuleLookup module_lookup,
+ c10::optional<at::Device> device) {
torch::ModelDef model_def;
at::DataPtr data_ptr;
size_t data_size;
model_def.ParseFromString(binary_string),
"JSON transcoder produced invalid protobuf output.");
moduleLookup_ = module_lookup;
+ device_ = device;
const auto& module_def = model_def.main_module();
loadTensorTable(&model_def);
std::vector<int64_t> strides(tensor_proto.strides().begin(), tensor_proto.strides().end());
auto type = at::typeMetaToScalarType(
caffe2::DataTypeToTypeMeta(tensor_proto.data_type()));
-
const std::string& record_key = tensor_proto.data().key();
+ AT_ASSERT(tensor_proto.has_device() && !tensor_proto.device().empty());
+ at::Device device(tensor_proto.device());
+ if (device_.has_value()) {
+ // override the device, if user provides map_location
+ device = device_.value();
+ }
+
auto storage_it = storageMap.find(record_key);
if (storage_it == storageMap.end()) {
at::DataPtr storage_ptr;
uint64_t record_size;
std::tie(storage_ptr, record_size) = reader_.getRecord(record_key);
- auto storage = at::Storage(
+ auto cpu_storage = at::Storage(
at::CPU(type).typeMeta(),
std::move(storage_ptr),
record_size / at::CPU(type).typeMeta().itemsize(),
nullptr); // NB: we didn't set any allocator for the tensor
- storage_it = storageMap.insert(std::make_pair(record_key, storage)).first;
+ if (device.type() == at::DeviceType::CPU) {
+ storage_it = storageMap.insert(std::make_pair(
+ record_key, cpu_storage)).first;
+ } else if (device.type() == at::DeviceType::CUDA) {
+ at::Tensor cpu_tensor = at::CPU(type)._th_tensor(
+ cpu_storage, tensor_proto.offset(), dims, strides);
+ at::Storage cuda_storage = cpu_tensor.to(device,
+ cpu_tensor.scalar_type()).storage();
+ storage_it = storageMap.insert(std::make_pair(
+ record_key, cuda_storage)).first;
+ } else {
+ AT_ERROR("supported devices include CPU and CUDA, however got ",
+ at::DeviceTypeName(device.type(), false));
+ }
}
- auto t = at::CPU(type)._th_tensor(
- storage_it->second, tensor_proto.offset(), dims, strides);
- return autograd::make_variable(t, tensor_proto.requires_grad());
+ AT_ASSERT(storage_it->second.device() == device);
+
+ at::Tensor result;
+ if (device.type() == at::DeviceType::CPU) {
+ result = at::CPU(type)._th_tensor(
+ storage_it->second, tensor_proto.offset(), dims, strides);
+ } else if (device.type() == at::DeviceType::CUDA) {
+ result = at::CUDA(type)._th_tensor(
+ storage_it->second, tensor_proto.offset(), dims, strides);
+ }
+ AT_ASSERT(result.defined());
+
+ result = autograd::make_variable(result, tensor_proto.requires_grad());
+
+ return result;
}
void ScriptModuleDeserializer::convertModule(
void import_ir_module(
ModuleLookup module_lookup,
- std::istream& in) {
+ std::istream& in,
+ c10::optional<at::Device> device) {
ScriptModuleDeserializer deserializer(&in);
- deserializer.deserialize(module_lookup);
+ deserializer.deserialize(module_lookup, device);
}
void import_ir_module(
ModuleLookup module_lookup,
- const std::string& filename) {
+ const std::string& filename,
+ c10::optional<at::Device> device) {
ScriptModuleDeserializer deserializer(filename);
- deserializer.deserialize(module_lookup);
+ deserializer.deserialize(module_lookup, device);
}
std::shared_ptr<script::Module> load(std::istream& in) {
};
ScriptModuleDeserializer deserializer(&in);
- deserializer.deserialize(module_lookup);
+ // TODO: add device support in C++ API
+ deserializer.deserialize(module_lookup, c10::optional<at::Device>(at::Device("cpu")));
return module;
}
TORCH_API void import_ir_module(
ModuleLookup module_lookup,
- const std::string& filename);
+ const std::string& filename,
+ c10::optional<c10::Device> device);
-TORCH_API void import_ir_module(ModuleLookup module_lookup, std::istream& in);
+TORCH_API void import_ir_module(
+ ModuleLookup module_lookup,
+ std::istream& in,
+ c10::optional<c10::Device> device);
/// Loads a serialized `script::Module` from the given `istream`.
///
});
m.def("merge_type_from_type_comment", &mergeTypesFromTypeComment);
- m.def("import_ir_module", [](ModuleLookup module_lookup, const std::string& filename) {
- import_ir_module(module_lookup, filename);
+ m.def("import_ir_module", [](ModuleLookup module_lookup, const std::string& filename,
+ py::object map_location) {
+ c10::optional<at::Device> optional_device;
+ if (!map_location.is(py::none())) {
+ AT_ASSERT(THPDevice_Check(map_location.ptr()));
+ optional_device = reinterpret_cast<THPDevice*>(map_location.ptr())->device;
+ }
+ import_ir_module(module_lookup, filename, optional_device);
});
- m.def("import_ir_module_from_buffer", [](ModuleLookup module_lookup, const std::string& buffer) {
+ m.def("import_ir_module_from_buffer", [](ModuleLookup module_lookup,
+ const std::string& buffer, py::object map_location) {
std::istringstream in(buffer);
- import_ir_module(module_lookup, in);
+ c10::optional<at::Device> optional_device;
+ if (!map_location.is(py::none())) {
+ AT_ASSERT(THPDevice_Check(map_location.ptr()));
+ optional_device = reinterpret_cast<THPDevice*>(map_location.ptr())->device;
+ }
+ import_ir_module(module_lookup, in, optional_device);
});
m.def("_jit_import_methods", import_methods);
m.def("_jit_set_emit_module_hook", setEmitModuleHook);
import torch._C
from torch import Tensor
from torch.autograd import Variable, function
+from torch.serialization import validate_cuda_device
from torch.nn import Module, ModuleList, ParameterList, Parameter, Sequential
from torch.jit.frontend import get_jit_ast, get_default_args
import torch.backends.cudnn as cudnn
import torch.jit.annotations
-from torch._six import raise_from, with_metaclass, get_function_from_type
+from torch._six import raise_from, with_metaclass, get_function_from_type, \
+ string_classes
from .._jit_internal import createResolutionCallback, _compiled_weak_fns, \
_weak_script_methods, _weak_modules, _weak_types, COMPILED, \
COMPILATION_PENDING, _boolean_dispatched
tracing_state.pop_scope()
-def load(f):
+def load(f, map_location=None):
r"""
Load a ``ScriptModule`` previously saved with :func:`save <torch.jit.save>`
- .. DANGER::
- All previously saved modules, no matter their device, are always loaded onto the CPU.
- This is different from :func:`torch.load`'s semantics and may change in the future.
+ All previously saved modules, no matter their device, are first loaded onto CPU,
+ and then are moved to the devices they were saved from. If this fails (e.g. because
+ the run time system doesn't have certain devices), an exception is raised.
+ However, storages can be dynamically remapped to an alternative set of devices
+ using the `map_location` argument. Comparing to :func:`torch.load`, `map_location`
+ in this function is simplified, which only accepts a string (e.g., 'cpu', 'cuda:0'),
+ or torch.device (e.g., torch.device('cpu'))
Arguments:
f: a file-like object (has to implement read, readline, tell, and seek),
or a string containing a file name
+ map_location: can a string (e.g., 'cpu', 'cuda:0'), a device (e.g.,
+ torch.device('cpu'))
Returns:
A ``ScriptModule`` object.
# Load ScriptModule from io.BytesIO object
>>> with open('scriptmodule.pt', 'rb') as f:
buffer = io.BytesIO(f.read())
+ # Load all tensors to the original device
>>> torch.jit.load(buffer)
+ # Load all tensors onto CPU, using a device
+ >>> torch.jit.load(buffer, map_location=torch.device('cpu'))
+ # Load all tensors onto CPU, using a string
+ >>> torch.jit.load(buffer, map_location='cpu')
"""
m = ScriptModule()
curr = getattr(curr, name)
return curr
+ if isinstance(map_location, string_classes):
+ map_location = torch.device(map_location)
+ elif not (map_location is None or
+ isinstance(map_location, torch.device)):
+ raise ValueError("map_location should be either None, string or torch.device, "
+ "but got type: " + str(type(map_location)))
+ if (str(map_location).startswith('cuda')):
+ validate_cuda_device(map_location)
+
if isinstance(f, str) or \
(sys.version_info[0] == 2 and isinstance(f, unicode)) or \
(sys.version_info[0] == 3 and isinstance(f, pathlib.Path)):
- torch._C.import_ir_module(module_lookup, f)
+ torch._C.import_ir_module(module_lookup, f, map_location)
else:
- torch._C.import_ir_module_from_buffer(module_lookup, f.read())
+ torch._C.import_ir_module_from_buffer(module_lookup, f.read(), map_location)
return m
return obj
+def validate_cuda_device(location):
+ if isinstance(location, torch.device):
+ location = str(location)
+ if not isinstance(location, _string_classes):
+ raise ValueError("location should be a string or torch.device")
+ if location[5:] == '':
+ device = 0
+ else:
+ device = max(int(location[5:]), 0)
+
+ if not torch.cuda.is_available():
+ raise RuntimeError('Attempting to deserialize object on a CUDA '
+ 'device but torch.cuda.is_available() is False. '
+ 'If you are running on a CPU-only machine, '
+ 'please use torch.load with map_location=\'cpu\' '
+ 'to map your storages to the CPU.')
+ if device >= torch.cuda.device_count():
+ raise RuntimeError('Attempting to deserialize object on CUDA device '
+ '{} but torch.cuda.device_count() is {}. Please use '
+ 'torch.load with map_location to map your storages '
+ 'to an existing device.'.format(
+ device, torch.cuda.device_count()))
+ return device
+
+
def _cuda_deserialize(obj, location):
if location.startswith('cuda'):
- if location[5:] == '':
- device = 0
- else:
- device = max(int(location[5:]), 0)
-
- if not torch.cuda.is_available():
- raise RuntimeError('Attempting to deserialize object on a CUDA '
- 'device but torch.cuda.is_available() is False. '
- 'If you are running on a CPU-only machine, '
- 'please use torch.load with map_location=\'cpu\' '
- 'to map your storages to the CPU.')
- if device >= torch.cuda.device_count():
- raise RuntimeError('Attempting to deserialize object on CUDA device '
- '{} but torch.cuda.device_count() is {}. Please use '
- 'torch.load with map_location to map your storages '
- 'to an existing device.'.format(
- device, torch.cuda.device_count()))
+ device = validate_cuda_device(location)
return obj.cuda(device)