[libc++] Refactor __debug_three_way_comp
authorLouis Dionne <ldionne.2@gmail.com>
Mon, 12 Jun 2023 21:16:32 +0000 (14:16 -0700)
committerLouis Dionne <ldionne.2@gmail.com>
Tue, 13 Jun 2023 21:24:56 +0000 (14:24 -0700)
This makes __debug_three_way_comp consistent with __debug_less and
in particular gets rid of a potential use-after-move caused by the
use of std::forward. In the previous version of the code, we would
call `__do_compare_assert` after forwarding the arguments into the
comparator, which could end up using the arguments after they've been
moved from.

This also simplifies how we call `__do_compare_assert` by using
`if constexpr` and adds a missing test for proxy iterators in
lexicographical_compare_three_way, which could have found this
issue.

Differential Revision: https://reviews.llvm.org/D152753

libcxx/include/__algorithm/three_way_comp_ref_type.h
libcxx/test/std/algorithms/alg.sorting/alg.three.way/lexicographical_compare_three_way.pass.cpp
libcxx/test/std/algorithms/alg.sorting/alg.three.way/lexicographical_compare_three_way_comp.pass.cpp

index ce1350a..5dc922f 100644 (file)
@@ -29,16 +29,23 @@ struct __debug_three_way_comp {
   _LIBCPP_HIDE_FROM_ABI constexpr __debug_three_way_comp(_Comp& __c) : __comp_(__c) {}
 
   template <class _Tp, class _Up>
-  _LIBCPP_HIDE_FROM_ABI constexpr auto operator()(_Tp&& __x, _Up&& __y) {
-    auto __r = __comp_(std::forward<_Tp>(__x), std::forward<_Up>(__y));
-    __do_compare_assert(0, __y, __x, __r);
+  _LIBCPP_HIDE_FROM_ABI constexpr auto operator()(const _Tp& __x, const _Up& __y) {
+    auto __r = __comp_(__x, __y);
+    if constexpr (__comparison_category<decltype(__comp_(__x, __y))>)
+      __do_compare_assert(__y, __x, __r);
+    return __r;
+  }
+
+  template <class _Tp, class _Up>
+  _LIBCPP_HIDE_FROM_ABI constexpr auto operator()(_Tp& __x, _Up& __y) {
+    auto __r = __comp_(__x, __y);
+    if constexpr (__comparison_category<decltype(__comp_(__x, __y))>)
+      __do_compare_assert(__y, __x, __r);
     return __r;
   }
 
   template <class _LHS, class _RHS, class _Order>
-  _LIBCPP_HIDE_FROM_ABI constexpr inline void __do_compare_assert(int, _LHS& __l, _RHS& __r, _Order __o)
-    requires __comparison_category<decltype(std::declval<_Comp&>()(std::declval<_LHS&>(), std::declval<_RHS&>()))>
-  {
+  _LIBCPP_HIDE_FROM_ABI constexpr void __do_compare_assert(_LHS& __l, _RHS& __r, _Order __o) {
     _Order __expected = __o;
     if (__o == _Order::less)
       __expected = _Order::greater;
@@ -47,11 +54,7 @@ struct __debug_three_way_comp {
     _LIBCPP_DEBUG_ASSERT(__comp_(__l, __r) == __expected, "Comparator does not induce a strict weak ordering");
     (void)__l;
     (void)__r;
-    (void)__expected;
   }
-
-  template <class _LHS, class _RHS, class _Order>
-  _LIBCPP_HIDE_FROM_ABI constexpr inline void __do_compare_assert(long, _LHS&, _RHS&, _Order) {}
 };
 
 // Pass the comparator by lvalue reference. Or in debug mode, using a
index 4c940eb..0bd1d74 100644 (file)
@@ -19,6 +19,7 @@
 #include <cassert>
 #include <compare>
 #include <concepts>
+#include <vector>
 
 #include "test_macros.h"
 #include "test_comparisons.h"
@@ -107,9 +108,17 @@ constexpr void test_comparison_categories() {
       std::partial_ordering::unordered);
 }
 
+// Check that it works with proxy iterators
+constexpr void test_proxy_iterators() {
+    std::vector<bool> vec(10, true);
+    auto result = std::lexicographical_compare_three_way(vec.begin(), vec.end(), vec.begin(), vec.end());
+    assert(result == std::strong_ordering::equal);
+}
+
 constexpr bool test() {
   test_iterator_types();
   test_comparison_categories();
+  test_proxy_iterators();
 
   return true;
 }
index d690f38..ebbf833 100644 (file)
@@ -22,6 +22,7 @@
 #include <compare>
 #include <concepts>
 #include <limits>
+#include <vector>
 
 #include "test_iterators.h"
 #include "test_macros.h"
@@ -162,10 +163,18 @@ constexpr void test_comparator_invocation_count() {
 #endif
 }
 
+// Check that it works with proxy iterators
+constexpr void test_proxy_iterators() {
+    std::vector<bool> vec(10, true);
+    auto result = std::lexicographical_compare_three_way(vec.begin(), vec.end(), vec.begin(), vec.end(), compare_last_digit_strong);
+    assert(result == std::strong_ordering::equal);
+}
+
 constexpr bool test() {
   test_iterator_types();
   test_comparison_categories();
   test_comparator_invocation_count();
+  test_proxy_iterators();
 
   return true;
 }