Poisson zero rate (#61511)
authorTill Hoffmann <tillahoffmann@gmail.com>
Thu, 19 Aug 2021 15:28:55 +0000 (08:28 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 19 Aug 2021 15:30:28 +0000 (08:30 -0700)
Summary:
This PR fixes https://github.com/pytorch/pytorch/issues/53485 by allowing zero rates for the Poisson distribution. This implementation is consistent with `scipy.stats.poisson` which admits zero rates. In addition to addressing the aforementioned issue, this PR makes two supporting changes:

1. add a `nonnegative` constraint to enforce non-negative rates for the Poisson distribution.
2. adjust the evaluation of the gradient of `xlogy` such that it is well defined for `x == 0 and y == 0`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61511

Reviewed By: ejguan

Differential Revision: D30352917

Pulled By: albanD

fbshipit-source-id: f3d33da58360e80d75eb83519f199b93232a2a2d

test/distributions/test_distributions.py
tools/autograd/derivatives.yaml
torch/distributions/constraint_registry.py
torch/distributions/constraints.py
torch/distributions/poisson.py

index 85e4dba..319b557 100644 (file)
@@ -387,6 +387,12 @@ EXAMPLES = [
         },
         {
             'rate': 0.2,
+        },
+        {
+            'rate': torch.tensor([0.0], requires_grad=True),
+        },
+        {
+            'rate': 0.0,
         }
     ]),
     Example(RelaxedBernoulli, [
@@ -667,7 +673,7 @@ BAD_EXAMPLES = [
     ]),
     Example(Poisson, [
         {
-            'rate': torch.tensor([0.0], requires_grad=True),
+            'rate': torch.tensor([-0.1], requires_grad=True),
         },
         {
             'rate': -1.0,
@@ -1315,17 +1321,29 @@ class TestDistributions(TestCase):
     def test_poisson_log_prob(self):
         rate = torch.randn(2, 3).abs().requires_grad_()
         rate_1d = torch.randn(1).abs().requires_grad_()
+        rate_zero = torch.zeros([], requires_grad=True)
 
-        def ref_log_prob(idx, x, log_prob):
-            l = rate.view(-1)[idx].detach()
+        def ref_log_prob(ref_rate, idx, x, log_prob):
+            l = ref_rate.view(-1)[idx].detach()
             expected = scipy.stats.poisson.logpmf(x, l)
             self.assertEqual(log_prob, expected, atol=1e-3, rtol=0)
 
         set_rng_seed(0)
-        self._check_log_prob(Poisson(rate), ref_log_prob)
+        self._check_log_prob(Poisson(rate), lambda *args: ref_log_prob(rate, *args))
+        self._check_log_prob(Poisson(rate_zero), lambda *args: ref_log_prob(rate_zero, *args))
         self._gradcheck_log_prob(Poisson, (rate,))
         self._gradcheck_log_prob(Poisson, (rate_1d,))
 
+        # We cannot check gradients automatically for zero rates because the finite difference
+        # approximation enters the forbidden parameter space. We instead compare with the
+        # theoretical results.
+        dist = Poisson(rate_zero)
+        s = dist.sample()
+        dist.log_prob(s).backward()
+        torch.testing.assert_allclose(rate_zero.grad, -1.0)
+        dist.log_prob(torch.ones_like(rate_zero)).backward()
+        torch.testing.assert_allclose(rate_zero.grad, torch.inf)
+
     @unittest.skipIf(IS_MACOS, "See https://github.com/pytorch/pytorch/issues/60347")
     @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
     def test_poisson_sample(self):
index b52b690..49e574a 100644 (file)
 
 - name: xlogy.Tensor(Tensor self, Tensor other) -> Tensor
   self: grad * at::xlogy((self != 0), other)
-  other: grad * self / other
+  other: grad * at::where(other.isnan() | (self != 0), self / other, zeros_like(other))
   result: self_t * at::xlogy((self_p != 0), other_p) + other_t * self_p / other_p
 
 - name: xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor
-  other: grad * self / other
+  other: grad * at::where(other.isnan() | (!self.equal(0)), self / other, zeros_like(other))
   result: auto_element_wise
 
 - name: xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor
index cbe987e..c03f0ad 100644 (file)
@@ -173,7 +173,9 @@ def _transform_to_independent(constraint):
 
 
 @biject_to.register(constraints.positive)
+@biject_to.register(constraints.nonnegative)
 @transform_to.register(constraints.positive)
+@transform_to.register(constraints.nonnegative)
 def _transform_to_positive(constraint):
     return transforms.ExpTransform()
 
index 99808b6..5eed19a 100644 (file)
@@ -545,6 +545,7 @@ integer_interval = _IntegerInterval
 real = _Real()
 real_vector = independent(real, 1)
 positive = _GreaterThan(0.)
+nonnegative = _GreaterThanEq(0.)
 greater_than = _GreaterThan
 greater_than_eq = _GreaterThanEq
 less_than = _LessThan
index 954ed6e..9adb641 100644 (file)
@@ -24,7 +24,7 @@ class Poisson(ExponentialFamily):
     Args:
         rate (Number, Tensor): the rate parameter
     """
-    arg_constraints = {'rate': constraints.positive}
+    arg_constraints = {'rate': constraints.nonnegative}
     support = constraints.nonnegative_integer
 
     @property
@@ -60,7 +60,7 @@ class Poisson(ExponentialFamily):
         if self._validate_args:
             self._validate_sample(value)
         rate, value = broadcast_all(self.rate, value)
-        return (rate.log() * value) - rate - (value + 1).lgamma()
+        return value.xlogy(rate) - rate - (value + 1).lgamma()
 
     @property
     def _natural_params(self):