[libc++] Forward to std::{,w}memchr in std::find
authorNikolas Klauser <nikolasklauser@berlin.de>
Wed, 24 May 2023 21:37:22 +0000 (14:37 -0700)
committerNikolas Klauser <nikolasklauser@berlin.de>
Thu, 25 May 2023 14:59:50 +0000 (07:59 -0700)
Reviewed By: #libc, ldionne

Spies: Mordante, libcxx-commits, ldionne, mikhail.ramalho

Differential Revision: https://reviews.llvm.org/D144394

17 files changed:
libcxx/benchmarks/CMakeLists.txt
libcxx/benchmarks/algorithms/find.bench.cpp [new file with mode: 0644]
libcxx/include/__algorithm/find.h
libcxx/include/__algorithm/ranges_find.h
libcxx/include/__string/char_traits.h
libcxx/include/__string/constexpr_c_functions.h
libcxx/include/cwchar
libcxx/test/libcxx/strings/c.strings/constexpr.cstring.compile.pass.cpp
libcxx/test/libcxx/transitive_includes/cxx03.csv
libcxx/test/libcxx/transitive_includes/cxx11.csv
libcxx/test/libcxx/transitive_includes/cxx14.csv
libcxx/test/libcxx/transitive_includes/cxx17.csv
libcxx/test/libcxx/transitive_includes/cxx20.csv
libcxx/test/libcxx/transitive_includes/cxx23.csv
libcxx/test/std/algorithms/alg.nonmodifying/alg.find/find.pass.cpp
libcxx/test/std/algorithms/alg.nonmodifying/alg.find/ranges.find.pass.cpp
libcxx/test/support/type_algorithms.h

index bf7c4b4..daa6fa2 100644 (file)
@@ -159,6 +159,7 @@ endfunction()
 set(BENCHMARK_TESTS
     algorithms.partition_point.bench.cpp
     algorithms/equal.bench.cpp
+    algorithms/find.bench.cpp
     algorithms/lower_bound.bench.cpp
     algorithms/make_heap.bench.cpp
     algorithms/make_heap_then_sort_heap.bench.cpp
diff --git a/libcxx/benchmarks/algorithms/find.bench.cpp b/libcxx/benchmarks/algorithms/find.bench.cpp
new file mode 100644 (file)
index 0000000..65b2cda
--- /dev/null
@@ -0,0 +1,49 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include <algorithm>
+#include <benchmark/benchmark.h>
+#include <cstring>
+#include <random>
+#include <vector>
+
+template <class T>
+static void bm_find(benchmark::State& state) {
+  std::vector<T> vec1(state.range(), '1');
+  std::mt19937_64 rng(std::random_device{}());
+
+  for (auto _ : state) {
+    auto idx  = rng() % vec1.size();
+    vec1[idx] = '2';
+    benchmark::DoNotOptimize(vec1);
+    benchmark::DoNotOptimize(std::find(vec1.begin(), vec1.end(), T('2')));
+    vec1[idx] = '1';
+  }
+}
+BENCHMARK(bm_find<char>)->DenseRange(1, 8)->Range(16, 1 << 20);
+BENCHMARK(bm_find<short>)->DenseRange(1, 8)->Range(16, 1 << 20);
+BENCHMARK(bm_find<int>)->DenseRange(1, 8)->Range(16, 1 << 20);
+
+template <class T>
+static void bm_ranges_find(benchmark::State& state) {
+  std::vector<T> vec1(state.range(), '1');
+  std::mt19937_64 rng(std::random_device{}());
+
+  for (auto _ : state) {
+    auto idx  = rng() % vec1.size();
+    vec1[idx] = '2';
+    benchmark::DoNotOptimize(vec1);
+    benchmark::DoNotOptimize(std::ranges::find(vec1, T('2')));
+    vec1[idx] = '1';
+  }
+}
+BENCHMARK(bm_ranges_find<char>)->DenseRange(1, 8)->Range(16, 1 << 20);
+BENCHMARK(bm_ranges_find<short>)->DenseRange(1, 8)->Range(16, 1 << 20);
+BENCHMARK(bm_ranges_find<int>)->DenseRange(1, 8)->Range(16, 1 << 20);
+
+BENCHMARK_MAIN();
index e51dc9b..e0de503 100644 (file)
 #ifndef _LIBCPP___ALGORITHM_FIND_H
 #define _LIBCPP___ALGORITHM_FIND_H
 
+#include <__algorithm/unwrap_iter.h>
 #include <__config>
+#include <__functional/identity.h>
+#include <__functional/invoke.h>
+#include <__string/constexpr_c_functions.h>
+#include <__type_traits/is_same.h>
+
+#ifndef _LIBCPP_HAS_NO_WIDE_CHARACTERS
+#  include <cwchar>
+#endif
 
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
 #  pragma GCC system_header
 
 _LIBCPP_BEGIN_NAMESPACE_STD
 
-template <class _InputIterator, class _Tp>
-_LIBCPP_NODISCARD_EXT inline _LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_SINCE_CXX20 _InputIterator
-find(_InputIterator __first, _InputIterator __last, const _Tp& __value) {
+template <class _Iter, class _Sent, class _Tp, class _Proj>
+_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _Iter
+__find_impl(_Iter __first, _Sent __last, const _Tp& __value, _Proj& __proj) {
   for (; __first != __last; ++__first)
-    if (*__first == __value)
+    if (std::__invoke(__proj, *__first) == __value)
       break;
   return __first;
 }
 
+template <class _Tp,
+          class _Up,
+          class _Proj,
+          __enable_if_t<__is_identity<_Proj>::value && __libcpp_is_trivially_equality_comparable<_Tp, _Up>::value &&
+                            sizeof(_Tp) == 1,
+                        int> = 0>
+_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _Tp*
+__find_impl(_Tp* __first, _Tp* __last, const _Up& __value, _Proj&) {
+  if (auto __ret = std::__constexpr_memchr(__first, __value, __last - __first))
+    return __ret;
+  return __last;
+}
+
+#ifndef _LIBCPP_HAS_NO_WIDE_CHARACTERS
+template <class _Tp,
+          class _Up,
+          class _Proj,
+          __enable_if_t<__is_identity<_Proj>::value && __libcpp_is_trivially_equality_comparable<_Tp, _Up>::value &&
+                            sizeof(_Tp) == sizeof(wchar_t) && _LIBCPP_ALIGNOF(_Tp) >= _LIBCPP_ALIGNOF(wchar_t),
+                        int> = 0>
+_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _Tp*
+__find_impl(_Tp* __first, _Tp* __last, const _Up& __value, _Proj&) {
+  if (auto __ret = std::__constexpr_wmemchr(__first, __value, __last - __first))
+    return __ret;
+  return __last;
+}
+#endif // _LIBCPP_HAS_NO_WIDE_CHARACTERS
+
+template <class _InputIterator, class _Tp>
+_LIBCPP_NODISCARD_EXT inline _LIBCPP_INLINE_VISIBILITY _LIBCPP_CONSTEXPR_SINCE_CXX20 _InputIterator
+find(_InputIterator __first, _InputIterator __last, const _Tp& __value) {
+  __identity __proj;
+  return std::__rewrap_iter(
+      __first, std::__find_impl(std::__unwrap_iter(__first), std::__unwrap_iter(__last), __value, __proj));
+}
+
 _LIBCPP_END_NAMESPACE_STD
 
 #endif // _LIBCPP___ALGORITHM_FIND_H
index 87f25d1..084cdff 100644 (file)
@@ -9,7 +9,9 @@
 #ifndef _LIBCPP___ALGORITHM_RANGES_FIND_H
 #define _LIBCPP___ALGORITHM_RANGES_FIND_H
 
+#include <__algorithm/find.h>
 #include <__algorithm/ranges_find_if.h>
+#include <__algorithm/unwrap_range.h>
 #include <__config>
 #include <__functional/identity.h>
 #include <__functional/invoke.h>
@@ -33,20 +35,30 @@ _LIBCPP_BEGIN_NAMESPACE_STD
 namespace ranges {
 namespace __find {
 struct __fn {
+  template <class _Iter, class _Sent, class _Tp, class _Proj>
+  _LIBCPP_HIDE_FROM_ABI static constexpr _Iter
+  __find_unwrap(_Iter __first, _Sent __last, const _Tp& __value, _Proj& __proj) {
+    if constexpr (forward_iterator<_Iter>) {
+      auto [__first_un, __last_un] = std::__unwrap_range(__first, std::move(__last));
+      return std::__rewrap_range<_Sent>(
+          std::move(__first), std::__find_impl(std::move(__first_un), std::move(__last_un), __value, __proj));
+    } else {
+      return std::__find_impl(std::move(__first), std::move(__last), __value, __proj);
+    }
+  }
+
   template <input_iterator _Ip, sentinel_for<_Ip> _Sp, class _Tp, class _Proj = identity>
     requires indirect_binary_predicate<ranges::equal_to, projected<_Ip, _Proj>, const _Tp*>
   _LIBCPP_NODISCARD_EXT _LIBCPP_HIDE_FROM_ABI constexpr
   _Ip operator()(_Ip __first, _Sp __last, const _Tp& __value, _Proj __proj = {}) const {
-    auto __pred = [&](auto&& __e) { return std::forward<decltype(__e)>(__e) == __value; };
-    return ranges::__find_if_impl(std::move(__first), std::move(__last), __pred, __proj);
+    return __find_unwrap(std::move(__first), std::move(__last), __value, __proj);
   }
 
   template <input_range _Rp, class _Tp, class _Proj = identity>
     requires indirect_binary_predicate<ranges::equal_to, projected<iterator_t<_Rp>, _Proj>, const _Tp*>
   _LIBCPP_NODISCARD_EXT _LIBCPP_HIDE_FROM_ABI constexpr
   borrowed_iterator_t<_Rp> operator()(_Rp&& __r, const _Tp& __value, _Proj __proj = {}) const {
-    auto __pred = [&](auto&& __e) { return std::forward<decltype(__e)>(__e) == __value; };
-    return ranges::__find_if_impl(ranges::begin(__r), ranges::end(__r), __pred, __proj);
+    return __find_unwrap(ranges::begin(__r), ranges::end(__r), __value, __proj);
   }
 };
 } // namespace __find
index c4dfcaf..61975cc 100644 (file)
@@ -244,7 +244,7 @@ struct _LIBCPP_TEMPLATE_VIS char_traits<char>
     const char_type* find(const char_type* __s, size_t __n, const char_type& __a) _NOEXCEPT {
       if (__n == 0)
           return nullptr;
-      return std::__constexpr_char_memchr(__s, static_cast<int>(__a), __n);
+      return std::__constexpr_memchr(__s, __a, __n);
     }
 
     static inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX20
index ffba782..9a768cf 100644 (file)
@@ -14,6 +14,7 @@
 #include <__type_traits/is_equality_comparable.h>
 #include <__type_traits/is_same.h>
 #include <__type_traits/is_trivially_lexicographically_comparable.h>
+#include <__type_traits/remove_cv.h>
 #include <cstddef>
 
 #if !defined(_LIBCPP_HAS_NO_PRAGMA_SYSTEM_HEADER)
@@ -95,20 +96,29 @@ __constexpr_memcmp_equal(const _Tp* __lhs, const _Up* __rhs, size_t __count) {
   }
 }
 
-inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 const char*
-__constexpr_char_memchr(const char* __str, int __char, size_t __count) {
-#if __has_builtin(__builtin_char_memchr)
-  return __builtin_char_memchr(__str, __char, __count);
-#else
-  if (!__libcpp_is_constant_evaluated())
-    return static_cast<const char*>(__builtin_memchr(__str, __char, __count));
-  for (; __count; --__count) {
-    if (*__str == __char)
-      return __str;
-    ++__str;
-  }
-  return nullptr;
+template <class _Tp, class _Up>
+_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _Tp* __constexpr_memchr(_Tp* __str, _Up __value, size_t __count) {
+  static_assert(sizeof(_Tp) == 1 && __libcpp_is_trivially_equality_comparable<_Tp, _Up>::value,
+                "Calling memchr on non-trivially equality comparable types is unsafe.");
+
+  if (__libcpp_is_constant_evaluated()) {
+// use __builtin_char_memchr to optimize constexpr evaluation if we can
+#if _LIBCPP_STD_VER >= 17 && __has_builtin(__builtin_char_memchr)
+    if constexpr (is_same_v<remove_cv_t<_Tp>, char> && is_same_v<remove_cv_t<_Up>, char>)
+      return __builtin_char_memchr(__str, __value, __count);
 #endif
+
+    for (; __count; --__count) {
+      if (*__str == __value)
+        return __str;
+      ++__str;
+    }
+    return nullptr;
+  } else {
+    char __value_buffer = 0;
+    __builtin_memcpy(&__value_buffer, &__value, sizeof(char));
+    return static_cast<_Tp*>(__builtin_memchr(__str, __value_buffer, __count));
+  }
 }
 
 _LIBCPP_END_NAMESPACE_STD
index fb7b92b..122af24 100644 (file)
@@ -104,7 +104,11 @@ size_t wcsrtombs(char* restrict dst, const wchar_t** restrict src, size_t len,
 
 #include <__assert> // all public C++ headers provide the assertion handler
 #include <__config>
+#include <__type_traits/apply_cv.h>
 #include <__type_traits/is_constant_evaluated.h>
+#include <__type_traits/is_equality_comparable.h>
+#include <__type_traits/is_same.h>
+#include <__type_traits/remove_cv.h>
 #include <cwctype>
 
 #include <wchar.h>
@@ -222,21 +226,31 @@ __constexpr_wmemcmp(const wchar_t* __lhs, const wchar_t* __rhs, size_t __count)
 #endif
 }
 
-inline _LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 const wchar_t*
-__constexpr_wmemchr(const wchar_t* __str, wchar_t __char, size_t __count) {
-#if __has_feature(cxx_constexpr_string_builtins)
-  return __builtin_wmemchr(__str, __char, __count);
-#else
-  if (!__libcpp_is_constant_evaluated())
-    return std::wmemchr(__str, __char, __count);
+template <class _Tp, class _Up>
+_LIBCPP_HIDE_FROM_ABI _LIBCPP_CONSTEXPR_SINCE_CXX14 _Tp* __constexpr_wmemchr(_Tp* __str, _Up __value, size_t __count) {
+  static_assert(sizeof(_Tp) == sizeof(wchar_t)&& _LIBCPP_ALIGNOF(_Tp) >= _LIBCPP_ALIGNOF(wchar_t) &&
+                    __libcpp_is_trivially_equality_comparable<_Tp, _Tp>::value,
+                "Calling wmemchr on non-trivially equality comparable types is unsafe.");
+
+#if __has_builtin(__builtin_wmemchr)
+  if (!__libcpp_is_constant_evaluated()) {
+    wchar_t __value_buffer = 0;
+    __builtin_memcpy(&__value_buffer, &__value, sizeof(wchar_t));
+    return reinterpret_cast<_Tp*>(
+        __builtin_wmemchr(reinterpret_cast<__apply_cv_t<_Tp, wchar_t>*>(__str), __value_buffer, __count));
+  }
+#  if _LIBCPP_STD_VER >= 17
+  else if constexpr (is_same_v<remove_cv_t<_Tp>, wchar_t>)
+    return __builtin_wmemchr(__str, __value, __count);
+#  endif
+#endif // __has_builtin(__builtin_wmemchr)
 
   for (; __count; --__count) {
-    if (*__str == __char)
+    if (*__str == __value)
       return __str;
     ++__str;
   }
   return nullptr;
-#endif
 }
 
 _LIBCPP_END_NAMESPACE_STD
index 5762553..9c941dd 100644 (file)
@@ -29,6 +29,6 @@ static_assert(std::__constexpr_memcmp_equal(Banane, Banane, 6), "");
 
 constexpr bool test_constexpr_wmemchr() {
   const char str[] = "Banane";
-  return std::__constexpr_char_memchr(str, 'n', 6) == str + 2;
+  return std::__constexpr_memchr(str, 'n', 6) == str + 2;
 }
 static_assert(test_constexpr_wmemchr(), "");
index 0770eaa..d45892a 100644 (file)
@@ -7,6 +7,7 @@ algorithm cstdint
 algorithm cstdlib
 algorithm cstring
 algorithm ctime
+algorithm cwchar
 algorithm execution
 algorithm initializer_list
 algorithm iosfwd
@@ -191,6 +192,7 @@ coroutine version
 cstddef version
 ctgmath ccomplex
 ctgmath cmath
+cwchar cstddef
 cwchar cwctype
 cwctype cctype
 deque algorithm
@@ -201,6 +203,7 @@ deque cstddef
 deque cstdint
 deque cstdlib
 deque cstring
+deque cwchar
 deque functional
 deque initializer_list
 deque iosfwd
@@ -950,6 +953,7 @@ vector cstddef
 vector cstdint
 vector cstdlib
 vector cstring
+vector cwchar
 vector initializer_list
 vector iosfwd
 vector limits
index f607561..264e144 100644 (file)
@@ -7,6 +7,7 @@ algorithm cstdint
 algorithm cstdlib
 algorithm cstring
 algorithm ctime
+algorithm cwchar
 algorithm execution
 algorithm initializer_list
 algorithm iosfwd
@@ -191,6 +192,7 @@ coroutine version
 cstddef version
 ctgmath ccomplex
 ctgmath cmath
+cwchar cstddef
 cwchar cwctype
 cwctype cctype
 deque algorithm
@@ -201,6 +203,7 @@ deque cstddef
 deque cstdint
 deque cstdlib
 deque cstring
+deque cwchar
 deque functional
 deque initializer_list
 deque iosfwd
@@ -951,6 +954,7 @@ vector cstddef
 vector cstdint
 vector cstdlib
 vector cstring
+vector cwchar
 vector initializer_list
 vector iosfwd
 vector limits
index 2fe6248..831001c 100644 (file)
@@ -7,6 +7,7 @@ algorithm cstdint
 algorithm cstdlib
 algorithm cstring
 algorithm ctime
+algorithm cwchar
 algorithm execution
 algorithm initializer_list
 algorithm iosfwd
@@ -191,6 +192,7 @@ coroutine version
 cstddef version
 ctgmath ccomplex
 ctgmath cmath
+cwchar cstddef
 cwchar cwctype
 cwctype cctype
 deque algorithm
@@ -201,6 +203,7 @@ deque cstddef
 deque cstdint
 deque cstdlib
 deque cstring
+deque cwchar
 deque functional
 deque initializer_list
 deque iosfwd
@@ -953,6 +956,7 @@ vector cstddef
 vector cstdint
 vector cstdlib
 vector cstring
+vector cwchar
 vector initializer_list
 vector iosfwd
 vector limits
index 2fe6248..831001c 100644 (file)
@@ -7,6 +7,7 @@ algorithm cstdint
 algorithm cstdlib
 algorithm cstring
 algorithm ctime
+algorithm cwchar
 algorithm execution
 algorithm initializer_list
 algorithm iosfwd
@@ -191,6 +192,7 @@ coroutine version
 cstddef version
 ctgmath ccomplex
 ctgmath cmath
+cwchar cstddef
 cwchar cwctype
 cwctype cctype
 deque algorithm
@@ -201,6 +203,7 @@ deque cstddef
 deque cstdint
 deque cstdlib
 deque cstring
+deque cwchar
 deque functional
 deque initializer_list
 deque iosfwd
@@ -953,6 +956,7 @@ vector cstddef
 vector cstdint
 vector cstdlib
 vector cstring
+vector cwchar
 vector initializer_list
 vector iosfwd
 vector limits
index 2c743b9..eee7103 100644 (file)
@@ -7,6 +7,7 @@ algorithm cstdint
 algorithm cstdlib
 algorithm cstring
 algorithm ctime
+algorithm cwchar
 algorithm execution
 algorithm initializer_list
 algorithm iosfwd
@@ -198,6 +199,7 @@ coroutine version
 cstddef version
 ctgmath ccomplex
 ctgmath cmath
+cwchar cstddef
 cwchar cwctype
 cwctype cctype
 deque algorithm
@@ -208,6 +210,7 @@ deque cstddef
 deque cstdint
 deque cstdlib
 deque cstring
+deque cwchar
 deque functional
 deque initializer_list
 deque iosfwd
@@ -958,6 +961,7 @@ vector cstddef
 vector cstdint
 vector cstdlib
 vector cstring
+vector cwchar
 vector initializer_list
 vector iosfwd
 vector limits
index 0a38631..854b233 100644 (file)
@@ -3,6 +3,7 @@ algorithm cstddef
 algorithm cstdint
 algorithm cstring
 algorithm ctime
+algorithm cwchar
 algorithm execution
 algorithm initializer_list
 algorithm iosfwd
@@ -127,12 +128,14 @@ coroutine version
 cstddef version
 ctgmath ccomplex
 ctgmath cmath
+cwchar cstddef
 cwchar cwctype
 cwctype cctype
 deque compare
 deque cstddef
 deque cstdint
 deque cstring
+deque cwchar
 deque initializer_list
 deque limits
 deque new
@@ -442,6 +445,7 @@ random vector
 random version
 ranges compare
 ranges cstddef
+ranges cwchar
 ranges initializer_list
 ranges iosfwd
 ranges iterator
@@ -640,6 +644,7 @@ vector cstddef
 vector cstdint
 vector cstdlib
 vector cstring
+vector cwchar
 vector initializer_list
 vector iosfwd
 vector limits
index c4de565..b55a852 100644 (file)
@@ -6,6 +6,8 @@
 //
 //===----------------------------------------------------------------------===//
 
+// ADDITIONAL_COMPILE_FLAGS: -Wno-sign-compare
+
 // <algorithm>
 
 // template<InputIterator Iter, class T>
 
 #include <algorithm>
 #include <cassert>
+#include <vector>
+#include <type_traits>
 
 #include "test_macros.h"
 #include "test_iterators.h"
+#include "type_algorithms.h"
+
+static std::vector<int> comparable_data;
+
+template <class ArrayT, class CompareT>
+struct Test {
+  template <class Iter>
+  TEST_CONSTEXPR_CXX20 void operator()() {
+    ArrayT arr[] = {
+        ArrayT(1), ArrayT(2), ArrayT(3), ArrayT(4), ArrayT(5), ArrayT(6), ArrayT(7), ArrayT(8), ArrayT(9), ArrayT(10)};
+
+    static_assert(std::is_same<decltype(std::find(Iter(arr), Iter(arr), 0)), Iter>::value, "");
+
+    { // first element matches
+      Iter iter = std::find(Iter(arr), Iter(arr + 10), CompareT(1));
+      assert(*iter == ArrayT(1));
+      assert(base(iter) == arr);
+    }
+
+    { // range is empty; return last
+      Iter iter = std::find(Iter(arr), Iter(arr), CompareT(1));
+      assert(base(iter) == arr);
+    }
+
+    { // if multiple elements match, return the first match
+      ArrayT data[] = {
+          ArrayT(1), ArrayT(2), ArrayT(3), ArrayT(4), ArrayT(5), ArrayT(6), ArrayT(7), ArrayT(5), ArrayT(4)};
+      Iter iter = std::find(Iter(std::begin(data)), Iter(std::end(data)), CompareT(5));
+      assert(*iter == ArrayT(5));
+      assert(base(iter) == data + 4);
+    }
+
+    { // some element matches
+      Iter iter = std::find(Iter(arr), Iter(arr + 10), CompareT(6));
+      assert(*iter == ArrayT(6));
+      assert(base(iter) == arr + 5);
+    }
+
+    { // last element matches
+      Iter iter = std::find(Iter(arr), Iter(arr + 10), CompareT(10));
+      assert(*iter == ArrayT(10));
+      assert(base(iter) == arr + 9);
+    }
 
-#if TEST_STD_VER > 17
-TEST_CONSTEXPR bool test_constexpr() {
-    int ia[] = {1, 3, 5, 2, 4, 6};
-    int ib[] = {1, 2, 3, 4, 5, 6};
-    return    (std::find(std::begin(ia), std::end(ia), 5) == ia+2)
-           && (std::find(std::begin(ib), std::end(ib), 9) == ib+6)
-           ;
+    { // if no element matches, last is returned
+      Iter iter = std::find(Iter(arr), Iter(arr + 10), CompareT(20));
+      assert(base(iter) == arr + 10);
     }
+
+    if (!TEST_IS_CONSTANT_EVALUATED)
+      comparable_data.clear();
+  }
+};
+
+template <class IndexT>
+class Comparable {
+  IndexT index_;
+
+  static IndexT init_index(IndexT i) {
+    IndexT size = static_cast<IndexT>(comparable_data.size());
+    comparable_data.push_back(i);
+    return size;
+  }
+
+public:
+  Comparable(IndexT i) : index_(init_index(i)) {}
+
+  friend bool operator==(const Comparable& lhs, const Comparable& rhs) {
+    return comparable_data[lhs.index_] == comparable_data[rhs.index_];
+  }
+};
+
+#if TEST_STD_VER >= 20
+template <class ElementT>
+class TriviallyComparable {
+  ElementT el_;
+
+public:
+  explicit constexpr TriviallyComparable(ElementT el) : el_(el) {}
+  bool operator==(const TriviallyComparable&) const = default;
+};
 #endif
 
-int main(int, char**)
-{
-    int ia[] = {0, 1, 2, 3, 4, 5};
-    const unsigned s = sizeof(ia)/sizeof(ia[0]);
-    cpp17_input_iterator<const int*> r = std::find(cpp17_input_iterator<const int*>(ia),
-                                             cpp17_input_iterator<const int*>(ia+s), 3);
-    assert(*r == 3);
-    r = std::find(cpp17_input_iterator<const int*>(ia), cpp17_input_iterator<const int*>(ia+s), 10);
-    assert(r == cpp17_input_iterator<const int*>(ia+s));
-
-#if TEST_STD_VER > 17
-    static_assert(test_constexpr());
+template <class CompareT>
+struct TestTypes {
+  template <class T>
+  TEST_CONSTEXPR_CXX20 void operator()() {
+    types::for_each(types::cpp17_input_iterator_list<T*>(), Test<T, CompareT>());
+  }
+};
+
+TEST_CONSTEXPR_CXX20 bool test() {
+  types::for_each(types::integer_types(), TestTypes<char>());
+  types::for_each(types::integer_types(), TestTypes<int>());
+  types::for_each(types::integer_types(), TestTypes<long long>());
+#if TEST_STD_VER >= 20
+  Test<TriviallyComparable<char>, TriviallyComparable<char>>().operator()<TriviallyComparable<char>*>();
+  Test<TriviallyComparable<wchar_t>, TriviallyComparable<wchar_t>>().operator()<TriviallyComparable<wchar_t>*>();
+#endif
+
+  return true;
+}
+
+int main(int, char**) {
+  test();
+#if TEST_STD_VER >= 20
+  static_assert(test());
 #endif
 
+  Test<Comparable<char>, Comparable<char> >().operator()<Comparable<char>*>();
+  Test<Comparable<wchar_t>, Comparable<wchar_t> >().operator()<Comparable<wchar_t>*>();
+
   return 0;
 }
index 4edded0..a1fb92a 100644 (file)
@@ -10,6 +10,8 @@
 
 // UNSUPPORTED: c++03, c++11, c++14, c++17
 
+// ADDITIONAL_COMPILE_FLAGS: -Wno-sign-compare
+
 // template<input_iterator I, sentinel_for<I> S, class T, class Proj = identity>
 //   requires indirect_binary_predicate<ranges::equal_to, projected<I, Proj>, const T*>
 //   constexpr I ranges::find(I first, S last, const T& value, Proj proj = {});
@@ -22,6 +24,7 @@
 #include <array>
 #include <cassert>
 #include <ranges>
+#include <vector>
 
 #include "almost_satisfies_types.h"
 #include "boolean_testable.h"
@@ -53,46 +56,78 @@ static_assert(!HasFindR<InputRangeNotInputOrOutputIterator, int>);
 static_assert(!HasFindR<InputRangeNotSentinelSemiregular, int>);
 static_assert(!HasFindR<InputRangeNotSentinelEqualityComparableWith, int>);
 
+static std::vector<int> comparable_data;
+
 template <class It, class Sent = It>
 constexpr void test_iterators() {
-  {
-    int a[] = {1, 2, 3, 4};
-    std::same_as<It> auto ret = std::ranges::find(It(a), Sent(It(a + 4)), 4);
-    assert(base(ret) == a + 3);
-    assert(*ret == 4);
-  }
-  {
-    int a[] = {1, 2, 3, 4};
-    auto range = std::ranges::subrange(It(a), Sent(It(a + 4)));
-    std::same_as<It> auto ret = std::ranges::find(range, 4);
-    assert(base(ret) == a + 3);
-    assert(*ret == 4);
+  using ValueT = std::iter_value_t<It>;
+  { // simple test
+    {
+      ValueT a[] = {1, 2, 3, 4};
+      std::same_as<It> auto ret = std::ranges::find(It(a), Sent(It(a + 4)), 4);
+      assert(base(ret) == a + 3);
+      assert(*ret == 4);
+    }
+    {
+      ValueT a[] = {1, 2, 3, 4};
+      auto range = std::ranges::subrange(It(a), Sent(It(a + 4)));
+      std::same_as<It> auto ret = std::ranges::find(range, 4);
+      assert(base(ret) == a + 3);
+      assert(*ret == 4);
+    }
   }
-}
 
-constexpr bool test() {
-  test_iterators<int*>();
-  test_iterators<const int*>();
-  test_iterators<cpp20_input_iterator<int*>, sentinel_wrapper<cpp20_input_iterator<int*>>>();
-  test_iterators<bidirectional_iterator<int*>>();
-  test_iterators<forward_iterator<int*>>();
-  test_iterators<random_access_iterator<int*>>();
-  test_iterators<contiguous_iterator<int*>>();
+  { // check that an empty range works
+    {
+      std::array<ValueT, 0> a = {};
+      auto ret = std::ranges::find(It(a.data()), Sent(It(a.data())), 1);
+      assert(base(ret) == a.data());
+    }
+    {
+      std::array<ValueT, 0> a = {};
+      auto range = std::ranges::subrange(It(a.data()), Sent(It(a.data())));
+      auto ret = std::ranges::find(range, 1);
+      assert(base(ret) == a.data());
+    }
+  }
 
-  {
-    // check that projections are used properly and that they are called with the iterator directly
+  { // check that last is returned with no match
     {
-      int a[] = {1, 2, 3, 4};
-      auto ret = std::ranges::find(a, a + 4, a + 3, [](int& i) { return &i; });
+      ValueT a[] = {1, 1, 1};
+      auto ret = std::ranges::find(a, a + 3, 0);
       assert(ret == a + 3);
     }
     {
-      int a[] = {1, 2, 3, 4};
-      auto ret = std::ranges::find(a, a + 3, [](int& i) { return &i; });
+      ValueT a[] = {1, 1, 1};
+      auto ret = std::ranges::find(a, 0);
       assert(ret == a + 3);
     }
   }
 
+  if (!std::is_constant_evaluated())
+    comparable_data.clear();
+}
+
+template <class ElementT>
+class TriviallyComparable {
+  ElementT el_;
+
+public:
+  TEST_CONSTEXPR TriviallyComparable(ElementT el) : el_(el) {}
+  bool operator==(const TriviallyComparable&) const = default;
+};
+
+constexpr bool test() {
+  types::for_each(types::type_list<char, wchar_t, int, long, TriviallyComparable<char>, TriviallyComparable<wchar_t>>{},
+                  []<class T> {
+                    types::for_each(types::cpp20_input_iterator_list<T*>{}, []<class Iter> {
+                      if constexpr (std::forward_iterator<Iter>)
+                        test_iterators<Iter>();
+                      test_iterators<Iter, sentinel_wrapper<Iter>>();
+                      test_iterators<Iter, sized_sentinel<Iter>>();
+                    });
+                  });
+
   { // check that the first element is returned
     {
       struct S {
@@ -118,19 +153,6 @@ constexpr bool test() {
     }
   }
 
-  { // check that end + 1 iterator is returned with no match
-    {
-      int a[] = {1, 1, 1};
-      auto ret = std::ranges::find(a, a + 3, 0);
-      assert(ret == a + 3);
-    }
-    {
-      int a[] = {1, 1, 1};
-      auto ret = std::ranges::find(a, 0);
-      assert(ret == a + 3);
-    }
-  }
-
   {
     // check that an iterator is returned with a borrowing range
     int a[] = {1, 2, 3, 4};
@@ -159,26 +181,45 @@ constexpr bool test() {
     }
   }
 
-  {
-    // check that an empty range works
-    {
-      std::array<int ,0> a = {};
-      auto ret = std::ranges::find(a.begin(), a.end(), 1);
-      assert(ret == a.begin());
-    }
-    {
-      std::array<int, 0> a = {};
-      auto ret = std::ranges::find(a, 1);
-      assert(ret == a.begin());
-    }
-  }
-
   return true;
 }
 
+template <class IndexT>
+class Comparable {
+  IndexT index_;
+
+public:
+  Comparable(IndexT i)
+      : index_([&]() {
+          IndexT size = static_cast<IndexT>(comparable_data.size());
+          comparable_data.push_back(i);
+          return size;
+        }()) {}
+
+  bool operator==(const Comparable& other) const {
+    return comparable_data[other.index_] == comparable_data[index_];
+  }
+
+  friend bool operator==(const Comparable& lhs, long long rhs) { return comparable_data[lhs.index_] == rhs; }
+};
+
 int main(int, char**) {
   test();
   static_assert(test());
 
+  types::for_each(types::cpp20_input_iterator_list<Comparable<char>*>{}, []<class Iter> {
+    if constexpr (std::forward_iterator<Iter>)
+      test_iterators<Iter>();
+    test_iterators<Iter, sentinel_wrapper<Iter>>();
+    test_iterators<Iter, sized_sentinel<Iter>>();
+  });
+
+  types::for_each(types::cpp20_input_iterator_list<Comparable<wchar_t>*>{}, []<class Iter> {
+    if constexpr (std::forward_iterator<Iter>)
+      test_iterators<Iter>();
+    test_iterators<Iter, sentinel_wrapper<Iter>>();
+    test_iterators<Iter, sized_sentinel<Iter>>();
+  });
+
   return 0;
 }
index 46e6586..da3d0ad 100644 (file)
@@ -123,7 +123,9 @@ using unsigned_integer_types =
 #endif
               >;
 
-using integral_types = concatenate_t<character_types, signed_integer_types, unsigned_integer_types, type_list<bool> >;
+using integer_types = concatenate_t<character_types, signed_integer_types, unsigned_integer_types>;
+
+using integral_types = concatenate_t<integer_types, type_list<bool> >;
 
 using floating_point_types = type_list<float, double, long double>;