Implements backward for `torch.lu_solve` (#61681)
authorNikita Vedeneev <nik@quansight.com>
Fri, 13 Aug 2021 04:15:42 +0000 (21:15 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 13 Aug 2021 04:17:11 +0000 (21:17 -0700)
Summary:
Fixes https://github.com/pytorch/pytorch/issues/22620

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61681

Reviewed By: ngimel

Differential Revision: D30063116

Pulled By: mruberry

fbshipit-source-id: e095b0cadfb7c8b37a7ef91bae5b5dc170d8ef1c

tools/autograd/derivatives.yaml
tools/autograd/gen_variable_type.py
torch/csrc/autograd/FunctionsManual.cpp
torch/csrc/autograd/FunctionsManual.h
torch/testing/_internal/common_methods_invocations.py

index 4224e41..4dbe9b4 100644 (file)
   self: not_implemented("lu_with_info")
 
 - name: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor
-  self: not_implemented("lu_solve")
+  self, LU_data: lu_solve_backward(grad, self, LU_data, LU_pivots)
 
 - name: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U)
   LU_data: lu_unpack_backward(grads, LU_data, unpack_data)
index 663afc3..f28033d 100644 (file)
@@ -102,7 +102,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
     'eig', 'lerp', 'linalg_vector_norm', 'cumprod', 'prod', 'index_copy', 'lu', 'unfold', 'unfold_backward',
     'index', 'masked_fill', 'cross', 'lu_unpack', 'renorm', '_conj_physical',
     'scatter', 'scatter_add', 'sigmoid', 'sigmoid_backward', 'trapezoid', 'cumulative_trapezoid',
-    'conj_physical_', '_neg_view', '_reshape_alias', '_det_lu_based_helper',
+    'conj_physical_', '_neg_view', '_reshape_alias', '_det_lu_based_helper', 'lu_solve',
 }
 
 GRADIENT_IMPLEMENTED_FOR_SPARSE_COMPLEX = {
index 4ac1f50..8fc2223 100644 (file)
@@ -3593,6 +3593,147 @@ Tensor i1e_backward(
   });
 }
 
+// lu_solve is a map (LU, P, B) -> (PLU)^{-1} B,
+// where LU = L + U - I and P is a permutation matrix, and is fixed.
+//
+// Let 1 = ones_like(LU),
+// 1_U = 1.triu(),
+// 1_L = 1 - 1_U (note the zero diagonal),
+// * := the Hadamard (element-wise) product
+//
+// Forward AD:
+//
+// Let X := U^{-1} L^{-1} P^T B be the output of the function.
+// Also, the LU input of the function could be represented as
+// LU = L + U - I.
+//
+// Differentiating LU = L + U - I produces:
+// dLU = dL + dU.
+// Noting that dL and dU are lower- and upper-triangular, respectively,
+// and that the diagonal of L is never explicitly exposed, so
+// diag(dL) = 0, it follows
+// dL = dLU * 1_L,
+// dU = dLU * 1_U.
+//
+// Differentiating X = U^{-1} L^{-1} P^T B produces:
+// dX = dU^{-1} L^{-1} P^T B + U^{-1} dL^{-1} P^T B + U^{-1} L^{-1} P^T dB
+// Note that for any invertible matrix A we have A A^{-1} = I, hence
+// dA A^{-1} + A dA^{-1} = 0 => dA^{-1} = -A^{-1} dA A^{-1}.
+// Inserting it back into the definition of dX gives:
+// dX = -U^{-1} dU U^{-1} L^{-1} P^T B - U^{-1} L^{-1} dL L^{-1} P^T B + U^{-1} L^{-1} P^T dB
+//
+// Backward AD:
+//
+// Using the definition of dL, dU from above:
+// Tr(L_grad^H dL) + Tr(U_grad^H dU) = Tr(L_grad^H (dLU * 1_L)) + Tr(U_grad^H (dLU * 1_U))
+//                                   = [using Tr(A (B * C)) = Tr((A * B^T) C)
+//                                   = Tr((L_grad^H * 1_L^T) dLU) + Tr((U_grad^H * 1_U^T) dLU),
+// hence
+// LU_grad = L_grad * 1_L + U_grad * 1_U (!!!)
+//
+// Using the definition of dX:
+// Tr(X_grad^H X) = Tr(-U^{-1} L^{-1} P^T B X_grad^H U^{-1} dU)
+//                + Tr(-L^{-1} P^T B X_grad^H U^{-1} L^{-1} dL)
+//                + Tr(X_grad^H U^{-1} L^{-1} P^T dB).
+// And we immediately get:
+// B_grad = [X_grad^H U^{-1} L^{-1} P^T]^H = [U^{-1} L^{-1} P^T]^H X_grad.
+// Let Z := L^{-1} P^T B X_grad^H U^{-1}, then
+// U_grad = [-U^{-1} Z]^H,
+// L_grad = [-Z L^{-1}]^H.
+// After inserting U_grad and L_grad into (!!!) we get the value for LU_grad.
+
+std::tuple<Tensor, Tensor> lu_solve_backward(
+  const Tensor& grad,
+  const Tensor& self,
+  const Tensor& LU_data,
+  const Tensor& LU_pivots
+) {
+  if (!grad.defined()) {
+    return std::make_tuple(Tensor{}, Tensor{});
+  }
+
+  Tensor P, L, U;
+  std::tie(P, L, U) = at::lu_unpack(LU_data, LU_pivots);
+
+  auto n = LU_data.size(-1);
+  auto nrhs = self.size(-1);
+
+  // stores L^{-1} P^T
+  Tensor Y;
+
+  Tensor LU_data_grad;
+  if (LU_data.requires_grad()) {
+    // X = -L^{-1} P^T B grad^H
+    auto X = -std::get<0>(at::triangular_solve(
+      (nrhs < n) ? P.transpose(-2, -1).matmul(self) : P.transpose(-2, -1),
+      L,
+      /*upper=*/false,
+      /*transpose=*/false,
+      /*unitriangular=*/true
+    ));
+    if (nrhs >= n) {
+      // Y stores L^{-1} P^T to be reused in the computation of self_grad
+      if (self.requires_grad()) {
+        Y = -X;
+      }
+      X = X.matmul(self);
+    }
+    X = X.matmul(grad.transpose(-2, -1).conj());
+
+    // X <- X U^{-1}
+    X = std::get<0>(at::triangular_solve(
+      X.transpose(-2, -1),
+      U,
+      /*upper=*/true,
+      /*transpose=*/true,
+      /*unitriangular=*/false
+    )).transpose(-2, -1);
+
+    // U_grad = [U^{-1} X]^H
+    auto U_grad = std::get<0>(at::triangular_solve(
+      X,
+      U,
+      /*upper=*/true,
+      /*transpose=*/false,
+      /*unitriangular=*/false
+    )).transpose(-2, -1).conj();
+
+    // L_grad = L^{-H} X^H
+    auto L_grad = std::get<0>(at::triangular_solve(
+      X.transpose(-2, -1),
+      L,
+      /*upper=*/false,
+      /*transpose=*/true,
+      /*unitriangular=*/true
+    )).conj();
+
+    // LU_data_grad = L_grad * 1_L + U_grad * 1_U
+    LU_data_grad = L_grad.tril(-1) + U_grad.triu();
+  }
+
+  // self_grad = [grad^H U^{-1} L^{-1} P^T]^H = [U^{-1} L^{-1} P^T]^H grad
+  Tensor self_grad;
+  if (self.requires_grad()) {
+    self_grad = std::get<0>(at::triangular_solve(
+      // reuse Y := L^{-1} P^T if already computed
+      Y.defined() ? Y : std::get<0>(at::triangular_solve(
+        P.transpose(-2, -1),
+        L,
+        /*upper=*/false,
+        /*transpose=*/false,
+        /*unitriangular=*/true
+      )),
+      U,
+      /*upper=*/true,
+      /*transpose=*/false,
+      /*unitriangular=*/false
+    )).transpose(-2, -1).conj().matmul(grad);
+  }
+
+
+  return std::make_tuple(self_grad, LU_data_grad);
+}
+
 Tensor lu_unpack_backward(
   const variable_list& grads,
   const Tensor& LU_data,
index 288ae11..91ddec3 100644 (file)
@@ -247,6 +247,12 @@ Tensor i1e_backward(
     const Tensor& grad,
     const Tensor& self,
     const Tensor& result);
+std::tuple<Tensor, Tensor> lu_solve_backward(
+  const Tensor& grad,
+  const Tensor& self,
+  const Tensor& LU_data,
+  const Tensor& LU_pivots
+);
 Tensor lu_unpack_backward(
   const variable_list& grads,
   const Tensor& LU_data,
index 3640436..e589e9c 100644 (file)
@@ -3150,6 +3150,33 @@ def sample_inputs_lu(op_info, device, dtype, requires_grad=False, **kwargs):
     return list(generate_samples())
 
 
+def sample_inputs_lu_solve(op_info, device, dtype, requires_grad=False, **kwargs):
+    from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
+
+    batches = [(), (0, ), (2, )]
+    ns = [5, 3, 0]
+    nrhs = [0, 1, 6]
+
+    def generate_samples():
+        for n, batch, rhs in product(ns, batches, nrhs):
+            a = random_fullrank_matrix_distinct_singular_value(n, *batch, dtype=dtype, device=device)
+            requires_grad_options = (False,) if not requires_grad else (True, False)
+            # we try all possible combinations of requires_grad for each input
+            for lu_requires_grad, b_requires_grad in product(requires_grad_options, requires_grad_options):
+                # when requires_grad == True, at least one input has to have requires_grad enabled
+                if requires_grad and not lu_requires_grad and not b_requires_grad:
+                    continue
+                # we run LU several times to guarantee that the produced SampleInputs are independent
+                # this is especially important when setting different requries_grad for same tensors!
+                lu, pivs = a.lu()
+                lu.requires_grad = lu_requires_grad
+                b = torch.randn(*batch, n, rhs, dtype=dtype, device=device)
+                b.requires_grad = b_requires_grad
+                yield SampleInput(b, args=(lu, pivs))
+
+    return list(generate_samples())
+
+
 def sample_inputs_lu_unpack(op_info, device, dtype, requires_grad=False, **kwargs):
     # not needed once OpInfo tests support Iterables
     def generate_samples():
@@ -6447,6 +6474,12 @@ op_db: List[OpInfo] = [
                # Skip operator schema test because this is a functional and not an operator
                SkipInfo('TestOperatorSignatures', 'test_get_torch_func_signature_exhaustive'),
            )),
+    OpInfo('lu_solve',
+           op=torch.lu_solve,
+           dtypes=floating_and_complex_types(),
+           check_batched_gradgrad=False,
+           sample_inputs_func=sample_inputs_lu_solve,
+           decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack]),
     OpInfo('lu_unpack',
            op=torch.lu_unpack,
            dtypes=floating_and_complex_types(),