Use CUDAGuard when serializing CUDA Tensors (#15807)
authorRichard Zou <zou3519@gmail.com>
Tue, 8 Jan 2019 15:26:15 +0000 (07:26 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 8 Jan 2019 15:31:50 +0000 (07:31 -0800)
Summary:
Fixes #15308. Before this change, `torch.save` and `torch.load` would
initialize the CUDA context on GPU 0 if it hadn't been initialized
already, even if the serialized tensors are only on GPU 1.

This PR fixes that bug by using CUDAGuard in the storage serialization
path.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15807

Differential Revision: D13593201

Pulled By: zou3519

fbshipit-source-id: 4addc91ea5a5278d56a03f3d422577ee39e99897

torch/csrc/generic/serialization.cpp

index 6c297db..f4e47a4 100644 (file)
@@ -2,9 +2,17 @@
 #define TH_GENERIC_FILE "torch/csrc/generic/serialization.cpp"
 #else
 
+#ifdef THC_GENERIC_FILE
+#include <c10/cuda/CUDAGuard.h>
+#endif
+
 template <class io>
 void THPStorage_(writeFileRaw)(THWStorage *self, io fd)
 {
+#ifdef THC_GENERIC_FILE
+  c10::cuda::CUDAGuard guard(self->device());
+#endif
+
   scalar_t *data;
   int64_t size = THWStorage_(size)(LIBRARY_STATE self);
 #ifndef THC_GENERIC_FILE
@@ -50,6 +58,13 @@ template void THPStorage_(writeFileRaw<PyObject*>)(THWStorage *self, PyObject* f
 template <class io>
 THWStorage * THPStorage_(readFileRaw)(io file, THWStorage *_storage)
 {
+#ifdef THC_GENERIC_FILE
+  c10::cuda::OptionalCUDAGuard guard;
+  if (_storage != nullptr) {
+    guard.set_device(_storage->device());
+  }
+#endif
+
   scalar_t *data;
   int64_t size;
   doRead(file, &size, sizeof(int64_t));