laplace = pairwise(Laplace, [-2.0, 4.0, -3.0, 6.0], [1.0, 2.5, 1.0, 2.5])
lognormal = pairwise(LogNormal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
normal = pairwise(Normal, [-2.0, 2.0, -3.0, 3.0], [1.0, 2.0, 1.0, 2.0])
+ independent = (Independent(normal[0], 1), Independent(normal[1], 1))
onehotcategorical = pairwise(OneHotCategorical, [[0.4, 0.3, 0.3],
[0.2, 0.7, 0.1],
[0.33, 0.33, 0.34],
(gumbel, gumbel),
(gumbel, normal),
(halfnormal, halfnormal),
+ (independent, independent),
(laplace, laplace),
(lognormal, lognormal),
(laplace, normal),
from .geometric import Geometric
from .gumbel import Gumbel
from .half_normal import HalfNormal
+from .independent import Independent
from .laplace import Laplace
from .logistic_normal import LogisticNormal
from .lowrank_multivariate_normal import (LowRankMultivariateNormal, _batch_lowrank_logdet,
result = t2 * (q.alpha + 1) - t1
result[p.low < q.support.lower_bound] = inf
return result
+
+
+@register_kl(Independent, Independent)
+def _kl_independent_independent(p, q):
+ if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims:
+ raise NotImplementedError
+ result = kl_divergence(p.base_dist, q.base_dist)
+ return _sum_rightmost(result, p.reinterpreted_batch_ndims)