Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18445
ghimport-source-id:
30d018737bf6989bc68b7e3676f44e0ca6141fde
Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18242 Test running a CUDA build on CPU machine.
* **#18445 Unify cudaGetDeviceCount implementations.**
I went about doing this by searching for calls to cudaGetDeviceCount,
and then methodically replacing them with references to c10::cuda::device_count()
or at::cuda::device_count().
There is a point to doing this: the various implementations wildly differed
in their handling of what to do when cudaGetDeviceCount returns an error.
The final standardized behavior is that **all errors are swallowed** and
we return device count of zero. This indirectly fixes running CUDA builds
on CPU, which was broken in #17847.
I added 'noexcept' to the 'deviceCount' virtual method on DeviceGuardImpl.
This is a BC-breaking change for anyone inheriting from DeviceGuardImpl
but all you need to do is put 'noexcept' on your method and it is backwards
compatible with older libtorch.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision:
D14612189
fbshipit-source-id:
3c8d186e3dd623c0e27625212c7ce30f75d943cb
manage their own state. There is only a single CUDA context/state.
*/
-/* Device info */
+/**
+ * DEPRECATED: use device_count() instead
+ */
inline int64_t getNumGPUs() {
return c10::cuda::device_count();
}
/**
- * In some situations, you may have compiled with CUDA, but no CUDA
- * device is actually available. Test for this case using is_available().
+ * CUDA is available if we compiled with CUDA, and there are one or more
+ * devices. If we compiled with CUDA but there is a driver problem, etc.,
+ * this function will report CUDA is not available (rather than raise an error.)
*/
inline bool is_available() {
- int count;
- cudaError_t err = cudaGetDeviceCount(&count);
- if (err == cudaErrorInsufficientDriver) {
- return false;
- }
- return count > 0;
+ return c10::cuda::device_count() > 0;
}
CAFFE2_API cudaDeviceProp* getCurrentDeviceProperties();
}
bool CUDAHooks::hasCUDA() const {
- int count;
- cudaError_t err = cudaGetDeviceCount(&count);
- if (err == cudaErrorInsufficientDriver) {
- return false;
- }
- return true;
+ return at::cuda::is_available();
}
bool CUDAHooks::hasMAGMA() const {
}
int CUDAHooks::getNumGPUs() const {
- int count;
- auto err = cudaGetDeviceCount(&count);
- if (err == cudaErrorNoDevice) {
- return 0;
- } else if (err != cudaSuccess) {
- AT_ERROR(
- "CUDA error (", static_cast<int>(err), "): ", cudaGetErrorString(err));
- }
- return count;
+ return at::cuda::device_count();
}
// Sigh, the registry doesn't support namespaces :(
// no-op
return Stream(Stream::DEFAULT, Device(DeviceType::CPU, -1));
}
- DeviceIndex deviceCount() const override {
+ DeviceIndex deviceCount() const noexcept override {
return 1;
}
};
setCurrentHIPStreamMasqueradingAsCUDA(cs);
return old_stream.unwrap();
}
- DeviceIndex deviceCount() const override {
+ DeviceIndex deviceCount() const noexcept override {
int deviceCnt;
C10_HIP_CHECK(hipGetDeviceCount(&deviceCnt));
return deviceCnt;
virtual Stream exchangeStream(Stream) const noexcept = 0;
/**
- * Get the number of devices.
+ * Get the number of devices. WARNING: This is REQUIRED to not raise
+ * an exception. If there is some sort of problem, e.g., driver error,
+ * you should report that there are zero available devices.
*/
- virtual DeviceIndex deviceCount() const = 0;
+ virtual DeviceIndex deviceCount() const noexcept = 0;
/**
* Intended use of this class is to leak the DeviceGuardImpl at program end.
current_streams_[s.device_index()] = s.id();
return Stream(Stream::UNSAFE, s.device(), old_id);
}
- DeviceIndex deviceCount() const override {
+ DeviceIndex deviceCount() const noexcept override {
return 1;
}
// Convenience methods for testing
Stream exchangeStream(Stream s) const noexcept override {
return impl_->exchangeStream(s);
}
- DeviceIndex deviceCount() const override {
+ DeviceIndex deviceCount() const noexcept override {
return impl_->deviceCount();
}
private:
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAException.h>
+#include <c10/cuda/CUDAFunctions.h>
#include <c10/util/UniqueVoidPtr.h>
#include <cuda_runtime_api.h>
}
static inline void assertValidDevice(int device) {
- int device_count;
- C10_CUDA_CHECK(cudaGetDeviceCount(&device_count));
- AT_ASSERTM(0 <= device && device < device_count, "Invalid device argument.");
+ int device_num = device_count();
+ AT_ASSERTM(0 <= device && device < device_num, "Invalid device argument.");
}
uint64_t currentMemoryAllocated(int device)
namespace c10 {
namespace cuda {
-inline DeviceIndex device_count() {
+inline DeviceIndex device_count() noexcept {
int count;
- C10_CUDA_CHECK(cudaGetDeviceCount(&count));
+ // NB: In the past, we were inconsistent about whether or not this reported
+ // an error if there were driver problems are not. Based on experience
+ // interacting with users, it seems that people basically ~never want this
+ // function to fail; it should just return zero if things are not working.
+ // Oblige them.
+ cudaError_t err = cudaGetDeviceCount(&count);
+ if (err != cudaSuccess) {
+ // Clear out the error state, so we don't spuriously trigger someone else.
+ // (This shouldn't really matter, since we won't be running very much CUDA
+ // code in this regime.)
+ cudaGetLastError();
+ return 0;
+ }
return static_cast<DeviceIndex>(count);
}
#include <c10/cuda/CUDAException.h>
#include <c10/cuda/CUDAStream.h>
+#include <c10/cuda/CUDAFunctions.h>
#include <cuda_runtime_api.h>
setCurrentCUDAStream(cs);
return old_stream.unwrap();
}
- DeviceIndex deviceCount() const override {
- int deviceCnt;
- C10_CUDA_CHECK(cudaGetDeviceCount(&deviceCnt));
- return deviceCnt;
+ DeviceIndex deviceCount() const noexcept override {
+ return device_count();
}
};
Stream exchangeStream(Stream s) const noexcept override {
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
}
- DeviceIndex deviceCount() const override {
+ DeviceIndex deviceCount() const noexcept override {
return 1;
}
};
// NB: the semantics of this are different from at::globalContext().hasCUDA();
// ATen's function tells you if you have a working driver and CUDA build,
// whereas this function also tells you if you actually have any GPUs.
+ // This function matches the semantics of at::cuda::is_available()
return cuda::device_count() > 0;
}
}
void onEachDevice(std::function<void(int)> op) override {
at::cuda::OptionalCUDAGuard device_guard;
- int count;
- TORCH_CUDA_CHECK(cudaGetDeviceCount(&count));
+ int count = at::cuda::device_count();
for(int i = 0; i < count; i++) {
device_guard.set_index(i);
op(i);
PyObject * THCPModule_getDeviceCount_wrap(PyObject *self)
{
HANDLE_TH_ERRORS
- int ndevice;
- if (cudaGetDeviceCount(&ndevice) != cudaSuccess) {
- cudaGetLastError();
- ndevice = 0;
- }
- return PyLong_FromLong(ndevice);
+ return PyLong_FromLong(at::cuda::device_count());
END_HANDLE_TH_ERRORS
}