From 8f11147d43386acec2642a23cf08710ee0ab1af5 Mon Sep 17 00:00:00 2001 From: Richard Zou Date: Tue, 8 Jan 2019 07:26:15 -0800 Subject: [PATCH] Use CUDAGuard when serializing CUDA Tensors (#15807) Summary: Fixes #15308. Before this change, `torch.save` and `torch.load` would initialize the CUDA context on GPU 0 if it hadn't been initialized already, even if the serialized tensors are only on GPU 1. This PR fixes that bug by using CUDAGuard in the storage serialization path. Pull Request resolved: https://github.com/pytorch/pytorch/pull/15807 Differential Revision: D13593201 Pulled By: zou3519 fbshipit-source-id: 4addc91ea5a5278d56a03f3d422577ee39e99897 --- torch/csrc/generic/serialization.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/torch/csrc/generic/serialization.cpp b/torch/csrc/generic/serialization.cpp index 6c297db..f4e47a4 100644 --- a/torch/csrc/generic/serialization.cpp +++ b/torch/csrc/generic/serialization.cpp @@ -2,9 +2,17 @@ #define TH_GENERIC_FILE "torch/csrc/generic/serialization.cpp" #else +#ifdef THC_GENERIC_FILE +#include +#endif + template void THPStorage_(writeFileRaw)(THWStorage *self, io fd) { +#ifdef THC_GENERIC_FILE + c10::cuda::CUDAGuard guard(self->device()); +#endif + scalar_t *data; int64_t size = THWStorage_(size)(LIBRARY_STATE self); #ifndef THC_GENERIC_FILE @@ -50,6 +58,13 @@ template void THPStorage_(writeFileRaw)(THWStorage *self, PyObject* f template THWStorage * THPStorage_(readFileRaw)(io file, THWStorage *_storage) { +#ifdef THC_GENERIC_FILE + c10::cuda::OptionalCUDAGuard guard; + if (_storage != nullptr) { + guard.set_device(_storage->device()); + } +#endif + scalar_t *data; int64_t size; doRead(file, &size, sizeof(int64_t)); -- 2.7.4