Unify cudaGetDeviceCount implementations. (#18445)
authorEdward Yang <ezyang@fb.com>
Tue, 26 Mar 2019 16:42:41 +0000 (09:42 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 26 Mar 2019 16:50:14 +0000 (09:50 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18445
ghimport-source-id: 30d018737bf6989bc68b7e3676f44e0ca6141fde

Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18242 Test running a CUDA build on CPU machine.
* **#18445 Unify cudaGetDeviceCount implementations.**

I went about doing this by searching for calls to cudaGetDeviceCount,
and then methodically replacing them with references to c10::cuda::device_count()
or at::cuda::device_count().

There is a point to doing this: the various implementations wildly differed
in their handling of what to do when cudaGetDeviceCount returns an error.
The final standardized behavior is that **all errors are swallowed** and
we return device count of zero.  This indirectly fixes running CUDA builds
on CPU, which was broken in #17847.

I added 'noexcept' to the 'deviceCount' virtual method on DeviceGuardImpl.
This is a BC-breaking change for anyone inheriting from DeviceGuardImpl
but all you need to do is put 'noexcept' on your method and it is backwards
compatible with older libtorch.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision: D14612189

fbshipit-source-id: 3c8d186e3dd623c0e27625212c7ce30f75d943cb

14 files changed:
aten/src/ATen/cuda/CUDAContext.h
aten/src/ATen/cuda/detail/CUDAHooks.cpp
aten/src/ATen/detail/CPUGuardImpl.h
aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h
c10/core/impl/DeviceGuardImplInterface.h
c10/core/impl/FakeGuardImpl.h
c10/core/impl/VirtualGuardImpl.h
c10/cuda/CUDACachingAllocator.cpp
c10/cuda/CUDAFunctions.h
c10/cuda/impl/CUDAGuardImpl.h
test/cpp_extensions/msnpu_extension.cpp
torch/csrc/api/src/cuda.cpp
torch/csrc/autograd/profiler_cuda.cpp
torch/csrc/cuda/Module.cpp

index ffd2802..718f5f9 100644 (file)
@@ -36,22 +36,20 @@ It is expected that the modules whose functions compose this interface will
 manage their own state. There is only a single CUDA context/state.
 */
 
-/* Device info */
+/**
+ * DEPRECATED: use device_count() instead
+ */
 inline int64_t getNumGPUs() {
     return c10::cuda::device_count();
 }
 
 /**
- * In some situations, you may have compiled with CUDA, but no CUDA
- * device is actually available.  Test for this case using is_available().
+ * CUDA is available if we compiled with CUDA, and there are one or more
+ * devices.  If we compiled with CUDA but there is a driver problem, etc.,
+ * this function will report CUDA is not available (rather than raise an error.)
  */
 inline bool is_available() {
-    int count;
-    cudaError_t err = cudaGetDeviceCount(&count);
-    if (err == cudaErrorInsufficientDriver) {
-      return false;
-    }
-    return count > 0;
+    return c10::cuda::device_count() > 0;
 }
 
 CAFFE2_API cudaDeviceProp* getCurrentDeviceProperties();
index d1da474..c93ef77 100644 (file)
@@ -46,12 +46,7 @@ std::unique_ptr<Generator> CUDAHooks::initCUDAGenerator(
 }
 
 bool CUDAHooks::hasCUDA() const {
-  int count;
-  cudaError_t err = cudaGetDeviceCount(&count);
-  if (err == cudaErrorInsufficientDriver) {
-    return false;
-  }
-  return true;
+  return at::cuda::is_available();
 }
 
 bool CUDAHooks::hasMAGMA() const {
@@ -152,15 +147,7 @@ void CUDAHooks::cuFFTClearPlanCache() const {
 }
 
 int CUDAHooks::getNumGPUs() const {
-  int count;
-  auto err = cudaGetDeviceCount(&count);
-  if (err == cudaErrorNoDevice) {
-    return 0;
-  } else if (err != cudaSuccess) {
-    AT_ERROR(
-        "CUDA error (", static_cast<int>(err), "): ", cudaGetErrorString(err));
-  }
-  return count;
+  return at::cuda::device_count();
 }
 
 // Sigh, the registry doesn't support namespaces :(
index ca08974..8a86ab7 100644 (file)
@@ -34,7 +34,7 @@ struct CPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
     // no-op
     return Stream(Stream::DEFAULT, Device(DeviceType::CPU, -1));
   }
-  DeviceIndex deviceCount() const override {
+  DeviceIndex deviceCount() const noexcept override {
     return 1;
   }
 };
index 5c7f207..7723e7a 100644 (file)
@@ -89,7 +89,7 @@ struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplI
     setCurrentHIPStreamMasqueradingAsCUDA(cs);
     return old_stream.unwrap();
   }
-  DeviceIndex deviceCount() const override {
+  DeviceIndex deviceCount() const noexcept override {
     int deviceCnt;
     C10_HIP_CHECK(hipGetDeviceCount(&deviceCnt));
     return deviceCnt;
index 66b1dff..08264d1 100644 (file)
@@ -81,9 +81,11 @@ struct C10_API DeviceGuardImplInterface {
   virtual Stream exchangeStream(Stream) const noexcept = 0;
 
   /**
-   * Get the number of devices.
+   * Get the number of devices.  WARNING: This is REQUIRED to not raise
+   * an exception.  If there is some sort of problem, e.g., driver error,
+   * you should report that there are zero available devices.
    */
-  virtual DeviceIndex deviceCount() const = 0;
+  virtual DeviceIndex deviceCount() const noexcept = 0;
 
   /**
    * Intended use of this class is to leak the DeviceGuardImpl at program end.
index 95088bb..7c18c8d 100644 (file)
@@ -54,7 +54,7 @@ struct FakeGuardImpl final : public DeviceGuardImplInterface {
     current_streams_[s.device_index()] = s.id();
     return Stream(Stream::UNSAFE, s.device(), old_id);
   }
-  DeviceIndex deviceCount() const override {
+  DeviceIndex deviceCount() const noexcept override {
     return 1;
   }
   // Convenience methods for testing
index 07e743a..608ffe6 100644 (file)
@@ -40,7 +40,7 @@ public:
   Stream exchangeStream(Stream s) const noexcept override {
     return impl_->exchangeStream(s);
   }
-  DeviceIndex deviceCount() const override {
+  DeviceIndex deviceCount() const noexcept override {
     return impl_->deviceCount();
   }
 private:
index 33d26ab..ff39c5c 100644 (file)
@@ -2,6 +2,7 @@
 
 #include <c10/cuda/CUDAGuard.h>
 #include <c10/cuda/CUDAException.h>
+#include <c10/cuda/CUDAFunctions.h>
 #include <c10/util/UniqueVoidPtr.h>
 
 #include <cuda_runtime_api.h>
@@ -624,9 +625,8 @@ std::mutex* getFreeMutex()
 }
 
 static inline void assertValidDevice(int device) {
-  int device_count;
-  C10_CUDA_CHECK(cudaGetDeviceCount(&device_count));
-  AT_ASSERTM(0 <= device && device < device_count, "Invalid device argument.");
+  int device_num = device_count();
+  AT_ASSERTM(0 <= device && device < device_num, "Invalid device argument.");
 }
 
 uint64_t currentMemoryAllocated(int device)
index a01ae1a..5597831 100644 (file)
 namespace c10 {
 namespace cuda {
 
-inline DeviceIndex device_count() {
+inline DeviceIndex device_count() noexcept {
   int count;
-  C10_CUDA_CHECK(cudaGetDeviceCount(&count));
+  // NB: In the past, we were inconsistent about whether or not this reported
+  // an error if there were driver problems are not.  Based on experience
+  // interacting with users, it seems that people basically ~never want this
+  // function to fail; it should just return zero if things are not working.
+  // Oblige them.
+  cudaError_t err = cudaGetDeviceCount(&count);
+  if (err != cudaSuccess) {
+    // Clear out the error state, so we don't spuriously trigger someone else.
+    // (This shouldn't really matter, since we won't be running very much CUDA
+    // code in this regime.)
+    cudaGetLastError();
+    return 0;
+  }
   return static_cast<DeviceIndex>(count);
 }
 
index 21b7298..6dbf9e5 100644 (file)
@@ -5,6 +5,7 @@
 
 #include <c10/cuda/CUDAException.h>
 #include <c10/cuda/CUDAStream.h>
+#include <c10/cuda/CUDAFunctions.h>
 
 #include <cuda_runtime_api.h>
 
@@ -54,10 +55,8 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
     setCurrentCUDAStream(cs);
     return old_stream.unwrap();
   }
-  DeviceIndex deviceCount() const override {
-    int deviceCnt;
-    C10_CUDA_CHECK(cudaGetDeviceCount(&deviceCnt));
-    return deviceCnt;
+  DeviceIndex deviceCount() const noexcept override {
+    return device_count();
   }
 };
 
index 6336e5e..dfffd08 100644 (file)
@@ -122,7 +122,7 @@ struct MSNPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
   Stream exchangeStream(Stream s) const noexcept override {
     return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
   }
-  DeviceIndex deviceCount() const override {
+  DeviceIndex deviceCount() const noexcept override {
     return 1;
   }
 };
index f4fe00d..77e7f75 100644 (file)
@@ -14,6 +14,7 @@ bool is_available() {
   // NB: the semantics of this are different from at::globalContext().hasCUDA();
   // ATen's function tells you if you have a working driver and CUDA build,
   // whereas this function also tells you if you actually have any GPUs.
+  // This function matches the semantics of at::cuda::is_available()
   return cuda::device_count() > 0;
 }
 
index dadadf2..b03d574 100644 (file)
@@ -43,8 +43,7 @@ struct CUDAMethods : public CUDAStubs {
   }
   void onEachDevice(std::function<void(int)> op) override {
     at::cuda::OptionalCUDAGuard device_guard;
-    int count;
-    TORCH_CUDA_CHECK(cudaGetDeviceCount(&count));
+    int count = at::cuda::device_count();
     for(int i = 0; i < count; i++) {
       device_guard.set_index(i);
       op(i);
index 1aa4e65..44819cb 100644 (file)
@@ -57,12 +57,7 @@ PyObject * THCPModule_getDevice_wrap(PyObject *self)
 PyObject * THCPModule_getDeviceCount_wrap(PyObject *self)
 {
   HANDLE_TH_ERRORS
-  int ndevice;
-  if (cudaGetDeviceCount(&ndevice) != cudaSuccess) {
-    cudaGetLastError();
-    ndevice = 0;
-  }
-  return PyLong_FromLong(ndevice);
+  return PyLong_FromLong(at::cuda::device_count());
   END_HANDLE_TH_ERRORS
 }