From 9d1a0c72b4ae54b97809966257bd1b9cb3140dfe Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 6 Dec 2019 20:06:48 -0800 Subject: [PATCH] Add a new ValueRange class. This class represents a generic abstraction over the different ways to represent a range of Values: ArrayRef, operand_range, result_range. This class will allow for removing the many instances of explicit SmallVector construction. It has the same memory cost as ArrayRef, and only suffers cost from indexing(if+elsing the different underlying representations). This change only updates a few of the existing usages, with more to be changed in followups; e.g. 'build' API. PiperOrigin-RevId: 284307996 --- mlir/include/mlir/IR/OpImplementation.h | 5 +- mlir/include/mlir/IR/Operation.h | 58 ++++++++++++++++++++++ mlir/include/mlir/IR/PatternMatch.h | 15 +++--- mlir/include/mlir/Support/LLVM.h | 2 + mlir/include/mlir/Support/STLExtras.h | 6 +++ mlir/include/mlir/Transforms/DialectConversion.h | 7 ++- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 3 +- mlir/lib/Dialect/AffineOps/AffineOps.cpp | 23 +++------ mlir/lib/Dialect/StandardOps/Ops.cpp | 2 +- mlir/lib/Dialect/VectorOps/VectorToVector.cpp | 5 +- mlir/lib/IR/AsmPrinter.cpp | 7 ++- mlir/lib/IR/Operation.cpp | 36 ++++++++++++++ mlir/lib/IR/PatternMatch.cpp | 17 +++---- mlir/lib/Transforms/DialectConversion.cpp | 19 ++++--- 14 files changed, 143 insertions(+), 62 deletions(-) diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index 666a90e..3052f79 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -95,15 +95,14 @@ public: /// SSA values in namesToUse. This may only be used for IsolatedFromAbove /// operations. If any entry in namesToUse is null, the corresponding /// argument name is left alone. - virtual void shadowRegionArgs(Region ®ion, - ArrayRef namesToUse) = 0; + virtual void shadowRegionArgs(Region ®ion, ValueRange namesToUse) = 0; /// Prints an affine map of SSA ids, where SSA id names are used in place /// of dims/symbols. /// Operand values must come from single-result sources, and be valid /// dimensions/symbol identifiers according to mlir::isValidDim/Symbol. virtual void printAffineMapOfSSAIds(AffineMapAttr mapAttr, - ArrayRef operands) = 0; + ValueRange operands) = 0; /// Print an optional arrow followed by a type list. void printOptionalArrowTypeList(ArrayRef types) { diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h index 1d9a401..a70970d 100644 --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -798,6 +798,64 @@ inline auto Operation::getResultTypes() -> result_type_range { return {result_type_begin(), result_type_end()}; } +/// This class provides an abstraction over the different types of ranges over +/// Value*s. In many cases, this prevents the need to explicitly materialize a +/// SmallVector/std::vector. This class should be used in places that are not +/// suitable for a more derived type(e.g. ArrayRef) or a template range +/// parameter. +class ValueRange { + /// The type representing the owner of this range. This is either a list of + /// values, operands, or results. + using OwnerT = llvm::PointerUnion; + +public: + ValueRange(const ValueRange &) = default; + ValueRange(ValueRange &&) = default; + + template , Arg>::value>> + ValueRange(Arg &&arg) + : ValueRange(ArrayRef(std::forward(arg))) {} + ValueRange(const std::initializer_list &values) + : ValueRange(ArrayRef(values)) {} + ValueRange(ArrayRef values = llvm::None); + ValueRange(iterator_range values); + ValueRange(iterator_range values); + + /// An iterator element of this range. + class Iterator : public indexed_accessor_iterator { + public: + Value *operator*() const; + + private: + Iterator(OwnerT owner, unsigned curIndex); + + /// Allow access to the constructor. + friend ValueRange; + }; + + Iterator begin() const { return Iterator(owner, 0); } + Iterator end() const { return Iterator(owner, count); } + Value *operator[](unsigned index) const { + assert(index < size() && "invalid index for value range"); + return *std::next(begin(), index); + } + + /// Return the size of this range. + size_t size() const { return count; } + + /// Return if the range is empty. + bool empty() const { return size() == 0; } + +private: + /// The object that owns the provided range of values. + OwnerT owner; + /// The size from the owning range. + unsigned count; +}; + } // end namespace mlir namespace llvm { diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 7e5596df..366d2b8 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -331,9 +331,9 @@ public: /// clients can specify a list of other nodes that this replacement may make /// (perhaps transitively) dead. If any of those values are dead, this will /// remove them as well. - virtual void replaceOp(Operation *op, ArrayRef newValues, - ArrayRef valuesToRemoveIfDead); - void replaceOp(Operation *op, ArrayRef newValues) { + virtual void replaceOp(Operation *op, ValueRange newValues, + ValueRange valuesToRemoveIfDead); + void replaceOp(Operation *op, ValueRange newValues) { replaceOp(op, newValues, llvm::None); } @@ -349,7 +349,7 @@ public: /// The result values of the two ops must be the same types. This allows /// specifying a list of ops that may be removed if dead. template - void replaceOpWithNewOp(ArrayRef valuesToRemoveIfDead, Operation *op, + void replaceOpWithNewOp(ValueRange valuesToRemoveIfDead, Operation *op, Args &&... args) { auto newOp = create(op->getLoc(), std::forward(args)...); replaceOpWithResultsOfAnotherOp(op, newOp.getOperation(), @@ -364,7 +364,7 @@ public: /// 'argValues' is used to replace the block arguments of 'source' after /// merging. virtual void mergeBlocks(Block *source, Block *dest, - ArrayRef argValues = llvm::None); + ValueRange argValues = llvm::None); /// Split the operations starting at "before" (inclusive) out of the given /// block into a new block, and return it. @@ -378,8 +378,7 @@ public: /// The valuesToRemoveIfDead list is an optional list of values that the /// rewriter should remove if they are dead at this point. /// - void updatedRootInPlace(Operation *op, - ArrayRef valuesToRemoveIfDead = {}); + void updatedRootInPlace(Operation *op, ValueRange valuesToRemoveIfDead = {}); protected: explicit PatternRewriter(MLIRContext *ctx) : OpBuilder(ctx) {} @@ -406,7 +405,7 @@ private: /// op and newOp are known to have the same number of results, replace the /// uses of op with uses of newOp void replaceOpWithResultsOfAnotherOp(Operation *op, Operation *newOp, - ArrayRef valuesToRemoveIfDead); + ValueRange valuesToRemoveIfDead); }; //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Support/LLVM.h b/mlir/include/mlir/Support/LLVM.h index 6ae8d5c..91d145d 100644 --- a/mlir/include/mlir/Support/LLVM.h +++ b/mlir/include/mlir/Support/LLVM.h @@ -56,6 +56,7 @@ template class DenseSet; template class DenseMap; template class function_ref; +template class iterator_range; // Other common classes. class raw_ostream; @@ -82,6 +83,7 @@ using DenseMap = llvm::DenseMap; template > using DenseSet = llvm::DenseSet; template using function_ref = llvm::function_ref; +using llvm::iterator_range; using llvm::MutableArrayRef; using llvm::None; using llvm::Optional; diff --git a/mlir/include/mlir/Support/STLExtras.h b/mlir/include/mlir/Support/STLExtras.h index 24e2ac6..95e52f9 100644 --- a/mlir/include/mlir/Support/STLExtras.h +++ b/mlir/include/mlir/Support/STLExtras.h @@ -177,6 +177,12 @@ public: return static_cast(*this); } + /// Returns the current index of the iterator. + ptrdiff_t getIndex() const { return index; } + + /// Returns the current object of the iterator. + const ObjectType &getObject() const { return object; } + protected: indexed_accessor_iterator(ObjectType object, ptrdiff_t index) : object(object), index(index) {} diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 8866950..fee58a4 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -347,8 +347,8 @@ public: //===--------------------------------------------------------------------===// /// PatternRewriter hook for replacing the results of an operation. - void replaceOp(Operation *op, ArrayRef newValues, - ArrayRef valuesToRemoveIfDead) override; + void replaceOp(Operation *op, ValueRange newValues, + ValueRange valuesToRemoveIfDead) override; using PatternRewriter::replaceOp; /// PatternRewriter hook for erasing a dead operation. The uses of this @@ -360,8 +360,7 @@ public: Block *splitBlock(Block *block, Block::iterator before) override; /// PatternRewriter hook for merging a block into another. - void mergeBlocks(Block *source, Block *dest, - ArrayRef argValues) override; + void mergeBlocks(Block *source, Block *dest, ValueRange argValues) override; /// PatternRewriter hook for moving blocks out of a region. void inlineRegionBefore(Region ®ion, Region &parent, diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index c1a7a33..c2fd0aa 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -1118,8 +1118,7 @@ struct CallOpInterfaceLowering : public LLVMLegalizationPattern { // If < 2 results, packing did not do anything and we can just return. if (numResults < 2) { - SmallVector results(newOp.getResults()); - rewriter.replaceOp(op, results); + rewriter.replaceOp(op, newOp.getResults()); return this->matchSuccess(); } diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp index 7232c6e..689207c 100644 --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -863,14 +863,11 @@ void AffineDmaStartOp::build(Builder *builder, OperationState &result, void AffineDmaStartOp::print(OpAsmPrinter &p) { p << "affine.dma_start " << *getSrcMemRef() << '['; - SmallVector operands(getSrcIndices()); - p.printAffineMapOfSSAIds(getSrcMapAttr(), operands); + p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices()); p << "], " << *getDstMemRef() << '['; - operands.assign(getDstIndices().begin(), getDstIndices().end()); - p.printAffineMapOfSSAIds(getDstMapAttr(), operands); + p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices()); p << "], " << *getTagMemRef() << '['; - operands.assign(getTagIndices().begin(), getTagIndices().end()); - p.printAffineMapOfSSAIds(getTagMapAttr(), operands); + p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices()); p << "], " << *getNumElements(); if (isStrided()) { p << ", " << *getStride(); @@ -1827,11 +1824,8 @@ ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) { void AffineLoadOp::print(OpAsmPrinter &p) { p << "affine.load " << *getMemRef() << '['; - AffineMapAttr mapAttr = getAttrOfType(getMapAttrName()); - if (mapAttr) { - SmallVector operands(getMapOperands()); - p.printAffineMapOfSSAIds(mapAttr, operands); - } + if (AffineMapAttr mapAttr = getAttrOfType(getMapAttrName())) + p.printAffineMapOfSSAIds(mapAttr, getMapOperands()); p << ']'; p.printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()}); p << " : " << getMemRefType(); @@ -1922,11 +1916,8 @@ ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) { void AffineStoreOp::print(OpAsmPrinter &p) { p << "affine.store " << *getValueToStore(); p << ", " << *getMemRef() << '['; - AffineMapAttr mapAttr = getAttrOfType(getMapAttrName()); - if (mapAttr) { - SmallVector operands(getMapOperands()); - p.printAffineMapOfSSAIds(mapAttr, operands); - } + if (AffineMapAttr mapAttr = getAttrOfType(getMapAttrName())) + p.printAffineMapOfSSAIds(mapAttr, getMapOperands()); p << ']'; p.printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()}); p << " : " << getMemRefType(); diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index a9e9364..34e0ecb 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -477,7 +477,7 @@ struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern { return matchFailure(); // Merge the successor into the current block and erase the branch. - rewriter.mergeBlocks(succ, opParent, llvm::to_vector<1>(op.getOperands())); + rewriter.mergeBlocks(succ, opParent, op.getOperands()); rewriter.eraseOp(op); return matchSuccess(); } diff --git a/mlir/lib/Dialect/VectorOps/VectorToVector.cpp b/mlir/lib/Dialect/VectorOps/VectorToVector.cpp index c2726ed..82d19f5 100644 --- a/mlir/lib/Dialect/VectorOps/VectorToVector.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorToVector.cpp @@ -544,10 +544,7 @@ struct ConvertMatchingFakeForkFakeJoinOp : public RewritePattern { "]: ConvertMatchingFakeForkFakeJoinOp on op: " << *op << " in func:\n"); LLVM_DEBUG(op->getParentOfType().print(dbgs())); - SmallVector forwardedOperands; - forwardedOperands.append(definingOp->getOperands().begin(), - definingOp->getOperands().end()); - rewriter.replaceOp(op, forwardedOperands); + rewriter.replaceOp(op, definingOp->getOperands()); return matchSuccess(); } }; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index ed97b8b..0ea447e 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1486,10 +1486,10 @@ public: /// SSA values in namesToUse. This may only be used for IsolatedFromAbove /// operations. If any entry in namesToUse is null, the corresponding /// argument name is left alone. - void shadowRegionArgs(Region ®ion, ArrayRef namesToUse) override; + void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override; void printAffineMapOfSSAIds(AffineMapAttr mapAttr, - ArrayRef operands) override { + ValueRange operands) override { AffineMap map = mapAttr.getValue(); unsigned numDims = map.getNumDims(); auto printValueName = [&](unsigned pos, bool isSymbol) { @@ -1851,8 +1851,7 @@ void OperationPrinter::printValueIDImpl(Value *value, bool printResultNo, /// SSA values in namesToUse. This may only be used for IsolatedFromAbove /// operations. If any entry in namesToUse is null, the corresponding /// argument name is left alone. -void OperationPrinter::shadowRegionArgs(Region ®ion, - ArrayRef namesToUse) { +void OperationPrinter::shadowRegionArgs(Region ®ion, ValueRange namesToUse) { assert(!region.empty() && "cannot shadow arguments of an empty region"); assert(region.front().getNumArguments() == namesToUse.size() && "incorrect number of names passed in"); diff --git a/mlir/lib/IR/Operation.cpp b/mlir/lib/IR/Operation.cpp index 1d213f4..3ebc712 100644 --- a/mlir/lib/IR/Operation.cpp +++ b/mlir/lib/IR/Operation.cpp @@ -740,6 +740,42 @@ Operation *Operation::clone() { } //===----------------------------------------------------------------------===// +// ValueRange +//===----------------------------------------------------------------------===// + +ValueRange::ValueRange(ArrayRef values) + : owner(values.data()), count(values.size()) {} +ValueRange::ValueRange(llvm::iterator_range values) + : count(llvm::size(values)) { + if (count != 0) { + auto begin = values.begin(); + owner = &begin.getObject()->getOpOperand(begin.getIndex()); + } +} +ValueRange::ValueRange(llvm::iterator_range values) + : count(llvm::size(values)) { + if (count != 0) { + auto begin = values.begin(); + owner = &begin.getObject()->getOpResult(begin.getIndex()); + } +} + +ValueRange::Iterator::Iterator(OwnerT owner, unsigned curIndex) + : indexed_accessor_iterator( + owner, curIndex) {} + +Value *ValueRange::Iterator::operator*() const { + // Operands access the held value via 'get'. + if (OpOperand *operand = object.dyn_cast()) + return operand[index].get(); + // An OpResult is a value, so we can return it directly. + if (OpResult *result = object.dyn_cast()) + return &result[index]; + // Otherwise, this is a raw value array so just index directly. + return object.get()[index]; +} + +//===----------------------------------------------------------------------===// // OpState trait class. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index b8ecab9..3887a03 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -85,8 +85,8 @@ PatternRewriter::~PatternRewriter() { /// clients can specify a list of other nodes that this replacement may make /// (perhaps transitively) dead. If any of those ops are dead, this will /// remove them as well. -void PatternRewriter::replaceOp(Operation *op, ArrayRef newValues, - ArrayRef valuesToRemoveIfDead) { +void PatternRewriter::replaceOp(Operation *op, ValueRange newValues, + ValueRange valuesToRemoveIfDead) { // Notify the rewriter subclass that we're about to replace this root. notifyRootReplaced(op); @@ -114,7 +114,7 @@ void PatternRewriter::eraseOp(Operation *op) { /// 'argValues' is used to replace the block arguments of 'source' after /// merging. void PatternRewriter::mergeBlocks(Block *source, Block *dest, - ArrayRef argValues) { + ValueRange argValues) { assert(llvm::all_of(source->getPredecessors(), [dest](Block *succ) { return succ == dest; }) && "expected 'source' to have no predecessors or only 'dest'"); @@ -141,15 +141,12 @@ Block *PatternRewriter::splitBlock(Block *block, Block::iterator before) { /// op and newOp are known to have the same number of results, replace the /// uses of op with uses of newOp void PatternRewriter::replaceOpWithResultsOfAnotherOp( - Operation *op, Operation *newOp, ArrayRef valuesToRemoveIfDead) { + Operation *op, Operation *newOp, ValueRange valuesToRemoveIfDead) { assert(op->getNumResults() == newOp->getNumResults() && "replacement op doesn't match results of original op"); if (op->getNumResults() == 1) return replaceOp(op, newOp->getResult(0), valuesToRemoveIfDead); - - SmallVector newResults(newOp->getResults().begin(), - newOp->getResults().end()); - return replaceOp(op, newResults, valuesToRemoveIfDead); + return replaceOp(op, newOp->getResults(), valuesToRemoveIfDead); } /// Move the blocks that belong to "region" before the given position in @@ -190,8 +187,8 @@ void PatternRewriter::cloneRegionBefore(Region ®ion, Block *before) { /// The opsToRemoveIfDead list is an optional list of nodes that the rewriter /// should remove if they are dead at this point. /// -void PatternRewriter::updatedRootInPlace( - Operation *op, ArrayRef valuesToRemoveIfDead) { +void PatternRewriter::updatedRootInPlace(Operation *op, + ValueRange valuesToRemoveIfDead) { // Notify the rewriter subclass that we're about to replace this root. notifyRootUpdated(op); diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index b1feea6..6d34db9 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -416,7 +416,7 @@ struct ConversionPatternRewriterImpl { /// This class represents one requested operation replacement via 'replaceOp'. struct OpReplacement { OpReplacement() = default; - OpReplacement(Operation *op, ArrayRef newValues) + OpReplacement(Operation *op, ValueRange newValues) : op(op), newValues(newValues.begin(), newValues.end()) {} Operation *op; @@ -501,8 +501,8 @@ struct ConversionPatternRewriterImpl { TypeConverter::SignatureConversion &conversion); /// PatternRewriter hook for replacing the results of an operation. - void replaceOp(Operation *op, ArrayRef newValues, - ArrayRef valuesToRemoveIfDead); + void replaceOp(Operation *op, ValueRange newValues, + ValueRange valuesToRemoveIfDead); /// Notifies that a block was split. void notifySplitBlock(Block *block, Block *continuation); @@ -687,9 +687,9 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( return nullptr; } -void ConversionPatternRewriterImpl::replaceOp( - Operation *op, ArrayRef newValues, - ArrayRef valuesToRemoveIfDead) { +void ConversionPatternRewriterImpl::replaceOp(Operation *op, + ValueRange newValues, + ValueRange valuesToRemoveIfDead) { assert(newValues.size() == op->getNumResults()); // Create mappings for each of the new result values. @@ -769,9 +769,8 @@ ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx, ConversionPatternRewriter::~ConversionPatternRewriter() {} /// PatternRewriter hook for replacing the results of an operation. -void ConversionPatternRewriter::replaceOp( - Operation *op, ArrayRef newValues, - ArrayRef valuesToRemoveIfDead) { +void ConversionPatternRewriter::replaceOp(Operation *op, ValueRange newValues, + ValueRange valuesToRemoveIfDead) { LLVM_DEBUG(llvm::dbgs() << "** Replacing operation : " << op->getName() << "\n"); impl->replaceOp(op, newValues, valuesToRemoveIfDead); @@ -826,7 +825,7 @@ Block *ConversionPatternRewriter::splitBlock(Block *block, /// PatternRewriter hook for merging a block into another. void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest, - ArrayRef argValues) { + ValueRange argValues) { // TODO(riverriddle) This requires fixing the implementation of // 'replaceUsesOfBlockArgument', which currently isn't undoable. llvm_unreachable("block merging updates are currently not supported"); -- 2.7.4