def _rebuild_tensor(storage, storage_offset, size, stride):
- class_name = storage.__class__.__name__.replace('Storage', 'Tensor')
- module = importlib.import_module(storage.__module__)
- tensor_class = getattr(module, class_name)
- return tensor_class().set_(storage, storage_offset, size, stride)
+ # first construct a tensor with the correct dtype/device
+ t = torch.tensor([], dtype=storage.dtype, device=storage.device)
+ return t.set_(storage, storage_offset, size, stride)
def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):