return at::legacy::th::_th_range_out(result, start, end, step);
}
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+namespace {
+// Different combinations of row, col, and offset can lead to two cases:
+//
+// Case 1 - Trapezoid (Triangle as a special case): row + offset <= col
+// Example A: offset > 0
+// 1 1 0 0 0
+// 1 1 1 0 0
+// 1 1 1 1 0
+// Example B: offset <= 0
+// 0 0 0
+// 1 0 0
+// 1 1 0
+// In this case, we calculate the number of elements in the first row and
+// last row of the tril respectively, and then compute the tril size.
+//
+// Case 2 - Trapezoid + Rectangle: row + offset > col
+// Example:
+// 1 1 0
+// 1 1 1
+// 1 1 1
+// In this case, we first calculate the size of top trapezoid, and then
+// calculate the size of the bottom rectangle.
+inline int64_t get_tril_size(int64_t row, int64_t col, int64_t offset) {
+ // number of elements in the first row of the tril
+ auto m_first_row = offset > 0 ?
+ std::min<int64_t>(col, 1 + offset) : // upper bounded by col
+ row + offset > 0; // either 0 or 1
+ // number of elements in the last row of the tril, bounded by [0, col]
+ auto m_last_row = std::max<int64_t>(0, std::min<int64_t>(col, row + offset));
+ // number of rows, bounded by [0, row]
+ auto n_row_all = std::max<int64_t>(0, std::min<int64_t>(row, row + offset));
+ auto n_row_trapezoid = (m_last_row - m_first_row + 1);
+
+ // calculate # of elements in the top trapezoid
+ auto n_indices =
+ (m_first_row + m_last_row) * n_row_trapezoid >> 1;
+
+ // calculate # of elements in the bottom rectangle if there is any
+ auto diff_row = n_row_all - n_row_trapezoid;
+ if (diff_row > 0) {
+ n_indices += diff_row * col;
+ }
+
+ return n_indices;
+}
+
+inline void check_args(
+ int64_t row, int64_t col, const TensorOptions& options) {
+ AT_CHECK(row >= 0, "row must be non-negative, got", row);
+ AT_CHECK(col >= 0, "col must be non-negative, got", col);
+ if (options.has_device()) {
+ AT_CHECK(
+ options.device() == at::kCPU,
+ "only support device='cpu', got",
+ options.device());
+ }
+ if (options.has_layout()) {
+ AT_CHECK(
+ options.layout() == at::kStrided,
+ "only support layout=torch.strided, got",
+ options.layout())
+ }
+}
+} // namespace
+
+Tensor tril_indices(
+ int64_t row, int64_t col, int64_t offset, const TensorOptions& options) {
+ check_args(row, col, options);
+
+ auto n_indices = get_tril_size(row, col, offset);
+
+ // create an empty Tensor with correct size
+ auto result = at::empty({2, n_indices}, options);
+
+ // The following three approaches result in very little performance
+ // differences. Hence, the 2nd option is taken for simpler code, and to return
+ // contiguous tensors. Refer to #14904 for more details.
+ //
+ // 1. sequential RAM access: fill row coordinates first, then columns. This
+ // results in two for-loop and more arithmetic operations.
+ //
+ // 2. interleaved RAM access: fill in index coordinates one by one, which
+ // jumps between the two output Tensor rows in every iteration.
+ //
+ // 3. sequential RAM + transpose: create an n X 2 Tensor, fill the Tensor
+ // sequentially, and then transpose it.
+ AT_DISPATCH_ALL_TYPES(result.type(), "tril_indices", [&]() -> void {
+ // fill the Tensor with correct values
+ scalar_t* result_data = result.data<scalar_t>();
+ int64_t i = 0;
+
+ scalar_t r = std::max<int64_t>(0, -offset), c = 0;
+ while (i < n_indices) {
+ result_data[i] = r;
+ result_data[n_indices + i++] = c;
+
+ // move to the next column and check if (r, c) is still in bound
+ c += 1;
+ if (c > r + offset || c >= col) {
+ r += 1;
+ c = 0;
+ // NOTE: not necessary to check if r is less than row here, because i
+ // and n_indices provide the guarantee
+ }
+ }
+ });
+
+ return result;
+}
+
+Tensor triu_indices(
+ int64_t row, int64_t col, int64_t offset, const TensorOptions& options) {
+ check_args(row, col, options);
+
+ auto n_indices = row * col - get_tril_size(row, col, offset - 1);
+
+ // create an empty Tensor with correct size
+ auto result = at::empty({2, n_indices}, options);
+
+ AT_DISPATCH_ALL_TYPES(result.type(), "triu_indices", [&]() -> void {
+ // fill the Tensor with correct values
+ scalar_t* result_data = result.data<scalar_t>();
+ int64_t i = 0;
+ // not typing std::max with scalar_t as it could be an unsigned type
+ // NOTE: no need to check if the returned value of std::max overflows
+ // scalar_t, as i and n_indices act as a guard.
+ scalar_t c = std::max<int64_t>(0, offset), r = 0;
+ while (i < n_indices) {
+ result_data[i] = r;
+ result_data[n_indices + i++] = c;
+
+ // move to the next column and check if (r, c) is still in bound
+ c += 1;
+ if (c >= col) {
+ r += 1;
+ // not typing std::max with scalar_t as it could be an unsigned type
+ // NOTE: not necessary to check if c is less than col or overflows here,
+ // because i and n_indices act as a guard.
+ c = std::max<int64_t>(0, r + offset);
+ }
+ }
+ });
+
+ return result;
+}
+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ zeros ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Tensor zeros(IntList size, const TensorOptions& options) {
- func: tril(Tensor self, int64_t diagonal=0) -> Tensor
variants: method, function
+- func: tril_indices(int64_t row, int64_t col, int64_t offset=0, TensorOptions options=at::kLong) -> Tensor
+
+- func: triu_indices(int64_t row, int64_t col, int64_t offset=0, TensorOptions options=at::kLong) -> Tensor
+
- func: trace(Tensor self) -> Tensor
variants: method, function
y = torch.randn(2, 1, device='cuda')
z = x + y
+ def test_tril_and_triu_indices(self):
+ self.assertRaises(
+ RuntimeError,
+ lambda: torch.triu_indices(
+ 1, 1, device='cuda', layout=torch.strided))
+
+ self.assertRaises(
+ RuntimeError,
+ lambda: torch.tril_indices(
+ 1, 1, device='cuda', layout=torch.strided))
+
def load_ignore_file():
from os.path import join, dirname
torch.tril(x, out=res2)
self.assertEqual(res1, res2, 0)
+ def _compare_trilu_indices(self, row, col, offset=0, dtype=torch.long):
+ if row == 0 or col == 0:
+ # have to handle this separately as tril and triu does not take
+ # empty matrix as input
+ self.assertEqual(
+ torch.empty(0, 2, dtype=dtype).transpose(0, 1),
+ torch.tril_indices(row, col, offset, dtype=dtype))
+
+ self.assertEqual(
+ torch.empty(0, 2, dtype=dtype).transpose(0, 1),
+ torch.triu_indices(row, col, offset, dtype=dtype))
+
+ else:
+ self.assertEqual(
+ torch.ones(row, col, dtype=dtype)
+ .tril(offset).nonzero().transpose(0, 1),
+ torch.tril_indices(row, col, offset, dtype=dtype))
+
+ self.assertEqual(
+ torch.ones(row, col, dtype=dtype)
+ .triu(offset).nonzero().transpose(0, 1),
+ torch.triu_indices(row, col, offset, dtype=dtype))
+
+ def test_tril_and_triu_indices(self):
+ self._compare_trilu_indices(1, 1)
+ self._compare_trilu_indices(3, 3)
+ self._compare_trilu_indices(3, 3, offset=1)
+ self._compare_trilu_indices(3, 3, offset=2)
+ self._compare_trilu_indices(3, 3, offset=200)
+ self._compare_trilu_indices(3, 3, offset=-1)
+ self._compare_trilu_indices(3, 3, offset=-2)
+ self._compare_trilu_indices(3, 3, offset=-200)
+ self._compare_trilu_indices(0, 3, offset=0)
+ self._compare_trilu_indices(0, 3, offset=1)
+ self._compare_trilu_indices(0, 3, offset=-1)
+ self._compare_trilu_indices(3, 0, offset=0)
+ self._compare_trilu_indices(3, 0, offset=1)
+ self._compare_trilu_indices(3, 0, offset=-1)
+ self._compare_trilu_indices(0, 0, offset=0)
+ self._compare_trilu_indices(0, 0, offset=1)
+ self._compare_trilu_indices(0, 0, offset=-1)
+ self._compare_trilu_indices(3, 6, offset=0)
+ self._compare_trilu_indices(3, 6, offset=1)
+ self._compare_trilu_indices(3, 6, offset=3)
+ self._compare_trilu_indices(3, 6, offset=9)
+ self._compare_trilu_indices(3, 6, offset=-1)
+ self._compare_trilu_indices(3, 6, offset=-3)
+ self._compare_trilu_indices(3, 6, offset=-9)
+ self._compare_trilu_indices(6, 3, offset=0)
+ self._compare_trilu_indices(6, 3, offset=1)
+ self._compare_trilu_indices(6, 3, offset=3)
+ self._compare_trilu_indices(6, 3, offset=9)
+ self._compare_trilu_indices(6, 3, offset=-1)
+ self._compare_trilu_indices(6, 3, offset=-3)
+ self._compare_trilu_indices(6, 3, offset=-9)
+ self._compare_trilu_indices(258, 253, offset=1, dtype=torch.float32)
+ self._compare_trilu_indices(257, 258, offset=1, dtype=torch.float64)
+ self._compare_trilu_indices(258, 258, offset=1, dtype=torch.short)
+ self._compare_trilu_indices(3, 513, offset=1, dtype=torch.long)
+ self._compare_trilu_indices(513, 3, offset=1, dtype=torch.int)
+ self._compare_trilu_indices(513, 0, offset=1, dtype=torch.double)
+
+ x = torch.ones(
+ 3, 3, dtype=torch.long, device='cpu', layout=torch.strided)
+ l = x.tril(0).nonzero().transpose(0, 1)
+ u = x.triu(0).nonzero().transpose(0, 1)
+ self.assertEqual(l, torch.tril_indices(3, 3))
+ self.assertEqual(l, torch.tril_indices(3, 3, device='cpu'))
+ self.assertEqual(
+ l, torch.tril_indices(3, 3, device='cpu', layout=torch.strided))
+
+ self.assertEqual(u, torch.triu_indices(3, 3))
+ self.assertEqual(u, torch.triu_indices(3, 3, device='cpu'))
+ self.assertEqual(
+ u, torch.triu_indices(3, 3, device='cpu', layout=torch.strided))
+
+ self.assertRaises(
+ RuntimeError,
+ lambda: torch.triu_indices(
+ 1, 1, device='cpu', layout=torch.sparse_coo))
+
+ self.assertRaises(
+ RuntimeError,
+ lambda: torch.tril_indices(
+ 1, 1, device='cpu', layout=torch.sparse_coo))
+
def test_triu(self):
x = torch.rand(SIZE, SIZE)
res1 = torch.triu(x)
def get_type_default(declaration):
- if declaration['name'].startswith('randperm'):
+ if declaration['name'].startswith('randperm') or \
+ declaration['name'] == 'tril_indices' or \
+ declaration['name'] == 'triu_indices':
return 'torch.int64'
else:
return 'None'
[-0.0614, -0.7344, -1.3164, 0.0000, 0.0000, 0.0000]])
""")
+# docstr is split in two parts to avoid format mis-captureing :math: braces '{}'
+# as common args.
+add_docstr(torch.tril_indices,
+ r"""
+tril_indices(row, column, offset=0, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor
+
+Returns the indices of the lower triangular part of a :attr:`row`-by-
+:attr:`column` matrix in a 2-by-N Tensor, where the first row contains row
+coordinates of all indices and the second row contains column coordinates.
+Indices are ordered based on rows and then columns.
+
+The lower triangular part of the matrix is defined as the elements on and
+below the diagonal.
+
+The argument :attr:`offset` controls which diagonal to consider. If
+:attr:`offset` = 0, all elements on and below the main diagonal are
+retained. A positive value includes just as many diagonals above the main
+diagonal, and similarly a negative value excludes just as many diagonals below
+the main diagonal. The main diagonal are the set of indices
+:math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]`
+where :math:`d_{1}, d_{2}` are the dimensions of the matrix.
+
+Args:
+ row (``int``): number of rows in the 2-D matrix.
+ column (``int``): number of columns in the 2-D matrix.
+ offset (``int``): diagonal offset from the main diagonal.
+ Default: if not provided, 0.
+ dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+ Default: if ``None``, ``torch.long``.
+ device (:class:`torch.device`, optional): currently only support ``cpu``.
+ layout (:class:`torch.layout`, optional): currently only support ``torch.strided``.
+
+Example::
+ >>> a = torch.tril_indices(3, 3)
+ >>> a
+ tensor([[0, 1, 1, 2, 2, 2],
+ [0, 0, 1, 0, 1, 2]])
+
+ >>> a = torch.tril_indices(4, 3, -1)
+ >>> a
+ tensor([[1, 2, 2, 3, 3, 3],
+ [0, 0, 1, 0, 1, 2]])
+
+ >>> a = torch.tril_indices(4, 3, 1)
+ >>> a
+ tensor([[0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3],
+ [0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2]])
+""")
+
add_docstr(torch.triu,
r"""
triu(input, diagonal=0, out=None) -> Tensor
[-0.9888, 1.0679, -1.3337, 0.0000, 0.0000, 0.0000]])
""")
+# docstr is split in two parts to avoid format mis-captureing :math: braces '{}'
+# as common args.
+add_docstr(torch.triu_indices,
+ r"""
+triu_indices(row, column, offset=0, dtype=torch.long, device='cpu', layout=torch.strided) -> Tensor
+
+Returns the indices of the upper triangular part of a :attr:`row` by
+:attr:`column` matrix in a 2-by-N Tensor, where the first row contains row
+coordinates of all indices and the second row contains column coordinates.
+Indices are ordered based on rows and then columns.
+
+The upper triangular part of the matrix is defined as the elements on and
+above the diagonal.
+
+The argument :attr:`offset` controls which diagonal to consider. If
+:attr:`offset` = 0, all elements on and above the main diagonal are
+retained. A positive value excludes just as many diagonals above the main
+diagonal, and similarly a negative value includes just as many diagonals below
+the main diagonal. The main diagonal are the set of indices
+:math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]`
+where :math:`d_{1}, d_{2}` are the dimensions of the matrix.
+
+Args:
+ row (``int``): number of rows in the 2-D matrix.
+ column (``int``): number of columns in the 2-D matrix.
+ offset (``int``): diagonal offset from the main diagonal.
+ Default: if not provided, 0.
+ dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
+ Default: if ``None``, ``torch.long``.
+ device (:class:`torch.device`, optional): currently only support ``cpu``.
+ layout (:class:`torch.layout`, optional): currently only support ``torch.strided``.
+
+Example::
+ >>> a = torch.triu_indices(3, 3)
+ >>> a
+ tensor([[0, 0, 0, 1, 1, 2],
+ [0, 1, 2, 1, 2, 2]])
+
+ >>> a = torch.triu_indices(4, 3, -1)
+ >>> a
+ tensor([[0, 0, 0, 1, 1, 1, 2, 2, 3],
+ [0, 1, 2, 0, 1, 2, 1, 2, 2]])
+
+ >>> a = torch.triu_indices(4, 3, 1)
+ >>> a
+ tensor([[0, 0, 1],
+ [1, 2, 2]])
+""")
+
add_docstr(torch.trtrs,
r"""
trtrs(b, A, upper=True, transpose=False, unitriangular=False) -> (Tensor, Tensor)