From 2d485ffb17ebdec54b399df591e1da031e101d46 Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Wed, 12 Dec 2018 11:19:03 -0800 Subject: [PATCH] Move CUDAGuard, CUDAStream and CUDAGuardImpl to c10/cuda (#14248) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/14248 This diff also introduces a horrifying hack to override CUDA's DeviceGuardImpl with a HIPGuardImplMasqueradingAsCUDA, to accommodate PyTorch's current behavior of pretending CUDA is HIP when you build with ROCm enabled. Reviewed By: bddppq Differential Revision: D13145293 fbshipit-source-id: ee0e207b6fd132f0d435512957424a002d588f02 --- aten/src/ATen/CMakeLists.txt | 4 +- aten/src/ATen/cuda/CUDAContext.cpp | 2 + aten/src/ATen/cuda/CUDAContext.h | 2 +- aten/src/ATen/cuda/CUDAEvent.h | 10 +-- aten/src/ATen/cuda/CUDAMultiStreamGuard.h | 2 +- .../hip/impl/HIPGuardImplMasqueradingAsCUDA.cpp | 14 ++++ .../ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h | 79 ++++++++++++++++++++++ aten/src/ATen/miopen/Utils.h | 2 +- aten/src/ATen/native/DispatchStub.h | 3 +- aten/src/ATen/native/cuda/Copy.cu | 2 +- aten/src/ATen/native/cuda/Resize.cuh | 2 +- aten/src/ATen/test/cuda_stream_test.cpp | 2 +- aten/src/THC/THCCachingAllocator.cpp | 2 +- aten/src/THC/THCCachingAllocator.h | 3 +- aten/src/THC/THCCachingHostAllocator.h | 2 +- aten/src/THC/THCGeneral.cpp | 2 +- aten/src/THC/THCGeneral.h.in | 2 +- aten/src/THC/THCStream.cpp | 3 +- c10/cuda/CMakeLists.txt | 11 ++- {aten/src/ATen => c10}/cuda/CUDAGuard.h | 15 ++-- {aten/src/ATen => c10}/cuda/CUDAStream.cpp | 8 +-- {aten/src/ATen => c10}/cuda/CUDAStream.h | 10 +-- .../detail => c10/cuda/impl}/CUDAGuardImpl.cpp | 6 +- .../cuda/detail => c10/cuda/impl}/CUDAGuardImpl.h | 20 +++--- c10/impl/InlineDeviceGuard.h | 3 +- docs/cpp/source/Doxyfile | 4 +- tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py | 33 +++++++++ tools/amd_build/pyHIPIFY/hipify_python.py | 5 +- tools/cwrap/plugins/AutoGPU.py | 2 +- tools/cwrap/plugins/NNExtension.py | 8 ++- torch/csrc/autograd/engine.cpp | 13 ++-- torch/csrc/autograd/profiler.cpp | 2 +- torch/csrc/cuda/Stream.cpp | 2 +- torch/csrc/cuda/comm.cpp | 2 +- torch/csrc/cuda/nccl.cpp | 2 +- torch/csrc/cuda/python_nccl.cpp | 2 +- torch/csrc/distributed/c10d/ddp.cpp | 2 +- torch/csrc/generic/StorageSharing.cpp | 2 +- torch/csrc/jit/fuser/cuda/fused_kernel.cpp | 2 +- torch/csrc/utils.h | 2 +- .../lib/THD/base/data_channels/DataChannelNccl.cpp | 2 +- torch/lib/c10d/ProcessGroupGloo.cpp | 4 +- torch/lib/c10d/ProcessGroupGloo.hpp | 2 +- torch/lib/c10d/ProcessGroupNCCL.cpp | 2 +- torch/lib/c10d/test/CUDATest.hpp | 2 +- torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp | 2 +- torch/lib/c10d/test/ProcessGroupNCCLTest.cpp | 4 +- 47 files changed, 229 insertions(+), 83 deletions(-) create mode 100644 aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.cpp create mode 100644 aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h rename {aten/src/ATen => c10}/cuda/CUDAGuard.h (96%) rename {aten/src/ATen => c10}/cuda/CUDAStream.cpp (98%) rename {aten/src/ATen => c10}/cuda/CUDAStream.h (97%) rename {aten/src/ATen/cuda/detail => c10/cuda/impl}/CUDAGuardImpl.cpp (58%) rename {aten/src/ATen/cuda/detail => c10/cuda/impl}/CUDAGuardImpl.h (73%) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 765ac38..35e67b1 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -37,8 +37,8 @@ FILE(GLOB cudnn_h "cudnn/*.h" "cudnn/*.cuh") FILE(GLOB cudnn_cpp "cudnn/*.cpp") FILE(GLOB hip_h "hip/*.h" "hip/detail/*.h" "hip/*.cuh" "hip/detail/*.cuh") -FILE(GLOB hip_cpp "hip/*.cpp" "hip/detail/*.cpp") -FILE(GLOB hip_hip "hip/*.hip" "hip/detail/*.hip") +FILE(GLOB hip_cpp "hip/*.cpp" "hip/detail/*.cpp" "hip/impl/*.cpp") +FILE(GLOB hip_hip "hip/*.hip" "hip/detail/*.hip" "hip/impl/*.hip") FILE(GLOB miopen_h "miopen/*.h") FILE(GLOB miopen_cpp "miopen/*.cpp") diff --git a/aten/src/ATen/cuda/CUDAContext.cpp b/aten/src/ATen/cuda/CUDAContext.cpp index 30dd2a3..7f9e6ef 100644 --- a/aten/src/ATen/cuda/CUDAContext.cpp +++ b/aten/src/ATen/cuda/CUDAContext.cpp @@ -1,6 +1,8 @@ #include #include +#include + namespace at { namespace cuda { /* Device info */ diff --git a/aten/src/ATen/cuda/CUDAContext.h b/aten/src/ATen/cuda/CUDAContext.h index 97cdb38..ffd2802 100644 --- a/aten/src/ATen/cuda/CUDAContext.h +++ b/aten/src/ATen/cuda/CUDAContext.h @@ -2,7 +2,7 @@ #include #include -#include +#include #include #include diff --git a/aten/src/ATen/cuda/CUDAEvent.h b/aten/src/ATen/cuda/CUDAEvent.h index 51d6b12..b12f0b7 100644 --- a/aten/src/ATen/cuda/CUDAEvent.h +++ b/aten/src/ATen/cuda/CUDAEvent.h @@ -2,8 +2,8 @@ #include #include -#include -#include +#include +#include #include #include @@ -35,7 +35,7 @@ struct AT_CUDA_API CUDAEvent { ~CUDAEvent() { try { if (is_created_) { - at::cuda::CUDAGuard device_guard(static_cast(device_index_)); + CUDAGuard device_guard(static_cast(device_index_)); cudaEventDestroy(event_); } } catch (...) { /* No throw */ } @@ -74,7 +74,7 @@ struct AT_CUDA_API CUDAEvent { // Note: cudaEventRecord must be called on the same device as the stream. void record(const CUDAStream& stream) { - at::cuda::CUDAGuard guard(static_cast(stream.device_index())); + CUDAGuard guard(static_cast(stream.device_index())); if (is_created_) { AT_ASSERT(device_index_ == stream.device_index()); @@ -92,7 +92,7 @@ struct AT_CUDA_API CUDAEvent { // The event has no actual GPU resources associated with it. void block(const CUDAStream& stream) { if (is_created_) { - at::cuda::CUDAGuard guard(static_cast(stream.device_index())); + CUDAGuard guard(static_cast(stream.device_index())); AT_CUDA_CHECK(cudaStreamWaitEvent(stream, event_, 0)); } } diff --git a/aten/src/ATen/cuda/CUDAMultiStreamGuard.h b/aten/src/ATen/cuda/CUDAMultiStreamGuard.h index b4ed2e6..7fd549c 100644 --- a/aten/src/ATen/cuda/CUDAMultiStreamGuard.h +++ b/aten/src/ATen/cuda/CUDAMultiStreamGuard.h @@ -1,7 +1,7 @@ #pragma once #include -#include +#include #include #include diff --git a/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.cpp b/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.cpp new file mode 100644 index 0000000..2215f55 --- /dev/null +++ b/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.cpp @@ -0,0 +1,14 @@ +#include + +// THIS IS A MASSIVE HACK. This will BREAK you Caffe2 CUDA code if you +// load ATen_hip, even if you don't ever actually use ATen_hip at runtime. +// +// If you ever link ATen_hip statically into the full library along +// with ATen_cuda (libomnibus), the loading order of this versus the regular +// ATen_cuda will be nondeterministic, and you'll nondeterministically get +// one or the other. (This will be obvious because all of your code +// will fail.) +// +// This hack can be removed once PyTorch is out-of-place HIPified, and +// doesn't pretend CUDA is HIP. +C10_REGISTER_GUARD_IMPL(CUDA, at::cuda::HIPGuardImplMasqueradingAsCUDA); diff --git a/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h new file mode 100644 index 0000000..3ccda82 --- /dev/null +++ b/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h @@ -0,0 +1,79 @@ +#pragma once + +#include + +// The includes of HIPGuard.h +#include +#include +#include +#include +#include + +#include + +// Use of c10::hip namespace here makes hipification easier, because +// I don't have to also fix namespaces. Sorry! +namespace c10 { namespace hip { + +// HIPGuardImplMasqueradingAsCUDA is like a HIPGuardImpl, but +// it reports its DeviceType as CUDA (e.g., type() returns CUDA, +// getDevice() reports the current HIP device as a CUDA device.) +// We can't directly use HIPGuardImpl, since it (piously) requires +// the DeviceType to be HIP. +// +// This is necessary for PyTorch at the moment, which is implemented +// by pretending that CUDA is actually HIP. Eventually, we want +// to make PyTorch treat HIP as a separate DeviceType, and then we +// can delete this class. +// +// Also, note that the cpp file associated with this also *overwrites* +// the entry in the DeviceGuardImpl registry for CUDA with this HIP +// implementation. +struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplInterface { + static constexpr DeviceType static_type = DeviceType::CUDA; + HIPGuardImplMasqueradingAsCUDA() {} + HIPGuardImplMasqueradingAsCUDA(DeviceType t) { + AT_ASSERT(t == DeviceType::CUDA); + } + DeviceType type() const override { + return DeviceType::CUDA; + } + Device exchangeDevice(Device d) const override { + AT_ASSERT(d.type() == DeviceType::CUDA); + Device old_device = getDevice(); + if (old_device.index() != d.index()) { + C10_HIP_CHECK(hipSetDevice(d.index())); + } + return old_device; + } + Device getDevice() const override { + int device; + C10_HIP_CHECK(hipGetDevice(&device)); + return Device(DeviceType::CUDA, device); + } + void setDevice(Device d) const override { + AT_ASSERT(d.type() == DeviceType::CUDA); + C10_HIP_CHECK(hipSetDevice(d.index())); + } + void uncheckedSetDevice(Device d) const noexcept override { + hipSetDevice(d.index()); + } + Stream getStream(Device d) const noexcept override { + return getCurrentHIPStream().unwrap(); + } + Stream exchangeStream(Stream s) const noexcept override { + HIPStream cs(s); + auto old_stream = getCurrentHIPStream(s.device().index()); + setCurrentHIPStream(cs); + return old_stream.unwrap(); + } +}; + +// All of the guards which have HIPGuardImpl burned in need to also have +// variants using HIPGuardImplMasqueradingAsCUDA. +using HIPGuardMasqueradingAsCUDA = c10::impl::InlineDeviceGuard; +using OptionalHIPGuardMasqueradingAsCUDA = c10::impl::InlineOptionalDeviceGuard; +using HIPStreamGuardMasqueradingAsCUDA = c10::impl::InlineStreamGuard; +using OptionalHIPStreamGuardMasqueradingAsCUDA = c10::impl::InlineOptionalStreamGuard; + +}} // namespace c10::hip diff --git a/aten/src/ATen/miopen/Utils.h b/aten/src/ATen/miopen/Utils.h index c264650..8a48b8e 100644 --- a/aten/src/ATen/miopen/Utils.h +++ b/aten/src/ATen/miopen/Utils.h @@ -10,7 +10,7 @@ namespace at { namespace native { inline void setMIOpenStreamToCurrent() { // NB: Due to in-place HIPify, getCurrentCUDAStream actually means // getCurrentHIPStream - MIOPEN_CHECK(miopenSetStream(getMiopenHandle(), at::cuda::getCurrentCUDAStream())); + MIOPEN_CHECK(miopenSetStream(getMiopenHandle(), at::hip::getCurrentHIPStream())); } // This function makes tensors which have zero stride contiguous, by diff --git a/aten/src/ATen/native/DispatchStub.h b/aten/src/ATen/native/DispatchStub.h index 89011fa..d0c303d 100644 --- a/aten/src/ATen/native/DispatchStub.h +++ b/aten/src/ATen/native/DispatchStub.h @@ -119,7 +119,8 @@ struct RegisterCUDADispatch { template struct RegisterHIPDispatch { RegisterHIPDispatch(DispatchStub& stub, FnPtr value) { - stub.hip_dispatch_ptr = value; + // TODO: make this point at hip_dispatch_ptr + stub.cuda_dispatch_ptr = value; } }; } // anonymous namespace diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu index 85c72d1..fabaa0d 100644 --- a/aten/src/ATen/native/cuda/Copy.cu +++ b/aten/src/ATen/native/cuda/Copy.cu @@ -5,7 +5,7 @@ #include #include #include -#include +#include #include namespace { diff --git a/aten/src/ATen/native/cuda/Resize.cuh b/aten/src/ATen/native/cuda/Resize.cuh index 3520f20..9fcfb0f 100644 --- a/aten/src/ATen/native/cuda/Resize.cuh +++ b/aten/src/ATen/native/cuda/Resize.cuh @@ -3,7 +3,7 @@ #include #include -#include +#include namespace at { namespace native { diff --git a/aten/src/ATen/test/cuda_stream_test.cpp b/aten/src/ATen/test/cuda_stream_test.cpp index 674c43a..22cd5f7 100644 --- a/aten/src/ATen/test/cuda_stream_test.cpp +++ b/aten/src/ATen/test/cuda_stream_test.cpp @@ -1,7 +1,7 @@ #include #include -#include +#include #include #include diff --git a/aten/src/THC/THCCachingAllocator.cpp b/aten/src/THC/THCCachingAllocator.cpp index 7f5ab16..44ebac1 100644 --- a/aten/src/THC/THCCachingAllocator.cpp +++ b/aten/src/THC/THCCachingAllocator.cpp @@ -2,7 +2,7 @@ #include #include -#include +#include #include #include diff --git a/aten/src/THC/THCCachingAllocator.h b/aten/src/THC/THCCachingAllocator.h index 3735368..626694a 100644 --- a/aten/src/THC/THCCachingAllocator.h +++ b/aten/src/THC/THCCachingAllocator.h @@ -2,7 +2,8 @@ #define THC_DEVICE_ALLOCATOR_INC #ifdef __cplusplus -#include +#include +#include #endif #if (__cplusplus >= 201103L) || (defined(_MSC_VER) && defined(__cplusplus)) diff --git a/aten/src/THC/THCCachingHostAllocator.h b/aten/src/THC/THCCachingHostAllocator.h index 759b30c..66acf01 100644 --- a/aten/src/THC/THCCachingHostAllocator.h +++ b/aten/src/THC/THCCachingHostAllocator.h @@ -5,7 +5,7 @@ #ifdef __cplusplus -#include +#include #endif // diff --git a/aten/src/THC/THCGeneral.cpp b/aten/src/THC/THCGeneral.cpp index 75f74aa..0290e5e 100644 --- a/aten/src/THC/THCGeneral.cpp +++ b/aten/src/THC/THCGeneral.cpp @@ -5,7 +5,7 @@ #include #include -#include +#include #include #include diff --git a/aten/src/THC/THCGeneral.h.in b/aten/src/THC/THCGeneral.h.in index f01ca2d..ca510e2 100644 --- a/aten/src/THC/THCGeneral.h.in +++ b/aten/src/THC/THCGeneral.h.in @@ -9,7 +9,7 @@ #undef expm1 #ifdef __cplusplus -#include +#include #endif #include diff --git a/aten/src/THC/THCStream.cpp b/aten/src/THC/THCStream.cpp index af4a6ff..f3a6ae2 100644 --- a/aten/src/THC/THCStream.cpp +++ b/aten/src/THC/THCStream.cpp @@ -1,2 +1 @@ - -#include +#include diff --git a/c10/cuda/CMakeLists.txt b/c10/cuda/CMakeLists.txt index c1262d5..0b0b1d8 100644 --- a/c10/cuda/CMakeLists.txt +++ b/c10/cuda/CMakeLists.txt @@ -15,13 +15,22 @@ configure_file( # Note: if you want to add ANY dependency to the c10 library, make sure you # check with the core PyTorch developers as the dependendency will be # transitively passed on to all libraries dependent on PyTorch. + +# Note: if you add a new source file/header, you will need to update +# tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py for new files +# and headers you add set(C10_CUDA_SRCS + CUDAStream.cpp + impl/CUDAGuardImpl.cpp impl/CUDATest.cpp ) set(C10_CUDA_HEADERS + CUDAException.h + CUDAGuard.h CUDAMacros.h CUDAMathCompat.h - CUDAException.h + CUDAStream.h + impl/CUDAGuardImpl.h impl/CUDATest.h ) set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) diff --git a/aten/src/ATen/cuda/CUDAGuard.h b/c10/cuda/CUDAGuard.h similarity index 96% rename from aten/src/ATen/cuda/CUDAGuard.h rename to c10/cuda/CUDAGuard.h index 3aa34f4..c7e8aca 100644 --- a/aten/src/ATen/cuda/CUDAGuard.h +++ b/c10/cuda/CUDAGuard.h @@ -1,13 +1,14 @@ #pragma once -#include +#include +#include #include #include #include #include -namespace at { namespace cuda { +namespace c10 { namespace cuda { // This code is kind of boilerplatey. See Note [Whither the DeviceGuard boilerplate] @@ -56,7 +57,7 @@ struct CUDAGuard { private: /// The guard for the current device. - c10::impl::InlineDeviceGuard guard_; + c10::impl::InlineDeviceGuard guard_; }; /// A variant of OptionalDeviceGuard that is specialized for CUDA. See @@ -108,7 +109,7 @@ struct OptionalCUDAGuard { void reset() { guard_.reset(); } private: - c10::impl::InlineOptionalDeviceGuard guard_; + c10::impl::InlineOptionalDeviceGuard guard_; }; /// A variant of StreamGuard that is specialized for CUDA. See CUDAGuard @@ -165,7 +166,7 @@ struct CUDAStreamGuard { Device original_device() const { return guard_.original_device(); } private: - c10::impl::InlineStreamGuard guard_; + c10::impl::InlineStreamGuard guard_; }; /// A variant of OptionalStreamGuard that is specialized for CUDA. See CUDAGuard @@ -228,8 +229,8 @@ struct OptionalCUDAStreamGuard { void reset() { guard_.reset(); } private: - c10::impl::InlineOptionalStreamGuard guard_; + c10::impl::InlineOptionalStreamGuard guard_; }; } // namespace cuda -} // namespace at +} // namespace c10 diff --git a/aten/src/ATen/cuda/CUDAStream.cpp b/c10/cuda/CUDAStream.cpp similarity index 98% rename from aten/src/ATen/cuda/CUDAStream.cpp rename to c10/cuda/CUDAStream.cpp index 4c6d0b8..a8ba415 100644 --- a/aten/src/ATen/cuda/CUDAStream.cpp +++ b/c10/cuda/CUDAStream.cpp @@ -1,5 +1,5 @@ -#include -#include +#include +#include #include #include @@ -10,7 +10,7 @@ #include #include -namespace at { +namespace c10 { namespace cuda { namespace { @@ -193,7 +193,7 @@ static void initGlobalStreamState() { static void initDeviceStreamState(DeviceIndex device_index) { // Switches to the requested device so streams are properly associated // with it. - at::cuda::CUDAGuard device_guard{device_index}; + CUDAGuard device_guard{device_index}; for (auto i = decltype(kStreamsPerPool){0}; i < kStreamsPerPool; ++i) { auto& lowpri_stream = low_priority_streams[device_index][i]; diff --git a/aten/src/ATen/cuda/CUDAStream.h b/c10/cuda/CUDAStream.h similarity index 97% rename from aten/src/ATen/cuda/CUDAStream.h rename to c10/cuda/CUDAStream.h index 61a6c1a..9845abc 100644 --- a/aten/src/ATen/cuda/CUDAStream.h +++ b/c10/cuda/CUDAStream.h @@ -5,7 +5,7 @@ #include -#include +#include #include #include @@ -51,14 +51,14 @@ * a kernel on the same stream from two different threads. */ -namespace at { +namespace c10 { namespace cuda { // Value object representing a CUDA stream. This is just a wrapper // around c10::Stream, but it comes with a little extra CUDA-specific // functionality (conversion to cudaStream_t), and a guarantee that // the wrapped c10::Stream really is a CUDA stream. -class AT_CUDA_API CUDAStream { +class C10_CUDA_API CUDAStream { public: enum Unchecked { UNCHECKED }; @@ -178,8 +178,8 @@ C10_API std::ostream& operator<<(std::ostream& stream, const CUDAStream& s); namespace std { template <> - struct hash { - size_t operator()(at::cuda::CUDAStream s) const noexcept { + struct hash { + size_t operator()(c10::cuda::CUDAStream s) const noexcept { return std::hash{}(s.unwrap()); } }; diff --git a/aten/src/ATen/cuda/detail/CUDAGuardImpl.cpp b/c10/cuda/impl/CUDAGuardImpl.cpp similarity index 58% rename from aten/src/ATen/cuda/detail/CUDAGuardImpl.cpp rename to c10/cuda/impl/CUDAGuardImpl.cpp index 69e2b9b..b0be679 100644 --- a/aten/src/ATen/cuda/detail/CUDAGuardImpl.cpp +++ b/c10/cuda/impl/CUDAGuardImpl.cpp @@ -1,6 +1,6 @@ -#include +#include -namespace at { +namespace c10 { namespace cuda { namespace impl { @@ -8,4 +8,4 @@ constexpr DeviceType CUDAGuardImpl::static_type; C10_REGISTER_GUARD_IMPL(CUDA, CUDAGuardImpl); -}}} // namespace at::cuda::detail +}}} // namespace c10::cuda::detail diff --git a/aten/src/ATen/cuda/detail/CUDAGuardImpl.h b/c10/cuda/impl/CUDAGuardImpl.h similarity index 73% rename from aten/src/ATen/cuda/detail/CUDAGuardImpl.h rename to c10/cuda/impl/CUDAGuardImpl.h index a5f55e6..f58282d 100644 --- a/aten/src/ATen/cuda/detail/CUDAGuardImpl.h +++ b/c10/cuda/impl/CUDAGuardImpl.h @@ -3,12 +3,12 @@ #include #include -#include -#include +#include +#include #include -namespace at { +namespace c10 { namespace cuda { namespace impl { @@ -25,32 +25,32 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { AT_ASSERT(d.type() == DeviceType::CUDA); Device old_device = getDevice(); if (old_device.index() != d.index()) { - AT_CUDA_CHECK(cudaSetDevice(d.index())); + C10_CUDA_CHECK(cudaSetDevice(d.index())); } return old_device; } Device getDevice() const override { int device; - AT_CUDA_CHECK(cudaGetDevice(&device)); + C10_CUDA_CHECK(cudaGetDevice(&device)); return Device(DeviceType::CUDA, device); } void setDevice(Device d) const override { AT_ASSERT(d.type() == DeviceType::CUDA); - AT_CUDA_CHECK(cudaSetDevice(d.index())); + C10_CUDA_CHECK(cudaSetDevice(d.index())); } void uncheckedSetDevice(Device d) const noexcept override { cudaSetDevice(d.index()); } Stream getStream(Device d) const noexcept override { - return at::cuda::getCurrentCUDAStream().unwrap(); + return getCurrentCUDAStream().unwrap(); } // NB: These do NOT set the current device Stream exchangeStream(Stream s) const noexcept override { CUDAStream cs(s); - auto old_stream = at::cuda::getCurrentCUDAStream(s.device().index()); - at::cuda::setCurrentCUDAStream(cs); + auto old_stream = getCurrentCUDAStream(s.device().index()); + setCurrentCUDAStream(cs); return old_stream.unwrap(); } }; -}}} // namespace at::cuda::impl +}}} // namespace c10::cuda::impl diff --git a/c10/impl/InlineDeviceGuard.h b/c10/impl/InlineDeviceGuard.h index 2df018c..39ad950 100644 --- a/c10/impl/InlineDeviceGuard.h +++ b/c10/impl/InlineDeviceGuard.h @@ -105,7 +105,8 @@ public: /// Sets the device to the given one. template ::value, int>::type = 0> void set_device(at::Device device) { - AT_ASSERT(device.type() == U::static_type); + AT_ASSERT((U::static_type == DeviceType::HIP && device.type() == DeviceType::CUDA) || + device.type() == U::static_type); auto index = device.index(); if (index == -1) return; impl_.setDevice(device); diff --git a/docs/cpp/source/Doxyfile b/docs/cpp/source/Doxyfile index b1c1dd8..1bd4682 100644 --- a/docs/cpp/source/Doxyfile +++ b/docs/cpp/source/Doxyfile @@ -35,8 +35,6 @@ INPUT = ../../../aten/src/ATen/ATen.h \ ../../../aten/src/ATen/core/ScalarType.h \ ../../../aten/src/ATen/core/Tensor.h \ ../../../aten/src/ATen/cuda/CUDAContext.h \ - ../../../aten/src/ATen/cuda/CUDAGuard.h \ - ../../../aten/src/ATen/cuda/CUDAStream.h \ ../../../aten/src/ATen/cudnn/Descriptors.h \ ../../../aten/src/ATen/cudnn/Handles.h \ ../../../aten/src/ATen/cudnn/Types.h \ @@ -52,6 +50,8 @@ INPUT = ../../../aten/src/ATen/ATen.h \ ../../../c10/util/ArrayRef.h \ ../../../c10/util/Exception.h \ ../../../c10/util/Optional.h \ + ../../../c10/cuda/CUDAGuard.h \ + ../../../c10/cuda/CUDAStream.h \ ../../../torch/csrc/api/include \ ../../../torch/csrc/api/src \ ../../../torch/csrc/autograd/generated/variable_factories.h \ diff --git a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py index 22aa972..53cfeb3 100644 --- a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py +++ b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py @@ -2214,6 +2214,22 @@ PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict([ ("cudaHostAllocator", ("hipHostAllocator", API_PYTORCH)), ("cudaDeviceAllocator", ("hipDeviceAllocator", API_PYTORCH)), ("define MAX_NUM_BLOCKS 200", ("define MAX_NUM_BLOCKS 64", API_PYTORCH)), + + ("cuda::CUDAGuard", ("hip::HIPGuardMasqueradingAsCUDA", API_PYTORCH)), + ("CUDAGuard", ("HIPGuardMasqueradingAsCUDA", API_PYTORCH)), + + ("cuda::OptionalCUDAGuard", ("hip::OptionalHIPGuardMasqueradingAsCUDA", API_PYTORCH)), + ("OptionalCUDAGuard", ("OptionalHIPGuardMasqueradingAsCUDA", API_PYTORCH)), + + ("cuda::CUDAStreamGuard", ("hip::HIPStreamGuardMasqueradingAsCUDA", API_PYTORCH)), + ("CUDAStreamGuard", ("HIPStreamGuardMasqueradingAsCUDA", API_PYTORCH)), + + ("cuda::OptionalCUDAStreamGuard", ("hip::OptionalHIPStreamGuardMasqueradingAsCUDA", API_PYTORCH)), + ("OptionalCUDAStreamGuard", ("OptionalHIPStreamGuardMasqueradingAsCUDA", API_PYTORCH)), + + # TODO: Undo this special-case; see the header for motivation behind this + # hack. It's VERY important this is only applied to PyTorch HIPify. + ("c10/cuda/CUDAGuard.h", ("ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h", API_PYTORCH)), ]) CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict([ @@ -2254,6 +2270,11 @@ CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict([ ("CuDNN" ,("MIOPEN", API_CAFFE2)), ("cudnn" ,("miopen", API_CAFFE2)), ("namespace cuda", ("namespace hip", API_CAFFE2)), + ("cuda::CUDAGuard", ("hip::HIPGuard", API_CAFFE2)), + ("cuda::OptionalCUDAGuard", ("hip::OptionalHIPGuard", API_CAFFE2)), + ("cuda::CUDAStreamGuard", ("hip::HIPStreamGuard", API_CAFFE2)), + ("cuda::OptionalCUDAStreamGuard", ("hip::OptionalHIPStreamGuard", API_CAFFE2)), + ("c10/cuda/CUDAGuard.h", ("c10/hip/HIPGuard.h", API_CAFFE2)), ]) # We must tread very carefully here. Blanket conversions like are done @@ -2272,15 +2293,27 @@ C10_MAPPINGS = collections.OrderedDict([ ("c10/cuda/CUDAMacros.h", ("c10/hip/HIPMacros.h", API_C10)), ("c10/cuda/CUDAMathCompat.h", ("c10/hip/HIPMathCompat.h", API_C10)), ("c10/cuda/CUDAFunctions.h", ("c10/hip/HIPFunctions.h", API_C10)), + ("c10/cuda/CUDAStream.h", ("c10/hip/HIPStream.h", API_C10)), ("c10/cuda/impl/CUDATest.h", ("c10/hip/impl/HIPTest.h", API_C10)), + ("c10/cuda/impl/CUDAGuardImpl.h", ("c10/hip/impl/HIPGuardImpl.h", API_C10)), ("c10/cuda/impl/cuda_cmake_macros.h", ("c10/hip/impl/hip_cmake_macros.h", API_C10)), ("C10_CUDA_CHECK", ("C10_HIP_CHECK", API_C10)), ("c10::cuda", ("c10::hip", API_C10)), + ("cuda::CUDAStream", ("hip::HIPStream", API_C10)), + ("CUDAStream", ("HIPStream", API_C10)), # This substitution is not permissible, because there's another copy of this # function in torch/cuda.h # ("cuda::device_count", ("hip::device_count", API_C10)), ("cuda::current_device", ("hip::current_device", API_C10)), ("cuda::set_device", ("hip::set_device", API_C10)), + ("cuda::getStreamFromPool", ("hip::getStreamFromPool", API_C10)), + ("getStreamFromPool", ("getStreamFromPool", API_C10)), + ("cuda::getDefaultCUDAStream", ("hip::getDefaultHIPStream", API_C10)), + ("getDefaultCUDAStream", ("getDefaultHIPStream", API_C10)), + ("cuda::getCurrentCUDAStream", ("hip::getCurrentHIPStream", API_C10)), + ("getCurrentCUDAStream", ("getCurrentHIPStream", API_C10)), + ("cuda::setCurrentCUDAStream", ("hip::setCurrentHIPStream", API_C10)), + ("setCurrentCUDAStream", ("setCurrentHIPStream", API_C10)), ]) # NB: C10 mappings are more specific than Caffe2 mappings, so run them diff --git a/tools/amd_build/pyHIPIFY/hipify_python.py b/tools/amd_build/pyHIPIFY/hipify_python.py index f3dde66..34c93fe 100755 --- a/tools/amd_build/pyHIPIFY/hipify_python.py +++ b/tools/amd_build/pyHIPIFY/hipify_python.py @@ -876,8 +876,9 @@ for mapping in CUDA_TO_HIP_MAPPINGS: if constants.API_CAFFE2 not in meta_data: PYTORCH_TRIE.add(src) PYTORCH_MAP[src] = dst - CAFFE2_TRIE.add(src) - CAFFE2_MAP[src] = dst + if constants.API_PYTORCH not in meta_data: + CAFFE2_TRIE.add(src) + CAFFE2_MAP[src] = dst RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.pattern()) RE_PYTORCH_PREPROCESSOR = re.compile(r'\b{0}\b'.format(PYTORCH_TRIE.pattern())) diff --git a/tools/cwrap/plugins/AutoGPU.py b/tools/cwrap/plugins/AutoGPU.py index 3ea14d0..61019b1 100644 --- a/tools/cwrap/plugins/AutoGPU.py +++ b/tools/cwrap/plugins/AutoGPU.py @@ -10,5 +10,5 @@ class AutoGPU(CWrapPlugin): def process_pre_arg_assign(self, template, option): if not option.get('device_guard', True): return template - call = 'at::cuda::CUDAGuard device_guard(get_device(args));' + call = 'SpecializedDeviceGuard device_guard(get_device(args));' return [call] + template diff --git a/tools/cwrap/plugins/NNExtension.py b/tools/cwrap/plugins/NNExtension.py index b30a433..d90b1be 100644 --- a/tools/cwrap/plugins/NNExtension.py +++ b/tools/cwrap/plugins/NNExtension.py @@ -13,11 +13,13 @@ MODULE_HEAD = """ // HIPify isn't being applied to autogenerated files, so defensively // handle both the CUDA and ROCM cases. #if defined(USE_CUDA) -#include +#include +using SpecializedDeviceGuard = c10::cuda::CUDAGuard; #elif defined(USE_ROCM) -#include +#include +// I'm not sure why the build doesn't like c10::cuda namespace... +using SpecializedDeviceGuard = at::cuda::HIPGuardMasqueradingAsCUDA; #endif - """ REGISTER_METHOD_TEMPLATE = Template(' {"$name", (PyCFunction)$name, METH_STATIC | METH_VARARGS, NULL},\n') diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 564005c..f12a927 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -29,14 +29,12 @@ #ifdef USE_CUDA #include -#include -#include +#include #endif // USE_CUDA #ifdef USE_ROCM #include -#include -#include +#include #endif // USE_ROCM namespace torch { namespace autograd { @@ -212,7 +210,7 @@ Engine::~Engine() = default; // not CUDA. auto Engine::thread_init(int device) -> void { THInferNumThreads(); -#if defined(USE_CUDA) || defined(USE_ROCM) +#if defined(USE_CUDA) // NB: We MUST NOT construct the guard for device -1, // as in some settings we compile with USE_CUDA, but // have lazy stubs for CUDA functionality (so actually @@ -222,6 +220,11 @@ auto Engine::thread_init(int device) -> void { if (device != -1) { guard.set_index(device); } +#elif defined(USE_ROCM) + at::cuda::OptionalHIPGuardMasqueradingAsCUDA guard; + if (device != -1) { + guard.set_index(device); + } #endif worker_device = device; thread_main(nullptr); diff --git a/torch/csrc/autograd/profiler.cpp b/torch/csrc/autograd/profiler.cpp index 74a9267..25b2780 100644 --- a/torch/csrc/autograd/profiler.cpp +++ b/torch/csrc/autograd/profiler.cpp @@ -2,7 +2,7 @@ #include #ifdef USE_CUDA -#include +#include #endif #include diff --git a/torch/csrc/cuda/Stream.cpp b/torch/csrc/cuda/Stream.cpp index 8d356b6..0f5c9fe 100644 --- a/torch/csrc/cuda/Stream.cpp +++ b/torch/csrc/cuda/Stream.cpp @@ -3,7 +3,7 @@ #include #include -#include +#include #include #include diff --git a/torch/csrc/cuda/comm.cpp b/torch/csrc/cuda/comm.cpp index 1f0b0cd..c6bc582 100644 --- a/torch/csrc/cuda/comm.cpp +++ b/torch/csrc/cuda/comm.cpp @@ -11,7 +11,7 @@ #include #include -#include +#include #include #include diff --git a/torch/csrc/cuda/nccl.cpp b/torch/csrc/cuda/nccl.cpp index d043d9a..af94da7 100644 --- a/torch/csrc/cuda/nccl.cpp +++ b/torch/csrc/cuda/nccl.cpp @@ -4,7 +4,7 @@ #include #include -#include +#include #include #include diff --git a/torch/csrc/cuda/python_nccl.cpp b/torch/csrc/cuda/python_nccl.cpp index b618806..b7dab1f 100644 --- a/torch/csrc/cuda/python_nccl.cpp +++ b/torch/csrc/cuda/python_nccl.cpp @@ -9,7 +9,7 @@ #include #include -#include +#include #include diff --git a/torch/csrc/distributed/c10d/ddp.cpp b/torch/csrc/distributed/c10d/ddp.cpp index b80963f..3dd0a03 100644 --- a/torch/csrc/distributed/c10d/ddp.cpp +++ b/torch/csrc/distributed/c10d/ddp.cpp @@ -9,7 +9,7 @@ #include #include -#include +#include #include #include diff --git a/torch/csrc/generic/StorageSharing.cpp b/torch/csrc/generic/StorageSharing.cpp index 0b73880..8fc91f3 100644 --- a/torch/csrc/generic/StorageSharing.cpp +++ b/torch/csrc/generic/StorageSharing.cpp @@ -1,7 +1,7 @@ #ifdef USE_CUDA #include #include -#include +#include #endif #include diff --git a/torch/csrc/jit/fuser/cuda/fused_kernel.cpp b/torch/csrc/jit/fuser/cuda/fused_kernel.cpp index 93c9a1c..42432ba 100644 --- a/torch/csrc/jit/fuser/cuda/fused_kernel.cpp +++ b/torch/csrc/jit/fuser/cuda/fused_kernel.cpp @@ -1,7 +1,7 @@ #include #include -#include +#include #include #include #include diff --git a/torch/csrc/utils.h b/torch/csrc/utils.h index fe54956..411d0de 100644 --- a/torch/csrc/utils.h +++ b/torch/csrc/utils.h @@ -11,7 +11,7 @@ #ifdef USE_CUDA #include -#include +#include #endif #define THPUtils_(NAME) TH_CONCAT_4(THP,Real,Utils_,NAME) diff --git a/torch/lib/THD/base/data_channels/DataChannelNccl.cpp b/torch/lib/THD/base/data_channels/DataChannelNccl.cpp index 1c29b54..e21d1b7 100644 --- a/torch/lib/THD/base/data_channels/DataChannelNccl.cpp +++ b/torch/lib/THD/base/data_channels/DataChannelNccl.cpp @@ -3,7 +3,7 @@ #include #include -#include +#include #include #include diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp index 2f1493d..613b854 100644 --- a/torch/lib/c10d/ProcessGroupGloo.cpp +++ b/torch/lib/c10d/ProcessGroupGloo.cpp @@ -10,8 +10,8 @@ #ifdef USE_CUDA #include -#include -#include +#include +#include #include #include #endif diff --git a/torch/lib/c10d/ProcessGroupGloo.hpp b/torch/lib/c10d/ProcessGroupGloo.hpp index edb59d2..1c7ac65 100644 --- a/torch/lib/c10d/ProcessGroupGloo.hpp +++ b/torch/lib/c10d/ProcessGroupGloo.hpp @@ -17,7 +17,7 @@ #ifdef USE_CUDA #include -#include +#include #endif #include diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index 5242b09..b6c699b 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -7,7 +7,7 @@ #include #include -#include +#include #include diff --git a/torch/lib/c10d/test/CUDATest.hpp b/torch/lib/c10d/test/CUDATest.hpp index 5e02e31..defaff8 100644 --- a/torch/lib/c10d/test/CUDATest.hpp +++ b/torch/lib/c10d/test/CUDATest.hpp @@ -1,6 +1,6 @@ #pragma once -#include +#include namespace c10d { namespace test { diff --git a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp index b66238b..de1432e 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooAsyncTest.cpp @@ -1,6 +1,6 @@ #include -#include +#include #include #include diff --git a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp index 7801053..f18c559 100644 --- a/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupNCCLTest.cpp @@ -5,9 +5,9 @@ #include #include -#include +#include #include -#include +#include using namespace c10d::test; -- 2.7.4