[ADT] Allow `llvm::enumerate` to enumerate over multiple ranges
authorJakub Kuderski <kubak@google.com>
Wed, 15 Mar 2023 23:34:21 +0000 (19:34 -0400)
committerJakub Kuderski <kubak@google.com>
Wed, 15 Mar 2023 23:34:22 +0000 (19:34 -0400)
This does not work by a mere composition of `enumerate` and `zip_equal`,
because C++17 does not allow for recursive expansion of structured
bindings.

This implementation uses `zippy` to manage the iteratees and adds the
stream of indices as the first zipped range. Because we have an upfront
assertion that all input ranges are of the same length, we only need to
check if the second range has ended during iteration.

As a consequence of using `zippy`, `enumerate` will now follow the
reference and lifetime semantics of the `zip*` family of functions. The
main difference is that `enumerate` exposes each tuple of references
through a new tuple-like type `enumerate_result`, with the familiar
`.index()` and `.value()` member functions.

Because the `enumerate_result` returned on dereference is a
temporary, enumeration result can no longer be used through an
lvalue ref.

Reviewed By: dblaikie, zero9178

Differential Revision: https://reviews.llvm.org/D144503

15 files changed:
llvm/include/llvm/ADT/STLExtras.h
llvm/lib/Target/AArch64/AArch64PerfectShuffle.h
llvm/lib/Target/ARM/ARMLowOverheadLoops.cpp
llvm/lib/Target/RISCV/RISCVInstrInfo.cpp
llvm/tools/llvm-mca/Views/InstructionInfoView.cpp
llvm/unittests/ADT/STLExtrasTest.cpp
llvm/utils/TableGen/GlobalISel/GIMatchTree.cpp
mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/SCF/IR/SCF.cpp
mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorCodegen.cpp
mlir/lib/Dialect/Tensor/IR/TensorInferTypeOpInterfaceImpl.cpp
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-yaml-gen.cpp

index 545e888c18230bc417ad24824769af4d43aff215..bf33d7980106543fde5b25c33f844ff7efe8f979 100644 (file)
@@ -755,26 +755,25 @@ template<typename... Iters> struct ZipTupleType {
   using type = std::tuple<decltype(*declval<Iters>())...>;
 };
 
-template <typename ZipType, typename... Iters>
+template <typename ZipType, typename ReferenceTupleType, typename... Iters>
 using zip_traits = iterator_facade_base<
     ZipType,
     std::common_type_t<
         std::bidirectional_iterator_tag,
         typename std::iterator_traits<Iters>::iterator_category...>,
     // ^ TODO: Implement random access methods.
-    typename ZipTupleType<Iters...>::type,
+    ReferenceTupleType,
     typename std::iterator_traits<
         std::tuple_element_t<0, std::tuple<Iters...>>>::difference_type,
     // ^ FIXME: This follows boost::make_zip_iterator's assumption that all
     // inner iterators have the same difference_type. It would fail if, for
     // instance, the second field's difference_type were non-numeric while the
     // first is.
-    typename ZipTupleType<Iters...>::type *,
-    typename ZipTupleType<Iters...>::type>;
+    ReferenceTupleType *, ReferenceTupleType>;
 
-template <typename ZipType, typename... Iters>
-struct zip_common : public zip_traits<ZipType, Iters...> {
-  using Base = zip_traits<ZipType, Iters...>;
+template <typename ZipType, typename ReferenceTupleType, typename... Iters>
+struct zip_common : public zip_traits<ZipType, ReferenceTupleType, Iters...> {
+  using Base = zip_traits<ZipType, ReferenceTupleType, Iters...>;
   using IndexSequence = std::index_sequence_for<Iters...>;
   using value_type = typename Base::value_type;
 
@@ -824,8 +823,10 @@ public:
 };
 
 template <typename... Iters>
-struct zip_first : zip_common<zip_first<Iters...>, Iters...> {
-  using zip_common<zip_first, Iters...>::zip_common;
+struct zip_first : zip_common<zip_first<Iters...>,
+                              typename ZipTupleType<Iters...>::type, Iters...> {
+  using zip_common<zip_first, typename ZipTupleType<Iters...>::type,
+                   Iters...>::zip_common;
 
   bool operator==(const zip_first &other) const {
     return std::get<0>(this->iterators) == std::get<0>(other.iterators);
@@ -833,8 +834,11 @@ struct zip_first : zip_common<zip_first<Iters...>, Iters...> {
 };
 
 template <typename... Iters>
-struct zip_shortest : zip_common<zip_shortest<Iters...>, Iters...> {
-  using zip_common<zip_shortest, Iters...>::zip_common;
+struct zip_shortest
+    : zip_common<zip_shortest<Iters...>, typename ZipTupleType<Iters...>::type,
+                 Iters...> {
+  using zip_common<zip_shortest, typename ZipTupleType<Iters...>::type,
+                   Iters...>::zip_common;
 
   bool operator==(const zip_shortest &other) const {
     return any_iterator_equals(other, std::index_sequence_for<Iters...>{});
@@ -2213,113 +2217,182 @@ template <typename T> struct deref {
 
 namespace detail {
 
-template <typename R> class enumerator_iter;
+/// Tuple-like type for `zip_enumerator` dereference.
+template <typename... Refs> struct enumerator_result;
 
-template <typename R> struct result_pair {
-  using value_reference =
-      typename std::iterator_traits<IterOfRange<R>>::reference;
-
-  friend class enumerator_iter<R>;
-
-  result_pair(std::size_t Index, IterOfRange<R> Iter)
-      : Index(Index), Iter(Iter) {}
-
-  std::size_t index() const { return Index; }
-  value_reference value() const { return *Iter; }
-
-private:
-  std::size_t Index = std::numeric_limits<std::size_t>::max();
-  IterOfRange<R> Iter;
-};
-
-template <std::size_t i, typename R>
-decltype(auto) get(const result_pair<R> &Pair) {
-  static_assert(i < 2);
-  if constexpr (i == 0) {
-    return Pair.index();
-  } else {
-    return Pair.value();
-  }
-}
-
-template <typename R>
-class enumerator_iter
-    : public iterator_facade_base<enumerator_iter<R>, std::forward_iterator_tag,
-                                  const result_pair<R>> {
-  using result_type = result_pair<R>;
-
-public:
-  explicit enumerator_iter(IterOfRange<R> EndIter)
-      : Result(std::numeric_limits<size_t>::max(), EndIter) {}
-
-  enumerator_iter(std::size_t Index, IterOfRange<R> Iter)
-      : Result(Index, Iter) {}
-
-  const result_type &operator*() const { return Result; }
+template <typename... Iters>
+using EnumeratorTupleType = enumerator_result<decltype(*declval<Iters>())...>;
+
+/// Zippy iterator that uses the second iterator for comparisons. For the
+/// increment to be safe, the second range has to be the shortest.
+/// Returns `enumerator_result` on dereference to provide `.index()` and
+/// `.value()` member functions.
+/// Note: Because the dereference operator returns `enumerator_result` as a
+/// value instead of a reference and does not strictly conform to the C++17's
+/// definition of forward iterator. However, it satisfies all the
+/// forward_iterator requirements that the `zip_common` and `zippy` depend on
+/// and fully conforms to the C++20 definition of forward iterator.
+/// This is similar to `std::vector<bool>::iterator` that returns bit reference
+/// wrappers on dereference.
+template <typename... Iters>
+struct zip_enumerator : zip_common<zip_enumerator<Iters...>,
+                                   EnumeratorTupleType<Iters...>, Iters...> {
+  static_assert(sizeof...(Iters) >= 2, "Expected at least two iteratees");
+  using zip_common<zip_enumerator<Iters...>, EnumeratorTupleType<Iters...>,
+                   Iters...>::zip_common;
 
-  enumerator_iter &operator++() {
-    assert(Result.Index != std::numeric_limits<size_t>::max());
-    ++Result.Iter;
-    ++Result.Index;
-    return *this;
+  bool operator==(const zip_enumerator &Other) const {
+    return std::get<1>(this->iterators) == std::get<1>(Other.iterators);
   }
+};
 
-  bool operator==(const enumerator_iter &RHS) const {
-    // Don't compare indices here, only iterators.  It's possible for an end
-    // iterator to have different indices depending on whether it was created
-    // by calling std::end() versus incrementing a valid iterator.
-    return Result.Iter == RHS.Result.Iter;
+template <typename... Refs> struct enumerator_result<std::size_t, Refs...> {
+  static constexpr std::size_t NumRefs = sizeof...(Refs);
+  static_assert(NumRefs != 0);
+  // `NumValues` includes the index.
+  static constexpr std::size_t NumValues = NumRefs + 1;
+
+  // Tuple type whose element types are references for each `Ref`.
+  using range_reference_tuple = std::tuple<Refs...>;
+  // Tuple type who elements are references to all values, including both
+  // the index and `Refs` reference types.
+  using value_reference_tuple = std::tuple<std::size_t, Refs...>;
+
+  enumerator_result(std::size_t Index, Refs &&...Rs)
+      : Idx(Index), Storage(std::forward<Refs>(Rs)...) {}
+
+  /// Returns the 0-based index of the current position within the original
+  /// input range(s).
+  std::size_t index() const { return Idx; }
+
+  /// Returns the value(s) for the current iterator. This does not include the
+  /// index.
+  decltype(auto) value() const {
+    if constexpr (NumRefs == 1)
+      return std::get<0>(Storage);
+    else
+      return Storage;
+  }
+
+  /// Returns the value at index `I`. This includes the index.
+  template <std::size_t I>
+  friend decltype(auto) get(const enumerator_result &Result) {
+    static_assert(I < NumValues, "Index out of bounds");
+    if constexpr (I == 0)
+      return Result.Idx;
+    else
+      return std::get<I - 1>(Result.Storage);
+  }
+
+  template <typename... Ts>
+  friend bool operator==(const enumerator_result &Result,
+                         const std::tuple<std::size_t, Ts...> &Other) {
+    static_assert(NumRefs == sizeof...(Ts), "Size mismatch");
+    if (Result.Idx != std::get<0>(Other))
+      return false;
+    return Result.is_value_equal(Other, std::make_index_sequence<NumRefs>{});
   }
 
 private:
-  result_type Result;
+  template <typename Tuple, std::size_t... Idx>
+  bool is_value_equal(const Tuple &Other, std::index_sequence<Idx...>) const {
+    return ((std::get<Idx>(Storage) == std::get<Idx + 1>(Other)) && ...);
+  }
+
+  std::size_t Idx;
+  // Make this tuple mutable to avoid casts that obfuscate const-correctness
+  // issues. Const-correctness of references is taken care of by `zippy` that
+  // defines const-non and const iterator types that will propagate down to
+  // `enumerator_result`'s `Refs`.
+  //  Note that unlike the results of `zip*` functions, `enumerate`'s result are
+  //  supposed to be modifiable even when defined as
+  // `const`.
+  mutable range_reference_tuple Storage;
 };
 
-template <typename R> class enumerator {
-public:
-  explicit enumerator(R &&Range) : TheRange(std::forward<R>(Range)) {}
+/// Infinite stream of increasing 0-based `size_t` indices.
+struct index_stream {
+  struct iterator : iterator_facade_base<iterator, std::forward_iterator_tag,
+                                         const iterator> {
+    iterator &operator++() {
+      assert(Index != std::numeric_limits<std::size_t>::max() &&
+             "Attempting to increment end iterator");
+      ++Index;
+      return *this;
+    }
 
-  enumerator_iter<R> begin() {
-    return enumerator_iter<R>(0, adl_begin(TheRange));
-  }
-  enumerator_iter<R> begin() const {
-    return enumerator_iter<R>(0, adl_begin(TheRange));
-  }
+    // Note: This dereference operator returns a value instead of a reference
+    // and does not strictly conform to the C++17's definition of forward
+    // iterator. However, it satisfies all the forward_iterator requirements
+    // that the `zip_common` depends on and fully conforms to the C++20
+    // definition of forward iterator.
+    std::size_t operator*() const { return Index; }
 
-  enumerator_iter<R> end() { return enumerator_iter<R>(adl_end(TheRange)); }
-  enumerator_iter<R> end() const {
-    return enumerator_iter<R>(adl_end(TheRange));
-  }
+    friend bool operator==(const iterator &Lhs, const iterator &Rhs) {
+      return Lhs.Index == Rhs.Index;
+    }
 
-private:
-  R TheRange;
+    std::size_t Index = 0;
+  };
+
+  iterator begin() const { return {}; }
+  iterator end() const {
+    // We approximate 'infinity' with the max size_t value, which should be good
+    // enough to index over any container.
+    iterator It;
+    It.Index = std::numeric_limits<std::size_t>::max();
+    return It;
+  }
 };
 
 } // end namespace detail
 
-/// Given an input range, returns a new range whose values are are pair (A,B)
-/// such that A is the 0-based index of the item in the sequence, and B is
-/// the value from the original sequence.  Example:
+/// Given two or more input ranges, returns a new range whose values are are
+/// tuples (A, B, C, ...), such that A is the 0-based index of the item in the
+/// sequence, and B, C, ..., are the values from the original input ranges. All
+/// input ranges are required to have equal lengths. Note that the returned
+/// iterator allows for the values (B, C, ...) to be modified.  Example:
+///
+/// ```c++
+/// std::vector<char> Letters = {'A', 'B', 'C', 'D'};
+/// std::vector<int> Vals = {10, 11, 12, 13};
 ///
-/// std::vector<char> Items = {'A', 'B', 'C', 'D'};
-/// for (auto X : enumerate(Items)) {
-///   printf("Item %zu - %c\n", X.index(), X.value());
+/// for (auto [Index, Letter, Value] : enumerate(Letters, Vals)) {
+///   printf("Item %zu - %c: %d\n", Index, Letter, Value);
+///   Value -= 10;
 /// }
+/// ```
 ///
-/// or using structured bindings:
+/// Output:
+///   Item 0 - A: 10
+///   Item 1 - B: 11
+///   Item 2 - C: 12
+///   Item 3 - D: 13
 ///
-/// for (auto [Index, Value] : enumerate(Items)) {
-///   printf("Item %zu - %c\n", Index, Value);
+/// or using an iterator:
+/// ```c++
+/// for (auto it : enumerate(Vals)) {
+///   it.value() += 10;
+///   printf("Item %zu: %d\n", it.index(), it.value());
 /// }
+/// ```
 ///
 /// Output:
-///   Item 0 - A
-///   Item 1 - B
-///   Item 2 - C
-///   Item 3 - D
+///   Item 0: 20
+///   Item 1: 21
+///   Item 2: 22
+///   Item 3: 23
 ///
-template <typename R> detail::enumerator<R> enumerate(R &&TheRange) {
-  return detail::enumerator<R>(std::forward<R>(TheRange));
+template <typename FirstRange, typename... RestRanges>
+auto enumerate(FirstRange &&First, RestRanges &&...Rest) {
+  assert((sizeof...(Rest) == 0 ||
+          all_equal({std::distance(adl_begin(First), adl_end(First)),
+                     std::distance(adl_begin(Rest), adl_end(Rest))...})) &&
+         "Ranges have different length");
+  using enumerator = detail::zippy<detail::zip_enumerator, detail::index_stream,
+                                   FirstRange, RestRanges...>;
+  return enumerator(detail::index_stream{}, std::forward<FirstRange>(First),
+                    std::forward<RestRanges>(Rest)...);
 }
 
 namespace detail {
@@ -2451,15 +2524,17 @@ template <class T> constexpr T *to_address(T *P) { return P; }
 } // end namespace llvm
 
 namespace std {
-template <typename R>
-struct tuple_size<llvm::detail::result_pair<R>>
-    : std::integral_constant<std::size_t, 2> {};
+template <typename... Refs>
+struct tuple_size<llvm::detail::enumerator_result<Refs...>>
+    : std::integral_constant<std::size_t, sizeof...(Refs)> {};
 
-template <std::size_t i, typename R>
-struct tuple_element<i, llvm::detail::result_pair<R>>
-    : std::conditional<i == 0, std::size_t,
-                       typename llvm::detail::result_pair<R>::value_reference> {
-};
+template <std::size_t I, typename... Refs>
+struct tuple_element<I, llvm::detail::enumerator_result<Refs...>>
+    : std::tuple_element<I, std::tuple<Refs...>> {};
+
+template <std::size_t I, typename... Refs>
+struct tuple_element<I, const llvm::detail::enumerator_result<Refs...>>
+    : std::tuple_element<I, std::tuple<Refs...>> {};
 
 } // namespace std
 
index 4555f1a3ebb0885e134d72d88cf8c3e4dafcbd1b..5846fd454b654847fc65aaaadb74f2e121b06cf5 100644 (file)
@@ -6590,11 +6590,11 @@ static unsigned getPerfectShuffleCost(llvm::ArrayRef<int> M) {
   assert(M.size() == 4 && "Expected a 4 entry perfect shuffle");
 
   // Special case zero-cost nop copies, from either LHS or RHS.
-  if (llvm::all_of(llvm::enumerate(M), [](auto &E) {
+  if (llvm::all_of(llvm::enumerate(M), [](const auto &E) {
         return E.value() < 0 || E.value() == (int)E.index();
       }))
     return 0;
-  if (llvm::all_of(llvm::enumerate(M), [](auto &E) {
+  if (llvm::all_of(llvm::enumerate(M), [](const auto &E) {
         return E.value() < 0 || E.value() == (int)E.index() + 4;
       }))
     return 0;
index c8b156a6f2c0f3d194af13db86748c76d08948e4..1e0a7d11d8ac21a4da3c8301e45e1cd4f38459ed 100644 (file)
@@ -1249,7 +1249,7 @@ bool LowOverheadLoop::ValidateMVEInst(MachineInstr *MI) {
   const MCInstrDesc &MCID = MI->getDesc();
   bool IsUse = false;
   unsigned LastOpIdx = MI->getNumOperands() - 1;
-  for (auto &Op : enumerate(reverse(MCID.operands()))) {
+  for (const auto &Op : enumerate(reverse(MCID.operands()))) {
     const MachineOperand &MO = MI->getOperand(LastOpIdx - Op.index());
     if (!MO.isReg() || !MO.isUse() || MO.getReg() != ARM::VPR)
       continue;
index 08f1c9ffb738a3499f741a50cb52affa08b7fc71..d525365762293f190068c029bb27f03d7b10c7d5 100644 (file)
@@ -1651,11 +1651,11 @@ bool RISCVInstrInfo::verifyInstruction(const MachineInstr &MI,
                                        StringRef &ErrInfo) const {
   MCInstrDesc const &Desc = MI.getDesc();
 
-  for (auto &OI : enumerate(Desc.operands())) {
-    unsigned OpType = OI.value().OperandType;
+  for (const auto &[Index, Operand] : enumerate(Desc.operands())) {
+    unsigned OpType = Operand.OperandType;
     if (OpType >= RISCVOp::OPERAND_FIRST_RISCV_IMM &&
         OpType <= RISCVOp::OPERAND_LAST_RISCV_IMM) {
-      const MachineOperand &MO = MI.getOperand(OI.index());
+      const MachineOperand &MO = MI.getOperand(Index);
       if (MO.isImm()) {
         int64_t Imm = MO.getImm();
         bool Ok;
index 257fdca8cb366b11f20233566df18ebb3adf688e..3262de48d41e30819c93399514a482a72e4950c7 100644 (file)
@@ -55,10 +55,7 @@ void InstructionInfoView::printView(raw_ostream &OS) const {
     }
   }
 
-  int Index = 0;
-  for (const auto &I : enumerate(zip(IIVD, Source))) {
-    const InstructionInfoViewData &IIVDEntry = std::get<0>(I.value());
-
+  for (const auto &[Index, IIVDEntry, Inst] : enumerate(IIVD, Source)) {
     TempStream << ' ' << IIVDEntry.NumMicroOpcodes << "    ";
     if (IIVDEntry.NumMicroOpcodes < 10)
       TempStream << "  ";
@@ -92,7 +89,7 @@ void InstructionInfoView::printView(raw_ostream &OS) const {
     }
 
     if (PrintEncodings) {
-      StringRef Encoding(CE.getEncoding(I.index()));
+      StringRef Encoding(CE.getEncoding(Index));
       unsigned EncodingSize = Encoding.size();
       TempStream << " " << EncodingSize
                  << (EncodingSize < 10 ? "     " : "    ");
@@ -104,9 +101,7 @@ void InstructionInfoView::printView(raw_ostream &OS) const {
       FOS.flush();
     }
 
-    const MCInst &Inst = std::get<1>(I.value());
     TempStream << printInstructionString(Inst) << '\n';
-    ++Index;
   }
 
   TempStream.flush();
index 8ec3df10bc6c3cff90d87bc3acd8cc95d3eee254..bb602bb6c39f736bc4fcee60d9e617d828cf0677 100644 (file)
 
 #include <array>
 #include <climits>
+#include <cstddef>
+#include <initializer_list>
 #include <list>
 #include <tuple>
+#include <type_traits>
 #include <utility>
 #include <vector>
 
@@ -153,6 +156,131 @@ TEST(STLExtrasTest, EnumerateModifyRValue) {
                                    PairType(2u, '4')));
 }
 
+TEST(STLExtrasTest, EnumerateTwoRanges) {
+  using Tuple = std::tuple<size_t, int, bool>;
+
+  std::vector<int> Ints = {1, 2};
+  std::vector<bool> Bools = {true, false};
+  EXPECT_THAT(llvm::enumerate(Ints, Bools),
+              ElementsAre(Tuple(0, 1, true), Tuple(1, 2, false)));
+
+  // Check that we can modify the values when the temporary is a const
+  // reference.
+  for (const auto &[Idx, Int, Bool] : llvm::enumerate(Ints, Bools)) {
+    (void)Idx;
+    Bool = false;
+    Int = -1;
+  }
+
+  EXPECT_THAT(Ints, ElementsAre(-1, -1));
+  EXPECT_THAT(Bools, ElementsAre(false, false));
+
+  // Check that we can modify the values when the result gets copied.
+  for (auto [Idx, Bool, Int] : llvm::enumerate(Bools, Ints)) {
+    (void)Idx;
+    Int = 3;
+    Bool = true;
+  }
+
+  EXPECT_THAT(Ints, ElementsAre(3, 3));
+  EXPECT_THAT(Bools, ElementsAre(true, true));
+
+  // Check that we can modify the values through `.value()`.
+  size_t Iters = 0;
+  for (auto It : llvm::enumerate(Bools, Ints)) {
+    EXPECT_EQ(It.index(), Iters);
+    ++Iters;
+
+    std::get<0>(It.value()) = false;
+    std::get<1>(It.value()) = 4;
+  }
+
+  EXPECT_THAT(Ints, ElementsAre(4, 4));
+  EXPECT_THAT(Bools, ElementsAre(false, false));
+}
+
+TEST(STLExtrasTest, EnumerateThreeRanges) {
+  using Tuple = std::tuple<size_t, int, bool, char>;
+
+  std::vector<int> Ints = {1, 2};
+  std::vector<bool> Bools = {true, false};
+  char Chars[] = {'X', 'D'};
+  EXPECT_THAT(llvm::enumerate(Ints, Bools, Chars),
+              ElementsAre(Tuple(0, 1, true, 'X'), Tuple(1, 2, false, 'D')));
+
+  for (auto [Idx, Int, Bool, Char] : llvm::enumerate(Ints, Bools, Chars)) {
+    (void)Idx;
+    Int = 0;
+    Bool = true;
+    Char = '!';
+  }
+
+  EXPECT_THAT(Ints, ElementsAre(0, 0));
+  EXPECT_THAT(Bools, ElementsAre(true, true));
+  EXPECT_THAT(Chars, ElementsAre('!', '!'));
+
+  // Check that we can modify the values through `.values()`.
+  size_t Iters = 0;
+  for (auto It : llvm::enumerate(Ints, Bools, Chars)) {
+    EXPECT_EQ(It.index(), Iters);
+    ++Iters;
+    auto [Int, Bool, Char] = It.value();
+    Int = 42;
+    Bool = false;
+    Char = '$';
+  }
+
+  EXPECT_THAT(Ints, ElementsAre(42, 42));
+  EXPECT_THAT(Bools, ElementsAre(false, false));
+  EXPECT_THAT(Chars, ElementsAre('$', '$'));
+}
+
+TEST(STLExtrasTest, EnumerateTemporaries) {
+  using Tuple = std::tuple<size_t, int, bool>;
+
+  EXPECT_THAT(
+      llvm::enumerate(llvm::SmallVector<int>({1, 2, 3}),
+                      std::vector<bool>({true, false, true})),
+      ElementsAre(Tuple(0, 1, true), Tuple(1, 2, false), Tuple(2, 3, true)));
+
+  size_t Iters = 0;
+  // This is fine from the point of view of range lifetimes because `zippy` will
+  // move all temporaries into its storage. No lifetime extension is necessary.
+  for (auto [Idx, Int, Bool] :
+       llvm::enumerate(llvm::SmallVector<int>({1, 2, 3}),
+                       std::vector<bool>({true, false, true}))) {
+    EXPECT_EQ(Idx, Iters);
+    ++Iters;
+    Int = 0;
+    Bool = true;
+  }
+
+  Iters = 0;
+  // The same thing but with the result as a const reference.
+  for (const auto &[Idx, Int, Bool] :
+       llvm::enumerate(llvm::SmallVector<int>({1, 2, 3}),
+                       std::vector<bool>({true, false, true}))) {
+    EXPECT_EQ(Idx, Iters);
+    ++Iters;
+    Int = 0;
+    Bool = true;
+  }
+}
+
+#if defined(GTEST_HAS_DEATH_TEST) && !defined(NDEBUG)
+TEST(STLExtrasTest, EnumerateDifferentLengths) {
+  std::vector<int> Ints = {0, 1};
+  bool Bools[] = {true, false, true};
+  std::string Chars = "abc";
+  EXPECT_DEATH(llvm::enumerate(Ints, Bools, Chars),
+               "Ranges have different length");
+  EXPECT_DEATH(llvm::enumerate(Bools, Ints, Chars),
+               "Ranges have different length");
+  EXPECT_DEATH(llvm::enumerate(Bools, Chars, Ints),
+               "Ranges have different length");
+}
+#endif
+
 template <bool B> struct CanMove {};
 template <> struct CanMove<false> {
   CanMove(CanMove &&) = delete;
@@ -190,8 +318,8 @@ public:
 template <bool Moveable, bool Copyable>
 struct Range : Counted<Moveable, Copyable> {
   using Counted<Moveable, Copyable>::Counted;
-  int *begin() { return nullptr; }
-  int *end() { return nullptr; }
+  int *begin() const { return nullptr; }
+  int *end() const { return nullptr; }
 };
 
 TEST(STLExtrasTest, EnumerateLifetimeSemanticsPRValue) {
index 493c6b35c8406c1cc4664df2cf94f2034784c673..fcf9c25ae74a3e57566707dbb1efb64bb214fc0f 100644 (file)
@@ -338,9 +338,9 @@ void GIMatchTreeBuilder::runStep() {
          "Must always partition into at least one partition");
 
   TreeNode->setNumChildren(Partitioner->getNumPartitions());
-  for (auto &C : enumerate(TreeNode->children())) {
-    SubtreeBuilders.emplace_back(&C.value(), NextInstrID);
-    Partitioner->applyForPartition(C.index(), *this, SubtreeBuilders.back());
+  for (const auto &[Idx, Child] : enumerate(TreeNode->children())) {
+    SubtreeBuilders.emplace_back(&Child, NextInstrID);
+    Partitioner->applyForPartition(Idx, *this, SubtreeBuilders.back());
   }
 
   TreeNode->setPartitioner(std::move(Partitioner));
index 6e2aa2c2755bc3d4c89f1cdddeb430fbe3dd62a4..4df55ddd62e0f5723acbc84c980906f4231844d5 100644 (file)
@@ -1394,7 +1394,7 @@ void generateCollapsedIndexingRegion(Location loc, Block *block,
   //   i1 = (i_{folded} / d2) % d1
   //   i0 = i_{folded} / (d1 * d2)
   llvm::DenseMap<unsigned, Value> indexReplacementVals;
-  for (auto &foldedDims :
+  for (auto foldedDims :
        enumerate(collapsingInfo.getCollapsedOpToOrigOpMapping())) {
     ReassociationIndicesRef foldedDimsRef(foldedDims.value());
     Value newIndexVal =
index c8999943d868eeae5f83631ea6a640616f73d7eb..27f1eab09eec6e315e3b99469c6e3efa4efda5fb 100644 (file)
@@ -1871,15 +1871,13 @@ LogicalResult ReinterpretCastOp::verify() {
            << srcType << " and result memref type " << resultType;
 
   // Match sizes in result memref type and in static_sizes attribute.
-  for (auto &en :
-       llvm::enumerate(llvm::zip(resultType.getShape(), getStaticSizes()))) {
-    int64_t resultSize = std::get<0>(en.value());
-    int64_t expectedSize = std::get<1>(en.value());
+  for (auto [idx, resultSize, expectedSize] :
+       llvm::enumerate(resultType.getShape(), getStaticSizes())) {
     if (!ShapedType::isDynamic(resultSize) &&
         !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
       return emitError("expected result type with size = ")
              << expectedSize << " instead of " << resultSize
-             << " in dim = " << en.index();
+             << " in dim = " << idx;
   }
 
   // Match offset and strides in static_offset and static_strides attributes. If
@@ -1900,16 +1898,14 @@ LogicalResult ReinterpretCastOp::verify() {
            << resultOffset << " instead of " << expectedOffset;
 
   // Match strides in result memref type and in static_strides attribute.
-  for (auto &en :
-       llvm::enumerate(llvm::zip(resultStrides, getStaticStrides()))) {
-    int64_t resultStride = std::get<0>(en.value());
-    int64_t expectedStride = std::get<1>(en.value());
+  for (auto [idx, resultStride, expectedStride] :
+       llvm::enumerate(resultStrides, getStaticStrides())) {
     if (!ShapedType::isDynamic(resultStride) &&
         !ShapedType::isDynamic(expectedStride) &&
         resultStride != expectedStride)
       return emitError("expected result type with stride = ")
              << expectedStride << " instead of " << resultStride
-             << " in dim = " << en.index();
+             << " in dim = " << idx;
   }
 
   return success();
index 2b59ada88c33877b29761cb2ff6510de138c9975..f3fcd5ac202633114454b04c2d250f95df9cebfa 100644 (file)
@@ -1891,10 +1891,7 @@ struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
     auto elseYieldArgs = op.elseYield().getOperands();
 
     SmallVector<Type> nonHoistable;
-    for (const auto &it :
-         llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
-      Value trueVal = std::get<0>(it.value());
-      Value falseVal = std::get<1>(it.value());
+    for (auto [trueVal, falseVal] : llvm::zip(thenYieldArgs, elseYieldArgs)) {
       if (&op.getThenRegion() == trueVal.getParentRegion() ||
           &op.getElseRegion() == falseVal.getParentRegion())
         nonHoistable.push_back(trueVal.getType());
index 7a802379355fe1ea43e4e67f54692bdb2250adb5..915e4b4ed1c565d786c9374fa2bd7777624eb026 100644 (file)
@@ -423,7 +423,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
     return b.notifyMatchFailure(
         op, "only support ops with one reduction dimension.");
   int reductionDim;
-  for (auto &[idx, iteratorType] :
+  for (auto [idx, iteratorType] :
        llvm::enumerate(tilingInterfaceOp.getLoopIteratorTypes())) {
     if (iteratorType == utils::IteratorType::reduction) {
       reductionDim = idx;
index 71c78d9061a9bf94dcc7073c09b0f7b8ba06e6c9..ae6e40f1c19cbd4b3845b3007d9a970a6fe38d4e 100644 (file)
@@ -1148,10 +1148,9 @@ public:
     desc.setSpecifier(newSpec);
 
     // Fills in slice information.
-    for (const auto &it : llvm::enumerate(llvm::zip(
-             op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides()))) {
-      Dimension dim = it.index();
-      auto [offset, size, stride] = it.value();
+    for (auto [idx, offset, size, stride] : llvm::enumerate(
+             op.getMixedOffsets(), op.getMixedSizes(), op.getMixedStrides())) {
+      Dimension dim = idx;
 
       Value offsetV = getValueOrCreateConstantIndexOp(rewriter, loc, offset);
       Value sizeV = getValueOrCreateConstantIndexOp(rewriter, loc, size);
index eb73a2c87295647f8c68ecae8046489d3d7e4b43..ea087c1357aec2c1f396acf61538b131ef3dd7a9 100644 (file)
@@ -93,7 +93,7 @@ static OpFoldResult getExpandedOutputDimFromInputShape(
                         .cast<AffineDimExpr>()
                         .getPosition();
   int64_t linearizedStaticDim = 1;
-  for (auto &d :
+  for (auto d :
        llvm::enumerate(dstStaticShape.slice(startPos, endPos - startPos + 1))) {
     if (d.index() + startPos == static_cast<unsigned>(dimIndex))
       continue;
index fdeb1e490033dd88547a289599560d5ef87a2ef3..43dedf1c4ce065e43f1e3fd9eb5c63eb3cb63bd4 100644 (file)
@@ -762,7 +762,7 @@ struct InsertSliceOpInterface
           return isConstantIntValue(ofr, 0);
         });
     bool sizesMatchDestSizes = llvm::all_of(
-        llvm::enumerate(insertSliceOp.getMixedSizes()), [&](auto &it) {
+        llvm::enumerate(insertSliceOp.getMixedSizes()), [&](const auto &it) {
           return getConstantIntValue(it.value()) ==
                  destType.getDimSize(it.index());
         });
index a51bdb50da9edee0833feaaf9d33f06cf1b56678..52ea148179357238073201f11607080617db7361 100644 (file)
@@ -869,14 +869,14 @@ exprs.push_back(getAffineConstantExpr(cst{1}, context));
           if (arg.kind != LinalgOperandDefKind::IndexAttr)
             continue;
           assert(arg.indexAttrMap);
-          for (auto &en :
+          for (auto [idx, result] :
                llvm::enumerate(arg.indexAttrMap->affineMap().getResults())) {
-            if (auto symbol = en.value().dyn_cast<AffineSymbolExpr>()) {
+            if (auto symbol = result.dyn_cast<AffineSymbolExpr>()) {
               std::string argName = arg.name;
               argName[0] = toupper(argName[0]);
               symbolBindings[symbol.getPosition()] =
                   llvm::formatv(structuredOpAccessAttrFormat, argName,
-                                symbol.getPosition(), en.index());
+                                symbol.getPosition(), idx);
             }
           }
         }