Add Kullback-Leibler for Independent distribution(s).
authorA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 6 Mar 2018 23:07:27 +0000 (15:07 -0800)
committerTensorFlower Gardener <gardener@tensorflow.org>
Tue, 6 Mar 2018 23:11:30 +0000 (15:11 -0800)
PiperOrigin-RevId: 188087902

tensorflow/contrib/distributions/python/kernel_tests/independent_test.py
tensorflow/contrib/distributions/python/ops/independent.py

index 06318ca..6a69f9e 100644 (file)
@@ -27,6 +27,7 @@ from tensorflow.python.framework import dtypes
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops.distributions import bernoulli as bernoulli_lib
+from tensorflow.python.ops.distributions import kullback_leibler
 from tensorflow.python.ops.distributions import normal as normal_lib
 from tensorflow.python.platform import test
 from tensorflow.python.platform import tf_logging
@@ -126,6 +127,100 @@ class ProductDistributionTest(test.TestCase):
       self.assertAllClose(sample_entropy_, actual_entropy_, rtol=0.01, atol=0.)
       self.assertAllClose(loc, actual_mode_, rtol=1e-6, atol=0.)
 
+  def testKLRaises(self):
+    ind1 = independent_lib.Independent(
+        distribution=normal_lib.Normal(
+            loc=np.float32([-1., 1]),
+            scale=np.float32([0.1, 0.5])),
+        reinterpreted_batch_ndims=1)
+    ind2 = independent_lib.Independent(
+        distribution=normal_lib.Normal(
+            loc=np.float32(-1),
+            scale=np.float32(0.5)),
+        reinterpreted_batch_ndims=0)
+
+    with self.assertRaisesRegexp(
+        ValueError, "Event shapes do not match"):
+      kullback_leibler.kl_divergence(ind1, ind2)
+
+    ind1 = independent_lib.Independent(
+        distribution=normal_lib.Normal(
+            loc=np.float32([-1., 1]),
+            scale=np.float32([0.1, 0.5])),
+        reinterpreted_batch_ndims=1)
+    ind2 = independent_lib.Independent(
+        distribution=mvn_diag_lib.MultivariateNormalDiag(
+            loc=np.float32([-1., 1]),
+            scale_diag=np.float32([0.1, 0.5])),
+        reinterpreted_batch_ndims=0)
+
+    with self.assertRaisesRegexp(
+        NotImplementedError, "different event shapes"):
+      kullback_leibler.kl_divergence(ind1, ind2)
+
+  def testKLScalarToMultivariate(self):
+    normal1 = normal_lib.Normal(
+        loc=np.float32([-1., 1]),
+        scale=np.float32([0.1, 0.5]))
+    ind1 = independent_lib.Independent(
+        distribution=normal1, reinterpreted_batch_ndims=1)
+
+    normal2 = normal_lib.Normal(
+        loc=np.float32([-3., 3]),
+        scale=np.float32([0.3, 0.3]))
+    ind2 = independent_lib.Independent(
+        distribution=normal2, reinterpreted_batch_ndims=1)
+
+    normal_kl = kullback_leibler.kl_divergence(normal1, normal2)
+    ind_kl = kullback_leibler.kl_divergence(ind1, ind2)
+    self.assertAllClose(
+        self.evaluate(math_ops.reduce_sum(normal_kl, axis=-1)),
+        self.evaluate(ind_kl))
+
+  def testKLIdentity(self):
+    normal1 = normal_lib.Normal(
+        loc=np.float32([-1., 1]),
+        scale=np.float32([0.1, 0.5]))
+    # This is functionally just a wrapper around normal1,
+    # and doesn't change any outputs.
+    ind1 = independent_lib.Independent(
+        distribution=normal1, reinterpreted_batch_ndims=0)
+
+    normal2 = normal_lib.Normal(
+        loc=np.float32([-3., 3]),
+        scale=np.float32([0.3, 0.3]))
+    # This is functionally just a wrapper around normal2,
+    # and doesn't change any outputs.
+    ind2 = independent_lib.Independent(
+        distribution=normal2, reinterpreted_batch_ndims=0)
+
+    normal_kl = kullback_leibler.kl_divergence(normal1, normal2)
+    ind_kl = kullback_leibler.kl_divergence(ind1, ind2)
+    self.assertAllClose(
+        self.evaluate(normal_kl), self.evaluate(ind_kl))
+
+  def testKLMultivariateToMultivariate(self):
+    # (1, 1, 2) batch of MVNDiag
+    mvn1 = mvn_diag_lib.MultivariateNormalDiag(
+        loc=np.float32([[[[-1., 1, 3.], [2., 4., 3.]]]]),
+        scale_diag=np.float32([[[0.2, 0.1, 5.], [2., 3., 4.]]]))
+    ind1 = independent_lib.Independent(
+        distribution=mvn1, reinterpreted_batch_ndims=2)
+
+    # (1, 1, 2) batch of MVNDiag
+    mvn2 = mvn_diag_lib.MultivariateNormalDiag(
+        loc=np.float32([[[[-2., 3, 2.], [1., 3., 2.]]]]),
+        scale_diag=np.float32([[[0.1, 0.5, 3.], [1., 2., 1.]]]))
+
+    ind2 = independent_lib.Independent(
+        distribution=mvn2, reinterpreted_batch_ndims=2)
+
+    mvn_kl = kullback_leibler.kl_divergence(mvn1, mvn2)
+    ind_kl = kullback_leibler.kl_divergence(ind1, ind2)
+    self.assertAllClose(
+        self.evaluate(math_ops.reduce_sum(mvn_kl, axis=[-1, -2])),
+        self.evaluate(ind_kl))
+
   def _testMnistLike(self, static_shape):
     sample_shape = [4, 5]
     batch_shape = [10]
index cbce005..7dcb3e3 100644 (file)
@@ -28,6 +28,7 @@ from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import check_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops.distributions import distribution as distribution_lib
+from tensorflow.python.ops.distributions import kullback_leibler
 
 
 class Independent(distribution_lib.Distribution):
@@ -254,3 +255,58 @@ class Independent(distribution_lib.Distribution):
     else:
       which_maximum = np.maximum
     return which_maximum(0, ndims - 1)
+
+
+@kullback_leibler.RegisterKL(Independent, Independent)
+def _kl_independent(a, b, name="kl_independent"):
+  """Batched KL divergence `KL(a || b)` for Independent distributions.
+
+  We can leverage the fact that
+  ```
+  KL(Independent(a) || Independent(b)) = sum(KL(a || b))
+  ```
+  where the sum is over the `reinterpreted_batch_ndims`.
+
+  Args:
+    a: Instance of `Independent`.
+    b: Instance of `Independent`.
+    name: (optional) name to use for created ops. Default "kl_independent".
+
+  Returns:
+    Batchwise `KL(a || b)`.
+
+  Raises:
+    ValueError: If the event space for `a` and `b`, or their underlying
+      distributions don't match.
+  """
+  p = a.distribution
+  q = b.distribution
+
+  # The KL between any two (non)-batched distributions is a scalar.
+  # Given that the KL between two factored distributions is the sum, i.e.
+  # KL(p1(x)p2(y) || q1(x)q2(y)) = KL(p1 || q1) + KL(q1 || q2), we compute
+  # KL(p || q) and do a `reduce_sum` on the reinterpreted batch dimensions.
+  if a.event_shape.is_fully_defined() and b.event_shape.is_fully_defined():
+    if a.event_shape == b.event_shape:
+      if p.event_shape == q.event_shape:
+        num_reduce_dims = a.event_shape.ndims - p.event_shape.ndims
+        reduce_dims = [-i - 1 for i in range(0, num_reduce_dims)]
+
+        return math_ops.reduce_sum(
+            kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims)
+      else:
+        raise NotImplementedError("KL between Independents with different "
+                                  "event shapes not supported.")
+    else:
+      raise ValueError("Event shapes do not match.")
+  else:
+    with ops.control_dependencies([
+        check_ops.assert_equal(a.event_shape_tensor(), b.event_shape_tensor()),
+        check_ops.assert_equal(p.event_shape_tensor(), q.event_shape_tensor())
+    ]):
+      num_reduce_dims = (
+          array_ops.shape(a.event_shape_tensor()[0]) -
+          array_ops.shape(p.event_shape_tensor()[0]))
+      reduce_dims = math_ops.range(-num_reduce_dims - 1, -1, 1)
+      return math_ops.reduce_sum(
+          kullback_leibler.kl_divergence(p, q, name=name), axis=reduce_dims)