Fix numerical stability in binomial.log_prob (#15962)
authorNeeraj Pradhan <npradhan@uber.com>
Thu, 17 Jan 2019 17:59:58 +0000 (09:59 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 17 Jan 2019 18:18:37 +0000 (10:18 -0800)
Summary:
This issue was discovered by fehiepsi in https://github.com/uber/pyro/issues/1706 with the `log_prob` computation for Binomial, ~and can be seen with `torch.float32` when we have a combination of low probability value and high `total_count` - a test is added to capture this (since scipy only uses float64, the comparison is done using relative tolerance).~

The problem is in the code that tries to pull out the minimum values amongst the logits (written by me earlier, presumably to avoid numerical instability issues), but it is not needed.

EDIT: After a few attempts, I have been unable to reliably show that the change is more numerically stable, and have removed my previous test which fails on linux. The reason is that the issue manifests itself when `total_count` is high and `probs` is very low. However, the precision of `lgamma` when `total_count` is high is bad enough to wash away any benefits. The justification for this still stands though - (a) simplifies code (removes the unnecessary bit), (b) is no worse than the previous implementation, (c) has better continuity behavior as observed by fehiepsi in the issue above.

cc. fehiepsi, alicanb, fritzo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15962

Differential Revision: D13709541

Pulled By: ezyang

fbshipit-source-id: 596c6853b6e4d5fba42336afa168a665ab6fbde2

torch/distributions/binomial.py

index e1763c1..26fa54d 100644 (file)
@@ -113,11 +113,9 @@ class Binomial(Distribution):
         log_factorial_n = torch.lgamma(self.total_count + 1)
         log_factorial_k = torch.lgamma(value + 1)
         log_factorial_nmk = torch.lgamma(self.total_count - value + 1)
-        max_val = (-self.logits).clamp(min=0.0)
-        # Note that: torch.log1p(-self.probs)) = max_val - torch.log1p((self.logits + 2 * max_val).exp()))
+        # Note that: torch.log1p(-self.probs)) = - torch.log1p(self.logits.exp()))
         return (log_factorial_n - log_factorial_k - log_factorial_nmk +
-                value * self.logits + self.total_count * max_val -
-                self.total_count * torch.log1p((self.logits + 2 * max_val).exp()))
+                value * self.logits - self.total_count * torch.log1p(self.logits.exp()))
 
     def enumerate_support(self, expand=True):
         total_count = int(self.total_count.max())