From 75fda2e0a5530b72443712bd9606bbd48f884817 Mon Sep 17 00:00:00 2001 From: Vedant Kumar Date: Wed, 25 Apr 2018 21:50:09 +0000 Subject: [PATCH] [ADT] Make filter_iterator support bidirectional iteration This makes it possible to reverse a filtered range. For example, here's a way to visit memory accesses in a BasicBlock in reverse order: auto MemInsts = reverse(make_filter_range(BB, [](Instruction &I) { return isa(&I) || isa(&I); })); for (auto &MI : MemInsts) ... To implement this functionality, I factored out forward iteration functionality into filter_iterator_base, and added a specialization of filter_iterator_impl which supports bidirectional iteration. Thanks to Tim Shen, Zachary Turner, and others for suggesting this design and providing feedback! This version of the patch supersedes the original (https://reviews.llvm.org/D45792). This was motivated by a problem we encountered in D45657: we'd like to visit the non-debug-info instructions in a BasicBlock in reverse order. Testing: check-llvm, check-clang Differential Revision: https://reviews.llvm.org/D45853 llvm-svn: 330875 --- llvm/include/llvm/ADT/STLExtras.h | 132 +++++++++++++++++++++++++---------- llvm/unittests/ADT/IteratorTest.cpp | 28 ++++++++ llvm/unittests/IR/BasicBlockTest.cpp | 9 +++ 3 files changed, 134 insertions(+), 35 deletions(-) diff --git a/llvm/include/llvm/ADT/STLExtras.h b/llvm/include/llvm/ADT/STLExtras.h index 051b900..06e751e 100644 --- a/llvm/include/llvm/ADT/STLExtras.h +++ b/llvm/include/llvm/ADT/STLExtras.h @@ -271,60 +271,121 @@ auto reverse( /// auto R = make_filter_range(A, [](int N) { return N % 2 == 1; }); /// // R contains { 1, 3 }. /// \endcode -template -class filter_iterator +/// +/// Note: filter_iterator_base implements support for forward iteration. +/// filter_iterator_impl exists to provide support for bidirectional iteration, +/// conditional on whether the wrapped iterator supports it. +template +class filter_iterator_base : public iterator_adaptor_base< - filter_iterator, WrappedIteratorT, + filter_iterator_base, + WrappedIteratorT, typename std::common_type< - std::forward_iterator_tag, - typename std::iterator_traits< - WrappedIteratorT>::iterator_category>::type> { + IterTag, typename std::iterator_traits< + WrappedIteratorT>::iterator_category>::type> { using BaseT = iterator_adaptor_base< - filter_iterator, WrappedIteratorT, + filter_iterator_base, + WrappedIteratorT, typename std::common_type< - std::forward_iterator_tag, - typename std::iterator_traits::iterator_category>:: - type>; + IterTag, typename std::iterator_traits< + WrappedIteratorT>::iterator_category>::type>; - struct PayloadType { - WrappedIteratorT End; - PredicateT Pred; - }; - - Optional Payload; +protected: + WrappedIteratorT End; + PredicateT Pred; void findNextValid() { - assert(Payload && "Payload should be engaged when findNextValid is called"); - while (this->I != Payload->End && !Payload->Pred(*this->I)) + while (this->I != End && !Pred(*this->I)) BaseT::operator++(); } - // Construct the begin iterator. The begin iterator requires to know where end - // is, so that it can properly stop when it hits end. - filter_iterator(WrappedIteratorT Begin, WrappedIteratorT End, PredicateT Pred) - : BaseT(std::move(Begin)), - Payload(PayloadType{std::move(End), std::move(Pred)}) { + // Construct the iterator. The begin iterator needs to know where the end + // is, so that it can properly stop when it gets there. The end iterator only + // needs the predicate to support bidirectional iteration. + filter_iterator_base(WrappedIteratorT Begin, WrappedIteratorT End, + PredicateT Pred) + : BaseT(Begin), End(End), Pred(Pred) { findNextValid(); } - // Construct the end iterator. It's not incrementable, so Payload doesn't - // have to be engaged. - filter_iterator(WrappedIteratorT End) : BaseT(End) {} - public: using BaseT::operator++; - filter_iterator &operator++() { + filter_iterator_base &operator++() { BaseT::operator++(); findNextValid(); return *this; } +}; - template - friend iterator_range, PT>> - make_filter_range(RT &&, PT); +/// Specialization of filter_iterator_base for forward iteration only. +template +class filter_iterator_impl + : public filter_iterator_base { + using BaseT = filter_iterator_base; + +public: + filter_iterator_impl(WrappedIteratorT Begin, WrappedIteratorT End, + PredicateT Pred) + : BaseT(Begin, End, Pred) {} }; +/// Specialization of filter_iterator_base for bidirectional iteration. +template +class filter_iterator_impl + : public filter_iterator_base { + using BaseT = filter_iterator_base; + void findPrevValid() { + while (!this->Pred(*this->I)) + BaseT::operator--(); + } + +public: + using BaseT::operator--; + + filter_iterator_impl(WrappedIteratorT Begin, WrappedIteratorT End, + PredicateT Pred) + : BaseT(Begin, End, Pred) {} + + filter_iterator_impl &operator--() { + BaseT::operator--(); + findPrevValid(); + return *this; + } +}; + +namespace detail { + +template struct fwd_or_bidi_tag_impl { + using type = std::forward_iterator_tag; +}; + +template <> struct fwd_or_bidi_tag_impl { + using type = std::bidirectional_iterator_tag; +}; + +/// Helper which sets its type member to forward_iterator_tag if the category +/// of \p IterT does not derive from bidirectional_iterator_tag, and to +/// bidirectional_iterator_tag otherwise. +template struct fwd_or_bidi_tag { + using type = typename fwd_or_bidi_tag_impl::iterator_category>::value>::type; +}; + +} // namespace detail + +/// Defines filter_iterator to a suitable specialization of +/// filter_iterator_impl, based on the underlying iterator's category. +template +using filter_iterator = filter_iterator_impl< + WrappedIteratorT, PredicateT, + typename detail::fwd_or_bidi_tag::type>; + /// Convenience function that takes a range of elements and a predicate, /// and return a new filter_iterator range. /// @@ -337,10 +398,11 @@ iterator_range, PredicateT>> make_filter_range(RangeT &&Range, PredicateT Pred) { using FilterIteratorT = filter_iterator, PredicateT>; - return make_range(FilterIteratorT(std::begin(std::forward(Range)), - std::end(std::forward(Range)), - std::move(Pred)), - FilterIteratorT(std::end(std::forward(Range)))); + return make_range( + FilterIteratorT(std::begin(std::forward(Range)), + std::end(std::forward(Range)), Pred), + FilterIteratorT(std::end(std::forward(Range)), + std::end(std::forward(Range)), Pred)); } // forward declarations required by zip_shortest/zip_first diff --git a/llvm/unittests/ADT/IteratorTest.cpp b/llvm/unittests/ADT/IteratorTest.cpp index c95ce80..5b9320e 100644 --- a/llvm/unittests/ADT/IteratorTest.cpp +++ b/llvm/unittests/ADT/IteratorTest.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/iterator.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "gtest/gtest.h" @@ -196,6 +197,33 @@ TEST(FilterIteratorTest, InputIterator) { EXPECT_EQ((SmallVector{1, 3, 5}), Actual); } +TEST(FilterIteratorTest, ReverseFilterRange) { + auto IsOdd = [](int N) { return N % 2 == 1; }; + int A[] = {0, 1, 2, 3, 4, 5, 6}; + + // Check basic reversal. + auto Range = reverse(make_filter_range(A, IsOdd)); + SmallVector Actual(Range.begin(), Range.end()); + EXPECT_EQ((SmallVector{5, 3, 1}), Actual); + + // Check that the reverse of the reverse is the original. + auto Range2 = reverse(reverse(make_filter_range(A, IsOdd))); + SmallVector Actual2(Range2.begin(), Range2.end()); + EXPECT_EQ((SmallVector{1, 3, 5}), Actual2); + + // Check empty ranges. + auto Range3 = reverse(make_filter_range(ArrayRef(), IsOdd)); + SmallVector Actual3(Range3.begin(), Range3.end()); + EXPECT_EQ((SmallVector{}), Actual3); + + // Check that we don't skip the first element, provided it isn't filtered + // away. + auto IsEven = [](int N) { return N % 2 == 0; }; + auto Range4 = reverse(make_filter_range(A, IsEven)); + SmallVector Actual4(Range4.begin(), Range4.end()); + EXPECT_EQ((SmallVector{6, 4, 2, 0}), Actual4); +} + TEST(PointerIterator, Basic) { int A[] = {1, 2, 3, 4}; pointer_iterator Begin(std::begin(A)), End(std::end(A)); diff --git a/llvm/unittests/IR/BasicBlockTest.cpp b/llvm/unittests/IR/BasicBlockTest.cpp index 5e3b11e..07ed997 100644 --- a/llvm/unittests/IR/BasicBlockTest.cpp +++ b/llvm/unittests/IR/BasicBlockTest.cpp @@ -69,6 +69,15 @@ TEST(BasicBlockTest, PhiRange) { CI = BB->phis().begin(); EXPECT_NE(CI, BB->phis().end()); + // Test that filtering iterators work with basic blocks. + auto isPhi = [](Instruction &I) { return isa(&I); }; + auto Phis = make_filter_range(*BB, isPhi); + auto ReversedPhis = reverse(make_filter_range(*BB, isPhi)); + EXPECT_EQ(std::distance(Phis.begin(), Phis.end()), 3); + EXPECT_EQ(&*Phis.begin(), P1); + EXPECT_EQ(std::distance(ReversedPhis.begin(), ReversedPhis.end()), 3); + EXPECT_EQ(&*ReversedPhis.begin(), P3); + // And iterate a const range. for (const auto &PN : const_cast(BB.get())->phis()) { EXPECT_EQ(BB.get(), PN.getIncomingBlock(0)); -- 2.7.4