Summary:
This issue was discovered by fehiepsi in https://github.com/uber/pyro/issues/1706 with the `log_prob` computation for Binomial, ~and can be seen with `torch.float32` when we have a combination of low probability value and high `total_count` - a test is added to capture this (since scipy only uses float64, the comparison is done using relative tolerance).~
The problem is in the code that tries to pull out the minimum values amongst the logits (written by me earlier, presumably to avoid numerical instability issues), but it is not needed.
EDIT: After a few attempts, I have been unable to reliably show that the change is more numerically stable, and have removed my previous test which fails on linux. The reason is that the issue manifests itself when `total_count` is high and `probs` is very low. However, the precision of `lgamma` when `total_count` is high is bad enough to wash away any benefits. The justification for this still stands though - (a) simplifies code (removes the unnecessary bit), (b) is no worse than the previous implementation, (c) has better continuity behavior as observed by fehiepsi in the issue above.
cc. fehiepsi, alicanb, fritzo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15962
Differential Revision:
D13709541
Pulled By: ezyang
fbshipit-source-id:
596c6853b6e4d5fba42336afa168a665ab6fbde2
log_factorial_n = torch.lgamma(self.total_count + 1)
log_factorial_k = torch.lgamma(value + 1)
log_factorial_nmk = torch.lgamma(self.total_count - value + 1)
- max_val = (-self.logits).clamp(min=0.0)
- # Note that: torch.log1p(-self.probs)) = max_val - torch.log1p((self.logits + 2 * max_val).exp()))
+ # Note that: torch.log1p(-self.probs)) = - torch.log1p(self.logits.exp()))
return (log_factorial_n - log_factorial_k - log_factorial_nmk +
- value * self.logits + self.total_count * max_val -
- self.total_count * torch.log1p((self.logits + 2 * max_val).exp()))
+ value * self.logits - self.total_count * torch.log1p(self.logits.exp()))
def enumerate_support(self, expand=True):
total_count = int(self.total_count.max())