// All of the guards which have HIPGuardImpl burned in need to also have
// variants using HIPGuardImplMasqueradingAsCUDA.
-using HIPGuardMasqueradingAsCUDA = c10::impl::InlineDeviceGuard<HIPGuardImplMasqueradingAsCUDA>;
-using OptionalHIPGuardMasqueradingAsCUDA = c10::impl::InlineOptionalDeviceGuard<HIPGuardImplMasqueradingAsCUDA>;
-using HIPStreamGuardMasqueradingAsCUDA = c10::impl::InlineStreamGuard<HIPGuardImplMasqueradingAsCUDA>;
-using OptionalHIPStreamGuardMasqueradingAsCUDA = c10::impl::InlineOptionalStreamGuard<HIPGuardImplMasqueradingAsCUDA>;
+
+/// This code is all a direct copy from c10/cuda/HIPGuardMasqueradingAsCUDA.h, but with
+/// the correct InlineDeviceGuard burned in. Sorry about the
+/// copy-pasting.
+
+struct HIPGuardMasqueradingAsCUDA {
+ explicit HIPGuardMasqueradingAsCUDA() = delete;
+ explicit HIPGuardMasqueradingAsCUDA(DeviceIndex device_index) : guard_(device_index) {}
+ explicit HIPGuardMasqueradingAsCUDA(Device device) : guard_(device) {}
+
+ HIPGuardMasqueradingAsCUDA(const HIPGuardMasqueradingAsCUDA&) = delete;
+ HIPGuardMasqueradingAsCUDA& operator=(const HIPGuardMasqueradingAsCUDA&) = delete;
+ HIPGuardMasqueradingAsCUDA(HIPGuardMasqueradingAsCUDA&& other) = delete;
+ HIPGuardMasqueradingAsCUDA& operator=(HIPGuardMasqueradingAsCUDA&& other) = delete;
+
+ void set_device(Device device) { guard_.set_device(device); }
+ void reset_device(Device device) { guard_.reset_device(device); }
+ void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
+ Device original_device() const { return guard_.original_device(); }
+ Device current_device() const { return guard_.current_device(); }
+
+ private:
+ c10::impl::InlineDeviceGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
+};
+
+struct OptionalHIPGuardMasqueradingAsCUDA {
+ explicit OptionalHIPGuardMasqueradingAsCUDA() : guard_() {}
+ explicit OptionalHIPGuardMasqueradingAsCUDA(optional<Device> device_opt) : guard_(device_opt) {}
+ explicit OptionalHIPGuardMasqueradingAsCUDA(optional<DeviceIndex> device_index_opt) : guard_(device_index_opt) {}
+
+ OptionalHIPGuardMasqueradingAsCUDA(const OptionalHIPGuardMasqueradingAsCUDA&) = delete;
+ OptionalHIPGuardMasqueradingAsCUDA& operator=(const OptionalHIPGuardMasqueradingAsCUDA&) = delete;
+ OptionalHIPGuardMasqueradingAsCUDA(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete;
+ OptionalHIPGuardMasqueradingAsCUDA& operator=(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete;
+
+ void set_device(Device device) { guard_.set_device(device); }
+ void reset_device(Device device) { guard_.reset_device(device); }
+ void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
+ optional<Device> original_device() const { return guard_.original_device(); }
+ optional<Device> current_device() const { return guard_.current_device(); }
+ void reset() { guard_.reset(); }
+
+private:
+ c10::impl::InlineOptionalDeviceGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
+};
+
+struct HIPStreamGuardMasqueradingAsCUDA {
+ explicit HIPStreamGuardMasqueradingAsCUDA() = delete;
+ explicit HIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {}
+ HIPStreamGuardMasqueradingAsCUDA(const HIPStreamGuardMasqueradingAsCUDA&) = delete;
+ HIPStreamGuardMasqueradingAsCUDA& operator=(const HIPStreamGuardMasqueradingAsCUDA&) = delete;
+ HIPStreamGuardMasqueradingAsCUDA(HIPStreamGuardMasqueradingAsCUDA&& other) = delete;
+ HIPStreamGuardMasqueradingAsCUDA& operator=(HIPStreamGuardMasqueradingAsCUDA&& other) = delete;
+
+ void reset_stream(Stream stream) { guard_.reset_stream(stream); }
+
+ HIPStream original_stream() const {
+ return HIPStream(HIPStream::UNCHECKED, guard_.original_stream());
+ }
+ HIPStream current_stream() const {
+ return HIPStream(HIPStream::UNCHECKED, guard_.current_stream());
+ }
+
+ Device current_device() const { return guard_.current_device(); }
+ Device original_device() const { return guard_.original_device(); }
+
+private:
+ c10::impl::InlineStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
+};
+
+struct OptionalHIPStreamGuardMasqueradingAsCUDA {
+ explicit OptionalHIPStreamGuardMasqueradingAsCUDA() : guard_() {}
+ explicit OptionalHIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {}
+ explicit OptionalHIPStreamGuardMasqueradingAsCUDA(optional<Stream> stream_opt) : guard_(stream_opt) {}
+
+ OptionalHIPStreamGuardMasqueradingAsCUDA(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete;
+ OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete;
+ OptionalHIPStreamGuardMasqueradingAsCUDA(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete;
+ OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete;
+
+ void reset_stream(Stream stream) { guard_.reset_stream(stream); }
+
+ optional<HIPStream> original_stream() const {
+ auto r = guard_.original_stream();
+ if (r.has_value()) {
+ return make_optional(HIPStream(HIPStream::UNCHECKED, r.value()));
+ } else {
+ return nullopt;
+ }
+ }
+
+ optional<HIPStream> current_stream() const {
+ auto r = guard_.current_stream();
+ if (r.has_value()) {
+ return make_optional(HIPStream(HIPStream::UNCHECKED, r.value()));
+ } else {
+ return nullopt;
+ }
+ }
+
+ void reset() { guard_.reset(); }
+
+private:
+ c10::impl::InlineOptionalStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
+};
}} // namespace c10::hip