From: A. Unique TensorFlower Date: Wed, 30 May 2018 16:39:57 +0000 (-0700) Subject: KL divergence for two Dirichlet distributions. X-Git-Tag: upstream/v1.9.0_rc1~37^2^2~14 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5eb510994043d1342170f657860196be0b7ed15c;p=platform%2Fupstream%2Ftensorflow.git KL divergence for two Dirichlet distributions. PiperOrigin-RevId: 198573236 --- diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py index 3bcfae0..bcec6ef 100644 --- a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py +++ b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py @@ -26,6 +26,7 @@ from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import dirichlet as dirichlet_lib +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging @@ -39,6 +40,7 @@ def try_import(name): # pylint: disable=invalid-name return module +special = try_import("scipy.special") stats = try_import("scipy.stats") @@ -262,6 +264,39 @@ class DirichletTest(test.TestCase): a=1., b=2.).cdf)[0], 0.01) + def testDirichletDirichletKL(self): + conc1 = np.array([[1., 2., 3., 1.5, 2.5, 3.5], + [1.5, 2.5, 3.5, 4.5, 5.5, 6.5]]) + conc2 = np.array([[0.5, 1., 1.5, 2., 2.5, 3.]]) + + d1 = dirichlet_lib.Dirichlet(conc1) + d2 = dirichlet_lib.Dirichlet(conc2) + x = d1.sample(int(1e4), seed=0) + kl_sample = math_ops.reduce_mean(d1.log_prob(x) - d2.log_prob(x), 0) + kl_actual = kullback_leibler.kl_divergence(d1, d2) + + kl_sample_val = self.evaluate(kl_sample) + kl_actual_val = self.evaluate(kl_actual) + + self.assertEqual(conc1.shape[:-1], kl_actual.get_shape()) + + if not special: + return + + kl_expected = ( + special.gammaln(np.sum(conc1, -1)) + - special.gammaln(np.sum(conc2, -1)) + - np.sum(special.gammaln(conc1) - special.gammaln(conc2), -1) + + np.sum((conc1 - conc2) * (special.digamma(conc1) - special.digamma( + np.sum(conc1, -1, keepdims=True))), -1)) + + self.assertAllClose(kl_expected, kl_actual_val, atol=0., rtol=1e-6) + self.assertAllClose(kl_sample_val, kl_actual_val, atol=0., rtol=1e-1) + + # Make sure KL(d1||d1) is 0 + kl_same = self.evaluate(kullback_leibler.kl_divergence(d1, d1)) + self.assertAllClose(kl_same, np.zeros_like(kl_expected)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/distributions/dirichlet.py b/tensorflow/python/ops/distributions/dirichlet.py index 1ab58c1..72567e6 100644 --- a/tensorflow/python/ops/distributions/dirichlet.py +++ b/tensorflow/python/ops/distributions/dirichlet.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import special_math_ops from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import util as distribution_util from tensorflow.python.util.tf_export import tf_export @@ -297,3 +298,80 @@ class Dirichlet(distribution.Distribution): math_ops.reduce_sum(x, -1), message="sample last-dimension must sum to `1`"), ], x) + + +@kullback_leibler.RegisterKL(Dirichlet, Dirichlet) +def _kl_dirichlet_dirichlet(d1, d2, name=None): + """Batchwise KL divergence KL(d1 || d2) with d1 and d2 Dirichlet. + + Args: + d1: instance of a Dirichlet distribution object. + d2: instance of a Dirichlet distribution object. + name: (optional) Name to use for created operations. + default is "kl_dirichlet_dirichlet". + + Returns: + Batchwise KL(d1 || d2) + """ + with ops.name_scope(name, "kl_dirichlet_dirichlet", values=[ + d1.concentration, d2.concentration]): + # The KL between Dirichlet distributions can be derived as follows. We have + # + # Dir(x; a) = 1 / B(a) * prod_i[x[i]^(a[i] - 1)] + # + # where B(a) is the multivariate Beta function: + # + # B(a) = Gamma(a[1]) * ... * Gamma(a[n]) / Gamma(a[1] + ... + a[n]) + # + # The KL is + # + # KL(Dir(x; a), Dir(x; b)) = E_Dir(x; a){log(Dir(x; a) / Dir(x; b))} + # + # so we'll need to know the log density of the Dirichlet. This is + # + # log(Dir(x; a)) = sum_i[(a[i] - 1) log(x[i])] - log B(a) + # + # The only term that matters for the expectations is the log(x[i]). To + # compute the expectation of this term over the Dirichlet density, we can + # use the following facts about the Dirichlet in exponential family form: + # 1. log(x[i]) is a sufficient statistic + # 2. expected sufficient statistics (of any exp family distribution) are + # equal to derivatives of the log normalizer with respect to + # corresponding natural parameters: E{T[i](x)} = dA/d(eta[i]) + # + # To proceed, we can rewrite the Dirichlet density in exponential family + # form as follows: + # + # Dir(x; a) = exp{eta(a) . T(x) - A(a)} + # + # where '.' is the dot product of vectors eta and T, and A is a scalar: + # + # eta[i](a) = a[i] - 1 + # T[i](x) = log(x[i]) + # A(a) = log B(a) + # + # Now, we can use fact (2) above to write + # + # E_Dir(x; a)[log(x[i])] + # = dA(a) / da[i] + # = d/da[i] log B(a) + # = d/da[i] (sum_j lgamma(a[j])) - lgamma(sum_j a[j]) + # = digamma(a[i])) - digamma(sum_j a[j]) + # + # Putting it all together, we have + # + # KL[Dir(x; a) || Dir(x; b)] + # = E_Dir(x; a){log(Dir(x; a) / Dir(x; b)} + # = E_Dir(x; a){sum_i[(a[i] - b[i]) log(x[i])} - (lbeta(a) - lbeta(b)) + # = sum_i[(a[i] - b[i]) * E_Dir(x; a){log(x[i])}] - lbeta(a) + lbeta(b) + # = sum_i[(a[i] - b[i]) * (digamma(a[i]) - digamma(sum_j a[j]))] + # - lbeta(a) + lbeta(b)) + + digamma_sum_d1 = math_ops.digamma( + math_ops.reduce_sum(d1.concentration, axis=-1, keepdims=True)) + digamma_diff = math_ops.digamma(d1.concentration) - digamma_sum_d1 + concentration_diff = d1.concentration - d2.concentration + + return (math_ops.reduce_sum(concentration_diff * digamma_diff, axis=-1) - + special_math_ops.lbeta(d1.concentration) + + special_math_ops.lbeta(d2.concentration))