Tensor addcmul(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1) const;
Tensor addcdiv(const Tensor & tensor1, const Tensor & tensor2, Scalar value=1) const;
std::tuple<Tensor,Tensor> gels(const Tensor & A) const;
- std::tuple<Tensor,Tensor> trtrs(const Tensor & A, bool upper=true, bool transpose=false, bool unitriangular=false) const;
+ std::tuple<Tensor,Tensor> triangular_solve(const Tensor & A, bool upper=true, bool transpose=false, bool unitriangular=false) const;
std::tuple<Tensor,Tensor> symeig(bool eigenvectors=false, bool upper=true) const;
std::tuple<Tensor,Tensor> eig(bool eigenvectors=false) const;
std::tuple<Tensor,Tensor,Tensor> svd(bool some=true, bool compute_uv=true) const;
inline std::tuple<Tensor,Tensor> Tensor::gels(const Tensor & A) const {
return type().gels(*this, A);
}
-inline std::tuple<Tensor,Tensor> Tensor::trtrs(const Tensor & A, bool upper, bool transpose, bool unitriangular) const {
- return type().trtrs(*this, A, upper, transpose, unitriangular);
+inline std::tuple<Tensor,Tensor> Tensor::triangular_solve(const Tensor & A, bool upper, bool transpose, bool unitriangular) const {
+ return type().triangular_solve(*this, A, upper, transpose, unitriangular);
}
inline std::tuple<Tensor,Tensor> Tensor::symeig(bool eigenvectors, bool upper) const {
return type().symeig(*this, eigenvectors, upper);
virtual Tensor addcmul(const Tensor & self, const Tensor & tensor1, const Tensor & tensor2, Scalar value) const = 0;
virtual Tensor addcdiv(const Tensor & self, const Tensor & tensor1, const Tensor & tensor2, Scalar value) const = 0;
virtual std::tuple<Tensor,Tensor> gels(const Tensor & self, const Tensor & A) const = 0;
- virtual std::tuple<Tensor,Tensor> trtrs(const Tensor & self, const Tensor & A, bool upper, bool transpose, bool unitriangular) const = 0;
+ virtual std::tuple<Tensor,Tensor> triangular_solve(const Tensor & self, const Tensor & A, bool upper, bool transpose, bool unitriangular) const = 0;
virtual std::tuple<Tensor,Tensor> symeig(const Tensor & self, bool eigenvectors, bool upper) const = 0;
virtual std::tuple<Tensor,Tensor> eig(const Tensor & self, bool eigenvectors) const = 0;
virtual std::tuple<Tensor,Tensor,Tensor> svd(const Tensor & self, bool some, bool compute_uv) const = 0;
_(aten, topk) \
_(aten, trace) \
_(aten, transpose) \
+_(aten, triangular_solve) \
_(aten, tril) \
_(aten, triplet_margin_loss) \
_(aten, triu) \
-_(aten, trtrs) \
_(aten, trunc) \
_(aten, type_as) \
_(aten, unbind) \
}
template<class scalar_t>
-void lapackTrtrs(char uplo, char trans, char diag, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info) {
- AT_ERROR("trtrs only takes float or double Tensors");
+void lapackTriangularSolve(char uplo, char trans, char diag, int n, int nrhs, scalar_t *a, int lda, scalar_t *b, int ldb, int *info) {
+ AT_ERROR("triangular_solve only takes float or double Tensors");
}
#ifdef USE_LAPACK
spotrf_(&uplo, &n, a, &lda, info);
}
-template<> void lapackTrtrs<double>(char uplo, char trans, char diag, int n, int nrhs, double *a, int lda, double *b, int ldb, int *info) {
+template<> void lapackTriangularSolve<double>(char uplo, char trans, char diag, int n, int nrhs, double *a, int lda, double *b, int ldb, int *info) {
dtrtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info);
}
-template<> void lapackTrtrs<float>(char uplo, char trans, char diag, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) {
+template<> void lapackTriangularSolve<float>(char uplo, char trans, char diag, int n, int nrhs, float *a, int lda, float *b, int ldb, int *info) {
strtrs_(&uplo, &trans, &diag, &n, &nrhs, a, &lda, b, &ldb, info);
}
#endif
}
Tensor& cholesky_solve_out(Tensor& result, const Tensor& self, const Tensor& A, bool upper) {
- AT_CHECK(self.dim() == 2 && A.dim() == 2,
- "torch.cholesky_solve() with the `out` keyword does not support batching. "
- "b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2.");
Tensor result_tmp;
result_tmp = at::_cholesky_solve_helper(self, A, upper);
result.resize_as_(result_tmp).copy_(result_tmp);
return result;
}
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ trtrs ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template<typename scalar_t>
-static void apply_trtrs(Tensor& b, Tensor& A, bool upper, bool transpose, bool unitriangular, std::vector<int64_t>& infos) {
+static void apply_triangular_solve(Tensor& b, Tensor& A, bool upper, bool transpose, bool unitriangular) {
#ifndef USE_LAPACK
- AT_ERROR("trtrs: LAPACK library not found in compilation");
+ AT_ERROR("triangular_solve: LAPACK library not found in compilation");
#else
char uplo = upper ? 'U' : 'L';
char trans = transpose ? 'T' : 'N';
int info;
if (b.dim() == 2) {
- lapackTrtrs<scalar_t>(uplo, trans, diag, n, nrhs, A_data, n, b_data, n, &info);
- infos[0] = info;
+ lapackTriangularSolve<scalar_t>(uplo, trans, diag, n, nrhs, A_data, n, b_data, n, &info);
} else {
auto A_mat_stride = matrixStride(A);
auto b_mat_stride = matrixStride(b);
for (int64_t i = 0; i < batch_size; i++) {
scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
- lapackTrtrs<scalar_t>(uplo, trans, diag, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info);
- infos[i] = info;
- if (info != 0) {
- return;
- }
+ lapackTriangularSolve<scalar_t>(uplo, trans, diag, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info);
}
}
#endif
}
-std::tuple<Tensor, Tensor> _trtrs_helper_cpu(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) {
+std::tuple<Tensor, Tensor> _triangular_solve_helper_cpu(const Tensor& self, const Tensor& A,
+ bool upper, bool transpose, bool unitriangular) {
auto self_working_copy = cloneBatchedColumnMajor(self);
auto A_working_copy = cloneBatchedColumnMajor(A);
- std::vector<int64_t> infos(batchCount(self), 0);
- AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "trtrs_cpu", [&]{
- apply_trtrs<scalar_t>(self_working_copy, A_working_copy, upper, transpose, unitriangular, infos);
+ AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "triangular_solve_cpu", [&]{
+ apply_triangular_solve<scalar_t>(self_working_copy, A_working_copy, upper, transpose, unitriangular);
});
- if (self.dim() > 2) {
- batchCheckErrors(infos, "trtrs_cpu");
- } else {
- singleCheckErrors(infos[0], "trtrs_cpu");
- }
return std::tuple<Tensor, Tensor>(self_working_copy, A_working_copy);
}
// Supports arbitrary batch dimensions for self and A
-std::tuple<Tensor, Tensor> trtrs(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) {
+std::tuple<Tensor, Tensor> triangular_solve(const Tensor& self, const Tensor& A,
+ bool upper, bool transpose, bool unitriangular) {
AT_CHECK(self.dim() >= 2,
"b should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
AT_CHECK(A.dim() >= 2,
"u should have at least 2 dimensions, but has ", A.dim(), " dimensions instead");
Tensor self_broadcasted, A_broadcasted;
std::tie(self_broadcasted, A_broadcasted) = _linear_solve_broadcast_args(self, A);
- return at::_trtrs_helper(self_broadcasted, A_broadcasted, upper, transpose, unitriangular);
+ return at::_triangular_solve_helper(self_broadcasted, A_broadcasted, upper, transpose, unitriangular);
}
-std::tuple<Tensor&, Tensor&> trtrs_out(Tensor& result, Tensor& clone_A,
- const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) {
- AT_CHECK(self.dim() == 2 && A.dim() == 2,
- "torch.trtrs() with the `out` keyword does not support batching. "
- "b.dim() (", self.dim(), ") and A.dim() (", A.dim(), ") must both be 2.");
+std::tuple<Tensor&, Tensor&> triangular_solve_out(Tensor& result, Tensor& clone_A, const Tensor& self, const Tensor& A,
+ bool upper, bool transpose, bool unitriangular) {
Tensor result_tmp, clone_A_tmp;
- std::tie(result_tmp, clone_A_tmp) = at::_trtrs_helper(self, A, upper, transpose, unitriangular);
+ std::tie(result_tmp, clone_A_tmp) = at::_triangular_solve_helper(self, A, upper, transpose, unitriangular);
result.resize_as_(result_tmp).copy_(result_tmp);
clone_A.resize_as_(clone_A_tmp).copy_(clone_A_tmp);
return std::tuple<Tensor&, Tensor&>(result, clone_A);
}
template<class scalar_t>
-void magmaTrsm(
+void magmaTriangularSolve(
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
scalar_t* dA, magma_int_t ldda, scalar_t* dB, magma_int_t lddb) {
- AT_ERROR("trtrs only takes float or double Tensors");
+ AT_ERROR("triangular_solve only takes float or double Tensors");
}
template<class scalar_t>
-void magmaTrsmBatched(
+void magmaTriangularSolveBatched(
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
scalar_t** dA_array, magma_int_t ldda, scalar_t** dB_array, magma_int_t lddb, magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
- AT_ERROR("trtrs only takes float or double Tensors");
+ AT_ERROR("triangular_solve only takes float or double Tensors");
}
template<>
}
template<>
-void magmaTrsm<double>(
+void magmaTriangularSolve<double>(
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
double* dA, magma_int_t ldda, double* dB, magma_int_t lddb) {
magma_dtrsm(MagmaLeft, uplo, trans, diag, m, n, 1, dA, ldda, dB, lddb);
}
template<>
-void magmaTrsm<float>(
+void magmaTriangularSolve<float>(
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
float* dA, magma_int_t ldda, float* dB, magma_int_t lddb) {
magma_strsm(MagmaLeft, uplo, trans, diag, m, n, 1, dA, ldda, dB, lddb);
}
template<>
-void magmaTrsmBatched<double>(
+void magmaTriangularSolveBatched<double>(
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
double** dA_array, magma_int_t ldda, double** dB_array, magma_int_t lddb, magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
}
template<>
-void magmaTrsmBatched<float>(
+void magmaTriangularSolveBatched<float>(
magma_uplo_t uplo, magma_trans_t trans, magma_diag_t diag, magma_int_t m, magma_int_t n,
float** dA_array, magma_int_t ldda, float** dB_array, magma_int_t lddb, magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
return triu_tril_cuda_template<true>(result, self_c, k, "triu");
}
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ trsm ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangular_solve ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <typename scalar_t>
-static void apply_trsm(Tensor& b, Tensor& A, bool upper, bool transpose, bool unitriangular) {
+static void apply_triangular_solve(Tensor& b, Tensor& A, bool upper, bool transpose, bool unitriangular) {
#ifndef USE_MAGMA
AT_ERROR("cholesky_solve: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
magma_int_t nrhs = magma_int_cast(b.size(-1), "b.size(-1)");
if (b.dim() == 2) {
- magmaTrsm<scalar_t>(uplo, trans, diag, n, nrhs, A_data, n, b_data, n);
+ magmaTriangularSolve<scalar_t>(uplo, trans, diag, n, nrhs, A_data, n, b_data, n);
} else {
auto A_mat_stride = matrixStride(A);
auto b_mat_stride = matrixStride(b);
}
MAGMAQueue magma_queue(b.get_device());
- magmaTrsmBatched<scalar_t>(
+ magmaTriangularSolveBatched<scalar_t>(
uplo, trans, diag, n, nrhs, A_array, n,
b_array, n, batch_size, magma_queue);
}
#endif
}
-std::tuple<Tensor, Tensor> _trsm_helper_cuda(const Tensor& self, const Tensor& A, bool upper, bool transpose, bool unitriangular) {
+std::tuple<Tensor, Tensor> _triangular_solve_helper_cuda(const Tensor& self, const Tensor& A,
+ bool upper, bool transpose, bool unitriangular) {
auto self_working_copy = cloneBatchedColumnMajor(self);
auto A_working_copy = cloneBatchedColumnMajor(A);
- AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "trsm_cuda", [&]{
- apply_trsm<scalar_t>(self_working_copy, A_working_copy, upper, transpose, unitriangular);
+ AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "triangular_solve_cuda", [&]{
+ apply_triangular_solve<scalar_t>(self_working_copy, A_working_copy, upper, transpose, unitriangular);
});
return std::tuple<Tensor, Tensor>(self_working_copy, A_working_copy);
}
matches_jit_signature: True
variants: method, function
-- func: trtrs(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!), Tensor(b!))
+- func: triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!), Tensor(b!))
matches_jit_signature: True
-- func: trtrs(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor, Tensor)
+- func: triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor, Tensor)
matches_jit_signature: True
variants: method, function
-- func: _trtrs_helper(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular) -> (Tensor, Tensor)
+- func: _triangular_solve_helper(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular) -> (Tensor, Tensor)
matches_jit_signature: True
variants: function
dispatch:
- CPU: _trtrs_helper_cpu
- CUDA: _trsm_helper_cuda
+ CPU: _triangular_solve_helper_cpu
+ CUDA: _triangular_solve_helper_cuda
- func: symeig(Tensor self, bool eigenvectors=False, bool upper=True, *, Tensor(a!) e, Tensor(b!) V) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
matches_jit_signature: True
.. automethod:: trace
.. automethod:: transpose
.. automethod:: transpose_
+ .. automethod:: triangular_solve
.. automethod:: tril
.. automethod:: tril_
.. automethod:: triu
.. autofunction:: solve
.. autofunction:: svd
.. autofunction:: symeig
+.. autofunction:: triangular_solve
.. autofunction:: trtrs
Utilities
run_test(upper, dims)
@skipIfNoLapack
- def test_trtrs(self):
+ def test_triangular_solve(self):
def _test_with_size(A_dims, B_dims):
A = torch.rand(*A_dims).requires_grad_()
b = torch.rand(*B_dims).requires_grad_()
for upper, transpose, unitriangular in product((True, False), repeat=3):
def func(A, b):
- return torch.trtrs(b, A, upper, transpose, unitriangular)
+ return torch.triangular_solve(b, A, upper, transpose, unitriangular)
gradcheck(func, [A, b])
gradgradcheck(func, [A, b])
_TestTorchMixin._test_geqrf(self, lambda t: t.cuda())
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
- def test_trtrs(self):
- _TestTorchMixin._test_trtrs(self, lambda t: t.cuda())
+ def test_triangular_solve(self):
+ _TestTorchMixin._test_triangular_solve(self, lambda t: t.cuda())
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
- def test_trtrs_batched(self):
- _TestTorchMixin._test_trtrs_batched(self, lambda t: t.cuda())
+ def test_triangular_solve_batched(self):
+ _TestTorchMixin._test_triangular_solve_batched(self, lambda t: t.cuda())
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
- def test_trtrs_batched_dims(self):
- _TestTorchMixin._test_trtrs_batched_dims(self, lambda t: t.cuda())
+ def test_triangular_solve_batched_dims(self):
+ _TestTorchMixin._test_triangular_solve_batched_dims(self, lambda t: t.cuda())
@unittest.skipIf(not TEST_MULTIGPU, "only one GPU detected")
def test_get_set_rng_state_all(self):
self._test_geqrf(self, lambda t: t)
@staticmethod
- def _test_trtrs(self, cast):
+ def _test_triangular_solve(self, cast):
a = torch.Tensor(((6.80, -2.11, 5.66, 5.97, 8.23),
(-6.05, -3.30, 5.36, -4.44, 1.08),
(-0.45, 2.58, -2.70, 0.27, 9.04),
L = torch.tril(a)
# solve Ux = b
- x = torch.trtrs(b, U)[0]
+ x = torch.triangular_solve(b, U)[0]
self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
- x = torch.trtrs(b, U, True, False, False)[0]
+ x = torch.triangular_solve(b, U, True, False, False)[0]
self.assertLessEqual(b.dist(torch.mm(U, x)), 1e-12)
# solve Lx = b
- x = torch.trtrs(b, L, False)[0]
+ x = torch.triangular_solve(b, L, False)[0]
self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
- x = torch.trtrs(b, L, False, False, False)[0]
+ x = torch.triangular_solve(b, L, False, False, False)[0]
self.assertLessEqual(b.dist(torch.mm(L, x)), 1e-12)
# solve U'x = b
- x = torch.trtrs(b, U, True, True)[0]
+ x = torch.triangular_solve(b, U, True, True)[0]
self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
- x = torch.trtrs(b, U, True, True, False)[0]
+ x = torch.triangular_solve(b, U, True, True, False)[0]
self.assertLessEqual(b.dist(torch.mm(U.t(), x)), 1e-12)
# solve U'x = b by manual transposition
- y = torch.trtrs(b, U.t(), False, False)[0]
+ y = torch.triangular_solve(b, U.t(), False, False)[0]
self.assertLessEqual(x.dist(y), 1e-12)
# solve L'x = b
- x = torch.trtrs(b, L, False, True)[0]
+ x = torch.triangular_solve(b, L, False, True)[0]
self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
- x = torch.trtrs(b, L, False, True, False)[0]
+ x = torch.triangular_solve(b, L, False, True, False)[0]
self.assertLessEqual(b.dist(torch.mm(L.t(), x)), 1e-12)
# solve L'x = b by manual transposition
- y = torch.trtrs(b, L.t(), True, False)[0]
+ y = torch.triangular_solve(b, L.t(), True, False)[0]
self.assertLessEqual(x.dist(y), 1e-12)
# test reuse
- res1 = torch.trtrs(b, a)[0]
+ res1 = torch.triangular_solve(b, a)[0]
ta = cast(torch.Tensor())
tb = cast(torch.Tensor())
- torch.trtrs(b, a, out=(tb, ta))
+ torch.triangular_solve(b, a, out=(tb, ta))
self.assertEqual(res1, tb, 0)
tb.zero_()
- torch.trtrs(b, a, out=(tb, ta))
+ torch.triangular_solve(b, a, out=(tb, ta))
self.assertEqual(res1, tb, 0)
@skipIfNoLapack
- def test_trtrs(self):
- self._test_trtrs(self, lambda t: t)
+ def test_triangular_solve(self):
+ self._test_triangular_solve(self, lambda t: t)
@staticmethod
- def _test_trtrs_batched(self, cast):
- def trtrs_test_helper(A_dims, b_dims, cast, upper, unitriangular):
+ def _test_triangular_solve_batched(self, cast):
+ def triangular_solve_test_helper(A_dims, b_dims, cast, upper, unitriangular):
A = cast(torch.randn(*A_dims))
A = A.triu() if upper else A.tril()
if unitriangular:
return A, b
for upper, transpose, unitriangular in product([True, False], repeat=3):
- # test against trtrs: one batch with all possible arguments
- A, b = trtrs_test_helper((1, 5, 5), (1, 5, 10), cast, upper, unitriangular)
- x_exp = torch.trtrs(b.squeeze(0), A.squeeze(0),
- upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
- x = torch.trtrs(b, A,
- upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
+ # test against triangular_solve: one batch with all possible arguments
+ A, b = triangular_solve_test_helper((1, 5, 5), (1, 5, 10), cast, upper, unitriangular)
+ x_exp = torch.triangular_solve(b.squeeze(0), A.squeeze(0),
+ upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
+ x = torch.triangular_solve(b, A,
+ upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
self.assertEqual(x, x_exp.unsqueeze(0))
- # test against trtrs in a loop: four batches with all possible arguments
- A, b = trtrs_test_helper((4, 5, 5), (4, 5, 10), cast, upper, unitriangular)
+ # test against triangular_solve in a loop: four batches with all possible arguments
+ A, b = triangular_solve_test_helper((4, 5, 5), (4, 5, 10), cast, upper, unitriangular)
x_exp_list = []
for i in range(4):
- x_exp = torch.trtrs(b[i], A[i],
- upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
+ x_exp = torch.triangular_solve(b[i], A[i],
+ upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
x_exp_list.append(x_exp)
x_exp = torch.stack(x_exp_list)
- x = torch.trtrs(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
+ x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
self.assertEqual(x, x_exp)
# basic correctness test
- A, b = trtrs_test_helper((3, 5, 5), (3, 5, 10), cast, upper, unitriangular)
- x = torch.trtrs(b, A, upper=upper, unitriangular=unitriangular, transpose=transpose)[0]
+ A, b = triangular_solve_test_helper((3, 5, 5), (3, 5, 10), cast, upper, unitriangular)
+ x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
if transpose:
self.assertLessEqual(b.dist(torch.matmul(A.transpose(-1, -2), x)), 2e-12)
else:
self.assertLessEqual(b.dist(torch.matmul(A, x)), 2e-12)
@skipIfNoLapack
- def test_trtrs_batched(self):
- _TestTorchMixin._test_trtrs_batched(self, lambda t: t)
+ def test_triangular_solve_batched(self):
+ _TestTorchMixin._test_triangular_solve_batched(self, lambda t: t)
@staticmethod
- def _test_trtrs_batched_dims(self, cast):
+ def _test_triangular_solve_batched_dims(self, cast):
if not TEST_SCIPY:
return
x_exp = torch.Tensor(scipy_tri_solve_batched(A.numpy(), b.numpy(),
upper, transpose, unitriangular))
A, b = cast(A), cast(b)
- x = torch.trtrs(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
+ x = torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular)[0]
self.assertEqual(x, cast(x_exp))
run_test((1, 3, 1, 4, 4), (2, 1, 3, 4, 5), cast, upper, transpose, unitriangular) # broadcasting A & b
@skipIfNoLapack
- def test_trtrs_batched_dims(self):
- self._test_trtrs_batched_dims(self, lambda t: t)
+ def test_triangular_solve_batched_dims(self):
+ self._test_triangular_solve_batched_dims(self, lambda t: t)
@skipIfNoLapack
def test_gels(self):
- name: transpose_(Tensor self, int64_t dim0, int64_t dim1)
self: grad.transpose(dim0, dim1)
+- name: triangular_solve(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular)
+ self, A: triangular_solve_backward(grads[0], grads[1], self, A, result0, upper, transpose, unitriangular, grad_input_mask)
+
- name: tril(Tensor self, int64_t diagonal)
self: grad.tril(diagonal)
- name: triu(Tensor self, int64_t diagonal)
self: grad.triu(diagonal)
-- name: trtrs(Tensor self, Tensor A, bool upper, bool transpose, bool unitriangular)
- self, A: trtrs_backward(grads[0], grads[1], self, A, result0, upper, transpose, unitriangular, grad_input_mask)
-
- name: trunc(Tensor self)
self: zeros_like(grad)
'_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*',
'_th_.*', '_thnn_.*',
'arange.*', 'range.*', '_solve.*', '_getri.*', '_inverse.*',
- '_cholesky.*', '_btrifact.*',
+ '_cholesky.*', '_btrifact.*', '_triangular_solve.*',
'slice', 'randint(_out)?',
'item', '_local_scalar_dense',
'max_pool1d', 'max_pool2d', 'max_pool3d', 'linear', 'to',
// Reference:
// https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
// Sec. 2.3.1 Matrix inverse product
-std::tuple<Tensor, Tensor> trtrs_backward(
+std::tuple<Tensor, Tensor> triangular_solve_backward(
const Tensor & grad_x, const Tensor & grad_m,
const Tensor & b, const Tensor & a, const Tensor & x,
const bool upper, const bool transpose, const bool unitriangular,
std::array<bool, 2> output_mask) {
Tensor grad_b, grad_a;
if (grad_x.defined()) {
- grad_b = std::get<0>(grad_x.trtrs(a, upper, !transpose, unitriangular));
+ grad_b = std::get<0>(grad_x.triangular_solve(a, upper, !transpose, unitriangular));
if (output_mask[1]) {
grad_a = transpose ? -x.matmul(grad_b.transpose(-1, -2)) : -grad_b.matmul(x.transpose(-1, -2));
if (upper) {
In-place version of :meth:`~Tensor.transpose`
""")
+add_docstr_all('triangular_solve',
+ r"""
+triangular_solve(A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor)
+
+See :func:`torch.triangular_solve`
+""")
+
add_docstr_all('tril',
r"""
tril(k=0) -> Tensor
In-place version of :meth:`~Tensor.triu`
""")
-add_docstr_all('trtrs',
- r"""
-trtrs(A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor)
-
-See :func:`torch.trtrs`
-""")
-
add_docstr_all('trunc',
r"""
trunc() -> Tensor
tensor(1.9073e-06)
""")
+add_docstr(torch.isnan,
+ r"""
+Returns a new tensor with boolean elements representing if each element is `NaN` or not.
+
+Arguments:
+ tensor (Tensor): A tensor to check
+
+Returns:
+ Tensor: A ``torch.ByteTensor`` containing a 1 at each location of `NaN` elements.
+
+Example::
+
+ >>> torch.isnan(torch.tensor([1, float('nan'), 2]))
+ tensor([ 0, 1, 0], dtype=torch.uint8)
+""")
+
add_docstr(torch.is_floating_point,
r"""
is_floating_point(tensor) -> (bool)
[ 0.5809, 0.4942]])
""")
+add_docstr(torch.triangular_solve,
+ r"""
+triangular_solve(b, A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor)
+
+Solves a system of equations with a triangular coefficient matrix :math:`A`
+and multiple right-hand sides :attr:`b`.
+
+In particular, solves :math:`AX = b` and assumes :math:`A` is upper-triangular
+with the default keyword arguments.
+
+`torch.triangular_solve(b, A)` can take in 2D inputs `b, A` or inputs that are
+batches of 2D matrices. If the inputs are batches, then returns
+batched outputs `X`
+
+.. note::
+
+ The :attr:`out` keyword only supports 2D matrix inputs, that is,
+ `b, A` must be 2D matrices.
+
+Args:
+ A (Tensor): the input triangular coefficient matrix of size :math:`(*, m, m)`
+ where :math:`*` is zero or more batch dimensions
+ b (Tensor): multiple right-hand sides of size :math:`(*, m, k)` where
+ :math:`*` is zero of more batch dimensions
+ upper (bool, optional): whether to solve the upper-triangular system
+ of equations (default) or the lower-triangular system of equations. Default: ``True``.
+ transpose (bool, optional): whether :math:`A` should be transposed before
+ being sent into the solver. Default: ``False``.
+ unitriangular (bool, optional): whether :math:`A` is unit triangular.
+ If True, the diagonal elements of :math:`A` are assumed to be
+ 1 and not referenced from :math:`A`. Default: ``False``.
+
+Returns:
+ A tuple :math:`(X, M)` where :math:`M` is a clone of :math:`A` and :math:`X`
+ is the solution to :math:`AX = b` (or whatever variant of the system of
+ equations, depending on the keyword arguments.)
+
+Examples::
+
+ >>> A = torch.randn(2, 2).triu()
+ >>> A
+ tensor([[ 1.1527, -1.0753],
+ [ 0.0000, 0.7986]])
+ >>> b = torch.randn(2, 3)
+ >>> b
+ tensor([[-0.0210, 2.3513, -1.5492],
+ [ 1.5429, 0.7403, -1.0243]])
+ >>> torch.triangular_solve(b, A)
+ (tensor([[ 1.7840, 2.9045, -2.5405],
+ [ 1.9319, 0.9269, -1.2826]]), tensor([[ 1.1527, -1.0753],
+ [ 0.0000, 0.7986]]))
+""")
+
add_docstr(torch.tril,
r"""
tril(input, diagonal=0, out=None) -> Tensor
[1, 2, 2]])
""".format(**factory_common_args))
-add_docstr(torch.trtrs,
- r"""
-trtrs(b, A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor)
-
-Solves a system of equations with a triangular coefficient matrix :math:`A`
-and multiple right-hand sides :attr:`b`.
-
-In particular, solves :math:`AX = b` and assumes :math:`A` is upper-triangular
-with the default keyword arguments.
-
-`torch.trtrs(b, A)` can take in 2D inputs `b, A` or inputs that are
-batches of 2D matrices. If the inputs are batches, then returns
-batched outputs `X`
-
-.. note::
-
- The :attr:`out` keyword only supports 2D matrix inputs, that is,
- `b, A` must be 2D matrices.
-
-Args:
- A (Tensor): the input triangular coefficient matrix of size :math:`(*, m, m)`
- where :math:`*` is zero or more batch dimensions
- b (Tensor): multiple right-hand sides of size :math:`(*, m, k)` where
- :math:`*` is zero of more batch dimensions
- upper (bool, optional): whether to solve the upper-triangular system
- of equations (default) or the lower-triangular system of equations. Default: ``True``.
- transpose (bool, optional): whether :math:`A` should be transposed before
- being sent into the solver. Default: ``False``.
- unitriangular (bool, optional): whether :math:`A` is unit triangular.
- If True, the diagonal elements of :math:`A` are assumed to be
- 1 and not referenced from :math:`A`. Default: ``False``.
-
-Returns:
- A tuple :math:`(X, M)` where :math:`M` is a clone of :math:`A` and :math:`X`
- is the solution to :math:`AX = b` (or whatever variant of the system of
- equations, depending on the keyword arguments.)
-
-Examples::
-
- >>> A = torch.randn(2, 2).triu()
- >>> A
- tensor([[ 1.1527, -1.0753],
- [ 0.0000, 0.7986]])
- >>> b = torch.randn(2, 3)
- >>> b
- tensor([[-0.0210, 2.3513, -1.5492],
- [ 1.5429, 0.7403, -1.0243]])
- >>> torch.trtrs(b, A)
- (tensor([[ 1.7840, 2.9045, -2.5405],
- [ 1.9319, 0.9269, -1.2826]]), tensor([[ 1.1527, -1.0753],
- [ 0.0000, 0.7986]]))
-""")
-
add_docstr(torch.trunc,
r"""
trunc(input, out=None) -> Tensor
# = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T)
qWt_qDinv = (q._unbroadcasted_cov_factor.transpose(-1, -2) /
q._unbroadcasted_cov_diag.unsqueeze(-2))
- A = torch.trtrs(qWt_qDinv, q._capacitance_tril, upper=False)[0]
+ A = torch.triangular_solve(qWt_qDinv, q._capacitance_tril, upper=False)[0]
term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1)
term22 = _batch_trace_XXT(p._unbroadcasted_cov_factor *
q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1))
# = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T
qWt_qDinv = (q._unbroadcasted_cov_factor.transpose(-1, -2) /
q._unbroadcasted_cov_diag.unsqueeze(-2))
- A = torch.trtrs(qWt_qDinv, q._capacitance_tril, upper=False)[0]
+ A = torch.triangular_solve(qWt_qDinv, q._capacitance_tril, upper=False)[0]
term21 = _batch_trace_XXT(p._unbroadcasted_scale_tril *
q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1))
term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril))
(n, p.cov_factor.size(-1)))
p_cov_diag = (torch.diag_embed(p._unbroadcasted_cov_diag.sqrt())
.expand(combined_batch_shape + (n, n)))
- term21 = _batch_trace_XXT(torch.trtrs(p_cov_factor, q_scale_tril, upper=False)[0])
- term22 = _batch_trace_XXT(torch.trtrs(p_cov_diag, q_scale_tril, upper=False)[0])
+ term21 = _batch_trace_XXT(torch.triangular_solve(p_cov_factor, q_scale_tril, upper=False)[0])
+ term22 = _batch_trace_XXT(torch.triangular_solve(p_cov_diag, q_scale_tril, upper=False)[0])
term2 = term21 + term22
return 0.5 * (term1 + term2 + term3 - p.event_shape[0])
n = p.event_shape[0]
q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n))
- term2 = _batch_trace_XXT(torch.trtrs(p_scale_tril, q_scale_tril, upper=False)[0])
+ term2 = _batch_trace_XXT(torch.triangular_solve(p_scale_tril, q_scale_tril, upper=False)[0])
term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc))
return half_term1 + 0.5 * (term2 + term3 - n)
# where :math:`C` is the capacitance matrix.
Wt_Dinv = (self._unbroadcasted_cov_factor.transpose(-1, -2)
/ self._unbroadcasted_cov_diag.unsqueeze(-2))
- A = torch.trtrs(Wt_Dinv, self._capacitance_tril, upper=False)[0]
+ A = torch.triangular_solve(Wt_Dinv, self._capacitance_tril, upper=False)[0]
precision_matrix = (torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal())
- torch.matmul(A.transpose(-1, -2), A))
return precision_matrix.expand(self._batch_shape + self._event_shape +
bx_batch_shape = bx.shape[:-1]
# Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n),
- # we are going to make bx have shape (..., 1, j, i, 1, n) to apply _batch_trtrs_lower
+ # we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tri.solve
bx_batch_dims = len(bx_batch_shape)
bL_batch_dims = bL.dim() - 2
outer_batch_dims = bx_batch_dims - bL_batch_dims
flat_L = bL.reshape(-1, n, n) # shape = b x n x n
flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n
flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c
- M_swap = torch.trtrs(flat_x_swap, flat_L, upper=False)[0].pow(2).sum(-2) # shape = b x c
+ M_swap = torch.triangular_solve(flat_x_swap, flat_L, upper=False)[0].pow(2).sum(-2) # shape = b x c
M = M_swap.t() # shape = c x b
# Now we revert the above reshape and permute operators.
import torch
import torch.nn.functional as F
from torch._six import inf
-from torch._C import _add_docstr
-from operator import mul
-from functools import reduce
-from collections import Iterable
-from torch._utils import annotate
from itertools import product
-import math
import warnings
__all__ = [
'broadcast_tensors',
'isfinite',
'isinf',
- 'isnan',
'norm',
'meshgrid',
'potrf',
'split',
'stft',
'tensordot',
+ 'trtrs',
'unique',
'cartesian_prod',
]
return torch._C._VariableFunctions.stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
-isnan = _add_docstr(torch.isnan, r"""
-Returns a new tensor with boolean elements representing if each element is `NaN` or not.
-
-Arguments:
- tensor (Tensor): A tensor to check
-
-Returns:
- Tensor: A ``torch.ByteTensor`` containing a 1 at each location of `NaN` elements.
-
-Example::
-
- >>> torch.isnan(torch.tensor([1, float('nan'), 2]))
- tensor([ 0, 1, 0], dtype=torch.uint8)
-""")
-
-
def unique(input, sorted=True, return_inverse=False, dim=None):
r"""Returns the unique scalar elements of the input tensor as a 1-D tensor.
warnings.warn("torch.gesv is deprecated in favour of torch.solve and will be removed in the "
"next release. Please use torch.solve instead.", stacklevel=2)
return torch.solve(b, A, out=out)
+
+
+def trtrs(b, A, upper=True, transpose=False, unitriangular=False, out=None):
+ r"""Solves a system of equations with a triangular coefficient matrix :math:`A`
+ and multiple right-hand sides :attr:`b`.
+
+ In particular, solves :math:`AX = b` and assumes :math:`A` is upper-triangular
+ with the default keyword arguments.
+
+ .. warning::
+ :func:`torch.trtrs` is deprecated in favour of :func:`torch.triangular_solve` and will be
+ removed in the next release. Please use :func:`torch.triangular_solve` instead.
+ """
+ warnings.warn("torch.trtrs is deprecated in favour of torch.triangular_solve and will be "
+ "removed in the next release. Please use torch.triangular_solve instead.", stacklevel=2)
+ return torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular, out=out)
"next release. Please use torch.solve instead.", stacklevel=2)
return super(Tensor, self).solve(A)
+ def trtrs(self, A, upper=True, transpose=False, unitriangular=False):
+ r"""See :func:`torch.triangular_solve`"""
+ warnings.warn("torch.trtrs is deprecated in favour of torch.triangular_solve and will be "
+ "removed in the next release. Please use torch.triangular_solve.", stacklevel=2)
+ return super(Tensor, self).triangular_solve(A, upper=upper,
+ transpose=transpose, unitriangular=unitriangular)
+
def stft(self, n_fft, hop_length=None, win_length=None, window=None,
center=True, pad_mode='reflect', normalized=False, onesided=True):
r"""See :func:`torch.stft`