From e25796ef6e7043de3a4d8d85b92cad5832ee676f Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 6 Jun 2019 15:38:08 -0700 Subject: [PATCH] Add support for matchAndRewrite to the DialectConversion patterns. This also drops the default "always succeed" match override to better align with RewritePattern. PiperOrigin-RevId: 251941625 --- .../Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp | 20 +++++--- .../Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp | 10 ++-- mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp | 5 +- mlir/examples/toy/Ch5/mlir/LateLowering.cpp | 25 ++++++---- mlir/include/mlir/Transforms/DialectConversion.h | 38 ++++++++++----- .../lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp | 50 ++++++++++--------- mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp | 57 +++++++++++++--------- mlir/lib/Transforms/DialectConversion.cpp | 21 ++++---- 8 files changed, 133 insertions(+), 93 deletions(-) diff --git a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp index d13f7f3..f090856 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp @@ -135,8 +135,8 @@ public: explicit RangeOpConversion(MLIRContext *context) : ConversionPattern(linalg::RangeOp::getOperationName(), 1, context) {} - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { auto rangeOp = cast(op); auto rangeDescriptorType = linalg::convertLinalgType(rangeOp.getResult()->getType()); @@ -153,6 +153,7 @@ public: rangeDescriptor = insertvalue(rangeDescriptorType, rangeDescriptor, operands[2], makePositionAttr(rewriter, 2)); rewriter.replaceOp(op, rangeDescriptor); + return matchSuccess(); } }; @@ -161,8 +162,8 @@ public: explicit ViewOpConversion(MLIRContext *context) : ConversionPattern(linalg::ViewOp::getOperationName(), 1, context) {} - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { auto viewOp = cast(op); auto viewDescriptorType = linalg::convertLinalgType(viewOp.getViewType()); auto memrefType = @@ -277,6 +278,7 @@ public: } rewriter.replaceOp(op, viewDescriptor); + return matchSuccess(); } }; @@ -285,8 +287,8 @@ public: explicit SliceOpConversion(MLIRContext *context) : ConversionPattern(linalg::SliceOp::getOperationName(), 1, context) {} - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { auto sliceOp = cast(op); auto newViewDescriptorType = linalg::convertLinalgType(sliceOp.getViewType()); @@ -366,6 +368,7 @@ public: } rewriter.replaceOp(op, newViewDescriptor); + return matchSuccess(); } }; @@ -376,9 +379,10 @@ public: explicit DropConsumer(MLIRContext *context) : ConversionPattern("some_consumer", 1, context) {} - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { rewriter.replaceOp(op, llvm::None); + return matchSuccess(); } }; diff --git a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp index ef0d858..26d6af8 100644 --- a/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp @@ -95,8 +95,8 @@ public: // an LLVM IR load. class LoadOpConversion : public LoadStoreOpConversion { using Base::Base; - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { edsc::ScopedContext edscContext(rewriter, op->getLoc()); auto elementType = linalg::convertLinalgType(*op->result_type_begin()); Value *viewDescriptor = operands[0]; @@ -104,6 +104,7 @@ class LoadOpConversion : public LoadStoreOpConversion { Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter); Value *element = intrinsics::load(elementType, ptr); rewriter.replaceOp(op, {element}); + return matchSuccess(); } }; @@ -111,8 +112,8 @@ class LoadOpConversion : public LoadStoreOpConversion { // an LLVM IR store. class StoreOpConversion : public LoadStoreOpConversion { using Base::Base; - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { edsc::ScopedContext edscContext(rewriter, op->getLoc()); Value *viewDescriptor = operands[1]; Value *data = operands[0]; @@ -120,6 +121,7 @@ class StoreOpConversion : public LoadStoreOpConversion { Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter); intrinsics::store(data, ptr); rewriter.replaceOp(op, llvm::None); + return matchSuccess(); } }; diff --git a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp index 189add0..82541f8 100644 --- a/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/EarlyLowering.cpp @@ -87,8 +87,8 @@ public: explicit MulOpConversion(MLIRContext *context) : ConversionPattern(toy::MulOp::getOperationName(), 1, context) {} - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { using namespace edsc; using intrinsics::constant_index; using linalg::intrinsics::range; @@ -117,6 +117,7 @@ public: auto resultView = view(result, {r0, r2}); rewriter.create(loc, lhsView, rhsView, resultView); rewriter.replaceOp(op, {typeCast(rewriter, result, mul.getType())}); + return matchSuccess(); } }; diff --git a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp index ecf6c9d..4434e1b 100644 --- a/mlir/examples/toy/Ch5/mlir/LateLowering.cpp +++ b/mlir/examples/toy/Ch5/mlir/LateLowering.cpp @@ -92,8 +92,8 @@ public: /// the rewritten operands for `op` in the new function. /// The results created by the new IR with the builder are returned, and their /// number must match the number of result of `op`. - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { auto add = cast(op); auto loc = add.getLoc(); // Create a `toy.alloc` operation to allocate the output buffer for this op. @@ -122,6 +122,7 @@ public: // Return the newly allocated buffer, with a type.cast to preserve the // consumers. rewriter.replaceOp(op, {typeCast(rewriter, result, add.getType())}); + return matchSuccess(); } }; @@ -132,8 +133,8 @@ public: explicit PrintOpConversion(MLIRContext *context) : ConversionPattern(toy::PrintOp::getOperationName(), 1, context) {} - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { // Get or create the declaration of the printf function in the module. Function *printfFunc = getPrintf(*op->getFunction()->getModule()); @@ -178,6 +179,7 @@ public: // clang-format on } rewriter.replaceOp(op, llvm::None); + return matchSuccess(); } private: @@ -230,8 +232,8 @@ public: explicit ConstantOpConversion(MLIRContext *context) : ConversionPattern(toy::ConstantOp::getOperationName(), 1, context) {} - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { toy::ConstantOp cstOp = cast(op); auto loc = cstOp.getLoc(); auto retTy = cstOp.getResult()->getType().cast(); @@ -264,6 +266,7 @@ public: } } rewriter.replaceOp(op, result); + return matchSuccess(); } }; @@ -273,8 +276,8 @@ public: explicit TransposeOpConversion(MLIRContext *context) : ConversionPattern(toy::TransposeOp::getOperationName(), 1, context) {} - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { auto transpose = cast(op); auto loc = transpose.getLoc(); Value *result = memRefTypeCast( @@ -296,6 +299,7 @@ public: // clang-format on rewriter.replaceOp(op, {typeCast(rewriter, result, transpose.getType())}); + return matchSuccess(); } }; @@ -305,13 +309,14 @@ public: explicit ReturnOpConversion(MLIRContext *context) : ConversionPattern(toy::ReturnOp::getOperationName(), 1, context) {} - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { // Argument is optional, handle both cases. if (op->getNumOperands()) rewriter.replaceOpWithNewOp(op, operands[0]); else rewriter.replaceOpWithNewOp(op); + return matchSuccess(); } }; diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 8b476c0..af08a1f 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -48,13 +48,6 @@ public: MLIRContext *ctx) : RewritePattern(rootName, benefit, ctx) {} - /// Hook for derived classes to implement matching. Dialect conversion - /// generally unconditionally match the root operation, so default to success - /// here. - virtual PatternMatchResult match(Operation *op) const override { - return matchSuccess(); - } - /// Hook for derived classes to implement rewriting. `op` is the (first) /// operation matched by the pattern, `operands` is a list of rewritten values /// that are passed to this operation, `rewriter` can be used to emit the new @@ -84,14 +77,33 @@ public: llvm_unreachable("unimplemented rewrite for terminators"); } - /// Rewrite the IR rooted at the specified operation with the result of - /// this pattern. If an unexpected error is encountered (an internal compiler - /// error), it is emitted through the normal MLIR diagnostic hooks and the IR - /// is left in a valid state. - void rewrite(Operation *op, PatternRewriter &rewriter) const final; + /// Hook for derived classes to implement combined matching and rewriting. + virtual PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef properOperands, + ArrayRef destinations, + ArrayRef> operands, + PatternRewriter &rewriter) const { + if (!match(op)) + return matchFailure(); + rewrite(op, properOperands, destinations, operands, rewriter); + return matchSuccess(); + } + + /// Hook for derived classes to implement combined matching and rewriting. + virtual PatternMatchResult matchAndRewrite(Operation *op, + ArrayRef operands, + PatternRewriter &rewriter) const { + if (!match(op)) + return matchFailure(); + rewrite(op, operands, rewriter); + return matchSuccess(); + } + + /// Attempt to match and rewrite the IR root at the specified operation. + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final; private: - using RewritePattern::matchAndRewrite; using RewritePattern::rewrite; }; diff --git a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp index 1b50320..dd91f06 100644 --- a/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp +++ b/mlir/lib/LLVMIR/Transforms/ConvertToLLVMDialect.cpp @@ -210,10 +210,6 @@ public: lowering_), dialect(dialect_) {} - PatternMatchResult match(Operation *op) const override { - return this->matchSuccess(); - } - // Get the LLVM IR dialect. LLVM::LLVMDialect &getDialect() const { return dialect; } // Get the LLVM context. @@ -279,8 +275,8 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern { // Convert the type of the result to an LLVM type, pass operands as is, // preserve attributes. - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { unsigned numResults = op->getNumResults(); Type packedType; @@ -296,9 +292,10 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern { // If the operation produced 0 or 1 result, return them immediately. if (numResults == 0) - return rewriter.replaceOp(op, llvm::None); + return rewriter.replaceOp(op, llvm::None), this->matchSuccess(); if (numResults == 1) - return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)); + return rewriter.replaceOp(op, newOp.getOperation()->getResult(0)), + this->matchSuccess(); // Otherwise, it had been converted to an operation producing a structure. // Extract individual results from the structure and return them as list. @@ -311,6 +308,7 @@ struct OneToOneLLVMOpLowering : public LLVMLegalizationPattern { this->getIntegerArrayAttr(rewriter, i))); } rewriter.replaceOp(op, results); + return this->matchSuccess(); } }; @@ -500,8 +498,8 @@ struct AllocOpLowering : public LLVMLegalizationPattern { struct DeallocOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { assert(operands.size() == 1 && "dealloc takes one operand"); OperandAdaptor transformed(operands); @@ -524,6 +522,7 @@ struct DeallocOpLowering : public LLVMLegalizationPattern { op->getLoc(), getVoidPtrType(), bufferPtr); rewriter.replaceOpWithNewOp( op, ArrayRef(), rewriter.getFunctionAttr(freeFunc), casted); + return matchSuccess(); } }; @@ -759,8 +758,8 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { struct LoadOpLowering : public LoadStoreOpLowering { using Base::Base; - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { auto loadOp = cast(op); OperandAdaptor transformed(operands); auto type = loadOp.getMemRefType(); @@ -771,6 +770,7 @@ struct LoadOpLowering : public LoadStoreOpLowering { rewriter.replaceOpWithNewOp(op, elementType, ArrayRef{dataPtr}); + return matchSuccess(); } }; @@ -779,8 +779,8 @@ struct LoadOpLowering : public LoadStoreOpLowering { struct StoreOpLowering : public LoadStoreOpLowering { using Base::Base; - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { auto type = cast(op).getMemRefType(); OperandAdaptor transformed(operands); @@ -788,6 +788,7 @@ struct StoreOpLowering : public LoadStoreOpLowering { transformed.indices(), rewriter, getModule()); rewriter.replaceOpWithNewOp(op, transformed.value(), dataPtr); + return matchSuccess(); } }; @@ -798,12 +799,14 @@ struct OneToOneLLVMTerminatorLowering using LLVMLegalizationPattern::LLVMLegalizationPattern; using Super = OneToOneLLVMTerminatorLowering; - void rewrite(Operation *op, ArrayRef properOperands, - ArrayRef destinations, - ArrayRef> operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, + ArrayRef properOperands, + ArrayRef destinations, + ArrayRef> operands, + PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, properOperands, destinations, operands, op->getAttrs()); + return this->matchSuccess(); } }; @@ -816,21 +819,23 @@ struct OneToOneLLVMTerminatorLowering struct ReturnOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { unsigned numArguments = op->getNumOperands(); // If ReturnOp has 0 or 1 operand, create it and return immediately. if (numArguments == 0) { - return rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, llvm::ArrayRef(), llvm::ArrayRef(), llvm::ArrayRef>(), op->getAttrs()); + return matchSuccess(); } if (numArguments == 1) { - return rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, llvm::ArrayRef(operands.front()), llvm::ArrayRef(), llvm::ArrayRef>(), op->getAttrs()); + return matchSuccess(); } // Otherwise, we need to pack the arguments into an LLVM struct type before @@ -847,6 +852,7 @@ struct ReturnOpLowering : public LLVMLegalizationPattern { rewriter.replaceOpWithNewOp( op, llvm::makeArrayRef(packed), llvm::ArrayRef(), llvm::ArrayRef>(), op->getAttrs()); + return matchSuccess(); } }; diff --git a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp index b3857ac..60c16d9 100644 --- a/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -159,8 +159,8 @@ public: LLVMTypeConverter &lowering_) : LLVMOpLowering(BufferAllocOp::getOperationName(), context, lowering_) {} - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { auto indexType = IndexType::get(op->getContext()); auto voidPtrTy = LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); @@ -204,6 +204,7 @@ public: desc = insertvalue(bufferDescriptorType, desc, size, positionAttr(rewriter, 1)); rewriter.replaceOp(op, desc); + return matchSuccess(); } }; @@ -215,8 +216,8 @@ public: : LLVMOpLowering(BufferDeallocOp::getOperationName(), context, lowering_) {} - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { auto voidPtrTy = LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo(); // Insert the `free` declaration if it is not already present. @@ -239,6 +240,7 @@ public: positionAttr(rewriter, 0))); call(ArrayRef(), rewriter.getFunctionAttr(freeFunc), casted); rewriter.replaceOp(op, llvm::None); + return matchSuccess(); } }; @@ -248,12 +250,13 @@ public: BufferSizeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(BufferSizeOp::getOperationName(), context, lowering_) {} - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { auto int64Ty = lowering.convertType(operands[0]->getType()); edsc::ScopedContext context(rewriter, op->getLoc()); rewriter.replaceOp( op, {extractvalue(int64Ty, operands[0], positionAttr(rewriter, 1))}); + return matchSuccess(); } }; @@ -263,8 +266,8 @@ public: explicit DimOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(linalg::DimOp::getOperationName(), context, lowering_) {} - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { auto dimOp = cast(op); auto indexTy = lowering.convertType(rewriter.getIndexType()); edsc::ScopedContext context(rewriter, op->getLoc()); @@ -273,6 +276,7 @@ public: {extractvalue( indexTy, operands[0], positionAttr(rewriter, {2, static_cast(dimOp.getIndex())}))}); + return matchSuccess(); } }; @@ -318,14 +322,15 @@ public: // an LLVM IR load. class LoadOpConversion : public LoadStoreOpConversion { using Base::Base; - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { edsc::ScopedContext edscContext(rewriter, op->getLoc()); auto elementTy = lowering.convertType(*op->result_type_begin()); Value *viewDescriptor = operands[0]; ArrayRef indices = operands.drop_front(); auto ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter); rewriter.replaceOp(op, {llvm_load(elementTy, ptr)}); + return matchSuccess(); } }; @@ -335,8 +340,8 @@ public: explicit RangeOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(RangeOp::getOperationName(), context, lowering_) {} - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { auto rangeOp = cast(op); auto rangeDescriptorTy = convertLinalgType(rangeOp.getResult()->getType(), lowering); @@ -352,6 +357,7 @@ public: desc = insertvalue(rangeDescriptorTy, desc, operands[2], positionAttr(rewriter, 2)); rewriter.replaceOp(op, desc); + return matchSuccess(); } }; @@ -363,8 +369,8 @@ public: : LLVMOpLowering(RangeIntersectOp::getOperationName(), context, lowering_) {} - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { auto rangeIntersectOp = cast(op); auto rangeDescriptorTy = convertLinalgType(rangeIntersectOp.getResult()->getType(), lowering); @@ -397,6 +403,7 @@ public: desc = insertvalue(rangeDescriptorTy, desc, mul(step1, step2), positionAttr(rewriter, 2)); rewriter.replaceOp(op, desc); + return matchSuccess(); } }; @@ -405,8 +412,8 @@ public: explicit SliceOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(SliceOp::getOperationName(), context, lowering_) {} - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { auto sliceOp = cast(op); auto viewDescriptorTy = convertLinalgType(sliceOp.getViewType(), lowering); auto viewType = sliceOp.getBaseViewType(); @@ -477,6 +484,7 @@ public: } rewriter.replaceOp(op, desc); + return matchSuccess(); } }; @@ -484,8 +492,8 @@ public: // an LLVM IR store. class StoreOpConversion : public LoadStoreOpConversion { using Base::Base; - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { edsc::ScopedContext edscContext(rewriter, op->getLoc()); Value *data = operands[0]; Value *viewDescriptor = operands[1]; @@ -493,6 +501,7 @@ class StoreOpConversion : public LoadStoreOpConversion { Value *ptr = obtainDataPtr(op, viewDescriptor, indices, rewriter); llvm_store(data, ptr); rewriter.replaceOp(op, llvm::None); + return matchSuccess(); } }; @@ -501,8 +510,8 @@ public: explicit ViewOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) : LLVMOpLowering(ViewOp::getOperationName(), context, lowering_) {} - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { auto viewOp = cast(op); auto viewDescriptorTy = convertLinalgType(viewOp.getViewType(), lowering); auto elementTy = getPtrToElementType(viewOp.getViewType(), lowering); @@ -548,6 +557,7 @@ public: } rewriter.replaceOp(op, desc); + return matchSuccess(); } }; @@ -560,20 +570,21 @@ public: static StringRef libraryFunctionName() { return "linalg_dot"; } - void rewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(Operation *op, ArrayRef operands, + PatternRewriter &rewriter) const override { auto *f = op->getFunction()->getModule()->getNamedFunction(libraryFunctionName()); if (!f) { op->emitError("Could not find function: " + libraryFunctionName() + "in lowering to LLVM "); - return; + return matchFailure(); } auto fAttr = rewriter.getFunctionAttr(f); auto named = rewriter.getNamedAttr("callee", fAttr); rewriter.replaceOpWithNewOp(op, operands, ArrayRef{named}); + return matchSuccess(); } }; diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp index 6647434..00ae605 100644 --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -226,19 +226,17 @@ struct DialectConversionRewriter final : public PatternRewriter { // ConversionPattern //===----------------------------------------------------------------------===// -/// Rewrite the IR rooted at the specified operation with the result of this -/// pattern. If an unexpected error is encountered (an internal compiler -/// error), it is emitted through the normal MLIR diagnostic hooks and the IR is -/// left in a valid state. -void ConversionPattern::rewrite(Operation *op, - PatternRewriter &rewriter) const { +/// Attempt to match and rewrite the IR root at the specified operation. +PatternMatchResult +ConversionPattern::matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const { SmallVector operands; auto &dialectRewriter = static_cast(rewriter); dialectRewriter.remapValues(op->getOperands(), operands); // If this operation has no successors, invoke the rewrite directly. if (op->getNumSuccessors() == 0) - return rewrite(op, operands, rewriter); + return matchAndRewrite(op, operands, rewriter); // Otherwise, we need to remap the successors. SmallVector destinations; @@ -257,10 +255,11 @@ void ConversionPattern::rewrite(Operation *op, } // Rewrite the operation. - rewrite(op, - llvm::makeArrayRef(operands.data(), - operands.data() + firstSuccessorOperand), - destinations, operandsPerDestination, rewriter); + return matchAndRewrite( + op, + llvm::makeArrayRef(operands.data(), + operands.data() + firstSuccessorOperand), + destinations, operandsPerDestination, rewriter); } //===----------------------------------------------------------------------===// -- 2.7.4