Introduce branchless sorting functions for sort3, sort4 and sort5.
authorMarco Gelmi <marcogelmi@google.com>
Fri, 8 Apr 2022 06:58:48 +0000 (08:58 +0200)
committerNikolas Klauser <nikolasklauser@berlin.de>
Fri, 8 Apr 2022 07:00:30 +0000 (09:00 +0200)
We are introducing branchless variants for sort3, sort4 and sort5.
These sorting functions have been generated using Reinforcement
Learning and aim to replace __sort3, __sort4 and __sort5 variants
for integral types.

The libc++ benchmarks were run on isolated machines for Skylake, ARM and
AMD architectures and achieve statistically significant improvement in
sorting random integers on test cases from sort1 to sort262144 for
uint32 and uint64.

A full performance overview for Intel Skylake, AMD and Arm can be
found here: https://bit.ly/3AtesYf

Reviewed By: ldionne, #libc, philnik

Spies: daniel.mankowitz, mgrang, Quuxplusone, andreamichi, philnik, libcxx-commits, nilayvaish, kristof.beyls

Differential Revision: https://reviews.llvm.org/D118029

libcxx/benchmarks/algorithms.bench.cpp
libcxx/include/__algorithm/sort.h
libcxx/test/libcxx/algorithms/robust_against_copying_comparators.pass.cpp
libcxx/test/std/algorithms/alg.sorting/alg.sort/sort/sort.pass.cpp

index b3724ec..5a97df1 100644 (file)
 
 namespace {
 
-enum class ValueType { Uint32, Uint64, Pair, Tuple, String };
-struct AllValueTypes : EnumValuesAsTuple<AllValueTypes, ValueType, 5> {
-  static constexpr const char* Names[] = {
-      "uint32", "uint64", "pair<uint32, uint32>",
-      "tuple<uint32, uint64, uint32>", "string"};
+enum class ValueType { Uint32, Uint64, Pair, Tuple, String, Float };
+struct AllValueTypes : EnumValuesAsTuple<AllValueTypes, ValueType, 6> {
+  static constexpr const char* Names[] = {"uint32", "uint64", "pair<uint32, uint32>", "tuple<uint32, uint64, uint32>",
+                                          "string", "float"};
 };
 
+using Types = std::tuple< uint32_t, uint64_t, std::pair<uint32_t, uint32_t>, std::tuple<uint32_t, uint64_t, uint32_t>,
+                          std::string, float >;
+
 template <class V>
-using Value = std::conditional_t<
-    V() == ValueType::Uint32, uint32_t,
-    std::conditional_t<
-        V() == ValueType::Uint64, uint64_t,
-        std::conditional_t<
-            V() == ValueType::Pair, std::pair<uint32_t, uint32_t>,
-            std::conditional_t<V() == ValueType::Tuple,
-                               std::tuple<uint32_t, uint64_t, uint32_t>,
-                               std::string> > > >;
+using Value = std::tuple_element_t<(int)V::value, Types>;
 
 enum class Order {
   Random,
index 27ce647..3faff6b 100644 (file)
@@ -123,6 +123,96 @@ __sort5(_ForwardIterator __x1, _ForwardIterator __x2, _ForwardIterator __x3,
     return __r;
 }
 
+template <class _Tp>
+struct __is_simple_comparator : false_type {};
+template <class _Tp>
+struct __is_simple_comparator<__less<_Tp>&> : true_type {};
+template <class _Tp>
+struct __is_simple_comparator<less<_Tp>&> : true_type {};
+template <class _Tp>
+struct __is_simple_comparator<greater<_Tp>&> : true_type {};
+
+template <class _Compare, class _Iter, class _Tp = typename iterator_traits<_Iter>::value_type>
+using __use_branchless_sort =
+    integral_constant<bool, __is_cpp17_contiguous_iterator<_Iter>::value && sizeof(_Tp) <= sizeof(void*) &&
+                                is_arithmetic<_Tp>::value && __is_simple_comparator<_Compare>::value>;
+
+// Ensures that __c(*__x, *__y) is true by swapping *__x and *__y if necessary.
+template <class _Compare, class _RandomAccessIterator>
+inline _LIBCPP_HIDE_FROM_ABI void __cond_swap(_RandomAccessIterator __x, _RandomAccessIterator __y, _Compare __c) {
+  using value_type = typename iterator_traits<_RandomAccessIterator>::value_type;
+  bool __r = __c(*__x, *__y);
+  value_type __tmp = __r ? *__x : *__y;
+  *__y = __r ? *__y : *__x;
+  *__x = __tmp;
+}
+
+// Ensures that *__x, *__y and *__z are ordered according to the comparator __c,
+// under the assumption that *__y and *__z are already ordered.
+template <class _Compare, class _RandomAccessIterator>
+inline _LIBCPP_HIDE_FROM_ABI void __partially_sorted_swap(_RandomAccessIterator __x, _RandomAccessIterator __y,
+                                                          _RandomAccessIterator __z, _Compare __c) {
+  using value_type = typename iterator_traits<_RandomAccessIterator>::value_type;
+  bool __r = __c(*__z, *__x);
+  value_type __tmp = __r ? *__z : *__x;
+  *__z = __r ? *__x : *__z;
+  __r = __c(__tmp, *__y);
+  *__x = __r ? *__x : *__y;
+  *__y = __r ? *__y : __tmp;
+}
+
+template <class _Compare, class _RandomAccessIterator>
+inline _LIBCPP_HIDE_FROM_ABI __enable_if_t<__use_branchless_sort<_Compare, _RandomAccessIterator>::value, void>
+__sort3_maybe_branchless(_RandomAccessIterator __x1, _RandomAccessIterator __x2, _RandomAccessIterator __x3,
+                         _Compare __c) {
+  _VSTD::__cond_swap<_Compare>(__x2, __x3, __c);
+  _VSTD::__partially_sorted_swap<_Compare>(__x1, __x2, __x3, __c);
+}
+
+template <class _Compare, class _RandomAccessIterator>
+inline _LIBCPP_HIDE_FROM_ABI __enable_if_t<!__use_branchless_sort<_Compare, _RandomAccessIterator>::value, void>
+__sort3_maybe_branchless(_RandomAccessIterator __x1, _RandomAccessIterator __x2, _RandomAccessIterator __x3,
+                         _Compare __c) {
+  _VSTD::__sort3<_Compare>(__x1, __x2, __x3, __c);
+}
+
+template <class _Compare, class _RandomAccessIterator>
+inline _LIBCPP_HIDE_FROM_ABI __enable_if_t<__use_branchless_sort<_Compare, _RandomAccessIterator>::value, void>
+__sort4_maybe_branchless(_RandomAccessIterator __x1, _RandomAccessIterator __x2, _RandomAccessIterator __x3,
+                         _RandomAccessIterator __x4, _Compare __c) {
+  _VSTD::__cond_swap<_Compare>(__x1, __x3, __c);
+  _VSTD::__cond_swap<_Compare>(__x2, __x4, __c);
+  _VSTD::__cond_swap<_Compare>(__x1, __x2, __c);
+  _VSTD::__cond_swap<_Compare>(__x3, __x4, __c);
+  _VSTD::__cond_swap<_Compare>(__x2, __x3, __c);
+}
+
+template <class _Compare, class _RandomAccessIterator>
+inline _LIBCPP_HIDE_FROM_ABI __enable_if_t<!__use_branchless_sort<_Compare, _RandomAccessIterator>::value, void>
+__sort4_maybe_branchless(_RandomAccessIterator __x1, _RandomAccessIterator __x2, _RandomAccessIterator __x3,
+                         _RandomAccessIterator __x4, _Compare __c) {
+  _VSTD::__sort4<_Compare>(__x1, __x2, __x3, __x4, __c);
+}
+
+template <class _Compare, class _RandomAccessIterator>
+inline _LIBCPP_HIDE_FROM_ABI __enable_if_t<__use_branchless_sort<_Compare, _RandomAccessIterator>::value, void>
+__sort5_maybe_branchless(_RandomAccessIterator __x1, _RandomAccessIterator __x2, _RandomAccessIterator __x3,
+                         _RandomAccessIterator __x4, _RandomAccessIterator __x5, _Compare __c) {
+  _VSTD::__cond_swap<_Compare>(__x1, __x2, __c);
+  _VSTD::__cond_swap<_Compare>(__x4, __x5, __c);
+  _VSTD::__partially_sorted_swap<_Compare>(__x3, __x4, __x5, __c);
+  _VSTD::__cond_swap<_Compare>(__x2, __x5, __c);
+  _VSTD::__partially_sorted_swap<_Compare>(__x1, __x3, __x4, __c);
+  _VSTD::__partially_sorted_swap<_Compare>(__x2, __x3, __x4, __c);
+}
+
+template <class _Compare, class _RandomAccessIterator>
+inline _LIBCPP_HIDE_FROM_ABI __enable_if_t<!__use_branchless_sort<_Compare, _RandomAccessIterator>::value, void>
+__sort5_maybe_branchless(_RandomAccessIterator __x1, _RandomAccessIterator __x2, _RandomAccessIterator __x3,
+                         _RandomAccessIterator __x4, _RandomAccessIterator __x5, _Compare __c) {
+  _VSTD::__sort5<_Compare>(__x1, __x2, __x3, __x4, __x5, __c);
+}
+
 // Assumes size > 0
 template <class _Compare, class _BidirectionalIterator>
 _LIBCPP_CONSTEXPR_AFTER_CXX11 void
@@ -163,7 +253,7 @@ __insertion_sort_3(_RandomAccessIterator __first, _RandomAccessIterator __last,
     typedef typename iterator_traits<_RandomAccessIterator>::difference_type difference_type;
     typedef typename iterator_traits<_RandomAccessIterator>::value_type value_type;
     _RandomAccessIterator __j = __first+difference_type(2);
-    _VSTD::__sort3<_Compare>(__first, __first+difference_type(1), __j, __comp);
+    _VSTD::__sort3_maybe_branchless<_Compare>(__first, __first + difference_type(1), __j, __comp);
     for (_RandomAccessIterator __i = __j+difference_type(1); __i != __last; ++__i)
     {
         if (__comp(*__i, *__j))
@@ -197,18 +287,20 @@ __insertion_sort_incomplete(_RandomAccessIterator __first, _RandomAccessIterator
             swap(*__first, *__last);
         return true;
     case 3:
-        _VSTD::__sort3<_Compare>(__first, __first+difference_type(1), --__last, __comp);
-        return true;
+      _VSTD::__sort3_maybe_branchless<_Compare>(__first, __first + difference_type(1), --__last, __comp);
+      return true;
     case 4:
-        _VSTD::__sort4<_Compare>(__first, __first+difference_type(1), __first+difference_type(2), --__last, __comp);
-        return true;
+      _VSTD::__sort4_maybe_branchless<_Compare>(__first, __first + difference_type(1), __first + difference_type(2),
+                                                --__last, __comp);
+      return true;
     case 5:
-        _VSTD::__sort5<_Compare>(__first, __first+difference_type(1), __first+difference_type(2), __first+difference_type(3), --__last, __comp);
-        return true;
+      _VSTD::__sort5_maybe_branchless<_Compare>(__first, __first + difference_type(1), __first + difference_type(2),
+                                                __first + difference_type(3), --__last, __comp);
+      return true;
     }
     typedef typename iterator_traits<_RandomAccessIterator>::value_type value_type;
     _RandomAccessIterator __j = __first+difference_type(2);
-    _VSTD::__sort3<_Compare>(__first, __first+difference_type(1), __j, __comp);
+    _VSTD::__sort3_maybe_branchless<_Compare>(__first, __first + difference_type(1), __j, __comp);
     const unsigned __limit = 8;
     unsigned __count = 0;
     for (_RandomAccessIterator __i = __j+difference_type(1); __i != __last; ++__i)
@@ -290,14 +382,16 @@ __introsort(_RandomAccessIterator __first, _RandomAccessIterator __last, _Compar
                 swap(*__first, *__last);
             return;
         case 3:
-            _VSTD::__sort3<_Compare>(__first, __first+difference_type(1), --__last, __comp);
-            return;
+          _VSTD::__sort3_maybe_branchless<_Compare>(__first, __first + difference_type(1), --__last, __comp);
+          return;
         case 4:
-            _VSTD::__sort4<_Compare>(__first, __first+difference_type(1), __first+difference_type(2), --__last, __comp);
-            return;
+          _VSTD::__sort4_maybe_branchless<_Compare>(__first, __first + difference_type(1), __first + difference_type(2),
+                                                    --__last, __comp);
+          return;
         case 5:
-            _VSTD::__sort5<_Compare>(__first, __first+difference_type(1), __first+difference_type(2), __first+difference_type(3), --__last, __comp);
-            return;
+          _VSTD::__sort5_maybe_branchless<_Compare>(__first, __first + difference_type(1), __first + difference_type(2),
+                                                    __first + difference_type(3), --__last, __comp);
+          return;
         }
         if (__len <= __limit)
         {
index 7fbdf1f..66b8b36 100644 (file)
 
 #include "test_macros.h"
 
+template <class T>
 struct Less {
     int *copies_;
     TEST_CONSTEXPR explicit Less(int *copies) : copies_(copies) {}
     TEST_CONSTEXPR_CXX14 Less(const Less& rhs) : copies_(rhs.copies_) { *copies_ += 1; }
     TEST_CONSTEXPR_CXX14 Less& operator=(const Less&) = default;
-    TEST_CONSTEXPR bool operator()(void*, void*) const { return false; }
+    TEST_CONSTEXPR bool operator()(T, T) const { return false; }
 };
 
+template <class T>
 struct Equal {
     int *copies_;
     TEST_CONSTEXPR explicit Equal(int *copies) : copies_(copies) {}
     TEST_CONSTEXPR_CXX14 Equal(const Equal& rhs) : copies_(rhs.copies_) { *copies_ += 1; }
     TEST_CONSTEXPR_CXX14 Equal& operator=(const Equal&) = default;
-    TEST_CONSTEXPR bool operator()(void*, void*) const { return true; }
+    TEST_CONSTEXPR bool operator()(T, T) const { return true; }
 };
 
+template <class T>
 struct UnaryVoid {
     int *copies_;
     TEST_CONSTEXPR explicit UnaryVoid(int *copies) : copies_(copies) {}
     TEST_CONSTEXPR_CXX14 UnaryVoid(const UnaryVoid& rhs) : copies_(rhs.copies_) { *copies_ += 1; }
     TEST_CONSTEXPR_CXX14 UnaryVoid& operator=(const UnaryVoid&) = default;
-    TEST_CONSTEXPR_CXX14 void operator()(void*) const {}
+    TEST_CONSTEXPR_CXX14 void operator()(T) const {}
 };
 
+template <class T>
 struct UnaryTrue {
     int *copies_;
     TEST_CONSTEXPR explicit UnaryTrue(int *copies) : copies_(copies) {}
     TEST_CONSTEXPR_CXX14 UnaryTrue(const UnaryTrue& rhs) : copies_(rhs.copies_) { *copies_ += 1; }
     TEST_CONSTEXPR_CXX14 UnaryTrue& operator=(const UnaryTrue&) = default;
-    TEST_CONSTEXPR bool operator()(void*) const { return true; }
+    TEST_CONSTEXPR bool operator()(T) const { return true; }
 };
 
+template <class T>
 struct NullaryValue {
     int *copies_;
     TEST_CONSTEXPR explicit NullaryValue(int *copies) : copies_(copies) {}
     TEST_CONSTEXPR_CXX14 NullaryValue(const NullaryValue& rhs) : copies_(rhs.copies_) { *copies_ += 1; }
     TEST_CONSTEXPR_CXX14 NullaryValue& operator=(const NullaryValue&) = default;
-    TEST_CONSTEXPR std::nullptr_t operator()() const { return nullptr; }
+    TEST_CONSTEXPR T operator()() const { return 0; }
 };
 
+template <class T>
 struct UnaryTransform {
     int *copies_;
     TEST_CONSTEXPR explicit UnaryTransform(int *copies) : copies_(copies) {}
     TEST_CONSTEXPR_CXX14 UnaryTransform(const UnaryTransform& rhs) : copies_(rhs.copies_) { *copies_ += 1; }
     TEST_CONSTEXPR_CXX14 UnaryTransform& operator=(const UnaryTransform&) = default;
-    TEST_CONSTEXPR std::nullptr_t operator()(void*) const { return nullptr; }
+    TEST_CONSTEXPR T operator()(T) const { return 0; }
 };
 
+template <class T>
 struct BinaryTransform {
     int *copies_;
     TEST_CONSTEXPR explicit BinaryTransform(int *copies) : copies_(copies) {}
     TEST_CONSTEXPR_CXX14 BinaryTransform(const BinaryTransform& rhs) : copies_(rhs.copies_) { *copies_ += 1; }
     TEST_CONSTEXPR_CXX14 BinaryTransform& operator=(const BinaryTransform&) = default;
-    TEST_CONSTEXPR std::nullptr_t operator()(void*, void*) const { return nullptr; }
+    TEST_CONSTEXPR T operator()(T, T) const { return 0; }
 };
 
 #if TEST_STD_VER > 17
@@ -81,124 +88,130 @@ struct ThreeWay {
 };
 #endif
 
+template <class T>
 TEST_CONSTEXPR_CXX20 bool all_the_algorithms()
 {
-    void *a[10] = {};
-    void *b[10] = {};
-    void **first = a;
-    void **mid = a+5;
-    void **last = a+10;
-    void **first2 = b;
-    void **mid2 = b+5;
-    void **last2 = b+10;
-    void *value = nullptr;
+    a[10] = {};
+    b[10] = {};
+    *first = a;
+    *mid = a+5;
+    *last = a+10;
+    *first2 = b;
+    *mid2 = b+5;
+    *last2 = b+10;
+    T value = 0;
     int count = 1;
 
     int copies = 0;
-    (void)std::adjacent_find(first, last, Equal(&copies)); assert(copies == 0);
+    (void)std::adjacent_find(first, last, Equal<T>(&copies)); assert(copies == 0);
 #if TEST_STD_VER >= 11
-    (void)std::all_of(first, last, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::any_of(first, last, UnaryTrue(&copies)); assert(copies == 0);
+    (void)std::all_of(first, last, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::any_of(first, last, UnaryTrue<T>(&copies)); assert(copies == 0);
 #endif
-    (void)std::binary_search(first, last, value, Less(&copies)); assert(copies == 0);
+    (void)std::binary_search(first, last, value, Less<T>(&copies)); assert(copies == 0);
 #if TEST_STD_VER > 17
-    (void)std::clamp(value, value, value, Less(&copies)); assert(copies == 0);
+    (void)std::clamp(value, value, value, Less<T>(&copies)); assert(copies == 0);
 #endif
-    (void)std::count_if(first, last, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::copy_if(first, last, first2, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::equal(first, last, first2, Equal(&copies)); assert(copies == 0);
+    (void)std::count_if(first, last, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::copy_if(first, last, first2, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::equal(first, last, first2, Equal<T>(&copies)); assert(copies == 0);
 #if TEST_STD_VER > 11
-    (void)std::equal(first, last, first2, last2, Equal(&copies)); assert(copies == 0);
+    (void)std::equal(first, last, first2, last2, Equal<T>(&copies)); assert(copies == 0);
 #endif
-    (void)std::equal_range(first, last, value, Less(&copies)); assert(copies == 0);
-    (void)std::find_end(first, last, first2, mid2, Equal(&copies)); assert(copies == 0);
+    (void)std::equal_range(first, last, value, Less<T>(&copies)); assert(copies == 0);
+    (void)std::find_end(first, last, first2, mid2, Equal<T>(&copies)); assert(copies == 0);
     //(void)std::find_first_of(first, last, first2, last2, Equal(&copies)); assert(copies == 0);
-    (void)std::find_if(first, last, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::find_if_not(first, last, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::for_each(first, last, UnaryVoid(&copies)); assert(copies == 1); copies = 0;
+    (void)std::find_if(first, last, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::find_if_not(first, last, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::for_each(first, last, UnaryVoid<T>(&copies)); assert(copies == 1); copies = 0;
 #if TEST_STD_VER > 14
-    (void)std::for_each_n(first, count, UnaryVoid(&copies)); assert(copies == 0);
+    (void)std::for_each_n(first, count, UnaryVoid<T>(&copies)); assert(copies == 0);
 #endif
-    (void)std::generate(first, last, NullaryValue(&copies)); assert(copies == 0);
-    (void)std::generate_n(first, count, NullaryValue(&copies)); assert(copies == 0);
-    (void)std::includes(first, last, first2, last2, Less(&copies)); assert(copies == 0);
-    (void)std::is_heap(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::is_heap_until(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::is_partitioned(first, last, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::is_permutation(first, last, first2, Equal(&copies)); assert(copies == 0);
+    (void)std::generate(first, last, NullaryValue<T>(&copies)); assert(copies == 0);
+    (void)std::generate_n(first, count, NullaryValue<T>(&copies)); assert(copies == 0);
+    (void)std::includes(first, last, first2, last2, Less<T>(&copies)); assert(copies == 0);
+    (void)std::is_heap(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::is_heap_until(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::is_partitioned(first, last, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::is_permutation(first, last, first2, Equal<T>(&copies)); assert(copies == 0);
 #if TEST_STD_VER > 11
-    (void)std::is_permutation(first, last, first2, last2, Equal(&copies)); assert(copies == 0);
+    (void)std::is_permutation(first, last, first2, last2, Equal<T>(&copies)); assert(copies == 0);
 #endif
-    (void)std::is_sorted(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::is_sorted_until(first, last, Less(&copies)); assert(copies == 0);
-    if (!TEST_IS_CONSTANT_EVALUATED) { (void)std::inplace_merge(first, mid, last, Less(&copies)); assert(copies == 0); }
-    (void)std::lexicographical_compare(first, last, first2, last2, Less(&copies)); assert(copies == 0);
+    (void)std::is_sorted(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::is_sorted_until(first, last, Less<T>(&copies)); assert(copies == 0);
+    if (!TEST_IS_CONSTANT_EVALUATED) { (void)std::inplace_merge(first, mid, last, Less<T>(&copies)); assert(copies == 0); }
+    (void)std::lexicographical_compare(first, last, first2, last2, Less<T>(&copies)); assert(copies == 0);
 #if TEST_STD_VER > 17
     //(void)std::lexicographical_compare_three_way(first, last, first2, last2, ThreeWay(&copies)); assert(copies == 0);
 #endif
-    (void)std::lower_bound(first, last, value, Less(&copies)); assert(copies == 0);
-    (void)std::make_heap(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::max(value, value, Less(&copies)); assert(copies == 0);
+    (void)std::lower_bound(first, last, value, Less<T>(&copies)); assert(copies == 0);
+    (void)std::make_heap(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::max(value, value, Less<T>(&copies)); assert(copies == 0);
 #if TEST_STD_VER >= 11
-    (void)std::max({ value, value }, Less(&copies)); assert(copies == 0);
+    (void)std::max({ value, value }, Less<T>(&copies)); assert(copies == 0);
 #endif
-    (void)std::max_element(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::merge(first, mid, mid, last, first2, Less(&copies)); assert(copies == 0);
-    (void)std::min(value, value, Less(&copies)); assert(copies == 0);
+    (void)std::max_element(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::merge(first, mid, mid, last, first2, Less<T>(&copies)); assert(copies == 0);
+    (void)std::min(value, value, Less<T>(&copies)); assert(copies == 0);
 #if TEST_STD_VER >= 11
-    (void)std::min({ value, value }, Less(&copies)); assert(copies == 0);
+    (void)std::min({ value, value }, Less<T>(&copies)); assert(copies == 0);
 #endif
-    (void)std::min_element(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::minmax(value, value, Less(&copies)); assert(copies == 0);
+    (void)std::min_element(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::minmax(value, value, Less<T>(&copies)); assert(copies == 0);
 #if TEST_STD_VER >= 11
-    (void)std::minmax({ value, value }, Less(&copies)); assert(copies == 0);
+    (void)std::minmax({ value, value }, Less<T>(&copies)); assert(copies == 0);
 #endif
-    (void)std::minmax_element(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::mismatch(first, last, first2, Equal(&copies)); assert(copies == 0);
+    (void)std::minmax_element(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::mismatch(first, last, first2, Equal<T>(&copies)); assert(copies == 0);
 #if TEST_STD_VER > 11
-    (void)std::mismatch(first, last, first2, last2, Equal(&copies)); assert(copies == 0);
+    (void)std::mismatch(first, last, first2, last2, Equal<T>(&copies)); assert(copies == 0);
 #endif
-    (void)std::next_permutation(first, last, Less(&copies)); assert(copies == 0);
+    (void)std::next_permutation(first, last, Less<T>(&copies)); assert(copies == 0);
 #if TEST_STD_VER >= 11
-    (void)std::none_of(first, last, UnaryTrue(&copies)); assert(copies == 0);
+    (void)std::none_of(first, last, UnaryTrue<T>(&copies)); assert(copies == 0);
 #endif
-    (void)std::nth_element(first, mid, last, Less(&copies)); assert(copies == 0);
-    (void)std::partial_sort(first, mid, last, Less(&copies)); assert(copies == 0);
-    (void)std::partial_sort_copy(first, last, first2, mid2, Less(&copies)); assert(copies == 0);
-    (void)std::partition(first, last, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::partition_copy(first, last, first2, last2, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::partition_point(first, last, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::pop_heap(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::prev_permutation(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::push_heap(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::remove_copy_if(first, last, first2, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::remove_if(first, last, UnaryTrue(&copies)); assert(copies == 0);
-    (void)std::replace_copy_if(first, last, first2, UnaryTrue(&copies), value); assert(copies == 0);
-    (void)std::replace_if(first, last, UnaryTrue(&copies), value); assert(copies == 0);
-    (void)std::search(first, last, first2, mid2, Equal(&copies)); assert(copies == 0);
-    (void)std::search_n(first, last, count, value, Equal(&copies)); assert(copies == 0);
-    (void)std::set_difference(first, mid, mid, last, first2, Less(&copies)); assert(copies == 0);
-    (void)std::set_intersection(first, mid, mid, last, first2, Less(&copies)); assert(copies == 0);
-    (void)std::set_symmetric_difference(first, mid, mid, last, first2, Less(&copies)); assert(copies == 0);
-    (void)std::set_union(first, mid, mid, last, first2, Less(&copies)); assert(copies == 0);
-    (void)std::sort(first, last, Less(&copies)); assert(copies == 0);
-    (void)std::sort_heap(first, last, Less(&copies)); assert(copies == 0);
-    if (!TEST_IS_CONSTANT_EVALUATED) { (void)std::stable_partition(first, last, UnaryTrue(&copies)); assert(copies == 0); }
-    if (!TEST_IS_CONSTANT_EVALUATED) { (void)std::stable_sort(first, last, Less(&copies)); assert(copies == 0); }
-    (void)std::transform(first, last, first2, UnaryTransform(&copies)); assert(copies == 0);
-    (void)std::transform(first, mid, mid, first2, BinaryTransform(&copies)); assert(copies == 0);
-    (void)std::unique(first, last, Equal(&copies)); assert(copies == 0);
-    (void)std::unique_copy(first, last, first2, Equal(&copies)); assert(copies == 0);
-    (void)std::upper_bound(first, last, value, Less(&copies)); assert(copies == 0);
+    (void)std::nth_element(first, mid, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::partial_sort(first, mid, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::partial_sort_copy(first, last, first2, mid2, Less<T>(&copies)); assert(copies == 0);
+    (void)std::partition(first, last, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::partition_copy(first, last, first2, last2, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::partition_point(first, last, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::pop_heap(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::prev_permutation(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::push_heap(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::remove_copy_if(first, last, first2, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::remove_if(first, last, UnaryTrue<T>(&copies)); assert(copies == 0);
+    (void)std::replace_copy_if(first, last, first2, UnaryTrue<T>(&copies), value); assert(copies == 0);
+    (void)std::replace_if(first, last, UnaryTrue<T>(&copies), value); assert(copies == 0);
+    (void)std::search(first, last, first2, mid2, Equal<T>(&copies)); assert(copies == 0);
+    (void)std::search_n(first, last, count, value, Equal<T>(&copies)); assert(copies == 0);
+    (void)std::set_difference(first, mid, mid, last, first2, Less<T>(&copies)); assert(copies == 0);
+    (void)std::set_intersection(first, mid, mid, last, first2, Less<T>(&copies)); assert(copies == 0);
+    (void)std::set_symmetric_difference(first, mid, mid, last, first2, Less<T>(&copies)); assert(copies == 0);
+    (void)std::set_union(first, mid, mid, last, first2, Less<T>(&copies)); assert(copies == 0);
+    (void)std::sort(first, first+3, Less<T>(&copies)); assert(copies == 0);
+    (void)std::sort(first, first+4, Less<T>(&copies)); assert(copies == 0);
+    (void)std::sort(first, first+5, Less<T>(&copies)); assert(copies == 0);
+    (void)std::sort(first, last, Less<T>(&copies)); assert(copies == 0);
+    (void)std::sort_heap(first, last, Less<T>(&copies)); assert(copies == 0);
+    if (!TEST_IS_CONSTANT_EVALUATED) { (void)std::stable_partition(first, last, UnaryTrue<T>(&copies)); assert(copies == 0); }
+    if (!TEST_IS_CONSTANT_EVALUATED) { (void)std::stable_sort(first, last, Less<T>(&copies)); assert(copies == 0); }
+    (void)std::transform(first, last, first2, UnaryTransform<T>(&copies)); assert(copies == 0);
+    (void)std::transform(first, mid, mid, first2, BinaryTransform<T>(&copies)); assert(copies == 0);
+    (void)std::unique(first, last, Equal<T>(&copies)); assert(copies == 0);
+    (void)std::unique_copy(first, last, first2, Equal<T>(&copies)); assert(copies == 0);
+    (void)std::upper_bound(first, last, value, Less<T>(&copies)); assert(copies == 0);
 
     return true;
 }
 
 int main(int, char**)
 {
-    all_the_algorithms();
+    all_the_algorithms<void*>();
+    all_the_algorithms<int>();
 #if TEST_STD_VER > 17
-    static_assert(all_the_algorithms());
+    static_assert(all_the_algorithms<void*>());
+    static_assert(all_the_algorithms<int>());
 #endif
 
     return 0;
index a36981c..b3c0a0d 100644 (file)
 #include <numeric>
 #include <random>
 #include <cassert>
+#include <vector>
+#include <deque>
 
 #include "test_macros.h"
 
 std::mt19937 randomness;
 
-template <class RI>
+template <class Container, class RI>
 void
 test_sort_helper(RI f, RI l)
 {
-    typedef typename std::iterator_traits<RI>::value_type value_type;
-    typedef typename std::iterator_traits<RI>::difference_type difference_type;
-
     if (f != l)
     {
-        difference_type len = l - f;
-        value_type* save(new value_type[len]);
+        Container save(l - f);
         do
         {
-            std::copy(f, l, save);
-            std::sort(save, save+len);
-            assert(std::is_sorted(save, save+len));
+            std::copy(f, l, save.begin());
+            std::sort(save.begin(), save.end());
+            assert(std::is_sorted(save.begin(), save.end()));
+            assert(std::is_permutation(save.begin(), save.end(), f));
         } while (std::next_permutation(f, l));
-        delete [] save;
     }
 }
 
-template <class RI>
+template <class T>
+void set_value(T& dest, int value)
+{
+    dest = value;
+}
+
+inline void set_value(std::pair<int, int>& dest, int value)
+{
+    dest.first = value;
+    dest.second = value;
+}
+
+template <class Container, class RI>
 void
 test_sort_driver_driver(RI f, RI l, int start, RI real_last)
 {
     for (RI i = l; i > f + start;)
     {
-        *--i = start;
+        set_value(*--i, start);
         if (f == i)
         {
-            test_sort_helper(f, real_last);
+            test_sort_helper<Container>(f, real_last);
         }
-    if (start > 0)
-        test_sort_driver_driver(f, i, start-1, real_last);
+        if (start > 0)
+            test_sort_driver_driver<Container>(f, i, start-1, real_last);
     }
 }
 
-template <class RI>
+template <class Container, class RI>
 void
 test_sort_driver(RI f, RI l, int start)
 {
-    test_sort_driver_driver(f, l, start, l);
+    test_sort_driver_driver<Container>(f, l, start, l);
 }
 
-template <int sa>
+template <class Container, int sa>
 void
 test_sort_()
 {
-    int ia[sa];
+    Container ia(sa);
     for (int i = 0; i < sa; ++i)
     {
-        test_sort_driver(ia, ia+sa, i);
+        test_sort_driver<Container>(ia.begin(), ia.end(), i);
     }
 }
 
+template <class T>
+T increment_or_reset(T value, int max_value)
+{
+    return value == max_value - 1 ? 0 : value + 1;
+}
+
+inline std::pair<int, int> increment_or_reset(std::pair<int, int> value,
+                                              int max_value)
+{
+    int new_value = value.first + 1;
+    if (new_value == max_value)
+    {
+        new_value = 0;
+    }
+    return std::make_pair(new_value, new_value);
+}
+
+template <class Container>
 void
 test_larger_sorts(int N, int M)
 {
+    using Iter = typename Container::iterator;
+    using ValueType = typename Container::value_type;
     assert(N != 0);
     assert(M != 0);
-    // create array length N filled with M different numbers
-    int* array = new int[N];
-    int x = 0;
+    // create container of length N filled with M different objects
+    Container array(N);
+    ValueType x = ValueType();
     for (int i = 0; i < N; ++i)
     {
         array[i] = x;
-        if (++x == M)
-            x = 0;
+        x = increment_or_reset(x, M);
     }
+    Container original = array;
+    Iter iter = array.begin();
+    Iter original_iter = original.begin();
+
     // test saw tooth pattern
-    std::sort(array, array+N);
-    assert(std::is_sorted(array, array+N));
+    std::sort(iter, iter+N);
+    assert(std::is_sorted(iter, iter+N));
+    assert(std::is_permutation(iter, iter+N, original_iter));
     // test random pattern
-    std::shuffle(array, array+N, randomness);
-    std::sort(array, array+N);
-    assert(std::is_sorted(array, array+N));
+    std::shuffle(iter, iter+N, randomness);
+    std::sort(iter, iter+N);
+    assert(std::is_sorted(iter, iter+N));
+    assert(std::is_permutation(iter, iter+N, original_iter));
     // test sorted pattern
-    std::sort(array, array+N);
-    assert(std::is_sorted(array, array+N));
+    std::sort(iter, iter+N);
+    assert(std::is_sorted(iter, iter+N));
+    assert(std::is_permutation(iter, iter+N, original_iter));
     // test reverse sorted pattern
-    std::reverse(array, array+N);
-    std::sort(array, array+N);
-    assert(std::is_sorted(array, array+N));
+    std::reverse(iter, iter+N);
+    std::sort(iter, iter+N);
+    assert(std::is_sorted(iter, iter+N));
+    assert(std::is_permutation(iter, iter+N, original_iter));
     // test swap ranges 2 pattern
-    std::swap_ranges(array, array+N/2, array+N/2);
-    std::sort(array, array+N);
-    assert(std::is_sorted(array, array+N));
+    std::swap_ranges(iter, iter+N/2, iter+N/2);
+    std::sort(iter, iter+N);
+    assert(std::is_sorted(iter, iter+N));
+    assert(std::is_permutation(iter, iter+N, original_iter));
     // test reverse swap ranges 2 pattern
-    std::reverse(array, array+N);
-    std::swap_ranges(array, array+N/2, array+N/2);
-    std::sort(array, array+N);
-    assert(std::is_sorted(array, array+N));
-    delete [] array;
+    std::reverse(iter, iter+N);
+    std::swap_ranges(iter, iter+N/2, iter+N/2);
+    std::sort(iter, iter+N);
+    assert(std::is_sorted(iter, iter+N));
+    assert(std::is_permutation(iter, iter+N, original_iter));
 }
 
+template <class Container>
 void
 test_larger_sorts(int N)
 {
-    test_larger_sorts(N, 1);
-    test_larger_sorts(N, 2);
-    test_larger_sorts(N, 3);
-    test_larger_sorts(N, N/2-1);
-    test_larger_sorts(N, N/2);
-    test_larger_sorts(N, N/2+1);
-    test_larger_sorts(N, N-2);
-    test_larger_sorts(N, N-1);
-    test_larger_sorts(N, N);
+    test_larger_sorts<Container>(N, 1);
+    test_larger_sorts<Container>(N, 2);
+    test_larger_sorts<Container>(N, 3);
+    test_larger_sorts<Container>(N, N/2-1);
+    test_larger_sorts<Container>(N, N/2);
+    test_larger_sorts<Container>(N, N/2+1);
+    test_larger_sorts<Container>(N, N-2);
+    test_larger_sorts<Container>(N, N-1);
+    test_larger_sorts<Container>(N, N);
 }
 
 void
@@ -205,28 +244,40 @@ void test_adversarial_quicksort(int N) {
   assert(std::is_sorted(V.begin(), V.end()));
 }
 
-int main(int, char**)
+template <class Container>
+void run_sort_tests()
 {
     // test null range
-    int d = 0;
+    using ValueType = typename Container::value_type;
+    ValueType d = ValueType();
     std::sort(&d, &d);
+
     // exhaustively test all possibilities up to length 8
-    test_sort_<1>();
-    test_sort_<2>();
-    test_sort_<3>();
-    test_sort_<4>();
-    test_sort_<5>();
-    test_sort_<6>();
-    test_sort_<7>();
-    test_sort_<8>();
-
-    test_larger_sorts(256);
-    test_larger_sorts(257);
-    test_larger_sorts(499);
-    test_larger_sorts(500);
-    test_larger_sorts(997);
-    test_larger_sorts(1000);
-    test_larger_sorts(1009);
+    test_sort_<Container, 1>();
+    test_sort_<Container, 2>();
+    test_sort_<Container, 3>();
+    test_sort_<Container, 4>();
+    test_sort_<Container, 5>();
+    test_sort_<Container, 6>();
+    test_sort_<Container, 7>();
+    test_sort_<Container, 8>();
+
+    test_larger_sorts<Container>(256);
+    test_larger_sorts<Container>(257);
+    test_larger_sorts<Container>(499);
+    test_larger_sorts<Container>(500);
+    test_larger_sorts<Container>(997);
+    test_larger_sorts<Container>(1000);
+    test_larger_sorts<Container>(1009);
+}
+
+int main(int, char**)
+{
+    // test various combinations of contiguous/non-contiguous containers with
+    // arithmetic/non-arithmetic types
+    run_sort_tests<std::vector<int> >();
+    run_sort_tests<std::deque<int> >();
+    run_sort_tests<std::vector<std::pair<int, int> > >();
 
     test_pointer_sort();
     test_adversarial_quicksort(1 << 20);