From: Tongzhou Wang Date: Tue, 26 Mar 2019 22:25:26 +0000 (-0700) Subject: Improve numerical precision of (s)logdet (#18449) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~614 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5292685d2f144d9781ab8b7991c0a1153098a477;p=platform%2Fupstream%2Fpytorch.git Improve numerical precision of (s)logdet (#18449) Summary: Fixes https://github.com/pytorch/pytorch/issues/18448 and https://github.com/pytorch/pytorch/issues/18450 Pull Request resolved: https://github.com/pytorch/pytorch/pull/18449 Differential Revision: D14611638 Pulled By: soumith fbshipit-source-id: 4f1f27ab5316a92d2783e734169f599afed743cf --- diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 9ec2165..5042acc 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -9,6 +9,7 @@ #include #include #include +#include namespace at { namespace native { @@ -26,7 +27,7 @@ static inline std::tuple _lu_det_P_diag_U_info(const Tensor int int_info = info.squeeze_().item(); AT_CHECK(int_info >= 0, "LU factorization (getrf) failed with info = ", int_info); auto n = self.size(0); - auto num_exchanges = (at::arange(1, n + 1, p.type()) != p).nonzero().size(0); + auto num_exchanges = (at::arange(1, n + 1, p.options()) != p).nonzero().size(0); if (num_exchanges % 2 == 1) { return std::make_tuple(-1., lu.diag(), int_info); } else { @@ -44,7 +45,7 @@ Tensor det(const Tensor& self) { int info; std::tie(det_P, diag_U, info) = _lu_det_P_diag_U_info(self); if (info > 0) { - return at::zeros({}, self.type()); + return at::zeros({}, self.options()); } else { return diag_U.prod().mul_(det_P); } @@ -56,16 +57,18 @@ Tensor logdet(const Tensor& self) { "logdet(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor " "of floating types"); double det_P; - Tensor diag_U, det; + Tensor diag_U; int info; std::tie(det_P, diag_U, info) = _lu_det_P_diag_U_info(self); if (info > 0) { - det = at::zeros({}, self.type()); - } else { - det = diag_U.prod().mul_(det_P); + return at::full({}, -std::numeric_limits::infinity(), self.options()); } - if (det.sign().item() <= 0) { - return det.log_(); // in order to get proper -inf (det=0) or nan (det<0) + // `det_sign` is the sign of the determinant. We work on `diag_U.sign()` for + // numerical stability when diag_U has a lot small values. + auto det_sign = diag_U.sign().prod().mul_(det_P); + // This synchronizes on GPU, but `_lu_det_P_diag_U_info` above already synchronizes + if (det_sign.item() <= 0) { + return det_sign.log_(); // get proper nan (det<0) or -inf (det=0) } else { return diag_U.abs_().log_().sum(); } @@ -77,15 +80,17 @@ std::tuple slogdet(const Tensor& self) { "slogdet(", self.type(), "{", self.sizes(), "}): expected a 2D square tensor " "of floating types"); double det_P; - Tensor diag_U, det; + Tensor diag_U; int info; std::tie(det_P, diag_U, info) = _lu_det_P_diag_U_info(self); if (info > 0) { return std::make_tuple(at::zeros({}, self.options()), - at::empty({}, self.options()).fill_(-INFINITY)); + at::full({}, -std::numeric_limits::infinity(), self.options())); } else { - det = diag_U.prod().mul_(det_P); - return std::make_tuple(det.sign(), diag_U.abs_().log_().sum()); + // `det_sign` is the sign of the determinant. We work on `diag_U.sign()` for + // numerical stability when diag_U has a lot small values. + auto det_sign = diag_U.sign().prod().mul_(det_P); + return std::make_tuple(det_sign, diag_U.abs_().log_().sum()); } } @@ -309,9 +314,9 @@ static inline Tensor& bmm_out_or_baddbmm_(Tensor& self_or_result, const Tensor& }); } } else if (at::hasMKL() && at::native::is_floating_point(self_or_result) - && batch_items_contiguous_or_transposed(batch1) - && batch_items_contiguous_or_transposed(batch2) - && self_or_result.is_contiguous()) { + && batch_items_contiguous_or_transposed(batch1) + && batch_items_contiguous_or_transposed(batch2) + && self_or_result.is_contiguous()) { at::native::_baddbmm_mkl_(self_or_result, batch1, batch2, beta, alpha); } else { // split along batch dimension if (is_bmm_out) { diff --git a/test/test_cuda.py b/test/test_cuda.py index a34e51f..93fcd7d 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -2138,7 +2138,7 @@ class TestCuda(TestCase): @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") def test_det_logdet_slogdet(self): - _TestTorchMixin._test_det_logdet_slogdet(self, lambda t: t.cuda()) + _TestTorchMixin._test_det_logdet_slogdet(self, 'cuda') @unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected") def test_solve(self): diff --git a/test/test_torch.py b/test/test_torch.py index 8a041cb..f895516 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -5531,47 +5531,60 @@ class _TestTorchMixin(object): self._test_chain_matmul(self, cast=lambda x: x) @staticmethod - def _test_det_logdet_slogdet(self, conv_fn): - def reference_det(M): - # naive row reduction - M = M.clone() - l = M.size(0) - multiplier = 1 - for i in range(l): - if M[i, 0] != 0: - if i != 0: - M[0], M[i] = M[i], M[0] - multiplier = -1 - break + def _test_det_logdet_slogdet(self, device): + def reference_slogdet(M): + if TEST_NUMPY: + sdet, logabsdet = np.linalg.slogdet(M.detach().cpu().numpy()) + return M.new_tensor(sdet), M.new_tensor(logabsdet) else: - return 0 - for i in range(1, l): - row = M[i] - for j in range(i): - row -= row[j] / M[j, j] * M[j] - M[i] = row - return M.diag().prod() * multiplier + # naive row reduction + M = M.clone() + l = M.size(0) + multiplier = 1 + for i in range(l): + if M[i, 0].item() != 0: + if i != 0: + M[0], M[i] = M[i], M[0] + multiplier = -1 + break + else: + return 0 + for i in range(1, l): + row = M[i] + for j in range(i): + row -= row[j] / M[j, j] * M[j] + M[i] = row + sdet = M.diag().sign().prod() + logabsdet = M.diag().abs_().log_().sum().add_(math.log(multiplier)) + return sdet, logabsdet def test_single_det(M, target, desc): + target_sdet, target_logabsdet = target + det = M.det() logdet = M.logdet() sdet, logabsdet = M.slogdet() - self.assertEqual(det, target, 1e-7, '{} (det)'.format(desc)) - if det.item() < 0: + + # Test det + self.assertEqual(det, target_sdet * target_logabsdet.exp(), 1e-7, '{} (det)'.format(desc)) + + # Test slogdet + # Compare the overall value rather than individual parts because of + # precision issues when det is near zero. + self.assertEqual(sdet * logabsdet.exp(), target_sdet * target_logabsdet.exp(), 1e-7, '{} (slogdet)'.format(desc)) + + # Test logdet + # Compare logdet against our own pytorch slogdet because they should + # be consistent, while it may behave slightly differently with other + # slogdet implementations when det is near zero due to precision + # issues. + if sdet.item() < 0: self.assertTrue(logdet.item() != logdet.item(), '{} (logdet negative case)'.format(desc)) - self.assertTrue(sdet.item() == -1, '{} (slogdet sign negative case)'.format(desc)) - self.assertEqual(logabsdet.exp(), det.abs(), 1e-7, '{} (slogdet logabsdet negative case)'.format(desc)) - elif det.item() == 0: - self.assertEqual(logdet.exp().item(), 0, 1e-7, '{} (logdet zero case)'.format(desc)) - self.assertTrue(sdet.item() == 0, '{} (slogdet sign zero case)'.format(desc)) - self.assertEqual(logabsdet.exp().item(), 0, 1e-7, '{} (slogdet logabsdet zero case)'.format(desc)) else: - self.assertEqual(logdet.exp(), det, 1e-7, '{} (logdet positive case)'.format(desc)) - self.assertTrue(sdet.item() == 1, '{} (slogdet sign positive case)'.format(desc)) - self.assertEqual(logabsdet.exp(), det, 1e-7, '{} (slogdet logabsdet positive case)'.format(desc)) + self.assertEqual(logdet.exp(), target_logabsdet.exp(), 1e-7, '{} (logdet non-negative case)'.format(desc)) - eye = conv_fn(torch.eye(5)) - test_single_det(eye, torch.tensor(1, dtype=eye.dtype), 'identity') + eye = torch.eye(5, device=device) + test_single_det(eye, (torch.ones((), device=device), torch.zeros((), device=device)), 'identity') # TODO: Remove when MAGMA 2.5.0 is built for CUDA 8 and CUDA 9.2 is_cuda_8_92 = False @@ -5580,22 +5593,29 @@ class _TestTorchMixin(object): def test(M): assert M.size(0) >= 5, 'this helper fn assumes M to be at least 5x5' - M = conv_fn(M) + M = M.to(device) if M.is_cuda and is_cuda_8_92: return - M_det = M.det() - ref_M_det = reference_det(M) + ref_M_sdet, ref_M_logabsdet = reference_slogdet(M) - test_single_det(M, ref_M_det, 'basic') - if abs(ref_M_det.item()) >= 1e-10: # skip singular - test_single_det(M, M.inverse().det().pow_(-1), 'inverse') - test_single_det(M, M.t().det(), 'transpose') + test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'basic') + if ref_M_logabsdet.exp().item() >= 1e-6: # skip singular + M_inv = M.inverse() + test_single_det(M_inv, reference_slogdet(M_inv), 'inverse') + + test_single_det(M, (ref_M_sdet, ref_M_logabsdet), 'transpose') for x in [0, 2, 4]: for scale in [-2, -0.1, 0, 10]: - target = M_det * scale + if scale > 0: + target = ref_M_sdet, ref_M_logabsdet + math.log(scale) + elif scale == 0: + target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) + else: + target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-scale) + # dim 0 M_clone = M.clone() M_clone[:, x] *= scale @@ -5607,7 +5627,7 @@ class _TestTorchMixin(object): for x1, x2 in [(0, 3), (4, 1), (3, 2)]: assert x1 != x2, 'x1 and x2 needs to be different for this test' - target = M_det.clone().zero_() + target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) # dim 0 M_clone = M.clone() M_clone[:, x2] = M_clone[:, x1] @@ -5618,7 +5638,14 @@ class _TestTorchMixin(object): test_single_det(M_clone, target, 'two columns are same') for scale1, scale2 in [(0.3, -1), (0, 2), (10, 0.1)]: - target = -M_det * scale1 * scale2 + det_scale = scale1 * scale2 * -1 + if det_scale > 0: + target = ref_M_sdet, ref_M_logabsdet + math.log(det_scale) + elif det_scale == 0: + target = torch.zeros_like(ref_M_sdet), torch.full_like(ref_M_logabsdet, -inf) + else: + target = ref_M_sdet.neg(), ref_M_logabsdet + math.log(-det_scale) + # dim 0 M_clone = M.clone() t = M_clone[:, x1] * scale1 @@ -5648,39 +5675,51 @@ class _TestTorchMixin(object): # scale by sqrt( (n-1)! )^(-1/n) = ( (n-1)! )^(-1/(2n)) # # source: https://arxiv.org/pdf/1112.0752.pdf + + # TODO: technically we need subexponential distn for this to hold, + # but we mostly use gaussian entries below. Consider switching + # to Chi-sq if this turns out not stable enough, since Chi-sq + # is easy enough to sample from. return math.factorial(n - 1) ** (-1.0 / (2 * n)) for n in [5, 10, 25]: scale = get_random_mat_scale(n) - test(torch.randn(n, n) * scale) - r = torch.randn(n, n) * scale + test(torch.randn(n, n, device=device) * scale) + r = torch.randn(n, n, device=device) * scale # symmetric psd test(r.mm(r.t())) # symmetric pd - r = torch.randn(n, n) * scale - test(r.mm(r.t()) + torch.eye(n) * 1e-6) + r = torch.randn(n, n, device=device) * scale + test(r.mm(r.t()) + torch.eye(n, device=device) * 1e-6) # symmetric - r = torch.randn(n, n) * scale + r = torch.randn(n, n, device=device) * scale for i in range(n): for j in range(i): r[i, j] = r[j, i] test(r) # non-contiguous - test((torch.randn(n, n, n + 1) * scale)[:, 2, 1:]) + test((torch.randn(n, n, n + 1, device=device) * scale)[:, 2, 1:]) # det = 0 - r = torch.randn(n, n) * scale + r = torch.randn(n, n, device=device) * scale u, s, v = r.svd() - if reference_det(u) < 0: + if reference_slogdet(u)[0] < 0: u = -u - if reference_det(v) < 0: + if reference_slogdet(v)[0] < 0: v = -v s[0] *= -1 s[-1] = 0 test(u.mm(s.diag()).mm(v)) + # Small values to test numerical stability. Note that we don't scale + # this matrix. + r = torch.randn(512, 512, device=device) + u, s, v = r.svd() + s.fill_(1. / (100 * s.numel())) + test(u.mm(s.diag()).mm(v)) + @skipIfNoLapack def test_det_logdet_slogdet(self): - self._test_det_logdet_slogdet(self, lambda x: x) + self._test_det_logdet_slogdet(self, 'cpu') @staticmethod def _test_fft_ifft_rfft_irfft(self, device='cpu'):