#include <ATen/LegacyTHDispatcher.h>
#include <c10/core/ScalarType.h>
#include <ATen/core/Deprecated.h>
+#include <ATen/native/TensorFactories.h>
#include <c10/core/TensorOptions.h>
#include <TH/THRandom.h>
#include <TH/THGenerator.hpp>
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
-namespace {
-// Different combinations of row, col, and offset can lead to two cases:
-//
-// Case 1 - Trapezoid (Triangle as a special case): row + offset <= col
-// Example A: offset > 0
-// 1 1 0 0 0
-// 1 1 1 0 0
-// 1 1 1 1 0
-// Example B: offset <= 0
-// 0 0 0
-// 1 0 0
-// 1 1 0
-// In this case, we calculate the number of elements in the first row and
-// last row of the tril respectively, and then compute the tril size.
-//
-// Case 2 - Trapezoid + Rectangle: row + offset > col
-// Example:
-// 1 1 0
-// 1 1 1
-// 1 1 1
-// In this case, we first calculate the size of top trapezoid, and then
-// calculate the size of the bottom rectangle.
-inline int64_t get_tril_size(int64_t row, int64_t col, int64_t offset) {
- // number of elements in the first row of the tril
- auto m_first_row = offset > 0 ?
- std::min<int64_t>(col, 1 + offset) : // upper bounded by col
- row + offset > 0; // either 0 or 1
- // number of elements in the last row of the tril, bounded by [0, col]
- auto m_last_row = std::max<int64_t>(0, std::min<int64_t>(col, row + offset));
- // number of rows, bounded by [0, row]
- auto n_row_all = std::max<int64_t>(0, std::min<int64_t>(row, row + offset));
- auto n_row_trapezoid = (m_last_row - m_first_row + 1);
-
- // calculate # of elements in the top trapezoid
- auto n_indices =
- (m_first_row + m_last_row) * n_row_trapezoid >> 1;
-
- // calculate # of elements in the bottom rectangle if there is any
- auto diff_row = n_row_all - n_row_trapezoid;
- if (diff_row > 0) {
- n_indices += diff_row * col;
- }
-
- return n_indices;
-}
-
-inline void check_args(
- int64_t row, int64_t col, const TensorOptions& options) {
- AT_CHECK(row >= 0, "row must be non-negative, got", row);
- AT_CHECK(col >= 0, "col must be non-negative, got", col);
- if (options.has_device()) {
- AT_CHECK(
- options.device() == at::kCPU,
- "only support device='cpu', got",
- options.device());
- }
- if (options.has_layout()) {
- AT_CHECK(
- options.layout() == at::kStrided,
- "only support layout=torch.strided, got",
- options.layout())
- }
-}
-} // namespace
-
-Tensor tril_indices(
+Tensor tril_indices_cpu(
int64_t row, int64_t col, int64_t offset, const TensorOptions& options) {
check_args(row, col, options);
- auto n_indices = get_tril_size(row, col, offset);
+ auto tril_size = get_tril_size(row, col, offset);
// create an empty Tensor with correct size
- auto result = at::empty({2, n_indices}, options);
+ auto result = at::empty({2, tril_size}, options);
// The following three approaches result in very little performance
// differences. Hence, the 2nd option is taken for simpler code, and to return
int64_t i = 0;
scalar_t r = std::max<int64_t>(0, -offset), c = 0;
- while (i < n_indices) {
+ while (i < tril_size) {
result_data[i] = r;
- result_data[n_indices + i++] = c;
+ result_data[tril_size + i++] = c;
// move to the next column and check if (r, c) is still in bound
c += 1;
r += 1;
c = 0;
// NOTE: not necessary to check if r is less than row here, because i
- // and n_indices provide the guarantee
+ // and tril_size provide the guarantee
}
}
});
return result;
}
-Tensor triu_indices(
+Tensor triu_indices_cpu(
int64_t row, int64_t col, int64_t offset, const TensorOptions& options) {
check_args(row, col, options);
- auto n_indices = row * col - get_tril_size(row, col, offset - 1);
+ auto triu_size = row * col - get_tril_size(row, col, offset - 1);
// create an empty Tensor with correct size
- auto result = at::empty({2, n_indices}, options);
+ auto result = at::empty({2, triu_size}, options);
AT_DISPATCH_ALL_TYPES(result.type(), "triu_indices", [&]() -> void {
// fill the Tensor with correct values
int64_t i = 0;
// not typing std::max with scalar_t as it could be an unsigned type
// NOTE: no need to check if the returned value of std::max overflows
- // scalar_t, as i and n_indices act as a guard.
+ // scalar_t, as i and triu_size act as a guard.
scalar_t c = std::max<int64_t>(0, offset), r = 0;
- while (i < n_indices) {
+ while (i < triu_size) {
result_data[i] = r;
- result_data[n_indices + i++] = c;
+ result_data[triu_size + i++] = c;
// move to the next column and check if (r, c) is still in bound
c += 1;
r += 1;
// not typing std::max with scalar_t as it could be an unsigned type
// NOTE: not necessary to check if c is less than col or overflows here,
- // because i and n_indices act as a guard.
+ // because i and triu_size act as a guard.
c = std::max<int64_t>(0, r + offset);
}
}
--- /dev/null
+#pragma once
+
+namespace at { namespace native {
+// Different combinations of row, col, and offset can lead to two cases:
+//
+// Case 1 - Trapezoid (Triangle as a special case): row + offset <= col
+// Example A: offset > 0
+// 1 1 0 0 0
+// 1 1 1 0 0
+// 1 1 1 1 0
+// Example B: offset <= 0
+// 0 0 0
+// 1 0 0
+// 1 1 0
+// In this case, we calculate the number of elements in the first row and
+// last row of the tril respectively, and then compute the tril size.
+//
+// Case 2 - Trapezoid + Rectangle: row + offset > col
+// Example:
+// 1 1 0
+// 1 1 1
+// 1 1 1
+// In this case, we first calculate the size of top trapezoid, and then
+// calculate the size of the bottom rectangle.
+inline int64_t get_tril_size(int64_t row, int64_t col, int64_t offset) {
+ // number of elements in the first row of the tril
+ auto m_first_row = offset > 0 ?
+ std::min<int64_t>(col, 1 + offset) : // upper bounded by col
+ row + offset > 0; // either 0 or 1
+ // number of elements in the last row of the tril, bounded by [0, col]
+ auto m_last_row = std::max<int64_t>(0, std::min<int64_t>(col, row + offset));
+ // number of rows, bounded by [0, row]
+ auto n_row_all = std::max<int64_t>(0, std::min<int64_t>(row, row + offset));
+ auto n_row_trapezoid = (m_last_row - m_first_row + 1);
+
+ // calculate # of elements in the top trapezoid
+ auto tril_size = (m_first_row + m_last_row) * n_row_trapezoid >> 1;
+
+ // calculate # of elements in the bottom rectangle if there is any
+ auto diff_row = n_row_all - n_row_trapezoid;
+ if (diff_row > 0) {
+ tril_size += diff_row * col;
+ }
+
+ return tril_size;
+}
+
+inline void check_args(
+ int64_t row, int64_t col, const TensorOptions& options) {
+ AT_CHECK(row >= 0, "row must be non-negative, got", row);
+ AT_CHECK(col >= 0, "col must be non-negative, got", col);
+ if (options.has_layout()) {
+ AT_CHECK(
+ options.layout() == at::kStrided,
+ "only support layout=torch.strided, got",
+ options.layout())
+ }
+}
+} // namespace native
+} // namespace at
#include <ATen/ATen.h>
#include <ATen/InitialTensorOptions.h>
#include <ATen/NativeFunctions.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/CUDAContext.h>
+#include <ATen/native/TensorFactories.h>
#include <c10/util/Exception.h>
#include <THC/THCGeneral.h>
#include <algorithm>
#include <cstddef>
+#include <cmath>
namespace at {
namespace native {
return result;
}
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+namespace {
+// To find the max integer that does not exceed the root of an int64_t variable,
+// we could use a loop to test one bit at a time, which takes up to 31
+// iterations. This would give the accurate result, but is relatively slow and
+// is an overkill for most cases where double's precision suffice.
+//
+// If we directly use sqrt to calculate the root, the convertion from int64_t
+// to double would lose 11 bits precision.
+//
+// The following solution uses sqrt directly for most cases, and would only
+// special handle it if there is indeed precision loss.
+__device__
+inline int64_t resolve_root_int(
+ int64_t b, int64_t cX4, int64_t x, int32_t sign) {
+ int64_t bXb_cX4 = b*b - cX4;
+ // potential precision loss could occur here when casting int64_t (63 bits
+ // precision) to double (52 bits precision)
+ double sr = ::sqrt((double)bXb_cX4);
+ int64_t res = ::__double2ll_rd((-b + sign * sr)/2);
+
+ // have to cast double to int64_t, otherwise it would only compare up to the
+ // precision of a double variable, ignoring the precision loss
+ if (bXb_cX4 != (int64_t) (sr * sr)) {
+ // handle precision loss by using binary search
+ int64_t llsr = ::__double2ll_rd(sr);
+ // Use the following math to reduce search space.
+ // Suppose z is the accurate result of sqrt(bXb_cX4) without precision loss
+ // let d = abs(bXb_cX4 - llsr * llsr), then we have:
+ // z = sqrt(bXb_cX4) <= sqrt(llsr * llsr + d) <= llsr + sqrt(d)
+ // z = sqrt(bXb_cX4) >= sqrt(llsr * llsr - d) >= llsr - sqrt(d)
+ // Hence, it is sufficient to search range [llsr - sqrt(d), llsr + sqrt(d)).
+ // And the true value of row would also be with in range,
+ // [res - sqrt(d), res + sqrt(d) + 1)
+ // as the denominator would only reduce the precision penalty.
+ int64_t diff =
+ ::__double2ll_ru(::sqrt(::fabs((double)(bXb_cX4 - llsr * llsr))));
+ // l never exceeds (could equal to) the target row index
+ auto l = res > diff ? res - diff : 0;
+ // r is always larger than the target row index
+ auto r = res + diff + 1;
+
+ // binary search for the correct answer
+ x <<= 1; // the loop always compares with 2x, so do it once here
+ while (l + 1 < r) {
+ auto m = (l + r) >> 1;
+ // for tril:
+ // b = 2f - 1, sign = 1, hence (2f + m - 1) * m / 2
+ // for triu:
+ // b = -2f - 1, sign = -1, hence (2f - m + 1) * m / 2
+ if (sign * (b + m) * m > x) {
+ r = m;
+ } else {
+ l = m;
+ }
+ }
+ res = l;
+ }
+
+ return res;
+}
+
+// f: the number of elements in the first row of the trapezoid.
+// x: the index of the target coordinates ordered by row and then column.
+//
+// View the tril as a top trapezoid stacked on a bottom rectangle. Assume x
+// corresponds to the coordinate (row, col) in the trapezoid, where the row and
+// the col both start from 0, then we have:
+//
+// (f + f + row - 1) * row / 2 <= x [1]
+// (f + f + row) * (row + 1) / 2 > x [2]
+//
+// Therefore, row is the maximum integer satisfying the following inequality:
+//
+// (row + 2f - 1)row <= 2x
+// row^2 + (2f-1)row - 2x <= 0. [3]
+//
+// Based on ineuqality [3], we have the following coefficients for formula of
+// root:
+// a = 1
+// b = 2f - 1
+// c = -2x
+// There are two roots, and we should use the largest integer that does not
+// exceed the root on the right. Intuitively, it is because:
+// i) the valid solution range of row is between two roots, as it is <= 0;
+// ii) as we count in more rows, the total # of elements should always
+// increase, hence so does the left-hand side row^2 + (2f-1)row - 2x.
+// Therefore, the valid range of row lies in between the nadir point and
+// the larger root on the right.
+// Full proof can be derived from inequality [2]. So, we calculate the result
+// coordinate as:
+//
+// row = floor((-b + sqrt(b^2 - 4c)) / 2)
+// col = x - (f + f + row - 1) * row / 2
+__device__
+inline void get_coordinate_in_tril_trapezoid(
+ int64_t f, int64_t x, int64_t & row, int64_t & col) {
+ f <<= 1; // all statements use 2f, so only calculate it once here.
+ auto b = f - 1;
+ auto cX4 = - (x << 3); // 4 * c = 4 * (-2x) = -8x;
+ row = resolve_root_int(b, cX4, x, 1);
+ col = x - ((f + row - 1) * row >> 1);
+}
+
+// f: the number of elements in the first row of the bottom trapezoid.
+// x: the index of the target coordinates ordered by row and then column.
+//
+// View the triu as a top rectangle stacked on a bottom trapezoid, where the
+// trapezoid is upside down. Assume x corresponds to the coordinate (row, col)
+// in the bottom trapezoid, where the row and the col start from 0, then we
+// have:
+//
+// (f + f - row + 1) * row / 2 <= x [1]
+// (f + f - row) * (row + 1) / 2 > x [2]
+//
+// Therefore, row is the maximum integer satisfying the following inequality:
+//
+// (-row + 2f + 1)row <= 2x
+// row^2 - (2f+1)row + 2x >= 0. [3]
+//
+// Based on ineuqality [3], we have the following coefficients for formula of
+// root:
+// a = 1
+// b = -1 - 2f
+// c = 2x
+// There are two roots, and we should use the largest integer that does not
+// exceed the root on the left. Intuitively, it is because:
+// i) the valid solution range of row is outside of the two roots, as it is <
+// > 0;
+// ii) as we count in more rows, the total # of elements should always
+// increase, hence so does the left-hand side row^2 - (2f+1)row + 2x.
+// Therefore, the valid range of row lies to the left of the smaller root
+// on the left.
+// Full proof can be derived from inequality [2]. So, we calculate the result
+// coordinate as:
+//
+// row = floor((-b - sqrt(b^2 - 4c)) / 2)
+// col = x - (f + f - row + 1) * row / 2
+__device__
+inline void get_coordinate_in_triu_trapezoid(
+ int64_t f, int64_t x, int64_t & row, int64_t & col) {
+ f <<= 1; // all statements use 2f, so only calculate it once here.
+ auto b = -1 - f;
+ auto cX4 = x << 3; // 4 * c = 4 * (2x) = 8x;
+ row = resolve_root_int(b, cX4, x, -1);
+ col = x - ((f - row + 1) * row >> 1) + row;
+}
+
+} // namespace
+
+template <typename scalar_t>
+__global__
+void tril_indices_kernel(scalar_t * tensor,
+ int64_t row_offset,
+ int64_t m_first_row,
+ int64_t col,
+ int64_t trapezoid_size,
+ int64_t tril_size) {
+ int64_t linear_index = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (linear_index < tril_size) {
+ int64_t r, c;
+ if (linear_index < trapezoid_size) {
+ // the coordinate is within the top trapezoid
+ get_coordinate_in_tril_trapezoid(m_first_row, linear_index, r, c);
+ } else {
+ // the coordinate falls in the bottom rectangle
+ auto surplus = linear_index - trapezoid_size;
+ // add the height of trapezoid: m_last_row (col) - m_first_row + 1
+ r = surplus / col + col - m_first_row + 1;
+ c = surplus % col;
+ }
+ r += row_offset;
+
+ tensor[linear_index] = r;
+ tensor[linear_index + tril_size] = c;
+ }
+}
+
+// Some Large test cases for the fallback binary search path is disabled by
+// default to speed up CI tests and to avoid OOM error. When modifying the
+// implementation, please enable them in test/test_cuda.py and make sure they
+// pass on your local server.
+Tensor tril_indices_cuda(
+ int64_t row, int64_t col, int64_t offset, const TensorOptions& options) {
+ check_args(row, col, options);
+
+ auto tril_size = get_tril_size(row, col, offset);
+ auto tensor = empty_cuda({2, tril_size}, options);
+
+ if (tril_size > 0) {
+ auto m_first_row = offset > 0 ?
+ std::min<int64_t>(col, 1 + offset) : // upper bounded by col
+ row + offset > 0; // either 0 or 1
+ auto trapezoid_row_offset = std::max<int64_t>(0, -offset);
+ auto rectangle_row_offset = trapezoid_row_offset + col - m_first_row + 1;
+ int64_t rectangle_size = 0;
+ if (rectangle_row_offset < row) {
+ rectangle_size = (row - rectangle_row_offset) * col;
+ }
+
+ dim3 dim_block = cuda::getApplyBlock();
+ dim3 dim_grid;
+ // using tril_size instead of tensor.numel(), as each thread takes care of
+ // two elements in the tensor.
+ AT_CHECK(
+ cuda::getApplyGrid(tril_size, dim_grid, tensor.get_device()),
+ "unable to get dim grid");
+
+ AT_DISPATCH_ALL_TYPES_AND_HALF(tensor.type(), "tril_indices_cuda", [&] {
+ tril_indices_kernel<<<
+ dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
+ tensor.data<scalar_t>(),
+ trapezoid_row_offset,
+ m_first_row,
+ col,
+ tril_size - rectangle_size,
+ tril_size);
+ });
+ }
+
+ return tensor;
+}
+
+template <typename scalar_t>
+__global__
+void triu_indices_kernel(scalar_t * tensor,
+ int64_t col_offset,
+ int64_t m_first_row,
+ int64_t col,
+ int64_t rectangle_size,
+ int64_t triu_size) {
+ int64_t linear_index = blockIdx.x * blockDim.x + threadIdx.x;
+
+ if (linear_index < triu_size) {
+ int64_t r, c;
+ if (linear_index < rectangle_size) {
+ // the coordinate is within the top rectangle
+ r = linear_index / col;
+ c = linear_index % col;
+ } else {
+ // the coordinate falls in the bottom trapezoid
+ get_coordinate_in_triu_trapezoid(
+ m_first_row, linear_index - rectangle_size, r, c);
+ r += rectangle_size / col;
+ }
+
+ c += col_offset;
+ tensor[linear_index] = r;
+ tensor[linear_index + triu_size] = c;
+ }
+}
+
+// Some Large test cases for the fallback binary search path is disabled by
+// default to speed up CI tests and to avoid OOM error. When modifying the
+// implementation, please enable them in test/test_cuda.py and make sure they
+// pass on your local server.
+Tensor triu_indices_cuda(
+ int64_t row, int64_t col, int64_t offset, const TensorOptions& options) {
+ check_args(row, col, options);
+
+ auto triu_size = row * col - get_tril_size(row, col, offset - 1);
+ auto tensor = empty_cuda({2, triu_size}, options);
+
+ if (triu_size > 0) {
+ // # of triu elements in the first row
+ auto m_first_row = offset > 0 ?
+ std::max<int64_t>(col - offset, 0) : // upper bounded by col
+ col;
+
+ // size of the top rectangle
+ int64_t rectangle_size = 0;
+ if (offset < 0) {
+ rectangle_size = std::min<int64_t>(row, -offset) * col;
+ }
+
+ dim3 dim_block = cuda::getApplyBlock();
+ dim3 dim_grid;
+
+ // using triu_size instead of tensor.numel(), as each thread takes care of
+ // two elements in the tensor.
+ AT_CHECK(
+ cuda::getApplyGrid(triu_size, dim_grid, tensor.get_device()),
+ "unable to get dim grid");
+
+ AT_DISPATCH_ALL_TYPES_AND_HALF(tensor.type(), "triu_indices_cuda", [&] {
+ triu_indices_kernel<<<
+ dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
+ tensor.data<scalar_t>(),
+ std::max<int64_t>(0, offset),
+ m_first_row,
+ col,
+ rectangle_size,
+ triu_size);
+ });
+ }
+
+ return tensor;
+}
+
}} // namespace at::native
variants: method, function
- func: tril_indices(int64_t row, int64_t col, int64_t offset=0, TensorOptions options=at::kLong) -> Tensor
+ dispatch:
+ CPU: tril_indices_cpu
+ CUDA: tril_indices_cuda
- func: triu_indices(int64_t row, int64_t col, int64_t offset=0, TensorOptions options=at::kLong) -> Tensor
+ dispatch:
+ CPU: triu_indices_cpu
+ CUDA: triu_indices_cuda
- func: trace(Tensor self) -> Tensor
variants: method, function
return args_out, kwargs_out
+def _compare_trilu_indices(
+ self, row, col, offset=0, dtype=torch.long, device='cpu'):
+ if row == 0 or col == 0:
+ # have to handle this separately as tril and triu does not take
+ # empty matrix as input
+ self.assertEqual(
+ torch.empty(0, 2, dtype=dtype, device=device).transpose(0, 1),
+ torch.tril_indices(row, col, offset, dtype=dtype, device=device))
+
+ self.assertEqual(
+ torch.empty(0, 2, dtype=dtype, device=device).transpose(0, 1),
+ torch.triu_indices(row, col, offset, dtype=dtype, device=device))
+
+ else:
+ self.assertEqual(
+ torch.ones(row, col, dtype=dtype, device='cpu')
+ .tril(offset).nonzero().transpose(0, 1).to(device),
+ torch.tril_indices(row, col, offset, dtype=dtype, device=device))
+
+ self.assertEqual(
+ torch.ones(row, col, dtype=dtype, device='cpu')
+ .tril(offset).nonzero().transpose(0, 1).to(device),
+ torch.tril_indices(row, col, offset, dtype=dtype, device=device))
+
+
+def _compare_large_trilu_indices(
+ self, row, col, offset=0, dtype=torch.long, device='cpu'):
+ l = torch.ones(row, col, dtype=dtype, device='cpu').tril(offset) \
+ .nonzero()[-100:-1, :].transpose(0, 1).to(device)
+ torch.cuda.empty_cache()
+
+ r = torch.tril_indices(
+ row, col, offset, dtype=dtype, device=device)[:, -100:-1]
+ self.assertEqual(l, r)
+ torch.cuda.empty_cache()
+
+ l = torch.ones(row, col, dtype=dtype, device='cpu').triu(offset) \
+ .nonzero()[-100:-1, :].transpose(0, 1).to(device)
+ torch.cuda.empty_cache()
+
+ r = torch.triu_indices(
+ row, col, offset, dtype=dtype, device=device)[:, -100:-1]
+ self.assertEqual(l, r)
+ torch.cuda.empty_cache()
+
+# (
+# row
+# col
+# offset (optional)
+# dtype (optional)
+# )
+tri_tests_args = [
+ (1, 1),
+ (3, 3),
+ (3, 3, 1),
+ (3, 3, 2),
+ (3, 3, 200),
+ (3, 3, -1),
+ (3, 3, -2),
+ (3, 3, -200),
+ (0, 3, 0),
+ (0, 3, 1),
+ (0, 3, -1),
+ (3, 0, 0),
+ (3, 0, 1),
+ (3, 0, -1),
+ (0, 0, 0),
+ (0, 0, 1),
+ (0, 0, -1),
+ (3, 6, 0),
+ (3, 6, 1),
+ (3, 6, 3),
+ (3, 6, 9),
+ (3, 6, -1),
+ (3, 6, -3),
+ (3, 6, -9),
+ (6, 3, 0),
+ (6, 3, 1),
+ (6, 3, 3),
+ (6, 3, 9),
+ (6, 3, -1),
+ (6, 3, -3),
+ (6, 3, -9),
+ (258, 253, 1, torch.float32),
+ (257, 258, 1, torch.float64),
+ (258, 258, 1, torch.short),
+ (3, 513, 1, torch.long),
+ (513, 3, 1, torch.int),
+ (513, 0, 1, torch.double),
+ (1024, 1024),
+ (1024, 1024, 500, torch.float32),
+ (1024, 1024, 1023),
+ (1024, 1024, -500),
+ (1023, 1025),
+ (1025, 1023, 1022),
+ (1024, 1024, -500),
+ (3, 2028),
+ (3, 2028, 1),
+ (3, 2028, -1),
+ (2028, 3),
+ (2028, 1),
+ (2028, 1, -1)
+]
+
+tri_large_tests_args = [
+ (1, 268435455),
+ # Large test cases below are deliberately commented out to speed up CI
+ # tests and to avoid OOM error. When modifying implementations of
+ # tril_indices and triu_indices, please enable these tests and make sure
+ # they pass.
+ #
+ # (5000, 5000),
+ # (10000, 10000),
+ # (268435455, 1),
+ # (134217727, 2, 1),
+ # (2, 134217727, 1),
+ # (536870901, 1),
+ # (1, 536870901),
+ # (268435455, 2, 1),
+ # (2, 268435455, 1)
+]
+
+
+def run_additional_tri_tests(self, device):
+ x = torch.ones(
+ 3, 3, dtype=torch.long, device=device, layout=torch.strided)
+ l = x.tril(0).nonzero().transpose(0, 1)
+ u = x.triu(0).nonzero().transpose(0, 1)
+ self.assertEqual(l, torch.tril_indices(3, 3, device=device))
+ self.assertEqual(
+ l, torch.tril_indices(3, 3, device=device, layout=torch.strided))
+
+ self.assertEqual(u, torch.triu_indices(3, 3, device=device))
+ self.assertEqual(
+ u, torch.triu_indices(3, 3, device=device, layout=torch.strided))
+
+ self.assertRaises(
+ RuntimeError,
+ lambda: torch.triu_indices(
+ 1, 1, device=device, layout=torch.sparse_coo))
+
+ self.assertRaises(
+ RuntimeError,
+ lambda: torch.tril_indices(
+ 1, 1, device=device, layout=torch.sparse_coo))
+
+
def unpack_variables(args):
if isinstance(args, tuple):
return tuple(unpack_variables(elem) for elem in args)
from test_torch import _TestTorchMixin
+from common_methods_invocations import tri_tests_args, tri_large_tests_args, \
+ run_additional_tri_tests, _compare_trilu_indices, _compare_large_trilu_indices
from common_utils import TestCase, get_gpu_type, to_gpu, freeze_rng_state, run_tests, \
PY3, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, skipIfRocm, TEST_NUMPY, TEST_WITH_ROCM, load_tests
y = torch.randn(2, 1, device='cuda')
z = x + y
- def test_tril_and_triu_indices(self):
- self.assertRaises(
- RuntimeError,
- lambda: torch.triu_indices(
- 1, 1, device='cuda', layout=torch.strided))
-
- self.assertRaises(
- RuntimeError,
- lambda: torch.tril_indices(
- 1, 1, device='cuda', layout=torch.strided))
+ def test_trilu_indices(self):
+ for test_args in tri_tests_args:
+ _compare_trilu_indices(self, *test_args, device='cuda')
+
+ # test default options
+ x = torch.ones(
+ 3, 3, dtype=torch.long, device='cuda', layout=torch.strided)
+ self.assertEqual(
+ x.tril(0).nonzero().transpose(0, 1),
+ torch.tril_indices(3, 3, device='cuda'))
+ self.assertEqual(
+ x.triu(0).nonzero().transpose(0, 1),
+ torch.triu_indices(3, 3, device='cuda'))
+
+ def test_large_trilu_indices(self):
+ for test_args in tri_large_tests_args:
+ _compare_large_trilu_indices(self, *test_args, device='cuda')
def load_ignore_file():
from itertools import product, combinations
from functools import reduce
from torch import multiprocessing as mp
+from common_methods_invocations import tri_tests_args, run_additional_tri_tests, \
+ _compare_trilu_indices
from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \
TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \
IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, skipIfRocm, do_test_dtypes, do_test_empty_full, \
torch.tril(x, out=res2)
self.assertEqual(res1, res2, 0)
- def _compare_trilu_indices(self, row, col, offset=0, dtype=torch.long):
- if row == 0 or col == 0:
- # have to handle this separately as tril and triu does not take
- # empty matrix as input
- self.assertEqual(
- torch.empty(0, 2, dtype=dtype).transpose(0, 1),
- torch.tril_indices(row, col, offset, dtype=dtype))
+ def test_trilu_indices(self):
+ for test_args in tri_tests_args:
+ _compare_trilu_indices(self, *test_args)
- self.assertEqual(
- torch.empty(0, 2, dtype=dtype).transpose(0, 1),
- torch.triu_indices(row, col, offset, dtype=dtype))
-
- else:
- self.assertEqual(
- torch.ones(row, col, dtype=dtype)
- .tril(offset).nonzero().transpose(0, 1),
- torch.tril_indices(row, col, offset, dtype=dtype))
-
- self.assertEqual(
- torch.ones(row, col, dtype=dtype)
- .triu(offset).nonzero().transpose(0, 1),
- torch.triu_indices(row, col, offset, dtype=dtype))
-
- def test_tril_and_triu_indices(self):
- self._compare_trilu_indices(1, 1)
- self._compare_trilu_indices(3, 3)
- self._compare_trilu_indices(3, 3, offset=1)
- self._compare_trilu_indices(3, 3, offset=2)
- self._compare_trilu_indices(3, 3, offset=200)
- self._compare_trilu_indices(3, 3, offset=-1)
- self._compare_trilu_indices(3, 3, offset=-2)
- self._compare_trilu_indices(3, 3, offset=-200)
- self._compare_trilu_indices(0, 3, offset=0)
- self._compare_trilu_indices(0, 3, offset=1)
- self._compare_trilu_indices(0, 3, offset=-1)
- self._compare_trilu_indices(3, 0, offset=0)
- self._compare_trilu_indices(3, 0, offset=1)
- self._compare_trilu_indices(3, 0, offset=-1)
- self._compare_trilu_indices(0, 0, offset=0)
- self._compare_trilu_indices(0, 0, offset=1)
- self._compare_trilu_indices(0, 0, offset=-1)
- self._compare_trilu_indices(3, 6, offset=0)
- self._compare_trilu_indices(3, 6, offset=1)
- self._compare_trilu_indices(3, 6, offset=3)
- self._compare_trilu_indices(3, 6, offset=9)
- self._compare_trilu_indices(3, 6, offset=-1)
- self._compare_trilu_indices(3, 6, offset=-3)
- self._compare_trilu_indices(3, 6, offset=-9)
- self._compare_trilu_indices(6, 3, offset=0)
- self._compare_trilu_indices(6, 3, offset=1)
- self._compare_trilu_indices(6, 3, offset=3)
- self._compare_trilu_indices(6, 3, offset=9)
- self._compare_trilu_indices(6, 3, offset=-1)
- self._compare_trilu_indices(6, 3, offset=-3)
- self._compare_trilu_indices(6, 3, offset=-9)
- self._compare_trilu_indices(258, 253, offset=1, dtype=torch.float32)
- self._compare_trilu_indices(257, 258, offset=1, dtype=torch.float64)
- self._compare_trilu_indices(258, 258, offset=1, dtype=torch.short)
- self._compare_trilu_indices(3, 513, offset=1, dtype=torch.long)
- self._compare_trilu_indices(513, 3, offset=1, dtype=torch.int)
- self._compare_trilu_indices(513, 0, offset=1, dtype=torch.double)
+ run_additional_tri_tests(self, 'cpu')
+ # test default options
x = torch.ones(
3, 3, dtype=torch.long, device='cpu', layout=torch.strided)
- l = x.tril(0).nonzero().transpose(0, 1)
- u = x.triu(0).nonzero().transpose(0, 1)
- self.assertEqual(l, torch.tril_indices(3, 3))
- self.assertEqual(l, torch.tril_indices(3, 3, device='cpu'))
self.assertEqual(
- l, torch.tril_indices(3, 3, device='cpu', layout=torch.strided))
-
- self.assertEqual(u, torch.triu_indices(3, 3))
- self.assertEqual(u, torch.triu_indices(3, 3, device='cpu'))
+ x.tril(0).nonzero().transpose(0, 1), torch.tril_indices(3, 3))
self.assertEqual(
- u, torch.triu_indices(3, 3, device='cpu', layout=torch.strided))
-
- self.assertRaises(
- RuntimeError,
- lambda: torch.triu_indices(
- 1, 1, device='cpu', layout=torch.sparse_coo))
-
- self.assertRaises(
- RuntimeError,
- lambda: torch.tril_indices(
- 1, 1, device='cpu', layout=torch.sparse_coo))
+ x.triu(0).nonzero().transpose(0, 1), torch.triu_indices(3, 3))
def test_triu(self):
x = torch.rand(SIZE, SIZE)
:math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]`
where :math:`d_{1}, d_{2}` are the dimensions of the matrix.
+NOTE: when running on 'cuda', row * col must be less than :math:`2^{59}` to
+prevent overflow during calculation.
+""" + r"""
Args:
row (``int``): number of rows in the 2-D matrix.
column (``int``): number of columns in the 2-D matrix.
Default: if not provided, 0.
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, ``torch.long``.
- device (:class:`torch.device`, optional): currently only support ``cpu``.
+ {device}
layout (:class:`torch.layout`, optional): currently only support ``torch.strided``.
Example::
>>> a
tensor([[0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3],
[0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2]])
-""")
+""".format(**factory_common_args))
add_docstr(torch.triu,
r"""
:math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]`
where :math:`d_{1}, d_{2}` are the dimensions of the matrix.
+NOTE: when running on 'cuda', row * col must be less than :math:`2^{59}` to
+prevent overflow during calculation.
+""" + r"""
Args:
row (``int``): number of rows in the 2-D matrix.
column (``int``): number of columns in the 2-D matrix.
Default: if not provided, 0.
dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
Default: if ``None``, ``torch.long``.
- device (:class:`torch.device`, optional): currently only support ``cpu``.
+ {device}
layout (:class:`torch.layout`, optional): currently only support ``torch.strided``.
Example::
>>> a
tensor([[0, 0, 1],
[1, 2, 2]])
-""")
+""".format(**factory_common_args))
add_docstr(torch.trtrs,
r"""