Replace caffe2::DeviceGuard with c10::cuda::CUDAGuard (#17623)
authorEdward Yang <ezyang@fb.com>
Wed, 6 Mar 2019 18:32:38 +0000 (10:32 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 6 Mar 2019 18:48:15 +0000 (10:48 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17623

Despite it's generic sounding name, caffe2::DeviceGuard actually
only worked on CUDA devices.  Rename it to something that more
clearly spells out its applicability.

I'm not sure if it's the right call, but in this patch I added
'using CUDAGuard = c10::cuda::CUDAGuard', as this seems to be more
in-line with how the Caffe2 codebase is currently written.  More
idiomatic c10 namespace style would be to say cuda::CUDAGuard.
Willing to change this if people shout.

This is a respin of D13156470 (#14284)

Reviewed By: dzhulgakov

Differential Revision: D14285504

fbshipit-source-id: 93b8ab938b064572b3b010c307e1261fde0fff3d

caffe2/contrib/nccl/cuda_nccl_gpu.cc
caffe2/core/blob_gpu_test.cc
caffe2/core/blob_serialization.cc
caffe2/core/common_gpu.h
caffe2/core/context_gpu.cu
caffe2/core/context_gpu.h
caffe2/core/context_gpu_test.cc
caffe2/core/cudnn_wrappers.h
caffe2/core/event_gpu.cc
caffe2/core/hip/miopen_wrapper.h
caffe2/python/pybind_state.h

index 8eb7085..1e63bc6 100644 (file)
@@ -26,7 +26,7 @@ class NCCLContext {
     streams_.resize(devices_.size());
     events_.resize(devices_.size());
     for (auto i = 0; i < devices_.size(); ++i) {
-      DeviceGuard g(devices_[i]);
+      CUDAGuard g(devices_[i]);
       // get stream priorities
       int lo_pri, hi_pri;
       CUDA_ENFORCE(cudaDeviceGetStreamPriorityRange(&lo_pri, &hi_pri));
@@ -35,18 +35,18 @@ class NCCLContext {
       CUDA_ENFORCE(cudaEventCreateWithFlags(
           &events_[i], cudaEventDefault | cudaEventDisableTiming));
     }
-    DeviceGuard g(master_gpu_id_);
+    CUDAGuard g(master_gpu_id_);
     CUDA_ENFORCE(cudaEventCreateWithFlags(
         &master_event_, cudaEventDefault | cudaEventDisableTiming));
   }
 
   ~NCCLContext() {
     for (auto i = 0; i < devices_.size(); ++i) {
-      DeviceGuard g(devices_[i]);
+      CUDAGuard g(devices_[i]);
       CUDA_ENFORCE(cudaStreamDestroy(streams_[i]));
       CUDA_ENFORCE(cudaEventDestroy(events_[i]));
     }
-    DeviceGuard g(master_gpu_id_);
+    CUDAGuard g(master_gpu_id_);
     CUDA_ENFORCE(cudaEventDestroy(master_event_));
 
     /*
@@ -137,7 +137,7 @@ void runNCCL(const NCCLExecution& ex, InitF&& init_f, F&& f) {
   // do initialization
   for (auto i = 0; i < ex.elements.size(); ++i) {
     auto& ctx = ex.elements[i];
-    DeviceGuard g(ctx.device);
+    CUDAGuard g(ctx.device);
     init_f(ex.elements[i]);
   }
 
@@ -150,7 +150,7 @@ void runNCCL(const NCCLExecution& ex, InitF&& init_f, F&& f) {
   // children streams, so the children streams are synchronized WRT
   // the original stream.
   {
-    DeviceGuard g(ex.stream_gpu_id);
+    CUDAGuard g(ex.stream_gpu_id);
     CUDA_ENFORCE(cudaEventRecord(context->master_event_, ex.stream));
   }
 
@@ -164,7 +164,7 @@ void runNCCL(const NCCLExecution& ex, InitF&& init_f, F&& f) {
 
     for (auto i = 0; i < ex.elements.size(); ++i) {
       auto& ctx = ex.elements[i];
-      DeviceGuard g(ctx.device);
+      CUDAGuard g(ctx.device);
       auto& comm = comms[i];
       auto& stream = streams[i];
       auto& event = events[i];
@@ -180,7 +180,7 @@ void runNCCL(const NCCLExecution& ex, InitF&& init_f, F&& f) {
 
     for (auto i = 0; i < ex.elements.size(); ++i) {
       auto& ctx = ex.elements[i];
-      DeviceGuard g(ctx.device);
+      CUDAGuard g(ctx.device);
       auto& comm = comms[i];
       auto& stream = streams[i];
       auto& event = events[i];
@@ -192,7 +192,7 @@ void runNCCL(const NCCLExecution& ex, InitF&& init_f, F&& f) {
   }
 
   // Now, wait on all the events in the original stream.
-  DeviceGuard dg(ex.stream_gpu_id);
+  CUDAGuard dg(ex.stream_gpu_id);
   for (auto& event : events) {
     CUDA_ENFORCE(cudaStreamWaitEvent(CHECK_NOTNULL(ex.stream), event, 0));
   }
index 8e5d34e..de6ea99 100644 (file)
@@ -190,7 +190,7 @@ TEST(TensorTest, TensorSerializationMultiDevices) {
     tensor.mutable_data<float>()[i] = i;
   }
   for (int gpu_id = 0; gpu_id < NumCudaDevices(); ++gpu_id) {
-    DeviceGuard guard(gpu_id);
+    CUDAGuard guard(gpu_id);
     CUDAContext context(gpu_id); // switch to the current gpu
     blob.Reset(new Tensor(tensor, CUDA));
     string serialized = SerializeBlob(blob, "test");
index b315e7c..34cfc98 100644 (file)
@@ -223,7 +223,7 @@ void TensorSerializer::Serialize(
   const TensorProto::DataType data_type = TypeMetaToDataType(input.dtype());
   proto.set_data_type(data_type);
   StoreDeviceDetail(input, &proto);
-  // TODO: use DeviceGuard here instead of context and employ explicit sync
+  // TODO: use CUDAGuard here instead of context and employ explicit sync
   // copy
   auto uniq_ptr = CreateContext(input.GetDevice());
   // A lot of copypaste is error prone. Should we create a macro for this?
index 4a69fc2..6e65bb5 100644 (file)
@@ -27,6 +27,7 @@
 
 #include "c10/cuda/CUDAMacros.h"
 #include "c10/cuda/CUDAMathCompat.h"
+#include <c10/cuda/CUDAGuard.h>
 
 // Defines CAFFE2_CUDA_EXPORT and CAFFE2_CUDA_IMPORT. On Windows, this
 // corresponds to different declarations (dllexport and dllimport). On
@@ -371,21 +372,7 @@ inline dim3 CAFFE_GET_BLOCKS_2D(const int N, const int /* M */) {
   return grid;
 }
 
-class DeviceGuard {
- public:
-  explicit DeviceGuard(int newDevice) : previous_(CaffeCudaGetDevice()) {
-    if (previous_ != newDevice) {
-      CaffeCudaSetDevice(newDevice);
-    }
-  }
-
-  ~DeviceGuard() noexcept {
-    CaffeCudaSetDevice(previous_);
-  }
-
- private:
-  int previous_;
-};
+using CUDAGuard = c10::cuda::CUDAGuard;
 
 template <typename T, int N>
 struct SimpleArray {
index 51ab9f0..d7f7fc6 100644 (file)
@@ -112,7 +112,7 @@ void CUDAContext::CopyBytesSync(
   // This emulates Caffe2 original behavior where sync copy doesn't change the
   // device. It's probably better for clarity to switch to the target device
   // explicitly here, but in the worst case CUDA would sync for us.
-  // TODO: change it to DeviceGuard
+  // TODO: change it to CUDAGuard
   CUDAContext context(-1); // take current device
   CUDA_ENFORCE(cudaMemcpyAsync(
       dst, src, nbytes, cudaMemcpyDefault, context.cuda_stream()));
@@ -212,7 +212,7 @@ static void Caffe2InitializeCuda() {
       "). Increase that and recompile.");
 
   for (DeviceIndex i = 0; i < NumCudaDevices(); ++i) {
-    DeviceGuard g(i);
+    CUDAGuard g(i);
     // Enable peer access.
     const int peer_group = i / CAFFE2_CUDA_MAX_PEER_SIZE;
     const int peer_start = peer_group * CAFFE2_CUDA_MAX_PEER_SIZE;
index a4b15a2..5cce1a0 100644 (file)
@@ -103,7 +103,7 @@ class CAFFE2_CUDA_API ThreadLocalCUDAObjects {
   }
 
   cublasHandle_t GetHandle(c10::cuda::CUDAStream cuda_stream) {
-    DeviceGuard guard(cuda_stream.device_index());
+    CUDAGuard guard(cuda_stream.device_index());
     // Default construct in the map if it doesn't exist, and return a mutable
     // refernce to it.
     auto& r = cublas_handles_[cuda_stream];
@@ -127,7 +127,7 @@ class CAFFE2_CUDA_API ThreadLocalCUDAObjects {
   }
 
   cudnnHandle_t GetCudnnHandle(c10::cuda::CUDAStream cuda_stream) {
-    DeviceGuard guard(cuda_stream.device_index());
+    CUDAGuard guard(cuda_stream.device_index());
     auto& r = cudnn_handles_[cuda_stream];
     if (r == nullptr) {
       CUDNN_ENFORCE(cudnnCreate(&r));
@@ -234,7 +234,7 @@ class CAFFE2_CUDA_API CUDAContext final : public BaseContext {
 
   curandGenerator_t& curand_generator() {
     if (!curand_generator_) {
-      DeviceGuard guard(gpu_id_);
+      CUDAGuard guard(gpu_id_);
       CURAND_ENFORCE(
           curandCreateGenerator(&curand_generator_, CURAND_RNG_PSEUDO_DEFAULT));
       CURAND_ENFORCE(
index e354f28..b27dddf 100644 (file)
@@ -43,7 +43,7 @@ TEST(CUDAContextTest, MemoryPoolAllocateDealloc) {
   const int nbytes = 1048576;
   for (int i = 0; i < NumCudaDevices(); ++i) {
     LOG(INFO) << "Device " << i << " of " << NumCudaDevices();
-    DeviceGuard guard(i);
+    CUDAGuard guard(i);
     auto allocated = CUDAContext::New(nbytes);
     EXPECT_NE(allocated, nullptr);
     cudaPointerAttributes attr;
index 5507092..02efba7 100644 (file)
@@ -87,7 +87,7 @@ struct CuDNNWorkspace {
 class CuDNNState {
  public:
   explicit CuDNNState(size_t gpu_id) : gpu_id_(gpu_id) {
-    DeviceGuard g(gpu_id_);
+    CUDAGuard g(gpu_id_);
     CUDNN_ENFORCE(cudnnCreate(&cudnn_handle_));
     CUDA_ENFORCE(cudaEventCreate(&before_));
     CUDA_ENFORCE(cudaEventCreate(&after_));
@@ -96,7 +96,7 @@ class CuDNNState {
   }
 
   ~CuDNNState() noexcept {
-    DeviceGuard g(gpu_id_);
+    CUDAGuard g(gpu_id_);
     CUDNN_CHECK(cudnnDestroy(cudnn_handle_));
     CUDA_CHECK(cudaStreamDestroy(stream_));
     CUDA_CHECK(cudaEventDestroy(after_));
@@ -162,7 +162,7 @@ class CuDNNWrapper {
         state_idx < CAFFE2_COMPILE_TIME_MAX_CUDNN_STATES, "Invalid state_idx");
     auto& sync_state = cudnn_states()[context_->device_id()][state_idx];
 
-    DeviceGuard dg(context_->device_id());
+    CUDAGuard dg(context_->device_id());
 
     // We need to serialize execution on the CuDNNState as we can't
     // allow multiple threads to race through the cudaEventRecord
index b2898ad..f2ec2f8 100644 (file)
@@ -12,12 +12,12 @@ struct CudaEventWrapper {
         device_id_(option.device_id()),
         status_(EventStatus::EVENT_INITIALIZED) {
     CAFFE_ENFORCE(option.device_type(), PROTO_CUDA);
-    DeviceGuard g(device_id_);
+    CUDAGuard g(device_id_);
     CUDA_ENFORCE(cudaEventCreateWithFlags(
         &cuda_event_, cudaEventDefault | cudaEventDisableTiming));
   }
   ~CudaEventWrapper() {
-    DeviceGuard g(device_id_);
+    CUDAGuard g(device_id_);
     CUDA_CHECK(cudaEventDestroy(cuda_event_));
   }
 
@@ -96,7 +96,7 @@ void EventFinishCUDA(const Event* event) {
 
   if (wrapper->status_ == EventStatus::EVENT_SCHEDULED) {
     // ok, even if event is already completed and status was not yet updated
-    DeviceGuard g(wrapper->device_id_);
+    CUDAGuard g(wrapper->device_id_);
     auto cudaResult = cudaEventSynchronize(wrapper->cuda_event_);
     if (cudaResult == cudaSuccess) {
       wrapper->status_ = EventStatus::EVENT_SUCCESS;
index 4eefb5e..1df9949 100644 (file)
@@ -5,6 +5,8 @@
 #include "caffe2/core/hip/common_miopen.h"
 #include "caffe2/core/hip/context_gpu.h"
 
+#include <c10/hip/HIPGuard.h>
+
 namespace caffe2 {
 
 class MIOPENWrapper;
@@ -53,7 +55,7 @@ class MIOPENState
     public:
     explicit MIOPENState(size_t gpu_id) : gpu_id_(gpu_id)
     {
-        DeviceGuard g(gpu_id_);
+        HIPGuard g(gpu_id_);
         MIOPEN_ENFORCE(miopenCreate(&miopen_handle_));
         HIP_ENFORCE(hipEventCreate(&before_));
         HIP_ENFORCE(hipEventCreate(&after_));
@@ -63,7 +65,7 @@ class MIOPENState
 
     ~MIOPENState() noexcept
     {
-        DeviceGuard g(gpu_id_);
+        HIPGuard g(gpu_id_);
         MIOPEN_CHECK(miopenDestroy(miopen_handle_));
         HIP_CHECK(hipStreamDestroy(stream_));
         HIP_CHECK(hipEventDestroy(after_));
@@ -125,7 +127,7 @@ class MIOPENWrapper
         CAFFE_ENFORCE(state_idx < CAFFE2_COMPILE_TIME_MAX_MIOPEN_STATES, "Invalid state_idx");
         auto& sync_state = miopen_states()[context_->device_id()][state_idx];
 
-        DeviceGuard dg(context_->device_id());
+        HIPGuard dg(context_->device_id());
 
         // We need to serialize execution on the MIOPENState as we can't
         // allow multiple threads to race through the cudaEventRecord
index 3415de7..637bf47 100644 (file)
@@ -162,7 +162,7 @@ class TensorFetcher : public BlobFetcherBase {
     }
 
     if (result.copied) {
-      // TODO: use DeviceGuard here instead of context and employ explicit sync
+      // TODO: use CUDAGuard here instead of context and employ explicit sync
       // copy
       auto context = CreateContext(tensor.GetDeviceType());
       context->CopyBytesToCPU(tensor.nbytes(), tensor.raw_data(), outPtr);