Batched cholesky decomposition (#14017)
authorvishwakftw <cs15btech11043@iith.ac.in>
Sat, 17 Nov 2018 18:47:17 +0000 (10:47 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Sat, 17 Nov 2018 18:49:15 +0000 (10:49 -0800)
Summary:
Implements batching for the Cholesky decomposition.

Performance could be improved with a dedicated batched `tril` and `triu` op, which is also impeding autograd operations.

Changes made:
- batching code
- tests in `test_torch.py`, `test_cuda.py` and `test_autograd.py`.
- doc string modification
- autograd modification
- removal of `_batch_potrf` in `MultivariateNormal`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14017

Differential Revision: D13087945

Pulled By: ezyang

fbshipit-source-id: 2386db887140295475ffc247742d5e9562a42f6e

15 files changed:
aten/src/ATen/Declarations.cwrap
aten/src/ATen/core/aten_interned_strings.h
aten/src/ATen/native/BatchLinearAlgebra.cpp
aten/src/ATen/native/LegacyDefinitions.cpp
aten/src/ATen/native/LinearAlgebraUtils.h
aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
aten/src/ATen/native/native_functions.yaml
test/test_autograd.py
test/test_cuda.py
test/test_torch.py
tools/autograd/gen_python_functions.py
tools/autograd/templates/Functions.cpp
torch/_torch_docs.py
torch/distributions/lowrank_multivariate_normal.py
torch/distributions/multivariate_normal.py

index 230afdc..79ff168 100644 (file)
     - THTensor* self
 ]]
 [[
-  name: _th_potrf
+  name: _th_potrf_single
   cname: potrf
   types:
     - Float
index 192affb..1890538 100644 (file)
@@ -42,6 +42,7 @@ _(aten, _cast_Long) \
 _(aten, _cast_Short) \
 _(aten, _cat) \
 _(aten, _ceil) \
+_(aten, _cholesky_helper) \
 _(aten, _convolution) \
 _(aten, _convolution_double_backward) \
 _(aten, _convolution_nogroup) \
index 6f451f0..e06ba29 100644 (file)
@@ -27,6 +27,10 @@ extern "C" void sgetri_(int *n, float *a, int *lda, int *ipiv, float *work, int
 // potrs
 extern "C" void dpotrs_(char *uplo, int *n, int *nrhs, double *a, int *lda, double *b, int *ldb, int *info);
 extern "C" void spotrs_(char *uplo, int *n, int *nrhs, float *a, int *lda, float *b, int *ldb, int *info);
+
+// potrf
+extern "C" void dpotrf_(char *uplo, int *n, double *a, int *lda, int *info);
+extern "C" void spotrf_(char *uplo, int *n, float *a, int *lda, int *info);
 #endif
 
 namespace at {
@@ -54,6 +58,11 @@ void lapackPotrs(char uplo, int n, int nrhs, scalar_t *a, int lda, scalar_t *b,
   AT_ERROR("potrs only takes float or double Tensors");
 }
 
+template<class scalar_t>
+void lapackCholesky(char uplo, int n, scalar_t *a, int lda, int *info) {
+  AT_ERROR("cholesky only takes float or double Tensors");
+}
+
 #ifdef USE_LAPACK
 template<> void lapackGesv<double>(int n, int nrhs, double *a, int lda, int *ipiv, double *b, int ldb, int *info) {
   dgesv_(&n, &nrhs, a, &lda, ipiv, b, &ldb, info);
@@ -86,6 +95,14 @@ template<> void lapackPotrs<double>(char uplo, int n, int nrhs, double *a, int l
 template<> void lapackPotrs<float>(char uplo, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) {
   spotrs_(&uplo, &n, &nrhs, a, &lda, b, &ldb, info);
 }
+
+template<> void lapackCholesky<double>(char uplo, int n, double *a, int lda, int *info) {
+  dpotrf_(&uplo, &n, a, &lda, info);
+}
+
+template<> void lapackCholesky<float>(char uplo, int n, float *a, int lda, int *info) {
+  spotrf_(&uplo, &n, a, &lda, info);
+}
 #endif
 
 // Below of the definitions of the functions operating on a batch that are going to be dispatched
@@ -97,7 +114,7 @@ template<typename scalar_t>
 static void apply_gesv(Tensor& b, Tensor& A, std::vector<int64_t>& infos) {
 #ifndef USE_LAPACK
   AT_ERROR("gesv: LAPACK library not found in compilation");
-#endif
+#else
   auto A_data = A.data<scalar_t>();
   auto b_data = b.data<scalar_t>();
   auto A_mat_stride = matrixStride(A);
@@ -109,8 +126,8 @@ static void apply_gesv(Tensor& b, Tensor& A, std::vector<int64_t>& infos) {
 
   auto ipiv = at::empty({n}, b.type().toScalarType(kInt));
 
+  int info;
   for (int64_t i = 0; i < batch_size; i++) {
-    int info;
     scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
     scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
     lapackGesv<scalar_t>(n, nrhs, A_working_ptr, n, ipiv.data<int>(), b_working_ptr, n, &info);
@@ -119,6 +136,7 @@ static void apply_gesv(Tensor& b, Tensor& A, std::vector<int64_t>& infos) {
       return;
     }
   }
+#endif
 }
 
 std::tuple<Tensor, Tensor> _gesv_helper_cpu(const Tensor& self, const Tensor& A) {
@@ -159,7 +177,7 @@ template <typename scalar_t>
 static void apply_inverse(Tensor& self, std::vector<int64_t>& infos) {
 #ifndef USE_LAPACK
   AT_ERROR("inverse: LAPACK library not found in compilation");
-#endif
+#else
   auto self_data = self.data<scalar_t>();
   auto self_matrix_stride = matrixStride(self);
 
@@ -194,6 +212,7 @@ static void apply_inverse(Tensor& self, std::vector<int64_t>& infos) {
       return;
     }
   }
+#endif
 }
 
 Tensor _inverse_helper_cpu(const Tensor& self) {
@@ -213,7 +232,7 @@ Tensor inverse(const Tensor &self) {
   if (self.dim() == 2) {
     return at::_th_getri_single(self);
   }
-  inverseCheckInputs(self);
+  squareCheckInputs(self);
   return at::_inverse_helper(self);
 }
 
@@ -231,7 +250,7 @@ template<typename scalar_t>
 static void apply_potrs(Tensor& b, Tensor& A, bool upper, std::vector<int64_t>& infos) {
 #ifndef USE_LAPACK
   AT_ERROR("potrs: LAPACK library not found in compilation");
-#endif
+#else
   char uplo = upper ? 'U' : 'L';
 
   auto A_data = A.data<scalar_t>();
@@ -253,6 +272,7 @@ static void apply_potrs(Tensor& b, Tensor& A, bool upper, std::vector<int64_t>&
       return;
     }
   }
+#endif
 }
 
 Tensor _potrs_helper_cpu(const Tensor& self, const Tensor& A, bool upper) {
@@ -284,4 +304,74 @@ Tensor& potrs_out(Tensor& result, const Tensor& self, const Tensor& A, bool uppe
   return at::_th_potrs_single_out(result, self, A, upper);
 }
 
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+template<typename scalar_t>
+static void apply_cholesky(Tensor& self, bool upper, std::vector<int64_t>& infos) {
+#ifndef USE_LAPACK
+  AT_ERROR("cholesky: LAPACK library not found in compilation");
+#else
+  char uplo = upper ? 'U' : 'L';
+
+  auto self_data = self.data<scalar_t>();
+  auto self_matrix_stride = matrixStride(self);
+
+  auto batch_size = batchCount(self);
+  auto n = self.size(-2);
+
+  int info;
+  for (int64_t i = 0; i < batch_size; i++) {
+    scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
+    lapackCholesky<scalar_t>(uplo, n, self_working_ptr, n, &info);
+    infos[i] = info;
+    if (info != 0) {
+      return;
+    }
+  }
+#endif
+}
+
+Tensor _cholesky_helper_cpu(const Tensor& self, bool upper) {
+  std::vector<int64_t> infos(batchCount(self), 0);
+  auto self_working_copy = cloneBatchedColumnMajor(self);
+  AT_DISPATCH_FLOATING_TYPES(self.type(), "cholesky", [&]{
+    apply_cholesky<scalar_t>(self_working_copy, upper, infos);
+  });
+  batchCheckErrors(infos, "cholesky");
+  return self_working_copy;
+}
+
+Tensor cholesky(const Tensor &self, bool upper) {
+  if (self.size(-1) == 0) {
+    return at::empty_like(self);
+  }
+  if (self.dim() == 2) {
+    return at::_th_potrf_single(self, upper);
+  }
+  squareCheckInputs(self);
+
+  // TODO: (#14071) Once `triu`, `tril` is implemented for batched tensors,
+  // this can be simplified. Currently, we are zero-ing out values in the
+  // batch of matrices by using a mask and the `where` function.
+  // The simplification with batched `triu` and `tril` would be this:
+  // if (upper) {
+  //   return raw_cholesky_output.triu();
+  // } else {
+  //   return raw_cholesky_output.tril();
+  // }
+  auto raw_cholesky_output = at::_cholesky_helper(self, upper);
+  int64_t n = self.size(-1);
+  auto indices = at::ones({n, n}, self.options().dtype(at::kByte));
+  indices = upper ? indices.tril(-1).expand_as(self) : indices.triu(1).expand_as(self);
+  return at::where(indices, at::zeros({}, self.options()), raw_cholesky_output);
+}
+
+Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) {
+  if (self.size(-1) == 0) {
+    return result.resize_as_(self);
+  }
+  result.copy_(native::cholesky(self, upper));
+  return result;
+}
+
 }}  // namespace at::native
index acd59bd..a0fa3ae 100644 (file)
@@ -483,14 +483,6 @@ std::tuple<Tensor,Tensor,Tensor> svd(const Tensor & self, bool some, bool comput
   return at::_th_svd(self, some, compute_uv);
 }
 
-Tensor & cholesky_out(Tensor & result, const Tensor & self, bool upper) {
-  return at::_th_potrf_out(result, self, upper);
-}
-
-Tensor cholesky(const Tensor & self, bool upper) {
-  return at::_th_potrf(self, upper);
-}
-
 Tensor & potri_out(Tensor & result, const Tensor & self, bool upper) {
   return at::_th_potri_out(result, self, upper);
 }
index 9c467f2..25668ca 100644 (file)
@@ -65,8 +65,8 @@ static inline void linearSolveCheckInputs(const Tensor& self, const Tensor& A) {
            " but each b matrix is ", self.size(-2), " by ", self.size(-1));
 }
 
-// Validates input shapes for inverse
-static inline void inverseCheckInputs(const Tensor& self) {
+// Validates input shapes for operations on batches of square matrices (inverse, cholesky)
+static inline void squareCheckInputs(const Tensor& self) {
   AT_CHECK(self.size(-1) == self.size(-2),
            "A must be batches of square matrices, "
            "but they are ", self.size(-1), " by ", self.size(-2), " matrices");
index 7b2d7bd..4c15c15 100644 (file)
@@ -50,6 +50,13 @@ void magmaPotrsBatched(
   AT_ERROR("potrs only takes float or double Tensors");
 }
 
+template<class scalar_t>
+void magmaCholeskyBatched(
+    magma_uplo_t uplo, magma_int_t n, scalar_t** dA_array, magma_int_t ldda,
+    magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
+  AT_ERROR("cholesky only takes float or double Tensors");
+}
+
 template<>
 void magmaGesvBatched<double>(
     magma_int_t n, magma_int_t nrhs, double** dA_array, magma_int_t ldda,
@@ -111,6 +118,20 @@ void magmaPotrsBatched<float>(
     float** dB_array, magma_int_t lddb, magma_int_t& info, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
     info = magma_spotrs_batched(uplo, n, nrhs, dA_array, ldda, dB_array, lddb, batchsize, magma_queue.get_queue());
 }
+
+template<>
+void magmaCholeskyBatched<double>(
+    magma_uplo_t uplo, magma_int_t n, double** dA_array, magma_int_t ldda,
+    magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
+    magma_dpotrf_batched(uplo, n, dA_array, ldda, info_array, batchsize, magma_queue.get_queue());
+}
+
+template<>
+void magmaCholeskyBatched<float>(
+    magma_uplo_t uplo, magma_int_t n, float** dA_array, magma_int_t ldda,
+    magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
+    magma_spotrf_batched(uplo, n, dA_array, ldda, info_array, batchsize, magma_queue.get_queue());
+}
 #endif
 
 #define ALLOCATE_ARRAY(name, type, size, dummy_tensor) \
@@ -297,6 +318,64 @@ Tensor _potrs_helper_cuda(const Tensor& self, const Tensor& A, bool upper) {
   return self_working_copy;
 }
 
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ cholesky ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+template <typename scalar_t>
+static void apply_cholesky(Tensor& self, bool upper, std::vector<int64_t>& infos) {
+#ifndef USE_MAGMA
+AT_ERROR("cholesky: MAGMA library not found in "
+    "compilation. Please rebuild with MAGMA.");
+#else
+  magma_uplo_t uplo = upper ? MagmaUpper : MagmaLower;
+
+  auto self_data = self.data<scalar_t>();
+  auto self_mat_stride = matrixStride(self);
+
+  magma_int_t batch_size = magma_int_cast(batchCount(self), "batchCount");
+  magma_int_t n = magma_int_cast(self.size(-2), "self.size(-2)");
+
+  magma_int_t* info_array;
+  scalar_t** self_array;
+
+  ALLOCATE_ARRAY(info_array, magma_int_t, batch_size, self);
+  ALLOCATE_ARRAY(self_array, scalar_t*, batch_size, self);
+
+  // Set up the created arrays
+  for (int64_t i = 0; i < batch_size; i++) {
+    self_array[i] = &self_data[i * self_mat_stride];
+  }
+
+  MAGMAQueue magma_queue(self.get_device());
+  magmaCholeskyBatched<scalar_t>(
+    uplo, n, self_array, n, info_array,
+    batch_size, magma_queue);
+
+  for (int64_t i = 0; i < batch_size; i++) {
+    infos[i] = info_array[i];
+  }
+#endif
+}
+
+Tensor _cholesky_helper_cuda(const Tensor& self, bool upper) {
+  std::vector<int64_t> infos(batchCount(self), 0);
+  Tensor self_working_copy;
+  if (upper) {
+    self_working_copy = cloneBatchedColumnMajor(self.transpose(-1, -2));
+  } else {
+    self_working_copy = cloneBatchedColumnMajor(self);
+  }
+
+  AT_DISPATCH_FLOATING_TYPES(self.type(), "cholesky", [&]{
+    apply_cholesky<scalar_t>(self_working_copy, false, infos);
+  });
+  batchCheckErrors(infos, "cholesky");
+  if (upper) {
+    return self_working_copy.transpose(-1, -2);
+  } else {
+    return self_working_copy;
+  }
+}
+
 }}  // namespace at::native
 
 #undef ALLOCATE_ARRAY
index 691f336..2e745a3 100644 (file)
 - func: cholesky(Tensor self, bool upper=false) -> Tensor
   variants: method, function
 
+- func: _cholesky_helper(Tensor self, bool upper) -> Tensor
+  variants: function
+  dispatch:
+    CPU: _cholesky_helper_cpu
+    CUDA: _cholesky_helper_cuda
+
 - func: potrs_out(Tensor result, Tensor self, Tensor input2, bool upper=true) -> Tensor
 
 - func: potrs(Tensor self, Tensor input2, bool upper=true) -> Tensor
index 25e81fb..68525c4 100644 (file)
@@ -2042,18 +2042,23 @@ class TestAutograd(TestCase):
 
     @skipIfNoLapack
     def test_cholesky(self):
-        root = torch.tril(torch.rand(S, S)).requires_grad_()
+        def func(root):
+            x = torch.matmul(root, root.transpose(-1, -2)) + 1e-05
+            return torch.cholesky(x, upper)
 
-        def run_test(upper):
-            def func(root):
-                x = torch.mm(root, root.t())
-                return torch.cholesky(x, upper)
+        def run_test(upper, dims):
+            root = torch.rand(*dims)
+            indices = torch.ones(dims[-1], dims[-1], dtype=torch.uint8).tril()
+            indices = indices.expand_as(root)
+            root[indices] = 0
+            root.requires_grad_()
 
             gradcheck(func, [root])
             gradgradcheck(func, [root])
 
-        run_test(upper=True)
-        run_test(upper=False)
+        for upper, dims in product([True, False], [(3, 3), (4, 3, 2, 2)]):
+            run_test(upper, dims)
+            run_test(upper, dims)
 
     @skipIfNoLapack
     def test_trtrs(self):
index 18bbee3..68149e6 100644 (file)
@@ -1593,6 +1593,14 @@ class TestCuda(TestCase):
     def test_potrs_batched_dims(self):
         _TestTorchMixin._test_potrs_batched_dims(self, lambda t: t.cuda())
 
+    @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
+    def test_cholesky(self):
+        _TestTorchMixin._test_cholesky(self, lambda t: t.cuda())
+
+    @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
+    def test_cholesky_batched(self):
+        _TestTorchMixin._test_cholesky_batched(self, lambda t: t.cuda())
+
     def test_view(self):
         _TestTorchMixin._test_view(self, lambda t: t.cuda())
 
index adb40fe..cb98a31 100644 (file)
@@ -5314,9 +5314,9 @@ class _TestTorchMixin(object):
         self.assertEqual(x, y)
         torch.set_rng_state(rng_state)
 
-    @skipIfNoLapack
-    def test_cholesky(self):
-        x = torch.rand(10, 10) + 1e-1
+    @staticmethod
+    def _test_cholesky(self, cast):
+        x = cast(torch.rand(10, 10) + 1e-1)
         A = torch.mm(x, x.t())
 
         # default Case
@@ -5334,6 +5334,29 @@ class _TestTorchMixin(object):
         B = torch.mm(L, L.t())
         self.assertEqual(A, B, 1e-14, 'cholesky (lower) did not allow rebuilding the original matrix')
 
+    @skipIfNoLapack
+    def test_cholesky(self):
+        self._test_cholesky(self, lambda t: t)
+
+    @staticmethod
+    def _test_cholesky_batched(self, cast):
+        from common_utils import random_symmetric_pd_matrix
+
+        def cholesky_test_helper(n, batch_dims, cast, upper):
+            A = cast(random_symmetric_pd_matrix(n, *batch_dims))
+            cholesky_exp = torch.stack([m.cholesky(upper=upper) for m in A.reshape(-1, n, n)])
+            cholesky_exp = cholesky_exp.reshape_as(A)
+            print(torch.cholesky(A, upper=upper))
+            print(cholesky_exp)
+            self.assertEqual(cholesky_exp, torch.cholesky(A, upper=upper))
+
+        for upper, batchsize in product([True, False], [(3,), (3, 4), (2, 3, 4)]):
+            cholesky_test_helper(3, batchsize, cast, upper)
+
+    @skipIfNoLapack
+    def test_cholesky_batched(self):
+        self._test_cholesky_batched(self, lambda t: t)
+
     @staticmethod
     def _test_potrs(self, cast):
         a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
@@ -5367,15 +5390,9 @@ class _TestTorchMixin(object):
     def _test_potrs_batched(self, cast):
         from common_utils import random_symmetric_pd_matrix
 
-        # TODO: This function should be replaced after batch potrf is ready
-        def get_cholesky(bmat, upper):
-            n = bmat.size(-1)
-            cholesky = torch.stack([m.cholesky(upper) for m in bmat.reshape(-1, n, n)])
-            return cholesky.reshape_as(bmat)
-
         def potrs_test_helper(A_dims, b_dims, cast, upper):
             A = cast(random_symmetric_pd_matrix(*A_dims))
-            L = get_cholesky(A, upper)
+            L = torch.cholesky(A, upper)
             b = cast(torch.randn(*b_dims))
             return A, L, b
 
@@ -5413,7 +5430,7 @@ class _TestTorchMixin(object):
             A = cast(A).permute(0, 2, 1)
             b = cast(b).permute(2, 1, 0)
             assert not A.is_contiguous() and not b.is_contiguous(), "contiguous inputs"
-            L = get_cholesky(A, upper)
+            L = torch.cholesky(A, upper)
             x = torch.potrs(b, L, upper=upper)
             self.assertEqual(x, cast(x_exp))
 
@@ -5429,18 +5446,12 @@ class _TestTorchMixin(object):
         from numpy.linalg import solve
         from common_utils import random_symmetric_pd_matrix
 
-        # TODO: This function should be replaced after batch potrf is ready
-        def get_cholesky(bmat, upper):
-            n = bmat.size(-1)
-            cholesky = torch.stack([m.cholesky(upper) for m in bmat.reshape(-1, n, n)])
-            return cholesky.reshape_as(bmat)
-
         def run_test(A_dims, b_dims, cast, upper):
             A = random_symmetric_pd_matrix(*A_dims)
             b = torch.randn(*b_dims)
             x_exp = torch.Tensor(solve(A.numpy(), b.numpy()))
             A, b = cast(A), cast(b)
-            L = get_cholesky(A, upper)
+            L = torch.cholesky(A, upper)
             x = torch.potrs(b, L, upper=upper)
             self.assertEqual(x, cast(x_exp))
 
index 4b19a93..f0b0de1 100644 (file)
@@ -26,7 +26,8 @@ SKIP_PYTHON_BINDINGS = [
     '_indexCopy_', 'max_values', 'min_values', 'argmax', 'argmin',
     '_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*',
     '_th_.*', '_thnn_.*',
-    'arange.*', 'range.*', '_gesv.*', '_getri.*', '_inverse.*', '_potrs.*',
+    'arange.*', 'range.*', '_gesv.*', '_getri.*', '_inverse.*',
+    '_potrs.*', '_cholesky.*',
     'slice', 'randint(_out)?',
     '_local_scalar', '_local_scalar_dense',
     'max_pool1d', 'max_pool2d', 'max_pool3d', 'linear', 'to',
index 5a411bd..eb2beaf 100644 (file)
@@ -642,27 +642,35 @@ Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntList
 Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) {
   // cf. Iain Murray (2016); arXiv 1602.07527
   if (upper) {
-    L = L.t();
-    grad = grad.t();
+    grad = grad.transpose(-1, -2);
+  } else {
+    L = L.transpose(-1, -2);
   }
 
-  auto phi = [](const Tensor & A) -> Tensor {
-    auto B = A.tril();
-    B = B - 0.5 * at::diag(at::diag(B));
+  auto batch_tril = [](const Tensor & A) -> Tensor {
+    int64_t n = A.size(-1);
+    auto indices = at::ones({n, n}, A.options().dtype(at::kByte));
+    indices = indices.triu(1).expand_as(A);
+    return at::where(indices, at::zeros({}, A.options()), A);
+  };
+
+  auto phi = [&batch_tril](const Tensor & A) -> Tensor {
+    auto B = A.dim() == 2 ? A.tril() : batch_tril(A);
+    B = B - 0.5 * at::diag_embed(B.diagonal(0, -1, -2), 0, -2, -1);
     return B;
   };
 
   // make sure not to double-count variation, since
   // only half of output matrix is unique
-  auto Lbar = grad.tril();
+  auto Lbar = grad.dim() == 2 ? grad.tril() : batch_tril(grad);
 
-  auto P = phi(at::mm(L.t(), Lbar));
+  auto P = phi(at::matmul(L, Lbar));
   Tensor S;
-  std::tie(S, std::ignore) = at::gesv(P + P.t(), L.t());
-  std::tie(S, std::ignore) = at::gesv(S.t(), L.t());
+  std::tie(S, std::ignore) = at::gesv(P + P.transpose(-1, -2), L);
+  std::tie(S, std::ignore) = at::gesv(S.transpose(-1, -2), L);
   S = phi(S);
   if (upper) {
-    S = S.t();
+    S = S.transpose(-1, -2);
   }
   return S;
 }
index 07e7cc8..afbe3f5 100644 (file)
@@ -826,10 +826,10 @@ Example::
 """)
 
 add_docstr(torch.cholesky, r"""
-cholesky(a, upper=False, out=None) -> Tensor
+cholesky(A, upper=False, out=None) -> Tensor
 
 Computes the Cholesky decomposition of a symmetric positive-definite
-matrix :math:`A`.
+matrix :math:`A` or for batches of symmetric positive-definite matrices.
 
 If :attr:`upper` is ``True``, the returned matrix `U` is upper-triangular, and
 the decomposition has the form:
@@ -845,16 +845,23 @@ the decomposition has the form:
 
     A = LL^T
 
+If :attr:`upper` is ``True``, and :attr:`A` is a batch of symmetric positive-definite
+matrices, then the returned tensor will be composed of upper-triangular Cholesky factors
+of each of the individual matrices. Similarly, when :attr:`upper` is ``False``, the returned
+tensor will be composed of lower-triangular Cholesky factors of each of the individual
+matrices.
+
 Args:
-    a (Tensor): the input 2-D tensor, a symmetric positive-definite matrix
-    upper (bool, optional): flag that indicates whether to return the
+    a (Tensor): the input tensor of size (*, n, n) where `*` is zero or more
+                batch dimensions consisting of symmetric positive-definite matrices.
+    upper (bool, optional): flag that indicates whether to return a
                             upper or lower triangular matrix. Default: ``False``
     out (Tensor, optional): the output matrix
 
 Example::
 
     >>> a = torch.randn(3, 3)
-    >>> a = torch.mm(a, a.t()) # make symmetric positive definite
+    >>> a = torch.mm(a, a.t()) # make symmetric positive-definite
     >>> l = torch.cholesky(a)
     >>> a
     tensor([[ 2.4112, -0.7486,  1.4551],
@@ -868,6 +875,12 @@ Example::
     tensor([[ 2.4112, -0.7486,  1.4551],
             [-0.7486,  1.3544,  0.1294],
             [ 1.4551,  0.1294,  1.6724]])
+    >>> a = torch.randn(3, 2, 2)
+    >>> a = torch.matmul(a, a.transpose(-1, -2)) + 1e-03 # make symmetric positive-definite
+    >>> l = torch.cholesky(a)
+    >>> z = torch.matmul(l, l.transpose(-1, -2))
+    >>> torch.max(torch.abs(z - a)) # Max non-zero
+    tensor(2.3842e-07)
 """)
 
 add_docstr(torch.clamp,
index 4cc3d98..7377522 100644 (file)
@@ -4,7 +4,7 @@ import torch
 from torch.distributions import constraints
 from torch.distributions.distribution import Distribution
 from torch.distributions.multivariate_normal import (_batch_diag, _batch_mahalanobis, _batch_mv,
-                                                     _batch_potrf_lower, _batch_trtrs_lower)
+                                                     _batch_trtrs_lower)
 from torch.distributions.utils import _standard_normal, lazy_property
 
 
@@ -27,7 +27,7 @@ def _batch_capacitance_tril(W, D):
     Wt_Dinv = W.transpose(-1, -2) / D.unsqueeze(-2)
     K = torch.matmul(Wt_Dinv, W).contiguous()
     K.view(-1, m * m)[:, ::m + 1] += 1  # add identity matrix to K
-    return _batch_potrf_lower(K)
+    return torch.cholesky(K)
 
 
 def _batch_lowrank_logdet(W, D, capacitance_tril):
@@ -150,7 +150,7 @@ class LowRankMultivariateNormal(Distribution):
         Dinvsqrt_W = self.cov_factor / cov_diag_sqrt_unsqueeze
         K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.transpose(-1, -2)).contiguous()
         K.view(-1, n * n)[:, ::n + 1] += 1  # add identity matrix to K
-        return cov_diag_sqrt_unsqueeze * _batch_potrf_lower(K)
+        return cov_diag_sqrt_unsqueeze * torch.cholesky(K)
 
     @lazy_property
     def covariance_matrix(self):
index ef96bd4..f6330d6 100644 (file)
@@ -20,15 +20,6 @@ def _batch_mv(bmat, bvec):
     return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1)
 
 
-def _batch_potrf_lower(bmat):
-    r"""
-    Applies a Cholesky decomposition to all matrices in a batch of arbitrary shape.
-    """
-    n = bmat.size(-1)
-    cholesky_ = torch.stack([m.cholesky(upper=False) for m in bmat.reshape(-1, n, n)])
-    return cholesky_.reshape(bmat.shape)
-
-
 def _batch_diag(bmat):
     r"""
     Returns the diagonals of a batch of square matrices.
@@ -137,7 +128,7 @@ class MultivariateNormal(Distribution):
         else:
             if precision_matrix is not None:
                 self.covariance_matrix = torch.inverse(precision_matrix).expand_as(loc_)
-            self._unbroadcasted_scale_tril = _batch_potrf_lower(self.covariance_matrix)
+            self._unbroadcasted_scale_tril = torch.cholesky(self.covariance_matrix)
 
     def expand(self, batch_shape, _instance=None):
         new = self._get_checked_instance(MultivariateNormal, _instance)