Summary:
Tries to fix C++ API's usage of MAGMA-based functions.
Attempts to Fix https://github.com/pytorch/pytorch/issues/18074
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18527
Differential Revision:
D14691694
Pulled By: soumith
fbshipit-source-id:
dd04e74418e486d73ea4a92193ddf79352ed71ba
THCState* thc_state = THCState_alloc();
THCudaInit(thc_state);
+#ifdef USE_MAGMA
+ THCMagma_init(thc_state);
+#endif
return std::unique_ptr<THCState, void (*)(THCState*)>(
thc_state, [](THCState* p) {
if (p)
tensor = tensor.to(at::kCPU, at::kInt);
REQUIRE_TENSOR_OPTIONS(at::kCPU, -1, at::kInt, at::kStrided);
}
+
+TEST(TensorTest, MagmaInitializesCorrectly_CUDA) {
+ auto tensor = at::arange(1, 17, at::TensorOptions(at::kFloat).device(at::Device("cuda")));
+ tensor = tensor.view({4, 4});
+ if (at::hasMAGMA()) {
+ at::inverse(tensor);
+ }
+}
THCPByteStorage_postInit(m);
THCPBoolStorage_postInit(m);
- bool has_magma = at::hasMAGMA();
- if (has_magma) {
- THCMagma_init(state);
- }
-
bool has_half = true;
auto set_module_attr = [&](const char* name, PyObject* v) {
}
};
- set_module_attr("has_magma", has_magma ? Py_True : Py_False);
+ set_module_attr("has_magma", at::hasMAGMA() ? Py_True : Py_False);
set_module_attr("has_half", has_half ? Py_True : Py_False);
auto _state_cdata = THPObjectPtr(PyLong_FromVoidPtr(state));