- arg: THTensor* tensor
]]
[[
- name: _th_tril
- cname: tril
- variants:
- - function
- return: argument 0
- arguments:
- - arg: THTensor* result
- output: True
- - THTensor* self
- - arg: long diagonal
- default: 0
-]]
-[[
- name: _th_tril_
- cname: tril
- variants: function
- return: self
- arguments:
- - THTensor* self
- - THTensor* self
- - arg: long diagonal
- default: 0
-]]
-[[
- name: _th_triu
- cname: triu
- variants:
- - function
- return: argument 0
- arguments:
- - arg: THTensor* result
- output: True
- - THTensor* self
- - arg: long diagonal
- default: 0
-]]
-[[
- name: _th_triu_
- cname: triu
- variants:
- - function
- return: self
- arguments:
- - THTensor* self
- - THTensor* self
- - arg: long diagonal
- default: 0
-]]
-[[
name: _th_cross
cname: cross
variants:
}
squareCheckInputs(self);
- // TODO: (#14071) Once `triu`, `tril` is implemented for batched tensors,
- // this can be simplified. Currently, we are zero-ing out values in the
- // batch of matrices by using a mask and the `where` function.
- // The simplification with batched `triu` and `tril` would be this:
- // if (upper) {
- // return raw_cholesky_output.triu();
- // } else {
- // return raw_cholesky_output.tril();
- // }
auto raw_cholesky_output = at::_cholesky_helper(self, upper);
- int64_t n = self.size(-1);
- auto indices = at::ones({n, n}, self.options().dtype(at::kByte));
- indices = upper ? indices.tril(-1).expand_as(self) : indices.triu(1).expand_as(self);
- return at::where(indices, at::zeros({}, self.options()), raw_cholesky_output);
+ if (upper) {
+ return raw_cholesky_output.triu_();
+ } else {
+ return raw_cholesky_output.tril_();
+ }
}
Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) {
return result;
}
+template <typename scalar_t, bool inplace, bool upper>
+static void apply_triu_tril_single(
+ scalar_t* result, scalar_t* self,
+ int64_t k, int64_t n, int64_t m,
+ int64_t res_row_stride, int64_t res_col_stride,
+ int64_t self_row_stride, int64_t self_col_stride) {
+
+ constexpr int64_t zero = 0;
+ int64_t i;
+
+ if (upper) {
+ #pragma omp parallel for private(i)
+ for (i = 0; i < n; i++) {
+ for (int64_t j = 0; j < std::min(m, i + k); j++) {
+ result[i * res_row_stride + j * res_col_stride] = 0;
+ }
+ if (!inplace) { // copy the rest of the self if not inplace
+ for (int64_t j = std::max(zero, i + k); j < m; j++) {
+ result[i * res_row_stride + j * res_col_stride] = self[i * self_row_stride + j * self_col_stride];
+ }
+ }
+ }
+ } else {
+ #pragma omp parallel for private(i)
+ for (i = 0; i < n; i++) {
+ for (int64_t j = std::max(zero, i + k + 1); j < m; j++) {
+ result[i * res_row_stride + j * res_col_stride] = 0;
+ }
+ if (!inplace) { // copy the rest of the self if not inplace
+ for (int64_t j = zero; j < std::min(m, i + k + 1); j++) {
+ result[i * res_row_stride + j * res_col_stride] = self[i * self_row_stride + j * self_col_stride];
+ }
+ }
+ }
+ }
+}
+
+template <typename scalar_t, bool inplace, bool upper>
+void apply_triu_tril(Tensor& result, const Tensor& self, int64_t k) {
+ auto n = self.size(-2);
+ auto m = self.size(-1);
+ auto self_data = self.data<scalar_t>();
+ auto self_stride = self.dim() > 2 ? self.stride(-3) : 1;
+ auto batchsize = batchCount(self);
+ auto self_row_stride = self.stride(-2);
+ auto self_column_stride = self.stride(-1);
+
+ auto result_data = result.data<scalar_t>();
+ int64_t result_stride, result_row_stride, result_column_stride;
+ if (result_data != self_data) {
+ result_stride = result.dim() > 2 ? result.stride(-3) : 1;
+ result_row_stride = result.stride(-2);
+ result_column_stride = result.stride(-1);
+ } else {
+ result_stride = self_stride;
+ result_row_stride = self_row_stride;
+ result_column_stride = self_column_stride;
+ }
+
+ int64_t b;
+ #pragma omp parallel for private(b)
+ for (b = 0; b < batchsize; b++) {
+ scalar_t* self_batch = &self_data[b * self_stride];
+ scalar_t* result_batch = &result_data[b * result_stride];
+ apply_triu_tril_single<scalar_t, inplace, upper>(
+ result_batch, self_batch, k, n, m,
+ result_row_stride, result_column_stride, self_row_stride, self_column_stride);
+ }
+}
+
+Tensor tril(const Tensor& self, int64_t k) {
+ Tensor result = at::empty({0}, self.options());
+ at::tril_out(result, self, k);
+ return result;
+}
+
+Tensor& tril_cpu_(Tensor &self, int64_t k) {
+ if (self.numel() == 0) {
+ return self;
+ }
+ if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
+ AT_DISPATCH_ALL_TYPES(self.type(), "tril", [&]{
+ apply_triu_tril<scalar_t, true, false>(self, self, k);
+ });
+ return self;
+}
+
+Tensor& tril_cpu_out(Tensor &result, const Tensor& self, int64_t k) {
+ if (result.sizes() != self.sizes()) {
+ result.resize_as_(self);
+ }
+ if (self.numel() == 0) {
+ return result;
+ }
+ Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
+ AT_DISPATCH_ALL_TYPES(self.type(), "tril", [&]{
+ apply_triu_tril<scalar_t, false, false>(result, self_c, k);
+ });
+ return result;
+}
+
+Tensor triu(const Tensor& self, int64_t k) {
+ Tensor result = at::empty({0}, self.options());
+ at::triu_out(result, self, k);
+ return result;
+}
+
+Tensor& triu_cpu_(Tensor &self, int64_t k) {
+ if (self.numel() == 0) {
+ return self;
+ }
+ if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
+ AT_DISPATCH_ALL_TYPES(self.type(), "triu", [&]{
+ apply_triu_tril<scalar_t, true, true>(self, self, k);
+ });
+ return self;
+}
+
+Tensor& triu_cpu_out(Tensor &result, const Tensor& self, int64_t k) {
+ if (result.sizes() != self.sizes()) {
+ result.resize_as_(self);
+ }
+ if (self.numel() == 0) {
+ return result;
+ }
+ Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
+ AT_DISPATCH_ALL_TYPES(self.type(), "triu", [&]{
+ apply_triu_tril<scalar_t, false, true>(result, self_c, k);
+ });
+ return result;
+}
+
}} // namespace at::native
return at::legacy::th::_th_atan2_(self, other);
}
-Tensor & tril_(Tensor& self, int64_t diagonal) {
- return at::legacy::th::_th_tril_(self, diagonal);
-}
-
-Tensor & triu_(Tensor& self, int64_t diagonal) {
- return at::legacy::th::_th_triu_(self, diagonal);
-}
-
Tensor & digamma_(Tensor& self) {
return at::legacy::th::_th_digamma_(self);
}
return at::legacy::th::_th_cross(self, other, dim);
}
-Tensor & triu_out(Tensor & result, const Tensor & self, int64_t diagonal) {
- return at::legacy::th::_th_triu_out(result, self, diagonal);
-}
-
-Tensor triu(const Tensor & self, int64_t diagonal) {
- return at::legacy::th::_th_triu(self, diagonal);
-}
-
-Tensor & tril_out(Tensor & result, const Tensor & self, int64_t diagonal) {
- return at::legacy::th::_th_tril_out(result, self, diagonal);
-}
-
-Tensor tril(const Tensor & self, int64_t diagonal) {
- return at::legacy::th::_th_tril(self, diagonal);
-}
-
Tensor trace(const Tensor & self) {
return at::legacy::th::_th_trace(self);
}
return batched_matrices.size(-1) * batched_matrices.size(-2);
}
+/* Checks a necessary property for the triu and tril implementations, hence the name.
+ * Here batch contiguity is checked for tensors with greater than 4 dimensions.
+ * Contiguous tensors and tensors with less than 3 dimensions pass this check
+ */
+static inline bool checkTrilTriuBatchContiguous(const Tensor& tensor) {
+ // Complete contiguity is the most desired property, which is why
+ // we return true if the tensor is contiguous
+ if (tensor.is_contiguous()) return true;
+
+ int64_t dims = tensor.dim();
+
+ // Tensors with dimension less than 4 are handled by default
+ if (dims <= 3) return true;
+
+ int64_t expected_stride = tensor.size(-1) * tensor.size(-2);
+ for (int64_t i = dims - 3; i >= 0; i--) {
+ if (expected_stride != tensor.stride(i)) return false;
+ expected_stride *= tensor.size(i);
+ }
+ return true;
+}
+
// Returns the epsilon value for floating types except half
static inline double _get_epsilon(const ScalarType& sc_type) {
switch (sc_type) {
}
}
+template <typename scalar_t, bool upper>
+__global__
+void triu_tril_kernel(
+ scalar_t* result, scalar_t* self, int64_t k, int64_t N,
+ int64_t res_batch_stride, int64_t res_row_stride, int64_t res_col_stride,
+ int64_t self_batch_stride, int64_t self_row_stride, int64_t self_col_stride, int64_t self_ncol) {
+ int64_t linear_idx = blockIdx.x * blockDim.x + threadIdx.x;
+ if (linear_idx >= N) {
+ return;
+ }
+
+ int64_t self_batch_idx = blockIdx.y;
+ int64_t row = linear_idx / self_ncol;
+ int64_t col = linear_idx % self_ncol;
+
+ bool mask = upper ? (col - row >= k) : (col - row <= k);
+
+ // Now compute the offset for the self and result tensor
+ int64_t res_offset = self_batch_idx * res_batch_stride + row * res_row_stride + col * res_col_stride;
+ int64_t self_offset = self_batch_idx * self_batch_stride + row * self_row_stride + col * self_col_stride;
+ result[res_offset] = mask ? self[self_offset] : scalar_t(0);
+}
+
+template <bool upper>
+Tensor& triu_tril_cuda_template(Tensor& result, const Tensor& self, int64_t k, const char* name) {
+ int64_t n_batches = batchCount(self), mat_size = self.size(-1) * self.size(-2),
+ res_batch_stride = result.dim() > 2 ? result.stride(-3) : 1,
+ res_row_stride = result.stride(-2), res_col_stride = result.stride(-1),
+ self_batch_stride = self.dim() > 2 ? self.stride(-3) : 1,
+ self_row_stride = self.stride(-2), self_col_stride = self.stride(-1);
+ dim3 dim_block = cuda::getApplyBlock();
+ dim3 dim_grid((mat_size + dim_block.x - 1) / dim_block.x, n_batches);
+ AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), name, [&]{
+ triu_tril_kernel<scalar_t, upper>
+ <<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
+ result.data<scalar_t>(), self.data<scalar_t>(), k, mat_size,
+ res_batch_stride, res_row_stride, res_col_stride,
+ self_batch_stride, self_row_stride, self_col_stride, self.size(-1));
+ });
+ AT_CUDA_CHECK(cudaGetLastError());
+ return result;
+}
+
+Tensor& tril_cuda_(Tensor &self, int64_t k) {
+ if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
+ return tril_cuda_out(self, self, k);
+}
+
+Tensor& tril_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
+ if (result.sizes() != self.sizes()) {
+ result.resize_as_(self);
+ }
+ if (self.numel() == 0) {
+ return result;
+ }
+ Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
+ return triu_tril_cuda_template<false>(result, self_c, k, "tril");
+}
+
+Tensor& triu_cuda_(Tensor &self, int64_t k) {
+ if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
+ return triu_cuda_out(self, self, k);
+}
+
+Tensor& triu_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
+ if (result.sizes() != self.sizes()) {
+ result.resize_as_(self);
+ }
+ if (self.numel() == 0) {
+ return result;
+ }
+ Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
+ return triu_tril_cuda_template<true>(result, self_c, k, "triu");
+}
+
}} // namespace at::native
#undef ALLOCATE_ARRAY
- func: tril_(Tensor self, int64_t diagonal=0) -> Tensor
variants: method
+ dispatch:
+ CPU: tril_cpu_
+ CUDA: tril_cuda_
- func: triu_(Tensor self, int64_t diagonal=0) -> Tensor
variants: method
+ dispatch:
+ CPU: triu_cpu_
+ CUDA: triu_cuda_
- func: digamma_(Tensor self) -> Tensor
variants: method
variants: method, function
- func: triu_out(Tensor result, Tensor self, int64_t diagonal=0) -> Tensor
+ dispatch:
+ CPU: triu_cpu_out
+ CUDA: triu_cuda_out
- func: triu(Tensor self, int64_t diagonal=0) -> Tensor
variants: method, function
- func: tril_out(Tensor result, Tensor self, int64_t diagonal=0) -> Tensor
+ dispatch:
+ CPU: tril_cpu_out
+ CUDA: tril_cuda_out
- func: tril(Tensor self, int64_t diagonal=0) -> Tensor
variants: method, function
TH_API void THTensor_(sort)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int dimension, int descendingOrder);
TH_API void THTensor_(topk)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int64_t k, int dim, int dir, int sorted);
-TH_API void THTensor_(tril)(THTensor *r_, THTensor *t, int64_t k);
TH_API void THTensor_(triu)(THTensor *r_, THTensor *t, int64_t k);
TH_API void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension);
TH_API void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int dimension);
THLongTensor_free(tmpIndices);
}
-void THTensor_(tril)(THTensor *r_, THTensor *t, int64_t k)
-{
- int64_t t_size_0, t_size_1;
- int64_t t_stride_0, t_stride_1;
- int64_t r__stride_0, r__stride_1;
- scalar_t *t_data, *r__data;
- int64_t r, c;
-
- THArgCheck(THTensor_(nDimensionLegacyAll)(t) == 2, 1, "expected a matrix");
-
- THTensor_(resizeAs)(r_, t);
-
- t_size_0 = THTensor_(size)(t, 0);
- t_size_1 = THTensor_(size)(t, 1);
- t_stride_0 = THTensor_(stride)(t, 0);
- t_stride_1 = THTensor_(stride)(t, 1);
- r__stride_0 = THTensor_(stride)(r_, 0);
- r__stride_1 = THTensor_(stride)(r_, 1);
- r__data = r_->data<scalar_t>();
- t_data = t->data<scalar_t>();
-
- for(r = 0; r < t_size_0; r++)
- {
- int64_t sz = THMin(r+k+1, t_size_1);
- for(c = THMax(0, r+k+1); c < t_size_1; c++)
- r__data[r*r__stride_0+c*r__stride_1] = 0;
- for(c = 0; c < sz; c++)
- r__data[r*r__stride_0+c*r__stride_1] = t_data[r*t_stride_0+c*t_stride_1];
- }
-}
-
void THTensor_(triu)(THTensor *r_, THTensor *t, int64_t k)
{
int64_t t_size_0, t_size_1;
THC_API void THCTensor_(catArray)(THCState *state, THCTensor *result, THCTensor **inputs, int numInputs, int dimension);
THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self);
-THC_API void THCTensor_(tril)(THCState *state, THCTensor *self, THCTensor *src, int64_t k);
THC_API void THCTensor_(triu)(THCState *state, THCTensor *self, THCTensor *src, int64_t k);
THC_API void THCTensor_(diag)(THCState *state, THCTensor *self, THCTensor *src, int64_t k);
THC_API void THCTensor_(eye)(THCState *state, THCTensor *self, int64_t n, int64_t k);
THCudaCheck(cudaGetLastError());
}
-void THCTensor_(tril)(THCState *state, THCTensor *self_, THCTensor *src_, int64_t k)
-{
- THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src_));
- THArgCheck(!src_->is_empty() && src_->dim() == 2, 1, "expected a matrix");
-
- if (self_ != src_)
- THCTensor_(resizeAs)(state, self_, src_);
-
- int64_t stride0 = self_->stride(0);
- int64_t stride1 = self_->stride(1);
- scalar_t *start = THCTensor_(data)(state, self_);
-
- TensorTriOp<scalar_t, 0> op(start, stride0, stride1, k);
-
- if (self_ == src_) {
- if (!THC_pointwiseApply1<scalar_t>(state, src_, op)) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- } else {
- THCTensor_(resizeAs)(state, self_, src_);
-
- if (!THC_pointwiseApply2<scalar_t, scalar_t>(state, self_, src_, op)) {
- THArgCheck(false, 2, CUTORCH_DIM_WARNING);
- }
- }
-
- THCudaCheck(cudaGetLastError());
-}
-
void THCTensor_(triu)(THCState *state, THCTensor *self_, THCTensor *src_, int64_t k)
{
THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src_));
('diagonal', (M, M, M), (-2, 0, 1), '3d_3'),
('tril', (M, M), NO_ARGS),
('tril', (M, M), (2,), 'idx'),
+ ('tril', (S, M, M), NO_ARGS, 'batched'),
+ ('tril', (S, M, M), (2,), 'batched_idx'),
+ ('tril', (3, 3, S, S), NO_ARGS, 'more_batched'),
('triu', (M, M), NO_ARGS),
('triu', (M, M), (2,), 'idx'),
+ ('triu', (S, M, M), NO_ARGS, 'batched'),
+ ('triu', (S, M, M), (2,), 'batched_idx'),
+ ('triu', (3, 3, S, S), NO_ARGS, 'more_batched'),
('trace', (M, M), NO_ARGS),
('cross', (S, 3), ((S, 3),)),
('cross', (S, 3, S), ((S, 3, S), 1), 'dim'),
for test_args in tri_large_tests_args:
_compare_large_trilu_indices(self, *test_args, device='cuda')
+ def test_triu_tril(self):
+ _TestTorchMixin._test_triu_tril(self, lambda t: t.cuda())
+
def load_ignore_file():
from os.path import join, dirname
# input unchanged
self.assertEqual(x, x0, 0)
- def test_tril(self):
- x = torch.rand(SIZE, SIZE)
- res1 = torch.tril(x)
- res2 = torch.Tensor()
- torch.tril(x, out=res2)
- self.assertEqual(res1, res2, 0)
-
def test_trilu_indices(self):
for test_args in tri_tests_args:
_compare_trilu_indices(self, *test_args)
-
run_additional_tri_tests(self, 'cpu')
# test default options
self.assertEqual(
x.triu(0).nonzero().transpose(0, 1), torch.triu_indices(3, 3))
- def test_triu(self):
- x = torch.rand(SIZE, SIZE)
- res1 = torch.triu(x)
- res2 = torch.Tensor()
- torch.triu(x, out=res2)
- self.assertEqual(res1, res2, 0)
+ @staticmethod
+ def _test_triu_tril(self, cast):
+ def gen_mask(shape, diagonal, cast, upper):
+ mask = torch.zeros(*shape[-2:]).byte()
+ for i in range(shape[-2]):
+ for j in range(shape[-1]):
+ cond = j - i < diagonal if upper else j - i > diagonal
+ if cond:
+ mask[i, j] = 1
+ return cast(mask.expand(*shape))
+
+ torch_functions = {True: torch.triu, False: torch.tril}
+ if TEST_NUMPY:
+ numpy_functions = {True: np.triu, False: np.tril}
+
+ def run_test(shape, cast, diagonal):
+ x_cpu = torch.randn(*shape)
+ x = cast(x_cpu)
+
+ for upper in [True, False]:
+ # normal test with mask
+ torch_tri_func = torch_functions[upper]
+ res1 = torch_tri_func(x, diagonal=diagonal)
+ res2 = cast(torch.Tensor())
+ torch_tri_func(x, diagonal=diagonal, out=res2)
+ exp_mask = gen_mask(shape, diagonal, cast, upper)
+ expected = torch.where(exp_mask, torch.tensor(0).type_as(x), x)
+ self.assertEqual(res1, res2, 0)
+ self.assertEqual(expected, res1, 0)
+
+ # non-contiguous and expanded tensors test
+ if not (0 in shape or 1 in shape):
+ for s in range(-len(shape), -1):
+ # non-contiguous tensors
+ x_nc = x.clone().transpose(s, s + 1)
+ exp_mask = gen_mask(x_nc.size(), diagonal, cast, upper)
+ assert not x_nc.is_contiguous(), "x is intentionally non-contiguous"
+ exp_nc = torch.where(exp_mask, torch.tensor(0).type_as(x), x_nc)
+ self.assertEqual(torch_tri_func(x_nc, diagonal), exp_nc, 0)
+ if upper:
+ self.assertEqual(x_nc.triu_(diagonal), exp_nc, 0)
+ else:
+ self.assertEqual(x_nc.tril_(diagonal), exp_nc, 0)
+
+ # any 3-dimensional tensor should be fine
+ if len(shape) <= 3 or s == -2:
+ self.assertFalse(x_nc.is_contiguous(),
+ "x_nc should remain non-contiguous")
+ elif s < -3:
+ self.assertTrue(x_nc.is_contiguous(),
+ "x_nc should become contiguous")
+
+ # expanded tensors
+ expanded_size = (x.size(0),) + x.size()
+ x_expanded = x.clone().expand(*expanded_size)
+ assert 0 in x_expanded.stride(), "x intentionally has 0 in its stride"
+ output = torch_tri_func(x_expanded, diagonal)
+ self.assertEqual(output, expected.expand(expanded_size), 0)
+ self.assertTrue(0 in x_expanded.stride(),
+ "geometry of x_expanded should be the same")
+ if upper:
+ self.assertEqual(output, x_expanded.triu_(diagonal), 0)
+ else:
+ self.assertEqual(output, x_expanded.tril_(diagonal), 0)
+
+ if not TEST_NUMPY:
+ continue
+
+ # numpy test
+ numpy_tri_func = numpy_functions[upper]
+ self.assertEqual(numpy_tri_func(x_cpu.numpy(), diagonal), res1.cpu().numpy())
+
+ diagonals = [-2, -1, 0, 1, 2]
+ shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3), # square matrices
+ (7, 3), (5, 7, 3), (7, 5, 7, 3), # fat matrices
+ (3, 7), (5, 3, 7), (7, 5, 3, 7), # thin matrices
+ (3, 0), (0, 3, 3), (3, 3, 0, 0), # no numel matrices
+ (3, 1), (5, 3, 1), (7, 5, 3, 1), # very fat matrices
+ (1, 3), (5, 1, 3), (7, 5, 1, 3)] # very thin matrices
+ for s, d in product(shapes, diagonals):
+ run_test(s, cast, d)
+
+ def test_triu_tril(self):
+ self._test_triu_tril(self, lambda t: t)
def test_cat(self):
SIZE = 10
L = L.transpose(-1, -2);
}
- auto batch_tril = [](const Tensor & A) -> Tensor {
- int64_t n = A.size(-1);
- auto indices = at::ones({n, n}, A.options().dtype(at::kByte));
- indices = indices.triu(1).expand_as(A);
- return at::where(indices, at::zeros({}, A.options()), A);
- };
-
- auto phi = [&batch_tril](const Tensor & A) -> Tensor {
- auto B = A.dim() == 2 ? A.tril() : batch_tril(A);
+ auto phi = [](const Tensor & A) -> Tensor {
+ auto B = A.tril();
B = B - 0.5 * at::diag_embed(B.diagonal(0, -1, -2), 0, -2, -1);
return B;
};
// make sure not to double-count variation, since
// only half of output matrix is unique
- auto Lbar = grad.dim() == 2 ? grad.tril() : batch_tril(grad);
+ auto Lbar = grad.tril();
auto P = phi(at::matmul(L, Lbar));
Tensor S;
r"""
tril(input, diagonal=0, out=None) -> Tensor
-Returns the lower triangular part of the matrix (2-D tensor) :attr:`input`,
-the other elements of the result tensor :attr:`out` are set to 0.
+Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices
+:attr:`input`, the other elements of the result tensor :attr:`out` are set to 0.
The lower triangular part of the matrix is defined as the elements on and
below the diagonal.
r"""
triu(input, diagonal=0, out=None) -> Tensor
-Returns the upper triangular part of the matrix (2-D tensor) :attr:`input`,
-the other elements of the result tensor :attr:`out` are set to 0.
+Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices
+:attr:`input`, the other elements of the result tensor :attr:`out` are set to 0.
The upper triangular part of the matrix is defined as the elements on and
above the diagonal.
[-0.2447, 0.9556, -1.2919, 1.3378, -0.1768, -1.0857],
[ 0.4333, 0.3146, 0.6576, -1.0432, 0.9348, -0.4410],
[-0.9888, 1.0679, -1.3337, -1.6556, 0.4798, 0.2830]])
- >>> torch.tril(b, diagonal=1)
- tensor([[ 0.5876, -0.0794, 0.0000, 0.0000, 0.0000, 0.0000],
- [-0.2447, 0.9556, -1.2919, 0.0000, 0.0000, 0.0000],
- [ 0.4333, 0.3146, 0.6576, -1.0432, 0.0000, 0.0000],
- [-0.9888, 1.0679, -1.3337, -1.6556, 0.4798, 0.0000]])
- >>> torch.tril(b, diagonal=-1)
- tensor([[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
- [-0.2447, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
- [ 0.4333, 0.3146, 0.0000, 0.0000, 0.0000, 0.0000],
- [-0.9888, 1.0679, -1.3337, 0.0000, 0.0000, 0.0000]])
+ >>> torch.triu(b, diagonal=1)
+ tensor([[ 0.0000, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235],
+ [ 0.0000, 0.0000, -1.2919, 1.3378, -0.1768, -1.0857],
+ [ 0.0000, 0.0000, 0.0000, -1.0432, 0.9348, -0.4410],
+ [ 0.0000, 0.0000, 0.0000, 0.0000, 0.4798, 0.2830]])
+ >>> torch.triu(b, diagonal=-1)
+ tensor([[ 0.5876, -0.0794, -1.8373, 0.6654, 0.2604, 1.5235],
+ [-0.2447, 0.9556, -1.2919, 1.3378, -0.1768, -1.0857],
+ [ 0.0000, 0.3146, 0.6576, -1.0432, 0.9348, -0.4410],
+ [ 0.0000, 0.0000, -1.3337, -1.6556, 0.4798, 0.2830]])
""")
-# docstr is split in two parts to avoid format mis-captureing :math: braces '{}'
+# docstr is split in two parts to avoid format mis-capturing :math: braces '{}'
# as common args.
add_docstr(torch.triu_indices,
r"""
"""
import torch
-from torch.distributions.utils import batch_tril
__all__ = [
'Constraint',
Constrain to lower-triangular square matrices.
"""
def check(self, value):
- value_tril = batch_tril(value)
+ value_tril = value.tril()
return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
Constrain to lower-triangular square matrices with positive diagonals.
"""
def check(self, value):
- value_tril = batch_tril(value)
+ value_tril = value.tril()
lower_triangular = (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
- n = value.size(-1)
- diag_mask = torch.eye(n, n, dtype=value.dtype, device=value.device)
- positive_diagonal = (value * diag_mask > (diag_mask - 1)).min(-1)[0].min(-1)[0]
+ positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0]
return lower_triangular & positive_diagonal
return torch.log(ps_clamped)
-def batch_tril(bmat, diagonal=0):
- """
- Given a batch of matrices, returns the lower triangular part of each matrix, with
- the other entries set to 0. The argument `diagonal` has the same meaning as in
- `torch.tril`.
- """
- if bmat.dim() == 2:
- return bmat.tril(diagonal=diagonal)
- else:
- return bmat * torch.tril(bmat.new(*bmat.shape[-2:]).fill_(1.0), diagonal=diagonal)
-
-
class lazy_property(object):
r"""
Used as a decorator for lazy loading of class attributes. This uses a
sz = LU_data.size(-1)
if unpack_data:
- I_U = torch.ones(sz, sz, device=LU_data.device, dtype=torch.uint8).triu_().expand_as(LU_data)
- zero = torch.tensor(0.).type_as(LU_data)
- U = torch.where(I_U, LU_data, zero)
- L = torch.where(I_U, zero, LU_data)
+ U = LU_data.triu()
+ L = LU_data.tril()
L.diagonal(dim1=-2, dim2=-1).fill_(1)
else:
L = U = None