Implementing cuda kernel for tril_indices and triu_indices (#15203)
authorShen Li <shenli@fb.com>
Thu, 20 Dec 2018 18:21:02 +0000 (10:21 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Thu, 20 Dec 2018 18:23:38 +0000 (10:23 -0800)
Summary:
Followup PR of #14904, and the stretch goal of #12653.

Directly calculate coordinates in the original tensor using column index in the result tensor. Every GPU thread takes care of a column (two numbers) in the output tensor.

The implementation detects and handles precision loss during calculating the square root of a `int64_t` variable, and supports tensors with up to `row * column = 2 ^ 59` numbers.

Algorithm details are describe in [comments of TensorFactories.cu](https://github.com/pytorch/pytorch/blob/23ddb6f58a1c8a7a660a793f174cf014230176c6/aten/src/ATen/native/cuda/TensorFactories.cu#L109-L255).

zou3519
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15203

Reviewed By: zou3519

Differential Revision: D13517695

Pulled By: mrshenli

fbshipit-source-id: 86b305d22cac08c8962a3b0cf8e9e620b7ec33ea

aten/src/ATen/native/TensorFactories.cpp
aten/src/ATen/native/TensorFactories.h [new file with mode: 0644]
aten/src/ATen/native/cuda/TensorFactories.cu
aten/src/ATen/native/native_functions.yaml
test/common_methods_invocations.py
test/test_cuda.py
test/test_torch.py
torch/_torch_docs.py

index f7d8f65..3dea98a 100644 (file)
@@ -13,6 +13,7 @@
 #include <ATen/LegacyTHDispatcher.h>
 #include <c10/core/ScalarType.h>
 #include <ATen/core/Deprecated.h>
+#include <ATen/native/TensorFactories.h>
 #include <c10/core/TensorOptions.h>
 #include <TH/THRandom.h>
 #include <TH/THGenerator.hpp>
@@ -513,79 +514,14 @@ Tensor& range_out(Tensor& result, Scalar start, Scalar end, Scalar step) {
 
 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
 
-namespace {
-// Different combinations of row, col, and offset can lead to two cases:
-//
-// Case 1 - Trapezoid (Triangle as a special case): row + offset <= col
-//    Example A: offset > 0
-//      1 1 0 0 0
-//      1 1 1 0 0
-//      1 1 1 1 0
-//    Example B: offset <= 0
-//      0 0 0
-//      1 0 0
-//      1 1 0
-//    In this case, we calculate the number of elements in the first row and
-//    last row of the tril respectively, and then compute the tril size.
-//
-// Case 2 - Trapezoid + Rectangle: row + offset > col
-//    Example:
-//      1 1 0
-//      1 1 1
-//      1 1 1
-//    In this case, we first calculate the size of top trapezoid, and then
-//    calculate the size of the bottom rectangle.
-inline int64_t get_tril_size(int64_t row, int64_t col, int64_t offset) {
-  // number of elements in the first row of the tril
-  auto m_first_row = offset > 0 ?
-    std::min<int64_t>(col, 1 + offset) : // upper bounded by col
-    row + offset > 0; // either 0 or 1
-  // number of elements in the last row of the tril, bounded by [0, col]
-  auto m_last_row = std::max<int64_t>(0, std::min<int64_t>(col, row + offset));
-  // number of rows, bounded by [0, row]
-  auto n_row_all = std::max<int64_t>(0, std::min<int64_t>(row, row + offset));
-  auto n_row_trapezoid = (m_last_row - m_first_row + 1);
-
-  // calculate # of elements in the top trapezoid
-  auto n_indices =
-    (m_first_row + m_last_row) * n_row_trapezoid >> 1;
-
-  // calculate # of elements in the bottom rectangle if there is any
-  auto diff_row = n_row_all - n_row_trapezoid;
-  if (diff_row > 0) {
-    n_indices += diff_row * col;
-  }
-
-  return n_indices;
-}
-
-inline void check_args(
-    int64_t row, int64_t col, const TensorOptions& options) {
-  AT_CHECK(row >= 0, "row must be non-negative, got", row);
-  AT_CHECK(col >= 0, "col must be non-negative, got", col);
-  if (options.has_device()) {
-    AT_CHECK(
-      options.device() == at::kCPU,
-      "only support device='cpu', got",
-      options.device());
-  }
-  if (options.has_layout()) {
-    AT_CHECK(
-      options.layout() == at::kStrided,
-      "only support layout=torch.strided, got",
-      options.layout())
-  }
-}
-} // namespace
-
-Tensor tril_indices(
+Tensor tril_indices_cpu(
     int64_t row, int64_t col, int64_t offset, const TensorOptions& options) {
   check_args(row, col, options);
 
-  auto n_indices = get_tril_size(row, col, offset);
+  auto tril_size = get_tril_size(row, col, offset);
 
   // create an empty Tensor with correct size
-  auto result = at::empty({2, n_indices}, options);
+  auto result = at::empty({2, tril_size}, options);
 
   // The following three approaches result in very little performance
   // differences. Hence, the 2nd option is taken for simpler code, and to return
@@ -605,9 +541,9 @@ Tensor tril_indices(
     int64_t i = 0;
 
     scalar_t r = std::max<int64_t>(0, -offset), c = 0;
-    while (i < n_indices) {
+    while (i < tril_size) {
       result_data[i] = r;
-      result_data[n_indices + i++] = c;
+      result_data[tril_size + i++] = c;
 
       // move to the next column and check if (r, c) is still in bound
       c += 1;
@@ -615,7 +551,7 @@ Tensor tril_indices(
         r += 1;
         c = 0;
         // NOTE: not necessary to check if r is less than row here, because i
-        // and n_indices provide the guarantee
+        // and tril_size provide the guarantee
       }
     }
   });
@@ -623,14 +559,14 @@ Tensor tril_indices(
   return result;
 }
 
-Tensor triu_indices(
+Tensor triu_indices_cpu(
     int64_t row, int64_t col, int64_t offset, const TensorOptions& options) {
   check_args(row, col, options);
 
-  auto n_indices = row * col - get_tril_size(row, col, offset - 1);
+  auto triu_size = row * col - get_tril_size(row, col, offset - 1);
 
   // create an empty Tensor with correct size
-  auto result = at::empty({2, n_indices}, options);
+  auto result = at::empty({2, triu_size}, options);
 
   AT_DISPATCH_ALL_TYPES(result.type(), "triu_indices", [&]() -> void {
     // fill the Tensor with correct values
@@ -638,11 +574,11 @@ Tensor triu_indices(
     int64_t i = 0;
     // not typing std::max with scalar_t as it could be an unsigned type
     // NOTE: no need to check if the returned value of std::max overflows
-    // scalar_t, as i and n_indices act as a guard.
+    // scalar_t, as i and triu_size act as a guard.
     scalar_t c = std::max<int64_t>(0, offset), r = 0;
-    while (i < n_indices) {
+    while (i < triu_size) {
       result_data[i] = r;
-      result_data[n_indices + i++] = c;
+      result_data[triu_size + i++] = c;
 
       // move to the next column and check if (r, c) is still in bound
       c += 1;
@@ -650,7 +586,7 @@ Tensor triu_indices(
         r += 1;
         // not typing std::max with scalar_t as it could be an unsigned type
         // NOTE: not necessary to check if c is less than col or overflows here,
-        // because i and n_indices act as a guard.
+        // because i and triu_size act as a guard.
         c = std::max<int64_t>(0, r + offset);
       }
     }
diff --git a/aten/src/ATen/native/TensorFactories.h b/aten/src/ATen/native/TensorFactories.h
new file mode 100644 (file)
index 0000000..8f68ab3
--- /dev/null
@@ -0,0 +1,60 @@
+#pragma once
+
+namespace at { namespace native {
+// Different combinations of row, col, and offset can lead to two cases:
+//
+// Case 1 - Trapezoid (Triangle as a special case): row + offset <= col
+//    Example A: offset > 0
+//      1 1 0 0 0
+//      1 1 1 0 0
+//      1 1 1 1 0
+//    Example B: offset <= 0
+//      0 0 0
+//      1 0 0
+//      1 1 0
+//    In this case, we calculate the number of elements in the first row and
+//    last row of the tril respectively, and then compute the tril size.
+//
+// Case 2 - Trapezoid + Rectangle: row + offset > col
+//    Example:
+//      1 1 0
+//      1 1 1
+//      1 1 1
+//    In this case, we first calculate the size of top trapezoid, and then
+//    calculate the size of the bottom rectangle.
+inline int64_t get_tril_size(int64_t row, int64_t col, int64_t offset) {
+  // number of elements in the first row of the tril
+  auto m_first_row = offset > 0 ?
+    std::min<int64_t>(col, 1 + offset) : // upper bounded by col
+    row + offset > 0; // either 0 or 1
+  // number of elements in the last row of the tril, bounded by [0, col]
+  auto m_last_row = std::max<int64_t>(0, std::min<int64_t>(col, row + offset));
+  // number of rows, bounded by [0, row]
+  auto n_row_all = std::max<int64_t>(0, std::min<int64_t>(row, row + offset));
+  auto n_row_trapezoid = (m_last_row - m_first_row + 1);
+
+  // calculate # of elements in the top trapezoid
+  auto tril_size = (m_first_row + m_last_row) * n_row_trapezoid >> 1;
+
+  // calculate # of elements in the bottom rectangle if there is any
+  auto diff_row = n_row_all - n_row_trapezoid;
+  if (diff_row > 0) {
+    tril_size += diff_row * col;
+  }
+
+  return tril_size;
+}
+
+inline void check_args(
+    int64_t row, int64_t col, const TensorOptions& options) {
+  AT_CHECK(row >= 0, "row must be non-negative, got", row);
+  AT_CHECK(col >= 0, "col must be non-negative, got", col);
+  if (options.has_layout()) {
+    AT_CHECK(
+      options.layout() == at::kStrided,
+      "only support layout=torch.strided, got",
+      options.layout())
+  }
+}
+} // namespace native
+} // namespace at
index 4afded7..744928f 100644 (file)
@@ -1,7 +1,9 @@
 #include <ATen/ATen.h>
 #include <ATen/InitialTensorOptions.h>
 #include <ATen/NativeFunctions.h>
+#include <ATen/cuda/CUDAApplyUtils.cuh>
 #include <ATen/cuda/CUDAContext.h>
+#include <ATen/native/TensorFactories.h>
 #include <c10/util/Exception.h>
 
 #include <THC/THCGeneral.h>
@@ -13,6 +15,7 @@
 
 #include <algorithm>
 #include <cstddef>
+#include <cmath>
 
 namespace at {
 namespace native {
@@ -101,4 +104,305 @@ Tensor& randperm_out_cuda(Tensor& result, int64_t n, Generator* generator) {
   return result;
 }
 
+// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+namespace {
+// To find the max integer that does not exceed the root of an int64_t variable,
+// we could use a loop to test one bit at a time, which takes up to 31
+// iterations. This would give the accurate result, but is relatively slow and
+// is an overkill for most cases where double's precision suffice.
+//
+// If we directly use sqrt to calculate the root, the convertion from int64_t
+// to double would lose 11 bits precision.
+//
+// The following solution uses sqrt directly for most cases, and would only
+// special handle it if there is indeed precision loss.
+__device__
+inline int64_t resolve_root_int(
+    int64_t b, int64_t cX4, int64_t x, int32_t sign) {
+  int64_t bXb_cX4 = b*b - cX4;
+  // potential precision loss could occur here when casting int64_t (63 bits
+  // precision) to double (52 bits precision)
+  double sr = ::sqrt((double)bXb_cX4);
+  int64_t res = ::__double2ll_rd((-b + sign * sr)/2);
+
+  // have to cast double to int64_t, otherwise it would only compare up to the
+  // precision of a double variable, ignoring the precision loss
+  if (bXb_cX4 != (int64_t) (sr * sr)) {
+    // handle precision loss by using binary search
+    int64_t llsr = ::__double2ll_rd(sr);
+    // Use the following math to reduce search space.
+    // Suppose z is the accurate result of sqrt(bXb_cX4) without precision loss
+    // let d = abs(bXb_cX4 - llsr * llsr), then we have:
+    // z = sqrt(bXb_cX4) <= sqrt(llsr * llsr + d) <= llsr + sqrt(d)
+    // z = sqrt(bXb_cX4) >= sqrt(llsr * llsr - d) >= llsr - sqrt(d)
+    // Hence, it is sufficient to search range [llsr - sqrt(d), llsr + sqrt(d)).
+    // And the true value of row would also be with in range,
+    //            [res - sqrt(d), res + sqrt(d) + 1)
+    // as the denominator would only reduce the precision penalty.
+    int64_t diff =
+      ::__double2ll_ru(::sqrt(::fabs((double)(bXb_cX4 - llsr * llsr))));
+    // l never exceeds (could equal to) the target row index
+    auto l = res > diff ? res - diff : 0;
+    // r is always larger than the target row index
+    auto r = res + diff + 1;
+
+    // binary search for the correct answer
+    x <<= 1; // the loop always compares with 2x, so do it once here
+    while (l + 1 < r) {
+      auto m = (l + r) >> 1;
+      // for tril:
+      //    b = 2f - 1, sign = 1, hence (2f + m - 1) * m / 2
+      // for triu:
+      //    b = -2f - 1, sign = -1, hence (2f - m + 1) * m / 2
+      if (sign * (b + m) * m > x) {
+        r = m;
+      } else {
+        l = m;
+      }
+    }
+    res = l;
+  }
+
+  return res;
+}
+
+// f: the number of elements in the first row of the trapezoid.
+// x: the index of the target coordinates ordered by row and then column.
+//
+// View the tril as a top trapezoid stacked on a bottom rectangle. Assume x
+// corresponds to the coordinate (row, col) in the trapezoid, where the row and
+// the col both start from 0, then we have:
+//
+//                   (f + f + row - 1) * row / 2 <= x                       [1]
+//                 (f + f + row) * (row + 1) / 2  > x                       [2]
+//
+// Therefore, row is the maximum integer satisfying the following inequality:
+//
+//                       (row + 2f - 1)row <= 2x
+//                  row^2 + (2f-1)row - 2x <= 0.                            [3]
+//
+// Based on ineuqality [3], we have the following coefficients for formula of
+// root:
+//                               a = 1
+//                               b = 2f - 1
+//                               c = -2x
+// There are two roots, and we should use the largest integer that does not
+// exceed the root on the right. Intuitively, it is because:
+//  i)  the valid solution range of row is between two roots, as it is <= 0;
+//  ii) as we count in more rows, the total # of elements should always
+//      increase, hence so does the left-hand side row^2 + (2f-1)row - 2x.
+//      Therefore, the valid range of row lies in between the nadir point and
+//      the larger root on the right.
+// Full proof can be derived from inequality [2]. So, we calculate the result
+// coordinate as:
+//
+//                   row = floor((-b + sqrt(b^2 - 4c)) / 2)
+//                   col = x - (f + f + row - 1) * row / 2
+__device__
+inline void get_coordinate_in_tril_trapezoid(
+    int64_t f, int64_t x, int64_t & row, int64_t & col) {
+  f <<= 1; // all statements use 2f, so only calculate it once here.
+  auto b = f - 1;
+  auto cX4 = - (x << 3); // 4 * c = 4 * (-2x) = -8x;
+  row = resolve_root_int(b, cX4, x, 1);
+  col = x - ((f + row - 1) * row >> 1);
+}
+
+// f: the number of elements in the first row of the bottom trapezoid.
+// x: the index of the target coordinates ordered by row and then column.
+//
+// View the triu as a top rectangle stacked on a bottom trapezoid, where the
+// trapezoid is upside down. Assume x corresponds to the coordinate (row, col)
+// in the bottom trapezoid, where the row and the col start from 0, then we
+// have:
+//
+//                   (f + f - row + 1) * row / 2 <= x                       [1]
+//                 (f + f - row) * (row + 1) / 2  > x                       [2]
+//
+// Therefore, row is the maximum integer satisfying the following inequality:
+//
+//                       (-row + 2f + 1)row <= 2x
+//                   row^2 - (2f+1)row + 2x >= 0.                           [3]
+//
+// Based on ineuqality [3], we have the following coefficients for formula of
+// root:
+//                               a = 1
+//                               b = -1 - 2f
+//                               c = 2x
+// There are two roots, and we should use the largest integer that does not
+// exceed the root on the left. Intuitively, it is because:
+//  i)  the valid solution range of row is outside of the two roots, as it is <
+//      > 0;
+//  ii) as we count in more rows, the total # of elements should always
+//      increase, hence so does the left-hand side row^2 - (2f+1)row + 2x.
+//      Therefore, the valid range of row lies to the left of the smaller root
+//      on the left.
+// Full proof can be derived from inequality [2]. So, we calculate the result
+// coordinate as:
+//
+//                   row = floor((-b - sqrt(b^2 - 4c)) / 2)
+//                   col = x - (f + f - row + 1) * row / 2
+__device__
+inline void get_coordinate_in_triu_trapezoid(
+    int64_t f, int64_t x, int64_t & row, int64_t & col) {
+  f <<= 1; // all statements use 2f, so only calculate it once here.
+  auto b = -1 - f;
+  auto cX4 = x << 3; // 4 * c = 4 * (2x) = 8x;
+  row = resolve_root_int(b, cX4, x, -1);
+  col = x - ((f - row + 1) * row >> 1) + row;
+}
+
+} // namespace
+
+template <typename scalar_t>
+__global__
+void tril_indices_kernel(scalar_t * tensor,
+                         int64_t row_offset,
+                         int64_t m_first_row,
+                         int64_t col,
+                         int64_t trapezoid_size,
+                         int64_t tril_size) {
+  int64_t linear_index = blockIdx.x * blockDim.x + threadIdx.x;
+
+  if (linear_index < tril_size) {
+    int64_t r, c;
+    if (linear_index < trapezoid_size) {
+      // the coordinate is within the top trapezoid
+      get_coordinate_in_tril_trapezoid(m_first_row, linear_index, r, c);
+    } else {
+      // the coordinate falls in the bottom rectangle
+      auto surplus = linear_index - trapezoid_size;
+      // add the height of trapezoid: m_last_row (col) - m_first_row + 1
+      r = surplus / col + col - m_first_row + 1;
+      c = surplus % col;
+    }
+    r += row_offset;
+
+    tensor[linear_index] = r;
+    tensor[linear_index + tril_size] = c;
+  }
+}
+
+// Some Large test cases for the fallback binary search path is disabled by
+// default to speed up CI tests and to avoid OOM error. When modifying the
+// implementation, please enable them in test/test_cuda.py and make sure they
+// pass on your local server.
+Tensor tril_indices_cuda(
+    int64_t row, int64_t col, int64_t offset, const TensorOptions& options) {
+  check_args(row, col, options);
+
+  auto tril_size = get_tril_size(row, col, offset);
+  auto tensor = empty_cuda({2, tril_size}, options);
+
+  if (tril_size > 0) {
+    auto m_first_row = offset > 0 ?
+      std::min<int64_t>(col, 1 + offset) : // upper bounded by col
+      row + offset > 0; // either 0 or 1
+    auto trapezoid_row_offset = std::max<int64_t>(0, -offset);
+    auto rectangle_row_offset = trapezoid_row_offset + col - m_first_row + 1;
+    int64_t rectangle_size = 0;
+    if (rectangle_row_offset < row) {
+      rectangle_size = (row - rectangle_row_offset) * col;
+    }
+
+    dim3 dim_block = cuda::getApplyBlock();
+    dim3 dim_grid;
+    // using tril_size instead of tensor.numel(), as each thread takes care of
+    // two elements in the tensor.
+    AT_CHECK(
+      cuda::getApplyGrid(tril_size, dim_grid, tensor.get_device()),
+      "unable to get dim grid");
+
+    AT_DISPATCH_ALL_TYPES_AND_HALF(tensor.type(), "tril_indices_cuda", [&] {
+      tril_indices_kernel<<<
+          dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
+        tensor.data<scalar_t>(),
+        trapezoid_row_offset,
+        m_first_row,
+        col,
+        tril_size - rectangle_size,
+        tril_size);
+    });
+  }
+
+  return tensor;
+}
+
+template <typename scalar_t>
+__global__
+void triu_indices_kernel(scalar_t * tensor,
+                         int64_t col_offset,
+                         int64_t m_first_row,
+                         int64_t col,
+                         int64_t rectangle_size,
+                         int64_t triu_size) {
+  int64_t linear_index = blockIdx.x * blockDim.x + threadIdx.x;
+
+  if (linear_index < triu_size) {
+    int64_t r, c;
+    if (linear_index < rectangle_size) {
+      // the coordinate is within the top rectangle
+      r = linear_index / col;
+      c = linear_index % col;
+    } else {
+      // the coordinate falls in the bottom trapezoid
+      get_coordinate_in_triu_trapezoid(
+        m_first_row, linear_index - rectangle_size, r, c);
+      r += rectangle_size / col;
+    }
+
+    c += col_offset;
+    tensor[linear_index] = r;
+    tensor[linear_index + triu_size] = c;
+  }
+}
+
+// Some Large test cases for the fallback binary search path is disabled by
+// default to speed up CI tests and to avoid OOM error. When modifying the
+// implementation, please enable them in test/test_cuda.py and make sure they
+// pass on your local server.
+Tensor triu_indices_cuda(
+    int64_t row, int64_t col, int64_t offset, const TensorOptions& options) {
+  check_args(row, col, options);
+
+  auto triu_size = row * col - get_tril_size(row, col, offset - 1);
+  auto tensor = empty_cuda({2, triu_size}, options);
+
+  if (triu_size > 0) {
+    // # of triu elements in the first row
+    auto m_first_row = offset > 0 ?
+      std::max<int64_t>(col - offset, 0) : // upper bounded by col
+      col;
+
+    // size of the top rectangle
+    int64_t rectangle_size = 0;
+    if (offset < 0) {
+      rectangle_size = std::min<int64_t>(row, -offset) * col;
+    }
+
+    dim3 dim_block = cuda::getApplyBlock();
+    dim3 dim_grid;
+
+    // using triu_size instead of tensor.numel(), as each thread takes care of
+    // two elements in the tensor.
+    AT_CHECK(
+      cuda::getApplyGrid(triu_size, dim_grid, tensor.get_device()),
+      "unable to get dim grid");
+
+    AT_DISPATCH_ALL_TYPES_AND_HALF(tensor.type(), "triu_indices_cuda", [&] {
+      triu_indices_kernel<<<
+          dim_grid, dim_block, 0, at::cuda::getCurrentCUDAStream()>>>(
+        tensor.data<scalar_t>(),
+        std::max<int64_t>(0, offset),
+        m_first_row,
+        col,
+        rectangle_size,
+        triu_size);
+    });
+  }
+
+  return tensor;
+}
+
 }} // namespace at::native
index 6c8d98f..6ffb10b 100644 (file)
   variants: method, function
 
 - func: tril_indices(int64_t row, int64_t col, int64_t offset=0, TensorOptions options=at::kLong) -> Tensor
+  dispatch:
+    CPU: tril_indices_cpu
+    CUDA: tril_indices_cuda
 
 - func: triu_indices(int64_t row, int64_t col, int64_t offset=0, TensorOptions options=at::kLong) -> Tensor
+  dispatch:
+    CPU: triu_indices_cpu
+    CUDA: triu_indices_cuda
 
 - func: trace(Tensor self) -> Tensor
   variants: method, function
index 0b7c54e..e442df8 100644 (file)
@@ -833,6 +833,153 @@ def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwarg
     return args_out, kwargs_out
 
 
+def _compare_trilu_indices(
+        self, row, col, offset=0, dtype=torch.long, device='cpu'):
+    if row == 0 or col == 0:
+        # have to handle this separately as tril and triu does not take
+        # empty matrix as input
+        self.assertEqual(
+            torch.empty(0, 2, dtype=dtype, device=device).transpose(0, 1),
+            torch.tril_indices(row, col, offset, dtype=dtype, device=device))
+
+        self.assertEqual(
+            torch.empty(0, 2, dtype=dtype, device=device).transpose(0, 1),
+            torch.triu_indices(row, col, offset, dtype=dtype, device=device))
+
+    else:
+        self.assertEqual(
+            torch.ones(row, col, dtype=dtype, device='cpu')
+                 .tril(offset).nonzero().transpose(0, 1).to(device),
+            torch.tril_indices(row, col, offset, dtype=dtype, device=device))
+
+        self.assertEqual(
+            torch.ones(row, col, dtype=dtype, device='cpu')
+                 .tril(offset).nonzero().transpose(0, 1).to(device),
+            torch.tril_indices(row, col, offset, dtype=dtype, device=device))
+
+
+def _compare_large_trilu_indices(
+        self, row, col, offset=0, dtype=torch.long, device='cpu'):
+    l = torch.ones(row, col, dtype=dtype, device='cpu').tril(offset) \
+             .nonzero()[-100:-1, :].transpose(0, 1).to(device)
+    torch.cuda.empty_cache()
+
+    r = torch.tril_indices(
+        row, col, offset, dtype=dtype, device=device)[:, -100:-1]
+    self.assertEqual(l, r)
+    torch.cuda.empty_cache()
+
+    l = torch.ones(row, col, dtype=dtype, device='cpu').triu(offset) \
+             .nonzero()[-100:-1, :].transpose(0, 1).to(device)
+    torch.cuda.empty_cache()
+
+    r = torch.triu_indices(
+        row, col, offset, dtype=dtype, device=device)[:, -100:-1]
+    self.assertEqual(l, r)
+    torch.cuda.empty_cache()
+
+# (
+#   row
+#   col
+#   offset (optional)
+#   dtype (optional)
+# )
+tri_tests_args = [
+    (1, 1),
+    (3, 3),
+    (3, 3, 1),
+    (3, 3, 2),
+    (3, 3, 200),
+    (3, 3, -1),
+    (3, 3, -2),
+    (3, 3, -200),
+    (0, 3, 0),
+    (0, 3, 1),
+    (0, 3, -1),
+    (3, 0, 0),
+    (3, 0, 1),
+    (3, 0, -1),
+    (0, 0, 0),
+    (0, 0, 1),
+    (0, 0, -1),
+    (3, 6, 0),
+    (3, 6, 1),
+    (3, 6, 3),
+    (3, 6, 9),
+    (3, 6, -1),
+    (3, 6, -3),
+    (3, 6, -9),
+    (6, 3, 0),
+    (6, 3, 1),
+    (6, 3, 3),
+    (6, 3, 9),
+    (6, 3, -1),
+    (6, 3, -3),
+    (6, 3, -9),
+    (258, 253, 1, torch.float32),
+    (257, 258, 1, torch.float64),
+    (258, 258, 1, torch.short),
+    (3, 513, 1, torch.long),
+    (513, 3, 1, torch.int),
+    (513, 0, 1, torch.double),
+    (1024, 1024),
+    (1024, 1024, 500, torch.float32),
+    (1024, 1024, 1023),
+    (1024, 1024, -500),
+    (1023, 1025),
+    (1025, 1023, 1022),
+    (1024, 1024, -500),
+    (3, 2028),
+    (3, 2028, 1),
+    (3, 2028, -1),
+    (2028, 3),
+    (2028, 1),
+    (2028, 1, -1)
+]
+
+tri_large_tests_args = [
+    (1, 268435455),
+    # Large test cases below are deliberately commented out to speed up CI
+    # tests and to avoid OOM error. When modifying implementations of
+    # tril_indices and triu_indices, please enable these tests and make sure
+    # they pass.
+    #
+    # (5000, 5000),
+    # (10000, 10000),
+    # (268435455, 1),
+    # (134217727, 2, 1),
+    # (2, 134217727, 1),
+    # (536870901, 1),
+    # (1, 536870901),
+    # (268435455, 2, 1),
+    # (2, 268435455, 1)
+]
+
+
+def run_additional_tri_tests(self, device):
+    x = torch.ones(
+        3, 3, dtype=torch.long, device=device, layout=torch.strided)
+    l = x.tril(0).nonzero().transpose(0, 1)
+    u = x.triu(0).nonzero().transpose(0, 1)
+    self.assertEqual(l, torch.tril_indices(3, 3, device=device))
+    self.assertEqual(
+        l, torch.tril_indices(3, 3, device=device, layout=torch.strided))
+
+    self.assertEqual(u, torch.triu_indices(3, 3, device=device))
+    self.assertEqual(
+        u, torch.triu_indices(3, 3, device=device, layout=torch.strided))
+
+    self.assertRaises(
+        RuntimeError,
+        lambda: torch.triu_indices(
+            1, 1, device=device, layout=torch.sparse_coo))
+
+    self.assertRaises(
+        RuntimeError,
+        lambda: torch.tril_indices(
+            1, 1, device=device, layout=torch.sparse_coo))
+
+
 def unpack_variables(args):
     if isinstance(args, tuple):
         return tuple(unpack_variables(elem) for elem in args)
index 9ddb95a..32a49ca 100644 (file)
@@ -16,6 +16,8 @@ from torch._six import inf, nan
 
 from test_torch import _TestTorchMixin
 
+from common_methods_invocations import tri_tests_args, tri_large_tests_args, \
+    run_additional_tri_tests, _compare_trilu_indices, _compare_large_trilu_indices
 from common_utils import TestCase, get_gpu_type, to_gpu, freeze_rng_state, run_tests, \
     PY3, IS_WINDOWS, NO_MULTIPROCESSING_SPAWN, skipIfRocm, TEST_NUMPY, TEST_WITH_ROCM, load_tests
 
@@ -2118,16 +2120,23 @@ class TestCuda(TestCase):
                 y = torch.randn(2, 1, device='cuda')
                 z = x + y
 
-    def test_tril_and_triu_indices(self):
-        self.assertRaises(
-            RuntimeError,
-            lambda: torch.triu_indices(
-                1, 1, device='cuda', layout=torch.strided))
-
-        self.assertRaises(
-            RuntimeError,
-            lambda: torch.tril_indices(
-                1, 1, device='cuda', layout=torch.strided))
+    def test_trilu_indices(self):
+        for test_args in tri_tests_args:
+            _compare_trilu_indices(self, *test_args, device='cuda')
+
+        # test default options
+        x = torch.ones(
+            3, 3, dtype=torch.long, device='cuda', layout=torch.strided)
+        self.assertEqual(
+            x.tril(0).nonzero().transpose(0, 1),
+            torch.tril_indices(3, 3, device='cuda'))
+        self.assertEqual(
+            x.triu(0).nonzero().transpose(0, 1),
+            torch.triu_indices(3, 3, device='cuda'))
+
+    def test_large_trilu_indices(self):
+        for test_args in tri_large_tests_args:
+            _compare_large_trilu_indices(self, *test_args, device='cuda')
 
 
 def load_ignore_file():
index 0386ba2..2d16c2d 100644 (file)
@@ -22,6 +22,8 @@ from torch._six import inf, nan, string_classes
 from itertools import product, combinations
 from functools import reduce
 from torch import multiprocessing as mp
+from common_methods_invocations import tri_tests_args, run_additional_tri_tests, \
+    _compare_trilu_indices
 from common_utils import TestCase, iter_indices, TEST_NUMPY, TEST_SCIPY, TEST_MKL, \
     TEST_LIBROSA, run_tests, download_file, skipIfNoLapack, suppress_warnings, \
     IS_WINDOWS, PY3, NO_MULTIPROCESSING_SPAWN, skipIfRocm, do_test_dtypes, do_test_empty_full, \
@@ -3760,91 +3762,19 @@ class _TestTorchMixin(object):
         torch.tril(x, out=res2)
         self.assertEqual(res1, res2, 0)
 
-    def _compare_trilu_indices(self, row, col, offset=0, dtype=torch.long):
-        if row == 0 or col == 0:
-            # have to handle this separately as tril and triu does not take
-            # empty matrix as input
-            self.assertEqual(
-                torch.empty(0, 2, dtype=dtype).transpose(0, 1),
-                torch.tril_indices(row, col, offset, dtype=dtype))
+    def test_trilu_indices(self):
+        for test_args in tri_tests_args:
+            _compare_trilu_indices(self, *test_args)
 
-            self.assertEqual(
-                torch.empty(0, 2, dtype=dtype).transpose(0, 1),
-                torch.triu_indices(row, col, offset, dtype=dtype))
-
-        else:
-            self.assertEqual(
-                torch.ones(row, col, dtype=dtype)
-                     .tril(offset).nonzero().transpose(0, 1),
-                torch.tril_indices(row, col, offset, dtype=dtype))
-
-            self.assertEqual(
-                torch.ones(row, col, dtype=dtype)
-                     .triu(offset).nonzero().transpose(0, 1),
-                torch.triu_indices(row, col, offset, dtype=dtype))
-
-    def test_tril_and_triu_indices(self):
-        self._compare_trilu_indices(1, 1)
-        self._compare_trilu_indices(3, 3)
-        self._compare_trilu_indices(3, 3, offset=1)
-        self._compare_trilu_indices(3, 3, offset=2)
-        self._compare_trilu_indices(3, 3, offset=200)
-        self._compare_trilu_indices(3, 3, offset=-1)
-        self._compare_trilu_indices(3, 3, offset=-2)
-        self._compare_trilu_indices(3, 3, offset=-200)
-        self._compare_trilu_indices(0, 3, offset=0)
-        self._compare_trilu_indices(0, 3, offset=1)
-        self._compare_trilu_indices(0, 3, offset=-1)
-        self._compare_trilu_indices(3, 0, offset=0)
-        self._compare_trilu_indices(3, 0, offset=1)
-        self._compare_trilu_indices(3, 0, offset=-1)
-        self._compare_trilu_indices(0, 0, offset=0)
-        self._compare_trilu_indices(0, 0, offset=1)
-        self._compare_trilu_indices(0, 0, offset=-1)
-        self._compare_trilu_indices(3, 6, offset=0)
-        self._compare_trilu_indices(3, 6, offset=1)
-        self._compare_trilu_indices(3, 6, offset=3)
-        self._compare_trilu_indices(3, 6, offset=9)
-        self._compare_trilu_indices(3, 6, offset=-1)
-        self._compare_trilu_indices(3, 6, offset=-3)
-        self._compare_trilu_indices(3, 6, offset=-9)
-        self._compare_trilu_indices(6, 3, offset=0)
-        self._compare_trilu_indices(6, 3, offset=1)
-        self._compare_trilu_indices(6, 3, offset=3)
-        self._compare_trilu_indices(6, 3, offset=9)
-        self._compare_trilu_indices(6, 3, offset=-1)
-        self._compare_trilu_indices(6, 3, offset=-3)
-        self._compare_trilu_indices(6, 3, offset=-9)
-        self._compare_trilu_indices(258, 253, offset=1, dtype=torch.float32)
-        self._compare_trilu_indices(257, 258, offset=1, dtype=torch.float64)
-        self._compare_trilu_indices(258, 258, offset=1, dtype=torch.short)
-        self._compare_trilu_indices(3, 513, offset=1, dtype=torch.long)
-        self._compare_trilu_indices(513, 3, offset=1, dtype=torch.int)
-        self._compare_trilu_indices(513, 0, offset=1, dtype=torch.double)
+        run_additional_tri_tests(self, 'cpu')
 
+        # test default options
         x = torch.ones(
             3, 3, dtype=torch.long, device='cpu', layout=torch.strided)
-        l = x.tril(0).nonzero().transpose(0, 1)
-        u = x.triu(0).nonzero().transpose(0, 1)
-        self.assertEqual(l, torch.tril_indices(3, 3))
-        self.assertEqual(l, torch.tril_indices(3, 3, device='cpu'))
         self.assertEqual(
-            l, torch.tril_indices(3, 3, device='cpu', layout=torch.strided))
-
-        self.assertEqual(u, torch.triu_indices(3, 3))
-        self.assertEqual(u, torch.triu_indices(3, 3, device='cpu'))
+            x.tril(0).nonzero().transpose(0, 1), torch.tril_indices(3, 3))
         self.assertEqual(
-            u, torch.triu_indices(3, 3, device='cpu', layout=torch.strided))
-
-        self.assertRaises(
-            RuntimeError,
-            lambda: torch.triu_indices(
-                1, 1, device='cpu', layout=torch.sparse_coo))
-
-        self.assertRaises(
-            RuntimeError,
-            lambda: torch.tril_indices(
-                1, 1, device='cpu', layout=torch.sparse_coo))
+            x.triu(0).nonzero().transpose(0, 1), torch.triu_indices(3, 3))
 
     def test_triu(self):
         x = torch.rand(SIZE, SIZE)
index 55fde66..47bb873 100644 (file)
@@ -5011,6 +5011,9 @@ the main diagonal. The main diagonal are the set of indices
 :math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]`
 where :math:`d_{1}, d_{2}` are the dimensions of the matrix.
 
+NOTE: when running on 'cuda', row * col must be less than :math:`2^{59}` to
+prevent overflow during calculation.
+""" + r"""
 Args:
     row (``int``): number of rows in the 2-D matrix.
     column (``int``): number of columns in the 2-D matrix.
@@ -5018,7 +5021,7 @@ Args:
         Default: if not provided, 0.
     dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
         Default: if ``None``, ``torch.long``.
-    device (:class:`torch.device`, optional): currently only support ``cpu``.
+    {device}
     layout (:class:`torch.layout`, optional): currently only support ``torch.strided``.
 
 Example::
@@ -5036,7 +5039,7 @@ Example::
     >>> a
     tensor([[0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3],
             [0, 1, 0, 1, 2, 0, 1, 2, 0, 1, 2]])
-""")
+""".format(**factory_common_args))
 
 add_docstr(torch.triu,
            r"""
@@ -5121,6 +5124,9 @@ the main diagonal. The main diagonal are the set of indices
 :math:`\lbrace (i, i) \rbrace` for :math:`i \in [0, \min\{d_{1}, d_{2}\} - 1]`
 where :math:`d_{1}, d_{2}` are the dimensions of the matrix.
 
+NOTE: when running on 'cuda', row * col must be less than :math:`2^{59}` to
+prevent overflow during calculation.
+""" + r"""
 Args:
     row (``int``): number of rows in the 2-D matrix.
     column (``int``): number of columns in the 2-D matrix.
@@ -5128,7 +5134,7 @@ Args:
         Default: if not provided, 0.
     dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
         Default: if ``None``, ``torch.long``.
-    device (:class:`torch.device`, optional): currently only support ``cpu``.
+    {device}
     layout (:class:`torch.layout`, optional): currently only support ``torch.strided``.
 
 Example::
@@ -5146,7 +5152,7 @@ Example::
     >>> a
     tensor([[0, 0, 1],
             [1, 2, 2]])
-""")
+""".format(**factory_common_args))
 
 add_docstr(torch.trtrs,
            r"""