From: Thomas Viehmann Date: Tue, 12 Mar 2019 15:45:17 +0000 (-0700) Subject: kthvalue consistency with sort in the presence of NaN (#17824) X-Git-Tag: accepted/tizen/6.5/unified/20211028.231830~868 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=aba9051a65598f4cc624ab4b517ca2bf164b249f;p=platform%2Fupstream%2Fpytorch.git kthvalue consistency with sort in the presence of NaN (#17824) Summary: This PR causes kthvalue to be consistent with sort (i.e. treat NaN as larger than any number), so that `a.kthvalue(n) == a.sort()[n - 1]`. One drawback is that median with a NaN argument does not return NaN, which is a deviation from NumPy. Thank you, ngimel, for raising this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/17824 Differential Revision: D14410092 Pulled By: ezyang fbshipit-source-id: bdec2d8272dc4c65bcf2f9b8995e237774c44c02 --- diff --git a/aten/src/ATen/native/Sorting.cpp b/aten/src/ATen/native/Sorting.cpp index 9f52b2a..01adda7 100644 --- a/aten/src/ATen/native/Sorting.cpp +++ b/aten/src/ATen/native/Sorting.cpp @@ -65,10 +65,11 @@ void dim_apply(TensorList tensors, int64_t dim, Fn f) { }); } -template +template void quick_select_template( TensorAccessor arr, int64_t k, + Comp gt_or_nan, Fn swap_fn) { int64_t P, L, R, i, j, swap; scalar_t rswap, piv; @@ -80,7 +81,7 @@ void quick_select_template( return; if (R == L + 1) { // Two elements only - if (arr[L] > arr[R]) { + if (gt_or_nan(arr[L], arr[R])) { swap_fn(L, R); } return; @@ -89,13 +90,13 @@ void quick_select_template( // Use median of three for pivot choice P = (L + R) >> 1; swap_fn(P, L + 1); - if (arr[L + 1] > arr[R]) { + if (gt_or_nan(arr[L + 1], arr[R])) { swap_fn(L + 1, R); } - if (arr[L] > arr[R]) { + if (gt_or_nan(arr[L], arr[R])) { swap_fn(L, R); } - if (arr[L + 1] > arr[L]) { + if (gt_or_nan(arr[L + 1], arr[L])) { swap_fn(L + 1, L); } @@ -105,10 +106,10 @@ void quick_select_template( do { do i++; - while (arr[i] < piv); + while (gt_or_nan(piv, arr[i])); do j--; - while (arr[j] > piv); + while (gt_or_nan(arr[j], piv)); if (j < i) break; swap_fn(i, j); @@ -165,10 +166,17 @@ std::tuple kthvalue_out_cpu( for (int64_t j = 0; j < tmp_indices.size(0); j++) { tmp_indices[j] = j; } - quick_select_template(tmp_values, k - 1, [&](int64_t i, int64_t j) { - std::swap(tmp_values[i], tmp_values[j]); - std::swap(tmp_indices[i], tmp_indices[j]); - }); + // we want NaN to be sorted as top for numpy compatibility + quick_select_template( + tmp_values, + k - 1, + [](scalar_t x, scalar_t y) -> bool { + return ((x != x && y == y) || (x > y)); + }, + [&](int64_t i, int64_t j) { + std::swap(tmp_values[i], tmp_values[j]); + std::swap(tmp_indices[i], tmp_indices[j]); + }); *mode_value = tmp_values[k - 1]; *mode_index = tmp_indices[k - 1]; }); diff --git a/aten/src/ATen/native/cuda/SortingKthValue.cu b/aten/src/ATen/native/cuda/SortingKthValue.cu index 50d0b90..ebf1dc8 100644 --- a/aten/src/ATen/native/cuda/SortingKthValue.cu +++ b/aten/src/ATen/native/cuda/SortingKthValue.cu @@ -27,7 +27,6 @@ namespace native { namespace { - template __global__ void gatherKthValue( cuda::detail::TensorInfo input, @@ -82,7 +81,7 @@ __global__ void gatherKthValue( bool inRange = (i < inputSliceSize); scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride]) : static_cast(0); - bool isKValue = inRange && (THCNumerics::eq(v, kValue)); + bool isKValue = inRange && THCNumerics::eq_with_nan(v, kValue); if (isKValue) { kValueIndex = i; diff --git a/aten/src/ATen/native/cuda/SortingRadixSelect.cuh b/aten/src/ATen/native/cuda/SortingRadixSelect.cuh index 2c5d81a..9efd035 100644 --- a/aten/src/ATen/native/cuda/SortingRadixSelect.cuh +++ b/aten/src/ATen/native/cuda/SortingRadixSelect.cuh @@ -14,11 +14,15 @@ struct TopKTypeConfig { // We use this to enable radix selection of floating-point values. // This also gives a relative order for NaNs, but that's ok, as they // will all be adjacent + // neg inf: signbit=1 exp=ff fraction=0 --> radix = 0 00 ff.. + // pos inf: signbit=0 exp=ff fraction=0 --> radix = 1 ff 00.. + // pos nan: signbit=0 exp=ff fraction>0 --> radix = 1 ff x>0 + // neg nan: signbit=1 exp=ff fraction>0 --> radix = 0 00 x { static inline __device__ RadixType convert(double v) { RadixType x = __double_as_longlong(v); RadixType mask = -((x >> 63)) | 0x8000000000000000; - return (x ^ mask); + return (v == v) ? (x ^ mask) : 0xffffffffffffffff; } static inline __device__ double deconvert(RadixType v) { @@ -120,7 +124,7 @@ struct TopKTypeConfig { #if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__ RadixType x = __half_as_ushort(v); RadixType mask = -((x >> 15)) | 0x8000; - return (x ^ mask); + return (v == v) ? (x ^ mask) : 0xffff; #else assert(false); return 0u; diff --git a/aten/src/THC/THCNumerics.cuh b/aten/src/THC/THCNumerics.cuh index 1a1a402..2547277 100644 --- a/aten/src/THC/THCNumerics.cuh +++ b/aten/src/THC/THCNumerics.cuh @@ -50,6 +50,7 @@ struct THCNumerics { static inline __host__ __device__ bool gt(uint8_t a, uint8_t b) { return a > b; } static inline __host__ __device__ bool ge(uint8_t a, uint8_t b) { return a >= b; } static inline __host__ __device__ bool eq(uint8_t a, uint8_t b) { return a == b; } + static inline __device__ bool eq_with_nan(uint8_t a, uint8_t b) { return a == b; } static inline __host__ __device__ bool ne(uint8_t a, uint8_t b) { return a != b; } static inline __host__ __device__ uint8_t neg(int8_t a) { return -a; } @@ -75,6 +76,7 @@ struct THCNumerics { static inline __host__ __device__ bool gt(int8_t a, int8_t b) { return a > b; } static inline __host__ __device__ bool ge(int8_t a, int8_t b) { return a >= b; } static inline __host__ __device__ bool eq(int8_t a, int8_t b) { return a == b; } + static inline __device__ bool eq_with_nan(int8_t a, int8_t b) { return a == b; } static inline __host__ __device__ bool ne(int8_t a, int8_t b) { return a != b; } static inline __host__ __device__ int8_t neg(int8_t a) { return -a; } @@ -100,6 +102,7 @@ struct THCNumerics { static inline __host__ __device__ bool gt(int16_t a, int16_t b) { return a > b; } static inline __host__ __device__ bool ge(int16_t a, int16_t b) { return a >= b; } static inline __host__ __device__ bool eq(int16_t a, int16_t b) { return a == b; } + static inline __device__ bool eq_with_nan(int16_t a, int16_t b) { return a == b; } static inline __host__ __device__ bool ne(int16_t a, int16_t b) { return a != b; } static inline __host__ __device__ int16_t neg(int16_t a) { return -a; } @@ -125,6 +128,7 @@ struct THCNumerics { static inline __host__ __device__ bool gt(int32_t a, int32_t b) { return a > b; } static inline __host__ __device__ bool ge(int32_t a, int32_t b) { return a >= b; } static inline __host__ __device__ bool eq(int32_t a, int32_t b) { return a == b; } + static inline __device__ bool eq_with_nan(int32_t a, int32_t b) { return a == b; } static inline __host__ __device__ bool ne(int32_t a, int32_t b) { return a != b; } static inline __host__ __device__ int32_t neg(int32_t a) { return -a; } @@ -150,6 +154,7 @@ struct THCNumerics { static inline __host__ __device__ bool gt(int64_t a, int64_t b) { return a > b; } static inline __host__ __device__ bool ge(int64_t a, int64_t b) { return a >= b; } static inline __host__ __device__ bool eq(int64_t a, int64_t b) { return a == b; } + static inline __device__ bool eq_with_nan(int64_t a, int64_t b) { return a == b; } static inline __host__ __device__ bool ne(int64_t a, int64_t b) { return a != b; } @@ -177,6 +182,7 @@ struct THCNumerics { static inline __host__ __device__ bool gt(at::Half a, at::Half b) { return a > b; } static inline __host__ __device__ bool ge(at::Half a, at::Half b) { return a >= b; } static inline __host__ __device__ bool eq(at::Half a, at::Half b) { return a == b; } + static inline __device__ bool eq_with_nan(at::Half a, at::Half b) { return __half_as_ushort(a) == __half_as_ushort(b); } static inline __host__ __device__ bool ne(at::Half a, at::Half b) { return a != b; } static inline __host__ __device__ at::Half exp(at::Half a) { return std::exp(a); } @@ -261,6 +267,7 @@ struct THCNumerics { static inline __host__ __device__ bool gt(float a, float b) { return a > b; } static inline __host__ __device__ bool ge(float a, float b) { return a >= b; } static inline __host__ __device__ bool eq(float a, float b) { return a == b; } + static inline __device__ bool eq_with_nan(float a, float b) { return __float_as_int(a) == __float_as_int(b); } static inline __host__ __device__ bool ne(float a, float b) { return a != b; } static inline __host__ __device__ float lgamma(float a) { return lgammaf(a);} @@ -320,6 +327,7 @@ struct THCNumerics { static inline __host__ __device__ bool gt(double a, double b) { return a > b; } static inline __host__ __device__ bool ge(double a, double b) { return a >= b; } static inline __host__ __device__ bool eq(double a, double b) { return a == b; } + static inline __device__ bool eq_with_nan(double a, double b) { return __double_as_longlong(a) == __double_as_longlong(b); } static inline __host__ __device__ bool ne(double a, double b) { return a != b; } static inline __host__ __device__ double lgamma(double a) { return ::lgamma(a);} diff --git a/test/test_cuda.py b/test/test_cuda.py index 4ec3472..d6b15db 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -2353,6 +2353,9 @@ class TestCuda(TestCase): def test_advancedindex_big(self): _TestTorchMixin._test_advancedindex_big(self, lambda t: t.cuda()) + def test_kthvalue(self): + _TestTorchMixin._test_kthvalue(self, device='cuda') + @skipIfRocm def test_btrifact(self): _TestTorchMixin._test_btrifact(self, lambda t: t.cuda()) diff --git a/test/test_torch.py b/test/test_torch.py index a50ab5c..85e59ea 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -4048,9 +4048,10 @@ class _TestTorchMixin(object): self.assertEqual(top1, top2) self.assertEqual(idx1, idx2) - def test_kthvalue(self): + @staticmethod + def _test_kthvalue(self, device='cpu'): SIZE = 50 - x = torch.rand(SIZE, SIZE, SIZE) + x = torch.rand(SIZE, SIZE, SIZE, device=device) x0 = x.clone() k = random.randint(1, SIZE) @@ -4061,8 +4062,8 @@ class _TestTorchMixin(object): self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0) # test use of result tensors k = random.randint(1, SIZE) - res1val = torch.Tensor() - res1ind = torch.LongTensor() + res1val = torch.tensor([], device=device) + res1ind = torch.tensor([], dtype=torch.long, device=device) torch.kthvalue(x, k, keepdim=False, out=(res1val, res1ind)) res2val, res2ind = torch.sort(x) self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0) @@ -4088,10 +4089,24 @@ class _TestTorchMixin(object): self.assertEqual(x, x0, 0) # simple test case (with repetitions) - y = torch.Tensor((3, 5, 4, 1, 1, 5)) + y = torch.tensor((3., 5, 4, 1, 1, 5), device=device) self.assertEqual(torch.kthvalue(y, 3)[0], 3, 0) self.assertEqual(torch.kthvalue(y, 2)[0], 1, 0) + # simple test case (with NaN) + SIZE = 50 + x = torch.rand(SIZE, SIZE, SIZE, device=device) + x[torch.arange(SIZE), :, torch.randint(50, (50,))] = nan + ks = [random.randint(1, SIZE), 1, SIZE, SIZE - 1] + res2val, res2ind = torch.sort(x) + for k in ks: + res1val, res1ind = torch.kthvalue(x, k, keepdim=False) + self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0) + self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0) + + def test_kthvalue(self): + self._test_kthvalue(self) + def test_median(self): for size in (155, 156): x = torch.rand(size, size)