`torch.lu_solve`: forward AD support (#64646)
authorNikita Vedeneev <nik@quansight.com>
Thu, 9 Sep 2021 15:56:29 +0000 (08:56 -0700)
committerFacebook GitHub Bot <facebook-github-bot@users.noreply.github.com>
Thu, 9 Sep 2021 15:58:00 +0000 (08:58 -0700)
Summary:
As per title.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/64646

Reviewed By: VitalyFedyunin

Differential Revision: D30807898

Pulled By: albanD

fbshipit-source-id: 1f943c22357dd1b3662cfe0d2a26af68e3a2df4c

tools/autograd/derivatives.yaml
torch/csrc/autograd/FunctionsManual.cpp
torch/csrc/autograd/FunctionsManual.h
torch/testing/_internal/common_methods_invocations.py

index 505130f..79ef447 100644 (file)
 
 - name: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor
   self, LU_data: lu_solve_backward(grad, self, LU_data, LU_pivots)
+  result: lu_solve_forward_AD(self_t, LU_data_t, LU_data_p, LU_pivots, result)
 
 - name: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U)
   LU_data: lu_unpack_backward(grads, LU_data, unpack_data)
index 9ccfbd1..6f05646 100644 (file)
@@ -3746,6 +3746,42 @@ std::tuple<Tensor, Tensor> lu_solve_backward(
   return std::make_tuple(self_grad, LU_data_grad);
 }
 
+Tensor lu_solve_forward_AD(
+  const Tensor& dB,
+  const Tensor& dLU_data,
+  const Tensor& LU_data,
+  const Tensor& LU_pivots,
+  const Tensor& X
+) {
+  auto dL = dLU_data.tril(-1);
+  auto dU = dLU_data.triu();
+
+  // From the derivations from above we have that:
+  // dX = -U^{-1} dU U^{-1} L^{-1} P^T B - U^{-1} L^{-1} dL L^{-1} P^T B + U^{-1} L^{-1} P^T dB,
+  // or, using that X = (LU)^{-1} P^T B,
+  // dX = -U^{-1} dU X - (LU)^{-1} dL U X + (LU)^{-1} P^T dB
+
+  // -U^{-1} dU X
+  auto U = LU_data.triu();
+  auto dU_part = -std::get<0>(at::triangular_solve(
+    dU.matmul(X),
+    U,
+    /*upper=*/true
+  ));
+
+  // (LU)^{-1} dL U X,
+  // we use lu_solve to solve this system which requires pivots which are returned by the lu routine.
+  // Since no pivoting is required for the system, we create a tensor of identity permutations
+  // which are 1-based because of the Fortran-like LAPACK interfaces.
+  auto identity_pivots = at::arange(1, LU_data.size(-1) + 1, LU_pivots.options()).expand(LU_pivots.sizes());
+  auto dL_part = at::lu_solve(dL.matmul(U).matmul(X), LU_data, identity_pivots);
+
+  // (LU)^{-1} P^T dB
+  auto dB_part = at::lu_solve(dB, LU_data, LU_pivots);
+
+  return dU_part - dL_part + dB_part;
+}
+
 Tensor lu_unpack_backward(
   const variable_list& grads,
   const Tensor& LU_data,
index 6684bcb..1aa1062 100644 (file)
@@ -253,6 +253,13 @@ std::tuple<Tensor, Tensor> lu_solve_backward(
   const Tensor& LU_data,
   const Tensor& LU_pivots
 );
+Tensor lu_solve_forward_AD(
+  const Tensor& B_t,
+  const Tensor& LU_data_t,
+  const Tensor& LU_data,
+  const Tensor& LU_pivots,
+  const Tensor& X
+);
 Tensor lu_unpack_backward(
   const variable_list& grads,
   const Tensor& LU_data,
index b1fec69..40ae6b4 100644 (file)
@@ -7101,6 +7101,7 @@ op_db: List[OpInfo] = [
            op=torch.lu_solve,
            dtypes=floating_and_complex_types(),
            check_batched_gradgrad=False,
+           supports_forward_ad=True,
            sample_inputs_func=sample_inputs_lu_solve,
            decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack]),
     OpInfo('lu_unpack',