KL divergence for two Dirichlet distributions.
authorA. Unique TensorFlower <gardener@tensorflow.org>
Wed, 30 May 2018 16:39:57 +0000 (09:39 -0700)
committerTensorFlower Gardener <gardener@tensorflow.org>
Wed, 30 May 2018 16:42:43 +0000 (09:42 -0700)
PiperOrigin-RevId: 198573236

tensorflow/python/kernel_tests/distributions/dirichlet_test.py
tensorflow/python/ops/distributions/dirichlet.py

index 3bcfae0..bcec6ef 100644 (file)
@@ -26,6 +26,7 @@ from tensorflow.python.framework import test_util
 from tensorflow.python.ops import array_ops
 from tensorflow.python.ops import math_ops
 from tensorflow.python.ops.distributions import dirichlet as dirichlet_lib
+from tensorflow.python.ops.distributions import kullback_leibler
 from tensorflow.python.platform import test
 from tensorflow.python.platform import tf_logging
 
@@ -39,6 +40,7 @@ def try_import(name):  # pylint: disable=invalid-name
   return module
 
 
+special = try_import("scipy.special")
 stats = try_import("scipy.stats")
 
 
@@ -262,6 +264,39 @@ class DirichletTest(test.TestCase):
                   a=1., b=2.).cdf)[0],
           0.01)
 
+  def testDirichletDirichletKL(self):
+    conc1 = np.array([[1., 2., 3., 1.5, 2.5, 3.5],
+                      [1.5, 2.5, 3.5, 4.5, 5.5, 6.5]])
+    conc2 = np.array([[0.5, 1., 1.5, 2., 2.5, 3.]])
+
+    d1 = dirichlet_lib.Dirichlet(conc1)
+    d2 = dirichlet_lib.Dirichlet(conc2)
+    x = d1.sample(int(1e4), seed=0)
+    kl_sample = math_ops.reduce_mean(d1.log_prob(x) - d2.log_prob(x), 0)
+    kl_actual = kullback_leibler.kl_divergence(d1, d2)
+
+    kl_sample_val = self.evaluate(kl_sample)
+    kl_actual_val = self.evaluate(kl_actual)
+
+    self.assertEqual(conc1.shape[:-1], kl_actual.get_shape())
+
+    if not special:
+      return
+
+    kl_expected = (
+        special.gammaln(np.sum(conc1, -1))
+        - special.gammaln(np.sum(conc2, -1))
+        - np.sum(special.gammaln(conc1) - special.gammaln(conc2), -1)
+        + np.sum((conc1 - conc2) * (special.digamma(conc1) - special.digamma(
+            np.sum(conc1, -1, keepdims=True))), -1))
+
+    self.assertAllClose(kl_expected, kl_actual_val, atol=0., rtol=1e-6)
+    self.assertAllClose(kl_sample_val, kl_actual_val, atol=0., rtol=1e-1)
+
+    # Make sure KL(d1||d1) is 0
+    kl_same = self.evaluate(kullback_leibler.kl_divergence(d1, d1))
+    self.assertAllClose(kl_same, np.zeros_like(kl_expected))
+
 
 if __name__ == "__main__":
   test.main()
index 1ab58c1..72567e6 100644 (file)
@@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops
 from tensorflow.python.ops import random_ops
 from tensorflow.python.ops import special_math_ops
 from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.ops.distributions import kullback_leibler
 from tensorflow.python.ops.distributions import util as distribution_util
 from tensorflow.python.util.tf_export import tf_export
 
@@ -297,3 +298,80 @@ class Dirichlet(distribution.Distribution):
             math_ops.reduce_sum(x, -1),
             message="sample last-dimension must sum to `1`"),
     ], x)
+
+
+@kullback_leibler.RegisterKL(Dirichlet, Dirichlet)
+def _kl_dirichlet_dirichlet(d1, d2, name=None):
+  """Batchwise KL divergence KL(d1 || d2) with d1 and d2 Dirichlet.
+
+  Args:
+    d1: instance of a Dirichlet distribution object.
+    d2: instance of a Dirichlet distribution object.
+    name: (optional) Name to use for created operations.
+      default is "kl_dirichlet_dirichlet".
+
+  Returns:
+    Batchwise KL(d1 || d2)
+  """
+  with ops.name_scope(name, "kl_dirichlet_dirichlet", values=[
+      d1.concentration, d2.concentration]):
+    # The KL between Dirichlet distributions can be derived as follows. We have
+    #
+    #   Dir(x; a) = 1 / B(a) * prod_i[x[i]^(a[i] - 1)]
+    #
+    # where B(a) is the multivariate Beta function:
+    #
+    #   B(a) = Gamma(a[1]) * ... * Gamma(a[n]) / Gamma(a[1] + ... + a[n])
+    #
+    # The KL is
+    #
+    #   KL(Dir(x; a), Dir(x; b)) = E_Dir(x; a){log(Dir(x; a) / Dir(x; b))}
+    #
+    # so we'll need to know the log density of the Dirichlet. This is
+    #
+    #   log(Dir(x; a)) = sum_i[(a[i] - 1) log(x[i])] - log B(a)
+    #
+    # The only term that matters for the expectations is the log(x[i]). To
+    # compute the expectation of this term over the Dirichlet density, we can
+    # use the following facts about the Dirichlet in exponential family form:
+    #   1. log(x[i]) is a sufficient statistic
+    #   2. expected sufficient statistics (of any exp family distribution) are
+    #      equal to derivatives of the log normalizer with respect to
+    #      corresponding natural parameters: E{T[i](x)} = dA/d(eta[i])
+    #
+    # To proceed, we can rewrite the Dirichlet density in exponential family
+    # form as follows:
+    #
+    #   Dir(x; a) = exp{eta(a) . T(x) - A(a)}
+    #
+    # where '.' is the dot product of vectors eta and T, and A is a scalar:
+    #
+    #   eta[i](a) = a[i] - 1
+    #     T[i](x) = log(x[i])
+    #        A(a) = log B(a)
+    #
+    # Now, we can use fact (2) above to write
+    #
+    #   E_Dir(x; a)[log(x[i])]
+    #       = dA(a) / da[i]
+    #       = d/da[i] log B(a)
+    #       = d/da[i] (sum_j lgamma(a[j])) - lgamma(sum_j a[j])
+    #       = digamma(a[i])) - digamma(sum_j a[j])
+    #
+    # Putting it all together, we have
+    #
+    # KL[Dir(x; a) || Dir(x; b)]
+    #     = E_Dir(x; a){log(Dir(x; a) / Dir(x; b)}
+    #     = E_Dir(x; a){sum_i[(a[i] - b[i]) log(x[i])} - (lbeta(a) - lbeta(b))
+    #     = sum_i[(a[i] - b[i]) * E_Dir(x; a){log(x[i])}] - lbeta(a) + lbeta(b)
+    #     = sum_i[(a[i] - b[i]) * (digamma(a[i]) - digamma(sum_j a[j]))]
+    #          - lbeta(a) + lbeta(b))
+
+    digamma_sum_d1 = math_ops.digamma(
+        math_ops.reduce_sum(d1.concentration, axis=-1, keepdims=True))
+    digamma_diff = math_ops.digamma(d1.concentration) - digamma_sum_d1
+    concentration_diff = d1.concentration - d2.concentration
+
+    return (math_ops.reduce_sum(concentration_diff * digamma_diff, axis=-1) -
+            special_math_ops.lbeta(d1.concentration) +
+            special_math_ops.lbeta(d2.concentration))