Batched upper triangular, lower triangular (#15257)
authorvishwakftw <cs15btech11043@iith.ac.in>
Thu, 10 Jan 2019 03:36:20 +0000 (19:36 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 10 Jan 2019 03:46:39 +0000 (19:46 -0800)
Summary:
Changelog:

- Implements `triu` and `tril` for batches of 2D tensors.
- Remove TH/THC binding for `tril`
- Fix CUDA implementation
- Update docstrings for tril and triu.
- Remove mask-based `triu` and `tril` in cholesky forward and backward.
- Remove batched tril in torch.distributions.utils
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15257

Differential Revision: D13613888

Pulled By: mrshenli

fbshipit-source-id: 0949a05b9b8e974c1acfaf02a6284848ec5cc1c4

18 files changed:
aten/src/ATen/Declarations.cwrap
aten/src/ATen/native/BatchLinearAlgebra.cpp
aten/src/ATen/native/LegacyDefinitions.cpp
aten/src/ATen/native/LinearAlgebraUtils.h
aten/src/ATen/native/cuda/BatchLinearAlgebra.cu
aten/src/ATen/native/native_functions.yaml
aten/src/TH/generic/THTensorMath.h
aten/src/TH/generic/THTensorMoreMath.cpp
aten/src/THC/generic/THCTensorMath.h
aten/src/THC/generic/THCTensorMathPairwise.cu
test/common_methods_invocations.py
test/test_cuda.py
test/test_torch.py
tools/autograd/templates/Functions.cpp
torch/_torch_docs.py
torch/distributions/constraints.py
torch/distributions/utils.py
torch/functional.py

index 0648523..38cb05b 100644 (file)
     - arg: THTensor* tensor
 ]]
 [[
-  name: _th_tril
-  cname: tril
-  variants:
-    - function
-  return: argument 0
-  arguments:
-    - arg: THTensor* result
-      output: True
-    - THTensor* self
-    - arg: long diagonal
-      default: 0
-]]
-[[
-  name: _th_tril_
-  cname: tril
-  variants: function
-  return: self
-  arguments:
-    - THTensor* self
-    - THTensor* self
-    - arg: long diagonal
-      default: 0
-]]
-[[
-  name: _th_triu
-  cname: triu
-  variants:
-    - function
-  return: argument 0
-  arguments:
-    - arg: THTensor* result
-      output: True
-    - THTensor* self
-    - arg: long diagonal
-      default: 0
-]]
-[[
-  name: _th_triu_
-  cname: triu
-  variants:
-    - function
-  return: self
-  arguments:
-    - THTensor* self
-    - THTensor* self
-    - arg: long diagonal
-      default: 0
-]]
-[[
   name: _th_cross
   cname: cross
   variants:
index a01f33e..4cb57dc 100644 (file)
@@ -372,20 +372,12 @@ Tensor cholesky(const Tensor &self, bool upper) {
   }
   squareCheckInputs(self);
 
-  // TODO: (#14071) Once `triu`, `tril` is implemented for batched tensors,
-  // this can be simplified. Currently, we are zero-ing out values in the
-  // batch of matrices by using a mask and the `where` function.
-  // The simplification with batched `triu` and `tril` would be this:
-  // if (upper) {
-  //   return raw_cholesky_output.triu();
-  // } else {
-  //   return raw_cholesky_output.tril();
-  // }
   auto raw_cholesky_output = at::_cholesky_helper(self, upper);
-  int64_t n = self.size(-1);
-  auto indices = at::ones({n, n}, self.options().dtype(at::kByte));
-  indices = upper ? indices.tril(-1).expand_as(self) : indices.triu(1).expand_as(self);
-  return at::where(indices, at::zeros({}, self.options()), raw_cholesky_output);
+  if (upper) {
+    return raw_cholesky_output.triu_();
+  } else {
+    return raw_cholesky_output.tril_();
+  }
 }
 
 Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) {
@@ -396,4 +388,136 @@ Tensor& cholesky_out(Tensor &result, const Tensor &self, bool upper) {
   return result;
 }
 
+template <typename scalar_t, bool inplace, bool upper>
+static void apply_triu_tril_single(
+    scalar_t* result, scalar_t* self,
+    int64_t k, int64_t n, int64_t m,
+    int64_t res_row_stride, int64_t res_col_stride,
+    int64_t self_row_stride, int64_t self_col_stride) {
+
+  constexpr int64_t zero = 0;
+  int64_t i;
+
+  if (upper) {
+    #pragma omp parallel for private(i)
+    for (i = 0; i < n; i++) {
+      for (int64_t j = 0; j < std::min(m, i + k); j++) {
+        result[i * res_row_stride + j * res_col_stride] = 0;
+      }
+      if (!inplace) {  // copy the rest of the self if not inplace
+        for (int64_t j = std::max(zero, i + k); j < m; j++) {
+          result[i * res_row_stride + j * res_col_stride] = self[i * self_row_stride + j * self_col_stride];
+        }
+      }
+    }
+  } else {
+    #pragma omp parallel for private(i)
+    for (i = 0; i < n; i++) {
+      for (int64_t j = std::max(zero, i + k + 1); j < m; j++) {
+        result[i * res_row_stride + j * res_col_stride] = 0;
+      }
+      if (!inplace) {  // copy the rest of the self if not inplace
+        for (int64_t j = zero; j < std::min(m, i + k + 1); j++) {
+          result[i * res_row_stride + j * res_col_stride] = self[i * self_row_stride + j * self_col_stride];
+        }
+      }
+    }
+  }
+}
+
+template <typename scalar_t, bool inplace, bool upper>
+void apply_triu_tril(Tensor& result, const Tensor& self, int64_t k) {
+  auto n = self.size(-2);
+  auto m = self.size(-1);
+  auto self_data = self.data<scalar_t>();
+  auto self_stride = self.dim() > 2 ? self.stride(-3) : 1;
+  auto batchsize = batchCount(self);
+  auto self_row_stride = self.stride(-2);
+  auto self_column_stride = self.stride(-1);
+
+  auto result_data = result.data<scalar_t>();
+  int64_t result_stride, result_row_stride, result_column_stride;
+  if (result_data != self_data) {
+    result_stride = result.dim() > 2 ? result.stride(-3) : 1;
+    result_row_stride = result.stride(-2);
+    result_column_stride = result.stride(-1);
+  } else {
+    result_stride = self_stride;
+    result_row_stride = self_row_stride;
+    result_column_stride = self_column_stride;
+  }
+
+  int64_t b;
+  #pragma omp parallel for private(b)
+  for (b = 0; b < batchsize; b++) {
+    scalar_t* self_batch = &self_data[b * self_stride];
+    scalar_t* result_batch = &result_data[b * result_stride];
+    apply_triu_tril_single<scalar_t, inplace, upper>(
+        result_batch, self_batch, k, n, m,
+        result_row_stride, result_column_stride, self_row_stride, self_column_stride);
+  }
+}
+
+Tensor tril(const Tensor& self, int64_t k) {
+  Tensor result = at::empty({0}, self.options());
+  at::tril_out(result, self, k);
+  return result;
+}
+
+Tensor& tril_cpu_(Tensor &self, int64_t k) {
+  if (self.numel() == 0) {
+    return self;
+  }
+  if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
+  AT_DISPATCH_ALL_TYPES(self.type(), "tril", [&]{
+    apply_triu_tril<scalar_t, true, false>(self, self, k);
+  });
+  return self;
+}
+
+Tensor& tril_cpu_out(Tensor &result, const Tensor& self, int64_t k) {
+  if (result.sizes() != self.sizes()) {
+    result.resize_as_(self);
+  }
+  if (self.numel() == 0) {
+    return result;
+  }
+  Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
+  AT_DISPATCH_ALL_TYPES(self.type(), "tril", [&]{
+    apply_triu_tril<scalar_t, false, false>(result, self_c, k);
+  });
+  return result;
+}
+
+Tensor triu(const Tensor& self, int64_t k) {
+  Tensor result = at::empty({0}, self.options());
+  at::triu_out(result, self, k);
+  return result;
+}
+
+Tensor& triu_cpu_(Tensor &self, int64_t k) {
+  if (self.numel() == 0) {
+    return self;
+  }
+  if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
+  AT_DISPATCH_ALL_TYPES(self.type(), "triu", [&]{
+    apply_triu_tril<scalar_t, true, true>(self, self, k);
+  });
+  return self;
+}
+
+Tensor& triu_cpu_out(Tensor &result, const Tensor& self, int64_t k) {
+  if (result.sizes() != self.sizes()) {
+    result.resize_as_(self);
+  }
+  if (self.numel() == 0) {
+    return result;
+  }
+  Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
+  AT_DISPATCH_ALL_TYPES(self.type(), "triu", [&]{
+    apply_triu_tril<scalar_t, false, true>(result, self_c, k);
+  });
+  return result;
+}
+
 }}  // namespace at::native
index eb6a254..abd3378 100644 (file)
@@ -130,14 +130,6 @@ Tensor & atan2_(Tensor& self, const Tensor & other) {
   return at::legacy::th::_th_atan2_(self, other);
 }
 
-Tensor & tril_(Tensor& self, int64_t diagonal) {
-  return at::legacy::th::_th_tril_(self, diagonal);
-}
-
-Tensor & triu_(Tensor& self, int64_t diagonal) {
-  return at::legacy::th::_th_triu_(self, diagonal);
-}
-
 Tensor & digamma_(Tensor& self) {
   return at::legacy::th::_th_digamma_(self);
 }
@@ -272,22 +264,6 @@ Tensor cross(const Tensor & self, const Tensor & other, int64_t dim) {
   return at::legacy::th::_th_cross(self, other, dim);
 }
 
-Tensor & triu_out(Tensor & result, const Tensor & self, int64_t diagonal) {
-  return at::legacy::th::_th_triu_out(result, self, diagonal);
-}
-
-Tensor triu(const Tensor & self, int64_t diagonal) {
-  return at::legacy::th::_th_triu(self, diagonal);
-}
-
-Tensor & tril_out(Tensor & result, const Tensor & self, int64_t diagonal) {
-  return at::legacy::th::_th_tril_out(result, self, diagonal);
-}
-
-Tensor tril(const Tensor & self, int64_t diagonal) {
-  return at::legacy::th::_th_tril(self, diagonal);
-}
-
 Tensor trace(const Tensor & self) {
   return at::legacy::th::_th_trace(self);
 }
index dbec1fa..34d3104 100644 (file)
@@ -41,6 +41,28 @@ static inline int64_t matrixStride(const Tensor& batched_matrices) {
   return batched_matrices.size(-1) * batched_matrices.size(-2);
 }
 
+/* Checks a necessary property for the triu and tril implementations, hence the name.
+ * Here batch contiguity is checked for tensors with greater than 4 dimensions.
+ * Contiguous tensors and tensors with less than 3 dimensions pass this check
+ */ 
+static inline bool checkTrilTriuBatchContiguous(const Tensor& tensor) {
+  // Complete contiguity is the most desired property, which is why
+  // we return true if the tensor is contiguous
+  if (tensor.is_contiguous()) return true;
+
+  int64_t dims = tensor.dim();
+
+  // Tensors with dimension less than 4 are handled by default
+  if (dims <= 3) return true;
+
+  int64_t expected_stride = tensor.size(-1) * tensor.size(-2);
+  for (int64_t i = dims - 3; i >= 0; i--) {
+    if (expected_stride != tensor.stride(i)) return false;
+    expected_stride *= tensor.size(i);
+  }
+  return true;
+}
+
 // Returns the epsilon value for floating types except half
 static inline double _get_epsilon(const ScalarType& sc_type) {
   switch (sc_type) {
index a6f1dd2..b45b4d8 100644 (file)
@@ -461,6 +461,81 @@ Tensor _cholesky_helper_cuda(const Tensor& self, bool upper) {
   }
 }
 
+template <typename scalar_t, bool upper>
+__global__
+void triu_tril_kernel(
+    scalar_t* result, scalar_t* self, int64_t k, int64_t N,
+    int64_t res_batch_stride, int64_t res_row_stride, int64_t res_col_stride,
+    int64_t self_batch_stride, int64_t self_row_stride, int64_t self_col_stride, int64_t self_ncol) {
+  int64_t linear_idx = blockIdx.x * blockDim.x + threadIdx.x;
+  if (linear_idx >= N) {
+    return;
+  }
+
+  int64_t self_batch_idx = blockIdx.y;
+  int64_t row = linear_idx / self_ncol;
+  int64_t col = linear_idx % self_ncol;
+
+  bool mask = upper ? (col - row >= k) : (col - row <= k);
+
+  // Now compute the offset for the self and result tensor
+  int64_t res_offset = self_batch_idx * res_batch_stride + row * res_row_stride + col * res_col_stride;
+  int64_t self_offset = self_batch_idx * self_batch_stride + row * self_row_stride + col * self_col_stride;
+  result[res_offset] = mask ? self[self_offset] : scalar_t(0);
+}
+
+template <bool upper>
+Tensor& triu_tril_cuda_template(Tensor& result, const Tensor& self, int64_t k, const char* name) {
+  int64_t n_batches = batchCount(self), mat_size = self.size(-1) * self.size(-2),
+          res_batch_stride = result.dim() > 2 ? result.stride(-3) : 1,
+          res_row_stride = result.stride(-2), res_col_stride = result.stride(-1),
+          self_batch_stride = self.dim() > 2 ? self.stride(-3) : 1,
+          self_row_stride = self.stride(-2), self_col_stride = self.stride(-1);
+  dim3 dim_block = cuda::getApplyBlock();
+  dim3 dim_grid((mat_size + dim_block.x - 1) / dim_block.x, n_batches);
+  AT_DISPATCH_ALL_TYPES_AND_HALF(self.type(), name, [&]{
+    triu_tril_kernel<scalar_t, upper>
+      <<<dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
+        result.data<scalar_t>(), self.data<scalar_t>(), k, mat_size,
+        res_batch_stride, res_row_stride, res_col_stride,
+        self_batch_stride, self_row_stride, self_col_stride, self.size(-1));
+  });
+  AT_CUDA_CHECK(cudaGetLastError());
+  return result;
+}
+
+Tensor& tril_cuda_(Tensor &self, int64_t k) {
+  if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
+  return tril_cuda_out(self, self, k);
+}
+
+Tensor& tril_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
+  if (result.sizes() != self.sizes()) {
+    result.resize_as_(self);
+  }
+  if (self.numel() == 0) {
+    return result;
+  }
+  Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
+  return triu_tril_cuda_template<false>(result, self_c, k, "tril");
+}
+
+Tensor& triu_cuda_(Tensor &self, int64_t k) {
+  if (!checkTrilTriuBatchContiguous(self)) self = self.contiguous();
+  return triu_cuda_out(self, self, k);
+}
+
+Tensor& triu_cuda_out(Tensor &result, const Tensor& self, int64_t k) {
+  if (result.sizes() != self.sizes()) {
+    result.resize_as_(self);
+  }
+  if (self.numel() == 0) {
+    return result;
+  }
+  Tensor self_c = checkTrilTriuBatchContiguous(self) ? self : self.contiguous();
+  return triu_tril_cuda_template<true>(result, self_c, k, "triu");
+}
+
 }}  // namespace at::native
 
 #undef ALLOCATE_ARRAY
index fa01c99..d7a5b8b 100644 (file)
 
 - func: tril_(Tensor self, int64_t diagonal=0) -> Tensor
   variants: method
+  dispatch:
+    CPU: tril_cpu_
+    CUDA: tril_cuda_
 
 - func: triu_(Tensor self,  int64_t diagonal=0) -> Tensor
   variants: method
+  dispatch:
+    CPU: triu_cpu_
+    CUDA: triu_cuda_
 
 - func: digamma_(Tensor self) -> Tensor
   variants: method
   variants: method, function
 
 - func: triu_out(Tensor result, Tensor self, int64_t diagonal=0) -> Tensor
+  dispatch:
+    CPU: triu_cpu_out
+    CUDA: triu_cuda_out
 
 - func: triu(Tensor self, int64_t diagonal=0) -> Tensor
   variants: method, function
 
 - func: tril_out(Tensor result, Tensor self, int64_t diagonal=0) -> Tensor
+  dispatch:
+    CPU: tril_cpu_out
+    CUDA: tril_cuda_out
 
 - func: tril(Tensor self, int64_t diagonal=0) -> Tensor
   variants: method, function
index d7bb1e6..adabf5b 100644 (file)
@@ -102,7 +102,6 @@ TH_API void THTensor_(randperm)(THTensor *r_, THGenerator *_generator, int64_t n
 
 TH_API void THTensor_(sort)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int dimension, int descendingOrder);
 TH_API void THTensor_(topk)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int64_t k, int dim, int dir, int sorted);
-TH_API void THTensor_(tril)(THTensor *r_, THTensor *t, int64_t k);
 TH_API void THTensor_(triu)(THTensor *r_, THTensor *t, int64_t k);
 TH_API void THTensor_(cat)(THTensor *r_, THTensor *ta, THTensor *tb, int dimension);
 TH_API void THTensor_(catArray)(THTensor *result, THTensor **inputs, int numInputs, int dimension);
index d6c29a7..bb7edf3 100644 (file)
@@ -1202,37 +1202,6 @@ void THTensor_(topk)(THTensor *rt_, THLongTensor *ri_, THTensor *t, int64_t k, i
   THLongTensor_free(tmpIndices);
 }
 
-void THTensor_(tril)(THTensor *r_, THTensor *t, int64_t k)
-{
-  int64_t t_size_0, t_size_1;
-  int64_t t_stride_0, t_stride_1;
-  int64_t r__stride_0, r__stride_1;
-  scalar_t *t_data, *r__data;
-  int64_t r, c;
-
-  THArgCheck(THTensor_(nDimensionLegacyAll)(t) == 2, 1, "expected a matrix");
-
-  THTensor_(resizeAs)(r_, t);
-
-  t_size_0 = THTensor_(size)(t, 0);
-  t_size_1 = THTensor_(size)(t, 1);
-  t_stride_0 = THTensor_(stride)(t, 0);
-  t_stride_1 = THTensor_(stride)(t, 1);
-  r__stride_0 = THTensor_(stride)(r_, 0);
-  r__stride_1 = THTensor_(stride)(r_, 1);
-  r__data = r_->data<scalar_t>();
-  t_data = t->data<scalar_t>();
-
-  for(r = 0; r < t_size_0; r++)
-  {
-    int64_t sz = THMin(r+k+1, t_size_1);
-    for(c = THMax(0, r+k+1); c < t_size_1; c++)
-      r__data[r*r__stride_0+c*r__stride_1] = 0;
-    for(c = 0; c < sz; c++)
-      r__data[r*r__stride_0+c*r__stride_1] = t_data[r*t_stride_0+c*t_stride_1];
-  }
-}
-
 void THTensor_(triu)(THTensor *r_, THTensor *t, int64_t k)
 {
   int64_t t_size_0, t_size_1;
index 8221a37..8ac8ac4 100644 (file)
@@ -12,7 +12,6 @@ THC_API void THCTensor_(cat)(THCState *state, THCTensor *result, THCTensor *ta,
 THC_API void THCTensor_(catArray)(THCState *state, THCTensor *result, THCTensor **inputs, int numInputs, int dimension);
 THC_API void THCTensor_(nonzero)(THCState* state, THCudaLongTensor *tensor, THCTensor *self);
 
-THC_API void THCTensor_(tril)(THCState *state, THCTensor *self, THCTensor *src, int64_t k);
 THC_API void THCTensor_(triu)(THCState *state, THCTensor *self, THCTensor *src, int64_t k);
 THC_API void THCTensor_(diag)(THCState *state, THCTensor *self, THCTensor *src, int64_t k);
 THC_API void THCTensor_(eye)(THCState *state, THCTensor *self, int64_t n, int64_t k);
index b3b994d..c16e369 100644 (file)
@@ -168,35 +168,6 @@ void THCTensor_(remainder)(THCState *state, THCTensor *self_, THCTensor *src_, s
   THCudaCheck(cudaGetLastError());
 }
 
-void THCTensor_(tril)(THCState *state, THCTensor *self_, THCTensor *src_, int64_t k)
-{
-  THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src_));
-  THArgCheck(!src_->is_empty() && src_->dim() == 2, 1, "expected a matrix");
-
-  if (self_ != src_)
-    THCTensor_(resizeAs)(state, self_, src_);
-
-  int64_t stride0 = self_->stride(0);
-  int64_t stride1 = self_->stride(1);
-  scalar_t *start = THCTensor_(data)(state, self_);
-
-  TensorTriOp<scalar_t, 0> op(start, stride0, stride1, k);
-
-  if (self_ == src_) {
-    if (!THC_pointwiseApply1<scalar_t>(state, src_, op)) {
-      THArgCheck(false, 2, CUTORCH_DIM_WARNING);
-    }
-  } else {
-    THCTensor_(resizeAs)(state, self_, src_);
-
-    if (!THC_pointwiseApply2<scalar_t, scalar_t>(state, self_, src_, op)) {
-      THArgCheck(false, 2, CUTORCH_DIM_WARNING);
-    }
-  }
-
-  THCudaCheck(cudaGetLastError());
-}
-
 void THCTensor_(triu)(THCState *state, THCTensor *self_, THCTensor *src_, int64_t k)
 {
   THCAssertSameGPU(THCTensor_(checkGPU)(state, 2, self_, src_));
index 8d2b790..0be58d4 100644 (file)
@@ -569,8 +569,14 @@ def method_tests():
         ('diagonal', (M, M, M), (-2, 0, 1), '3d_3'),
         ('tril', (M, M), NO_ARGS),
         ('tril', (M, M), (2,), 'idx'),
+        ('tril', (S, M, M), NO_ARGS, 'batched'),
+        ('tril', (S, M, M), (2,), 'batched_idx'),
+        ('tril', (3, 3, S, S), NO_ARGS, 'more_batched'),
         ('triu', (M, M), NO_ARGS),
         ('triu', (M, M), (2,), 'idx'),
+        ('triu', (S, M, M), NO_ARGS, 'batched'),
+        ('triu', (S, M, M), (2,), 'batched_idx'),
+        ('triu', (3, 3, S, S), NO_ARGS, 'more_batched'),
         ('trace', (M, M), NO_ARGS),
         ('cross', (S, 3), ((S, 3),)),
         ('cross', (S, 3, S), ((S, 3, S), 1), 'dim'),
index df1f21e..26eab2a 100644 (file)
@@ -2226,6 +2226,9 @@ class TestCuda(TestCase):
         for test_args in tri_large_tests_args:
             _compare_large_trilu_indices(self, *test_args, device='cuda')
 
+    def test_triu_tril(self):
+        _TestTorchMixin._test_triu_tril(self, lambda t: t.cuda())
+
 
 def load_ignore_file():
     from os.path import join, dirname
index d42db3b..f35cda5 100644 (file)
@@ -3816,17 +3816,9 @@ class _TestTorchMixin(object):
         # input unchanged
         self.assertEqual(x, x0, 0)
 
-    def test_tril(self):
-        x = torch.rand(SIZE, SIZE)
-        res1 = torch.tril(x)
-        res2 = torch.Tensor()
-        torch.tril(x, out=res2)
-        self.assertEqual(res1, res2, 0)
-
     def test_trilu_indices(self):
         for test_args in tri_tests_args:
             _compare_trilu_indices(self, *test_args)
-
         run_additional_tri_tests(self, 'cpu')
 
         # test default options
@@ -3837,12 +3829,90 @@ class _TestTorchMixin(object):
         self.assertEqual(
             x.triu(0).nonzero().transpose(0, 1), torch.triu_indices(3, 3))
 
-    def test_triu(self):
-        x = torch.rand(SIZE, SIZE)
-        res1 = torch.triu(x)
-        res2 = torch.Tensor()
-        torch.triu(x, out=res2)
-        self.assertEqual(res1, res2, 0)
+    @staticmethod
+    def _test_triu_tril(self, cast):
+        def gen_mask(shape, diagonal, cast, upper):
+            mask = torch.zeros(*shape[-2:]).byte()
+            for i in range(shape[-2]):
+                for j in range(shape[-1]):
+                    cond = j - i < diagonal if upper else j - i > diagonal
+                    if cond:
+                        mask[i, j] = 1
+            return cast(mask.expand(*shape))
+
+        torch_functions = {True: torch.triu, False: torch.tril}
+        if TEST_NUMPY:
+            numpy_functions = {True: np.triu, False: np.tril}
+
+        def run_test(shape, cast, diagonal):
+            x_cpu = torch.randn(*shape)
+            x = cast(x_cpu)
+
+            for upper in [True, False]:
+                # normal test with mask
+                torch_tri_func = torch_functions[upper]
+                res1 = torch_tri_func(x, diagonal=diagonal)
+                res2 = cast(torch.Tensor())
+                torch_tri_func(x, diagonal=diagonal, out=res2)
+                exp_mask = gen_mask(shape, diagonal, cast, upper)
+                expected = torch.where(exp_mask, torch.tensor(0).type_as(x), x)
+                self.assertEqual(res1, res2, 0)
+                self.assertEqual(expected, res1, 0)
+
+                # non-contiguous and expanded tensors test
+                if not (0 in shape or 1 in shape):
+                    for s in range(-len(shape), -1):
+                        # non-contiguous tensors
+                        x_nc = x.clone().transpose(s, s + 1)
+                        exp_mask = gen_mask(x_nc.size(), diagonal, cast, upper)
+                        assert not x_nc.is_contiguous(), "x is intentionally non-contiguous"
+                        exp_nc = torch.where(exp_mask, torch.tensor(0).type_as(x), x_nc)
+                        self.assertEqual(torch_tri_func(x_nc, diagonal), exp_nc, 0)
+                        if upper:
+                            self.assertEqual(x_nc.triu_(diagonal), exp_nc, 0)
+                        else:
+                            self.assertEqual(x_nc.tril_(diagonal), exp_nc, 0)
+
+                        # any 3-dimensional tensor should be fine
+                        if len(shape) <= 3 or s == -2:
+                            self.assertFalse(x_nc.is_contiguous(),
+                                             "x_nc should remain non-contiguous")
+                        elif s < -3:
+                            self.assertTrue(x_nc.is_contiguous(),
+                                            "x_nc should become contiguous")
+
+                    # expanded tensors
+                    expanded_size = (x.size(0),) + x.size()
+                    x_expanded = x.clone().expand(*expanded_size)
+                    assert 0 in x_expanded.stride(), "x intentionally has 0 in its stride"
+                    output = torch_tri_func(x_expanded, diagonal)
+                    self.assertEqual(output, expected.expand(expanded_size), 0)
+                    self.assertTrue(0 in x_expanded.stride(),
+                                    "geometry of x_expanded should be the same")
+                    if upper:
+                        self.assertEqual(output, x_expanded.triu_(diagonal), 0)
+                    else:
+                        self.assertEqual(output, x_expanded.tril_(diagonal), 0)
+
+                if not TEST_NUMPY:
+                    continue
+
+                # numpy test
+                numpy_tri_func = numpy_functions[upper]
+                self.assertEqual(numpy_tri_func(x_cpu.numpy(), diagonal), res1.cpu().numpy())
+
+        diagonals = [-2, -1, 0, 1, 2]
+        shapes = [(3, 3), (5, 3, 3), (7, 5, 3, 3),  # square matrices
+                  (7, 3), (5, 7, 3), (7, 5, 7, 3),  # fat matrices
+                  (3, 7), (5, 3, 7), (7, 5, 3, 7),  # thin matrices
+                  (3, 0), (0, 3, 3), (3, 3, 0, 0),  # no numel matrices
+                  (3, 1), (5, 3, 1), (7, 5, 3, 1),  # very fat matrices
+                  (1, 3), (5, 1, 3), (7, 5, 1, 3)]  # very thin matrices
+        for s, d in product(shapes, diagonals):
+            run_test(s, cast, d)
+
+    def test_triu_tril(self):
+        self._test_triu_tril(self, lambda t: t)
 
     def test_cat(self):
         SIZE = 10
index 776f238..4f1a158 100644 (file)
@@ -675,22 +675,15 @@ Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) {
     L = L.transpose(-1, -2);
   }
 
-  auto batch_tril = [](const Tensor & A) -> Tensor {
-    int64_t n = A.size(-1);
-    auto indices = at::ones({n, n}, A.options().dtype(at::kByte));
-    indices = indices.triu(1).expand_as(A);
-    return at::where(indices, at::zeros({}, A.options()), A);
-  };
-
-  auto phi = [&batch_tril](const Tensor & A) -> Tensor {
-    auto B = A.dim() == 2 ? A.tril() : batch_tril(A);
+  auto phi = [](const Tensor & A) -> Tensor {
+    auto B = A.tril();
     B = B - 0.5 * at::diag_embed(B.diagonal(0, -1, -2), 0, -2, -1);
     return B;
   };
 
   // make sure not to double-count variation, since
   // only half of output matrix is unique
-  auto Lbar = grad.dim() == 2 ? grad.tril() : batch_tril(grad);
+  auto Lbar = grad.tril();
 
   auto P = phi(at::matmul(L, Lbar));
   Tensor S;
index 162c2af..f7056cb 100644 (file)
@@ -4951,8 +4951,8 @@ add_docstr(torch.tril,
            r"""
 tril(input, diagonal=0, out=None) -> Tensor
 
-Returns the lower triangular part of the matrix (2-D tensor) :attr:`input`,
-the other elements of the result tensor :attr:`out` are set to 0.
+Returns the lower triangular part of the matrix (2-D tensor) or batch of matrices
+:attr:`input`, the other elements of the result tensor :attr:`out` are set to 0.
 
 The lower triangular part of the matrix is defined as the elements on and
 below the diagonal.
@@ -5056,8 +5056,8 @@ add_docstr(torch.triu,
            r"""
 triu(input, diagonal=0, out=None) -> Tensor
 
-Returns the upper triangular part of the matrix (2-D tensor) :attr:`input`,
-the other elements of the result tensor :attr:`out` are set to 0.
+Returns the upper triangular part of a matrix (2-D tensor) or batch of matrices
+:attr:`input`, the other elements of the result tensor :attr:`out` are set to 0.
 
 The upper triangular part of the matrix is defined as the elements on and
 above the diagonal.
@@ -5101,19 +5101,19 @@ Example::
             [-0.2447,  0.9556, -1.2919,  1.3378, -0.1768, -1.0857],
             [ 0.4333,  0.3146,  0.6576, -1.0432,  0.9348, -0.4410],
             [-0.9888,  1.0679, -1.3337, -1.6556,  0.4798,  0.2830]])
-    >>> torch.tril(b, diagonal=1)
-    tensor([[ 0.5876, -0.0794,  0.0000,  0.0000,  0.0000,  0.0000],
-            [-0.2447,  0.9556, -1.2919,  0.0000,  0.0000,  0.0000],
-            [ 0.4333,  0.3146,  0.6576, -1.0432,  0.0000,  0.0000],
-            [-0.9888,  1.0679, -1.3337, -1.6556,  0.4798,  0.0000]])
-    >>> torch.tril(b, diagonal=-1)
-    tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
-            [-0.2447,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
-            [ 0.4333,  0.3146,  0.0000,  0.0000,  0.0000,  0.0000],
-            [-0.9888,  1.0679, -1.3337,  0.0000,  0.0000,  0.0000]])
+    >>> torch.triu(b, diagonal=1)
+    tensor([[ 0.0000, -0.0794, -1.8373,  0.6654,  0.2604,  1.5235],
+            [ 0.0000,  0.0000, -1.2919,  1.3378, -0.1768, -1.0857],
+            [ 0.0000,  0.0000,  0.0000, -1.0432,  0.9348, -0.4410],
+            [ 0.0000,  0.0000,  0.0000,  0.0000,  0.4798,  0.2830]])
+    >>> torch.triu(b, diagonal=-1)
+    tensor([[ 0.5876, -0.0794, -1.8373,  0.6654,  0.2604,  1.5235],
+            [-0.2447,  0.9556, -1.2919,  1.3378, -0.1768, -1.0857],
+            [ 0.0000,  0.3146,  0.6576, -1.0432,  0.9348, -0.4410],
+            [ 0.0000,  0.0000, -1.3337, -1.6556,  0.4798,  0.2830]])
 """)
 
-# docstr is split in two parts to avoid format mis-captureing :math: braces '{}'
+# docstr is split in two parts to avoid format mis-capturing :math: braces '{}'
 # as common args.
 add_docstr(torch.triu_indices,
            r"""
index 8320535..7f6adf9 100644 (file)
@@ -19,7 +19,6 @@ The following constraints are implemented:
 """
 
 import torch
-from torch.distributions.utils import batch_tril
 
 __all__ = [
     'Constraint',
@@ -256,7 +255,7 @@ class _LowerTriangular(Constraint):
     Constrain to lower-triangular square matrices.
     """
     def check(self, value):
-        value_tril = batch_tril(value)
+        value_tril = value.tril()
         return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
 
 
@@ -265,12 +264,10 @@ class _LowerCholesky(Constraint):
     Constrain to lower-triangular square matrices with positive diagonals.
     """
     def check(self, value):
-        value_tril = batch_tril(value)
+        value_tril = value.tril()
         lower_triangular = (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
 
-        n = value.size(-1)
-        diag_mask = torch.eye(n, n, dtype=value.dtype, device=value.device)
-        positive_diagonal = (value * diag_mask > (diag_mask - 1)).min(-1)[0].min(-1)[0]
+        positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0]
         return lower_triangular & positive_diagonal
 
 
index 698c5a2..0f5a2f1 100644 (file)
@@ -90,18 +90,6 @@ def probs_to_logits(probs, is_binary=False):
     return torch.log(ps_clamped)
 
 
-def batch_tril(bmat, diagonal=0):
-    """
-    Given a batch of matrices, returns the lower triangular part of each matrix, with
-    the other entries set to 0. The argument `diagonal` has the same meaning as in
-    `torch.tril`.
-    """
-    if bmat.dim() == 2:
-        return bmat.tril(diagonal=diagonal)
-    else:
-        return bmat * torch.tril(bmat.new(*bmat.shape[-2:]).fill_(1.0), diagonal=diagonal)
-
-
 class lazy_property(object):
     r"""
     Used as a decorator for lazy loading of class attributes. This uses a
index 81873ee..54e87f2 100644 (file)
@@ -101,10 +101,8 @@ def btriunpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True):
     sz = LU_data.size(-1)
 
     if unpack_data:
-        I_U = torch.ones(sz, sz, device=LU_data.device, dtype=torch.uint8).triu_().expand_as(LU_data)
-        zero = torch.tensor(0.).type_as(LU_data)
-        U = torch.where(I_U, LU_data, zero)
-        L = torch.where(I_U, zero, LU_data)
+        U = LU_data.triu()
+        L = LU_data.tril()
         L.diagonal(dim1=-2, dim2=-1).fill_(1)
     else:
         L = U = None