- name: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info)
self: _lu_with_info_backward(grad, self, LU, pivots)
+ LU: _lu_with_info_jvp(self_t, LU, pivots)
- name: lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor
self, LU_data: lu_solve_backward(grad, self, LU_data, LU_pivots)
return plu_backward_base({/*L_grad=*/grad, /*U_grad=*/grad}, self, P, L, U);
}
+Tensor _lu_with_info_jvp(
+ const Tensor& dX,
+ const Tensor& LU,
+ const Tensor& pivs
+) {
+ // This function is based on the forward AD derivations outlined
+ // in the description to the plu_backward_base function.
+
+ Tensor P, L, U;
+ std::tie(P, L, U) = at::lu_unpack(LU, pivs);
+
+ auto m = LU.size(-2);
+ auto n = LU.size(-1);
+ auto k = std::min(m, n);
+
+ auto pdX = P.transpose(-1, -2).matmul(dX);
+
+ // similar to the backward implementation, we also consider block structures such as:
+ // for a matrix A of size m x n we decompose it as
+ // A = (A1 | A2) with A1 of size m x m if m <= n and
+ // A = (A1^T | A2^T)^T with A1 of size n x n if m > n.
+ auto pdX1 = pdX.narrow(-2, 0, k).narrow(-1, 0, k);
+ auto L1 = L.narrow(-2, 0, k).narrow(-1, 0, k);
+ auto U1 = U.narrow(-2, 0, k).narrow(-1, 0, k);
+
+ // dK = L1^{-1} pdX1
+ auto dK = std::get<0>(at::triangular_solve(
+ pdX1,
+ L1,
+ /*upper=*/false,
+ /*transpose=*/false,
+ /*unitriangular=*/true
+ ));
+ // dK <- dK U1^{-1}
+ dK = std::get<0>(at::triangular_solve(
+ dK.transpose(-1, -2),
+ U1,
+ /*upper=*/true,
+ /*transpose=*/true
+ )).transpose(-1, -2);
+
+ auto dL1 = L1.matmul(dK.tril(-1));
+ auto dU1 = dK.triu().matmul(U1);
+
+ // since LU = L + U - I, we have that dLU = dL + dU
+ // if LU is of size m x n, we always have
+ // dLU1 = dL1 + dU1, where the block indexing follows the rules
+ // outlined above.
+ if (m == n) {
+ return dL1 + dU1;
+ }
+ else {
+ auto dLU = at::zeros_like(LU);
+ dLU.narrow(-2, 0, k).narrow(-1, 0, k).copy_(dL1 + dU1);
+
+ if (m < n) {
+ // we only need to update dU2 defined as
+ // dU2 := L1^{-1} (pdX2 - dL1 U2)
+ auto pdX2 = pdX.narrow(-1, k, n - k);
+ auto U2 = U.narrow(-1, k, n - k);
+ dLU.narrow(-1, k, n - k).copy_(std::get<0>(at::triangular_solve(
+ pdX2 - dL1.matmul(U2),
+ L1,
+ /*upper=*/false,
+ /*transpose=*/false,
+ /*unitriangular=*/true
+ )));
+ }
+ else {
+ // we only need to update dL2 defined as
+ // dL2 := (pdX2 - L2 dU1) U1^{-1}
+ auto pdX2 = pdX.narrow(-2, k, m - k);
+ auto L2 = L.narrow(-2, k, m - k);
+ dLU.narrow(-2, k, m - k).copy_(std::get<0>(at::triangular_solve(
+ (pdX2 - L2.matmul(dU1)).transpose(-1, -2),
+ U1,
+ /*upper=*/true,
+ /*transpose=*/true
+ )).transpose(-1, -2));
+ }
+
+ return dLU;
+ }
+}
+
} // namespace details
} // namespace generated
} // namespace autograd