Match NumPy by considering NaNs to be larger than any number when sorting (#15886)
authorBrennan Vincent <btv@fb.com>
Fri, 11 Jan 2019 16:09:06 +0000 (08:09 -0800)
committerFacebook Github Bot <facebook-github-bot@users.noreply.github.com>
Fri, 11 Jan 2019 16:14:11 +0000 (08:14 -0800)
Summary:
Fixes #15764
Pull Request resolved: https://github.com/pytorch/pytorch/pull/15886

Differential Revision: D13612971

Pulled By: umanwizard

fbshipit-source-id: 91f552a25d1fd108f2f0b10e09a0ce0364f8c21e

aten/src/TH/generic/THTensorMoreMath.cpp
aten/src/THC/THCSortUtils.cuh
aten/src/THC/THCTensorSort.cuh
aten/src/THC/generic/THCTensorSort.cu
test/test_torch.py

index bb7edf3..ac1b323 100644 (file)
@@ -674,6 +674,11 @@ void THTensor_(randperm)(THTensor *r_, THGenerator *_generator, int64_t n)
   REAL_SWAP(ARR(III), ARR(JJJ)); \
   LONG_SWAP(IDX(III), IDX(JJJ))
 
+/* Emulate NumPy behavior of putting NaNs
+ * at the end of an ascending list. */
+#define GT_OR_NAN(x, y) \
+  ((x != x && y == y) || (x > y))
+
 static void THTensor_(quicksortascend)(scalar_t *arr, int64_t *idx, int64_t elements, int64_t stride)
 {
   int64_t beg[MAX_LEVELS], end[MAX_LEVELS], i, j, L, R, P, swap, pid, stack = 0, sz_right, sz_left;
@@ -689,15 +694,15 @@ static void THTensor_(quicksortascend)(scalar_t *arr, int64_t *idx, int64_t elem
       /* Use median of three for pivot choice */
     P=(L+R)>>1;
     BOTH_SWAP(P, L+1);
-    if (ARR(L+1) > ARR(R)) { BOTH_SWAP(L+1, R); }
-    if (ARR(L) > ARR(R)) { BOTH_SWAP(L, R); }
-    if (ARR(L+1) > ARR(L)) { BOTH_SWAP(L+1, L); }
+    if (GT_OR_NAN(ARR(L+1), ARR(R))) { BOTH_SWAP(L+1, R); }
+    if (GT_OR_NAN(ARR(L), ARR(R))) { BOTH_SWAP(L, R); }
+    if (GT_OR_NAN(ARR(L+1), ARR(L))) { BOTH_SWAP(L+1, L); }
 
     i = L+1; j = R; piv = ARR(L); pid = IDX(L);
 
     do {
-      do { i = i+1; } while(ARR(i) < piv);
-      do { j = j-1; } while(ARR(j) > piv);
+      do { i = i+1; } while(GT_OR_NAN(piv, ARR(i)));
+      do { j = j-1; } while(GT_OR_NAN(ARR(j), piv));
       if (j < i)
           break;
       BOTH_SWAP(i, j);
@@ -748,7 +753,7 @@ static void THTensor_(quicksortascend)(scalar_t *arr, int64_t *idx, int64_t elem
   } /* while not done */
   /* Now insertion sort on the concatenation of subfiles */
   for(i=elements-2; i>=0; i--) {
-    if (ARR(i) > ARR(i+1)) {
+    if (GT_OR_NAN(ARR(i),ARR(i+1))) {
       piv = ARR(i);
       pid = IDX(i);
       j = i+1;
@@ -756,7 +761,7 @@ static void THTensor_(quicksortascend)(scalar_t *arr, int64_t *idx, int64_t elem
         ARR(j-1) = ARR(j);
         IDX(j-1) = IDX(j);
         j = j+1;
-      } while(j < elements && ARR(j) < piv);
+      } while(j < elements && GT_OR_NAN(piv, ARR(j)));
       ARR(j-1) = piv;
       IDX(j-1) = pid;
      }
@@ -778,15 +783,15 @@ static void THTensor_(quicksortdescend)(scalar_t *arr, int64_t *idx, int64_t ele
       /* Use median of three for pivot choice */
     P=(L+R)>>1;
     BOTH_SWAP(P, L+1);
-    if (ARR(L+1) < ARR(R)) { BOTH_SWAP(L+1, R); }
-    if (ARR(L) < ARR(R)) { BOTH_SWAP(L, R); }
-    if (ARR(L+1) < ARR(L)) { BOTH_SWAP(L+1, L); }
+    if (GT_OR_NAN(ARR(R), ARR(L+1))) { BOTH_SWAP(L+1, R); }
+    if (GT_OR_NAN(ARR(R), ARR(L))) { BOTH_SWAP(L, R); }
+    if (GT_OR_NAN(ARR(L), ARR(L+1))) { BOTH_SWAP(L+1, L); }
 
     i = L+1; j = R; piv = ARR(L); pid = IDX(L);
 
     do {
-      do { i = i+1; } while(ARR(i) > piv);
-      do { j = j-1; } while(ARR(j) < piv);
+      do { i = i+1; } while(GT_OR_NAN(ARR(i), piv));
+      do { j = j-1; } while(GT_OR_NAN(piv, ARR(j)));
       if (j < i)
           break;
       BOTH_SWAP(i, j);
@@ -837,7 +842,7 @@ static void THTensor_(quicksortdescend)(scalar_t *arr, int64_t *idx, int64_t ele
   } /* while not done */
   /* Now insertion sort on the concatenation of subfiles */
   for(i=elements-2; i>=0; i--) {
-    if (ARR(i) < ARR(i+1)) {
+    if (GT_OR_NAN(ARR(i+1), ARR(i))) {
       piv = ARR(i);
       pid = IDX(i);
       j = i+1;
@@ -845,7 +850,7 @@ static void THTensor_(quicksortdescend)(scalar_t *arr, int64_t *idx, int64_t ele
         ARR(j-1) = ARR(j);
         IDX(j-1) = IDX(j);
         j = j+1;
-      } while(j < elements && ARR(j) > piv);
+      } while(j < elements && GT_OR_NAN(ARR(j), piv));
       ARR(j-1) = piv;
       IDX(j-1) = pid;
      }
index 038b413..8e9bdd2 100644 (file)
@@ -7,17 +7,17 @@
 #include <c10/macros/Macros.h>
 
 // Collection of kernel sort routines
-template <typename T>
+template <typename T, bool handleNaN = false>
 struct LTComp {
   __device__ inline bool operator()(const T& a, const T& b) const {
-    return THCNumerics<T>::lt(a, b);
+    return (handleNaN && THCNumerics<T>::isnan(b) && !THCNumerics<T>::isnan(a)) || THCNumerics<T>::lt(a, b);
   }
 };
 
-template <typename T>
+template <typename T, bool handleNaN = false>
 struct GTComp {
   __device__ inline bool operator()(const T& a, const T& b) const {
-    return THCNumerics<T>::gt(a, b);
+    return (handleNaN && THCNumerics<T>::isnan(a) && !THCNumerics<T>::isnan(b)) || THCNumerics<T>::gt(a, b);
   }
 };
 
index 2d8b862..81e0275 100644 (file)
 #include <thrust/system/cuda/execution_policy.h>
 #endif
 
-template <typename T>
+template <typename T, bool handleNaN = false>
 struct ThrustGTOp {
   __device__ bool operator()(const T& lhs, const T& rhs) const {
-    return THCNumerics<T>::gt(lhs, rhs);
+    return (handleNaN && THCNumerics<T>::isnan(lhs) && !THCNumerics<T>::isnan(rhs)) || THCNumerics<T>::gt(lhs, rhs);
   }
 };
 
-template <typename T>
+template <typename T, bool handleNaN = false>
 struct ThrustLTOp {
   __device__ bool operator()(const T& lhs, const T& rhs) const {
-    return THCNumerics<T>::lt(lhs, rhs);
+    return (handleNaN && THCNumerics<T>::isnan(rhs) && !THCNumerics<T>::isnan(lhs)) || THCNumerics<T>::lt(lhs, rhs);
   }
 };
 
index 3c7a99a..9379446 100644 (file)
@@ -53,7 +53,7 @@ void THCTensor_(sortKeyValueInplace)(THCState* state,
     dim3 block(blockSize);                                              \
                                                                         \
     if (dir) {                                                          \
-      bitonicSortKVInPlace<scalar_t, int64_t, A, -1, GTComp<scalar_t>, TYPE, SIZE> \
+      bitonicSortKVInPlace<scalar_t, int64_t, A, -1, GTComp<scalar_t, true>, TYPE, SIZE> \
         <<<grid, block, 0, THCState_getCurrentStream(state)>>>(         \
           keyInfo,                                                      \
           keySlices,                                                    \
@@ -61,9 +61,9 @@ void THCTensor_(sortKeyValueInplace)(THCState* state,
           (TYPE) keyInfo.strides[collapseKeyDim],                       \
           valueInfo,                                                    \
           (TYPE) valueInfo.strides[collapseValueDim],                   \
-          GTComp<scalar_t>());                                              \
+          GTComp<scalar_t, true>());                                    \
     } else {                                                            \
-      bitonicSortKVInPlace<scalar_t, int64_t, A, -1, LTComp<scalar_t>, TYPE, SIZE> \
+      bitonicSortKVInPlace<scalar_t, int64_t, A, -1, LTComp<scalar_t, true>, TYPE, SIZE> \
         <<<grid, block, 0, THCState_getCurrentStream(state)>>>(         \
           keyInfo,                                                      \
           keySlices,                                                    \
@@ -71,7 +71,7 @@ void THCTensor_(sortKeyValueInplace)(THCState* state,
           (TYPE) keyInfo.strides[collapseKeyDim],                       \
           valueInfo,                                                    \
           (TYPE) valueInfo.strides[collapseValueDim],                   \
-          LTComp<scalar_t>());                                              \
+          LTComp<scalar_t, true>());                                              \
     }                                                                   \
   } while (0)
 
@@ -234,13 +234,13 @@ void THCTensor_(sortViaThrust)(THCState* state,
 #if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
       thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
 #endif
-      keyIter, keyIter + totalElements, indexIter, ThrustGTOp<scalar_t>());
+      keyIter, keyIter + totalElements, indexIter, ThrustGTOp<scalar_t, true>());
   } else {
     thrust::stable_sort_by_key(
 #if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
       thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
 #endif
-      keyIter, keyIter + totalElements, indexIter, ThrustLTOp<scalar_t>());
+      keyIter, keyIter + totalElements, indexIter, ThrustLTOp<scalar_t, true>());
   }
 
   // Then, re-sort according to slice that each index is
index 4d51c0f..6d48ecc 100644 (file)
@@ -3551,10 +3551,14 @@ class _TestTorchMixin(object):
         SIZE = 4
         if order == 'descending':
             def check_order(a, b):
-                return a >= b
+                # `a != a` because we put NaNs
+                # at the end of ascending sorted lists,
+                # and the beginning of descending ones.
+                return a != a or a >= b
         elif order == 'ascending':
             def check_order(a, b):
-                return a <= b
+                # see above
+                return b != b or a <= b
         else:
             error('unknown order "{}", must be "ascending" or "descending"'.format(order))
 
@@ -3629,6 +3633,17 @@ class _TestTorchMixin(object):
         # Test that we still have proper sorting with duplicate keys
         self.assertIsOrdered('descending', x, res2val, res2ind, 'random with duplicate keys')
 
+        # Test sorting with NaNs
+        x = torch.rand(SIZE, SIZE)
+        x[1][2] = float('NaN')
+        x[3][0] = float('NaN')
+        torch.sort(x, out=(res2val, res2ind))
+        self.assertIsOrdered('ascending', x, res2val, res2ind,
+                             'random with NaNs')
+        torch.sort(x, out=(res2val, res2ind), descending=True)
+        self.assertIsOrdered('descending', x, res2val, res2ind,
+                             'random with NaNs')
+
     @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
     def test_tensordot(self):
         devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']