From 62151aa259154a9206d812b16525fd5b55ce4a89 Mon Sep 17 00:00:00 2001 From: Igor Fedan Date: Thu, 27 Dec 2018 15:24:22 -0800 Subject: [PATCH] Added deviceCount() virtual method to DeviceGuardImplInterface (#15574) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15574 Added deviceCount() virtual method to DeviceGuardImplInterface, also added correspondent implementation for CPUGuardImpl, CUDAGuardImpl, FakeGuardImpl, VirtualGuardImpl, HIPGuardImplMasqueradingAsCUDA Reviewed By: soumith Differential Revision: D13554609 fbshipit-source-id: 913bf2aad44a0a356efe54505ee4abaf6c4622db --- aten/src/ATen/detail/CPUGuardImpl.h | 3 +++ aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h | 5 +++++ c10/cuda/impl/CUDAGuardImpl.h | 5 +++++ c10/impl/DeviceGuardImplInterface.h | 5 +++++ c10/impl/FakeGuardImpl.h | 3 +++ c10/impl/VirtualGuardImpl.h | 3 +++ 6 files changed, 24 insertions(+) diff --git a/aten/src/ATen/detail/CPUGuardImpl.h b/aten/src/ATen/detail/CPUGuardImpl.h index f47cdc3..223c955 100644 --- a/aten/src/ATen/detail/CPUGuardImpl.h +++ b/aten/src/ATen/detail/CPUGuardImpl.h @@ -34,6 +34,9 @@ struct CPUGuardImpl final : public c10::impl::DeviceGuardImplInterface { // no-op return Stream(Stream::DEFAULT, Device(DeviceType::CPU, -1)); } + DeviceIndex deviceCount() const override { + return 1; + } }; }} // namespace at::detail diff --git a/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h index 3ccda82..6772b90 100644 --- a/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h +++ b/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h @@ -67,6 +67,11 @@ struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplI setCurrentHIPStream(cs); return old_stream.unwrap(); } + DeviceIndex deviceCount() const override { + int deviceCnt; + C10_HIP_CHECK(hipGetDeviceCount(&deviceCnt)); + return deviceCnt; + } }; // All of the guards which have HIPGuardImpl burned in need to also have diff --git a/c10/cuda/impl/CUDAGuardImpl.h b/c10/cuda/impl/CUDAGuardImpl.h index f58282d..90c2d59 100644 --- a/c10/cuda/impl/CUDAGuardImpl.h +++ b/c10/cuda/impl/CUDAGuardImpl.h @@ -51,6 +51,11 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { setCurrentCUDAStream(cs); return old_stream.unwrap(); } + DeviceIndex deviceCount() const override { + int deviceCnt; + C10_CUDA_CHECK(cudaGetDeviceCount(&deviceCnt)); + return deviceCnt; + } }; }}} // namespace c10::cuda::impl diff --git a/c10/impl/DeviceGuardImplInterface.h b/c10/impl/DeviceGuardImplInterface.h index d0780dc..e53eff5 100644 --- a/c10/impl/DeviceGuardImplInterface.h +++ b/c10/impl/DeviceGuardImplInterface.h @@ -81,6 +81,11 @@ struct C10_API DeviceGuardImplInterface { virtual Stream exchangeStream(Stream) const noexcept = 0; /** + * Get the number of devices. + */ + virtual DeviceIndex deviceCount() const = 0; + + /** * Intended use of this class is to leak the DeviceGuardImpl at program end. * So you better not call the destructor, buster! */ diff --git a/c10/impl/FakeGuardImpl.h b/c10/impl/FakeGuardImpl.h index fd0c0fa..f406a06 100644 --- a/c10/impl/FakeGuardImpl.h +++ b/c10/impl/FakeGuardImpl.h @@ -54,6 +54,9 @@ struct FakeGuardImpl final : public DeviceGuardImplInterface { current_streams_[s.device_index()] = s.id(); return Stream(Stream::UNSAFE, s.device(), old_id); } + DeviceIndex deviceCount() const override { + return 1; + } // Convenience methods for testing static DeviceIndex getDeviceIndex() { return current_device_; diff --git a/c10/impl/VirtualGuardImpl.h b/c10/impl/VirtualGuardImpl.h index 4cb10fd..26ae3ad 100644 --- a/c10/impl/VirtualGuardImpl.h +++ b/c10/impl/VirtualGuardImpl.h @@ -40,6 +40,9 @@ public: Stream exchangeStream(Stream s) const noexcept override { return impl_->exchangeStream(s); } + DeviceIndex deviceCount() const override { + return impl_->deviceCount(); + } private: const DeviceGuardImplInterface* impl_ = nullptr; }; -- 2.7.4