`torch.lu`: forward AD support (#64742)
authorNikita Vedeneev <nik@quansight.com>
Fri, 10 Sep 2021 14:17:30 +0000 (07:17 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Fri, 10 Sep 2021 14:19:11 +0000 (07:19 -0700)
Summary:
As per title.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64742

Reviewed By: H-Huang

Differential Revision: D30841227

Pulled By: albanD

fbshipit-source-id: dc4d043ab94358594adb110fbbbb60750c98262a

tools/autograd/derivatives.yaml
torch/csrc/autograd/FunctionsManual.cpp
torch/csrc/autograd/FunctionsManual.h
torch/testing/_internal/common_methods_invocations.py

index 79ef447..660c188 100644 (file)
 
 - name: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info)
   self: _lu_with_info_backward(grad, self, LU, pivots)
+  LU: _lu_with_info_jvp(self_t, LU, pivots)
 
 - name: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor
   self, LU_data: lu_solve_backward(grad, self, LU_data, LU_pivots)
index 6f05646..e53a0c5 100644 (file)
@@ -4051,6 +4051,91 @@ Tensor _lu_with_info_backward(
   return plu_backward_base({/*L_grad=*/grad, /*U_grad=*/grad}, self, P, L, U);
 }
 
+Tensor _lu_with_info_jvp(
+  const Tensor& dX,
+  const Tensor& LU,
+  const Tensor& pivs
+) {
+  // This function is based on the forward AD derivations outlined
+  // in the description to the plu_backward_base function.
+
+  Tensor P, L, U;
+  std::tie(P, L, U) = at::lu_unpack(LU, pivs);
+
+  auto m = LU.size(-2);
+  auto n = LU.size(-1);
+  auto k = std::min(m, n);
+
+  auto pdX = P.transpose(-1, -2).matmul(dX);
+
+  // similar to the backward implementation, we also consider block structures such as:
+  // for a matrix A of size m x n we decompose it as
+  // A = (A1 | A2) with A1 of size m x m if m <= n and
+  // A = (A1^T | A2^T)^T with A1 of size n x n if m > n.
+  auto pdX1 = pdX.narrow(-2, 0, k).narrow(-1, 0, k);
+  auto L1 = L.narrow(-2, 0, k).narrow(-1, 0, k);
+  auto U1 = U.narrow(-2, 0, k).narrow(-1, 0, k);
+
+  // dK = L1^{-1} pdX1
+  auto dK = std::get<0>(at::triangular_solve(
+    pdX1,
+    L1,
+    /*upper=*/false,
+    /*transpose=*/false,
+    /*unitriangular=*/true
+  ));
+  // dK <- dK U1^{-1}
+  dK = std::get<0>(at::triangular_solve(
+    dK.transpose(-1, -2),
+    U1,
+    /*upper=*/true,
+    /*transpose=*/true
+  )).transpose(-1, -2);
+
+  auto dL1 = L1.matmul(dK.tril(-1));
+  auto dU1 = dK.triu().matmul(U1);
+
+  // since LU = L + U - I, we have that dLU = dL + dU
+  // if LU is of size m x n, we always have
+  // dLU1 = dL1 + dU1, where the block indexing follows the rules
+  // outlined above.
+  if (m == n) {
+    return dL1 + dU1;
+  }
+  else {
+    auto dLU = at::zeros_like(LU);
+    dLU.narrow(-2, 0, k).narrow(-1, 0, k).copy_(dL1 + dU1);
+
+    if (m < n) {
+      // we only need to update dU2 defined as
+      // dU2 := L1^{-1} (pdX2 - dL1 U2)
+      auto pdX2 = pdX.narrow(-1, k, n - k);
+      auto U2 = U.narrow(-1, k, n - k);
+      dLU.narrow(-1, k, n - k).copy_(std::get<0>(at::triangular_solve(
+        pdX2 - dL1.matmul(U2),
+        L1,
+        /*upper=*/false,
+        /*transpose=*/false,
+        /*unitriangular=*/true
+      )));
+    }
+    else {
+      // we only need to update dL2 defined as
+      // dL2 := (pdX2 - L2 dU1) U1^{-1}
+      auto pdX2 = pdX.narrow(-2, k, m - k);
+      auto L2 = L.narrow(-2, k, m - k);
+      dLU.narrow(-2, k, m - k).copy_(std::get<0>(at::triangular_solve(
+        (pdX2 - L2.matmul(dU1)).transpose(-1, -2),
+        U1,
+        /*upper=*/true,
+        /*transpose=*/true
+      )).transpose(-1, -2));
+    }
+
+    return dLU;
+  }
+}
+
 } // namespace details
 } // namespace generated
 } // namespace autograd
index 1aa1062..b24dce7 100644 (file)
@@ -287,6 +287,11 @@ Tensor _lu_with_info_backward(
   const Tensor& LU,
   const Tensor& pivs
 );
+Tensor _lu_with_info_jvp(
+  const Tensor& dX,
+  const Tensor& LU,
+  const Tensor& pivs
+);
 
 Tensor cat_jvp(at::TensorList tensors, int64_t dim);
 Tensor cumprod_jvp(Tensor self_t, Tensor self_p, Tensor result, int dim);
index f0c163c..8451791 100644 (file)
@@ -7090,6 +7090,7 @@ op_db: List[OpInfo] = [
            # This causes vmap failures, hence we skip batched gradient checks
            check_batched_grad=False,
            check_batched_gradgrad=False,
+           supports_forward_ad=True,
            supports_out=False,
            sample_inputs_func=sample_inputs_lu,
            decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],