std::tuple<Tensor,Tensor> geqrf() const;
Tensor orgqr(const Tensor & input2) const;
Tensor ormqr(const Tensor & input2, const Tensor & input3, bool left=true, bool transpose=false) const;
- std::tuple<Tensor,Tensor> btrifact(bool pivot=true) const;
- std::tuple<Tensor,Tensor,Tensor> btrifact_with_info(bool pivot=true) const;
Tensor btrisolve(const Tensor & LU_data, const Tensor & LU_pivots) const;
Tensor multinomial(int64_t num_samples, bool replacement=false, Generator * generator=nullptr) const;
Tensor lgamma() const;
inline Tensor Tensor::ormqr(const Tensor & input2, const Tensor & input3, bool left, bool transpose) const {
return type().ormqr(*this, input2, input3, left, transpose);
}
-inline std::tuple<Tensor,Tensor> Tensor::btrifact(bool pivot) const {
- return type().btrifact(*this, pivot);
-}
-inline std::tuple<Tensor,Tensor,Tensor> Tensor::btrifact_with_info(bool pivot) const {
- return type().btrifact_with_info(*this, pivot);
-}
inline Tensor Tensor::btrisolve(const Tensor & LU_data, const Tensor & LU_pivots) const {
return type().btrisolve(*this, LU_data, LU_pivots);
}
virtual std::tuple<Tensor,Tensor> geqrf(const Tensor & self) const = 0;
virtual Tensor orgqr(const Tensor & self, const Tensor & input2) const = 0;
virtual Tensor ormqr(const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose) const = 0;
- virtual std::tuple<Tensor,Tensor> btrifact(const Tensor & self, bool pivot) const = 0;
- virtual std::tuple<Tensor,Tensor,Tensor> btrifact_with_info(const Tensor & self, bool pivot) const = 0;
virtual Tensor btrisolve(const Tensor & self, const Tensor & LU_data, const Tensor & LU_pivots) const = 0;
virtual Tensor multinomial(const Tensor & self, int64_t num_samples, bool replacement, Generator * generator) const = 0;
virtual Tensor lgamma(const Tensor & self) const = 0;
_(aten, _log1p) \
_(aten, _log2) \
_(aten, _logspace) \
+_(aten, _lu_with_info) \
_(aten, _masked_scale) \
_(aten, _mm) \
_(aten, _mv) \
_(aten, blackman_window) \
_(aten, bmm) \
_(aten, broadcast_tensors) \
-_(aten, btrifact) \
-_(aten, btrifact_with_info) \
_(aten, btrisolve) \
_(aten, cartesian_prod) \
_(aten, cat) \
}
template<class scalar_t>
-void lapackGetrf(int m, int n, scalar_t *a, int lda, int *ipiv, int *info) {
- AT_ERROR("getrf only takes float or double Tensors");
+void lapackLu(int m, int n, scalar_t *a, int lda, int *ipiv, int *info) {
+ AT_ERROR("lu only takes float or double Tensors");
}
template<class scalar_t>
sgetri_(&n, a, &lda, ipiv, work, &lwork, info);
}
-template<> void lapackGetrf<double>(int m, int n, double *a, int lda, int *ipiv, int *info) {
+template<> void lapackLu<double>(int m, int n, double *a, int lda, int *ipiv, int *info) {
dgetrf_(&m, &n, a, &lda, ipiv, info);
}
-template<> void lapackGetrf<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
+template<> void lapackLu<float>(int m, int n, float *a, int lda, int *ipiv, int *info) {
sgetrf_(&m, &n, a, &lda, ipiv, info);
}
for (int64_t i = 0; i < batch_size; i++) {
int info;
scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
- lapackGetrf<scalar_t>(n, n, self_working_ptr, n, ipiv.data<int>(), &info);
+ lapackLu<scalar_t>(n, n, self_working_ptr, n, ipiv.data<int>(), &info);
infos[i] = info;
if (info != 0) {
return;
return result;
}
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ btrifact ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template<typename scalar_t>
-static void apply_btrifact(Tensor& self, Tensor& pivots, Tensor& infos) {
+static void apply_lu(Tensor& self, Tensor& pivots, Tensor& infos) {
#ifndef USE_LAPACK
- AT_ERROR("btrifact: LAPACK library not found in compilation");
+ AT_ERROR("lu: LAPACK library not found in compilation");
#else
auto self_data = self.data<scalar_t>();
- auto self_matrix_stride = matrixStride(self);
- auto batch_size = batchCount(self);
-
auto pivots_data = pivots.data<int>();
- auto pivots_matrix_stride = pivots.size(-1);
auto infos_data = infos.data<int>();
auto n = self.size(-1);
- for (int64_t i = 0; i < batch_size; i++) {
- scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
- int* pivots_working_ptr = &pivots_data[i * pivots_matrix_stride];
- int* infos_working_ptr = &infos_data[i];
- lapackGetrf<scalar_t>(n, n, self_working_ptr, n, pivots_working_ptr, infos_working_ptr);
+ if (self.dim() == 2) {
+ lapackLu<scalar_t>(n, n, self_data, n, pivots_data, infos_data);
+ } else {
+ auto self_matrix_stride = matrixStride(self);
+ auto batch_size = batchCount(self);
+ auto pivots_matrix_stride = pivots.size(-1);
+ for (int64_t i = 0; i < batch_size; i++) {
+ scalar_t* self_working_ptr = &self_data[i * self_matrix_stride];
+ int* pivots_working_ptr = &pivots_data[i * pivots_matrix_stride];
+ int* infos_working_ptr = &infos_data[i];
+ lapackLu<scalar_t>(n, n, self_working_ptr, n, pivots_working_ptr, infos_working_ptr);
+ }
}
#endif
}
-std::tuple<Tensor, Tensor, Tensor> _btrifact_helper_cpu(const Tensor& self, bool pivot) {
- AT_CHECK(pivot, "btrifact without pivoting is not implemented on the CPU");
- AT_CHECK(self.dim() > 2,
- "expected tensor with more than 2 dimensions, got size: ", self.sizes(),
+std::tuple<Tensor, Tensor, Tensor> _lu_with_info_cpu(const Tensor& self, bool pivot, bool check_errors) {
+ AT_CHECK(pivot, "lu without pivoting is not implemented on the CPU");
+ AT_CHECK(self.dim() >= 2,
+ "expected tensor with 2 or more dimensions, got size: ", self.sizes(),
" instead");
squareCheckInputs(self);
auto req_size = self.sizes().vec();
req_size.pop_back();
- auto pivots_tensor = at::zeros(req_size, self.options().dtype(kInt));
+ auto pivots_tensor = at::empty(req_size, self.options().dtype(kInt));
req_size.pop_back();
auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt));
self_working_copy = at::empty_like(self);
} else {
self_working_copy = cloneBatchedColumnMajor(self);
- AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "btrifact_cpu", [&]{
- apply_btrifact<scalar_t>(self_working_copy, pivots_tensor, infos_tensor);
+ AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_cpu", [&]{
+ apply_lu<scalar_t>(self_working_copy, pivots_tensor, infos_tensor);
});
}
+ if (check_errors) {
+ if (self.dim() == 2) {
+ singleCheckErrors(infos_tensor.item<int64_t>(), "lu");
+ } else {
+ batchCheckErrors(infos_tensor, "lu");
+ }
+ }
return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor);
}
-std::tuple<Tensor, Tensor> btrifact(const Tensor& self, bool pivot) {
- Tensor LU_fact, pivots, infos;
- std::tie(LU_fact, pivots, infos) = at::_btrifact_helper(self, pivot);
- batchCheckErrors(infos, "btrifact");
- return std::make_tuple(LU_fact, pivots);
-}
-
-std::tuple<Tensor&, Tensor&> btrifact_out(
- Tensor& A_LU,
- Tensor& pivots,
- const Tensor& self,
- bool pivot) {
- Tensor infos, A_LU_tmp, pivots_tmp;
- std::tie(A_LU_tmp, pivots_tmp, infos) = at::_btrifact_helper(self, pivot);
- batchCheckErrors(infos, "btrifact");
- A_LU.resize_as_(A_LU_tmp).copy_(A_LU_tmp);
- pivots.resize_as_(pivots_tmp).copy_(pivots_tmp);
- return std::tuple<Tensor&, Tensor&>(A_LU, pivots);
-}
-
-std::tuple<Tensor, Tensor, Tensor> btrifact_with_info(
- const Tensor& self,
- bool pivot) {
- Tensor LU_fact, pivots, infos;
- std::tie(LU_fact, pivots, infos) = at::_btrifact_helper(self, pivot);
- return std::make_tuple(LU_fact, pivots, infos);
-}
-
-std::tuple<Tensor&, Tensor&, Tensor&> btrifact_with_info_out(
- Tensor& A_LU,
- Tensor& pivots,
- Tensor& info,
- const Tensor& self,
- bool pivot) {
- Tensor info_tmp, A_LU_tmp, pivots_tmp;
- std::tie(A_LU_tmp, pivots_tmp, info_tmp) = at::_btrifact_helper(self, pivot);
- A_LU.resize_as_(A_LU_tmp).copy_(A_LU_tmp);
- pivots.resize_as_(pivots_tmp).copy_(pivots_tmp);
- info.resize_as_(info_tmp).copy_(info_tmp);
- return std::tuple<Tensor&, Tensor&, Tensor&>(A_LU, pivots, info);
-}
-
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triu/tril ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <typename scalar_t, bool upper>
// where info helps us identify singular matrices.
static inline std::tuple<double, Tensor, int> _lu_det_P_diag_U_info(const Tensor& self) {
Tensor p, lu, info;
- std::tie(lu, p, info) = self.unsqueeze(0).btrifact_with_info();
- p.squeeze_(0);
- lu.squeeze_(0);
- int int_info = info.squeeze_().item<int32_t>();
+ std::tie(lu, p, info) = at::_lu_with_info(self, /*pivot=*/true, /*check_errors=*/false);
+ int int_info = info.item<int32_t>();
AT_CHECK(int_info >= 0, "LU factorization (getrf) failed with info = ", int_info);
auto n = self.size(0);
auto num_exchanges = (at::arange(1, n + 1, p.options()) != p).nonzero().size(0);
" but each b matrix is ", self.size(-2), " by ", self.size(-1));
}
-// Validates input shapes for operations on batches of square matrices (inverse, cholesky)
+// Validates input shapes for operations on batches of square matrices (inverse, cholesky, lu)
static inline void squareCheckInputs(const Tensor& self) {
AT_CHECK(self.size(-1) == self.size(-2),
"A must be batches of square matrices, "
}
template<class scalar_t>
-void magmaGetrfBatched(
+void magmaLu(
+ magma_int_t m, magma_int_t n, scalar_t* dA, magma_int_t ldda,
+ magma_int_t* ipiv, magma_int_t* info) {
+ AT_ERROR("lu only takes float or double Tensors");
+}
+
+template<class scalar_t>
+void magmaLuBatched(
magma_int_t m, magma_int_t n, scalar_t** dA_array, magma_int_t ldda,
magma_int_t** ipiv_array, magma_int_t* info_array, magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
- AT_ERROR("getrf only takes float or double Tensors");
+ AT_ERROR("lu only takes float or double Tensors");
+}
+
+template<class scalar_t>
+void magmaLuNoPiv(
+ magma_int_t m, magma_int_t n, scalar_t* dA, magma_int_t ldda,
+ magma_int_t* info) {
+ AT_ERROR("lu only takes float or double Tensors");
}
template<class scalar_t>
-void magmaGetrfNoPivBatched(
+void magmaLuNoPivBatched(
magma_int_t m, magma_int_t n, scalar_t** dA_array, magma_int_t ldda,
magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
- AT_ERROR("getrf only takes float or double Tensors");
+ AT_ERROR("lu only takes float or double Tensors");
}
template<class scalar_t>
}
template<>
-void magmaGetrfBatched<double>(
+void magmaLu<double>(
+ magma_int_t m, magma_int_t n, double* dA, magma_int_t ldda,
+ magma_int_t* ipiv, magma_int_t* info) {
+ magma_dgetrf_gpu(m, n, dA, ldda, ipiv, info);
+}
+
+template<>
+void magmaLu<float>(
+ magma_int_t m, magma_int_t n, float* dA, magma_int_t ldda,
+ magma_int_t* ipiv, magma_int_t* info) {
+ magma_sgetrf_gpu(m, n, dA, ldda, ipiv, info);
+}
+
+template<>
+void magmaLuBatched<double>(
magma_int_t m, magma_int_t n, double** dA_array, magma_int_t ldda,
magma_int_t** ipiv_array, magma_int_t* info_array, magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
}
template<>
-void magmaGetrfBatched<float>(
+void magmaLuBatched<float>(
magma_int_t m, magma_int_t n, float** dA_array, magma_int_t ldda,
magma_int_t** ipiv_array, magma_int_t* info_array, magma_int_t batchsize,
const MAGMAQueue& magma_queue) {
}
template<>
-void magmaGetrfNoPivBatched<double>(
+void magmaLuNoPiv<double>(
+ magma_int_t m, magma_int_t n, double* dA, magma_int_t ldda,
+ magma_int_t* info) {
+ magma_dgetrf_nopiv_gpu(m, n, dA, ldda, info);
+}
+
+template<>
+void magmaLuNoPiv<float>(
+ magma_int_t m, magma_int_t n, float* dA, magma_int_t ldda,
+ magma_int_t* info) {
+ magma_sgetrf_nopiv_gpu(m, n, dA, ldda, info);
+}
+
+template<>
+void magmaLuNoPivBatched<double>(
magma_int_t m, magma_int_t n, double** dA_array, magma_int_t ldda,
magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
magma_dgetrf_nopiv_batched(m, n, dA_array, ldda, info_array, batchsize, magma_queue.get_queue());
}
template<>
-void magmaGetrfNoPivBatched<float>(
+void magmaLuNoPivBatched<float>(
magma_int_t m, magma_int_t n, float** dA_array, magma_int_t ldda,
magma_int_t* info_array, magma_int_t batchsize, const MAGMAQueue& magma_queue) {
magma_sgetrf_nopiv_batched(m, n, dA_array, ldda, info_array, batchsize, magma_queue.get_queue());
}
MAGMAQueue magma_queue(self.get_device());
- magmaGetrfBatched<scalar_t>(
+ magmaLuBatched<scalar_t>(
n, n, self_array, n, ipiv_array, info_array,
batch_size, magma_queue);
}
}
-// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ btrifact ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lu ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
template <typename scalar_t>
-static void apply_btrifact(Tensor& self, Tensor& pivots, Tensor& infos) {
+static void apply_lu(Tensor& self, Tensor& pivots, Tensor& infos, bool get_pivots) {
#ifndef USE_MAGMA
-AT_ERROR("btrifact: MAGMA library not found in "
+AT_ERROR("lu: MAGMA library not found in "
"compilation. Please rebuild with MAGMA.");
#else
auto self_data = self.data<scalar_t>();
- auto self_matrix_stride = matrixStride(self);
- magma_int_t batch_size = magma_int_cast(batchCount(self), "batchCount");
magma_int_t n = magma_int_cast(self.size(-1), "n");
- scalar_t** self_array;
- ALLOCATE_ARRAY(self_array, scalar_t*, batch_size, self);
-
- // Set up the created arrays
- for (int64_t i = 0; i < batch_size; i++) {
- self_array[i] = &self_data[i * self_matrix_stride];
- }
+ if (self.dim() == 2) {
+ // If `pivots` is defined, then we have to compute them.
+ // We will use the normal getrf function to compute the LU factorization
+ // and the pivots
+ // We create temporary tensors on the CPU, because tensors on the GPU
+ // cause segfault when passed to magmaLu and magmaLuNoPiv. The data is later
+ // copied to the appropriate tensors.
+ Tensor info_tmp = at::zeros({}, at::kInt);
+ if (get_pivots) {
+ Tensor piv_tmp = at::empty({n}, at::kInt);
+ magmaLu<scalar_t>(
+ n, n, self_data, n, piv_tmp.data<magma_int_t>(), info_tmp.data<magma_int_t>());
+ pivots.copy_(piv_tmp);
+ } else {
+ magmaLuNoPiv<scalar_t>(n, n, self_data, n, info_tmp.data<magma_int_t>());
+ }
+ infos.copy_(info_tmp);
+ } else {
+ auto self_matrix_stride = matrixStride(self);
+ magma_int_t batch_size = magma_int_cast(batchCount(self), "batchCount");
- MAGMAQueue magma_queue(self.get_device());
+ scalar_t** self_array;
+ ALLOCATE_ARRAY(self_array, scalar_t*, batch_size, self);
- // If `pivots` is defined, then we have to compute them.
- // We will use the normal getrf function to compute the LU factorization
- // and the pivots
- if (pivots.defined()) {
- auto pivots_data = pivots.data<magma_int_t>();
- auto pivots_matrix_stride = pivots.size(-1);
- magma_int_t** pivots_array;
- ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size, pivots);
+ // Set up the created arrays
for (int64_t i = 0; i < batch_size; i++) {
- pivots_array[i] = &pivots_data[i * pivots_matrix_stride];
+ self_array[i] = &self_data[i * self_matrix_stride];
}
- magmaGetrfBatched<scalar_t>(
- n, n, self_array, n, pivots_array,
- infos.data<magma_int_t>(), batch_size, magma_queue);
- } else {
- magmaGetrfNoPivBatched<scalar_t>(
- n, n, self_array, n, infos.data<magma_int_t>(),
- batch_size, magma_queue);
+ MAGMAQueue magma_queue(self.get_device());
+
+ // Same comment as in the case of single matrix above.
+ if (get_pivots) {
+ auto pivots_data = pivots.data<magma_int_t>();
+ auto pivots_matrix_stride = pivots.size(-1);
+ magma_int_t** pivots_array;
+ ALLOCATE_ARRAY(pivots_array, magma_int_t*, batch_size, pivots);
+ for (int64_t i = 0; i < batch_size; i++) {
+ pivots_array[i] = &pivots_data[i * pivots_matrix_stride];
+ }
+ magmaLuBatched<scalar_t>(
+ n, n, self_array, n, pivots_array,
+ infos.data<magma_int_t>(), batch_size, magma_queue);
+ } else {
+ magmaLuNoPivBatched<scalar_t>(
+ n, n, self_array, n, infos.data<magma_int_t>(),
+ batch_size, magma_queue);
+ }
}
#endif
}
-std::tuple<Tensor, Tensor, Tensor> _btrifact_helper_cuda(const Tensor& self, bool pivot) {
- AT_CHECK(self.dim() > 2,
- "expected tensor with more than 2 dimensions, got size: ", self.sizes(),
+std::tuple<Tensor, Tensor, Tensor> _lu_with_info_cuda(const Tensor& self, bool pivot, bool check_errors) {
+ AT_CHECK(self.dim() >= 2,
+ "expected tensor with 2 or more dimensions, got size: ", self.sizes(),
" instead");
squareCheckInputs(self);
auto req_size = self.sizes().vec();
req_size.pop_back();
- Tensor pivots_tensor;
- if (pivot) {
- pivots_tensor = at::zeros(req_size, self.options().dtype(kInt));
- }
+ Tensor pivots_tensor = at::zeros(req_size, self.options().dtype(at::kInt));
req_size.pop_back();
- auto infos_tensor = at::zeros(req_size, self.options().dtype(kInt));
+ auto infos_tensor = at::zeros(req_size, self.options().dtype(at::kInt));
Tensor self_working_copy;
if (self.numel() == 0) {
self_working_copy = at::empty_like(self);
} else {
self_working_copy = cloneBatchedColumnMajor(self);
- AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "btrifact_cuda", [&]{
- apply_btrifact<scalar_t>(self_working_copy, pivots_tensor, infos_tensor);
+ AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "lu_cuda", [&]{
+ apply_lu<scalar_t>(self_working_copy, pivots_tensor, infos_tensor, pivot);
});
}
+ if (check_errors) {
+ if (self.dim() == 2) {
+ singleCheckErrors(infos_tensor.item<int64_t>(), "lu");
+ } else {
+ batchCheckErrors(infos_tensor, "lu");
+ }
+ }
return std::make_tuple(self_working_copy, pivots_tensor, infos_tensor);
}
matches_jit_signature: True
variants: method, function
-- func: btrifact(Tensor self, *, bool pivot=True, Tensor(a!) A_LU, Tensor(b!) pivots) -> (Tensor(a!), Tensor(b!))
- matches_jit_signature: True
-
-- func: btrifact(Tensor self, *, bool pivot=True) -> (Tensor, Tensor)
- matches_jit_signature: True
- variants: method, function
-
-- func: btrifact_with_info(Tensor self, *, bool pivot=True, Tensor(a!) A_LU, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!), Tensor(b!), Tensor(c!))
- matches_jit_signature: True
-
-- func: btrifact_with_info(Tensor self, *, bool pivot=True) -> (Tensor, Tensor, Tensor)
- matches_jit_signature: True
- variants: method, function
-
-- func: _btrifact_helper(Tensor self, bool pivot) -> (Tensor, Tensor, Tensor)
+- func: _lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor, Tensor, Tensor)
matches_jit_signature: True
variants: function
dispatch:
- CPU: _btrifact_helper_cpu
- CUDA: _btrifact_helper_cuda
+ CPU: _lu_with_info_cpu
+ CUDA: _lu_with_info_cuda
- func: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!)
matches_jit_signature: True
.. automethod:: long
.. automethod:: lt
.. automethod:: lt_
+ .. automethod:: lu
.. automethod:: map_
.. automethod:: masked_scatter_
.. automethod:: masked_scatter
.. autofunction:: det
.. autofunction:: logdet
.. autofunction:: slogdet
+.. autofunction:: lu
.. autofunction:: matmul
.. autofunction:: matrix_power
.. autofunction:: matrix_rank
@skipIfRocm
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
- def test_btrifact(self):
- _TestTorchMixin._test_btrifact(self, lambda t: t.cuda())
+ def test_lu(self):
+ _TestTorchMixin._test_lu(self, lambda t: t.cuda())
@skipIfRocm
@unittest.skipIf(not TEST_MAGMA, "no MAGMA library detected")
_test_mm(n, m, p, torch.int64, lambda x, y: torch.randint(0, 100, (x, y), dtype=torch.int64))
@staticmethod
- def _test_btrifact(self, cast):
+ def _test_lu(self, cast):
from common_utils import random_fullrank_matrix_distinct_singular_value as fullrank
def run_test(matrix_size, batches, cast):
a = cast(fullrank(matrix_size, *batches))
- a_LU_info, pivots_info, info_ = a.btrifact_with_info()
+ a_LU_info, pivots_info, info_ = a.lu(get_infos=True)
self.assertEqual(a_LU_info.size(), torch.Size(batches + (matrix_size, matrix_size)))
self.assertEqual(pivots_info.size(), torch.Size(batches + (matrix_size,)))
self.assertEqual(info_.size(), torch.Size(batches))
self.assertEqual(info_.abs().sum(), 0)
- a_LU, pivots = a.btrifact()
+ a_LU, pivots = a.lu()
self.assertEqual(a_LU, a_LU_info)
self.assertEqual(pivots_info, pivots)
if a.is_cuda:
- a_LU_info_nopiv, nopiv, info_nopiv = a.btrifact_with_info(pivot=False)
- self.assertIsNone(nopiv)
+ a_LU_info_nopiv, nopiv, info_nopiv = a.lu(pivot=False, get_infos=True)
+ self.assertEqual(nopiv, cast(torch.zeros(a.shape[:-1], dtype=torch.int32)))
self.assertEqual(info_, info_nopiv)
P, L, U = torch.btriunpack(a_LU, pivots)
self.assertEqual(P.matmul(L.matmul(U)), a)
- for ms, batch in product([3, 5, 7], [(2,), (3,), (3, 5)]):
+ for ms, batch in product([3, 5, 7], [(), (2,), (3,), (3, 5)]):
run_test(ms, batch, cast)
# Info should be positive for rank deficient matrices
- a = cast(fullrank(3, 5))
+ a = cast(torch.ones(5, 3, 3))
if not (a.is_cuda and any(x in torch.version.cuda for x in ['8.0', '9.2'])):
- a[0, 1] = 2 * a[0, 0] # Row 2 of a[0] is 2 times Row 1 of a[0], thereby causing a rank deficiency
- self.assertGreater(a.btrifact_with_info()[2][0], 0)
+ self.assertGreater(a.lu(get_infos=True)[2][0], 0)
# Error checking, no pivoting variant on CPU
with self.assertRaisesRegex(RuntimeError,
- 'btrifact without pivoting is not implemented on the CPU'):
- torch.btrifact(torch.empty(1, 2, 2), pivot=False)
+ 'lu without pivoting is not implemented on the CPU'):
+ torch.lu(torch.empty(1, 2, 2), pivot=False)
@skipIfNoLapack
@skipIfRocm
- def test_btrifact(self):
- self._test_btrifact(self, lambda t: t)
+ def test_lu(self):
+ self._test_lu(self, lambda t: t)
@staticmethod
def _test_btrisolve(self, cast):
(-1.56, 4.00),
(9.81, -4.09)))
a, b = cast(a), cast(b)
- LU_data, pivots, info = a.btrifact_with_info()
+ LU_data, pivots, info = a.lu(get_infos=True)
self.assertEqual(info.abs().sum(), 0)
x = torch.btrisolve(b, LU_data, pivots)
b_ = torch.bmm(a, x.unsqueeze(2)).squeeze()
def _test_btriunpack(self, cast):
def run_test(shape, cast):
a = cast(torch.randn(*shape))
- a_lu, p = torch.btrifact(a.reshape(-1, shape[-1], shape[-1]))
- a_lu = a_lu.reshape_as(a)
- p = p.reshape(a.shape[:-1])
+ a_lu, p = torch.lu(a)
p_ref, l_ref, u_ref = torch.btriunpack(a_lu, p)
self.assertEqual(p_ref.matmul(l_ref.matmul(u_ref)), a)
+ run_test((3, 3), cast)
run_test((5, 3, 3), cast)
run_test((7, 3, 5, 5), cast)
run_test((7, 5, 3, 3, 3), cast)
A = torch.randn(3, 3, device=A_device)
err_str = "Expected b and A to be on the same device"
with self.assertRaisesRegex(RuntimeError, err_str):
- torch.gesv(b, A)
+ torch.solve(b, A)
with self.assertRaisesRegex(RuntimeError, err_str):
torch.cholesky_solve(b, A)
+ with self.assertRaisesRegex(RuntimeError, err_str):
+ torch.triangular_solve(b, A)
+
@skipIfNoLapack
def test_qr(self):
self.assertEqual(torch.tensor(0., device=device), fn(torch.dot, (0,), (0,)))
if torch._C.has_lapack:
- # btrifact
- A_LU, pivots = fn(torch.btrifact, (0, 5, 5))
+ # lu
+ A_LU, pivots = fn(torch.lu, (0, 5, 5))
self.assertEqual([(0, 5, 5), (0, 5)], [A_LU.shape, pivots.shape])
- A_LU, pivots = fn(torch.btrifact, (0, 0, 0))
+ A_LU, pivots = fn(torch.lu, (0, 0, 0))
self.assertEqual([(0, 0, 0), (0, 0)], [A_LU.shape, pivots.shape])
- A_LU, pivots = fn(torch.btrifact, (2, 0, 0))
+ A_LU, pivots = fn(torch.lu, (2, 0, 0))
self.assertEqual([(2, 0, 0), (2, 0)], [A_LU.shape, pivots.shape])
@skipIfRocm
self: grad.bmm(mat2.transpose(1, 2))
mat2: self.transpose(1, 2).bmm(grad)
-- name: btrifact(Tensor self, bool pivot)
- self: not_implemented("btrifact")
-
-- name: btrifact_with_info(Tensor self, bool pivot)
- self: not_implemented("btrifact_with_info")
-
- name: btrisolve(Tensor self, Tensor LU_data, Tensor LU_pivots)
self: not_implemented("btrisolve")
self: zeros_like(self)
other: zeros_like(other)
+- name: _lu_with_info(Tensor self, bool pivot, bool check_errors)
+ self: not_implemented("lu_with_info")
+
- name: masked_fill_(Tensor self, Tensor mask, Scalar value)
self: grad.clone().masked_fill_(mask, 0)
'_cumsum.*', '_cumprod.*', '_sum.*', '_prod.*',
'_th_.*', '_thnn_.*',
'arange.*', 'range.*', '_solve.*', '_getri.*', '_inverse.*',
- '_cholesky.*', '_btrifact.*', '_triangular_solve.*',
+ '_cholesky.*', '_triangular_solve.*',
'slice', 'randint(_out)?',
'item', '_local_scalar_dense',
'max_pool1d', 'max_pool2d', 'max_pool3d', 'linear', 'to',
center=True, pad_mode='reflect', normalized=False, onesided=True): ...
def split(self, split_size, dim=0): ...
def unique(self, sorted=True, return_inverse=False, dim=None): ...
+ def lu(self, pivot=True, get_infos=False): ...
${function_hints}
See :func:`torch.bmm`
""")
-add_docstr_all('btrifact',
- r"""
-btrifact(pivot=True) -> (Tensor, Tensor)
-
-See :func:`torch.btrifact`
-""")
-
-add_docstr_all('btrifact_with_info',
- r"""
-btrifact_with_info(pivot=True) -> (Tensor, Tensor, Tensor)
-
-See :func:`torch.btrifact_with_info`
-""")
-
add_docstr_all('btrisolve',
r"""
btrisolve(LU_data, LU_pivots) -> Tensor
See :func:`torch.ger`
""")
-add_docstr_all('solve',
- r"""
-solve(A) -> Tensor, Tensor
-
-See :func:`torch.solve`
-""")
-
add_docstr_all('indices',
r"""
indices() -> Tensor
""")
+add_docstr_all('solve',
+ r"""
+solve(A) -> Tensor, Tensor
+
+See :func:`torch.solve`
+""")
+
add_docstr_all('sort',
r"""
sort(dim=-1, descending=False) -> (Tensor, LongTensor)
[ 0., 0., 0.]])
""".format(**factory_like_common_args))
-add_docstr(torch.btrifact,
- r"""
-btrifact(A, pivot=True) -> (Tensor, IntTensor)
-
-Batch LU factorization.
-
-Returns a tuple containing the LU factorization and pivots. Pivoting is done if
-:attr:`pivot` is set.
-
-.. note::
- LU factorization with :attr:`pivot` = ``True`` is not available for CPU, and attempting
- to do so will throw an error. However, LU factorization with :attr:`pivot` = ``True`` is
- available for CUDA.
-
-Arguments:
- A (Tensor): the tensor to factor
- pivot (bool, optional): controls whether pivoting is done
-
-Returns:
- A tuple containing factorization and pivots.
-
-Example::
-
- >>> A = torch.randn(2, 3, 3)
- >>> A_LU, pivots = torch.btrifact(A)
- >>> A_LU
- tensor([[[ 1.3506, 2.5558, -0.0816],
- [ 0.1684, 1.1551, 0.1940],
- [ 0.1193, 0.6189, -0.5497]],
-
- [[ 0.4526, 1.2526, -0.3285],
- [-0.7988, 0.7175, -0.9701],
- [ 0.2634, -0.9255, -0.3459]]])
-
- >>> pivots
- tensor([[ 3, 3, 3],
- [ 3, 3, 3]], dtype=torch.int32)
-""")
-
-add_docstr(torch.btrifact_with_info,
- r"""
-btrifact_with_info(A, pivot=True) -> (Tensor, IntTensor, IntTensor)
-
-Batch LU factorization with additional error information.
-
-This is a version of :meth:`torch.btrifact` that always creates an info
-`IntTensor`, and returns it as the third return value.
-
-Arguments:
- A (Tensor): the tensor to factor
- pivot (bool, optional): controls whether pivoting is done
-
-Returns:
- A tuple containing factorization, pivots, and an `IntTensor` where non-zero
- values indicate whether factorization for each minibatch sample succeeds.
-
-Example::
-
- >>> A = torch.randn(2, 3, 3)
- >>> A_LU, pivots, info = A.btrifact_with_info()
- >>> if info.nonzero().size(0) == 0:
- >>> print('LU factorization succeeded for all samples!')
- LU factorization succeeded for all samples!
-""")
-
add_docstr(torch.btrisolve,
r"""
btrisolve(b, LU_data, LU_pivots) -> Tensor
Arguments:
b (Tensor): the RHS tensor
- LU_data (Tensor): the pivoted LU factorization of A from :meth:`btrifact`.
+ LU_data (Tensor): the pivoted LU factorization of A from :meth:`torch.lu`.
LU_pivots (IntTensor): the pivots of the LU factorization
Example::
>>> A = torch.randn(2, 3, 3)
>>> b = torch.randn(2, 3)
- >>> A_LU = torch.btrifact(A)
+ >>> A_LU = torch.lu(A)
>>> x = torch.btrisolve(b, *A_LU)
>>> torch.norm(torch.bmm(A, x.unsqueeze(2)) - b.unsqueeze(2))
tensor(1.00000e-07 *
__all__ = [
'btriunpack',
+ 'broadcast_tensors',
+ 'btrifact',
+ 'btrifact_with_info',
+ 'cartesian_prod',
'chain_matmul',
'einsum',
- 'broadcast_tensors',
+ 'gesv',
'isfinite',
'isinf',
+ 'lu',
'norm',
'meshgrid',
'potrf',
'pstrf',
'potrs',
- 'gesv',
'split',
'stft',
'tensordot',
'trtrs',
'unique',
- 'cartesian_prod',
]
def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
- r"""Unpacks the data and pivots from a batched LU factorization (btrifact) of a tensor.
+ r"""Unpacks the data and pivots from a LU factorization of a tensor.
Returns a tuple of tensors as ``(the pivots, the L tensor, the U tensor)``.
Example::
>>> A = torch.randn(2, 3, 3)
- >>> A_LU, pivots = A.btrifact()
+ >>> A_LU, pivots = A.lu()
>>> P, A_L, A_U = torch.btriunpack(A_LU, pivots)
>>>
>>> # can recover A from factorization
L = U = None
if unpack_pivots:
- P = torch.eye(sz, device=LU_data.device, dtype=LU_data.dtype).expand_as(LU_data).clone()
- LU_pivots = LU_pivots - 1
- for idx in product(*map(lambda x: list(range(x)), LU_data.shape[:-2])):
+ LU_pivots_zero_idx = LU_pivots - 1
+ if LU_data.dim() > 2:
+ P = torch.eye(sz, device=LU_data.device, dtype=LU_data.dtype).expand_as(LU_data).clone()
+ for idx in product(*map(lambda x: list(range(x)), LU_data.shape[:-2])):
+ final_order = list(range(sz))
+ for k, j in enumerate(LU_pivots_zero_idx[idx]):
+ final_order[k], final_order[j] = final_order[j], final_order[k]
+ P[idx] = P[idx].index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))
+ else:
+ P = torch.eye(sz, device=LU_data.device, dtype=LU_data.dtype)
final_order = list(range(sz))
- for k, j in enumerate(LU_pivots[idx]):
+ for k, j, in enumerate(LU_pivots_zero_idx):
final_order[k], final_order[j] = final_order[j], final_order[k]
- P[idx] = P[idx].index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))
+ P = P.index_select(1, torch.as_tensor(final_order, device=LU_pivots.device))
else:
P = None
In particular, solves :math:`AX = b` and assumes :math:`A` is upper-triangular
with the default keyword arguments.
+ For more information regarding :func:`torch.trtrs`, please check :func:`torch.triangular_solve`.
+
.. warning::
:func:`torch.trtrs` is deprecated in favour of :func:`torch.triangular_solve` and will be
removed in the next release. Please use :func:`torch.triangular_solve` instead.
warnings.warn("torch.trtrs is deprecated in favour of torch.triangular_solve and will be "
"removed in the next release. Please use torch.triangular_solve instead.", stacklevel=2)
return torch.triangular_solve(b, A, upper=upper, transpose=transpose, unitriangular=unitriangular, out=out)
+
+
+def btrifact(A, pivot=True, out=None):
+ r"""Returns a tuple containing the LU factorization and pivots of :attr:`A`.
+ Pivoting is done if :attr:`pivot` is set.
+
+ For more information regarding :func:`torch.btrifact`, please check :func:`torch.lu`.
+
+ .. warning::
+ :func:`torch.btrifact` is deprecated in favour of :func:`torch.lu` and will be
+ removed in the next release. Please use :func:`torch.lu` instead.
+ """
+ warnings.warn("torch.btrifact is deprecated in favour of torch.lu and will be "
+ "removed in the next release. Please use torch.lu instead.", stacklevel=2)
+ return lu(A, pivot=pivot, get_infos=False, out=out)
+
+
+def btrifact_with_info(A, pivot=True, out=None):
+ r"""Performs LU factorization and returns additional status information along with the LU
+ factorization and pivots.
+
+ For more information regarding :func:`torch.btrifact_with_info`, please check :func:`torch.lu`.
+
+ .. warning::
+ :func:`torch.btrifact_with_info` is deprecated in favour of :func:`torch.lu` and will
+ be removed in the next release. Please use :func:`torch.lu` with the :attr:`get_infos`
+ argument set to ``True`` instead.
+ """
+ warnings.warn("torch.btrifact_with_info is deprecated in favour of torch.lu and will be "
+ "removed in the next release. Please use torch.lu with the get_infos argument "
+ "set to True instead.",
+ stacklevel=2)
+ return lu(A, pivot=pivot, get_infos=True, out=out)
+
+
+def lu(A, pivot=True, get_infos=False, out=None):
+ r"""Computes the LU factorization of a square matrix or batches of square matrices
+ :attr:`A`. Returns a tuple containing the LU factorization and pivots of :attr:`A`.
+ Pivoting is done if :attr:`pivot` is set to ``True``.
+
+ .. note::
+ The pivots returned by the function are 1-indexed. If :attr:`pivot` is ``False``,
+ then the returned pivots is a tensor filled with zeros of the appropriate size.
+
+ .. note::
+ LU factorization with :attr:`pivot` = ``False`` is not available for CPU, and attempting
+ to do so will throw an error. However, LU factorization with :attr:`pivot` = ``False`` is
+ available for CUDA.
+
+ .. note::
+ This function does not check if the factorization was successful or not if
+ :attr:`get_infos` is ``True`` since the status of the factorization is present in the
+ third element of the return tuple.
+
+ Arguments:
+ A (Tensor): the tensor to factor of size :math:`(*, m, m)`
+ pivot (bool, optional): controls whether pivoting is done. Default: ``True``
+ get_infos (bool, optional): if set to ``True``, returns an info IntTensor.
+ Default: ``False``
+ out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``,
+ then the elements in the tuple are Tensor, IntTensor,
+ and IntTensor. If :attr:`get_infos` is ``False``, then the
+ elements in the tuple are Tensor, IntTensor. Default: ``None``
+
+ Returns:
+ (Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing
+
+ - **factorization** (*Tensor*): the factorization of size :math:`(*, m, m)`
+
+ - **pivots** (*IntTensor*): the pivots of size :math:`(*, m)`
+
+ - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of
+ size :math:`(*)` where non-zero values indicate whether factorization for the matrix or
+ each minibatch has succeeded or failed
+
+ Example::
+
+ >>> A = torch.randn(2, 3, 3)
+ >>> A_LU, pivots = torch.lu(A)
+ >>> A_LU
+ tensor([[[ 1.3506, 2.5558, -0.0816],
+ [ 0.1684, 1.1551, 0.1940],
+ [ 0.1193, 0.6189, -0.5497]],
+
+ [[ 0.4526, 1.2526, -0.3285],
+ [-0.7988, 0.7175, -0.9701],
+ [ 0.2634, -0.9255, -0.3459]]])
+ >>> pivots
+ tensor([[ 3, 3, 3],
+ [ 3, 3, 3]], dtype=torch.int32)
+ >>> A_LU, pivots, info = torch.lu(A, get_infos=True)
+ >>> if info.nonzero().size(0) == 0:
+ ... print('LU factorization succeeded for all samples!')
+ LU factorization succeeded for all samples!
+ """
+ # If get_infos is True, then we don't need to check for errors and vice versa
+ result = torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
+ if out is not None:
+ if not isinstance(out, (tuple, list)):
+ raise TypeError("argument 'out' must be tuple of Tensors, not {}"
+ .format(type(out).__name__))
+ if len(out) - int(get_infos) != 2:
+ raise TypeError("expected tuple of {} elements but got {}"
+ .format(2 + int(get_infos), len(out)))
+ return (out[i].resize_as_(result[i]).copy_(result[i]) for i in range(len(out)))
+ if get_infos:
+ return result # A_LU, pivots, infos
+ else:
+ return result[0], result[1] # A_LU, pivots
def trtrs(self, A, upper=True, transpose=False, unitriangular=False):
r"""See :func:`torch.triangular_solve`"""
warnings.warn("torch.trtrs is deprecated in favour of torch.triangular_solve and will be "
- "removed in the next release. Please use torch.triangular_solve.", stacklevel=2)
+ "removed in the next release. Please use torch.triangular_solve instead.",
+ stacklevel=2)
return super(Tensor, self).triangular_solve(A, upper=upper,
transpose=transpose, unitriangular=unitriangular)
+ def btrifact(self, pivot=True):
+ r"""See :func:`torch.lu`"""
+ warnings.warn("torch.btrifact is deprecated in favour of torch.lu and will be removed in "
+ "the next release. Please use torch.lu instead.", stacklevel=2)
+ return torch._lu_with_info(self, pivot=pivot, check_errors=True)
+
+ def btrifact_with_info(self, pivot=True):
+ r"""See :func:`torch.lu`"""
+ warnings.warn("torch.btrifact_with_info is deprecated in favour of torch.lu with the "
+ "and will be removed in the next release. Please use torch.lu with the "
+ "get_infos argument set to True instead.", stacklevel=2)
+ return torch._lu_with_info(self, pivot=pivot, check_errors=False)
+
+ def lu(self, pivot=True, get_infos=False):
+ r"""See :func:`torch.lu`"""
+ # If get_infos is True, then we don't need to check for errors and vice versa
+ LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos))
+ if get_infos:
+ return LU, pivots, infos
+ else:
+ return LU, pivots
+
def stft(self, n_fft, hop_length=None, win_length=None, window=None,
center=True, pad_mode='reflect', normalized=False, onesided=True):
r"""See :func:`torch.stft`