#include <ATen/ATen.h>
#include <ATen/ExpandUtils.h>
#include <limits>
+#include <sstream>
namespace at { namespace native {
}
}
-// Validates input shapes for linear solve methods (solve, cholesky_solve)
+// Validates input shapes and devices for linear solve methods (gesv, cholesky_solve)
static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A) {
+ int64_t self_is_cuda = self.is_cuda();
+ int64_t A_is_cuda = A.is_cuda();
+
+ std::stringstream ss;
+ if (self_is_cuda != A_is_cuda) {
+ ss << "Expected b and A to be on the same device, but found b on ";
+ if (self_is_cuda) {
+ ss << "GPU";
+ } else {
+ ss << "CPU";
+ }
+ ss << " and A on ";
+ if (A_is_cuda) {
+ ss << "GPU";
+ } else {
+ ss << "CPU";
+ }
+ ss << " instead.";
+ AT_ERROR(ss.str());
+ }
+
AT_CHECK(A.size(-1) == A.size(-2),
"A must be batches of square matrices, "
"but they are ", A.size(-1), " by ", A.size(-2), " matrices");
def test_solve_batched_dims(self):
self._test_solve_batched_dims(self, lambda t: t)
+ def test_solve_methods_arg_device(self):
+ if not torch.cuda.is_available():
+ return
+
+ for b_device, A_device in product(['cpu', 'cuda'], repeat=2):
+ if b_device == A_device:
+ continue
+
+ b = torch.randn(3, 1, device=b_device)
+ A = torch.randn(3, 3, device=A_device)
+ err_str = "Expected b and A to be on the same device"
+ with self.assertRaisesRegex(RuntimeError, err_str):
+ torch.gesv(b, A)
+
+ with self.assertRaisesRegex(RuntimeError, err_str):
+ torch.cholesky_solve(b, A)
+
@skipIfNoLapack
def test_qr(self):