Add ability to query if built with CUDA and MKL-DNN. (#18362)
authorEdward Yang <ezyang@fb.com>
Mon, 25 Mar 2019 17:22:54 +0000 (10:22 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 25 Mar 2019 17:39:09 +0000 (10:39 -0700)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18362
ghimport-source-id: 374b7ab97e2d6a894368007133201f510539296f

Stack from [ghstack](https://github.com/ezyang/ghstack):
* #18242 Test running a CUDA build on CPU machine.
* **#18362 Add ability to query if built with CUDA and MKL-DNN.**

Fixes #18108.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Differential Revision: D14584430

fbshipit-source-id: 7605a1ac4e8f2a7c70d52e5a43ad7f03f0457473

aten/src/ATen/Context.cpp
aten/src/ATen/Context.h
test/test_torch.py
torch/backends/cuda/__init__.py
torch/backends/mkldnn/__init__.py [new file with mode: 0644]
torch/csrc/Module.cpp

index 8ca1f1f..29a1d67 100644 (file)
@@ -84,6 +84,14 @@ bool Context::hasMKL() const {
 #endif
 }
 
+bool Context::hasMKLDNN() const {
+#if AT_MKLDNN_ENABLED()
+  return true;
+#else
+  return false;
+#endif
+}
+
 bool Context::hasOpenMP() const {
 #ifdef _OPENMP
   return true;
index 05a6328..76d1a90 100644 (file)
@@ -68,6 +68,7 @@ class CAFFE2_API Context {
   bool hasOpenMP() const;
   bool hasMKL() const;
   bool hasLAPACK() const;
+  bool hasMKLDNN() const;
   bool hasMAGMA() const {
     return detail::getCUDAHooks().hasMAGMA();
   }
@@ -244,6 +245,10 @@ static inline bool hasMAGMA() {
   return globalContext().hasMAGMA();
 }
 
+static inline bool hasMKLDNN() {
+  return globalContext().hasMKLDNN();
+}
+
 static inline void manual_seed(uint64_t seed) {
   globalContext().defaultGenerator(DeviceType::CPU).manualSeed(seed);
   // NB: Sometimes we build with CUDA, but we don't have any GPUs
index 5cc1ee8..989c1f2 100644 (file)
@@ -8,6 +8,7 @@ import copy
 import shutil
 import torch
 import torch.cuda
+import torch.backends.cuda
 import tempfile
 import unittest
 import warnings
@@ -10417,7 +10418,9 @@ tensor([[[1., 1., 1.,  ..., 1., 1., 1.],
         self.assertTrue(grid_b2.equal(expected_grid_b))
         self.assertTrue(grid_c2.equal(expected_grid_c))
 
-    @unittest.skipIf(torch.cuda.is_available() or IS_SANDCASTLE, "CUDA is available, can't test CUDA not built error")
+    # NB: we must not be built with CUDA; if we are built with CUDA but no CUDA
+    # is available, we get a different error.
+    @unittest.skipIf(torch.backends.cuda.is_built() or IS_SANDCASTLE, "CUDA is built, can't test CUDA not built error")
     def test_cuda_not_built(self):
         msg = "Torch not compiled with CUDA enabled"
         self.assertRaisesRegex(AssertionError, msg, lambda: torch.cuda.current_device())
index 0252e3a..a83f034 100644 (file)
@@ -2,6 +2,14 @@ import sys
 import torch
 
 
+def is_built():
+    r"""Returns whether PyTorch is built with CUDA support.  Note that this
+    doesn't necessarily mean CUDA is available; just that if this PyTorch
+    binary were run a machine with working CUDA drivers and devices, we
+    would be able to use it."""
+    return torch._C.has_cuda
+
+
 class ContextProp(object):
     def __init__(self, getter, setter):
         self.getter = getter
diff --git a/torch/backends/mkldnn/__init__.py b/torch/backends/mkldnn/__init__.py
new file mode 100644 (file)
index 0000000..1b852c9
--- /dev/null
@@ -0,0 +1,6 @@
+import torch
+
+
+def is_available():
+    r"""Returns whether PyTorch is built with MKL-DNN support."""
+    return torch._C.has_mkldnn
index 0ddf9ac..7a7bb1a 100644 (file)
@@ -645,6 +645,15 @@ PyObject* initModule() {
   ASSERT_TRUE(set_module_attr("has_mkl", at::hasMKL() ? Py_True : Py_False));
   ASSERT_TRUE(set_module_attr("has_lapack", at::hasLAPACK() ? Py_True : Py_False));
 
+#ifdef USE_CUDA
+  PyObject *has_cuda = Py_True;
+#else
+  PyObject *has_cuda = Py_False;
+#endif
+  ASSERT_TRUE(set_module_attr("has_cuda", has_cuda));
+
+  ASSERT_TRUE(set_module_attr("has_mkldnn", at::hasMKLDNN() ? Py_True : Py_False));
+
 #ifdef _GLIBCXX_USE_CXX11_ABI
   ASSERT_TRUE(set_module_attr("_GLIBCXX_USE_CXX11_ABI", _GLIBCXX_USE_CXX11_ABI ? Py_True : Py_False));
 #else