Add backend checks to solve methods (gesv, cholesky_solve) (#18116)
authorvishwakftw <cs15btech11043@iith.ac.in>
Tue, 19 Mar 2019 17:36:23 +0000 (10:36 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 19 Mar 2019 17:44:45 +0000 (10:44 -0700)
Summary:
Changelog:
- Incorporate a simple backend check in the linearSolveCheckInputs function in LinearAlgebraUtils.h
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18116

Differential Revision: D14504469

Pulled By: soumith

fbshipit-source-id: 7402b6dbaa8d73048946613b806d54f68bcbd8f4

aten/src/ATen/native/LinearAlgebraUtils.h
test/test_torch.py

index 0d95096..fc23d3c 100644 (file)
@@ -1,6 +1,7 @@
 #include <ATen/ATen.h>
 #include <ATen/ExpandUtils.h>
 #include <limits>
+#include <sstream>
 
 namespace at { namespace native {
 
@@ -75,8 +76,29 @@ static inline double _get_epsilon(const ScalarType& sc_type) {
   }
 }
 
-// Validates input shapes for linear solve methods (solve, cholesky_solve)
+// Validates input shapes and devices for linear solve methods (gesv, cholesky_solve)
 static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A) {
+  int64_t self_is_cuda = self.is_cuda();
+  int64_t A_is_cuda = A.is_cuda();
+
+  std::stringstream ss;
+  if (self_is_cuda != A_is_cuda) {
+    ss << "Expected b and A to be on the same device, but found b on ";
+    if (self_is_cuda) {
+      ss << "GPU";
+    } else {
+      ss << "CPU";
+    }
+    ss << " and A on ";
+    if (A_is_cuda) {
+      ss << "GPU";
+    } else {
+      ss << "CPU";
+    }
+    ss << " instead.";
+    AT_ERROR(ss.str());
+  }
+
   AT_CHECK(A.size(-1) == A.size(-2),
            "A must be batches of square matrices, "
            "but they are ", A.size(-1), " by ", A.size(-2), " matrices");
index 07da497..6d2c5e1 100644 (file)
@@ -4711,6 +4711,23 @@ class _TestTorchMixin(object):
     def test_solve_batched_dims(self):
         self._test_solve_batched_dims(self, lambda t: t)
 
+    def test_solve_methods_arg_device(self):
+        if not torch.cuda.is_available():
+            return
+
+        for b_device, A_device in product(['cpu', 'cuda'], repeat=2):
+            if b_device == A_device:
+                continue
+
+            b = torch.randn(3, 1, device=b_device)
+            A = torch.randn(3, 3, device=A_device)
+            err_str = "Expected b and A to be on the same device"
+            with self.assertRaisesRegex(RuntimeError, err_str):
+                torch.gesv(b, A)
+
+            with self.assertRaisesRegex(RuntimeError, err_str):
+                torch.cholesky_solve(b, A)
+
     @skipIfNoLapack
     def test_qr(self):