Use non-legacy constructors for tensor deserialization. (#18750)
authorGregory Chanan <gchanan@fb.com>
Wed, 3 Apr 2019 14:51:15 +0000 (07:51 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 3 Apr 2019 14:54:11 +0000 (07:54 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18750
ghimport-source-id: f1475cfb67841c41d9867d4429ba9125d5c7dd07

Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18751 Disallow changing the device of a tensor via set_.
* **#18750 Use non-legacy constructors for tensor deserialization.**
* #18749 Add device and dtype to storage.

Deserialization currently uses legacy constructors.  This is bad because we need to maintain them, but there is a more immediate problem:
1) We are trying to implement device caching on TensorImpl to get rid of a virtual dispatch
2) This doesn't work if one is able to change the device of a Tensor underlying a Variable.
3) Deserialization does 2)

So the plan is to change deserialization, then enforce that we don't change the device out from underneath a Variable.

Differential Revision: D14729513

fbshipit-source-id: 090d6cdb375b94dc1bf4f554b2df243952b8cdc6

torch/_utils.py

index 10b9dca..ec6029a 100644 (file)
@@ -125,10 +125,9 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
 
 
 def _rebuild_tensor(storage, storage_offset, size, stride):
-    class_name = storage.__class__.__name__.replace('Storage', 'Tensor')
-    module = importlib.import_module(storage.__module__)
-    tensor_class = getattr(module, class_name)
-    return tensor_class().set_(storage, storage_offset, size, stride)
+    # first construct a tensor with the correct dtype/device
+    t = torch.tensor([], dtype=storage.dtype, device=storage.device)
+    return t.set_(storage, storage_offset, size, stride)
 
 
 def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, backward_hooks):