Fix expanded mvn and lowrankmvn (#14557)
authorfehiepsi <fehiepsi@gmail.com>
Fri, 30 Nov 2018 18:44:56 +0000 (10:44 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 30 Nov 2018 18:49:13 +0000 (10:49 -0800)
Summary:
This PR fixes an issue of the slowness expanded MVN.

A notebook to show the problem is [here](https://gist.github.com/fehiepsi/b15ac2978f1045d6d96b1d35b640d742). Basically, mvn's sample and log_prob have expensive computations based on `cholesky` and `trtrs`. We can save a lot of computation based on caching the unbroadcasted version of `scale_tril` (or `cov_diag`, `cov_factor` in lowrank mvn).
When expanding, this cached tensor should not be expanded together with other arguments.

Ref: https://github.com/uber/pyro/issues/1586

cc neerajprad fritzo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14557

Differential Revision: D13277408

Pulled By: soumith

fbshipit-source-id: a6b16f999b008d5da148ccf519b7f32d9c6a5351

torch/distributions/kl.py
torch/distributions/lowrank_multivariate_normal.py
torch/distributions/multivariate_normal.py

index 8e43638..d166e99 100644 (file)
@@ -300,19 +300,24 @@ def _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q):
         raise ValueError("KL-divergence between two Low Rank Multivariate Normals with\
                           different event shapes cannot be computed")
 
-    term1 = (_batch_lowrank_logdet(q.cov_factor, q.cov_diag, q._capacitance_tril) -
-             _batch_lowrank_logdet(p.cov_factor, p.cov_diag, p._capacitance_tril))
-    term3 = _batch_lowrank_mahalanobis(q.cov_factor, q.cov_diag, q.loc - p.loc,
+    term1 = (_batch_lowrank_logdet(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
+                                   q._capacitance_tril) -
+             _batch_lowrank_logdet(p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag,
+                                   p._capacitance_tril))
+    term3 = _batch_lowrank_mahalanobis(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
+                                       q.loc - p.loc,
                                        q._capacitance_tril)
     # Expands term2 according to
     # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ (pW @ pW.T + pD)
     #                  = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T)
-    qWt_qDinv = q.cov_factor.transpose(-1, -2) / q.cov_diag.unsqueeze(-2)
+    qWt_qDinv = (q._unbroadcasted_cov_factor.transpose(-1, -2) /
+                 q._unbroadcasted_cov_diag.unsqueeze(-2))
     A = _batch_trtrs_lower(qWt_qDinv, q._capacitance_tril)
-    term21 = (p.cov_diag / q.cov_diag).sum(-1)
-    term22 = _batch_trace_XXT(p.cov_factor * q.cov_diag.rsqrt().unsqueeze(-1))
-    term23 = _batch_trace_XXT(A * p.cov_diag.sqrt().unsqueeze(-2))
-    term24 = _batch_trace_XXT(A.matmul(p.cov_factor))
+    term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1)
+    term22 = _batch_trace_XXT(p._unbroadcasted_cov_factor *
+                              q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1))
+    term23 = _batch_trace_XXT(A * p._unbroadcasted_cov_diag.sqrt().unsqueeze(-2))
+    term24 = _batch_trace_XXT(A.matmul(p._unbroadcasted_cov_factor))
     term2 = term21 + term22 - term23 - term24
     return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
 
@@ -323,16 +328,20 @@ def _kl_multivariatenormal_lowrankmultivariatenormal(p, q):
         raise ValueError("KL-divergence between two (Low Rank) Multivariate Normals with\
                           different event shapes cannot be computed")
 
-    term1 = (_batch_lowrank_logdet(q.cov_factor, q.cov_diag, q._capacitance_tril) -
+    term1 = (_batch_lowrank_logdet(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
+                                   q._capacitance_tril) -
              2 * _batch_diag(p._unbroadcasted_scale_tril).log().sum(-1))
-    term3 = _batch_lowrank_mahalanobis(q.cov_factor, q.cov_diag, q.loc - p.loc,
+    term3 = _batch_lowrank_mahalanobis(q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag,
+                                       q.loc - p.loc,
                                        q._capacitance_tril)
     # Expands term2 according to
     # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ p_tril @ p_tril.T
     #                  = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T
-    qWt_qDinv = q.cov_factor.transpose(-1, -2) / q.cov_diag.unsqueeze(-2)
+    qWt_qDinv = (q._unbroadcasted_cov_factor.transpose(-1, -2) /
+                 q._unbroadcasted_cov_diag.unsqueeze(-2))
     A = _batch_trtrs_lower(qWt_qDinv, q._capacitance_tril)
-    term21 = _batch_trace_XXT(p._unbroadcasted_scale_tril * q.cov_diag.rsqrt().unsqueeze(-1))
+    term21 = _batch_trace_XXT(p._unbroadcasted_scale_tril *
+                              q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1))
     term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril))
     term2 = term21 - term22
     return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
@@ -345,16 +354,19 @@ def _kl_lowrankmultivariatenormal_multivariatenormal(p, q):
                           different event shapes cannot be computed")
 
     term1 = (2 * _batch_diag(q._unbroadcasted_scale_tril).log().sum(-1) -
-             _batch_lowrank_logdet(p.cov_factor, p.cov_diag, p._capacitance_tril))
+             _batch_lowrank_logdet(p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag,
+                                   p._capacitance_tril))
     term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
     # Expands term2 according to
     # inv(qcov) @ pcov = inv(q_tril @ q_tril.T) @ (pW @ pW.T + pD)
     combined_batch_shape = torch._C._infer_size(q._unbroadcasted_scale_tril.shape[:-2],
-                                                p.cov_factor.shape[:-2])
+                                                p._unbroadcasted_cov_factor.shape[:-2])
     n = p.event_shape[0]
     q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
-    p_cov_factor = p.cov_factor.expand(combined_batch_shape + (n, p.cov_factor.size(-1)))
-    p_cov_diag = _batch_vector_diag(p.cov_diag.sqrt()).expand(combined_batch_shape + (n, n))
+    p_cov_factor = p._unbroadcasted_cov_factor.expand(combined_batch_shape +
+                                                      (n, p.cov_factor.size(-1)))
+    p_cov_diag = (_batch_vector_diag(p._unbroadcasted_cov_diag.sqrt())
+                  .expand(combined_batch_shape + (n, n)))
     term21 = _batch_trace_XXT(_batch_trtrs_lower(p_cov_factor, q_scale_tril))
     term22 = _batch_trace_XXT(_batch_trtrs_lower(p_cov_diag, q_scale_tril))
     term2 = term21 + term22
index 7377522..019ac62 100644 (file)
@@ -112,7 +112,9 @@ class LowRankMultivariateNormal(Distribution):
         self.cov_diag = cov_diag_[..., 0]
         batch_shape = self.loc.shape[:-1]
 
-        self._capacitance_tril = _batch_capacitance_tril(self.cov_factor, self.cov_diag)
+        self._unbroadcasted_cov_factor = cov_factor
+        self._unbroadcasted_cov_diag = cov_diag
+        self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag)
         super(LowRankMultivariateNormal, self).__init__(batch_shape, event_shape,
                                                         validate_args=validate_args)
 
@@ -123,7 +125,9 @@ class LowRankMultivariateNormal(Distribution):
         new.loc = self.loc.expand(loc_shape)
         new.cov_diag = self.cov_diag.expand(loc_shape)
         new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:])
-        new._capacitance_tril = self._capacitance_tril.expand(batch_shape + self._capacitance_tril.shape[-2:])
+        new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor
+        new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag
+        new._capacitance_tril = self._capacitance_tril
         super(LowRankMultivariateNormal, new).__init__(batch_shape,
                                                        self.event_shape,
                                                        validate_args=False)
@@ -136,7 +140,8 @@ class LowRankMultivariateNormal(Distribution):
 
     @lazy_property
     def variance(self):
-        return self.cov_factor.pow(2).sum(-1) + self.cov_diag
+        return (self._unbroadcasted_cov_factor.pow(2).sum(-1)
+                + self._unbroadcasted_cov_diag).expand(self._batch_shape + self._event_shape)
 
     @lazy_property
     def scale_tril(self):
@@ -146,46 +151,58 @@ class LowRankMultivariateNormal(Distribution):
         # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1,
         # hence it is well-conditioned and safe to take Cholesky decomposition.
         n = self._event_shape[0]
-        cov_diag_sqrt_unsqueeze = self.cov_diag.sqrt().unsqueeze(-1)
-        Dinvsqrt_W = self.cov_factor / cov_diag_sqrt_unsqueeze
+        cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1)
+        Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze
         K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.transpose(-1, -2)).contiguous()
         K.view(-1, n * n)[:, ::n + 1] += 1  # add identity matrix to K
-        return cov_diag_sqrt_unsqueeze * torch.cholesky(K)
+        scale_tril = cov_diag_sqrt_unsqueeze * torch.cholesky(K)
+        return scale_tril.expand(self._batch_shape + self._event_shape + self._event_shape)
 
     @lazy_property
     def covariance_matrix(self):
-        return (torch.matmul(self.cov_factor, self.cov_factor.transpose(-1, -2)) +
-                _batch_vector_diag(self.cov_diag))
+        covariance_matrix = (torch.matmul(self._unbroadcasted_cov_factor,
+                                          self._unbroadcasted_cov_factor.transpose(-1, -2))
+                             + _batch_vector_diag(self._unbroadcasted_cov_diag))
+        return covariance_matrix.expand(self._batch_shape + self._event_shape +
+                                        self._event_shape)
 
     @lazy_property
     def precision_matrix(self):
         # We use "Woodbury matrix identity" to take advantage of low rank form::
         #     inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D)
         # where :math:`C` is the capacitance matrix.
-        Wt_Dinv = self.cov_factor.transpose(-1, -2) / self.cov_diag.unsqueeze(-2)
+        Wt_Dinv = (self._unbroadcasted_cov_factor.transpose(-1, -2)
+                   / self._unbroadcasted_cov_diag.unsqueeze(-2))
         A = _batch_trtrs_lower(Wt_Dinv, self._capacitance_tril)
-        return (_batch_vector_diag(self.cov_diag.reciprocal()) -
-                torch.matmul(A.transpose(-1, -2), A))
+        precision_matrix = (_batch_vector_diag(self._unbroadcasted_cov_diag.reciprocal())
+                            - torch.matmul(A.transpose(-1, -2), A))
+        return precision_matrix.expand(self._batch_shape + self._event_shape +
+                                       self._event_shape)
 
     def rsample(self, sample_shape=torch.Size()):
         shape = self._extended_shape(sample_shape)
         W_shape = shape[:-1] + self.cov_factor.shape[-1:]
         eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device)
         eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device)
-        return self.loc + _batch_mv(self.cov_factor, eps_W) + self.cov_diag.sqrt() * eps_D
+        return (self.loc + _batch_mv(self._unbroadcasted_cov_factor, eps_W)
+                + self._unbroadcasted_cov_diag.sqrt() * eps_D)
 
     def log_prob(self, value):
         if self._validate_args:
             self._validate_sample(value)
         diff = value - self.loc
-        M = _batch_lowrank_mahalanobis(self.cov_factor, self.cov_diag, diff,
+        M = _batch_lowrank_mahalanobis(self._unbroadcasted_cov_factor,
+                                       self._unbroadcasted_cov_diag,
+                                       diff,
                                        self._capacitance_tril)
-        log_det = _batch_lowrank_logdet(self.cov_factor, self.cov_diag,
+        log_det = _batch_lowrank_logdet(self._unbroadcasted_cov_factor,
+                                        self._unbroadcasted_cov_diag,
                                         self._capacitance_tril)
         return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M)
 
     def entropy(self):
-        log_det = _batch_lowrank_logdet(self.cov_factor, self.cov_diag,
+        log_det = _batch_lowrank_logdet(self._unbroadcasted_cov_factor,
+                                        self._unbroadcasted_cov_diag,
                                         self._capacitance_tril)
         H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det)
         if len(self._batch_shape) == 0:
index f6330d6..fa13cc3 100644 (file)
@@ -136,7 +136,7 @@ class MultivariateNormal(Distribution):
         loc_shape = batch_shape + self.event_shape
         cov_shape = batch_shape + self.event_shape + self.event_shape
         new.loc = self.loc.expand(loc_shape)
-        new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand(cov_shape)
+        new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril
         if 'covariance_matrix' in self.__dict__:
             new.covariance_matrix = self.covariance_matrix.expand(cov_shape)
         if 'scale_tril' in self.__dict__: