#pragma once
-#include <ATen/DeviceGuard.h>
-#include <c10/util/ArrayRef.h>
-#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/CUDAGuardImpl.h>
#include <c10/DeviceType.h>
#include <c10/impl/InlineDeviceGuard.h>
#include <c10/impl/InlineStreamGuard.h>
#include <cstddef>
-#include <vector>
namespace at { namespace cuda {
c10::impl::InlineOptionalStreamGuard<at::cuda::detail::CUDAGuardImpl> guard_;
};
-// TODO: Implement this generically in c10. You'll need some way to get
-// the number of GPUs from the GuardImpl, in that case.
-struct CUDAMultiStreamGuard {
- /// Calls `set_stream` on each of the streams in the list.
- /// This may be useful if you need to set different streams
- /// for different devices.
- explicit CUDAMultiStreamGuard(ArrayRef<CUDAStream> streams) : CUDAMultiStreamGuard() {
- for (const auto& s : streams) {
- setCurrentCUDAStream(s);
- }
- }
-
- CUDAMultiStreamGuard() {
- const size_t device_count = getNumGPUs();
- original_streams_.reserve(device_count);
- for (size_t device = 0; device < device_count; ++device) {
- original_streams_.push_back(getCurrentCUDAStream(device));
- }
- }
-
- CUDAMultiStreamGuard(const CUDAGuard&) = delete;
- CUDAMultiStreamGuard& operator=(const CUDAGuard&) = delete;
-
- // See Note [Move construction for RAII guards is tricky]
- CUDAMultiStreamGuard(CUDAGuard&& other) = delete;
-
- // See Note [Move assignment for RAII guards is tricky]
- CUDAMultiStreamGuard& operator=(CUDAGuard&& other) = delete;
-
- ArrayRef<CUDAStream> original_streams() const {
- return original_streams_;
- }
-
- /// Resets the CUDA stream on each device to the one that was active upon
- /// construction.
- ~CUDAMultiStreamGuard() {
- for (const auto& s : original_streams_) {
- uncheckedSetCurrentCUDAStream(s);
- }
- }
-
-private:
- /// The original streams that were active on all devices.
- std::vector<CUDAStream> original_streams_;
-};
-
} // namespace cuda
} // namespace at
--- /dev/null
+#pragma once
+
+#include <c10/util/ArrayRef.h>
+#include <ATen/cuda/CUDAStream.h>
+#include <ATen/cuda/CUDAContext.h>
+
+#include <vector>
+
+namespace at { namespace cuda {
+
+// TODO: Implement this generically in c10. You'll need some way to get
+// the number of GPUs from the GuardImpl, in that case.
+class CUDAMultiStreamGuard final {
+public:
+ /// Calls `set_stream` on each of the streams in the list.
+ /// This may be useful if you need to set different streams
+ /// for different devices.
+ explicit CUDAMultiStreamGuard(ArrayRef<CUDAStream> streams) : CUDAMultiStreamGuard() {
+ for (const auto& s : streams) {
+ setCurrentCUDAStream(s);
+ }
+ }
+
+ CUDAMultiStreamGuard() {
+ const size_t device_count = getNumGPUs();
+ original_streams_.reserve(device_count);
+ for (size_t device = 0; device < device_count; ++device) {
+ original_streams_.push_back(getCurrentCUDAStream(device));
+ }
+ }
+
+ CUDAMultiStreamGuard(const CUDAGuard&) = delete;
+ CUDAMultiStreamGuard& operator=(const CUDAGuard&) = delete;
+
+ // See Note [Move construction for RAII guards is tricky]
+ CUDAMultiStreamGuard(CUDAGuard&& other) = delete;
+
+ // See Note [Move assignment for RAII guards is tricky]
+ CUDAMultiStreamGuard& operator=(CUDAGuard&& other) = delete;
+
+ ArrayRef<CUDAStream> original_streams() const {
+ return original_streams_;
+ }
+
+ /// Resets the CUDA stream on each device to the one that was active upon
+ /// construction.
+ ~CUDAMultiStreamGuard() {
+ for (const auto& s : original_streams_) {
+ uncheckedSetCurrentCUDAStream(s);
+ }
+ }
+
+private:
+ /// The original streams that were active on all devices.
+ std::vector<CUDAStream> original_streams_;
+};
+
+}} // namespace at::cuda
#include "ATen/cuda/CUDAContext.h"
#include "ATen/cuda/CUDAGuard.h"
+#include "ATen/cuda/CUDAMultiStreamGuard.h"
#include "ATen/cuda/CUDAEvent.h"
#include "cuda_runtime.h"
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAEvent.h>
#include <ATen/cuda/CUDAGuard.h>
+#include <ATen/cuda/CUDAMultiStreamGuard.h>
#include <cstddef>
#include <memory>
#include <gloo/transport/tcp/device.h>
#include <ATen/cuda/CUDAGuard.h>
+#include <ATen/cuda/CUDAMultiStreamGuard.h>
#include <c10d/FileStore.hpp>
#include <c10d/ProcessGroupGloo.hpp>
#include <c10d/test/TestUtils.hpp>
#include <ATen/cuda/CUDAGuard.h>
+#include <ATen/cuda/CUDAMultiStreamGuard.h>
#include <ATen/cuda/CUDAStream.h>
using namespace c10d::test;