Implement torch.tril_indices and torch.triu_indices (#12653) (#14904)
authorShen Li <shenli@fb.com>
Wed, 12 Dec 2018 23:18:57 +0000 (15:18 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Wed, 12 Dec 2018 23:40:14 +0000 (15:40 -0800)
Summary:
This is an optimized implementation that does the following:

1. created an empty Tensor of correct size.
2. fill the Tensor with correct values.

The following three designs to fill in the Tensor result in roughly the same performance. Hence, the 2nd option is taken for simpler code, and to return contiguous tensors.

1. Sequential: fill row coordinates first, then columns. This results in two for-loop and more arithmetic operations.
2. Interleaved: fill in index coordinates one by one, which jumps between the two output Tensor rows in every iteration.
3. Transpose: create a n X 2 Tensor, fill the Tensor sequentially, and then transpose it.

<img width="352" alt="screen shot 2018-12-10 at 3 54 39 pm" src="https://user-images.githubusercontent.com/16999635/49769172-07bd3580-fc94-11e8-8164-41839185e9f9.png">

NOTE:

This implementation returns a 2D tensor, instead of a tuple of two tensors. It means that users will not be able to do the following:

```python
x = torch.ones(3, 3)
i = torch.tril_indices(3, 3)
x[i]  # need to first convert the 2D tensor into a tuple of two 1D tensors.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14904

Reviewed By: zou3519

Differential Revision: D13433027

Pulled By: mrshenli

fbshipit-source-id: 41c876aafcf584832d7069f7c5929ffb59e0ae6a

aten/src/ATen/native/TensorFactories.cpp
aten/src/ATen/native/native_functions.yaml
test/test_cuda.py
test/test_torch.py
tools/autograd/gen_python_functions.py
torch/_torch_docs.py

index ed4537f..ef9608f 100644 (file)
@@ -523,6 +523,154 @@ Tensor& range_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
   return at::legacy::th::_th_range_out(result, start, end, step);
 }
 
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+namespace {
+// Different combinations of row, col, and offset can lead to two cases:
+//
+// Case 1 - Trapezoid (Triangle as a special case): row + offset <= col
+//    Example A: offset > 0
+//      1 1 0 0 0
+//      1 1 1 0 0
+//      1 1 1 1 0
+//    Example B: offset <= 0
+//      0 0 0
+//      1 0 0
+//      1 1 0
+//    In this case, we calculate the number of elements in the first row and
+//    last row of the tril respectively, and then compute the tril size.
+//
+// Case 2 - Trapezoid + Rectangle: row + offset > col
+//    Example:
+//      1 1 0
+//      1 1 1
+//      1 1 1
+//    In this case, we first calculate the size of top trapezoid, and then
+//    calculate the size of the bottom rectangle.
+inline int64_t get_tril_size(int64_t row, int64_t col, int64_t offset) {
+  // number of elements in the first row of the tril
+  auto m_first_row = offset > 0 ?
+    std::min<int64_t>(col, 1 + offset) : // upper bounded by col
+    row + offset > 0; // either 0 or 1
+  // number of elements in the last row of the tril, bounded by [0, col]
+  auto m_last_row = std::max<int64_t>(0, std::min<int64_t>(col, row + offset));
+  // number of rows, bounded by [0, row]
+  auto n_row_all = std::max<int64_t>(0, std::min<int64_t>(row, row + offset));
+  auto n_row_trapezoid = (m_last_row - m_first_row + 1);
+
+  // calculate # of elements in the top trapezoid
+  auto n_indices =
+    (m_first_row + m_last_row) * n_row_trapezoid >> 1;
+
+  // calculate # of elements in the bottom rectangle if there is any
+  auto diff_row = n_row_all - n_row_trapezoid;
+  if (diff_row > 0) {
+    n_indices += diff_row * col;
+  }
+
+  return n_indices;
+}
+
+inline void check_args(
+    int64_t row, int64_t col, const TensorOptions& options) {
+  AT_CHECK(row >= 0, "row must be non-negative, got", row);
+  AT_CHECK(col >= 0, "col must be non-negative, got", col);
+  if (options.has_device()) {
+    AT_CHECK(
+      options.device() == at::kCPU,
+      "only support device='cpu', got",
+      options.device());
+  }
+  if (options.has_layout()) {
+    AT_CHECK(
+      options.layout() == at::kStrided,
+      "only support layout=torch.strided, got",
+      options.layout())
+  }
+}
+} // namespace
+
+Tensor tril_indices(
+    int64_t row, int64_t col, int64_t offset, const TensorOptions& options) {
+  check_args(row, col, options);
+
+  auto n_indices = get_tril_size(row, col, offset);
+
+  // create an empty Tensor with correct size
+  auto result = at::empty({2, n_indices}, options);
+
+  // The following three approaches result in very little performance
+  // differences. Hence, the 2nd option is taken for simpler code, and to return
+  // contiguous tensors. Refer to #14904 for more details.
+  //
+  // 1. sequential RAM access: fill row coordinates first, then columns. This
+  //    results in two for-loop and more arithmetic operations.
+  //
+  // 2. interleaved RAM access: fill in index coordinates one by one, which
+  //    jumps between the two output Tensor rows in every iteration.
+  //
+  // 3. sequential RAM + transpose: create an n X 2 Tensor, fill the Tensor
+  //    sequentially, and then transpose it.
+  AT_DISPATCH_ALL_TYPES(result.type(), "tril_indices", [&]() -> void {
+    // fill the Tensor with correct values
+    scalar_t* result_data = result.data<scalar_t>();
+    int64_t i = 0;
+
+    scalar_t r = std::max<int64_t>(0, -offset), c = 0;
+    while (i < n_indices) {
+      result_data[i] = r;
+      result_data[n_indices + i++] = c;
+
+      // move to the next column and check if (r, c) is still in bound
+      c += 1;
+      if (c > r + offset || c >= col) {
+        r += 1;
+        c = 0;
+        // NOTE: not necessary to check if r is less than row here, because i
+        // and n_indices provide the guarantee
+      }
+    }
+  });
+
+  return result;
+}
+
+Tensor triu_indices(
+    int64_t row, int64_t col, int64_t offset, const TensorOptions& options) {
+  check_args(row, col, options);
+
+  auto n_indices = row * col - get_tril_size(row, col, offset - 1);
+
+  // create an empty Tensor with correct size
+  auto result = at::empty({2, n_indices}, options);
+
+  AT_DISPATCH_ALL_TYPES(result.type(), "triu_indices", [&]() -> void {
+    // fill the Tensor with correct values
+    scalar_t* result_data = result.data<scalar_t>();
+    int64_t i = 0;
+    // not typing std::max with scalar_t as it could be an unsigned type
+    // NOTE: no need to check if the returned value of std::max overflows
+    // scalar_t, as i and n_indices act as a guard.
+    scalar_t c = std::max<int64_t>(0, offset), r = 0;
+    while (i < n_indices) {
+      result_data[i] = r;
+      result_data[n_indices + i++] = c;
+
+      // move to the next column and check if (r, c) is still in bound
+      c += 1;
+      if (c >= col) {
+        r += 1;
+        // not typing std::max with scalar_t as it could be an unsigned type
+        // NOTE: not necessary to check if c is less than col or overflows here,
+        // because i and n_indices act as a guard.
+        c = std::max<int64_t>(0, r + offset);
+      }
+    }
+  });
+
+  return result;
+}
+
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ zeros ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
 Tensor zeros(IntList size, const TensorOptions& options) {
index 949bc5a..5729e0d 100644 (file)
 - func: tril(Tensor self, int64_t diagonal=0) -> Tensor
   variants: method, function
 
+- func: tril_indices(int64_t row, int64_t col, int64_t offset=0, TensorOptions options=at::kLong) -> Tensor
+
+- func: triu_indices(int64_t row, int64_t col, int64_t offset=0, TensorOptions options=at::kLong) -> Tensor
+
 - func: trace(Tensor self) -> Tensor
   variants: method, function
 
index 8226479..ffb6281 100644 (file)
@@ -2118,6 +2118,17 @@ class TestCuda(TestCase):
                 y = torch.randn(2, 1, device='cuda')
                 z = x + y
 
+    def test_tril_and_triu_indices(self):
+        self.assertRaises(
+            RuntimeError,
+            lambda: torch.triu_indices(
+                1, 1, device='cuda', layout=torch.strided))
+
+        self.assertRaises(
+            RuntimeError,
+            lambda: torch.tril_indices(
+                1, 1, device='cuda', layout=torch.strided))
+
 
 def load_ignore_file():
     from os.path import join, dirname
index 9be44f7..726c9b1 100644 (file)
@@ -3764,6 +3764,92 @@ class _TestTorchMixin(object):
         torch.tril(x, out=res2)
         self.assertEqual(res1, res2, 0)
 
+    def _compare_trilu_indices(self, row, col, offset=0, dtype=torch.long):
+        if row == 0 or col == 0:
+            # have to handle this separately as tril and triu does not take
+            # empty matrix as input
+            self.assertEqual(
+                torch.empty(0, 2, dtype=dtype).transpose(0, 1),
+                torch.tril_indices(row, col, offset, dtype=dtype))
+
+            self.assertEqual(
+                torch.empty(0, 2, dtype=dtype).transpose(0, 1),
+                torch.triu_indices(row, col, offset, dtype=dtype))
+
+        else:
+            self.assertEqual(
+                torch.ones(row, col, dtype=dtype)
+                     .tril(offset).nonzero().transpose(0, 1),
+                torch.tril_indices(row, col, offset, dtype=dtype))
+
+            self.assertEqual(
+                torch.ones(row, col, dtype=dtype)
+                     .triu(offset).nonzero().transpose(0, 1),
+                torch.triu_indices(row, col, offset, dtype=dtype))
+
+    def test_tril_and_triu_indices(self):
+        self._compare_trilu_indices(1, 1)
+        self._compare_trilu_indices(3, 3)
+        self._compare_trilu_indices(3, 3, offset=1)
+        self._compare_trilu_indices(3, 3, offset=2)
+        self._compare_trilu_indices(3, 3, offset=200)
+        self._compare_trilu_indices(3, 3, offset=-1)
+        self._compare_trilu_indices(3, 3, offset=-2)
+        self._compare_trilu_indices(3, 3, offset=-200)
+        self._compare_trilu_indices(0, 3, offset=0)
+        self._compare_trilu_indices(0, 3, offset=1)
+        self._compare_trilu_indices(0, 3, offset=-1)
+        self._compare_trilu_indices(3, 0, offset=0)
+        self._compare_trilu_indices(3, 0, offset=1)
+        self._compare_trilu_indices(3, 0, offset=-1)
+        self._compare_trilu_indices(0, 0, offset=0)
+        self._compare_trilu_indices(0, 0, offset=1)
+        self._compare_trilu_indices(0, 0, offset=-1)
+        self._compare_trilu_indices(3, 6, offset=0)
+        self._compare_trilu_indices(3, 6, offset=1)
+        self._compare_trilu_indices(3, 6, offset=3)
+        self._compare_trilu_indices(3, 6, offset=9)
+        self._compare_trilu_indices(3, 6, offset=-1)
+        self._compare_trilu_indices(3, 6, offset=-3)
+        self._compare_trilu_indices(3, 6, offset=-9)
+        self._compare_trilu_indices(6, 3, offset=0)
+        self._compare_trilu_indices(6, 3, offset=1)
+        self._compare_trilu_indices(6, 3, offset=3)
+        self._compare_trilu_indices(6, 3, offset=9)
+        self._compare_trilu_indices(6, 3, offset=-1)
+        self._compare_trilu_indices(6, 3, offset=-3)
+        self._compare_trilu_indices(6, 3, offset=-9)
+        self._compare_trilu_indices(258, 253, offset=1, dtype=torch.float32)
+        self._compare_trilu_indices(257, 258, offset=1, dtype=torch.float64)
+        self._compare_trilu_indices(258, 258, offset=1, dtype=torch.short)
+        self._compare_trilu_indices(3, 513, offset=1, dtype=torch.long)
+        self._compare_trilu_indices(513, 3, offset=1, dtype=torch.int)
+        self._compare_trilu_indices(513, 0, offset=1, dtype=torch.double)
+
+        x = torch.ones(
+            3, 3, dtype=torch.long, device='cpu', layout=torch.strided)
+        l = x.tril(0).nonzero().transpose(0, 1)
+        u = x.triu(0).nonzero().transpose(0, 1)
+        self.assertEqual(l, torch.tril_indices(3, 3))
+        self.assertEqual(l, torch.tril_indices(3, 3, device='cpu'))
+        self.assertEqual(
+            l, torch.tril_indices(3, 3, device='cpu', layout=torch.strided))
+
+        self.assertEqual(u, torch.triu_indices(3, 3))
+        self.assertEqual(u, torch.triu_indices(3, 3, device='cpu'))
+        self.assertEqual(
+            u, torch.triu_indices(3, 3, device='cpu', layout=torch.strided))
+
+        self.assertRaises(
+            RuntimeError,
+            lambda: torch.triu_indices(
+                1, 1, device='cpu', layout=torch.sparse_coo))
+
+        self.assertRaises(
+            RuntimeError,
+            lambda: torch.tril_indices(
+                1, 1, device='cpu', layout=torch.sparse_coo))
+
     def test_triu(self):
         x = torch.rand(SIZE, SIZE)
         res1 = torch.triu(x)
index e8042b3..1a0fc7f 100644 (file)
@@ -228,7 +228,9 @@ def group_declarations_by_name(declarations, should_bind_fn):
 
 
 def get_type_default(declaration):
-    if declaration['name'].startswith('randperm'):
+    if declaration['name'].startswith('randperm') or \
+            declaration['name'] == 'tril_indices' or \
+            declaration['name'] == 'triu_indices':
         return 'torch.int64'
     else:
         return 'None'
index 389d0c2..744b111 100644 (file)
@@ -4987,6 +4987,55 @@ Example::
             [-0.0614, -0.7344, -1.3164,  0.0000,  0.0000,  0.0000]])
 """)
 
+# docstr is split in two parts to avoid format mis-captureing :math: braces '{}'
+# as common args.
+add_docstr(torch.tril_indices,
+           r"""
+tril_indices(row, column, offset=0, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor
+
+Returns the indices of the lower triangular part of a :attr:`row`-by-
+:attr:`column` matrix in a 2-by-N Tensor, where the first row contains row
+coordinates of all indices and the second row contains column coordinates.
+Indices are ordered based on rows and then columns.
+
+The lower triangular part of the matrix is defined as the elements on and
+below the diagonal.
+
+The argument :attr:`offset` controls which diagonal to consider. If
+:attr:`offset` = 0, all elements on and below the main diagonal are
+retained. A positive value includes just as many diagonals above the main
+diagonal, and similarly a negative value excludes just as many diagonals below
+the main diagonal. The main diagonal are the set of indices
+:math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]`
+where :math:`d_{1}, d_{2}` are the dimensions of the matrix.
+
+Args:
+    row (``int``): number of rows in the 2-D matrix.
+    column (``int``): number of columns in the 2-D matrix.
+    offset (``int``): diagonal offset from the main diagonal.
+        Default: if not provided, 0.
+    dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+        Default: if ``None``, ``torch.long``.
+    device (:class:`torch.device`, optional): currently only support ``cpu``.
+    layout (:class:`torch.layout`, optional): currently only support ``torch.strided``.
+
+Example::
+    >>> a = torch.tril_indices(3, 3)
+    >>> a
+    tensor([[0, 1, 1, 2, 2, 2],
+            [0, 0, 1, 0, 1, 2]])
+
+    >>> a = torch.tril_indices(4, 3, -1)
+    >>> a
+    tensor([[1, 2, 2, 3, 3, 3],
+            [0, 0, 1, 0, 1, 2]])
+
+    >>> a = torch.tril_indices(4, 3, 1)
+    >>> a
+    tensor([[0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3],
+            [0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2]])
+""")
+
 add_docstr(torch.triu,
            r"""
 triu(input, diagonal=0, out=None) -> Tensor
@@ -5048,6 +5097,55 @@ Example::
             [-0.9888,  1.0679, -1.3337,  0.0000,  0.0000,  0.0000]])
 """)
 
+# docstr is split in two parts to avoid format mis-captureing :math: braces '{}'
+# as common args.
+add_docstr(torch.triu_indices,
+           r"""
+triu_indices(row, column, offset=0, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor
+
+Returns the indices of the upper triangular part of a :attr:`row` by
+:attr:`column` matrix in a 2-by-N Tensor, where the first row contains row
+coordinates of all indices and the second row contains column coordinates.
+Indices are ordered based on rows and then columns.
+
+The upper triangular part of the matrix is defined as the elements on and
+above the diagonal.
+
+The argument :attr:`offset` controls which diagonal to consider. If
+:attr:`offset` = 0, all elements on and above the main diagonal are
+retained. A positive value excludes just as many diagonals above the main
+diagonal, and similarly a negative value includes just as many diagonals below
+the main diagonal. The main diagonal are the set of indices
+:math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]`
+where :math:`d_{1}, d_{2}` are the dimensions of the matrix.
+
+Args:
+    row (``int``): number of rows in the 2-D matrix.
+    column (``int``): number of columns in the 2-D matrix.
+    offset (``int``): diagonal offset from the main diagonal.
+        Default: if not provided, 0.
+    dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+        Default: if ``None``, ``torch.long``.
+    device (:class:`torch.device`, optional): currently only support ``cpu``.
+    layout (:class:`torch.layout`, optional): currently only support ``torch.strided``.
+
+Example::
+    >>> a = torch.triu_indices(3, 3)
+    >>> a
+    tensor([[0, 0, 0, 1, 1, 2],
+            [0, 1, 2, 1, 2, 2]])
+
+    >>> a = torch.triu_indices(4, 3, -1)
+    >>> a
+    tensor([[0, 0, 0, 1, 1, 1, 2, 2, 3],
+            [0, 1, 2, 0, 1, 2, 1, 2, 2]])
+
+    >>> a = torch.triu_indices(4, 3, 1)
+    >>> a
+    tensor([[0, 0, 1],
+            [1, 2, 2]])
+""")
+
 add_docstr(torch.trtrs,
            r"""
 trtrs(b, A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor)