From 8045b3eb143ed70ab1873e148d266d56e0e1481f Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Mon, 11 Mar 2019 08:07:22 -0700 Subject: [PATCH] Registering of kl-divergence for independent distribution (#17681) Summary: This address issue https://github.com/pytorch/pytorch/issues/13545 and implements the proposed fix together with a single test. Pull Request resolved: https://github.com/pytorch/pytorch/pull/17681 Differential Revision: D14360161 Pulled By: ezyang fbshipit-source-id: 427afc88e9054b5b0dc39ebbab1087b990695ea5 --- test/test_distributions.py | 2 ++ torch/distributions/kl.py | 9 +++++++++ 2 files changed, 11 insertions(+) diff --git a/test/test_distributions.py b/test/test_distributions.py index b803275..353c4cd 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3078,6 +3078,7 @@ class TestKL(TestCase): laplace = pairwise(Laplace, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5]) lognormal = pairwise(LogNormal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0]) normal = pairwise(Normal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0]) + independent = (Independent(normal[0], 1), Independent(normal[1], 1)) onehotcategorical = pairwise(OneHotCategorical, [[0.4, 0.3, 0.3], [0.2, 0.7, 0.1], [0.33, 0.33, 0.34], @@ -3127,6 +3128,7 @@ class TestKL(TestCase): (gumbel, gumbel), (gumbel, normal), (halfnormal, halfnormal), + (independent, independent), (laplace, laplace), (lognormal, lognormal), (laplace, normal), diff --git a/torch/distributions/kl.py b/torch/distributions/kl.py index 7745ff6..650cf8a 100644 --- a/torch/distributions/kl.py +++ b/torch/distributions/kl.py @@ -17,6 +17,7 @@ from .gamma import Gamma from .geometric import Geometric from .gumbel import Gumbel from .half_normal import HalfNormal +from .independent import Independent from .laplace import Laplace from .logistic_normal import LogisticNormal from .lowrank_multivariate_normal import (LowRankMultivariateNormal, _batch_lowrank_logdet, @@ -730,3 +731,11 @@ def _kl_uniform_pareto(p, q): result = t2 * (q.alpha + 1) - t1 result[p.low < q.support.lower_bound] = inf return result + + +@register_kl(Independent, Independent) +def _kl_independent_independent(p, q): + if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims: + raise NotImplementedError + result = kl_divergence(p.base_dist, q.base_dist) + return _sum_rightmost(result, p.reinterpreted_batch_ndims) -- 2.7.4