REAL_SWAP(ARR(III), ARR(JJJ)); \
LONG_SWAP(IDX(III), IDX(JJJ))
+/* Emulate NumPy behavior of putting NaNs
+ * at the end of an ascending list. */
+#define GT_OR_NAN(x, y) \
+ ((x != x && y == y) || (x > y))
+
static void THTensor_(quicksortascend)(scalar_t *arr, int64_t *idx, int64_t elements, int64_t stride)
{
int64_t beg[MAX_LEVELS], end[MAX_LEVELS], i, j, L, R, P, swap, pid, stack = 0, sz_right, sz_left;
/* Use median of three for pivot choice */
P=(L+R)>>1;
BOTH_SWAP(P, L+1);
- if (ARR(L+1) > ARR(R)) { BOTH_SWAP(L+1, R); }
- if (ARR(L) > ARR(R)) { BOTH_SWAP(L, R); }
- if (ARR(L+1) > ARR(L)) { BOTH_SWAP(L+1, L); }
+ if (GT_OR_NAN(ARR(L+1), ARR(R))) { BOTH_SWAP(L+1, R); }
+ if (GT_OR_NAN(ARR(L), ARR(R))) { BOTH_SWAP(L, R); }
+ if (GT_OR_NAN(ARR(L+1), ARR(L))) { BOTH_SWAP(L+1, L); }
i = L+1; j = R; piv = ARR(L); pid = IDX(L);
do {
- do { i = i+1; } while(ARR(i) < piv);
- do { j = j-1; } while(ARR(j) > piv);
+ do { i = i+1; } while(GT_OR_NAN(piv, ARR(i)));
+ do { j = j-1; } while(GT_OR_NAN(ARR(j), piv));
if (j < i)
break;
BOTH_SWAP(i, j);
} /* while not done */
/* Now insertion sort on the concatenation of subfiles */
for(i=elements-2; i>=0; i--) {
- if (ARR(i) > ARR(i+1)) {
+ if (GT_OR_NAN(ARR(i),ARR(i+1))) {
piv = ARR(i);
pid = IDX(i);
j = i+1;
ARR(j-1) = ARR(j);
IDX(j-1) = IDX(j);
j = j+1;
- } while(j < elements && ARR(j) < piv);
+ } while(j < elements && GT_OR_NAN(piv, ARR(j)));
ARR(j-1) = piv;
IDX(j-1) = pid;
}
/* Use median of three for pivot choice */
P=(L+R)>>1;
BOTH_SWAP(P, L+1);
- if (ARR(L+1) < ARR(R)) { BOTH_SWAP(L+1, R); }
- if (ARR(L) < ARR(R)) { BOTH_SWAP(L, R); }
- if (ARR(L+1) < ARR(L)) { BOTH_SWAP(L+1, L); }
+ if (GT_OR_NAN(ARR(R), ARR(L+1))) { BOTH_SWAP(L+1, R); }
+ if (GT_OR_NAN(ARR(R), ARR(L))) { BOTH_SWAP(L, R); }
+ if (GT_OR_NAN(ARR(L), ARR(L+1))) { BOTH_SWAP(L+1, L); }
i = L+1; j = R; piv = ARR(L); pid = IDX(L);
do {
- do { i = i+1; } while(ARR(i) > piv);
- do { j = j-1; } while(ARR(j) < piv);
+ do { i = i+1; } while(GT_OR_NAN(ARR(i), piv));
+ do { j = j-1; } while(GT_OR_NAN(piv, ARR(j)));
if (j < i)
break;
BOTH_SWAP(i, j);
} /* while not done */
/* Now insertion sort on the concatenation of subfiles */
for(i=elements-2; i>=0; i--) {
- if (ARR(i) < ARR(i+1)) {
+ if (GT_OR_NAN(ARR(i+1), ARR(i))) {
piv = ARR(i);
pid = IDX(i);
j = i+1;
ARR(j-1) = ARR(j);
IDX(j-1) = IDX(j);
j = j+1;
- } while(j < elements && ARR(j) > piv);
+ } while(j < elements && GT_OR_NAN(ARR(j), piv));
ARR(j-1) = piv;
IDX(j-1) = pid;
}
#include <c10/macros/Macros.h>
// Collection of kernel sort routines
-template <typename T>
+template <typename T, bool handleNaN = false>
struct LTComp {
__device__ inline bool operator()(const T& a, const T& b) const {
- return THCNumerics<T>::lt(a, b);
+ return (handleNaN && THCNumerics<T>::isnan(b) && !THCNumerics<T>::isnan(a)) || THCNumerics<T>::lt(a, b);
}
};
-template <typename T>
+template <typename T, bool handleNaN = false>
struct GTComp {
__device__ inline bool operator()(const T& a, const T& b) const {
- return THCNumerics<T>::gt(a, b);
+ return (handleNaN && THCNumerics<T>::isnan(a) && !THCNumerics<T>::isnan(b)) || THCNumerics<T>::gt(a, b);
}
};
#include <thrust/system/cuda/execution_policy.h>
#endif
-template <typename T>
+template <typename T, bool handleNaN = false>
struct ThrustGTOp {
__device__ bool operator()(const T& lhs, const T& rhs) const {
- return THCNumerics<T>::gt(lhs, rhs);
+ return (handleNaN && THCNumerics<T>::isnan(lhs) && !THCNumerics<T>::isnan(rhs)) || THCNumerics<T>::gt(lhs, rhs);
}
};
-template <typename T>
+template <typename T, bool handleNaN = false>
struct ThrustLTOp {
__device__ bool operator()(const T& lhs, const T& rhs) const {
- return THCNumerics<T>::lt(lhs, rhs);
+ return (handleNaN && THCNumerics<T>::isnan(rhs) && !THCNumerics<T>::isnan(lhs)) || THCNumerics<T>::lt(lhs, rhs);
}
};
dim3 block(blockSize); \
\
if (dir) { \
- bitonicSortKVInPlace<scalar_t, int64_t, A, -1, GTComp<scalar_t>, TYPE, SIZE> \
+ bitonicSortKVInPlace<scalar_t, int64_t, A, -1, GTComp<scalar_t, true>, TYPE, SIZE> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
keyInfo, \
keySlices, \
(TYPE) keyInfo.strides[collapseKeyDim], \
valueInfo, \
(TYPE) valueInfo.strides[collapseValueDim], \
- GTComp<scalar_t>()); \
+ GTComp<scalar_t, true>()); \
} else { \
- bitonicSortKVInPlace<scalar_t, int64_t, A, -1, LTComp<scalar_t>, TYPE, SIZE> \
+ bitonicSortKVInPlace<scalar_t, int64_t, A, -1, LTComp<scalar_t, true>, TYPE, SIZE> \
<<<grid, block, 0, THCState_getCurrentStream(state)>>>( \
keyInfo, \
keySlices, \
(TYPE) keyInfo.strides[collapseKeyDim], \
valueInfo, \
(TYPE) valueInfo.strides[collapseValueDim], \
- LTComp<scalar_t>()); \
+ LTComp<scalar_t, true>()); \
} \
} while (0)
#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
- keyIter, keyIter + totalElements, indexIter, ThrustGTOp<scalar_t>());
+ keyIter, keyIter + totalElements, indexIter, ThrustGTOp<scalar_t, true>());
} else {
thrust::stable_sort_by_key(
#if CUDA_VERSION >= 7000 || defined __HIP_PLATFORM_HCC__
thrust::cuda::par(thrustAlloc).on(THCState_getCurrentStream(state)),
#endif
- keyIter, keyIter + totalElements, indexIter, ThrustLTOp<scalar_t>());
+ keyIter, keyIter + totalElements, indexIter, ThrustLTOp<scalar_t, true>());
}
// Then, re-sort according to slice that each index is
SIZE = 4
if order == 'descending':
def check_order(a, b):
- return a >= b
+ # `a != a` because we put NaNs
+ # at the end of ascending sorted lists,
+ # and the beginning of descending ones.
+ return a != a or a >= b
elif order == 'ascending':
def check_order(a, b):
- return a <= b
+ # see above
+ return b != b or a <= b
else:
error('unknown order "{}", must be "ascending" or "descending"'.format(order))
# Test that we still have proper sorting with duplicate keys
self.assertIsOrdered('descending', x, res2val, res2ind, 'random with duplicate keys')
+ # Test sorting with NaNs
+ x = torch.rand(SIZE, SIZE)
+ x[1][2] = float('NaN')
+ x[3][0] = float('NaN')
+ torch.sort(x, out=(res2val, res2ind))
+ self.assertIsOrdered('ascending', x, res2val, res2ind,
+ 'random with NaNs')
+ torch.sort(x, out=(res2val, res2ind), descending=True)
+ self.assertIsOrdered('descending', x, res2val, res2ind,
+ 'random with NaNs')
+
@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
def test_tensordot(self):
devices = ['cpu'] if not torch.cuda.is_available() else ['cpu', 'cuda']