push magma init into lazyInitCUDA (#18527)
authorSoumith Chintala <soumith@gmail.com>
Wed, 3 Apr 2019 19:27:19 +0000 (12:27 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 3 Apr 2019 19:47:34 +0000 (12:47 -0700)
Summary:
Tries to fix C++ API's usage of MAGMA-based functions.

Attempts to Fix https://github.com/pytorch/pytorch/issues/18074
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18527

Differential Revision: D14691694

Pulled By: soumith

fbshipit-source-id: dd04e74418e486d73ea4a92193ddf79352ed71ba

aten/src/ATen/cuda/detail/CUDAHooks.cpp
test/cpp/api/tensor_cuda.cpp
torch/csrc/cuda/Module.cpp

index c93ef77..eb9b2f3 100644 (file)
@@ -33,6 +33,9 @@ std::unique_ptr<THCState, void (*)(THCState*)> CUDAHooks::initCUDA() const {
   THCState* thc_state = THCState_alloc();
 
   THCudaInit(thc_state);
+#ifdef USE_MAGMA
+  THCMagma_init(thc_state);
+#endif
   return std::unique_ptr<THCState, void (*)(THCState*)>(
       thc_state, [](THCState* p) {
         if (p)
index 44b6103..672f5c5 100644 (file)
@@ -113,3 +113,11 @@ TEST(TensorTest, ToDeviceAndDtype_MultiCUDA) {
   tensor = tensor.to(at::kCPU, at::kInt);
   REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kInt, at::kStrided);
 }
+
+TEST(TensorTest, MagmaInitializesCorrectly_CUDA) {
+  auto tensor = at::arange(1, 17, at::TensorOptions(at::kFloat).device(at::Device("cuda")));
+  tensor = tensor.view({4, 4});
+  if (at::hasMAGMA()) {
+    at::inverse(tensor);
+  }
+}
index 44819cb..2860112 100644 (file)
@@ -380,11 +380,6 @@ static PyObject * THCPModule_initExtension(PyObject *self)
   THCPByteStorage_postInit(m);
   THCPBoolStorage_postInit(m);
 
-  bool has_magma = at::hasMAGMA();
-  if (has_magma) {
-    THCMagma_init(state);
-  }
-
   bool has_half = true;
 
   auto set_module_attr = [&](const char* name, PyObject* v) {
@@ -394,7 +389,7 @@ static PyObject * THCPModule_initExtension(PyObject *self)
     }
   };
 
-  set_module_attr("has_magma", has_magma ? Py_True : Py_False);
+  set_module_attr("has_magma", at::hasMAGMA() ? Py_True : Py_False);
   set_module_attr("has_half", has_half ? Py_True : Py_False);
 
   auto _state_cdata = THPObjectPtr(PyLong_FromVoidPtr(state));