def _cuda_deserialize(obj, location):
if location.startswith('cuda'):
device = validate_cuda_device(location)
- return obj.cuda(device)
+ if getattr(obj, "_torch_load_uninitialized", False):
+ storage_type = getattr(torch.cuda, type(obj).__name__)
+ with torch.cuda.device(device):
+ return storage_type(obj.size())
+ else:
+ return obj.cuda(device)
register_package(10, _cpu_tag, _cpu_deserialize)
data_type, root_key, location, size, view_metadata = data
location = maybe_decode_ascii(location)
if root_key not in deserialized_objects:
- deserialized_objects[root_key] = restore_location(
- data_type(size), location)
+ obj = data_type(size)
+ obj._torch_load_uninitialized = True
+ deserialized_objects[root_key] = restore_location(obj, location)
storage = deserialized_objects[root_key]
if view_metadata is not None:
view_key, offset, view_size = view_metadata