Improve numerical precision of (s)logdet (#18449)
authorTongzhou Wang <tongzhou.wang.1994@gmail.com>
Tue, 26 Mar 2019 22:25:26 +0000 (15:25 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 26 Mar 2019 22:32:14 +0000 (15:32 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/18448 and https://github.com/pytorch/pytorch/issues/18450
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18449

Differential Revision: D14611638

Pulled By: soumith

fbshipit-source-id: 4f1f27ab5316a92d2783e734169f599afed743cf

aten/src/ATen/native/LinearAlgebra.cpp
test/test_cuda.py
test/test_torch.py

index 9ec2165..5042acc 100644 (file)
@@ -9,6 +9,7 @@
 #include <functional>
 #include <numeric>
 #include <vector>
+#include <limits>
 
 namespace at {
 namespace native {
@@ -26,7 +27,7 @@ static inline std::tuple<double, Tensor, int> _lu_det_P_diag_U_info(const Tensor
   int int_info = info.squeeze_().item<int32_t>();
   AT_CHECK(int_info >= 0, "LU factorization (getrf) failed with info = ", int_info);
   auto n = self.size(0);
-  auto num_exchanges = (at::arange(1, n + 1, p.type()) != p).nonzero().size(0);
+  auto num_exchanges = (at::arange(1, n + 1, p.options()) != p).nonzero().size(0);
   if (num_exchanges % 2 == 1) {
     return std::make_tuple(-1., lu.diag(), int_info);
   } else {
@@ -44,7 +45,7 @@ Tensor det(const Tensor& self) {
   int info;
   std::tie(det_P, diag_U, info) = _lu_det_P_diag_U_info(self);
   if (info > 0) {
-    return at::zeros({}, self.type());
+    return at::zeros({}, self.options());
   } else {
     return diag_U.prod().mul_(det_P);
   }
@@ -56,16 +57,18 @@ Tensor logdet(const Tensor& self) {
            "logdet(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor "
            "of floating types");
   double det_P;
-  Tensor diag_U, det;
+  Tensor diag_U;
   int info;
   std::tie(det_P, diag_U, info) = _lu_det_P_diag_U_info(self);
   if (info > 0) {
-    det = at::zeros({}, self.type());
-  } else {
-    det = diag_U.prod().mul_(det_P);
+    return at::full({}, -std::numeric_limits<double>::infinity(), self.options());
   }
-  if (det.sign().item<double>() <= 0) {
-    return det.log_();  // in order to get proper -inf (det=0) or nan (det<0)
+  // `det_sign` is the sign of the determinant. We work on `diag_U.sign()` for
+  // numerical stability when diag_U has a lot small values.
+  auto det_sign = diag_U.sign().prod().mul_(det_P);
+  // This synchronizes on GPU, but `_lu_det_P_diag_U_info` above already synchronizes
+  if (det_sign.item<double>() <= 0) {
+    return det_sign.log_();  // get proper nan (det<0) or -inf (det=0)
   } else {
     return diag_U.abs_().log_().sum();
   }
@@ -77,15 +80,17 @@ std::tuple<Tensor, Tensor> slogdet(const Tensor& self) {
            "slogdet(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor "
            "of floating types");
   double det_P;
-  Tensor diag_U, det;
+  Tensor diag_U;
   int info;
   std::tie(det_P, diag_U, info) = _lu_det_P_diag_U_info(self);
   if (info > 0) {
     return std::make_tuple(at::zeros({}, self.options()),
-                           at::empty({}, self.options()).fill_(-INFINITY));
+                           at::full({}, -std::numeric_limits<double>::infinity(), self.options()));
   } else {
-    det = diag_U.prod().mul_(det_P);
-    return std::make_tuple(det.sign(), diag_U.abs_().log_().sum());
+    // `det_sign` is the sign of the determinant. We work on `diag_U.sign()` for
+    // numerical stability when diag_U has a lot small values.
+    auto det_sign = diag_U.sign().prod().mul_(det_P);
+    return std::make_tuple(det_sign, diag_U.abs_().log_().sum());
   }
 }
 
@@ -309,9 +314,9 @@ static inline Tensor& bmm_out_or_baddbmm_(Tensor& self_or_result, const Tensor&
         });
     }
   } else if (at::hasMKL() && at::native::is_floating_point(self_or_result)
-            && batch_items_contiguous_or_transposed(batch1)
-            && batch_items_contiguous_or_transposed(batch2)
-            && self_or_result.is_contiguous()) {
+            && batch_items_contiguous_or_transposed(batch1)
+            && batch_items_contiguous_or_transposed(batch2)
+            && self_or_result.is_contiguous()) {
     at::native::_baddbmm_mkl_(self_or_result, batch1, batch2, beta, alpha);
   } else { // split along batch dimension
     if (is_bmm_out) {
index a34e51f..93fcd7d 100644 (file)
@@ -2138,7 +2138,7 @@ class TestCuda(TestCase):
 
     @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
     def test_det_logdet_slogdet(self):
-        _TestTorchMixin._test_det_logdet_slogdet(self, lambda t: t.cuda())
+        _TestTorchMixin._test_det_logdet_slogdet(self, 'cuda')
 
     @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
     def test_solve(self):
index 8a041cb..f895516 100644 (file)
@@ -5531,47 +5531,60 @@ class _TestTorchMixin(object):
         self._test_chain_matmul(self, cast=lambda x: x)
 
     @staticmethod
-    def _test_det_logdet_slogdet(self, conv_fn):
-        def reference_det(M):
-            # naive row reduction
-            M = M.clone()
-            l = M.size(0)
-            multiplier = 1
-            for i in range(l):
-                if M[i, 0] != 0:
-                    if i != 0:
-                        M[0], M[i] = M[i], M[0]
-                        multiplier = -1
-                    break
+    def _test_det_logdet_slogdet(self, device):
+        def reference_slogdet(M):
+            if TEST_NUMPY:
+                sdet, logabsdet = np.linalg.slogdet(M.detach().cpu().numpy())
+                return M.new_tensor(sdet), M.new_tensor(logabsdet)
             else:
-                return 0
-            for i in range(1, l):
-                row = M[i]
-                for j in range(i):
-                    row -= row[j] / M[j, j] * M[j]
-                M[i] = row
-            return M.diag().prod() * multiplier
+                # naive row reduction
+                M = M.clone()
+                l = M.size(0)
+                multiplier = 1
+                for i in range(l):
+                    if M[i, 0].item() != 0:
+                        if i != 0:
+                            M[0], M[i] = M[i], M[0]
+                            multiplier = -1
+                        break
+                else:
+                    return 0
+                for i in range(1, l):
+                    row = M[i]
+                    for j in range(i):
+                        row -= row[j] / M[j, j] * M[j]
+                    M[i] = row
+            sdet = M.diag().sign().prod()
+            logabsdet = M.diag().abs_().log_().sum().add_(math.log(multiplier))
+            return sdet, logabsdet
 
         def test_single_det(M, target, desc):
+            target_sdet, target_logabsdet = target
+
             det = M.det()
             logdet = M.logdet()
             sdet, logabsdet = M.slogdet()
-            self.assertEqual(det, target, 1e-7, '{} (det)'.format(desc))
-            if det.item() < 0:
+
+            # Test det
+            self.assertEqual(det, target_sdet * target_logabsdet.exp(), 1e-7, '{} (det)'.format(desc))
+
+            # Test slogdet
+            # Compare the overall value rather than individual parts because of
+            # precision issues when det is near zero.
+            self.assertEqual(sdet * logabsdet.exp(), target_sdet * target_logabsdet.exp(), 1e-7, '{} (slogdet)'.format(desc))
+
+            # Test logdet
+            # Compare logdet against our own pytorch slogdet because they should
+            # be consistent, while it may behave slightly differently with other
+            # slogdet implementations when det is near zero due to precision
+            # issues.
+            if sdet.item() < 0:
                 self.assertTrue(logdet.item() != logdet.item(), '{} (logdet negative case)'.format(desc))
-                self.assertTrue(sdet.item() == -1, '{} (slogdet sign negative case)'.format(desc))
-                self.assertEqual(logabsdet.exp(), det.abs(), 1e-7, '{} (slogdet logabsdet negative case)'.format(desc))
-            elif det.item() == 0:
-                self.assertEqual(logdet.exp().item(), 0, 1e-7, '{} (logdet zero case)'.format(desc))
-                self.assertTrue(sdet.item() == 0, '{} (slogdet sign zero case)'.format(desc))
-                self.assertEqual(logabsdet.exp().item(), 0, 1e-7, '{} (slogdet logabsdet zero case)'.format(desc))
             else:
-                self.assertEqual(logdet.exp(), det, 1e-7, '{} (logdet positive case)'.format(desc))
-                self.assertTrue(sdet.item() == 1, '{} (slogdet sign  positive case)'.format(desc))
-                self.assertEqual(logabsdet.exp(), det, 1e-7, '{} (slogdet logabsdet positive case)'.format(desc))
+                self.assertEqual(logdet.exp(), target_logabsdet.exp(), 1e-7, '{} (logdet non-negative case)'.format(desc))
 
-        eye = conv_fn(torch.eye(5))
-        test_single_det(eye, torch.tensor(1, dtype=eye.dtype), 'identity')
+        eye = torch.eye(5, device=device)
+        test_single_det(eye, (torch.ones((), device=device), torch.zeros((), device=device)), 'identity')
 
         # TODO: Remove when MAGMA 2.5.0 is built for CUDA 8 and CUDA 9.2
         is_cuda_8_92 = False
@@ -5580,22 +5593,29 @@ class _TestTorchMixin(object):
 
         def test(M):
             assert M.size(0) >= 5, 'this helper fn assumes M to be at least 5x5'
-            M = conv_fn(M)
+            M = M.to(device)
 
             if M.is_cuda and is_cuda_8_92:
                 return
 
-            M_det = M.det()
-            ref_M_det = reference_det(M)
+            ref_M_sdet, ref_M_logabsdet = reference_slogdet(M)
 
-            test_single_det(M, ref_M_det, 'basic')
-            if abs(ref_M_det.item()) >= 1e-10:  # skip singular
-                test_single_det(M, M.inverse().det().pow_(-1), 'inverse')
-            test_single_det(M, M.t().det(), 'transpose')
+            test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'basic')
+            if ref_M_logabsdet.exp().item() >= 1e-6:  # skip singular
+                M_inv = M.inverse()
+                test_single_det(M_inv, reference_slogdet(M_inv), 'inverse')
+
+            test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'transpose')
 
             for x in [0, 2, 4]:
                 for scale in [-2, -0.1, 0, 10]:
-                    target = M_det * scale
+                    if scale > 0:
+                        target = ref_M_sdet, ref_M_logabsdet + math.log(scale)
+                    elif scale == 0:
+                        target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf)
+                    else:
+                        target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-scale)
+
                     # dim 0
                     M_clone = M.clone()
                     M_clone[:, x] *= scale
@@ -5607,7 +5627,7 @@ class _TestTorchMixin(object):
 
             for x1, x2 in [(0, 3), (4, 1), (3, 2)]:
                 assert x1 != x2, 'x1 and x2 needs to be different for this test'
-                target = M_det.clone().zero_()
+                target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf)
                 # dim 0
                 M_clone = M.clone()
                 M_clone[:, x2] = M_clone[:, x1]
@@ -5618,7 +5638,14 @@ class _TestTorchMixin(object):
                 test_single_det(M_clone, target, 'two columns are same')
 
                 for scale1, scale2 in [(0.3, -1), (0, 2), (10, 0.1)]:
-                    target = -M_det * scale1 * scale2
+                    det_scale = scale1 * scale2 * -1
+                    if det_scale > 0:
+                        target = ref_M_sdet, ref_M_logabsdet + math.log(det_scale)
+                    elif det_scale == 0:
+                        target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf)
+                    else:
+                        target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-det_scale)
+
                     # dim 0
                     M_clone = M.clone()
                     t = M_clone[:, x1] * scale1
@@ -5648,39 +5675,51 @@ class _TestTorchMixin(object):
             #   scale by sqrt( (n-1)! )^(-1/n) = ( (n-1)! )^(-1/(2n))
             #
             # source: https://arxiv.org/pdf/1112.0752.pdf
+
+            # TODO: technically we need subexponential distn for this to hold,
+            #       but we mostly use gaussian entries below. Consider switching
+            #       to Chi-sq if this turns out not stable enough, since Chi-sq
+            #       is easy enough to sample from.
             return math.factorial(n - 1) ** (-1.0 / (2 * n))
 
         for n in [5, 10, 25]:
             scale = get_random_mat_scale(n)
-            test(torch.randn(n, n) * scale)
-            r = torch.randn(n, n) * scale
+            test(torch.randn(n, n, device=device) * scale)
+            r = torch.randn(n, n, device=device) * scale
             # symmetric psd
             test(r.mm(r.t()))
             # symmetric pd
-            r = torch.randn(n, n) * scale
-            test(r.mm(r.t()) + torch.eye(n) * 1e-6)
+            r = torch.randn(n, n, device=device) * scale
+            test(r.mm(r.t()) + torch.eye(n, device=device) * 1e-6)
             # symmetric
-            r = torch.randn(n, n) * scale
+            r = torch.randn(n, n, device=device) * scale
             for i in range(n):
                 for j in range(i):
                     r[i, j] = r[j, i]
             test(r)
             # non-contiguous
-            test((torch.randn(n, n, n + 1) * scale)[:, 2, 1:])
+            test((torch.randn(n, n, n + 1, device=device) * scale)[:, 2, 1:])
             # det = 0
-            r = torch.randn(n, n) * scale
+            r = torch.randn(n, n, device=device) * scale
             u, s, v = r.svd()
-            if reference_det(u) < 0:
+            if reference_slogdet(u)[0] < 0:
                 u = -u
-            if reference_det(v) < 0:
+            if reference_slogdet(v)[0] < 0:
                 v = -v
             s[0] *= -1
             s[-1] = 0
             test(u.mm(s.diag()).mm(v))
 
+        # Small values to test numerical stability. Note that we don't scale
+        # this matrix.
+        r = torch.randn(512, 512, device=device)
+        u, s, v = r.svd()
+        s.fill_(1. / (100 * s.numel()))
+        test(u.mm(s.diag()).mm(v))
+
     @skipIfNoLapack
     def test_det_logdet_slogdet(self):
-        self._test_det_logdet_slogdet(self, lambda x: x)
+        self._test_det_logdet_slogdet(self, 'cpu')
 
     @staticmethod
     def _test_fft_ifft_rfft_irfft(self, device='cpu'):