From: vishwakftw Date: Thu, 10 Jan 2019 03:36:20 +0000 (-0800) Subject: Batched upper triangular, lower triangular (#15257) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~1936 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=b4c3268b23c30cb14b1a249e9566e0bd54c9bcd8;p=platform%2Fupstream%2Fpytorch.git Batched upper triangular, lower triangular (#15257) Summary: Changelog: - Implements `triu` and `tril` for batches of 2D tensors. - Remove TH/THC binding for `tril` - Fix CUDA implementation - Update docstrings for tril and triu. - Remove mask-based `triu` and `tril` in cholesky forward and backward. - Remove batched tril in torch.distributions.utils Pull Request resolved: https://github.com/pytorch/pytorch/pull/15257 Differential Revision: D13613888 Pulled By: mrshenli fbshipit-source-id: 0949a05b9b8e974c1acfaf02a6284848ec5cc1c4 --- diff --git a/aten/src/ATen/Declarations.cwrap b/aten/src/ATen/Declarations.cwrap index 0648523..38cb05b 100644 --- a/aten/src/ATen/Declarations.cwrap +++ b/aten/src/ATen/Declarations.cwrap @@ -2044,55 +2044,6 @@ - arg: THTensor* tensor ]] [[ - name: _th_tril - cname: tril - variants: - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: long diagonal - default: 0 -]] -[[ - name: _th_tril_ - cname: tril - variants: function - return: self - arguments: - - THTensor* self - - THTensor* self - - arg: long diagonal - default: 0 -]] -[[ - name: _th_triu - cname: triu - variants: - - function - return: argument 0 - arguments: - - arg: THTensor* result - output: True - - THTensor* self - - arg: long diagonal - default: 0 -]] -[[ - name: _th_triu_ - cname: triu - variants: - - function - return: self - arguments: - - THTensor* self - - THTensor* self - - arg: long diagonal - default: 0 -]] -[[ name: _th_cross cname: cross variants: diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index a01f33e..4cb57dc 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -372,20 +372,12 @@ Tensor cholesky(const Tensor &self, bool upper) { } squareCheckInputs(self); - // TODO: (#14071) Once `triu`, `tril` is implemented for batched tensors, - // this can be simplified. Currently, we are zero-ing out values in the - // batch of matrices by using a mask and the `where` function. - // The simplification with batched `triu` and `tril` would be this: - // if (upper) { - // return raw_cholesky_output.triu(); - // } else { - // return raw_cholesky_output.tril(); - // } auto raw_cholesky_output = at::_cholesky_helper(self, upper); - int64_t n = self.size(-1); - auto indices = at::ones({n, n}, self.options().dtype(at::kByte)); - indices = upper ? indices.tril(-1).expand_as(self) : indices.triu(1).expand_as(self); - return at::where(indices, at::zeros({}, self.options()), raw_cholesky_output); + if (upper) { + return raw_cholesky_output.triu_(); + } else { + return raw_cholesky_output.tril_(); + } } Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) { @@ -396,4 +388,136 @@ Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) { return result; } +template +static void apply_triu_tril_single( + scalar_t* result, scalar_t* self, + int64_t k, int64_t n, int64_t m, + int64_t res_row_stride, int64_t res_col_stride, + int64_t self_row_stride, int64_t self_col_stride) { + + constexpr int64_t zero = 0; + int64_t i; + + if (upper) { + #pragma omp parallel for private(i) + for (i = 0; i < n; i++) { + for (int64_t j = 0; j < std::min(m, i + k); j++) { + result[i * res_row_stride + j * res_col_stride] = 0; + } + if (!inplace) { // copy the rest of the self if not inplace + for (int64_t j = std::max(zero, i + k); j < m; j++) { + result[i * res_row_stride + j * res_col_stride] = self[i * self_row_stride + j * self_col_stride]; + } + } + } + } else { + #pragma omp parallel for private(i) + for (i = 0; i < n; i++) { + for (int64_t j = std::max(zero, i + k + 1); j < m; j++) { + result[i * res_row_stride + j * res_col_stride] = 0; + } + if (!inplace) { // copy the rest of the self if not inplace + for (int64_t j = zero; j < std::min(m, i + k + 1); j++) { + result[i * res_row_stride + j * res_col_stride] = self[i * self_row_stride + j * self_col_stride]; + } + } + } + } +} + +template +void apply_triu_tril(Tensor& result, const Tensor& self, int64_t k) { + auto n = self.size(-2); + auto m = self.size(-1); + auto self_data = self.data(); + auto self_stride = self.dim() > 2 ? self.stride(-3) : 1; + auto batchsize = batchCount(self); + auto self_row_stride = self.stride(-2); + auto self_column_stride = self.stride(-1); + + auto result_data = result.data(); + int64_t result_stride, result_row_stride, result_column_stride; + if (result_data != self_data) { + result_stride = result.dim() > 2 ? result.stride(-3) : 1; + result_row_stride = result.stride(-2); + result_column_stride = result.stride(-1); + } else { + result_stride = self_stride; + result_row_stride = self_row_stride; + result_column_stride = self_column_stride; + } + + int64_t b; + #pragma omp parallel for private(b) + for (b = 0; b < batchsize; b++) { + scalar_t* self_batch = &self_data[b * self_stride]; + scalar_t* result_batch = &result_data[b * result_stride]; + apply_triu_tril_single( + result_batch, self_batch, k, n, m, + result_row_stride, result_column_stride, self_row_stride, self_column_stride); + } +} + +Tensor tril(const Tensor& self, int64_t k) { + Tensor result = at::empty({0}, self.options()); + at::tril_out(result, self, k); + return result; +} + +Tensor& tril_cpu_(Tensor &self, int64_t k) { + if (self.numel() == 0) { + return self; + } + if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous(); + AT_DISPATCH_ALL_TYPES(self.type(), "tril", [&]{ + apply_triu_tril(self, self, k); + }); + return self; +} + +Tensor& tril_cpu_out(Tensor &result, const Tensor& self, int64_t k) { + if (result.sizes() != self.sizes()) { + result.resize_as_(self); + } + if (self.numel() == 0) { + return result; + } + Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous(); + AT_DISPATCH_ALL_TYPES(self.type(), "tril", [&]{ + apply_triu_tril(result, self_c, k); + }); + return result; +} + +Tensor triu(const Tensor& self, int64_t k) { + Tensor result = at::empty({0}, self.options()); + at::triu_out(result, self, k); + return result; +} + +Tensor& triu_cpu_(Tensor &self, int64_t k) { + if (self.numel() == 0) { + return self; + } + if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous(); + AT_DISPATCH_ALL_TYPES(self.type(), "triu", [&]{ + apply_triu_tril(self, self, k); + }); + return self; +} + +Tensor& triu_cpu_out(Tensor &result, const Tensor& self, int64_t k) { + if (result.sizes() != self.sizes()) { + result.resize_as_(self); + } + if (self.numel() == 0) { + return result; + } + Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous(); + AT_DISPATCH_ALL_TYPES(self.type(), "triu", [&]{ + apply_triu_tril(result, self_c, k); + }); + return result; +} + }} // namespace at::native diff --git a/aten/src/ATen/native/LegacyDefinitions.cpp b/aten/src/ATen/native/LegacyDefinitions.cpp index eb6a254..abd3378 100644 --- a/aten/src/ATen/native/LegacyDefinitions.cpp +++ b/aten/src/ATen/native/LegacyDefinitions.cpp @@ -130,14 +130,6 @@ Tensor & atan2_(Tensor& self, const Tensor & other) { return at::legacy::th::_th_atan2_(self, other); } -Tensor & tril_(Tensor& self, int64_t diagonal) { - return at::legacy::th::_th_tril_(self, diagonal); -} - -Tensor & triu_(Tensor& self, int64_t diagonal) { - return at::legacy::th::_th_triu_(self, diagonal); -} - Tensor & digamma_(Tensor& self) { return at::legacy::th::_th_digamma_(self); } @@ -272,22 +264,6 @@ Tensor cross(const Tensor & self, const Tensor & other, int64_t dim) { return at::legacy::th::_th_cross(self, other, dim); } -Tensor & triu_out(Tensor & result, const Tensor & self, int64_t diagonal) { - return at::legacy::th::_th_triu_out(result, self, diagonal); -} - -Tensor triu(const Tensor & self, int64_t diagonal) { - return at::legacy::th::_th_triu(self, diagonal); -} - -Tensor & tril_out(Tensor & result, const Tensor & self, int64_t diagonal) { - return at::legacy::th::_th_tril_out(result, self, diagonal); -} - -Tensor tril(const Tensor & self, int64_t diagonal) { - return at::legacy::th::_th_tril(self, diagonal); -} - Tensor trace(const Tensor & self) { return at::legacy::th::_th_trace(self); } diff --git a/aten/src/ATen/native/LinearAlgebraUtils.h b/aten/src/ATen/native/LinearAlgebraUtils.h index dbec1fa..34d3104 100644 --- a/aten/src/ATen/native/LinearAlgebraUtils.h +++ b/aten/src/ATen/native/LinearAlgebraUtils.h @@ -41,6 +41,28 @@ static inline int64_t matrixStride(const Tensor& batched_matrices) { return batched_matrices.size(-1) * batched_matrices.size(-2); } +/* Checks a necessary property for the triu and tril implementations, hence the name. + * Here batch contiguity is checked for tensors with greater than 4 dimensions. + * Contiguous tensors and tensors with less than 3 dimensions pass this check + */ +static inline bool checkTrilTriuBatchContiguous(const Tensor& tensor) { + // Complete contiguity is the most desired property, which is why + // we return true if the tensor is contiguous + if (tensor.is_contiguous()) return true; + + int64_t dims = tensor.dim(); + + // Tensors with dimension less than 4 are handled by default + if (dims <= 3) return true; + + int64_t expected_stride = tensor.size(-1) * tensor.size(-2); + for (int64_t i = dims - 3; i >= 0; i--) { + if (expected_stride != tensor.stride(i)) return false; + expected_stride *= tensor.size(i); + } + return true; +} + // Returns the epsilon value for floating types except half static inline double _get_epsilon(const ScalarType& sc_type) { switch (sc_type) { diff --git a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu index a6f1dd2..b45b4d8 100644 --- a/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu +++ b/aten/src/ATen/native/cuda/BatchLinearAlgebra.cu @@ -461,6 +461,81 @@ Tensor _cholesky_helper_cuda(const Tensor& self, bool upper) { } } +template +__global__ +void triu_tril_kernel( + scalar_t* result, scalar_t* self, int64_t k, int64_t N, + int64_t res_batch_stride, int64_t res_row_stride, int64_t res_col_stride, + int64_t self_batch_stride, int64_t self_row_stride, int64_t self_col_stride, int64_t self_ncol) { + int64_t linear_idx = blockIdx.x * blockDim.x + threadIdx.x; + if (linear_idx >= N) { + return; + } + + int64_t self_batch_idx = blockIdx.y; + int64_t row = linear_idx / self_ncol; + int64_t col = linear_idx % self_ncol; + + bool mask = upper ? (col - row >= k) : (col - row <= k); + + // Now compute the offset for the self and result tensor + int64_t res_offset = self_batch_idx * res_batch_stride + row * res_row_stride + col * res_col_stride; + int64_t self_offset = self_batch_idx * self_batch_stride + row * self_row_stride + col * self_col_stride; + result[res_offset] = mask ? self[self_offset] : scalar_t(0); +} + +template +Tensor& triu_tril_cuda_template(Tensor& result, const Tensor& self, int64_t k, const char* name) { + int64_t n_batches = batchCount(self), mat_size = self.size(-1) * self.size(-2), + res_batch_stride = result.dim() > 2 ? result.stride(-3) : 1, + res_row_stride = result.stride(-2), res_col_stride = result.stride(-1), + self_batch_stride = self.dim() > 2 ? self.stride(-3) : 1, + self_row_stride = self.stride(-2), self_col_stride = self.stride(-1); + dim3 dim_block = cuda::getApplyBlock(); + dim3 dim_grid((mat_size + dim_block.x - 1) / dim_block.x, n_batches); + AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), name, [&]{ + triu_tril_kernel + <<>>( + result.data(), self.data(), k, mat_size, + res_batch_stride, res_row_stride, res_col_stride, + self_batch_stride, self_row_stride, self_col_stride, self.size(-1)); + }); + AT_CUDA_CHECK(cudaGetLastError()); + return result; +} + +Tensor& tril_cuda_(Tensor &self, int64_t k) { + if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous(); + return tril_cuda_out(self, self, k); +} + +Tensor& tril_cuda_out(Tensor &result, const Tensor& self, int64_t k) { + if (result.sizes() != self.sizes()) { + result.resize_as_(self); + } + if (self.numel() == 0) { + return result; + } + Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous(); + return triu_tril_cuda_template(result, self_c, k, "tril"); +} + +Tensor& triu_cuda_(Tensor &self, int64_t k) { + if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous(); + return triu_cuda_out(self, self, k); +} + +Tensor& triu_cuda_out(Tensor &result, const Tensor& self, int64_t k) { + if (result.sizes() != self.sizes()) { + result.resize_as_(self); + } + if (self.numel() == 0) { + return result; + } + Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous(); + return triu_tril_cuda_template(result, self_c, k, "triu"); +} + }} // namespace at::native #undef ALLOCATE_ARRAY diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index fa01c99..d7a5b8b 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2558,9 +2558,15 @@ - func: tril_(Tensor self, int64_t diagonal=0) -> Tensor variants: method + dispatch: + CPU: tril_cpu_ + CUDA: tril_cuda_ - func: triu_(Tensor self, int64_t diagonal=0) -> Tensor variants: method + dispatch: + CPU: triu_cpu_ + CUDA: triu_cuda_ - func: digamma_(Tensor self) -> Tensor variants: method @@ -2661,11 +2667,17 @@ variants: method, function - func: triu_out(Tensor result, Tensor self, int64_t diagonal=0) -> Tensor + dispatch: + CPU: triu_cpu_out + CUDA: triu_cuda_out - func: triu(Tensor self, int64_t diagonal=0) -> Tensor variants: method, function - func: tril_out(Tensor result, Tensor self, int64_t diagonal=0) -> Tensor + dispatch: + CPU: tril_cpu_out + CUDA: tril_cuda_out - func: tril(Tensor self, int64_t diagonal=0) -> Tensor variants: method, function diff --git a/aten/src/TH/generic/THTensorMath.h b/aten/src/TH/generic/THTensorMath.h index d7bb1e6..adabf5b 100644 --- a/aten/src/TH/generic/THTensorMath.h +++ b/aten/src/TH/generic/THTensorMath.h @@ -102,7 +102,6 @@ TH_API void THTensor_(randperm)(THTensor *r_, THGenerator *_generator, int64_t n TH_API void THTensor_(sort)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int dimension, int descendingOrder); TH_API void THTensor_(topk)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int64_t k, int dim, int dir, int sorted); -TH_API void THTensor_(tril)(THTensor *r_, THTensor *t, int64_t k); TH_API void THTensor_(triu)(THTensor *r_, THTensor *t, int64_t k); TH_API void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension); TH_API void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int dimension); diff --git a/aten/src/TH/generic/THTensorMoreMath.cpp b/aten/src/TH/generic/THTensorMoreMath.cpp index d6c29a7..bb7edf3 100644 --- a/aten/src/TH/generic/THTensorMoreMath.cpp +++ b/aten/src/TH/generic/THTensorMoreMath.cpp @@ -1202,37 +1202,6 @@ void THTensor_(topk)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int64_t k, i THLongTensor_free(tmpIndices); } -void THTensor_(tril)(THTensor *r_, THTensor *t, int64_t k) -{ - int64_t t_size_0, t_size_1; - int64_t t_stride_0, t_stride_1; - int64_t r__stride_0, r__stride_1; - scalar_t *t_data, *r__data; - int64_t r, c; - - THArgCheck(THTensor_(nDimensionLegacyAll)(t) == 2, 1, "expected a matrix"); - - THTensor_(resizeAs)(r_, t); - - t_size_0 = THTensor_(size)(t, 0); - t_size_1 = THTensor_(size)(t, 1); - t_stride_0 = THTensor_(stride)(t, 0); - t_stride_1 = THTensor_(stride)(t, 1); - r__stride_0 = THTensor_(stride)(r_, 0); - r__stride_1 = THTensor_(stride)(r_, 1); - r__data = r_->data(); - t_data = t->data(); - - for(r = 0; r < t_size_0; r++) - { - int64_t sz = THMin(r+k+1, t_size_1); - for(c = THMax(0, r+k+1); c < t_size_1; c++) - r__data[r*r__stride_0+c*r__stride_1] = 0; - for(c = 0; c < sz; c++) - r__data[r*r__stride_0+c*r__stride_1] = t_data[r*t_stride_0+c*t_stride_1]; - } -} - void THTensor_(triu)(THTensor *r_, THTensor *t, int64_t k) { int64_t t_size_0, t_size_1; diff --git a/aten/src/THC/generic/THCTensorMath.h b/aten/src/THC/generic/THCTensorMath.h index 8221a37..8ac8ac4 100644 --- a/aten/src/THC/generic/THCTensorMath.h +++ b/aten/src/THC/generic/THCTensorMath.h @@ -12,7 +12,6 @@ THC_API void THCTensor_(cat)(THCState *state, THCTensor *result, THCTensor *ta, THC_API void THCTensor_(catArray)(THCState *state, THCTensor *result, THCTensor **inputs, int numInputs, int dimension); THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self); -THC_API void THCTensor_(tril)(THCState *state, THCTensor *self, THCTensor *src, int64_t k); THC_API void THCTensor_(triu)(THCState *state, THCTensor *self, THCTensor *src, int64_t k); THC_API void THCTensor_(diag)(THCState *state, THCTensor *self, THCTensor *src, int64_t k); THC_API void THCTensor_(eye)(THCState *state, THCTensor *self, int64_t n, int64_t k); diff --git a/aten/src/THC/generic/THCTensorMathPairwise.cu b/aten/src/THC/generic/THCTensorMathPairwise.cu index b3b994d..c16e369 100644 --- a/aten/src/THC/generic/THCTensorMathPairwise.cu +++ b/aten/src/THC/generic/THCTensorMathPairwise.cu @@ -168,35 +168,6 @@ void THCTensor_(remainder)(THCState *state, THCTensor *self_, THCTensor *src_, s THCudaCheck(cudaGetLastError()); } -void THCTensor_(tril)(THCState *state, THCTensor *self_, THCTensor *src_, int64_t k) -{ - THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src_)); - THArgCheck(!src_->is_empty() && src_->dim() == 2, 1, "expected a matrix"); - - if (self_ != src_) - THCTensor_(resizeAs)(state, self_, src_); - - int64_t stride0 = self_->stride(0); - int64_t stride1 = self_->stride(1); - scalar_t *start = THCTensor_(data)(state, self_); - - TensorTriOp op(start, stride0, stride1, k); - - if (self_ == src_) { - if (!THC_pointwiseApply1(state, src_, op)) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } else { - THCTensor_(resizeAs)(state, self_, src_); - - if (!THC_pointwiseApply2(state, self_, src_, op)) { - THArgCheck(false, 2, CUTORCH_DIM_WARNING); - } - } - - THCudaCheck(cudaGetLastError()); -} - void THCTensor_(triu)(THCState *state, THCTensor *self_, THCTensor *src_, int64_t k) { THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src_)); diff --git a/test/common_methods_invocations.py b/test/common_methods_invocations.py index 8d2b790..0be58d4 100644 --- a/test/common_methods_invocations.py +++ b/test/common_methods_invocations.py @@ -569,8 +569,14 @@ def method_tests(): ('diagonal', (M, M, M), (-2, 0, 1), '3d_3'), ('tril', (M, M), NO_ARGS), ('tril', (M, M), (2,), 'idx'), + ('tril', (S, M, M), NO_ARGS, 'batched'), + ('tril', (S, M, M), (2,), 'batched_idx'), + ('tril', (3, 3, S, S), NO_ARGS, 'more_batched'), ('triu', (M, M), NO_ARGS), ('triu', (M, M), (2,), 'idx'), + ('triu', (S, M, M), NO_ARGS, 'batched'), + ('triu', (S, M, M), (2,), 'batched_idx'), + ('triu', (3, 3, S, S), NO_ARGS, 'more_batched'), ('trace', (M, M), NO_ARGS), ('cross', (S, 3), ((S, 3),)), ('cross', (S, 3, S), ((S, 3, S), 1), 'dim'), diff --git a/test/test_cuda.py b/test/test_cuda.py index df1f21e..26eab2a 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -2226,6 +2226,9 @@ class TestCuda(TestCase): for test_args in tri_large_tests_args: _compare_large_trilu_indices(self, *test_args, device='cuda') + def test_triu_tril(self): + _TestTorchMixin._test_triu_tril(self, lambda t: t.cuda()) + def load_ignore_file(): from os.path import join, dirname diff --git a/test/test_torch.py b/test/test_torch.py index d42db3b..f35cda5 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -3816,17 +3816,9 @@ class _TestTorchMixin(object): # input unchanged self.assertEqual(x, x0, 0) - def test_tril(self): - x = torch.rand(SIZE, SIZE) - res1 = torch.tril(x) - res2 = torch.Tensor() - torch.tril(x, out=res2) - self.assertEqual(res1, res2, 0) - def test_trilu_indices(self): for test_args in tri_tests_args: _compare_trilu_indices(self, *test_args) - run_additional_tri_tests(self, 'cpu') # test default options @@ -3837,12 +3829,90 @@ class _TestTorchMixin(object): self.assertEqual( x.triu(0).nonzero().transpose(0, 1), torch.triu_indices(3, 3)) - def test_triu(self): - x = torch.rand(SIZE, SIZE) - res1 = torch.triu(x) - res2 = torch.Tensor() - torch.triu(x, out=res2) - self.assertEqual(res1, res2, 0) + @staticmethod + def _test_triu_tril(self, cast): + def gen_mask(shape, diagonal, cast, upper): + mask = torch.zeros(*shape[-2:]).byte() + for i in range(shape[-2]): + for j in range(shape[-1]): + cond = j - i < diagonal if upper else j - i > diagonal + if cond: + mask[i, j] = 1 + return cast(mask.expand(*shape)) + + torch_functions = {True: torch.triu, False: torch.tril} + if TEST_NUMPY: + numpy_functions = {True: np.triu, False: np.tril} + + def run_test(shape, cast, diagonal): + x_cpu = torch.randn(*shape) + x = cast(x_cpu) + + for upper in [True, False]: + # normal test with mask + torch_tri_func = torch_functions[upper] + res1 = torch_tri_func(x, diagonal=diagonal) + res2 = cast(torch.Tensor()) + torch_tri_func(x, diagonal=diagonal, out=res2) + exp_mask = gen_mask(shape, diagonal, cast, upper) + expected = torch.where(exp_mask, torch.tensor(0).type_as(x), x) + self.assertEqual(res1, res2, 0) + self.assertEqual(expected, res1, 0) + + # non-contiguous and expanded tensors test + if not (0 in shape or 1 in shape): + for s in range(-len(shape), -1): + # non-contiguous tensors + x_nc = x.clone().transpose(s, s + 1) + exp_mask = gen_mask(x_nc.size(), diagonal, cast, upper) + assert not x_nc.is_contiguous(), "x is intentionally non-contiguous" + exp_nc = torch.where(exp_mask, torch.tensor(0).type_as(x), x_nc) + self.assertEqual(torch_tri_func(x_nc, diagonal), exp_nc, 0) + if upper: + self.assertEqual(x_nc.triu_(diagonal), exp_nc, 0) + else: + self.assertEqual(x_nc.tril_(diagonal), exp_nc, 0) + + # any 3-dimensional tensor should be fine + if len(shape) <= 3 or s == -2: + self.assertFalse(x_nc.is_contiguous(), + "x_nc should remain non-contiguous") + elif s < -3: + self.assertTrue(x_nc.is_contiguous(), + "x_nc should become contiguous") + + # expanded tensors + expanded_size = (x.size(0),) + x.size() + x_expanded = x.clone().expand(*expanded_size) + assert 0 in x_expanded.stride(), "x intentionally has 0 in its stride" + output = torch_tri_func(x_expanded, diagonal) + self.assertEqual(output, expected.expand(expanded_size), 0) + self.assertTrue(0 in x_expanded.stride(), + "geometry of x_expanded should be the same") + if upper: + self.assertEqual(output, x_expanded.triu_(diagonal), 0) + else: + self.assertEqual(output, x_expanded.tril_(diagonal), 0) + + if not TEST_NUMPY: + continue + + # numpy test + numpy_tri_func = numpy_functions[upper] + self.assertEqual(numpy_tri_func(x_cpu.numpy(), diagonal), res1.cpu().numpy()) + + diagonals = [-2, -1, 0, 1, 2] + shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices + (7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices + (3, 7), (5, 3, 7), (7, 5, 3, 7), # thin matrices + (3, 0), (0, 3, 3), (3, 3, 0, 0), # no numel matrices + (3, 1), (5, 3, 1), (7, 5, 3, 1), # very fat matrices + (1, 3), (5, 1, 3), (7, 5, 1, 3)] # very thin matrices + for s, d in product(shapes, diagonals): + run_test(s, cast, d) + + def test_triu_tril(self): + self._test_triu_tril(self, lambda t: t) def test_cat(self): SIZE = 10 diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index 776f238..4f1a158 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -675,22 +675,15 @@ Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) { L = L.transpose(-1, -2); } - auto batch_tril = [](const Tensor & A) -> Tensor { - int64_t n = A.size(-1); - auto indices = at::ones({n, n}, A.options().dtype(at::kByte)); - indices = indices.triu(1).expand_as(A); - return at::where(indices, at::zeros({}, A.options()), A); - }; - - auto phi = [&batch_tril](const Tensor & A) -> Tensor { - auto B = A.dim() == 2 ? A.tril() : batch_tril(A); + auto phi = [](const Tensor & A) -> Tensor { + auto B = A.tril(); B = B - 0.5 * at::diag_embed(B.diagonal(0, -1, -2), 0, -2, -1); return B; }; // make sure not to double-count variation, since // only half of output matrix is unique - auto Lbar = grad.dim() == 2 ? grad.tril() : batch_tril(grad); + auto Lbar = grad.tril(); auto P = phi(at::matmul(L, Lbar)); Tensor S; diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 162c2af..f7056cb 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -4951,8 +4951,8 @@ add_docstr(torch.tril, r""" tril(input, diagonal=0, out=None) -> Tensor -Returns the lower triangular part of the matrix (2-D tensor) :attr:`input`, -the other elements of the result tensor :attr:`out` are set to 0. +Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices +:attr:`input`, the other elements of the result tensor :attr:`out` are set to 0. The lower triangular part of the matrix is defined as the elements on and below the diagonal. @@ -5056,8 +5056,8 @@ add_docstr(torch.triu, r""" triu(input, diagonal=0, out=None) -> Tensor -Returns the upper triangular part of the matrix (2-D tensor) :attr:`input`, -the other elements of the result tensor :attr:`out` are set to 0. +Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices +:attr:`input`, the other elements of the result tensor :attr:`out` are set to 0. The upper triangular part of the matrix is defined as the elements on and above the diagonal. @@ -5101,19 +5101,19 @@ Example:: [-0.2447, 0.9556, -1.2919, 1.3378, -0.1768, -1.0857], [ 0.4333, 0.3146, 0.6576, -1.0432, 0.9348, -0.4410], [-0.9888, 1.0679, -1.3337, -1.6556, 0.4798, 0.2830]]) - >>> torch.tril(b, diagonal=1) - tensor([[ 0.5876, -0.0794, 0.0000, 0.0000, 0.0000, 0.0000], - [-0.2447, 0.9556, -1.2919, 0.0000, 0.0000, 0.0000], - [ 0.4333, 0.3146, 0.6576, -1.0432, 0.0000, 0.0000], - [-0.9888, 1.0679, -1.3337, -1.6556, 0.4798, 0.0000]]) - >>> torch.tril(b, diagonal=-1) - tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], - [-0.2447, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000], - [ 0.4333, 0.3146, 0.0000, 0.0000, 0.0000, 0.0000], - [-0.9888, 1.0679, -1.3337, 0.0000, 0.0000, 0.0000]]) + >>> torch.triu(b, diagonal=1) + tensor([[ 0.0000, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], + [ 0.0000, 0.0000, -1.2919, 1.3378, -0.1768, -1.0857], + [ 0.0000, 0.0000, 0.0000, -1.0432, 0.9348, -0.4410], + [ 0.0000, 0.0000, 0.0000, 0.0000, 0.4798, 0.2830]]) + >>> torch.triu(b, diagonal=-1) + tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235], + [-0.2447, 0.9556, -1.2919, 1.3378, -0.1768, -1.0857], + [ 0.0000, 0.3146, 0.6576, -1.0432, 0.9348, -0.4410], + [ 0.0000, 0.0000, -1.3337, -1.6556, 0.4798, 0.2830]]) """) -# docstr is split in two parts to avoid format mis-captureing :math: braces '{}' +# docstr is split in two parts to avoid format mis-capturing :math: braces '{}' # as common args. add_docstr(torch.triu_indices, r""" diff --git a/torch/distributions/constraints.py b/torch/distributions/constraints.py index 8320535..7f6adf9 100644 --- a/torch/distributions/constraints.py +++ b/torch/distributions/constraints.py @@ -19,7 +19,6 @@ The following constraints are implemented: """ import torch -from torch.distributions.utils import batch_tril __all__ = [ 'Constraint', @@ -256,7 +255,7 @@ class _LowerTriangular(Constraint): Constrain to lower-triangular square matrices. """ def check(self, value): - value_tril = batch_tril(value) + value_tril = value.tril() return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] @@ -265,12 +264,10 @@ class _LowerCholesky(Constraint): Constrain to lower-triangular square matrices with positive diagonals. """ def check(self, value): - value_tril = batch_tril(value) + value_tril = value.tril() lower_triangular = (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] - n = value.size(-1) - diag_mask = torch.eye(n, n, dtype=value.dtype, device=value.device) - positive_diagonal = (value * diag_mask > (diag_mask - 1)).min(-1)[0].min(-1)[0] + positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0] return lower_triangular & positive_diagonal diff --git a/torch/distributions/utils.py b/torch/distributions/utils.py index 698c5a2..0f5a2f1 100644 --- a/torch/distributions/utils.py +++ b/torch/distributions/utils.py @@ -90,18 +90,6 @@ def probs_to_logits(probs, is_binary=False): return torch.log(ps_clamped) -def batch_tril(bmat, diagonal=0): - """ - Given a batch of matrices, returns the lower triangular part of each matrix, with - the other entries set to 0. The argument `diagonal` has the same meaning as in - `torch.tril`. - """ - if bmat.dim() == 2: - return bmat.tril(diagonal=diagonal) - else: - return bmat * torch.tril(bmat.new(*bmat.shape[-2:]).fill_(1.0), diagonal=diagonal) - - class lazy_property(object): r""" Used as a decorator for lazy loading of class attributes. This uses a diff --git a/torch/functional.py b/torch/functional.py index 81873ee..54e87f2 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -101,10 +101,8 @@ def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): sz = LU_data.size(-1) if unpack_data: - I_U = torch.ones(sz, sz, device=LU_data.device, dtype=torch.uint8).triu_().expand_as(LU_data) - zero = torch.tensor(0.).type_as(LU_data) - U = torch.where(I_U, LU_data, zero) - L = torch.where(I_U, zero, LU_data) + U = LU_data.triu() + L = LU_data.tril() L.diagonal(dim1=-2, dim2=-1).fill_(1) else: L = U = None