});
}
-template <typename scalar_t, typename Fn>
+template <typename scalar_t, typename Comp, typename Fn>
void quick_select_template(
TensorAccessor<scalar_t, 1> arr,
int64_t k,
+ Comp gt_or_nan,
Fn swap_fn) {
int64_t P, L, R, i, j, swap;
scalar_t rswap, piv;
return;
if (R == L + 1) { // Two elements only
- if (arr[L] > arr[R]) {
+ if (gt_or_nan(arr[L], arr[R])) {
swap_fn(L, R);
}
return;
// Use median of three for pivot choice
P = (L + R) >> 1;
swap_fn(P, L + 1);
- if (arr[L + 1] > arr[R]) {
+ if (gt_or_nan(arr[L + 1], arr[R])) {
swap_fn(L + 1, R);
}
- if (arr[L] > arr[R]) {
+ if (gt_or_nan(arr[L], arr[R])) {
swap_fn(L, R);
}
- if (arr[L + 1] > arr[L]) {
+ if (gt_or_nan(arr[L + 1], arr[L])) {
swap_fn(L + 1, L);
}
do {
do
i++;
- while (arr[i] < piv);
+ while (gt_or_nan(piv, arr[i]));
do
j--;
- while (arr[j] > piv);
+ while (gt_or_nan(arr[j], piv));
if (j < i)
break;
swap_fn(i, j);
for (int64_t j = 0; j < tmp_indices.size(0); j++) {
tmp_indices[j] = j;
}
- quick_select_template(tmp_values, k - 1, [&](int64_t i, int64_t j) {
- std::swap(tmp_values[i], tmp_values[j]);
- std::swap(tmp_indices[i], tmp_indices[j]);
- });
+ // we want NaN to be sorted as top for numpy compatibility
+ quick_select_template(
+ tmp_values,
+ k - 1,
+ [](scalar_t x, scalar_t y) -> bool {
+ return ((x != x && y == y) || (x > y));
+ },
+ [&](int64_t i, int64_t j) {
+ std::swap(tmp_values[i], tmp_values[j]);
+ std::swap(tmp_indices[i], tmp_indices[j]);
+ });
*mode_value = tmp_values[k - 1];
*mode_index = tmp_indices[k - 1];
});
namespace {
-
template <typename scalar_t, typename index_t, int Dim>
__global__ void gatherKthValue(
cuda::detail::TensorInfo<scalar_t, index_t> input,
bool inRange = (i < inputSliceSize);
scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride])
: static_cast<scalar_t>(0);
- bool isKValue = inRange && (THCNumerics<scalar_t>::eq(v, kValue));
+ bool isKValue = inRange && THCNumerics<scalar_t>::eq_with_nan(v, kValue);
if (isKValue) {
kValueIndex = i;
// We use this to enable radix selection of floating-point values.
// This also gives a relative order for NaNs, but that's ok, as they
// will all be adjacent
+ // neg inf: signbit=1 exp=ff fraction=0 --> radix = 0 00 ff..
+ // pos inf: signbit=0 exp=ff fraction=0 --> radix = 1 ff 00..
+ // pos nan: signbit=0 exp=ff fraction>0 --> radix = 1 ff x>0
+ // neg nan: signbit=1 exp=ff fraction>0 --> radix = 0 00 x<ff...
static inline __device__ RadixType convert(float v) {
RadixType x = __float_as_int(v);
RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
- return (x ^ mask);
+ return (v == v) ? (x ^ mask) : 0xffffffff;
}
static inline __device__ float deconvert(RadixType v) {
static inline __device__ RadixType convert(double v) {
RadixType x = __double_as_longlong(v);
RadixType mask = -((x >> 63)) | 0x8000000000000000;
- return (x ^ mask);
+ return (v == v) ? (x ^ mask) : 0xffffffffffffffff;
}
static inline __device__ double deconvert(RadixType v) {
#if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
RadixType x = __half_as_ushort(v);
RadixType mask = -((x >> 15)) | 0x8000;
- return (x ^ mask);
+ return (v == v) ? (x ^ mask) : 0xffff;
#else
assert(false);
return 0u;
static inline __host__ __device__ bool gt(uint8_t a, uint8_t b) { return a > b; }
static inline __host__ __device__ bool ge(uint8_t a, uint8_t b) { return a >= b; }
static inline __host__ __device__ bool eq(uint8_t a, uint8_t b) { return a == b; }
+ static inline __device__ bool eq_with_nan(uint8_t a, uint8_t b) { return a == b; }
static inline __host__ __device__ bool ne(uint8_t a, uint8_t b) { return a != b; }
static inline __host__ __device__ uint8_t neg(int8_t a) { return -a; }
static inline __host__ __device__ bool gt(int8_t a, int8_t b) { return a > b; }
static inline __host__ __device__ bool ge(int8_t a, int8_t b) { return a >= b; }
static inline __host__ __device__ bool eq(int8_t a, int8_t b) { return a == b; }
+ static inline __device__ bool eq_with_nan(int8_t a, int8_t b) { return a == b; }
static inline __host__ __device__ bool ne(int8_t a, int8_t b) { return a != b; }
static inline __host__ __device__ int8_t neg(int8_t a) { return -a; }
static inline __host__ __device__ bool gt(int16_t a, int16_t b) { return a > b; }
static inline __host__ __device__ bool ge(int16_t a, int16_t b) { return a >= b; }
static inline __host__ __device__ bool eq(int16_t a, int16_t b) { return a == b; }
+ static inline __device__ bool eq_with_nan(int16_t a, int16_t b) { return a == b; }
static inline __host__ __device__ bool ne(int16_t a, int16_t b) { return a != b; }
static inline __host__ __device__ int16_t neg(int16_t a) { return -a; }
static inline __host__ __device__ bool gt(int32_t a, int32_t b) { return a > b; }
static inline __host__ __device__ bool ge(int32_t a, int32_t b) { return a >= b; }
static inline __host__ __device__ bool eq(int32_t a, int32_t b) { return a == b; }
+ static inline __device__ bool eq_with_nan(int32_t a, int32_t b) { return a == b; }
static inline __host__ __device__ bool ne(int32_t a, int32_t b) { return a != b; }
static inline __host__ __device__ int32_t neg(int32_t a) { return -a; }
static inline __host__ __device__ bool gt(int64_t a, int64_t b) { return a > b; }
static inline __host__ __device__ bool ge(int64_t a, int64_t b) { return a >= b; }
static inline __host__ __device__ bool eq(int64_t a, int64_t b) { return a == b; }
+ static inline __device__ bool eq_with_nan(int64_t a, int64_t b) { return a == b; }
static inline __host__ __device__ bool ne(int64_t a, int64_t b) { return a != b; }
static inline __host__ __device__ bool gt(at::Half a, at::Half b) { return a > b; }
static inline __host__ __device__ bool ge(at::Half a, at::Half b) { return a >= b; }
static inline __host__ __device__ bool eq(at::Half a, at::Half b) { return a == b; }
+ static inline __device__ bool eq_with_nan(at::Half a, at::Half b) { return __half_as_ushort(a) == __half_as_ushort(b); }
static inline __host__ __device__ bool ne(at::Half a, at::Half b) { return a != b; }
static inline __host__ __device__ at::Half exp(at::Half a) { return std::exp(a); }
static inline __host__ __device__ bool gt(float a, float b) { return a > b; }
static inline __host__ __device__ bool ge(float a, float b) { return a >= b; }
static inline __host__ __device__ bool eq(float a, float b) { return a == b; }
+ static inline __device__ bool eq_with_nan(float a, float b) { return __float_as_int(a) == __float_as_int(b); }
static inline __host__ __device__ bool ne(float a, float b) { return a != b; }
static inline __host__ __device__ float lgamma(float a) { return lgammaf(a);}
static inline __host__ __device__ bool gt(double a, double b) { return a > b; }
static inline __host__ __device__ bool ge(double a, double b) { return a >= b; }
static inline __host__ __device__ bool eq(double a, double b) { return a == b; }
+ static inline __device__ bool eq_with_nan(double a, double b) { return __double_as_longlong(a) == __double_as_longlong(b); }
static inline __host__ __device__ bool ne(double a, double b) { return a != b; }
static inline __host__ __device__ double lgamma(double a) { return ::lgamma(a);}
def test_advancedindex_big(self):
_TestTorchMixin._test_advancedindex_big(self, lambda t: t.cuda())
+ def test_kthvalue(self):
+ _TestTorchMixin._test_kthvalue(self, device='cuda')
+
@skipIfRocm
def test_btrifact(self):
_TestTorchMixin._test_btrifact(self, lambda t: t.cuda())
self.assertEqual(top1, top2)
self.assertEqual(idx1, idx2)
- def test_kthvalue(self):
+ @staticmethod
+ def _test_kthvalue(self, device='cpu'):
SIZE = 50
- x = torch.rand(SIZE, SIZE, SIZE)
+ x = torch.rand(SIZE, SIZE, SIZE, device=device)
x0 = x.clone()
k = random.randint(1, SIZE)
self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0)
# test use of result tensors
k = random.randint(1, SIZE)
- res1val = torch.Tensor()
- res1ind = torch.LongTensor()
+ res1val = torch.tensor([], device=device)
+ res1ind = torch.tensor([], dtype=torch.long, device=device)
torch.kthvalue(x, k, keepdim=False, out=(res1val, res1ind))
res2val, res2ind = torch.sort(x)
self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0)
self.assertEqual(x, x0, 0)
# simple test case (with repetitions)
- y = torch.Tensor((3, 5, 4, 1, 1, 5))
+ y = torch.tensor((3., 5, 4, 1, 1, 5), device=device)
self.assertEqual(torch.kthvalue(y, 3)[0], 3, 0)
self.assertEqual(torch.kthvalue(y, 2)[0], 1, 0)
+ # simple test case (with NaN)
+ SIZE = 50
+ x = torch.rand(SIZE, SIZE, SIZE, device=device)
+ x[torch.arange(SIZE), :, torch.randint(50, (50,))] = nan
+ ks = [random.randint(1, SIZE), 1, SIZE, SIZE - 1]
+ res2val, res2ind = torch.sort(x)
+ for k in ks:
+ res1val, res1ind = torch.kthvalue(x, k, keepdim=False)
+ self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0)
+ self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0)
+
+ def test_kthvalue(self):
+ self._test_kthvalue(self)
+
def test_median(self):
for size in (155, 156):
x = torch.rand(size, size)