From ae1fc584ea3a55e5f483f7fd63e6de1b5453f79d Mon Sep 17 00:00:00 2001 From: Edward Yang Date: Tue, 12 Feb 2019 07:22:05 -0800 Subject: [PATCH] Refine return type Stream to HIPStream in HIPStreamGuardMasqueradingAsCUDA (#16978) Summary: Previously, we used the templated class directly to provide implementations. However, there is a subtle difference between this, and CUDAStreamGuard: CUDAStreamGuard has refined types for the Streams it returns. This lead to a compilation failure of HIPified ddp.cpp. This commit lines them up more closely, at the cost of copy-paste. A possible alternate strategy would have been to extend the InlineDeviceGuard templates to optionally accept refinements for Stream. I leave this for future work. Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/16978 Differential Revision: D14045346 Pulled By: ezyang fbshipit-source-id: 2b101606e62e4db588027c57902ea739a2119410 --- .../ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h | 109 ++++++++++++++++++++- 1 file changed, 105 insertions(+), 4 deletions(-) diff --git a/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h index 195ce31..0f150ee 100644 --- a/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h +++ b/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h @@ -78,9 +78,110 @@ struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplI // All of the guards which have HIPGuardImpl burned in need to also have // variants using HIPGuardImplMasqueradingAsCUDA. -using HIPGuardMasqueradingAsCUDA = c10::impl::InlineDeviceGuard; -using OptionalHIPGuardMasqueradingAsCUDA = c10::impl::InlineOptionalDeviceGuard; -using HIPStreamGuardMasqueradingAsCUDA = c10::impl::InlineStreamGuard; -using OptionalHIPStreamGuardMasqueradingAsCUDA = c10::impl::InlineOptionalStreamGuard; + +/// This code is all a direct copy from c10/cuda/HIPGuardMasqueradingAsCUDA.h, but with +/// the correct InlineDeviceGuard burned in. Sorry about the +/// copy-pasting. + +struct HIPGuardMasqueradingAsCUDA { + explicit HIPGuardMasqueradingAsCUDA() = delete; + explicit HIPGuardMasqueradingAsCUDA(DeviceIndex device_index) : guard_(device_index) {} + explicit HIPGuardMasqueradingAsCUDA(Device device) : guard_(device) {} + + HIPGuardMasqueradingAsCUDA(const HIPGuardMasqueradingAsCUDA&) = delete; + HIPGuardMasqueradingAsCUDA& operator=(const HIPGuardMasqueradingAsCUDA&) = delete; + HIPGuardMasqueradingAsCUDA(HIPGuardMasqueradingAsCUDA&& other) = delete; + HIPGuardMasqueradingAsCUDA& operator=(HIPGuardMasqueradingAsCUDA&& other) = delete; + + void set_device(Device device) { guard_.set_device(device); } + void reset_device(Device device) { guard_.reset_device(device); } + void set_index(DeviceIndex device_index) { guard_.set_index(device_index); } + Device original_device() const { return guard_.original_device(); } + Device current_device() const { return guard_.current_device(); } + + private: + c10::impl::InlineDeviceGuard guard_; +}; + +struct OptionalHIPGuardMasqueradingAsCUDA { + explicit OptionalHIPGuardMasqueradingAsCUDA() : guard_() {} + explicit OptionalHIPGuardMasqueradingAsCUDA(optional device_opt) : guard_(device_opt) {} + explicit OptionalHIPGuardMasqueradingAsCUDA(optional device_index_opt) : guard_(device_index_opt) {} + + OptionalHIPGuardMasqueradingAsCUDA(const OptionalHIPGuardMasqueradingAsCUDA&) = delete; + OptionalHIPGuardMasqueradingAsCUDA& operator=(const OptionalHIPGuardMasqueradingAsCUDA&) = delete; + OptionalHIPGuardMasqueradingAsCUDA(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete; + OptionalHIPGuardMasqueradingAsCUDA& operator=(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete; + + void set_device(Device device) { guard_.set_device(device); } + void reset_device(Device device) { guard_.reset_device(device); } + void set_index(DeviceIndex device_index) { guard_.set_index(device_index); } + optional original_device() const { return guard_.original_device(); } + optional current_device() const { return guard_.current_device(); } + void reset() { guard_.reset(); } + +private: + c10::impl::InlineOptionalDeviceGuard guard_; +}; + +struct HIPStreamGuardMasqueradingAsCUDA { + explicit HIPStreamGuardMasqueradingAsCUDA() = delete; + explicit HIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {} + HIPStreamGuardMasqueradingAsCUDA(const HIPStreamGuardMasqueradingAsCUDA&) = delete; + HIPStreamGuardMasqueradingAsCUDA& operator=(const HIPStreamGuardMasqueradingAsCUDA&) = delete; + HIPStreamGuardMasqueradingAsCUDA(HIPStreamGuardMasqueradingAsCUDA&& other) = delete; + HIPStreamGuardMasqueradingAsCUDA& operator=(HIPStreamGuardMasqueradingAsCUDA&& other) = delete; + + void reset_stream(Stream stream) { guard_.reset_stream(stream); } + + HIPStream original_stream() const { + return HIPStream(HIPStream::UNCHECKED, guard_.original_stream()); + } + HIPStream current_stream() const { + return HIPStream(HIPStream::UNCHECKED, guard_.current_stream()); + } + + Device current_device() const { return guard_.current_device(); } + Device original_device() const { return guard_.original_device(); } + +private: + c10::impl::InlineStreamGuard guard_; +}; + +struct OptionalHIPStreamGuardMasqueradingAsCUDA { + explicit OptionalHIPStreamGuardMasqueradingAsCUDA() : guard_() {} + explicit OptionalHIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {} + explicit OptionalHIPStreamGuardMasqueradingAsCUDA(optional stream_opt) : guard_(stream_opt) {} + + OptionalHIPStreamGuardMasqueradingAsCUDA(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete; + OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete; + OptionalHIPStreamGuardMasqueradingAsCUDA(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete; + OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete; + + void reset_stream(Stream stream) { guard_.reset_stream(stream); } + + optional original_stream() const { + auto r = guard_.original_stream(); + if (r.has_value()) { + return make_optional(HIPStream(HIPStream::UNCHECKED, r.value())); + } else { + return nullopt; + } + } + + optional current_stream() const { + auto r = guard_.current_stream(); + if (r.has_value()) { + return make_optional(HIPStream(HIPStream::UNCHECKED, r.value())); + } else { + return nullopt; + } + } + + void reset() { guard_.reset(); } + +private: + c10::impl::InlineOptionalStreamGuard guard_; +}; }} // namespace c10::hip -- 2.7.4