kthvalue consistency with sort in the presence of NaN (#17824)
authorThomas Viehmann <tv.code@beamnet.de>
Tue, 12 Mar 2019 15:45:17 +0000 (08:45 -0700)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Tue, 12 Mar 2019 15:49:19 +0000 (08:49 -0700)
Summary:
This PR causes kthvalue to be consistent with sort
(i.e. treat NaN as larger than any number), so that
`a.kthvalue(n) == a.sort()[n - 1]`.

One drawback is that median with a NaN argument does not return NaN,
which is a deviation from NumPy.

Thank you, ngimel, for raising this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/17824

Differential Revision: D14410092

Pulled By: ezyang

fbshipit-source-id: bdec2d8272dc4c65bcf2f9b8995e237774c44c02

aten/src/ATen/native/Sorting.cpp
aten/src/ATen/native/cuda/SortingKthValue.cu
aten/src/ATen/native/cuda/SortingRadixSelect.cuh
aten/src/THC/THCNumerics.cuh
test/test_cuda.py
test/test_torch.py

index 9f52b2a..01adda7 100644 (file)
@@ -65,10 +65,11 @@ void dim_apply(TensorList tensors, int64_t dim, Fn f) {
   });
 }
 
-template <typename scalar_t, typename Fn>
+template <typename scalar_t, typename Comp, typename Fn>
 void quick_select_template(
     TensorAccessor<scalar_t, 1> arr,
     int64_t k,
+    Comp gt_or_nan,
     Fn swap_fn) {
   int64_t P, L, R, i, j, swap;
   scalar_t rswap, piv;
@@ -80,7 +81,7 @@ void quick_select_template(
       return;
 
     if (R == L + 1) { // Two elements only
-      if (arr[L] > arr[R]) {
+      if (gt_or_nan(arr[L], arr[R])) {
         swap_fn(L, R);
       }
       return;
@@ -89,13 +90,13 @@ void quick_select_template(
     // Use median of three for pivot choice
     P = (L + R) >> 1;
     swap_fn(P, L + 1);
-    if (arr[L + 1] > arr[R]) {
+    if (gt_or_nan(arr[L + 1], arr[R])) {
       swap_fn(L + 1, R);
     }
-    if (arr[L] > arr[R]) {
+    if (gt_or_nan(arr[L], arr[R])) {
       swap_fn(L, R);
     }
-    if (arr[L + 1] > arr[L]) {
+    if (gt_or_nan(arr[L + 1], arr[L])) {
       swap_fn(L + 1, L);
     }
 
@@ -105,10 +106,10 @@ void quick_select_template(
     do {
       do
         i++;
-      while (arr[i] < piv);
+      while (gt_or_nan(piv, arr[i]));
       do
         j--;
-      while (arr[j] > piv);
+      while (gt_or_nan(arr[j], piv));
       if (j < i)
         break;
       swap_fn(i, j);
@@ -165,10 +166,17 @@ std::tuple<Tensor&, Tensor&> kthvalue_out_cpu(
           for (int64_t j = 0; j < tmp_indices.size(0); j++) {
             tmp_indices[j] = j;
           }
-          quick_select_template(tmp_values, k - 1, [&](int64_t i, int64_t j) {
-            std::swap(tmp_values[i], tmp_values[j]);
-            std::swap(tmp_indices[i], tmp_indices[j]);
-          });
+          // we want NaN to be sorted as top for numpy compatibility
+          quick_select_template(
+              tmp_values,
+              k - 1,
+              [](scalar_t x, scalar_t y) -> bool {
+                return ((x != x && y == y) || (x > y));
+              },
+              [&](int64_t i, int64_t j) {
+                std::swap(tmp_values[i], tmp_values[j]);
+                std::swap(tmp_indices[i], tmp_indices[j]);
+              });
           *mode_value = tmp_values[k - 1];
           *mode_index = tmp_indices[k - 1];
         });
index 50d0b90..ebf1dc8 100644 (file)
@@ -27,7 +27,6 @@ namespace native {
 
 namespace {
 
-
 template <typename scalar_t, typename index_t, int Dim>
 __global__ void gatherKthValue(
     cuda::detail::TensorInfo<scalar_t, index_t> input,
@@ -82,7 +81,7 @@ __global__ void gatherKthValue(
     bool inRange = (i < inputSliceSize);
     scalar_t v = inRange ? doLdg(&inputSliceStart[i * inputWithinSliceStride])
                          : static_cast<scalar_t>(0);
-    bool isKValue = inRange && (THCNumerics<scalar_t>::eq(v, kValue));
+    bool isKValue = inRange && THCNumerics<scalar_t>::eq_with_nan(v, kValue);
 
     if (isKValue) {
       kValueIndex = i;
index 2c5d81a..9efd035 100644 (file)
@@ -14,11 +14,15 @@ struct TopKTypeConfig<float> {
   // We use this to enable radix selection of floating-point values.
   // This also gives a relative order for NaNs, but that's ok, as they
   // will all be adjacent
+  // neg inf: signbit=1 exp=ff fraction=0 --> radix = 0 00 ff..
+  // pos inf: signbit=0 exp=ff fraction=0 --> radix = 1 ff 00..
+  // pos nan: signbit=0 exp=ff fraction>0 --> radix = 1 ff x>0
+  // neg nan: signbit=1 exp=ff fraction>0 --> radix = 0 00 x<ff...
   static inline __device__ RadixType convert(float v) {
     RadixType x = __float_as_int(v);
     RadixType mask = (x & 0x80000000) ? 0xffffffff : 0x80000000;
 
-    return (x ^ mask);
+    return (v == v) ? (x ^ mask) : 0xffffffff;
   }
 
   static inline __device__ float deconvert(RadixType v) {
@@ -103,7 +107,7 @@ struct TopKTypeConfig<double> {
   static inline __device__ RadixType convert(double v) {
     RadixType x = __double_as_longlong(v);
     RadixType mask = -((x >> 63)) | 0x8000000000000000;
-    return (x ^ mask);
+    return (v == v) ? (x ^ mask) : 0xffffffffffffffff;
   }
 
   static inline __device__ double deconvert(RadixType v) {
@@ -120,7 +124,7 @@ struct TopKTypeConfig<at::Half> {
 #if CUDA_VERSION >= 8000 || defined __HIP_PLATFORM_HCC__
     RadixType x = __half_as_ushort(v);
     RadixType mask = -((x >> 15)) | 0x8000;
-    return (x ^ mask);
+    return (v == v) ? (x ^ mask) : 0xffff;
 #else
     assert(false);
     return 0u;
index 1a1a402..2547277 100644 (file)
@@ -50,6 +50,7 @@ struct THCNumerics<uint8_t> {
   static inline __host__ __device__ bool gt(uint8_t a, uint8_t b) { return a > b; }
   static inline __host__ __device__ bool ge(uint8_t a, uint8_t b) { return a >= b; }
   static inline __host__ __device__ bool eq(uint8_t a, uint8_t b) { return a == b; }
+  static inline __device__ bool eq_with_nan(uint8_t a, uint8_t b) { return a == b; }
   static inline __host__ __device__ bool ne(uint8_t a, uint8_t b) { return a != b; }
 
   static inline __host__ __device__  uint8_t neg(int8_t a) { return -a; }
@@ -75,6 +76,7 @@ struct THCNumerics<int8_t> {
   static inline __host__ __device__ bool gt(int8_t a, int8_t b) { return a > b; }
   static inline __host__ __device__ bool ge(int8_t a, int8_t b) { return a >= b; }
   static inline __host__ __device__ bool eq(int8_t a, int8_t b) { return a == b; }
+  static inline __device__ bool eq_with_nan(int8_t a, int8_t b) { return a == b; }
   static inline __host__ __device__ bool ne(int8_t a, int8_t b) { return a != b; }
 
   static inline __host__ __device__  int8_t neg(int8_t a) { return -a; }
@@ -100,6 +102,7 @@ struct THCNumerics<int16_t> {
   static inline __host__ __device__ bool gt(int16_t a, int16_t b) { return a > b; }
   static inline __host__ __device__ bool ge(int16_t a, int16_t b) { return a >= b; }
   static inline __host__ __device__ bool eq(int16_t a, int16_t b) { return a == b; }
+  static inline __device__ bool eq_with_nan(int16_t a, int16_t b) { return a == b; }
   static inline __host__ __device__ bool ne(int16_t a, int16_t b) { return a != b; }
 
   static inline __host__ __device__  int16_t neg(int16_t a) { return -a; }
@@ -125,6 +128,7 @@ struct THCNumerics<int32_t> {
   static inline __host__ __device__ bool gt(int32_t a, int32_t b) { return a > b; }
   static inline __host__ __device__ bool ge(int32_t a, int32_t b) { return a >= b; }
   static inline __host__ __device__ bool eq(int32_t a, int32_t b) { return a == b; }
+  static inline __device__ bool eq_with_nan(int32_t a, int32_t b) { return a == b; }
   static inline __host__ __device__ bool ne(int32_t a, int32_t b) { return a != b; }
 
   static inline __host__ __device__  int32_t neg(int32_t a) { return -a; }
@@ -150,6 +154,7 @@ struct THCNumerics<int64_t> {
   static inline __host__ __device__ bool gt(int64_t a, int64_t b) { return a > b; }
   static inline __host__ __device__ bool ge(int64_t a, int64_t b) { return a >= b; }
   static inline __host__ __device__ bool eq(int64_t a, int64_t b) { return a == b; }
+  static inline __device__ bool eq_with_nan(int64_t a, int64_t b) { return a == b; }
   static inline __host__ __device__ bool ne(int64_t a, int64_t b) { return a != b; }
 
 
@@ -177,6 +182,7 @@ struct THCNumerics<at::Half> {
   static inline __host__ __device__ bool gt(at::Half a, at::Half b) { return a > b; }
   static inline __host__ __device__ bool ge(at::Half a, at::Half b) { return a >= b; }
   static inline __host__ __device__ bool eq(at::Half a, at::Half b) { return a == b; }
+  static inline __device__ bool eq_with_nan(at::Half a, at::Half b) { return __half_as_ushort(a) == __half_as_ushort(b); }
   static inline __host__ __device__ bool ne(at::Half a, at::Half b) { return a != b; }
 
   static inline __host__ __device__ at::Half exp(at::Half a) { return std::exp(a); }
@@ -261,6 +267,7 @@ struct THCNumerics<float> {
   static inline __host__ __device__ bool gt(float a, float b) { return a > b; }
   static inline __host__ __device__ bool ge(float a, float b) { return a >= b; }
   static inline __host__ __device__ bool eq(float a, float b) { return a == b; }
+  static inline __device__ bool eq_with_nan(float a, float b) { return __float_as_int(a) == __float_as_int(b); }
   static inline __host__ __device__ bool ne(float a, float b) { return a != b; }
 
   static inline __host__ __device__  float lgamma(float a) { return lgammaf(a);}
@@ -320,6 +327,7 @@ struct THCNumerics<double> {
   static inline __host__ __device__ bool gt(double a, double b) { return a > b; }
   static inline __host__ __device__ bool ge(double a, double b) { return a >= b; }
   static inline __host__ __device__ bool eq(double a, double b) { return a == b; }
+  static inline __device__ bool eq_with_nan(double a, double b) { return __double_as_longlong(a) == __double_as_longlong(b); }
   static inline __host__ __device__ bool ne(double a, double b) { return a != b; }
 
   static inline __host__ __device__  double lgamma(double a) { return ::lgamma(a);}
index 4ec3472..d6b15db 100644 (file)
@@ -2353,6 +2353,9 @@ class TestCuda(TestCase):
     def test_advancedindex_big(self):
         _TestTorchMixin._test_advancedindex_big(self, lambda t: t.cuda())
 
+    def test_kthvalue(self):
+        _TestTorchMixin._test_kthvalue(self, device='cuda')
+
     @skipIfRocm
     def test_btrifact(self):
         _TestTorchMixin._test_btrifact(self, lambda t: t.cuda())
index a50ab5c..85e59ea 100644 (file)
@@ -4048,9 +4048,10 @@ class _TestTorchMixin(object):
         self.assertEqual(top1, top2)
         self.assertEqual(idx1, idx2)
 
-    def test_kthvalue(self):
+    @staticmethod
+    def _test_kthvalue(self, device='cpu'):
         SIZE = 50
-        x = torch.rand(SIZE, SIZE, SIZE)
+        x = torch.rand(SIZE, SIZE, SIZE, device=device)
         x0 = x.clone()
 
         k = random.randint(1, SIZE)
@@ -4061,8 +4062,8 @@ class _TestTorchMixin(object):
         self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0)
         # test use of result tensors
         k = random.randint(1, SIZE)
-        res1val = torch.Tensor()
-        res1ind = torch.LongTensor()
+        res1val = torch.tensor([], device=device)
+        res1ind = torch.tensor([], dtype=torch.long, device=device)
         torch.kthvalue(x, k, keepdim=False, out=(res1val, res1ind))
         res2val, res2ind = torch.sort(x)
         self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0)
@@ -4088,10 +4089,24 @@ class _TestTorchMixin(object):
         self.assertEqual(x, x0, 0)
 
         # simple test case (with repetitions)
-        y = torch.Tensor((3, 5, 4, 1, 1, 5))
+        y = torch.tensor((3., 5, 4, 1, 1, 5), device=device)
         self.assertEqual(torch.kthvalue(y, 3)[0], 3, 0)
         self.assertEqual(torch.kthvalue(y, 2)[0], 1, 0)
 
+        # simple test case (with NaN)
+        SIZE = 50
+        x = torch.rand(SIZE, SIZE, SIZE, device=device)
+        x[torch.arange(SIZE), :, torch.randint(50, (50,))] = nan
+        ks = [random.randint(1, SIZE), 1, SIZE, SIZE - 1]
+        res2val, res2ind = torch.sort(x)
+        for k in ks:
+            res1val, res1ind = torch.kthvalue(x, k, keepdim=False)
+            self.assertEqual(res1val[:, :], res2val[:, :, k - 1], 0)
+            self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], 0)
+
+    def test_kthvalue(self):
+        self._test_kthvalue(self)
+
     def test_median(self):
         for size in (155, 156):
             x = torch.rand(size, size)