Split out CUDAMultiStreamGuard from CUDAGuard (#13912)
authorEdward Yang <ezyang@fb.com>
Mon, 19 Nov 2018 16:13:08 +0000 (08:13 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 19 Nov 2018 16:20:11 +0000 (08:20 -0800)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13912

The implementation and API of CUDAMultiStreamGuard is less mature,
and it cannot be implemented generically (yet) in c10_cuda.  This
might be a reasonable thing to do eventually, but not for now.

Reviewed By: smessmer

Differential Revision: D13046500

fbshipit-source-id: 4ea39ca1344f1ad5ae7c82c98617aa348c327848

aten/src/ATen/cuda/CUDAGuard.h
aten/src/ATen/cuda/CUDAMultiStreamGuard.h [new file with mode: 0644]
aten/src/ATen/test/stream_test.cpp
torch/csrc/distributed/c10d/ddp.cpp
torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp
torch/lib/c10d/test/ProcessGroupNCCLTest.cpp

index 9a0cc65..1af7dea 100644 (file)
@@ -1,15 +1,11 @@
 #pragma once
 
-#include <ATen/DeviceGuard.h>
-#include <c10/util/ArrayRef.h>
-#include <ATen/cuda/CUDAContext.h>
 #include <ATen/cuda/detail/CUDAGuardImpl.h>
 #include <c10/DeviceType.h>
 #include <c10/impl/InlineDeviceGuard.h>
 #include <c10/impl/InlineStreamGuard.h>
 
 #include <cstddef>
-#include <vector>
 
 namespace at { namespace cuda {
 
@@ -235,51 +231,5 @@ private:
   c10::impl::InlineOptionalStreamGuard<at::cuda::detail::CUDAGuardImpl> guard_;
 };
 
-// TODO: Implement this generically in c10.  You'll need some way to get
-// the number of GPUs from the GuardImpl, in that case.
-struct CUDAMultiStreamGuard {
-  /// Calls `set_stream` on each of the streams in the list.
-  /// This may be useful if you need to set different streams
-  /// for different devices.
-  explicit CUDAMultiStreamGuard(ArrayRef<CUDAStream> streams) : CUDAMultiStreamGuard() {
-    for (const auto& s : streams) {
-      setCurrentCUDAStream(s);
-    }
-  }
-
-  CUDAMultiStreamGuard() {
-    const size_t device_count = getNumGPUs();
-    original_streams_.reserve(device_count);
-    for (size_t device = 0; device < device_count; ++device) {
-      original_streams_.push_back(getCurrentCUDAStream(device));
-    }
-  }
-
-  CUDAMultiStreamGuard(const CUDAGuard&) = delete;
-  CUDAMultiStreamGuard& operator=(const CUDAGuard&) = delete;
-
-  // See Note [Move construction for RAII guards is tricky]
-  CUDAMultiStreamGuard(CUDAGuard&& other) = delete;
-
-  // See Note [Move assignment for RAII guards is tricky]
-  CUDAMultiStreamGuard& operator=(CUDAGuard&& other) = delete;
-
-  ArrayRef<CUDAStream> original_streams() const {
-    return original_streams_;
-  }
-
-  /// Resets the CUDA stream on each device to the one that was active upon
-  /// construction.
-  ~CUDAMultiStreamGuard() {
-    for (const auto& s : original_streams_) {
-      uncheckedSetCurrentCUDAStream(s);
-    }
-  }
-
-private:
-  /// The original streams that were active on all devices.
-  std::vector<CUDAStream> original_streams_;
-};
-
 } // namespace cuda
 } // namespace at
diff --git a/aten/src/ATen/cuda/CUDAMultiStreamGuard.h b/aten/src/ATen/cuda/CUDAMultiStreamGuard.h
new file mode 100644 (file)
index 0000000..c2484a3
--- /dev/null
@@ -0,0 +1,58 @@
+#pragma once
+
+#include <c10/util/ArrayRef.h>
+#include <ATen/cuda/CUDAStream.h>
+#include <ATen/cuda/CUDAContext.h>
+
+#include <vector>
+
+namespace at { namespace cuda {
+
+// TODO: Implement this generically in c10.  You'll need some way to get
+// the number of GPUs from the GuardImpl, in that case.
+class CUDAMultiStreamGuard final {
+public:
+  /// Calls `set_stream` on each of the streams in the list.
+  /// This may be useful if you need to set different streams
+  /// for different devices.
+  explicit CUDAMultiStreamGuard(ArrayRef<CUDAStream> streams) : CUDAMultiStreamGuard() {
+    for (const auto& s : streams) {
+      setCurrentCUDAStream(s);
+    }
+  }
+
+  CUDAMultiStreamGuard() {
+    const size_t device_count = getNumGPUs();
+    original_streams_.reserve(device_count);
+    for (size_t device = 0; device < device_count; ++device) {
+      original_streams_.push_back(getCurrentCUDAStream(device));
+    }
+  }
+
+  CUDAMultiStreamGuard(const CUDAGuard&) = delete;
+  CUDAMultiStreamGuard& operator=(const CUDAGuard&) = delete;
+
+  // See Note [Move construction for RAII guards is tricky]
+  CUDAMultiStreamGuard(CUDAGuard&& other) = delete;
+
+  // See Note [Move assignment for RAII guards is tricky]
+  CUDAMultiStreamGuard& operator=(CUDAGuard&& other) = delete;
+
+  ArrayRef<CUDAStream> original_streams() const {
+    return original_streams_;
+  }
+
+  /// Resets the CUDA stream on each device to the one that was active upon
+  /// construction.
+  ~CUDAMultiStreamGuard() {
+    for (const auto& s : original_streams_) {
+      uncheckedSetCurrentCUDAStream(s);
+    }
+  }
+
+private:
+  /// The original streams that were active on all devices.
+  std::vector<CUDAStream> original_streams_;
+};
+
+}} // namespace at::cuda
index 2dfc469..327285f 100644 (file)
@@ -2,6 +2,7 @@
 
 #include "ATen/cuda/CUDAContext.h"
 #include "ATen/cuda/CUDAGuard.h"
+#include "ATen/cuda/CUDAMultiStreamGuard.h"
 #include "ATen/cuda/CUDAEvent.h"
 
 #include "cuda_runtime.h"
index 898a7db..b80963f 100644 (file)
@@ -10,6 +10,7 @@
 #include <ATen/ATen.h>
 #include <ATen/cuda/CUDAEvent.h>
 #include <ATen/cuda/CUDAGuard.h>
+#include <ATen/cuda/CUDAMultiStreamGuard.h>
 
 #include <cstddef>
 #include <memory>
index fcdd9aa..69e2180 100644 (file)
@@ -1,6 +1,7 @@
 #include <gloo/transport/tcp/device.h>
 
 #include <ATen/cuda/CUDAGuard.h>
+#include <ATen/cuda/CUDAMultiStreamGuard.h>
 
 #include <c10d/FileStore.hpp>
 #include <c10d/ProcessGroupGloo.hpp>
index 158eef1..7801053 100644 (file)
@@ -6,6 +6,7 @@
 #include <c10d/test/TestUtils.hpp>
 
 #include <ATen/cuda/CUDAGuard.h>
+#include <ATen/cuda/CUDAMultiStreamGuard.h>
 #include <ATen/cuda/CUDAStream.h>
 
 using namespace c10d::test;