From 9ed22ae5b8c8f286a992bca7ef4e4b3263c01116 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 10 Dec 2019 13:20:50 -0800 Subject: [PATCH] Refactor the various operand/result/type iterators to use indexed_accessor_range. This has several benefits: * The implementation is much cleaner and more efficient. * The ranges now have support for many useful operations: operator[], slice, drop_front, size, etc. * Value ranges can now directly query a range for their types via 'getTypes()': e.g: void foo(Operation::operand_range operands) { auto operandTypes = operands.getTypes(); } PiperOrigin-RevId: 284834912 --- mlir/include/mlir/IR/BlockSupport.h | 11 +- mlir/include/mlir/IR/Operation.h | 208 ++++---------------------------- mlir/include/mlir/IR/OperationSupport.h | 152 +++++++++++++++++++++++ mlir/include/mlir/IR/Region.h | 7 +- mlir/include/mlir/IR/TypeUtilities.h | 10 +- mlir/include/mlir/Support/STLExtras.h | 30 ++++- mlir/lib/Dialect/StandardOps/Ops.cpp | 6 +- mlir/lib/IR/Operation.cpp | 59 ++------- mlir/lib/IR/OperationSupport.cpp | 47 ++++++++ mlir/lib/IR/TypeUtilities.cpp | 12 +- 10 files changed, 279 insertions(+), 263 deletions(-) diff --git a/mlir/include/mlir/IR/BlockSupport.h b/mlir/include/mlir/IR/BlockSupport.h index 83fbee3..fd30c36 100644 --- a/mlir/include/mlir/IR/BlockSupport.h +++ b/mlir/include/mlir/IR/BlockSupport.h @@ -68,14 +68,12 @@ class SuccessorRange final : public detail::indexed_accessor_range_base { public: - using detail::indexed_accessor_range_base< - SuccessorRange, BlockOperand *, Block *, Block *, - Block *>::indexed_accessor_range_base; + using RangeBaseT::RangeBaseT; SuccessorRange(Block *block); private: /// See `detail::indexed_accessor_range_base` for details. - static BlockOperand *offset_object(BlockOperand *object, ptrdiff_t index) { + static BlockOperand *offset_base(BlockOperand *object, ptrdiff_t index) { return object + index; } /// See `detail::indexed_accessor_range_base` for details. @@ -83,9 +81,8 @@ private: return object[index].get(); } - /// Allow access to `offset_object` and `dereference_iterator`. - friend detail::indexed_accessor_range_base; + /// Allow access to `offset_base` and `dereference_iterator`. + friend RangeBaseT; }; } // end namespace mlir diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 037c4fc..ac78647 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -29,15 +29,6 @@ #include "llvm/ADT/Twine.h" namespace mlir { -class BlockAndValueMapping; -class Location; -class MLIRContext; -class OperandIterator; -class OperandTypeIterator; -struct OperationState; -class ResultIterator; -class ResultTypeIterator; - /// Terminator operations can have Block operands to represent successors. using BlockOperand = IROperandImpl; @@ -230,14 +221,14 @@ public: } // Support operand iteration. - using operand_iterator = OperandIterator; - using operand_range = llvm::iterator_range; + using operand_range = OperandRange; + using operand_iterator = operand_range::iterator; - operand_iterator operand_begin(); - operand_iterator operand_end(); + operand_iterator operand_begin() { return getOperands().begin(); } + operand_iterator operand_end() { return getOperands().end(); } /// Returns an iterator on the underlying Value's (Value *). - operand_range getOperands(); + operand_range getOperands() { return operand_range(this); } /// Erase the operand at position `idx`. void eraseOperand(unsigned idx) { getOperandStorage().eraseOperand(idx); } @@ -249,11 +240,11 @@ public: OpOperand &getOpOperand(unsigned idx) { return getOpOperands()[idx]; } // Support operand type iteration. - using operand_type_iterator = OperandTypeIterator; - using operand_type_range = llvm::iterator_range; - operand_type_iterator operand_type_begin(); - operand_type_iterator operand_type_end(); - operand_type_range getOperandTypes(); + using operand_type_iterator = operand_range::type_iterator; + using operand_type_range = iterator_range; + operand_type_iterator operand_type_begin() { return operand_begin(); } + operand_type_iterator operand_type_end() { return operand_end(); } + operand_type_range getOperandTypes() { return getOperands().getTypes(); } //===--------------------------------------------------------------------===// // Results @@ -266,14 +257,13 @@ public: Value *getResult(unsigned idx) { return &getOpResult(idx); } - // Support result iteration. - using result_iterator = ResultIterator; - using result_range = llvm::iterator_range; - - result_iterator result_begin(); - result_iterator result_end(); + /// Support result iteration. + using result_range = ResultRange; + using result_iterator = result_range::iterator; - result_range getResults(); + result_iterator result_begin() { return getResults().begin(); } + result_iterator result_end() { return getResults().end(); } + result_range getResults() { return result_range(this); } MutableArrayRef getOpResults() { return {getTrailingObjects(), numResults}; @@ -281,12 +271,12 @@ public: OpResult &getOpResult(unsigned idx) { return getOpResults()[idx]; } - // Support result type iteration. - using result_type_iterator = ResultTypeIterator; - using result_type_range = llvm::iterator_range; - result_type_iterator result_type_begin(); - result_type_iterator result_type_end(); - result_type_range getResultTypes(); + /// Support result type iteration. + using result_type_iterator = result_range::type_iterator; + using result_type_range = iterator_range; + result_type_iterator result_type_begin() { return result_begin(); } + result_type_iterator result_type_end() { return result_end(); } + result_type_range getResultTypes() { return getResults().getTypes(); } //===--------------------------------------------------------------------===// // Attributes @@ -657,91 +647,6 @@ inline raw_ostream &operator<<(raw_ostream &os, Operation &op) { return os; } -/// This class implements the const/non-const operand iterators for the -/// Operation class in terms of getOperand(idx). -class OperandIterator final - : public indexed_accessor_iterator { -public: - /// Initializes the operand iterator to the specified operand index. - OperandIterator(Operation *object, unsigned index) - : indexed_accessor_iterator(object, index) {} - - Value *operator*() const { return this->base->getOperand(this->index); } -}; - -/// This class implements the operand type iterators for the Operation -/// class in terms of operand_iterator->getType(). -class OperandTypeIterator final - : public llvm::mapped_iterator { - static Type unwrap(Value *value) { return value->getType(); } - -public: - using reference = Type; - - /// Provide a const deference method. - Type operator*() const { return unwrap(*I); } - - /// Initializes the operand type iterator to the specified operand iterator. - OperandTypeIterator(OperandIterator it) - : llvm::mapped_iterator(it, &unwrap) { - } -}; - -// Implement the inline operand iterator methods. -inline auto Operation::operand_begin() -> operand_iterator { - return operand_iterator(this, 0); -} - -inline auto Operation::operand_end() -> operand_iterator { - return operand_iterator(this, getNumOperands()); -} - -inline auto Operation::getOperands() -> operand_range { - return {operand_begin(), operand_end()}; -} - -inline auto Operation::operand_type_begin() -> operand_type_iterator { - return operand_type_iterator(operand_begin()); -} - -inline auto Operation::operand_type_end() -> operand_type_iterator { - return operand_type_iterator(operand_end()); -} - -inline auto Operation::getOperandTypes() -> operand_type_range { - return {operand_type_begin(), operand_type_end()}; -} - -/// This class implements the result iterators for the Operation class -/// in terms of getResult(idx). -class ResultIterator final - : public indexed_accessor_iterator { -public: - /// Initializes the result iterator to the specified index. - ResultIterator(Operation *base, unsigned index) - : indexed_accessor_iterator(base, index) {} - - Value *operator*() const { return this->base->getResult(this->index); } -}; - -/// This class implements the result type iterators for the Operation -/// class in terms of result_iterator->getType(). -class ResultTypeIterator final - : public llvm::mapped_iterator { - static Type unwrap(Value *value) { return value->getType(); } - -public: - using reference = Type; - - /// Initializes the result type iterator to the specified result iterator. - ResultTypeIterator(ResultIterator it) - : llvm::mapped_iterator(it, &unwrap) {} -}; - /// This class implements use iterator for the Operation. This iterates over all /// uses of all results of an Operation. class UseIterator final @@ -768,75 +673,6 @@ private: /// The use of the result. Value::use_iterator use; }; - -// Implement the inline result iterator methods. -inline auto Operation::result_begin() -> result_iterator { - return result_iterator(this, 0); -} - -inline auto Operation::result_end() -> result_iterator { - return result_iterator(this, getNumResults()); -} - -inline auto Operation::getResults() -> llvm::iterator_range { - return {result_begin(), result_end()}; -} - -inline auto Operation::result_type_begin() -> result_type_iterator { - return result_type_iterator(result_begin()); -} - -inline auto Operation::result_type_end() -> result_type_iterator { - return result_type_iterator(result_end()); -} - -inline auto Operation::getResultTypes() -> result_type_range { - return {result_type_begin(), result_type_end()}; -} - -/// This class provides an abstraction over the different types of ranges over -/// Value*s. In many cases, this prevents the need to explicitly materialize a -/// SmallVector/std::vector. This class should be used in places that are not -/// suitable for a more derived type (e.g. ArrayRef) or a template range -/// parameter. -class ValueRange - : public detail::indexed_accessor_range_base< - ValueRange, - llvm::PointerUnion, Value *, - Value *, Value *> { - /// The type representing the owner of this range. This is either a list of - /// values, operands, or results. - using OwnerT = llvm::PointerUnion; - -public: - using detail::indexed_accessor_range_base< - ValueRange, OwnerT, Value *, Value *, - Value *>::indexed_accessor_range_base; - - template , Arg>::value && - !std::is_convertible::value>> - ValueRange(Arg &&arg) - : ValueRange(ArrayRef(std::forward(arg))) {} - ValueRange(Value *const &value) : ValueRange(&value, /*count=*/1) {} - ValueRange(const std::initializer_list &values) - : ValueRange(ArrayRef(values)) {} - ValueRange(ArrayRef values = llvm::None); - ValueRange(iterator_range values); - ValueRange(iterator_range values); - -private: - /// See `detail::indexed_accessor_range_base` for details. - static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index); - /// See `detail::indexed_accessor_range_base` for details. - static Value *dereference_iterator(const OwnerT &owner, ptrdiff_t index); - - /// Allow access to `offset_base` and `dereference_iterator`. - friend detail::indexed_accessor_range_base; -}; - } // end namespace mlir namespace llvm { diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index 14ddf2d..0a0e1ac 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -60,6 +60,10 @@ template using OperandAdaptor = typename OpTy::OperandAdaptor; class OwningRewritePatternList; +//===----------------------------------------------------------------------===// +// AbstractOperation +//===----------------------------------------------------------------------===// + enum class OperationProperty { /// This bit is set for an operation if it is a commutative operation: that /// is a binary operator (two inputs) where "a op b" and "b op a" produce the @@ -201,6 +205,10 @@ private: bool (&hasRawTrait)(ClassID *traitID); }; +//===----------------------------------------------------------------------===// +// OperationName +//===----------------------------------------------------------------------===// + class OperationName { public: using RepresentationUnion = @@ -251,6 +259,10 @@ inline llvm::hash_code hash_value(OperationName arg) { return llvm::hash_value(arg.getAsOpaquePointer()); } +//===----------------------------------------------------------------------===// +// OperationState +//===----------------------------------------------------------------------===// + /// This represents an operation in an abstracted form, suitable for use with /// the builder APIs. This object is a large and heavy weight object meant to /// be used as a temporary object on the stack. It is generally unwise to put @@ -322,6 +334,10 @@ public: MLIRContext *getContext() { return location->getContext(); } }; +//===----------------------------------------------------------------------===// +// OperandStorage +//===----------------------------------------------------------------------===// + namespace detail { /// A utility class holding the information necessary to dynamically resize /// operands. @@ -445,6 +461,10 @@ private: }; } // end namespace detail +//===----------------------------------------------------------------------===// +// OpPrintingFlags +//===----------------------------------------------------------------------===// + /// Set of flags used to control the behavior of the various IR print methods /// (e.g. Operation::Print). class OpPrintingFlags { @@ -504,6 +524,138 @@ private: bool printLocalScope : 1; }; +//===----------------------------------------------------------------------===// +// Operation Value-Iterators +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// ValueTypeRange + +/// This class implements iteration on the types of a given range of values. +template +class ValueTypeIterator final + : public llvm::mapped_iterator { + static Type unwrap(Value *value) { return value->getType(); } + +public: + using reference = Type; + + /// Provide a const dereference method. + Type operator*() const { return unwrap(*this->I); } + + /// Initializes the type iterator to the specified value iterator. + ValueTypeIterator(ValueIteratorT it) + : llvm::mapped_iterator(it, &unwrap) {} +}; + +//===----------------------------------------------------------------------===// +// OperandRange + +/// This class implements the operand iterators for the Operation class. +class OperandRange final + : public detail::indexed_accessor_range_base { +public: + using RangeBaseT::RangeBaseT; + OperandRange(Operation *op); + + /// Returns the types of the values within this range. + using type_iterator = ValueTypeIterator; + iterator_range getTypes() const { return {begin(), end()}; } + +private: + /// See `detail::indexed_accessor_range_base` for details. + static OpOperand *offset_base(OpOperand *object, ptrdiff_t index) { + return object + index; + } + /// See `detail::indexed_accessor_range_base` for details. + static Value *dereference_iterator(OpOperand *object, ptrdiff_t index) { + return object[index].get(); + } + + /// Allow access to `offset_base` and `dereference_iterator`. + friend RangeBaseT; +}; + +//===----------------------------------------------------------------------===// +// ResultRange + +/// This class implements the result iterators for the Operation class. +class ResultRange final + : public detail::indexed_accessor_range_base { +public: + using RangeBaseT::RangeBaseT; + ResultRange(Operation *op); + + /// Returns the types of the values within this range. + using type_iterator = ValueTypeIterator; + iterator_range getTypes() const { return {begin(), end()}; } + +private: + /// See `detail::indexed_accessor_range_base` for details. + static OpResult *offset_base(OpResult *object, ptrdiff_t index) { + return object + index; + } + /// See `detail::indexed_accessor_range_base` for details. + static Value *dereference_iterator(OpResult *object, ptrdiff_t index) { + return &object[index]; + } + + /// Allow access to `offset_base` and `dereference_iterator`. + friend RangeBaseT; +}; + +//===----------------------------------------------------------------------===// +// ValueRange + +/// This class provides an abstraction over the different types of ranges over +/// Value*s. In many cases, this prevents the need to explicitly materialize a +/// SmallVector/std::vector. This class should be used in places that are not +/// suitable for a more derived type (e.g. ArrayRef) or a template range +/// parameter. +class ValueRange final + : public detail::indexed_accessor_range_base< + ValueRange, + llvm::PointerUnion, Value *, + Value *, Value *> { +public: + using RangeBaseT::RangeBaseT; + + template , Arg>::value && + !std::is_convertible::value>> + ValueRange(Arg &&arg) + : ValueRange(ArrayRef(std::forward(arg))) {} + ValueRange(Value *const &value) : ValueRange(&value, /*count=*/1) {} + ValueRange(const std::initializer_list &values) + : ValueRange(ArrayRef(values)) {} + ValueRange(iterator_range values) + : ValueRange(OperandRange(values)) {} + ValueRange(iterator_range values) + : ValueRange(ResultRange(values)) {} + ValueRange(ArrayRef values = llvm::None); + ValueRange(OperandRange values); + ValueRange(ResultRange values); + + /// Returns the types of the values within this range. + using type_iterator = ValueTypeIterator; + iterator_range getTypes() const { return {begin(), end()}; } + +private: + /// The type representing the owner of this range. This is either a list of + /// values, operands, or results. + using OwnerT = llvm::PointerUnion; + + /// See `detail::indexed_accessor_range_base` for details. + static OwnerT offset_base(const OwnerT &owner, ptrdiff_t index); + /// See `detail::indexed_accessor_range_base` for details. + static Value *dereference_iterator(const OwnerT &owner, ptrdiff_t index); + + /// Allow access to `offset_base` and `dereference_iterator`. + friend RangeBaseT; +}; } // end namespace mlir namespace llvm { diff --git a/mlir/include/mlir/IR/Region.h b/mlir/include/mlir/IR/Region.h index 3d25140..27b20c2 100644 --- a/mlir/include/mlir/IR/Region.h +++ b/mlir/include/mlir/IR/Region.h @@ -175,9 +175,7 @@ class RegionRange using OwnerT = llvm::PointerUnion *>; public: - using detail::indexed_accessor_range_base< - RegionRange, OwnerT, Region *, Region *, - Region *>::indexed_accessor_range_base; + using RangeBaseT::RangeBaseT; RegionRange(MutableArrayRef regions = llvm::None); @@ -196,8 +194,7 @@ private: static Region *dereference_iterator(const OwnerT &owner, ptrdiff_t index); /// Allow access to `offset_base` and `dereference_iterator`. - friend detail::indexed_accessor_range_base; + friend RangeBaseT; }; } // end namespace mlir diff --git a/mlir/include/mlir/IR/TypeUtilities.h b/mlir/include/mlir/IR/TypeUtilities.h index 49d57e8..6512f8f 100644 --- a/mlir/include/mlir/IR/TypeUtilities.h +++ b/mlir/include/mlir/IR/TypeUtilities.h @@ -65,13 +65,14 @@ LogicalResult verifyCompatibleShape(Type type1, Type type2); // An iterator for the element types of an op's operands of shaped types. class OperandElementTypeIterator final - : public llvm::mapped_iterator { + : public llvm::mapped_iterator { public: using reference = Type; /// Initializes the result element type iterator to the specified operand /// iterator. - explicit OperandElementTypeIterator(OperandIterator it); + explicit OperandElementTypeIterator(Operation::operand_iterator it); private: static Type unwrap(Value *value); @@ -82,13 +83,14 @@ using OperandElementTypeRange = // An iterator for the tensor element types of an op's results of shaped types. class ResultElementTypeIterator final - : public llvm::mapped_iterator { + : public llvm::mapped_iterator { public: using reference = Type; /// Initializes the result element type iterator to the specified result /// iterator. - explicit ResultElementTypeIterator(ResultIterator it); + explicit ResultElementTypeIterator(Operation::result_iterator it); private: static Type unwrap(Value *value); diff --git a/mlir/include/mlir/Support/STLExtras.h b/mlir/include/mlir/Support/STLExtras.h index 07db06a..c98f925 100644 --- a/mlir/include/mlir/Support/STLExtras.h +++ b/mlir/include/mlir/Support/STLExtras.h @@ -204,6 +204,9 @@ template class indexed_accessor_range_base { public: + using RangeBaseT = + indexed_accessor_range_base; + /// An iterator element of this range. class iterator : public indexed_accessor_iterator { @@ -223,11 +226,17 @@ public: ReferenceT>; }; + indexed_accessor_range_base(iterator begin, iterator end) + : base(DerivedT::offset_base(begin.getBase(), begin.getIndex())), + count(end.getIndex() - begin.getIndex()) {} + indexed_accessor_range_base(const iterator_range &range) + : indexed_accessor_range_base(range.begin(), range.end()) {} + iterator begin() const { return iterator(base, 0); } iterator end() const { return iterator(base, count); } ReferenceT operator[](unsigned index) const { assert(index < size() && "invalid index for value range"); - return *std::next(begin(), index); + return DerivedT::dereference_iterator(base, index); } /// Return the size of this range. @@ -237,22 +246,35 @@ public: bool empty() const { return size() == 0; } /// Drop the first N elements, and keep M elements. - DerivedT slice(unsigned n, unsigned m) const { + DerivedT slice(size_t n, size_t m) const { assert(n + m <= size() && "invalid size specifiers"); return DerivedT(DerivedT::offset_base(base, n), m); } /// Drop the first n elements. - DerivedT drop_front(unsigned n = 1) const { + DerivedT drop_front(size_t n = 1) const { assert(size() >= n && "Dropping more elements than exist"); return slice(n, size() - n); } /// Drop the last n elements. - DerivedT drop_back(unsigned n = 1) const { + DerivedT drop_back(size_t n = 1) const { assert(size() >= n && "Dropping more elements than exist"); return DerivedT(base, size() - n); } + /// Take the first n elements. + DerivedT take_front(size_t n = 1) const { + return n < size() ? drop_back(size() - n) + : static_cast(*this); + } + + /// Allow conversion to SmallVector if necessary. + /// TODO(riverriddle) Remove this when SmallVector accepts different range + /// types in its constructor. + template operator SmallVector() const { + return {begin(), end()}; + } + protected: indexed_accessor_range_base(BaseT base, ptrdiff_t count) : base(base), count(count) {} diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index ee90cea..7726c04 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -2819,7 +2819,7 @@ public: return matchFailure(); } SmallVector staticShape(subViewOp.getNumSizes()); - for (auto size : enumerate(subViewOp.sizes())) { + for (auto size : llvm::enumerate(subViewOp.sizes())) { auto defOp = size.value()->getDefiningOp(); assert(defOp); staticShape[size.index()] = cast(defOp).getValue(); @@ -2865,7 +2865,7 @@ public: } SmallVector staticStrides(subViewOp.getNumStrides()); - for (auto stride : enumerate(subViewOp.strides())) { + for (auto stride : llvm::enumerate(subViewOp.strides())) { auto defOp = stride.value()->getDefiningOp(); assert(defOp); assert(baseStrides[stride.index()] > 0); @@ -2916,7 +2916,7 @@ public: } auto staticOffset = baseOffset; - for (auto offset : enumerate(subViewOp.offsets())) { + for (auto offset : llvm::enumerate(subViewOp.offsets())) { auto defOp = offset.value()->getDefiningOp(); assert(defOp); assert(baseStrides[offset.index()] > 0); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 0483c27..fd747a9 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -597,9 +597,8 @@ void Operation::setSuccessor(Block *block, unsigned index) { } auto Operation::getNonSuccessorOperands() -> operand_range { - return {operand_iterator(this, 0), - operand_iterator(this, hasSuccessors() ? getSuccessorOperandIndex(0) - : getNumOperands())}; + return getOperands().take_front(hasSuccessors() ? getSuccessorOperandIndex(0) + : getNumOperands()); } /// Get the index of the first operand of the successor at the provided @@ -635,9 +634,7 @@ Operation::decomposeSuccessorOperandIndex(unsigned operandIndex) { auto Operation::getSuccessorOperands(unsigned index) -> operand_range { unsigned succOperandIndex = getSuccessorOperandIndex(index); - return {operand_iterator(this, succOperandIndex), - operand_iterator(this, - succOperandIndex + getNumSuccessorOperands(index))}; + return getOperands().slice(succOperandIndex, getNumSuccessorOperands(index)); } /// Attempt to fold this operation using the Op's registered foldHook. @@ -746,48 +743,6 @@ Operation *Operation::clone() { } //===----------------------------------------------------------------------===// -// ValueRange -//===----------------------------------------------------------------------===// - -ValueRange::ValueRange(ArrayRef values) - : ValueRange(values.data(), values.size()) {} -ValueRange::ValueRange(llvm::iterator_range values) - : ValueRange(nullptr, llvm::size(values)) { - if (!empty()) { - auto begin = values.begin(); - base = &begin.getBase()->getOpOperand(begin.getIndex()); - } -} -ValueRange::ValueRange(llvm::iterator_range values) - : ValueRange(nullptr, llvm::size(values)) { - if (!empty()) { - auto begin = values.begin(); - base = &begin.getBase()->getOpResult(begin.getIndex()); - } -} - -/// See `detail::indexed_accessor_range_base` for details. -ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner, - ptrdiff_t index) { - if (OpOperand *operand = owner.dyn_cast()) - return operand + index; - if (OpResult *result = owner.dyn_cast()) - return result + index; - return owner.get() + index; -} -/// See `detail::indexed_accessor_range_base` for details. -Value *ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) { - // Operands access the held value via 'get'. - if (OpOperand *operand = owner.dyn_cast()) - return operand[index].get(); - // An OpResult is a value, so we can return it directly. - if (OpResult *result = owner.dyn_cast()) - return &result[index]; - // Otherwise, this is a raw value array so just index directly. - return owner.get()[index]; -} - -//===----------------------------------------------------------------------===// // OpState trait class. //===----------------------------------------------------------------------===// @@ -979,7 +934,7 @@ OpTrait::impl::verifySameOperandsAndResultElementType(Operation *op) { auto elementType = getElementTypeOrSelf(op->getResult(0)); // Verify result element type matches first result's element type. - for (auto result : drop_begin(op->getResults(), 1)) { + for (auto result : llvm::drop_begin(op->getResults(), 1)) { if (getElementTypeOrSelf(result) != elementType) return op->emitOpError( "requires the same element type for all operands and results"); @@ -1210,7 +1165,7 @@ Value *impl::foldCastOp(Operation *op) { } //===----------------------------------------------------------------------===// -// CastOp implementation +// Misc. utils //===----------------------------------------------------------------------===// /// Insert an operation, generated by `buildTerminatorOp`, at the end of the @@ -1230,6 +1185,10 @@ void impl::ensureRegionTerminator( block.push_back(buildTerminatorOp()); } +//===----------------------------------------------------------------------===// +// UseIterator +//===----------------------------------------------------------------------===// + UseIterator::UseIterator(Operation *op, bool end) : op(op), res(end ? op->result_end() : op->result_begin()) { // Only initialize current use if there are results/can be uses. diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp index e4ff889..256a261 100644 --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -144,3 +144,50 @@ void detail::OperandStorage::grow(ResizableStorage &resizeUtil, operand.~OpOperand(); resizeUtil.setDynamicStorage(newStorage); } + +//===----------------------------------------------------------------------===// +// Operation Value-Iterators +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// OperandRange + +OperandRange::OperandRange(Operation *op) + : OperandRange(op->getOpOperands().data(), op->getNumOperands()) {} + +//===----------------------------------------------------------------------===// +// ResultRange + +ResultRange::ResultRange(Operation *op) + : ResultRange(op->getOpResults().data(), op->getNumResults()) {} + +//===----------------------------------------------------------------------===// +// ValueRange + +ValueRange::ValueRange(ArrayRef values) + : ValueRange(values.data(), values.size()) {} +ValueRange::ValueRange(OperandRange values) + : ValueRange(values.begin().getBase(), values.size()) {} +ValueRange::ValueRange(ResultRange values) + : ValueRange(values.begin().getBase(), values.size()) {} + +/// See `detail::indexed_accessor_range_base` for details. +ValueRange::OwnerT ValueRange::offset_base(const OwnerT &owner, + ptrdiff_t index) { + if (OpOperand *operand = owner.dyn_cast()) + return operand + index; + if (OpResult *result = owner.dyn_cast()) + return result + index; + return owner.get() + index; +} +/// See `detail::indexed_accessor_range_base` for details. +Value *ValueRange::dereference_iterator(const OwnerT &owner, ptrdiff_t index) { + // Operands access the held value via 'get'. + if (OpOperand *operand = owner.dyn_cast()) + return operand[index].get(); + // An OpResult is a value, so we can return it directly. + if (OpResult *result = owner.dyn_cast()) + return &result[index]; + // Otherwise, this is a raw value array so just index directly. + return owner.get()[index]; +} diff --git a/mlir/lib/IR/TypeUtilities.cpp b/mlir/lib/IR/TypeUtilities.cpp index a963a8d..0172141 100644 --- a/mlir/lib/IR/TypeUtilities.cpp +++ b/mlir/lib/IR/TypeUtilities.cpp @@ -92,15 +92,19 @@ LogicalResult mlir::verifyCompatibleShape(Type type1, Type type2) { return success(); } -OperandElementTypeIterator::OperandElementTypeIterator(OperandIterator it) - : llvm::mapped_iterator(it, &unwrap) {} +OperandElementTypeIterator::OperandElementTypeIterator( + Operation::operand_iterator it) + : llvm::mapped_iterator( + it, &unwrap) {} Type OperandElementTypeIterator::unwrap(Value *value) { return value->getType().cast().getElementType(); } -ResultElementTypeIterator::ResultElementTypeIterator(ResultIterator it) - : llvm::mapped_iterator(it, &unwrap) {} +ResultElementTypeIterator::ResultElementTypeIterator( + Operation::result_iterator it) + : llvm::mapped_iterator( + it, &unwrap) {} Type ResultElementTypeIterator::unwrap(Value *value) { return value->getType().cast().getElementType(); -- 2.7.4