- name: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor
self, LU_data: lu_solve_backward(grad, self, LU_data, LU_pivots)
+ result: lu_solve_forward_AD(self_t, LU_data_t, LU_data_p, LU_pivots, result)
- name: lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U)
LU_data: lu_unpack_backward(grads, LU_data, unpack_data)
return std::make_tuple(self_grad, LU_data_grad);
}
+Tensor lu_solve_forward_AD(
+ const Tensor& dB,
+ const Tensor& dLU_data,
+ const Tensor& LU_data,
+ const Tensor& LU_pivots,
+ const Tensor& X
+) {
+ auto dL = dLU_data.tril(-1);
+ auto dU = dLU_data.triu();
+
+ // From the derivations from above we have that:
+ // dX = -U^{-1} dU U^{-1} L^{-1} P^T B - U^{-1} L^{-1} dL L^{-1} P^T B + U^{-1} L^{-1} P^T dB,
+ // or, using that X = (LU)^{-1} P^T B,
+ // dX = -U^{-1} dU X - (LU)^{-1} dL U X + (LU)^{-1} P^T dB
+
+ // -U^{-1} dU X
+ auto U = LU_data.triu();
+ auto dU_part = -std::get<0>(at::triangular_solve(
+ dU.matmul(X),
+ U,
+ /*upper=*/true
+ ));
+
+ // (LU)^{-1} dL U X,
+ // we use lu_solve to solve this system which requires pivots which are returned by the lu routine.
+ // Since no pivoting is required for the system, we create a tensor of identity permutations
+ // which are 1-based because of the Fortran-like LAPACK interfaces.
+ auto identity_pivots = at::arange(1, LU_data.size(-1) + 1, LU_pivots.options()).expand(LU_pivots.sizes());
+ auto dL_part = at::lu_solve(dL.matmul(U).matmul(X), LU_data, identity_pivots);
+
+ // (LU)^{-1} P^T dB
+ auto dB_part = at::lu_solve(dB, LU_data, LU_pivots);
+
+ return dU_part - dL_part + dB_part;
+}
+
Tensor lu_unpack_backward(
const variable_list& grads,
const Tensor& LU_data,
const Tensor& LU_data,
const Tensor& LU_pivots
);
+Tensor lu_solve_forward_AD(
+ const Tensor& B_t,
+ const Tensor& LU_data_t,
+ const Tensor& LU_data,
+ const Tensor& LU_pivots,
+ const Tensor& X
+);
Tensor lu_unpack_backward(
const variable_list& grads,
const Tensor& LU_data,
op=torch.lu_solve,
dtypes=floating_and_complex_types(),
check_batched_gradgrad=False,
+ supports_forward_ad=True,
sample_inputs_func=sample_inputs_lu_solve,
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack]),
OpInfo('lu_unpack',