raise ValueError("KL-divergence between two Low Rank Multivariate Normals with\
different event shapes cannot be computed")
- term1 = (_batch_lowrank_logdet(q.cov_factor, q.cov_diag, q._capacitance_tril) -
- _batch_lowrank_logdet(p.cov_factor, p.cov_diag, p._capacitance_tril))
- term3 = _batch_lowrank_mahalanobis(q.cov_factor, q.cov_diag, q.loc - p.loc,
+ term1 = (_batch_lowrank_logdet(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
+ q._capacitance_tril) -
+ _batch_lowrank_logdet(p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag,
+ p._capacitance_tril))
+ term3 = _batch_lowrank_mahalanobis(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
+ q.loc - p.loc,
q._capacitance_tril)
# Expands term2 according to
# inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ (pW @ pW.T + pD)
# = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T)
- qWt_qDinv = q.cov_factor.transpose(-1, -2) / q.cov_diag.unsqueeze(-2)
+ qWt_qDinv = (q._unbroadcasted_cov_factor.transpose(-1, -2) /
+ q._unbroadcasted_cov_diag.unsqueeze(-2))
A = _batch_trtrs_lower(qWt_qDinv, q._capacitance_tril)
- term21 = (p.cov_diag / q.cov_diag).sum(-1)
- term22 = _batch_trace_XXT(p.cov_factor * q.cov_diag.rsqrt().unsqueeze(-1))
- term23 = _batch_trace_XXT(A * p.cov_diag.sqrt().unsqueeze(-2))
- term24 = _batch_trace_XXT(A.matmul(p.cov_factor))
+ term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1)
+ term22 = _batch_trace_XXT(p._unbroadcasted_cov_factor *
+ q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1))
+ term23 = _batch_trace_XXT(A * p._unbroadcasted_cov_diag.sqrt().unsqueeze(-2))
+ term24 = _batch_trace_XXT(A.matmul(p._unbroadcasted_cov_factor))
term2 = term21 + term22 - term23 - term24
return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
raise ValueError("KL-divergence between two (Low Rank) Multivariate Normals with\
different event shapes cannot be computed")
- term1 = (_batch_lowrank_logdet(q.cov_factor, q.cov_diag, q._capacitance_tril) -
+ term1 = (_batch_lowrank_logdet(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
+ q._capacitance_tril) -
2 * _batch_diag(p._unbroadcasted_scale_tril).log().sum(-1))
- term3 = _batch_lowrank_mahalanobis(q.cov_factor, q.cov_diag, q.loc - p.loc,
+ term3 = _batch_lowrank_mahalanobis(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
+ q.loc - p.loc,
q._capacitance_tril)
# Expands term2 according to
# inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ p_tril @ p_tril.T
# = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T
- qWt_qDinv = q.cov_factor.transpose(-1, -2) / q.cov_diag.unsqueeze(-2)
+ qWt_qDinv = (q._unbroadcasted_cov_factor.transpose(-1, -2) /
+ q._unbroadcasted_cov_diag.unsqueeze(-2))
A = _batch_trtrs_lower(qWt_qDinv, q._capacitance_tril)
- term21 = _batch_trace_XXT(p._unbroadcasted_scale_tril * q.cov_diag.rsqrt().unsqueeze(-1))
+ term21 = _batch_trace_XXT(p._unbroadcasted_scale_tril *
+ q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1))
term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril))
term2 = term21 - term22
return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
different event shapes cannot be computed")
term1 = (2 * _batch_diag(q._unbroadcasted_scale_tril).log().sum(-1) -
- _batch_lowrank_logdet(p.cov_factor, p.cov_diag, p._capacitance_tril))
+ _batch_lowrank_logdet(p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag,
+ p._capacitance_tril))
term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
# Expands term2 according to
# inv(qcov) @ pcov = inv(q_tril @ q_tril.T) @ (pW @ pW.T + pD)
combined_batch_shape = torch._C._infer_size(q._unbroadcasted_scale_tril.shape[:-2],
- p.cov_factor.shape[:-2])
+ p._unbroadcasted_cov_factor.shape[:-2])
n = p.event_shape[0]
q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
- p_cov_factor = p.cov_factor.expand(combined_batch_shape + (n, p.cov_factor.size(-1)))
- p_cov_diag = _batch_vector_diag(p.cov_diag.sqrt()).expand(combined_batch_shape + (n, n))
+ p_cov_factor = p._unbroadcasted_cov_factor.expand(combined_batch_shape +
+ (n, p.cov_factor.size(-1)))
+ p_cov_diag = (_batch_vector_diag(p._unbroadcasted_cov_diag.sqrt())
+ .expand(combined_batch_shape + (n, n)))
term21 = _batch_trace_XXT(_batch_trtrs_lower(p_cov_factor, q_scale_tril))
term22 = _batch_trace_XXT(_batch_trtrs_lower(p_cov_diag, q_scale_tril))
term2 = term21 + term22
self.cov_diag = cov_diag_[..., 0]
batch_shape = self.loc.shape[:-1]
- self._capacitance_tril = _batch_capacitance_tril(self.cov_factor, self.cov_diag)
+ self._unbroadcasted_cov_factor = cov_factor
+ self._unbroadcasted_cov_diag = cov_diag
+ self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag)
super(LowRankMultivariateNormal, self).__init__(batch_shape, event_shape,
validate_args=validate_args)
new.loc = self.loc.expand(loc_shape)
new.cov_diag = self.cov_diag.expand(loc_shape)
new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:])
- new._capacitance_tril = self._capacitance_tril.expand(batch_shape + self._capacitance_tril.shape[-2:])
+ new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor
+ new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag
+ new._capacitance_tril = self._capacitance_tril
super(LowRankMultivariateNormal, new).__init__(batch_shape,
self.event_shape,
validate_args=False)
@lazy_property
def variance(self):
- return self.cov_factor.pow(2).sum(-1) + self.cov_diag
+ return (self._unbroadcasted_cov_factor.pow(2).sum(-1)
+ + self._unbroadcasted_cov_diag).expand(self._batch_shape + self._event_shape)
@lazy_property
def scale_tril(self):
# The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
# hence it is well-conditioned and safe to take Cholesky decomposition.
n = self._event_shape[0]
- cov_diag_sqrt_unsqueeze = self.cov_diag.sqrt().unsqueeze(-1)
- Dinvsqrt_W = self.cov_factor / cov_diag_sqrt_unsqueeze
+ cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1)
+ Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze
K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.transpose(-1, -2)).contiguous()
K.view(-1, n * n)[:, ::n + 1] += 1 # add identity matrix to K
- return cov_diag_sqrt_unsqueeze * torch.cholesky(K)
+ scale_tril = cov_diag_sqrt_unsqueeze * torch.cholesky(K)
+ return scale_tril.expand(self._batch_shape + self._event_shape + self._event_shape)
@lazy_property
def covariance_matrix(self):
- return (torch.matmul(self.cov_factor, self.cov_factor.transpose(-1, -2)) +
- _batch_vector_diag(self.cov_diag))
+ covariance_matrix = (torch.matmul(self._unbroadcasted_cov_factor,
+ self._unbroadcasted_cov_factor.transpose(-1, -2))
+ + _batch_vector_diag(self._unbroadcasted_cov_diag))
+ return covariance_matrix.expand(self._batch_shape + self._event_shape +
+ self._event_shape)
@lazy_property
def precision_matrix(self):
# We use "Woodbury matrix identity" to take advantage of low rank form::
# inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
# where :math:`C` is the capacitance matrix.
- Wt_Dinv = self.cov_factor.transpose(-1, -2) / self.cov_diag.unsqueeze(-2)
+ Wt_Dinv = (self._unbroadcasted_cov_factor.transpose(-1, -2)
+ / self._unbroadcasted_cov_diag.unsqueeze(-2))
A = _batch_trtrs_lower(Wt_Dinv, self._capacitance_tril)
- return (_batch_vector_diag(self.cov_diag.reciprocal()) -
- torch.matmul(A.transpose(-1, -2), A))
+ precision_matrix = (_batch_vector_diag(self._unbroadcasted_cov_diag.reciprocal())
+ - torch.matmul(A.transpose(-1, -2), A))
+ return precision_matrix.expand(self._batch_shape + self._event_shape +
+ self._event_shape)
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
W_shape = shape[:-1] + self.cov_factor.shape[-1:]
eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device)
eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
- return self.loc + _batch_mv(self.cov_factor, eps_W) + self.cov_diag.sqrt() * eps_D
+ return (self.loc + _batch_mv(self._unbroadcasted_cov_factor, eps_W)
+ + self._unbroadcasted_cov_diag.sqrt() * eps_D)
def log_prob(self, value):
if self._validate_args:
self._validate_sample(value)
diff = value - self.loc
- M = _batch_lowrank_mahalanobis(self.cov_factor, self.cov_diag, diff,
+ M = _batch_lowrank_mahalanobis(self._unbroadcasted_cov_factor,
+ self._unbroadcasted_cov_diag,
+ diff,
self._capacitance_tril)
- log_det = _batch_lowrank_logdet(self.cov_factor, self.cov_diag,
+ log_det = _batch_lowrank_logdet(self._unbroadcasted_cov_factor,
+ self._unbroadcasted_cov_diag,
self._capacitance_tril)
return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M)
def entropy(self):
- log_det = _batch_lowrank_logdet(self.cov_factor, self.cov_diag,
+ log_det = _batch_lowrank_logdet(self._unbroadcasted_cov_factor,
+ self._unbroadcasted_cov_diag,
self._capacitance_tril)
H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det)
if len(self._batch_shape) == 0: