From 7be6a40ab9b914b14ab61ae13e47e0bb8237e74d Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 9 Dec 2019 12:55:05 -0800 Subject: [PATCH] Add new indexed_accessor_range_base and indexed_accessor_range classes that simplify defining index-able ranges. Many ranges want similar functionality from a range type(e.g. slice/drop_front/operator[]/etc.), so these classes provide a generic implementation that may be used by many different types of ranges. This removes some code duplication, and also empowers many of the existing range types in MLIR(e.g. result type ranges, operand ranges, ElementsAttr ranges, etc.). This change only updates RegionRange and ValueRange, more ranges will be updated in followup commits. PiperOrigin-RevId: 284615679 --- mlir/include/mlir/IR/Attributes.h | 6 +- mlir/include/mlir/IR/Block.h | 4 +- mlir/include/mlir/IR/Operation.h | 67 +++++----------- mlir/include/mlir/IR/Region.h | 45 ++++------- mlir/include/mlir/Support/STLExtras.h | 140 +++++++++++++++++++++++++++++++--- mlir/lib/IR/Attributes.cpp | 2 +- mlir/lib/IR/Operation.cpp | 57 +++++--------- mlir/lib/IR/Region.cpp | 26 ++++--- 8 files changed, 206 insertions(+), 141 deletions(-) diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 3968d44..59df75d 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -641,12 +641,12 @@ protected: /// Return the current index for this iterator, adjusted for the case of a /// splat. ptrdiff_t getDataIndex() const { - bool isSplat = this->object.getInt(); + bool isSplat = this->base.getInt(); return isSplat ? 0 : this->index; } - /// Return the data object pointer. - const char *getData() const { return this->object.getPointer(); } + /// Return the data base pointer. + const char *getData() const { return this->base.getPointer(); } }; } // namespace detail diff --git a/mlir/include/mlir/IR/Block.h b/mlir/include/mlir/IR/Block.h index f01f191..532352e 100644 --- a/mlir/include/mlir/IR/Block.h +++ b/mlir/include/mlir/IR/Block.h @@ -460,9 +460,9 @@ public: Block *>(object, index) {} SuccessorIterator(const SuccessorIterator &other) - : SuccessorIterator(other.object, other.index) {} + : SuccessorIterator(other.base, other.index) {} - Block *operator*() const { return this->object->getSuccessor(this->index); } + Block *operator*() const { return this->base->getSuccessor(this->index); } /// Get the successor number in the terminator. unsigned getSuccessorIndex() const { return this->index; } diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 75ea972..037c4fc 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -668,7 +668,7 @@ public: : indexed_accessor_iterator(object, index) {} - Value *operator*() const { return this->object->getOperand(this->index); } + Value *operator*() const { return this->base->getOperand(this->index); } }; /// This class implements the operand type iterators for the Operation @@ -721,11 +721,11 @@ class ResultIterator final Value *, Value *> { public: /// Initializes the result iterator to the specified index. - ResultIterator(Operation *object, unsigned index) + ResultIterator(Operation *base, unsigned index) : indexed_accessor_iterator(object, index) {} + Value *>(base, index) {} - Value *operator*() const { return this->object->getResult(this->index); } + Value *operator*() const { return this->base->getResult(this->index); } }; /// This class implements the result type iterators for the Operation @@ -799,15 +799,19 @@ inline auto Operation::getResultTypes() -> result_type_range { /// SmallVector/std::vector. This class should be used in places that are not /// suitable for a more derived type (e.g. ArrayRef) or a template range /// parameter. -class ValueRange { +class ValueRange + : public detail::indexed_accessor_range_base< + ValueRange, + llvm::PointerUnion, Value *, + Value *, Value *> { /// The type representing the owner of this range. This is either a list of /// values, operands, or results. using OwnerT = llvm::PointerUnion; public: - ValueRange(const ValueRange &) = default; - ValueRange(ValueRange &&) = default; - ValueRange &operator=(const ValueRange &) = default; + using detail::indexed_accessor_range_base< + ValueRange, OwnerT, Value *, Value *, + Value *>::indexed_accessor_range_base; template values); ValueRange(iterator_range values); - /// An iterator element of this range. - class Iterator : public indexed_accessor_iterator { - public: - Value *operator*() const; - - private: - Iterator(OwnerT owner, unsigned curIndex); - - /// Allow access to the constructor. - friend ValueRange; - }; - - Iterator begin() const { return Iterator(owner, 0); } - Iterator end() const { return Iterator(owner, count); } - Value *operator[](unsigned index) const { - assert(index < size() && "invalid index for value range"); - return *std::next(begin(), index); - } - - /// Return the size of this range. - size_t size() const { return count; } - - /// Return if the range is empty. - bool empty() const { return size() == 0; } - - /// Drop the first N elements, and keep M elements. - ValueRange slice(unsigned n, unsigned m) const; - /// Drop the first n elements. - ValueRange drop_front(unsigned n = 1) const; - /// Drop the last n elements. - ValueRange drop_back(unsigned n = 1) const; - private: - ValueRange(OwnerT owner, unsigned count) : owner(owner), count(count) {} - - /// The object that owns the provided range of values. - OwnerT owner; - /// The size from the owning range. - unsigned count; + /// See `detail::indexed_accessor_range_base` for details. + static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index); + /// See `detail::indexed_accessor_range_base` for details. + static Value *dereference_iterator(const OwnerT &owner, ptrdiff_t index); + + /// Allow access to `offset_base` and `dereference_iterator`. + friend detail::indexed_accessor_range_base; }; } // end namespace mlir diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h index 933bf10..3d25140 100644 --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -165,14 +165,19 @@ private: /// SmallVector/std::vector. This class should be used in places that are not /// suitable for a more derived type (e.g. ArrayRef) or a template range /// parameter. -class RegionRange { +class RegionRange + : public detail::indexed_accessor_range_base< + RegionRange, + llvm::PointerUnion *>, + Region *, Region *, Region *> { /// The type representing the owner of this range. This is either a list of /// values, operands, or results. using OwnerT = llvm::PointerUnion *>; public: - RegionRange(const RegionRange &) = default; - RegionRange(RegionRange &&) = default; + using detail::indexed_accessor_range_base< + RegionRange, OwnerT, Region *, Region *, + Region *>::indexed_accessor_range_base; RegionRange(MutableArrayRef regions = llvm::None); @@ -184,33 +189,15 @@ public: } RegionRange(ArrayRef> regions); - /// An iterator element of this range. - class Iterator : public indexed_accessor_iterator { - public: - Region *operator*() const; - - private: - Iterator(OwnerT owner, unsigned curIndex); - /// Allow access to the constructor. - friend RegionRange; - }; - Iterator begin() const { return Iterator(owner, 0); } - Iterator end() const { return Iterator(owner, count); } - Region *operator[](unsigned index) const { - assert(index < size() && "invalid index for region range"); - return *std::next(begin(), index); - } - /// Return the size of this range. - size_t size() const { return count; } - /// Return if the range is empty. - bool empty() const { return size() == 0; } - private: - /// The object that owns the provided range of regions. - OwnerT owner; - /// The size from the owning range. - unsigned count; + /// See `detail::indexed_accessor_range_base` for details. + static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index); + /// See `detail::indexed_accessor_range_base` for details. + static Region *dereference_iterator(const OwnerT &owner, ptrdiff_t index); + + /// Allow access to `offset_base` and `dereference_iterator`. + friend detail::indexed_accessor_range_base; }; } // end namespace mlir diff --git a/mlir/include/mlir/Support/STLExtras.h b/mlir/include/mlir/Support/STLExtras.h index 95e52f9..07db06a 100644 --- a/mlir/include/mlir/Support/STLExtras.h +++ b/mlir/include/mlir/Support/STLExtras.h @@ -147,9 +147,9 @@ using is_invocable = is_detected; // Extra additions to //===----------------------------------------------------------------------===// -/// A utility class used to implement an iterator that contains some object and -/// an index. The iterator moves the index but keeps the object constant. -template class indexed_accessor_iterator : public llvm::iterator_facade_base { public: ptrdiff_t operator-(const indexed_accessor_iterator &rhs) const { - assert(object == rhs.object && "incompatible iterators"); + assert(base == rhs.base && "incompatible iterators"); return index - rhs.index; } bool operator==(const indexed_accessor_iterator &rhs) const { - return object == rhs.object && index == rhs.index; + return base == rhs.base && index == rhs.index; } bool operator<(const indexed_accessor_iterator &rhs) const { - assert(object == rhs.object && "incompatible iterators"); + assert(base == rhs.base && "incompatible iterators"); return index < rhs.index; } @@ -180,16 +180,134 @@ public: /// Returns the current index of the iterator. ptrdiff_t getIndex() const { return index; } - /// Returns the current object of the iterator. - const ObjectType &getObject() const { return object; } + /// Returns the current base of the iterator. + const BaseT &getBase() const { return base; } protected: - indexed_accessor_iterator(ObjectType object, ptrdiff_t index) - : object(object), index(index) {} - ObjectType object; + indexed_accessor_iterator(BaseT base, ptrdiff_t index) + : base(base), index(index) {} + BaseT base; ptrdiff_t index; }; +namespace detail { +/// The class represents the base of a range of indexed_accessor_iterators. It +/// provides support for many different range functionalities, e.g. +/// drop_front/slice/etc.. Derived range classes must implement the following +/// static methods: +/// * ReferenceT dereference_iterator(const BaseT &base, ptrdiff_t index) +/// - Derefence an iterator pointing to the base object at the given index. +/// * BaseT offset_base(const BaseT &base, ptrdiff_t index) +/// - Return a new base that is offset from the provide base by 'index' +/// elements. +template +class indexed_accessor_range_base { +public: + /// An iterator element of this range. + class iterator : public indexed_accessor_iterator { + public: + // Index into this iterator, invoking a static method on the derived type. + ReferenceT operator*() const { + return DerivedT::dereference_iterator(this->getBase(), this->getIndex()); + } + + private: + iterator(BaseT owner, ptrdiff_t curIndex) + : indexed_accessor_iterator( + owner, curIndex) {} + + /// Allow access to the constructor. + friend indexed_accessor_range_base; + }; + + iterator begin() const { return iterator(base, 0); } + iterator end() const { return iterator(base, count); } + ReferenceT operator[](unsigned index) const { + assert(index < size() && "invalid index for value range"); + return *std::next(begin(), index); + } + + /// Return the size of this range. + size_t size() const { return count; } + + /// Return if the range is empty. + bool empty() const { return size() == 0; } + + /// Drop the first N elements, and keep M elements. + DerivedT slice(unsigned n, unsigned m) const { + assert(n + m <= size() && "invalid size specifiers"); + return DerivedT(DerivedT::offset_base(base, n), m); + } + + /// Drop the first n elements. + DerivedT drop_front(unsigned n = 1) const { + assert(size() >= n && "Dropping more elements than exist"); + return slice(n, size() - n); + } + /// Drop the last n elements. + DerivedT drop_back(unsigned n = 1) const { + assert(size() >= n && "Dropping more elements than exist"); + return DerivedT(base, size() - n); + } + +protected: + indexed_accessor_range_base(BaseT base, ptrdiff_t count) + : base(base), count(count) {} + indexed_accessor_range_base(const indexed_accessor_range_base &) = default; + indexed_accessor_range_base(indexed_accessor_range_base &&) = default; + indexed_accessor_range_base & + operator=(const indexed_accessor_range_base &) = default; + + /// The base that owns the provided range of values. + BaseT base; + /// The size from the owning range. + ptrdiff_t count; +}; +} // end namespace detail + +/// This class provides an implementation of a range of +/// indexed_accessor_iterators where the base is not indexable. Ranges with +/// bases that are offsetable should derive from indexed_accessor_range_base +/// instead. Derived range classes are expected to implement the following +/// static method: +/// * ReferenceT dereference_iterator(const BaseT &base, ptrdiff_t index) +/// - Derefence an iterator pointing to a parent base at the given index. +template +class indexed_accessor_range + : public detail::indexed_accessor_range_base< + indexed_accessor_range, + std::pair, T, PointerT, ReferenceT> { +protected: + indexed_accessor_range(BaseT base, ptrdiff_t startIndex, ptrdiff_t count) + : detail::indexed_accessor_range_base< + DerivedT, std::pair, T, PointerT, ReferenceT>( + std::make_pair(base, startIndex), count) {} + +private: + /// See `detail::indexed_accessor_range_base` for details. + static std::pair + offset_base(const std::pair &base, ptrdiff_t index) { + // We encode the internal base as a pair of the derived base and a start + // index into the derived base. + return std::make_pair(base.first, base.second + index); + } + /// See `detail::indexed_accessor_range_base` for details. + static ReferenceT + dereference_iterator(const std::pair &base, + ptrdiff_t index) { + return DerivedT::dereference_iterator(base.first, base.second + index); + } + + /// Allow access to `offset_base` and `dereference_iterator`. + friend detail::indexed_accessor_range_base< + indexed_accessor_range, + std::pair, T, PointerT, ReferenceT>; +}; + /// Given a container of pairs, return a range over the second elements. template auto make_second_range(ContainerTy &&c) { return llvm::map_range( diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index f2f3d41..b546643 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -527,7 +527,7 @@ DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( /// Accesses the Attribute value at this iterator position. Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { - auto owner = getFromOpaquePointer(object).cast(); + auto owner = getFromOpaquePointer(base).cast(); Type eltTy = owner.getType().getElementType(); if (auto intEltTy = eltTy.dyn_cast()) { if (intEltTy.getWidth() == 1) diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index ae635d1..0483c27 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -750,60 +750,41 @@ Operation *Operation::clone() { //===----------------------------------------------------------------------===// ValueRange::ValueRange(ArrayRef values) - : owner(values.data()), count(values.size()) {} + : ValueRange(values.data(), values.size()) {} ValueRange::ValueRange(llvm::iterator_range values) - : count(llvm::size(values)) { - if (count != 0) { + : ValueRange(nullptr, llvm::size(values)) { + if (!empty()) { auto begin = values.begin(); - owner = &begin.getObject()->getOpOperand(begin.getIndex()); + base = &begin.getBase()->getOpOperand(begin.getIndex()); } } ValueRange::ValueRange(llvm::iterator_range values) - : count(llvm::size(values)) { - if (count != 0) { + : ValueRange(nullptr, llvm::size(values)) { + if (!empty()) { auto begin = values.begin(); - owner = &begin.getObject()->getOpResult(begin.getIndex()); + base = &begin.getBase()->getOpResult(begin.getIndex()); } } -/// Drop the first N elements, and keep M elements. -ValueRange ValueRange::slice(unsigned n, unsigned m) const { - assert(n + m <= size() && "Invalid specifier"); - OwnerT newOwner; +/// See `detail::indexed_accessor_range_base` for details. +ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner, + ptrdiff_t index) { if (OpOperand *operand = owner.dyn_cast()) - newOwner = operand + n; - else if (OpResult *result = owner.dyn_cast()) - newOwner = result + n; - else - newOwner = owner.get() + n; - return ValueRange(newOwner, m); -} - -/// Drop the first n elements. -ValueRange ValueRange::drop_front(unsigned n) const { - assert(size() >= n && "Dropping more elements than exist"); - return slice(n, size() - n); -} - -/// Drop the last n elements. -ValueRange ValueRange::drop_back(unsigned n) const { - assert(size() >= n && "Dropping more elements than exist"); - return ValueRange(owner, size() - n); + return operand + index; + if (OpResult *result = owner.dyn_cast()) + return result + index; + return owner.get() + index; } - -ValueRange::Iterator::Iterator(OwnerT owner, unsigned curIndex) - : indexed_accessor_iterator( - owner, curIndex) {} - -Value *ValueRange::Iterator::operator*() const { +/// See `detail::indexed_accessor_range_base` for details. +Value *ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) { // Operands access the held value via 'get'. - if (OpOperand *operand = object.dyn_cast()) + if (OpOperand *operand = owner.dyn_cast()) return operand[index].get(); // An OpResult is a value, so we can return it directly. - if (OpResult *result = object.dyn_cast()) + if (OpResult *result = owner.dyn_cast()) return &result[index]; // Otherwise, this is a raw value array so just index directly. - return object.get()[index]; + return owner.get()[index]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Region.cpp b/mlir/lib/IR/Region.cpp index a5a19cb..c588e56 100644 --- a/mlir/lib/IR/Region.cpp +++ b/mlir/lib/IR/Region.cpp @@ -217,17 +217,23 @@ void llvm::ilist_traits<::mlir::Block>::transferNodesFromList( //===----------------------------------------------------------------------===// // RegionRange //===----------------------------------------------------------------------===// + RegionRange::RegionRange(MutableArrayRef regions) - : owner(regions.data()), count(regions.size()) {} + : RegionRange(regions.data(), regions.size()) {} RegionRange::RegionRange(ArrayRef> regions) - : owner(regions.data()), count(regions.size()) {} -RegionRange::Iterator::Iterator(OwnerT owner, unsigned curIndex) - : indexed_accessor_iterator( - owner, curIndex) {} - -Region *RegionRange::Iterator::operator*() const { - if (const std::unique_ptr *operand = - object.dyn_cast *>()) + : RegionRange(regions.data(), regions.size()) {} + +/// See `detail::indexed_accessor_range_base` for details. +RegionRange::OwnerT RegionRange::offset_base(const OwnerT &owner, + ptrdiff_t index) { + if (auto *operand = owner.dyn_cast *>()) + return operand + index; + return &owner.get()[index]; +} +/// See `detail::indexed_accessor_range_base` for details. +Region *RegionRange::dereference_iterator(const OwnerT &owner, + ptrdiff_t index) { + if (auto *operand = owner.dyn_cast *>()) return operand[index].get(); - return &object.get()[index]; + return &owner.get()[index]; } -- 2.7.4