Registering of kl-divergence for independent distribution (#17681)
authorNicki Skafte <skaftenicki@gmail.com>
Mon, 11 Mar 2019 15:07:22 +0000 (08:07 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Mon, 11 Mar 2019 15:10:16 +0000 (08:10 -0700)
Summary:
This address issue https://github.com/pytorch/pytorch/issues/13545 and implements the proposed fix together with a single test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17681

Differential Revision: D14360161

Pulled By: ezyang

fbshipit-source-id: 427afc88e9054b5b0dc39ebbab1087b990695ea5

test/test_distributions.py
torch/distributions/kl.py

index b803275..353c4cd 100644 (file)
@@ -3078,6 +3078,7 @@ class TestKL(TestCase):
         laplace = pairwise(Laplace, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5])
         lognormal = pairwise(LogNormal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
         normal = pairwise(Normal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
+        independent = (Independent(normal[0], 1), Independent(normal[1], 1))
         onehotcategorical = pairwise(OneHotCategorical, [[0.4, 0.3, 0.3],
                                                          [0.2, 0.7, 0.1],
                                                          [0.33, 0.33, 0.34],
@@ -3127,6 +3128,7 @@ class TestKL(TestCase):
             (gumbel, gumbel),
             (gumbel, normal),
             (halfnormal, halfnormal),
+            (independent, independent),
             (laplace, laplace),
             (lognormal, lognormal),
             (laplace, normal),
index 7745ff6..650cf8a 100644 (file)
@@ -17,6 +17,7 @@ from .gamma import Gamma
 from .geometric import Geometric
 from .gumbel import Gumbel
 from .half_normal import HalfNormal
+from .independent import Independent
 from .laplace import Laplace
 from .logistic_normal import LogisticNormal
 from .lowrank_multivariate_normal import (LowRankMultivariateNormal, _batch_lowrank_logdet,
@@ -730,3 +731,11 @@ def _kl_uniform_pareto(p, q):
     result = t2 * (q.alpha + 1) - t1
     result[p.low < q.support.lower_bound] = inf
     return result
+
+
+@register_kl(Independent, Independent)
+def _kl_independent_independent(p, q):
+    if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
+        raise NotImplementedError
+    result = kl_divergence(p.base_dist, q.base_dist)
+    return _sum_rightmost(result, p.reinterpreted_batch_ndims)