Refine return type Stream to HIPStream in HIPStreamGuardMasqueradingAsCUDA (#16978)
authorEdward Yang <ezyang@fb.com>
Tue, 12 Feb 2019 15:22:05 +0000 (07:22 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 12 Feb 2019 15:36:25 +0000 (07:36 -0800)
Summary:
Previously, we used the templated class directly to provide
implementations.  However, there is a subtle difference
between this, and CUDAStreamGuard: CUDAStreamGuard has refined types
for the Streams it returns.  This lead to a compilation failure
of HIPified ddp.cpp.  This commit lines them up more closely,
at the cost of copy-paste.

A possible alternate strategy would have been to extend the
InlineDeviceGuard templates to optionally accept refinements
for Stream.  I leave this for future work.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16978

Differential Revision: D14045346

Pulled By: ezyang

fbshipit-source-id: 2b101606e62e4db588027c57902ea739a2119410

aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h

index 195ce31..0f150ee 100644 (file)
@@ -78,9 +78,110 @@ struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplI
 
 // All of the guards which have HIPGuardImpl burned in need to also have
 // variants using HIPGuardImplMasqueradingAsCUDA.
-using HIPGuardMasqueradingAsCUDA = c10::impl::InlineDeviceGuard<HIPGuardImplMasqueradingAsCUDA>;
-using OptionalHIPGuardMasqueradingAsCUDA = c10::impl::InlineOptionalDeviceGuard<HIPGuardImplMasqueradingAsCUDA>;
-using HIPStreamGuardMasqueradingAsCUDA = c10::impl::InlineStreamGuard<HIPGuardImplMasqueradingAsCUDA>;
-using OptionalHIPStreamGuardMasqueradingAsCUDA = c10::impl::InlineOptionalStreamGuard<HIPGuardImplMasqueradingAsCUDA>;
+
+/// This code is all a direct copy from c10/cuda/HIPGuardMasqueradingAsCUDA.h, but with
+/// the correct InlineDeviceGuard burned in.  Sorry about the
+/// copy-pasting.
+
+struct HIPGuardMasqueradingAsCUDA {
+  explicit HIPGuardMasqueradingAsCUDA() = delete;
+  explicit HIPGuardMasqueradingAsCUDA(DeviceIndex device_index) : guard_(device_index) {}
+  explicit HIPGuardMasqueradingAsCUDA(Device device) : guard_(device) {}
+
+  HIPGuardMasqueradingAsCUDA(const HIPGuardMasqueradingAsCUDA&) = delete;
+  HIPGuardMasqueradingAsCUDA& operator=(const HIPGuardMasqueradingAsCUDA&) = delete;
+  HIPGuardMasqueradingAsCUDA(HIPGuardMasqueradingAsCUDA&& other) = delete;
+  HIPGuardMasqueradingAsCUDA& operator=(HIPGuardMasqueradingAsCUDA&& other) = delete;
+
+  void set_device(Device device) { guard_.set_device(device); }
+  void reset_device(Device device) { guard_.reset_device(device); }
+  void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
+  Device original_device() const { return guard_.original_device(); }
+  Device current_device() const { return guard_.current_device(); }
+
+ private:
+  c10::impl::InlineDeviceGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
+};
+
+struct OptionalHIPGuardMasqueradingAsCUDA {
+  explicit OptionalHIPGuardMasqueradingAsCUDA() : guard_() {}
+  explicit OptionalHIPGuardMasqueradingAsCUDA(optional<Device> device_opt) : guard_(device_opt) {}
+  explicit OptionalHIPGuardMasqueradingAsCUDA(optional<DeviceIndex> device_index_opt) : guard_(device_index_opt) {}
+
+  OptionalHIPGuardMasqueradingAsCUDA(const OptionalHIPGuardMasqueradingAsCUDA&) = delete;
+  OptionalHIPGuardMasqueradingAsCUDA& operator=(const OptionalHIPGuardMasqueradingAsCUDA&) = delete;
+  OptionalHIPGuardMasqueradingAsCUDA(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete;
+  OptionalHIPGuardMasqueradingAsCUDA& operator=(OptionalHIPGuardMasqueradingAsCUDA&& other) = delete;
+
+  void set_device(Device device) { guard_.set_device(device); }
+  void reset_device(Device device) { guard_.reset_device(device); }
+  void set_index(DeviceIndex device_index) { guard_.set_index(device_index); }
+  optional<Device> original_device() const { return guard_.original_device(); }
+  optional<Device> current_device() const { return guard_.current_device(); }
+  void reset() { guard_.reset(); }
+
+private:
+  c10::impl::InlineOptionalDeviceGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
+};
+
+struct HIPStreamGuardMasqueradingAsCUDA {
+  explicit HIPStreamGuardMasqueradingAsCUDA() = delete;
+  explicit HIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {}
+  HIPStreamGuardMasqueradingAsCUDA(const HIPStreamGuardMasqueradingAsCUDA&) = delete;
+  HIPStreamGuardMasqueradingAsCUDA& operator=(const HIPStreamGuardMasqueradingAsCUDA&) = delete;
+  HIPStreamGuardMasqueradingAsCUDA(HIPStreamGuardMasqueradingAsCUDA&& other) = delete;
+  HIPStreamGuardMasqueradingAsCUDA& operator=(HIPStreamGuardMasqueradingAsCUDA&& other) = delete;
+
+  void reset_stream(Stream stream) { guard_.reset_stream(stream); }
+
+  HIPStream original_stream() const {
+    return HIPStream(HIPStream::UNCHECKED, guard_.original_stream());
+  }
+  HIPStream current_stream() const {
+    return HIPStream(HIPStream::UNCHECKED, guard_.current_stream());
+  }
+
+  Device current_device() const { return guard_.current_device(); }
+  Device original_device() const { return guard_.original_device(); }
+
+private:
+  c10::impl::InlineStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
+};
+
+struct OptionalHIPStreamGuardMasqueradingAsCUDA {
+  explicit OptionalHIPStreamGuardMasqueradingAsCUDA() : guard_() {}
+  explicit OptionalHIPStreamGuardMasqueradingAsCUDA(Stream stream) : guard_(stream) {}
+  explicit OptionalHIPStreamGuardMasqueradingAsCUDA(optional<Stream> stream_opt) : guard_(stream_opt) {}
+
+  OptionalHIPStreamGuardMasqueradingAsCUDA(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete;
+  OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(const OptionalHIPStreamGuardMasqueradingAsCUDA&) = delete;
+  OptionalHIPStreamGuardMasqueradingAsCUDA(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete;
+  OptionalHIPStreamGuardMasqueradingAsCUDA& operator=(OptionalHIPStreamGuardMasqueradingAsCUDA&& other) = delete;
+
+  void reset_stream(Stream stream) { guard_.reset_stream(stream); }
+
+  optional<HIPStream> original_stream() const {
+    auto r = guard_.original_stream();
+    if (r.has_value()) {
+      return make_optional(HIPStream(HIPStream::UNCHECKED, r.value()));
+    } else {
+      return nullopt;
+    }
+  }
+
+  optional<HIPStream> current_stream() const {
+    auto r = guard_.current_stream();
+    if (r.has_value()) {
+      return make_optional(HIPStream(HIPStream::UNCHECKED, r.value()));
+    } else {
+      return nullopt;
+    }
+  }
+
+  void reset() { guard_.reset(); }
+
+private:
+  c10::impl::InlineOptionalStreamGuard<HIPGuardImplMasqueradingAsCUDA> guard_;
+};
 
 }} // namespace c10::hip