From 8b447b6cad22c36a3fa653b4bdea9fc1d2fd2915 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 18 Jul 2019 12:04:57 -0700 Subject: [PATCH] NFC: Expose a ConversionPatternRewriter for use with ConversionPatterns. This specific PatternRewriter will allow for exposing hooks in the future that are only useful for the conversion framework, e.g. type conversions. PiperOrigin-RevId: 258818122 --- .../Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp | 20 +- .../Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp | 10 +- mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp | 9 +- mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 25 +- mlir/include/mlir/IR/PatternMatch.h | 5 +- mlir/include/mlir/Transforms/DialectConversion.h | 59 ++- .../ControlFlowToCFG/ConvertControlFlowToCFG.cpp | 19 +- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 67 +-- mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp | 63 +-- mlir/lib/Transforms/DialectConversion.cpp | 473 ++++++++++++--------- mlir/test/lib/TestDialect/TestPatterns.cpp | 20 +- mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp | 5 +- 12 files changed, 469 insertions(+), 306 deletions(-) diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index c43a2ae..67b0ac0 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -138,8 +138,9 @@ public: explicit RangeOpConversion(MLIRContext *context) : ConversionPattern(linalg::RangeOp::getOperationName(), 1, context) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { auto rangeOp = cast(op); auto rangeDescriptorType = linalg::convertLinalgType(rangeOp.getResult()->getType()); @@ -165,8 +166,9 @@ public: explicit ViewOpConversion(MLIRContext *context) : ConversionPattern(linalg::ViewOp::getOperationName(), 1, context) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { auto viewOp = cast(op); auto viewDescriptorType = linalg::convertLinalgType(viewOp.getViewType()); auto memrefType = @@ -290,8 +292,9 @@ public: explicit SliceOpConversion(MLIRContext *context) : ConversionPattern(linalg::SliceOp::getOperationName(), 1, context) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { auto sliceOp = cast(op); auto newViewDescriptorType = linalg::convertLinalgType(sliceOp.getViewType()); @@ -382,8 +385,9 @@ public: explicit DropConsumer(MLIRContext *context) : ConversionPattern("some_consumer", 1, context) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { rewriter.replaceOp(op, llvm::None); return matchSuccess(); } diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index c86f5d7..68a48d6 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -96,8 +96,9 @@ public: // an LLVM IR load. class LoadOpConversion : public LoadStoreOpConversion { using Base::Base; - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { edsc::ScopedContext edscContext(rewriter, op->getLoc()); auto elementType = linalg::convertLinalgType(*op->result_type_begin()); Value *viewDescriptor = operands[0]; @@ -113,8 +114,9 @@ class LoadOpConversion : public LoadStoreOpConversion { // an LLVM IR store. class StoreOpConversion : public LoadStoreOpConversion { using Base::Base; - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { edsc::ScopedContext edscContext(rewriter, op->getLoc()); Value *viewDescriptor = operands[1]; Value *data = operands[0]; diff --git a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp index e4df917..f3463ba 100644 --- a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp @@ -57,7 +57,7 @@ namespace { /// time both side of the cast (producer and consumer) will be lowered to a /// dialect like LLVM and end up with the same LLVM representation, at which /// point this becomes a no-op and is eliminated. -Value *typeCast(PatternRewriter &builder, Value *val, Type destTy) { +Value *typeCast(ConversionPatternRewriter &builder, Value *val, Type destTy) { if (val->getType() == destTy) return val; return builder.create(val->getLoc(), val, destTy) @@ -67,7 +67,7 @@ Value *typeCast(PatternRewriter &builder, Value *val, Type destTy) { /// Create a type cast to turn a toy.array into a memref. The Toy Array will be /// lowered to a memref during buffer allocation, at which point the type cast /// becomes useless. -Value *memRefTypeCast(PatternRewriter &builder, Value *val) { +Value *memRefTypeCast(ConversionPatternRewriter &builder, Value *val) { if (val->getType().isa()) return val; auto toyArrayTy = val->getType().dyn_cast(); @@ -87,8 +87,9 @@ public: explicit MulOpConversion(MLIRContext *context) : ConversionPattern(toy::MulOp::getOperationName(), 1, context) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { using namespace edsc; using intrinsics::constant_index; using linalg::intrinsics::range; diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index cd826fb..8b80588 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -92,8 +92,9 @@ public: /// the rewritten operands for `op` in the new function. /// The results created by the new IR with the builder are returned, and their /// number must match the number of result of `op`. - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { auto add = cast(op); auto loc = add.getLoc(); // Create a `toy.alloc` operation to allocate the output buffer for this op. @@ -133,8 +134,9 @@ public: explicit PrintOpConversion(MLIRContext *context) : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { // Get or create the declaration of the printf function in the module. FuncOp printfFunc = getPrintf(op->getParentOfType()); @@ -232,8 +234,9 @@ public: explicit ConstantOpConversion(MLIRContext *context) : ConversionPattern(toy::ConstantOp::getOperationName(), 1, context) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { toy::ConstantOp cstOp = cast(op); auto loc = cstOp.getLoc(); auto retTy = cstOp.getResult()->getType().cast(); @@ -276,8 +279,9 @@ public: explicit TransposeOpConversion(MLIRContext *context) : ConversionPattern(toy::TransposeOp::getOperationName(), 1, context) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { auto transpose = cast(op); auto loc = transpose.getLoc(); Value *result = memRefTypeCast( @@ -309,8 +313,9 @@ public: explicit ReturnOpConversion(MLIRContext *context) : ConversionPattern(toy::ReturnOp::getOperationName(), 1, context) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { // Argument is optional, handle both cases. if (op->getNumOperands()) rewriter.replaceOpWithNewOp(op, operands[0]); diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index de68e4b..d739a80 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -321,7 +321,10 @@ public: /// (perhaps transitively) dead. If any of those values are dead, this will /// remove them as well. virtual void replaceOp(Operation *op, ArrayRef newValues, - ArrayRef valuesToRemoveIfDead = {}); + ArrayRef valuesToRemoveIfDead); + void replaceOp(Operation *op, ArrayRef newValues) { + replaceOp(op, newValues, llvm::None); + } /// Replaces the result op with a new op that is created without verification. /// The result values of the two ops must be the same types. diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index bfe3674..68c6f12 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -31,6 +31,7 @@ namespace mlir { // Forward declarations. class Block; +class ConversionPatternRewriter; class FuncOp; class MLIRContext; class Operation; @@ -192,7 +193,7 @@ public: /// have successors. This function should not fail. If some specific cases of /// the operation are not supported, these cases should not be matched. virtual void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter) const { llvm_unreachable("unimplemented rewrite"); } @@ -209,7 +210,7 @@ public: virtual void rewrite(Operation *op, ArrayRef properOperands, ArrayRef destinations, ArrayRef> operands, - PatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter) const { llvm_unreachable("unimplemented rewrite for terminators"); } @@ -218,7 +219,7 @@ public: matchAndRewrite(Operation *op, ArrayRef properOperands, ArrayRef destinations, ArrayRef> operands, - PatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter) const { if (!match(op)) return matchFailure(); rewrite(op, properOperands, destinations, operands, rewriter); @@ -226,9 +227,9 @@ public: } /// Hook for derived classes to implement combined matching and rewriting. - virtual PatternMatchResult matchAndRewrite(Operation *op, - ArrayRef operands, - PatternRewriter &rewriter) const { + virtual PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { if (!match(op)) return matchFailure(); rewrite(op, operands, rewriter); @@ -244,6 +245,50 @@ private: }; //===----------------------------------------------------------------------===// +// Conversion PatternRewriter +//===----------------------------------------------------------------------===// + +namespace detail { +struct ConversionPatternRewriterImpl; +} // end namespace detail + +/// This class implements a pattern rewriter for use with ConversionPatterns. +class ConversionPatternRewriter final : public PatternRewriter { +public: + ConversionPatternRewriter(MLIRContext *ctx, TypeConverter *converter); + ~ConversionPatternRewriter() override; + + //===--------------------------------------------------------------------===// + // PatternRewriter Hooks + //===--------------------------------------------------------------------===// + + /// PatternRewriter hook for replacing the results of an operation. + void replaceOp(Operation *op, ArrayRef newValues, + ArrayRef valuesToRemoveIfDead) override; + using PatternRewriter::replaceOp; + + /// PatternRewriter hook for splitting a block into two parts. + Block *splitBlock(Block *block, Block::iterator before) override; + + /// PatternRewriter hook for moving blocks out of a region. + void inlineRegionBefore(Region ®ion, Region &parent, + Region::iterator before) override; + using PatternRewriter::inlineRegionBefore; + + /// PatternRewriter hook for creating a new operation. + Operation *createOperation(const OperationState &state) override; + + /// PatternRewriter hook for updating the root operation in-place. + void notifyRootUpdated(Operation *op) override; + + /// Return a reference to the internal implementation. + detail::ConversionPatternRewriterImpl &getImpl(); + +private: + std::unique_ptr impl; +}; + +//===----------------------------------------------------------------------===// // ConversionTarget //===----------------------------------------------------------------------===// @@ -260,7 +305,7 @@ public: /// by the target. Dynamic, - /// This target explicitly does not support this operation. + /// The target explicitly does not support this operation. Illegal, }; diff --git a/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp b/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp index 9c2053d..1515d95 100644 --- a/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp +++ b/mlir/lib/Conversion/ControlFlowToCFG/ConvertControlFlowToCFG.cpp @@ -104,8 +104,9 @@ struct ForLowering : public ConversionPattern { ForLowering(MLIRContext *ctx) : ConversionPattern(ForOp::getOperationName(), 1, ctx) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override; + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; }; // Create a CFG subgraph for the loop.if operation (including its "then" and @@ -154,16 +155,18 @@ struct IfLowering : public ConversionPattern { IfLowering(MLIRContext *ctx) : ConversionPattern(IfOp::getOperationName(), 1, ctx) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override; + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; }; struct TerminatorLowering : public ConversionPattern { TerminatorLowering(MLIRContext *ctx) : ConversionPattern(TerminatorOp::getOperationName(), 1, ctx) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { rewriter.replaceOp(op, {}); return matchSuccess(); } @@ -172,7 +175,7 @@ struct TerminatorLowering : public ConversionPattern { PatternMatchResult ForLowering::matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter) const { auto forOp = cast(op); Location loc = op->getLoc(); @@ -228,7 +231,7 @@ ForLowering::matchAndRewrite(Operation *op, ArrayRef operands, PatternMatchResult IfLowering::matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter) const { auto ifOp = cast(op); auto loc = op->getLoc(); diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index aa72a7b..042e768 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -229,7 +229,7 @@ public: } // Create an LLVM IR pseudo-operation defining the given index constant. - Value *createIndexConstant(PatternRewriter &builder, Location loc, + Value *createIndexConstant(ConversionPatternRewriter &builder, Location loc, uint64_t value) const { auto attr = builder.getIntegerAttr(builder.getIndexType(), value); return builder.create(loc, getIndexType(), attr); @@ -237,7 +237,7 @@ public: // Get the array attribute named "position" containing the given list of // integers as integer attribute elements. - static ArrayAttr getIntegerArrayAttr(PatternRewriter &builder, + static ArrayAttr getIntegerArrayAttr(ConversionPatternRewriter &builder, ArrayRef values) { SmallVector attrs; attrs.reserve(values.size()); @@ -247,7 +247,8 @@ public: } // Extract raw data pointer value from a value representing a memref. - static Value *extractMemRefElementPtr(PatternRewriter &builder, Location loc, + static Value *extractMemRefElementPtr(ConversionPatternRewriter &builder, + Location loc, Value *convertedMemRefValue, Type elementTypePtr, bool hasStaticShape) { @@ -274,8 +275,9 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern { // Convert the type of the result to an LLVM type, pass operands as is, // preserve attributes. - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { unsigned numResults = op->getNumResults(); Type packedType; @@ -398,7 +400,7 @@ struct AllocOpLowering : public LLVMLegalizationPattern { } void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + ConversionPatternRewriter &rewriter) const override { auto allocOp = cast(op); MemRefType type = allocOp.getType(); @@ -495,8 +497,9 @@ struct AllocOpLowering : public LLVMLegalizationPattern { struct DeallocOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { assert(operands.size() == 1 && "dealloc takes one operand"); OperandAdaptor transformed(operands); @@ -538,7 +541,7 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern { } void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + ConversionPatternRewriter &rewriter) const override { auto memRefCastOp = cast(op); OperandAdaptor transformed(operands); auto targetType = memRefCastOp.getType(); @@ -610,7 +613,7 @@ struct DimOpLowering : public LLVMLegalizationPattern { } void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + ConversionPatternRewriter &rewriter) const override { auto dimOp = cast(op); OperandAdaptor transformed(operands); MemRefType type = dimOp.getOperand()->getType().cast(); @@ -660,7 +663,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { // by accumulating the running linearized value. // Note that `indices` and `allocSizes` are passed in the same order as they // appear in load/store operations and memref type declarations. - Value *linearizeSubscripts(PatternRewriter &builder, Location loc, + Value *linearizeSubscripts(ConversionPatternRewriter &builder, Location loc, ArrayRef indices, ArrayRef allocSizes) const { assert(indices.size() == allocSizes.size() && @@ -686,7 +689,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { Value *getElementPtr(Location loc, Type elementTypePtr, ArrayRef shape, Value *memRefDescriptor, ArrayRef indices, - PatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter) const { // Get the list of MemRef sizes. Static sizes are defined as constants. // Dynamic sizes are extracted from the MemRef descriptor, where they start // from the position 1 (the buffer is at position 0). @@ -722,7 +725,7 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { Value *getRawElementPtr(Location loc, Type elementTypePtr, ArrayRef shape, Value *rawDataPtr, ArrayRef indices, - PatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter) const { if (shape.empty()) return rawDataPtr; @@ -738,7 +741,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { } Value *getDataPtr(Location loc, MemRefType type, Value *dataPtr, - ArrayRef indices, PatternRewriter &rewriter, + ArrayRef indices, + ConversionPatternRewriter &rewriter, llvm::Module &module) const { auto ptrType = getMemRefElementPtrType(type, this->lowering); auto shape = type.getShape(); @@ -755,8 +759,9 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { struct LoadOpLowering : public LoadStoreOpLowering { using Base::Base; - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { auto loadOp = cast(op); OperandAdaptor transformed(operands); auto type = loadOp.getMemRefType(); @@ -776,8 +781,9 @@ struct LoadOpLowering : public LoadStoreOpLowering { struct StoreOpLowering : public LoadStoreOpLowering { using Base::Base; - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { auto type = cast(op).getMemRefType(); OperandAdaptor transformed(operands); @@ -796,8 +802,9 @@ struct StoreOpLowering : public LoadStoreOpLowering { struct IndexCastOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { IndexCastOpOperandAdaptor transformed(operands); auto indexCastOp = cast(op); @@ -829,8 +836,9 @@ static LLVM::ICmpPredicate convertCmpIPredicate(CmpIPredicate pred) { struct CmpIOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { auto cmpiOp = cast(op); CmpIOpOperandAdaptor transformed(operands); @@ -851,11 +859,11 @@ struct OneToOneLLVMTerminatorLowering using LLVMLegalizationPattern::LLVMLegalizationPattern; using Super = OneToOneLLVMTerminatorLowering; - PatternMatchResult matchAndRewrite(Operation *op, - ArrayRef properOperands, - ArrayRef destinations, - ArrayRef> operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef properOperands, + ArrayRef destinations, + ArrayRef> operands, + ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, properOperands, destinations, operands, op->getAttrs()); return this->matchSuccess(); @@ -871,8 +879,9 @@ struct OneToOneLLVMTerminatorLowering struct ReturnOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { unsigned numArguments = op->getNumOperands(); // If ReturnOp has 0 or 1 operand, create it and return immediately. diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index fb26f85..98be230 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -164,8 +164,9 @@ public: LLVMTypeConverter &lowering_) : LLVMOpLowering(BufferAllocOp::getOperationName(), context, lowering_) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { auto indexType = IndexType::get(op->getContext()); auto voidPtrTy = LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); @@ -227,8 +228,9 @@ public: : LLVMOpLowering(BufferDeallocOp::getOperationName(), context, lowering_) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { auto voidPtrTy = LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); // Insert the `free` declaration if it is not already present. @@ -261,8 +263,9 @@ public: BufferSizeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(BufferSizeOp::getOperationName(), context, lowering_) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); edsc::ScopedContext context(rewriter, op->getLoc()); rewriter.replaceOp( @@ -277,8 +280,9 @@ public: explicit DimOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(linalg::DimOp::getOperationName(), context, lowering_) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { auto dimOp = cast(op); auto indexTy = lowering.convertType(rewriter.getIndexType()); edsc::ScopedContext context(rewriter, op->getLoc()); @@ -307,7 +311,7 @@ public: // a getelementptr. This must be called under an edsc::ScopedContext. Value *obtainDataPtr(Operation *op, Value *viewDescriptor, ArrayRef indices, - PatternRewriter &rewriter) const { + ConversionPatternRewriter &rewriter) const { auto loadOp = cast(op); auto elementTy = getPtrToElementType(loadOp.getViewType(), lowering); auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); @@ -333,8 +337,9 @@ public: // an LLVM IR load. class LoadOpConversion : public LoadStoreOpConversion { using Base::Base; - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { edsc::ScopedContext edscContext(rewriter, op->getLoc()); auto elementTy = lowering.convertType(*op->result_type_begin()); Value *viewDescriptor = operands[0]; @@ -351,8 +356,9 @@ public: explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { auto rangeOp = cast(op); auto rangeDescriptorTy = convertLinalgType(rangeOp.getResult()->getType(), lowering); @@ -380,8 +386,9 @@ public: : LLVMOpLowering(RangeIntersectOp::getOperationName(), context, lowering_) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { auto rangeIntersectOp = cast(op); auto rangeDescriptorTy = convertLinalgType(rangeIntersectOp.getResult()->getType(), lowering); @@ -423,8 +430,9 @@ public: explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { auto sliceOp = cast(op); auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering); auto viewType = sliceOp.getBaseViewType(); @@ -503,8 +511,9 @@ public: // an LLVM IR store. class StoreOpConversion : public LoadStoreOpConversion { using Base::Base; - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { edsc::ScopedContext edscContext(rewriter, op->getLoc()); Value *data = operands[0]; Value *viewDescriptor = operands[1]; @@ -521,8 +530,9 @@ public: explicit ViewOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(ViewOp::getOperationName(), context, lowering_) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { auto viewOp = cast(op); auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering); auto elementTy = getPtrToElementType(viewOp.getViewType(), lowering); @@ -598,9 +608,9 @@ static FuncOp getLLVMLibraryCallImplDefinition(FuncOp libFn) { // Get function definition for the LinalgOp. If it doesn't exist, insert a // definition. template -static FuncOp getLLVMLibraryCallDeclaration(Operation *op, - LLVMTypeConverter &lowering, - PatternRewriter &rewriter) { +static FuncOp +getLLVMLibraryCallDeclaration(Operation *op, LLVMTypeConverter &lowering, + ConversionPatternRewriter &rewriter) { assert(isa(op)); auto fnName = LinalgOp::getLibraryCallName(); auto module = op->getParentOfType(); @@ -689,8 +699,9 @@ public: LinalgTypeConverter &lowering_) : LLVMOpLowering(LinalgOp::getOperationName(), context, lowering_) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { // Only emit library call declaration. Fill in the body later. auto f = getLLVMLibraryCallDeclaration(op, lowering, rewriter); static_cast(lowering).addLibraryFnDeclaration(f); diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 3ef9766..ed271b6 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -28,6 +28,7 @@ #include "llvm/Support/raw_ostream.h" using namespace mlir; +using namespace mlir::detail; #define DEBUG_TYPE "dialect-conversion" @@ -102,6 +103,7 @@ struct ArgConverter { /// The pattern rewriter to use when materializing conversions. PatternRewriter &rewriter; }; +} // end anonymous namespace constexpr StringLiteral ArgConverter::kCastName; @@ -283,9 +285,9 @@ Operation *ArgConverter::createCast(ArrayRef inputs, Type outputType) { } //===----------------------------------------------------------------------===// -// DialectConversionRewriter +// ConversionPatternRewriterImpl //===----------------------------------------------------------------------===// - +namespace { /// This class contains a snapshot of the current conversion rewriter state. /// This is useful when saving and undoing a set of rewrites. struct RewriterState { @@ -307,10 +309,11 @@ struct RewriterState { /// The current number of type conversion actions performed. unsigned numTypeConversions; }; +} // end anonymous namespace -/// This class implements a pattern rewriter for ConversionPattern -/// patterns. It automatically performs remapping of replaced operation values. -struct DialectConversionRewriter final : public PatternRewriter { +namespace mlir { +namespace detail { +struct ConversionPatternRewriterImpl { /// This class represents one requested operation replacement via 'replaceOp'. struct OpReplacement { OpReplacement() = default; @@ -362,205 +365,55 @@ struct DialectConversionRewriter final : public PatternRewriter { NamedAttributeList originalParentAttributes; }; - DialectConversionRewriter(MLIRContext *ctx, TypeConverter *converter) - : PatternRewriter(ctx), argConverter(converter, *this) {} - ~DialectConversionRewriter() = default; + ConversionPatternRewriterImpl(PatternRewriter &rewriter, + TypeConverter *converter) + : argConverter(converter, rewriter) {} /// Return the current state of the rewriter. - RewriterState getCurrentState() { - return RewriterState(createdOps.size(), replacements.size(), - blockActions.size(), typeConversions.size()); - } + RewriterState getCurrentState(); /// Reset the state of the rewriter to a previously saved point. - void resetState(RewriterState state) { - // Undo any type conversions or block actions. - undoTypeConversions(state.numTypeConversions); - undoBlockActions(state.numBlockActions); - - // Reset any replaced operations and undo any saved mappings. - for (auto &repl : llvm::drop_begin(replacements, state.numReplacements)) - for (auto *result : repl.op->getResults()) - mapping.erase(result); - replacements.resize(state.numReplacements); - - // Pop all of the newly created operations. - while (createdOps.size() != state.numCreatedOperations) - createdOps.pop_back_val()->erase(); - } + void resetState(RewriterState state); /// Undo the block actions (motions, splits) one by one in reverse order until /// "numActionsToKeep" actions remains. - void undoBlockActions(unsigned numActionsToKeep = 0) { - for (auto &action : - llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) { - switch (action.kind) { - // Merge back the block that was split out. - case BlockActionKind::Split: { - action.originalBlock->getOperations().splice( - action.originalBlock->end(), action.block->getOperations()); - action.block->erase(); - break; - } - // Move the block back to its original position. - case BlockActionKind::Move: { - Region *originalRegion = action.originalPosition.region; - originalRegion->getBlocks().splice( - std::next(originalRegion->begin(), - action.originalPosition.position), - action.block->getParent()->getBlocks(), action.block); - break; - } - } - } - blockActions.resize(numActionsToKeep); - } + void undoBlockActions(unsigned numActionsToKeep = 0); /// Undo the type conversion actions one by one, until "numActionsToKeep" /// actions remain. - void undoTypeConversions(unsigned numActionsToKeep = 0) { - for (auto &conversion : - llvm::drop_begin(typeConversions, numActionsToKeep)) { - if (auto *region = conversion.object.dyn_cast()) - region->getContainingOp()->setAttrs( - conversion.originalParentAttributes); - else - argConverter.discardPendingRewrites(conversion.object.get()); - } - typeConversions.resize(numActionsToKeep); - } + void undoTypeConversions(unsigned numActionsToKeep = 0); /// Cleanup and destroy any generated rewrite operations. This method is /// invoked when the conversion process fails. - void discardRewrites() { - undoTypeConversions(); - undoBlockActions(); - - // Remove any newly created ops. - for (auto *op : createdOps) { - op->dropAllDefinedValueUses(); - op->erase(); - } - } + void discardRewrites(); /// Apply all requested operation rewrites. This method is invoked when the /// conversion process succeeds. - void applyRewrites() { - // Apply all of the rewrites replacements requested during conversion. - for (auto &repl : replacements) { - for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i) - repl.op->getResult(i)->replaceAllUsesWith( - mapping.lookupOrDefault(repl.newValues[i])); - - // if this operation defines any regions, drop any pending argument - // rewrites. - if (argConverter.typeConverter && repl.op->getNumRegions()) { - for (auto ®ion : repl.op->getRegions()) - for (auto &block : region) - argConverter.cancelPendingRewrites(&block); - } - } - - // In a second pass, erase all of the replaced operations in reverse. This - // allows processing nested operations before their parent region is - // destroyed. - for (auto &repl : llvm::reverse(replacements)) - repl.op->erase(); - - argConverter.applyRewrites(); - } + void applyRewrites(); /// Return if the given block has already been converted. - bool hasSignatureBeenConverted(Block *block) { - return argConverter.hasBeenConverted(block); - } + bool hasSignatureBeenConverted(Block *block); /// Convert the signature of the given region. - LogicalResult convertRegionSignature(Region ®ion) { - auto parentAttrs = region.getContainingOp()->getAttrList(); - auto result = argConverter.convertSignature(region, mapping); - if (succeeded(result)) { - typeConversions.push_back(TypeConversion{®ion, parentAttrs}); - if (!region.empty()) - typeConversions.push_back( - TypeConversion{®ion.front(), NamedAttributeList()}); - } - return result; - } + LogicalResult convertRegionSignature(Region ®ion); /// Convert the signature of the given block. - LogicalResult convertBlockSignature(Block *block) { - auto result = argConverter.convertSignature(block, mapping); - if (succeeded(result)) - typeConversions.push_back(TypeConversion{block, NamedAttributeList()}); - return result; - } + LogicalResult convertBlockSignature(Block *block); /// PatternRewriter hook for replacing the results of an operation. void replaceOp(Operation *op, ArrayRef newValues, - ArrayRef valuesToRemoveIfDead) override { - assert(newValues.size() == op->getNumResults()); - - // Create mappings for each of the new result values. - for (unsigned i = 0, e = newValues.size(); i < e; ++i) { - assert((newValues[i] || op->getResult(i)->use_empty()) && - "result value has remaining uses that must be replaced"); - if (newValues[i]) - mapping.map(op->getResult(i), newValues[i]); - } + ArrayRef valuesToRemoveIfDead); - // Record the requested operation replacement. - replacements.emplace_back(op, newValues); - } - - /// PatternRewriter hook for splitting a block into two parts. - Block *splitBlock(Block *block, Block::iterator before) override { - auto *continuation = PatternRewriter::splitBlock(block, before); - BlockAction action; - action.kind = BlockActionKind::Split; - action.block = continuation; - action.originalBlock = block; - blockActions.push_back(action); - return continuation; - } - - /// PatternRewriter hook for moving blocks out of a region. - void inlineRegionBefore(Region ®ion, Region &parent, - Region::iterator before) override { - for (auto &pair : llvm::enumerate(region)) { - Block &block = pair.value(); - unsigned position = pair.index(); - BlockAction action; - action.kind = BlockActionKind::Move; - action.block = █ - action.originalPosition = {®ion, position}; - blockActions.push_back(action); - } - PatternRewriter::inlineRegionBefore(region, parent, before); - } - - /// PatternRewriter hook for creating a new operation. - Operation *createOperation(const OperationState &state) override { - auto *result = OpBuilder::createOperation(state); - createdOps.push_back(result); - return result; - } + /// Notifies that a block was split. + void notifySplitBlock(Block *block, Block *continuation); - /// PatternRewriter hook for updating the root operation in-place. - void notifyRootUpdated(Operation *op) override { - // The rewriter caches changes to the IR to allow for operating in-place and - // backtracking. The rewrite is currently not capable of backtracking - // in-place modifications. - llvm_unreachable("in-place operation updates are not supported"); - } + /// Notifies that the blocks of a region are about to be moved. + void notifyRegionIsBeingInlinedBefore(Region ®ion, Region &parent, + Region::iterator before); /// Remap the given operands to those with potentially different types. void remapValues(Operation::operand_range operands, - SmallVectorImpl &remapped) { - remapped.reserve(llvm::size(operands)); - for (Value *operand : operands) - remapped.push_back(mapping.lookupOrDefault(operand)); - } + SmallVectorImpl &remapped); // Mapping between replaced values that differ in type. This happens when // replacing a value with one of a different type. @@ -581,7 +434,226 @@ struct DialectConversionRewriter final : public PatternRewriter { /// Ordered list of type conversion actions. SmallVector typeConversions; }; -} // end anonymous namespace +} // end namespace detail +} // end namespace mlir + +RewriterState ConversionPatternRewriterImpl::getCurrentState() { + return RewriterState(createdOps.size(), replacements.size(), + blockActions.size(), typeConversions.size()); +} + +void ConversionPatternRewriterImpl::resetState(RewriterState state) { + // Undo any type conversions or block actions. + undoTypeConversions(state.numTypeConversions); + undoBlockActions(state.numBlockActions); + + // Reset any replaced operations and undo any saved mappings. + for (auto &repl : llvm::drop_begin(replacements, state.numReplacements)) + for (auto *result : repl.op->getResults()) + mapping.erase(result); + replacements.resize(state.numReplacements); + + // Pop all of the newly created operations. + while (createdOps.size() != state.numCreatedOperations) + createdOps.pop_back_val()->erase(); +} + +void ConversionPatternRewriterImpl::undoBlockActions( + unsigned numActionsToKeep) { + for (auto &action : + llvm::reverse(llvm::drop_begin(blockActions, numActionsToKeep))) { + switch (action.kind) { + // Merge back the block that was split out. + case BlockActionKind::Split: { + action.originalBlock->getOperations().splice( + action.originalBlock->end(), action.block->getOperations()); + action.block->erase(); + break; + } + // Move the block back to its original position. + case BlockActionKind::Move: { + Region *originalRegion = action.originalPosition.region; + originalRegion->getBlocks().splice( + std::next(originalRegion->begin(), action.originalPosition.position), + action.block->getParent()->getBlocks(), action.block); + break; + } + } + } + blockActions.resize(numActionsToKeep); +} + +void ConversionPatternRewriterImpl::undoTypeConversions( + unsigned numActionsToKeep) { + for (auto &conversion : llvm::drop_begin(typeConversions, numActionsToKeep)) { + if (auto *region = conversion.object.dyn_cast()) + region->getContainingOp()->setAttrs(conversion.originalParentAttributes); + else + argConverter.discardPendingRewrites(conversion.object.get()); + } + typeConversions.resize(numActionsToKeep); +} + +void ConversionPatternRewriterImpl::discardRewrites() { + undoTypeConversions(); + undoBlockActions(); + + // Remove any newly created ops. + for (auto *op : createdOps) { + op->dropAllDefinedValueUses(); + op->erase(); + } +} + +void ConversionPatternRewriterImpl::applyRewrites() { + // Apply all of the rewrites replacements requested during conversion. + for (auto &repl : replacements) { + for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i) + repl.op->getResult(i)->replaceAllUsesWith( + mapping.lookupOrDefault(repl.newValues[i])); + + // If this operation defines any regions, drop any pending argument + // rewrites. + if (argConverter.typeConverter && repl.op->getNumRegions()) { + for (auto ®ion : repl.op->getRegions()) + for (auto &block : region) + argConverter.cancelPendingRewrites(&block); + } + } + + // In a second pass, erase all of the replaced operations in reverse. This + // allows processing nested operations before their parent region is + // destroyed. + for (auto &repl : llvm::reverse(replacements)) + repl.op->erase(); + + argConverter.applyRewrites(); +} + +bool ConversionPatternRewriterImpl::hasSignatureBeenConverted(Block *block) { + return argConverter.hasBeenConverted(block); +} + +LogicalResult +ConversionPatternRewriterImpl::convertRegionSignature(Region ®ion) { + auto parentAttrs = region.getContainingOp()->getAttrList(); + auto result = argConverter.convertSignature(region, mapping); + if (succeeded(result)) { + typeConversions.push_back(TypeConversion{®ion, parentAttrs}); + if (!region.empty()) + typeConversions.push_back( + TypeConversion{®ion.front(), NamedAttributeList()}); + } + return result; +} + +LogicalResult +ConversionPatternRewriterImpl::convertBlockSignature(Block *block) { + auto result = argConverter.convertSignature(block, mapping); + if (succeeded(result)) + typeConversions.push_back(TypeConversion{block, NamedAttributeList()}); + return result; +} + +void ConversionPatternRewriterImpl::replaceOp( + Operation *op, ArrayRef newValues, + ArrayRef valuesToRemoveIfDead) { + assert(newValues.size() == op->getNumResults()); + + // Create mappings for each of the new result values. + for (unsigned i = 0, e = newValues.size(); i < e; ++i) { + assert((newValues[i] || op->getResult(i)->use_empty()) && + "result value has remaining uses that must be replaced"); + if (newValues[i]) + mapping.map(op->getResult(i), newValues[i]); + } + + // Record the requested operation replacement. + replacements.emplace_back(op, newValues); +} + +void ConversionPatternRewriterImpl::notifySplitBlock(Block *block, + Block *continuation) { + BlockAction action; + action.kind = BlockActionKind::Split; + action.block = continuation; + action.originalBlock = block; + blockActions.push_back(action); +} + +void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore( + Region ®ion, Region &parent, Region::iterator before) { + for (auto &pair : llvm::enumerate(region)) { + Block &block = pair.value(); + unsigned position = pair.index(); + BlockAction action; + action.kind = BlockActionKind::Move; + action.block = █ + action.originalPosition = {®ion, position}; + blockActions.push_back(action); + } +} + +void ConversionPatternRewriterImpl::remapValues( + Operation::operand_range operands, SmallVectorImpl &remapped) { + remapped.reserve(llvm::size(operands)); + for (Value *operand : operands) + remapped.push_back(mapping.lookupOrDefault(operand)); +} + +//===----------------------------------------------------------------------===// +// ConversionPatternRewriter +//===----------------------------------------------------------------------===// + +ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx, + TypeConverter *converter) + : PatternRewriter(ctx), + impl(new detail::ConversionPatternRewriterImpl(*this, converter)) {} +ConversionPatternRewriter::~ConversionPatternRewriter() {} + +/// PatternRewriter hook for replacing the results of an operation. +void ConversionPatternRewriter::replaceOp( + Operation *op, ArrayRef newValues, + ArrayRef valuesToRemoveIfDead) { + impl->replaceOp(op, newValues, valuesToRemoveIfDead); +} + +/// PatternRewriter hook for splitting a block into two parts. +Block *ConversionPatternRewriter::splitBlock(Block *block, + Block::iterator before) { + auto *continuation = PatternRewriter::splitBlock(block, before); + impl->notifySplitBlock(block, continuation); + return continuation; +} + +/// PatternRewriter hook for moving blocks out of a region. +void ConversionPatternRewriter::inlineRegionBefore(Region ®ion, + Region &parent, + Region::iterator before) { + impl->notifyRegionIsBeingInlinedBefore(region, parent, before); + PatternRewriter::inlineRegionBefore(region, parent, before); +} + +/// PatternRewriter hook for creating a new operation. +Operation * +ConversionPatternRewriter::createOperation(const OperationState &state) { + auto *result = OpBuilder::createOperation(state); + impl->createdOps.push_back(result); + return result; +} + +/// PatternRewriter hook for updating the root operation in-place. +void ConversionPatternRewriter::notifyRootUpdated(Operation *op) { + // The rewriter caches changes to the IR to allow for operating in-place and + // backtracking. The rewriter is currently not capable of backtracking + // in-place modifications. + llvm_unreachable("in-place operation updates are not supported"); +} + +/// Return a reference to the internal implementation. +detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() { + return *impl; +} //===----------------------------------------------------------------------===// // Conversion Patterns @@ -592,12 +664,12 @@ PatternMatchResult ConversionPattern::matchAndRewrite(Operation *op, PatternRewriter &rewriter) const { SmallVector operands; - auto &dialectRewriter = static_cast(rewriter); - dialectRewriter.remapValues(op->getOperands(), operands); + auto &dialectRewriter = static_cast(rewriter); + dialectRewriter.getImpl().remapValues(op->getOperands(), operands); // If this operation has no successors, invoke the rewrite directly. if (op->getNumSuccessors() == 0) - return matchAndRewrite(op, operands, rewriter); + return matchAndRewrite(op, operands, dialectRewriter); // Otherwise, we need to remap the successors. SmallVector destinations; @@ -620,7 +692,7 @@ ConversionPattern::matchAndRewrite(Operation *op, op, llvm::makeArrayRef(operands.data(), operands.data() + firstSuccessorOperand), - destinations, operandsPerDestination, rewriter); + destinations, operandsPerDestination, dialectRewriter); } //===----------------------------------------------------------------------===// @@ -648,13 +720,13 @@ public: /// Attempt to legalize the given operation. Returns success if the operation /// was legalized, failure otherwise. - LogicalResult legalize(Operation *op, DialectConversionRewriter &rewriter); + LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter); private: /// Attempt to legalize the given operation by applying the provided pattern. /// Returns success if the operation was legalized, failure otherwise. LogicalResult legalizePattern(Operation *op, RewritePattern *pattern, - DialectConversionRewriter &rewriter); + ConversionPatternRewriter &rewriter); /// Build an optimistic legalization graph given the provided patterns. This /// function populates 'legalizerPatterns' with the operations that are not @@ -693,15 +765,16 @@ bool OperationLegalizer::isIllegal(Operation *op) const { LogicalResult OperationLegalizer::legalize(Operation *op, - DialectConversionRewriter &rewriter) { + ConversionPatternRewriter &rewriter) { // Make sure that the signature of the parent block of this operation has been // converted. - if (rewriter.argConverter.typeConverter) { + auto &rewriterImpl = rewriter.getImpl(); + if (rewriterImpl.argConverter.typeConverter) { auto *block = op->getBlock(); - if (block && !rewriter.hasSignatureBeenConverted(block)) { + if (block && !rewriterImpl.hasSignatureBeenConverted(block)) { if (failed(block->isEntryBlock() - ? rewriter.convertRegionSignature(*block->getParent()) - : rewriter.convertBlockSignature(block))) + ? rewriterImpl.convertRegionSignature(*block->getParent()) + : rewriterImpl.convertBlockSignature(block))) return failure(); } } @@ -743,7 +816,7 @@ OperationLegalizer::legalize(Operation *op, LogicalResult OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern, - DialectConversionRewriter &rewriter) { + ConversionPatternRewriter &rewriter) { LLVM_DEBUG({ llvm::dbgs() << "-* Applying rewrite pattern '" << op->getName() << " -> ("; interleaveComma(pattern->getGeneratedOps(), llvm::dbgs()); @@ -759,10 +832,11 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern, return failure(); } - RewriterState curState = rewriter.getCurrentState(); + auto &rewriterImpl = rewriter.getImpl(); + RewriterState curState = rewriterImpl.getCurrentState(); auto cleanupFailure = [&] { // Reset the rewriter state and pop this pattern. - rewriter.resetState(curState); + rewriterImpl.resetState(curState); appliedPatterns.erase(pattern); return failure(); }; @@ -776,9 +850,9 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern, // Recursively legalize each of the new operations. for (unsigned i = curState.numCreatedOperations, - e = rewriter.createdOps.size(); + e = rewriterImpl.createdOps.size(); i != e; ++i) { - if (failed(legalize(rewriter.createdOps[i], rewriter))) { + if (failed(legalize(rewriterImpl.createdOps[i], rewriter))) { LLVM_DEBUG(llvm::dbgs() << "-- FAIL: Generated operation was illegal.\n"); return cleanupFailure(); } @@ -941,7 +1015,7 @@ struct OperationConverter { private: /// Converts an operation with the given rewriter. - LogicalResult convert(DialectConversionRewriter &rewriter, Operation *op); + LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op); /// Recursively collect all of the operations, to convert from within /// 'region'. @@ -991,7 +1065,7 @@ OperationConverter::computeConversionSet(Region ®ion, } /// Converts an operation with the given rewriter. -LogicalResult OperationConverter::convert(DialectConversionRewriter &rewriter, +LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter, Operation *op) { // Legalize the given operation. if (failed(opLegalizer.legalize(op, rewriter))) { @@ -1013,16 +1087,17 @@ LogicalResult OperationConverter::convert(DialectConversionRewriter &rewriter, // within. // FIXME(riverriddle) This should be replaced by patterns when the pattern // rewriter exposes functionality to remap region signatures. - if (rewriter.argConverter.typeConverter) { + auto &rewriterImpl = rewriter.getImpl(); + if (rewriterImpl.argConverter.typeConverter) { for (auto ®ion : op->getRegions()) - if (region.empty() && failed(rewriter.convertRegionSignature(region))) + if (region.empty() && failed(rewriterImpl.convertRegionSignature(region))) return failure(); } return success(); } -/// Converts the given top-level operation to the conversion target. +/// Converts the given operations to the conversion target. LogicalResult OperationConverter::convertOperations(ArrayRef ops, TypeConverter *typeConverter) { @@ -1039,16 +1114,16 @@ OperationConverter::convertOperations(ArrayRef ops, } // Convert each operation and discard rewrites on failure. - DialectConversionRewriter rewriter(ops.front()->getContext(), typeConverter); + ConversionPatternRewriter rewriter(ops.front()->getContext(), typeConverter); for (auto *op : toConvert) { if (failed(convert(rewriter, op))) { - rewriter.discardRewrites(); + rewriter.getImpl().discardRewrites(); return failure(); } } // Otherwise the body conversion succeeded, so apply all rewrites. - rewriter.applyRewrites(); + rewriter.getImpl().applyRewrites(); return success(); } diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp index 6b1266e..410536c 100644 --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -62,8 +62,9 @@ struct TestRegionRewriteBlockMovement : public ConversionPattern { TestRegionRewriteBlockMovement(MLIRContext *ctx) : ConversionPattern("test.region", 1, ctx) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const final { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { // Inline this region into the parent region. auto &parentRegion = *op->getContainingRegion(); rewriter.inlineRegionBefore(op->getRegion(0), parentRegion, @@ -101,8 +102,9 @@ struct TestRegionRewriteUndo : public RewritePattern { /// This pattern simply erases the given operation. struct TestDropOp : public ConversionPattern { TestDropOp(MLIRContext *ctx) : ConversionPattern("test.drop_op", 1, ctx) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const final { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { rewriter.replaceOp(op, llvm::None); return matchSuccess(); } @@ -111,8 +113,9 @@ struct TestDropOp : public ConversionPattern { struct TestPassthroughInvalidOp : public ConversionPattern { TestPassthroughInvalidOp(MLIRContext *ctx) : ConversionPattern("test.invalid", 1, ctx) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const final { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { rewriter.replaceOpWithNewOp(op, llvm::None, operands, llvm::None); return matchSuccess(); @@ -122,8 +125,9 @@ struct TestPassthroughInvalidOp : public ConversionPattern { struct TestSplitReturnType : public ConversionPattern { TestSplitReturnType(MLIRContext *ctx) : ConversionPattern("test.return", 1, ctx) {} - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const final { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const final { // Check for a return of F32. if (op->getNumOperands() != 1 || !op->getOperand(0)->getType().isF32()) return matchFailure(); diff --git a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp index b76a565..0077e95 100644 --- a/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp +++ b/mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp @@ -115,8 +115,9 @@ public: lowering_.getDialect()->getContext(), lowering_) {} // Convert the kernel arguments to an LLVM type, preserve the rest. - PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { rewriter.clone(*op)->setOperands(operands); return rewriter.replaceOp(op, llvm::None), matchSuccess(); } -- 2.7.4