Stop swapping in Storages of the wrong device for Tensors. (#18831)
authorGregory Chanan <gchanan@fb.com>
Thu, 4 Apr 2019 13:19:54 +0000 (06:19 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 4 Apr 2019 13:25:33 +0000 (06:25 -0700)
commit486fae563d79989b31e5708fb488a1376aec391b
treebf964e0d570ef1f18ad40351047b5cee9bde1f70
parentd70c6f23f40bab5baf88a62cfd045abac4335b39
Stop swapping in Storages of the wrong device for Tensors. (#18831)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18831
ghimport-source-id: 2741e0d70ebe2c2217572c3af54ddd9d2047e342

Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18833 [STACK] Cache device on TensorImpl; clean up TensorImpl constructors.
* #18832 [STACK] Disallow changing the device of a tensor via set_.
* **#18831 [STACK] Stop swapping in Storages of the wrong device for Tensors.**

This is necessary to support device caching, see https://github.com/pytorch/pytorch/pull/18751 and https://github.com/pytorch/pytorch/pull/18578.

In library code, we potentially swap in Storages with the wrong device when device_guard is False.  This happens as follows with "view-like" operations.
1) We allocate a tensor on the 'wrong' device (because device_guard is false).
2) We swap out the 'wrong' storage with the 'right' storage using e.g. THCTensor_setStorage.

Instead, we can just construct the Tensor with the correct Storage from the beginning.  This is what we do with 'view'.

Note there are two other "view-like" cases where this happens:
1) unfold
2) set_()

Because these aren't performance critical, I just added the device_guard instead of applying the above correction.

For completeness, this also includes a test that all `device_guard: false` functions behave properly under these conditions.

Reviewed By: dzhulgakov

Differential Revision: D14766232

fbshipit-source-id: 0865c3ddae3f415df5da7a9869b1ea9f210e81bc
aten/src/ATen/native/native_functions.yaml
aten/src/TH/THStorageFunctions.hpp
aten/src/THC/generic/THCTensor.cpp
test/test_torch.py