From c9989dfe37d41dec334b22719b5efc2dfaaf671e Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Thu, 28 Feb 2019 13:32:22 -0800 Subject: [PATCH] Make HIPStream also masquerade as CUDA. (#17469) Summary: HIPGuard interfaces that interacted with HIPStream were previously totally busted (because the streams had the wrong device type). This fixes it, following along the same lines of MasqueardingAsCUDA. Along the way I beefed up the explanatory comment. Signed-off-by: Edward Z. Yang cc jithunnair-amd iotamudelta bddppq Pull Request resolved: https://github.com/pytorch/pytorch/pull/17469 Differential Revision: D14243396 Pulled By: ezyang fbshipit-source-id: 972455753a62f8584ba9ab194f9c785db7bb9bde --- .../ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h | 68 +++++++----- .../ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h | 118 +++++++++++++++++++++ tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py | 16 +++ 3 files changed, 178 insertions(+), 24 deletions(-) create mode 100644 aten/src/ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h diff --git a/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h index 0f150ee..5c7f207 100644 --- a/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h +++ b/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h @@ -11,26 +11,46 @@ #include +#include + // Use of c10::hip namespace here makes hipification easier, because // I don't have to also fix namespaces. Sorry! namespace c10 { namespace hip { // Note [Masquerading as CUDA] // ~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// HIPGuardImplMasqueradingAsCUDA is like a HIPGuardImpl, but -// it reports its DeviceType as CUDA (e.g., type() returns CUDA, -// getDevice() reports the current HIP device as a CUDA device.) -// We can't directly use HIPGuardImpl, since it (piously) requires -// the DeviceType to be HIP. +// c10_hip is very easy to understand: it is HIPified from c10_cuda, +// and anywhere you said CUDA, the source code now says HIP. HIPified +// PyTorch is much harder to understand: it is HIPified from regular +// PyTorch, yes, but NO source-to-source translation from CUDA to +// HIP occurs; instead, anywhere we see "CUDA", it actually means "HIP". +// For example, when you use HIPified PyTorch, you say x.cuda() to +// move a tensor onto ROCm device. We call this situation "HIP +// maquerading as CUDA". +// +// This leads to a very awkward situation when we want to call c10_hip +// code from PyTorch, since c10_hip is expecting things to be called +// HIP, but PyTorch is calling them CUDA (masquerading as HIP). To +// fix this impedance mismatch, we have MasqueradingAsCUDA variants +// for all c10_hip classes. These translate between the "HIP" and "CUDA +// masquerading as HIP" worlds. For example, +// HIPGuardImplMasqueradingAsCUDA (this file) provides something like a +// HIPGuardImpl, but it reports its DeviceType as CUDA (e.g., type() +// returns CUDA, getDevice() reports the current HIP device as a CUDA +// device.) +// +// We should be able to delete all of these classes entirely once +// we switch PyTorch to calling a HIP a HIP. // -// This is necessary for PyTorch at the moment, which is implemented -// by pretending that CUDA is actually HIP. Eventually, we want -// to make PyTorch treat HIP as a separate DeviceType, and then we -// can delete this class. +// When you add a new MasqueradingAsCUDA class/function, you need to +// also update the rewrite rules in tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py // -// Also, note that the cpp file associated with this also *overwrites* -// the entry in the DeviceGuardImpl registry for CUDA with this HIP -// implementation. +// +// +// By the way, note that the cpp file associated with this also +// *overwrites* the entry in the DeviceGuardImpl registry for CUDA with +// this HIP implementation. + struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplInterface { static constexpr DeviceType static_type = DeviceType::CUDA; HIPGuardImplMasqueradingAsCUDA() {} @@ -61,12 +81,12 @@ struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplI hipSetDevice(d.index()); } Stream getStream(Device d) const noexcept override { - return getCurrentHIPStream().unwrap(); + return getCurrentHIPStreamMasqueradingAsCUDA().unwrap(); } Stream exchangeStream(Stream s) const noexcept override { - HIPStream cs(s); - auto old_stream = getCurrentHIPStream(s.device().index()); - setCurrentHIPStream(cs); + HIPStreamMasqueradingAsCUDA cs(s); + auto old_stream = getCurrentHIPStreamMasqueradingAsCUDA(s.device().index()); + setCurrentHIPStreamMasqueradingAsCUDA(cs); return old_stream.unwrap(); } DeviceIndex deviceCount() const override { @@ -134,11 +154,11 @@ struct HIPStreamGuardMasqueradingAsCUDA { void reset_stream(Stream stream) { guard_.reset_stream(stream); } - HIPStream original_stream() const { - return HIPStream(HIPStream::UNCHECKED, guard_.original_stream()); + HIPStreamMasqueradingAsCUDA original_stream() const { + return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.original_stream()); } - HIPStream current_stream() const { - return HIPStream(HIPStream::UNCHECKED, guard_.current_stream()); + HIPStreamMasqueradingAsCUDA current_stream() const { + return HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, guard_.current_stream()); } Device current_device() const { return guard_.current_device(); } @@ -160,19 +180,19 @@ struct OptionalHIPStreamGuardMasqueradingAsCUDA { void reset_stream(Stream stream) { guard_.reset_stream(stream); } - optional original_stream() const { + optional original_stream() const { auto r = guard_.original_stream(); if (r.has_value()) { - return make_optional(HIPStream(HIPStream::UNCHECKED, r.value())); + return make_optional(HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value())); } else { return nullopt; } } - optional current_stream() const { + optional current_stream() const { auto r = guard_.current_stream(); if (r.has_value()) { - return make_optional(HIPStream(HIPStream::UNCHECKED, r.value())); + return make_optional(HIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA::UNCHECKED, r.value())); } else { return nullopt; } diff --git a/aten/src/ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h new file mode 100644 index 0000000..f37aa22 --- /dev/null +++ b/aten/src/ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h @@ -0,0 +1,118 @@ +#pragma once + +#include + +// Use of c10::hip namespace here makes hipification easier, because +// I don't have to also fix namespaces. Sorry! +namespace c10 { namespace hip { + +// See Note [Masquerading as CUDA] for motivation + +class HIPStreamMasqueradingAsCUDA { +public: + + enum Unchecked { UNCHECKED }; + + explicit HIPStreamMasqueradingAsCUDA(Stream stream) + : HIPStreamMasqueradingAsCUDA(UNCHECKED, stream) { + // We did the coercion unchecked; check that it was right. + AT_CHECK(stream.device().type() == DeviceType::CUDA /* !!! */); + } + + explicit HIPStreamMasqueradingAsCUDA(Unchecked, Stream stream) + // Unsafely coerce the "CUDA" stream into a HIP stream + : stream_( + HIPStream( + Stream( + Stream::UNSAFE, + Device(DeviceType::HIP, stream.device_index()), + stream.id()) + ) + ) {} + + // New constructor, just for this. Does NOT coerce. + explicit HIPStreamMasqueradingAsCUDA(HIPStream stream) : stream_(stream) {} + + bool operator==(const HIPStreamMasqueradingAsCUDA& other) const noexcept { + return stream_ == other.stream_; + } + + bool operator!=(const HIPStreamMasqueradingAsCUDA& other) const noexcept { + return stream_ != other.stream_; + } + + operator hipStream_t() const { return stream_.stream(); } + + operator Stream() const { + // Unsafely coerce HIP stream into a "CUDA" stream + return Stream(Stream::UNSAFE, device(), id()); + } + + DeviceIndex device_index() const { return stream_.device_index(); } + + Device device() const { + // Unsafely coerce HIP device into CUDA device + return Device(DeviceType::CUDA, stream_.device_index()); + } + + StreamId id() const { return stream_.id(); } + bool query() const { return stream_.query(); } + void synchronize() const { stream_.synchronize(); } + int priority() const { return stream_.priority(); } + hipStream_t stream() const { return stream_.stream(); } + + Stream unwrap() const { + // Unsafely coerce HIP stream into "CUDA" stream + return Stream(Stream::UNSAFE, device(), id()); + } + + uint64_t pack() const noexcept { + // Unsafely coerce HIP stream into "CUDA" stream before packing + return unwrap().pack(); + } + + static HIPStreamMasqueradingAsCUDA unpack(uint64_t bits) { + // NB: constructor manages CUDA->HIP translation for us + return HIPStreamMasqueradingAsCUDA(Stream::unpack(bits)); + } + + static std::tuple priority_range() { return HIPStream::priority_range(); } + + // New method, gets the underlying HIPStream + HIPStream hip_stream() const { return stream_; } + +private: + HIPStream stream_; +}; + +HIPStreamMasqueradingAsCUDA +inline getStreamFromPoolMasqueradingAsCUDA(const bool isHighPriority = false, DeviceIndex device = -1) { + return HIPStreamMasqueradingAsCUDA(getStreamFromPool(isHighPriority, device)); +} + +inline HIPStreamMasqueradingAsCUDA getDefaultHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) { + return HIPStreamMasqueradingAsCUDA(getDefaultHIPStream(device_index)); +} + +inline HIPStreamMasqueradingAsCUDA getCurrentHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) { + return HIPStreamMasqueradingAsCUDA(getCurrentHIPStream(device_index)); +} + +inline void setCurrentHIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA stream) { + setCurrentHIPStream(stream.hip_stream()); +} + +inline std::ostream& operator<<(std::ostream& stream, const HIPStreamMasqueradingAsCUDA& s) { + stream << s.hip_stream() << " (masquerading as CUDA)"; +} + +}} // namespace c10::hip + +namespace std { + template <> + struct hash { + size_t operator()(c10::hip::HIPStreamMasqueradingAsCUDA s) const noexcept { + return std::hash{}(s.unwrap()); + } + }; +} // namespace std diff --git a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py index 926a797..42930f6 100644 --- a/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py +++ b/tools/amd_build/pyHIPIFY/cuda_to_hip_mappings.py @@ -2236,10 +2236,26 @@ PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict([ ("cuda::CUDACachingAllocator::get", ("hip::HIPCachingAllocatorMasqueradingAsCUDA::get", API_PYTORCH)), ("CUDACachingAllocator::get", ("HIPCachingAllocatorMasqueradingAsCUDA::get", API_PYTORCH)), + ("cuda::CUDAStream", ("hip::HIPStreamMasqueradingAsCUDA", API_PYTORCH)), + ("CUDAStream", ("HIPStreamMasqueradingAsCUDA", API_PYTORCH)), + + ("cuda::getStreamFromPool", ("hip::getStreamFromPoolMasqueradingAsCUDA", API_PYTORCH)), + ("getStreamFromPool", ("getStreamFromPoolMasqueradingAsCUDA", API_PYTORCH)), + + ("cuda::getDefaultCUDAStream", ("hip::getDefaultHIPStreamMasqueradingAsCUDA", API_PYTORCH)), + ("getDefaultCUDAStream", ("getDefaultHIPStreamMasqueradingAsCUDA", API_PYTORCH)), + + ("cuda::getCurrentCUDAStream", ("hip::getCurrentHIPStreamMasqueradingAsCUDA", API_PYTORCH)), + ("getCurrentCUDAStream", ("getCurrentHIPStreamMasqueradingAsCUDA", API_PYTORCH)), + + ("cuda::setCurrentCUDAStream", ("hip::setCurrentHIPStreamMasqueradingAsCUDA", API_PYTORCH)), + ("setCurrentCUDAStream", ("setCurrentHIPStreamMasqueradingAsCUDA", API_PYTORCH)), + # TODO: Undo this special-case; see the header for motivation behind this # hack. It's VERY important this is only applied to PyTorch HIPify. ("c10/cuda/CUDAGuard.h", ("ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h", API_PYTORCH)), ("c10/cuda/CUDACachingAllocator.h", ("ATen/hip/impl/HIPCachingAllocatorMasqueradingAsCUDA.h", API_PYTORCH)), + ("c10/cuda/CUDAStream.h", ("ATen/hip/impl/HIPStreamMasqueradingAsCUDA.h", API_PYTORCH)), ]) CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict([ -- 2.7.4