Allow non-integer values for Poisson CDF/PMF.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 21 Feb 2018 20:36:25 +0000 (12:36 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 21 Feb 2018 20:46:43 +0000 (12:46 -0800)
PiperOrigin-RevId: 186502845

tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py
tensorflow/contrib/distributions/python/ops/poisson.py

index d9c9008..19a7472 100644 (file)
@@ -18,6 +18,7 @@ from __future__ import division
 from __future__ import print_function
 
 import numpy as np
+from scipy import special
 from scipy import stats
 from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib
 from tensorflow.python.framework import constant_op
@@ -110,7 +111,7 @@ class PoissonTest(test.TestCase):
       batch_size = 6
       lam = constant_op.constant([3.0] * batch_size)
       lam_v = 3.0
-      x = [2.2, 3.1, 4., 5.5, 6., 7.]
+      x = [2., 3., 4., 5., 6., 7.]
 
       poisson = self._make_poisson(rate=lam)
       log_cdf = poisson.log_cdf(x)
@@ -121,12 +122,31 @@ class PoissonTest(test.TestCase):
       self.assertEqual(cdf.get_shape(), (6,))
       self.assertAllClose(cdf.eval(), stats.poisson.cdf(x, lam_v))
 
+  def testPoissonCDFNonIntegerValues(self):
+    with self.test_session():
+      batch_size = 6
+      lam = constant_op.constant([3.0] * batch_size)
+      lam_v = 3.0
+      x = np.array([2.2, 3.1, 4., 5.5, 6., 7.], dtype=np.float32)
+
+      poisson = self._make_poisson(rate=lam)
+      cdf = poisson.cdf(x)
+      self.assertEqual(cdf.get_shape(), (6,))
+
+      # The Poisson CDF should be valid on these non-integer values, and
+      # equal to igammac(1 + x, rate).
+      self.assertAllClose(cdf.eval(), special.gammaincc(1. + x, lam_v))
+
+      with self.assertRaisesOpError("cannot contain fractional components"):
+        poisson_validate = self._make_poisson(rate=lam, validate_args=True)
+        poisson_validate.cdf(x).eval()
+
   def testPoissonCdfMultidimensional(self):
     with self.test_session():
       batch_size = 6
       lam = constant_op.constant([[2.0, 4.0, 5.0]] * batch_size)
       lam_v = [2.0, 4.0, 5.0]
-      x = np.array([[2.2, 3.1, 4., 5.5, 6., 7.]], dtype=np.float32).T
+      x = np.array([[2., 3., 4., 5., 6., 7.]], dtype=np.float32).T
 
       poisson = self._make_poisson(rate=lam)
       log_cdf = poisson.log_cdf(x)
index e967dcc..02e97c0 100644 (file)
@@ -35,9 +35,15 @@ __all__ = [
 
 
 _poisson_sample_note = """
-Note that the input value must be a non-negative floating point tensor with
-dtype `dtype` and whose shape can be broadcast with `self.rate`. `x` is only
-legal if it is non-negative and its components are equal to integer values.
+The Poisson distribution is technically only defined for non-negative integer
+values. When `validate_args=False`, non-integral inputs trigger an assertion.
+
+When `validate_args=False` calculations are otherwise unchanged despite
+integral or non-integral inputs.
+
+When `validate_args=False`, evaluating the pmf at non-integral values,
+corresponds to evaluations of an unnormalized distribution, that does not
+correspond to evaluations of the cdf.
 """
 
 
@@ -150,10 +156,6 @@ class Poisson(distribution.Distribution):
   def _cdf(self, x):
     if self.validate_args:
       x = distribution_util.embed_check_nonnegative_integer_form(x)
-    else:
-      # Whether or not x is integer-form, the following is well-defined.
-      # However, scipy takes the floor, so we do too.
-      x = math_ops.floor(x)
     return math_ops.igammac(1. + x, self.rate)
 
   def _log_normalization(self):
@@ -162,9 +164,6 @@ class Poisson(distribution.Distribution):
   def _log_unnormalized_prob(self, x):
     if self.validate_args:
       x = distribution_util.embed_check_nonnegative_integer_form(x)
-    else:
-      # For consistency with cdf, we take the floor.
-      x = math_ops.floor(x)
     return x * self.log_rate - math_ops.lgamma(1. + x)
 
   def _mean(self):