From: River Riddle Date: Sun, 26 May 2019 00:22:27 +0000 (-0700) Subject: Add a templated wrapper around RewritePattern that allows for defining match... X-Git-Tag: llvmorg-11-init~1466^2~1594 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=9e21ab8f522265d37159372dbce96f66488c4e34;p=platform%2Fupstream%2Fllvm.git Add a templated wrapper around RewritePattern that allows for defining match/rewrite methods with an instance of the source op instead of a raw Operation*. -- PiperOrigin-RevId: 250003405 --- diff --git a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp index 621bc26..0fe70e2 100644 --- a/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/Transforms.cpp @@ -248,12 +248,12 @@ namespace { /// mlir::StoreOp requires finding the proper indexing in the supporting MemRef. /// This is most easily achieved by calling emitAndReturnFullyComposedView to /// fold away all the SliceOp. -template struct Rewriter : public RewritePattern { - explicit Rewriter(MLIRContext *context) - : RewritePattern(LoadOrStoreOpTy::getOperationName(), 1, context) {} +template +struct Rewriter : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; /// Performs the rewrite. - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(LoadOrStoreOpTy op, PatternRewriter &rewriter) const override; }; @@ -270,9 +270,8 @@ struct LowerLinalgLoadStorePass template <> PatternMatchResult -Rewriter::matchAndRewrite(Operation *op, +Rewriter::matchAndRewrite(linalg::LoadOp load, PatternRewriter &rewriter) const { - auto load = cast(op); SliceOp slice = dyn_cast(load.getView()->getDefiningOp()); ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult()) : cast(load.getView()->getDefiningOp()); @@ -280,15 +279,14 @@ Rewriter::matchAndRewrite(Operation *op, ScopedContext scope(builder, load.getLoc()); auto *memRef = view.getSupportingMemRef(); auto operands = emitAndReturnLoadStoreOperands(load, view); - rewriter.replaceOpWithNewOp(op, memRef, operands); + rewriter.replaceOpWithNewOp(load, memRef, operands); return matchSuccess(); } template <> PatternMatchResult -Rewriter::matchAndRewrite(Operation *op, +Rewriter::matchAndRewrite(linalg::StoreOp store, PatternRewriter &rewriter) const { - auto store = cast(op); SliceOp slice = dyn_cast(store.getView()->getDefiningOp()); ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult()) : cast(store.getView()->getDefiningOp()); @@ -297,7 +295,7 @@ Rewriter::matchAndRewrite(Operation *op, auto *valueToStore = store.getValueToStore(); auto *memRef = view.getSupportingMemRef(); auto operands = emitAndReturnLoadStoreOperands(store, view); - rewriter.replaceOpWithNewOp(op, valueToStore, memRef, + rewriter.replaceOpWithNewOp(store, valueToStore, memRef, operands); return matchSuccess(); } diff --git a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp index 37aa47f..8baa45c 100644 --- a/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch4/mlir/ToyCombine.cpp @@ -33,25 +33,21 @@ namespace toy { namespace { /// Fold transpose(transpose(x) -> transpose(x) -struct SimplifyRedundantTranspose : public mlir::RewritePattern { +struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { /// We register this pattern to match every toy.transpose in the IR. /// The "benefit" is used by the framework to order the patterns and process /// them in order of profitability. SimplifyRedundantTranspose(mlir::MLIRContext *context) - : RewritePattern(TransposeOp::getOperationName(), /* benefit = */ 1, - context) {} + : OpRewritePattern(context, /*benefit=*/1) {} /// This method is attempting to match a pattern and rewrite it. The rewriter /// argument is the orchestrator of the sequence of rewrites. It is expected /// to interact with it to perform any changes to the IR from here. mlir::PatternMatchResult - matchAndRewrite(mlir::Operation *op, + matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { - // We can directly cast the current operation as this will only get invoked - // on TransposeOp. - TransposeOp transpose = llvm::cast(op); // Look through the input of the current transpose. - mlir::Value *transposeInput = transpose.getOperand(); + mlir::Value *transposeInput = op.getOperand(); TransposeOp transposeInputOp = llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); // If the input is defined by another Transpose, bingo! @@ -65,15 +61,12 @@ struct SimplifyRedundantTranspose : public mlir::RewritePattern { }; /// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place. -struct SimplifyReshapeConstant : public mlir::RewritePattern { - SimplifyReshapeConstant(mlir::MLIRContext *context) - : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1, - context) {} +struct SimplifyReshapeConstant : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; mlir::PatternMatchResult - matchAndRewrite(mlir::Operation *op, + matchAndRewrite(ReshapeOp reshape, mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = llvm::cast(op); // Look through the input of the current reshape. ConstantOp constantOp = llvm::dyn_cast_or_null( reshape.getOperand()->getDefiningOp()); @@ -81,7 +74,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern { if (!constantOp) return matchFailure(); - auto reshapeType = op->getResult(0)->getType().cast(); + auto reshapeType = reshape.getType().cast(); if (auto valueAttr = constantOp.getAttrOfType("value")) { // FIXME Check matching of element count! @@ -90,7 +83,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern { reshapeType.getShape(), valueAttr.getType().getElementType()); auto newAttr = mlir::DenseElementsAttr::get(newType, valueAttr.getRawData()); - rewriter.replaceOpWithNewOp(op, reshapeType.getShape(), + rewriter.replaceOpWithNewOp(reshape, reshapeType.getShape(), newAttr); } else if (auto valueAttr = constantOp.getAttrOfType("value")) { @@ -102,7 +95,7 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern { auto tensorTy = rewriter.getTensorType(reshapeType.getShape(), reshapeType.getElementType()); auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data); - rewriter.replaceOpWithNewOp(op, reshapeType.getShape(), + rewriter.replaceOpWithNewOp(reshape, reshapeType.getShape(), newAttr); } else { llvm_unreachable("Unsupported Constant format"); @@ -112,17 +105,15 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern { }; /// Fold reshape(reshape(x)) -> reshape(x) -struct SimplifyReshapeReshape : public mlir::RewritePattern { - SimplifyReshapeReshape(mlir::MLIRContext *context) - : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1, - context) {} +struct SimplifyReshapeReshape : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; mlir::PatternMatchResult - matchAndRewrite(mlir::Operation *op, + matchAndRewrite(ReshapeOp op, mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = llvm::cast(op); // Look through the input of the current reshape. - mlir::Value *reshapeInput = reshape.getOperand(); + mlir::Value *reshapeInput = op.getOperand(); + // If the input is defined by another reshape, bingo! if (!matchPattern(reshapeInput, mlir::m_Op())) return matchFailure(); @@ -134,18 +125,15 @@ struct SimplifyReshapeReshape : public mlir::RewritePattern { }; /// Fold reshape(x)) -> x, when input type matches output type -struct SimplifyNullReshape : public mlir::RewritePattern { - SimplifyNullReshape(mlir::MLIRContext *context) - : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1, - context) {} +struct SimplifyNullReshape : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; mlir::PatternMatchResult - matchAndRewrite(mlir::Operation *op, + matchAndRewrite(ReshapeOp op, mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = llvm::cast(op); - if (reshape.getOperand()->getType() != reshape.getResult()->getType()) + if (op.getOperand()->getType() != op.getType()) return matchFailure(); - rewriter.replaceOp(reshape, {reshape.getOperand()}); + rewriter.replaceOp(op, {op.getOperand()}); return matchSuccess(); } }; diff --git a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp index 6e05eaf..64bd2c9 100644 --- a/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp +++ b/mlir/examples/toy/Ch5/mlir/ToyCombine.cpp @@ -22,7 +22,7 @@ #include "toy/Dialect.h" -#include "mlir/IR/Operation.h" +#include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" @@ -32,30 +32,26 @@ namespace toy { namespace { -/// Fold transpose(transpose(x)) -> transpose(x) -struct SimplifyRedundantTranspose : public mlir::RewritePattern { +/// Fold transpose(transpose(x) -> transpose(x) +struct SimplifyRedundantTranspose : public mlir::OpRewritePattern { /// We register this pattern to match every toy.transpose in the IR. /// The "benefit" is used by the framework to order the patterns and process /// them in order of profitability. SimplifyRedundantTranspose(mlir::MLIRContext *context) - : RewritePattern(TransposeOp::getOperationName(), /* benefit = */ 1, - context) {} + : OpRewritePattern(context, /*benefit=*/1) {} /// This method is attempting to match a pattern and rewrite it. The rewriter /// argument is the orchestrator of the sequence of rewrites. It is expected /// to interact with it to perform any changes to the IR from here. mlir::PatternMatchResult - matchAndRewrite(mlir::Operation *op, + matchAndRewrite(TransposeOp op, mlir::PatternRewriter &rewriter) const override { - // We can directly cast the current operation as this will only get invoked - // on TransposeOp. - TransposeOp transpose = llvm::cast(op); - // look through the input to the current transpose - mlir::Value *transposeInput = transpose.getOperand(); - mlir::Operation *transposeInputInst = transposeInput->getDefiningOp(); - // If the input is defined by another Transpose, bingo! + // Look through the input of the current transpose. + mlir::Value *transposeInput = op.getOperand(); TransposeOp transposeInputOp = - mlir::dyn_cast_or_null(transposeInputInst); + llvm::dyn_cast_or_null(transposeInput->getDefiningOp()); + + // If the input is defined by another Transpose, bingo! if (!transposeInputOp) return matchFailure(); @@ -66,25 +62,21 @@ struct SimplifyRedundantTranspose : public mlir::RewritePattern { }; /// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place. -struct SimplifyReshapeConstant : public mlir::RewritePattern { - SimplifyReshapeConstant(mlir::MLIRContext *context) - : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1, - context) {} +struct SimplifyReshapeConstant : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; mlir::PatternMatchResult - matchAndRewrite(mlir::Operation *op, + matchAndRewrite(ReshapeOp reshape, mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = llvm::cast(op); - // look through the input to the current reshape - mlir::Value *reshapeInput = reshape.getOperand(); - mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp(); - // If the input is defined by another reshape, bingo! - ConstantOp constantOp = - mlir::dyn_cast_or_null(reshapeInputInst); + // Look through the input of the current reshape. + ConstantOp constantOp = llvm::dyn_cast_or_null( + reshape.getOperand()->getDefiningOp()); + + // If the input is defined by another constant, bingo! if (!constantOp) return matchFailure(); - auto reshapeType = op->getResult(0)->getType().cast(); + auto reshapeType = reshape.getType().cast(); if (auto valueAttr = constantOp.getAttrOfType("value")) { // FIXME Check matching of element count! @@ -93,9 +85,8 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern { reshapeType.getShape(), valueAttr.getType().getElementType()); auto newAttr = mlir::DenseElementsAttr::get(newType, valueAttr.getRawData()); - auto newConstant = rewriter.create( - constantOp.getLoc(), reshapeType.getShape(), newAttr); - rewriter.replaceOp(op, {newConstant}); + rewriter.replaceOpWithNewOp(reshape, reshapeType.getShape(), + newAttr); } else if (auto valueAttr = constantOp.getAttrOfType("value")) { // Broadcast @@ -106,9 +97,8 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern { auto tensorTy = rewriter.getTensorType(reshapeType.getShape(), reshapeType.getElementType()); auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data); - auto newConstant = rewriter.create( - constantOp.getLoc(), reshapeType.getShape(), newAttr); - rewriter.replaceOp(op, {newConstant}); + rewriter.replaceOpWithNewOp(reshape, reshapeType.getShape(), + newAttr); } else { llvm_unreachable("Unsupported Constant format"); } @@ -117,43 +107,35 @@ struct SimplifyReshapeConstant : public mlir::RewritePattern { }; /// Fold reshape(reshape(x)) -> reshape(x) -struct SimplifyReshapeReshape : public mlir::RewritePattern { - SimplifyReshapeReshape(mlir::MLIRContext *context) - : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1, - context) {} +struct SimplifyReshapeReshape : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; mlir::PatternMatchResult - matchAndRewrite(mlir::Operation *op, + matchAndRewrite(ReshapeOp op, mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = llvm::cast(op); - // look through the input to the current reshape - mlir::Value *reshapeInput = reshape.getOperand(); - mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp(); + // Look through the input of the current reshape. + mlir::Value *reshapeInput = op.getOperand(); + // If the input is defined by another reshape, bingo! - ReshapeOp reshapeInputOp = - mlir::dyn_cast_or_null(reshapeInputInst); - if (!reshapeInputOp) + if (!matchPattern(reshapeInput, mlir::m_Op())) return matchFailure(); // Use the rewriter to perform the replacement - rewriter.replaceOp(op, {reshapeInputOp}); + rewriter.replaceOp(op, {reshapeInput}); return matchSuccess(); } }; /// Fold reshape(x)) -> x, when input type matches output type -struct SimplifyNullReshape : public mlir::RewritePattern { - SimplifyNullReshape(mlir::MLIRContext *context) - : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1, - context) {} +struct SimplifyNullReshape : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; mlir::PatternMatchResult - matchAndRewrite(mlir::Operation *op, + matchAndRewrite(ReshapeOp op, mlir::PatternRewriter &rewriter) const override { - ReshapeOp reshape = llvm::cast(op); - if (reshape.getOperand()->getType() != reshape.getResult()->getType()) + if (op.getOperand()->getType() != op.getType()) return matchFailure(); - rewriter.replaceOp(reshape, {reshape.getOperand()}); + rewriter.replaceOp(op, {op.getOperand()}); return matchSuccess(); } }; @@ -176,17 +158,14 @@ void ReshapeOp::getCanonicalizationPatterns( namespace { /// Fold type.cast(x) -> x, when input type matches output type -struct SimplifyIdentityTypeCast : public mlir::RewritePattern { - SimplifyIdentityTypeCast(mlir::MLIRContext *context) - : RewritePattern(TypeCastOp::getOperationName(), /* benefit = */ 1, - context) {} +struct SimplifyIdentityTypeCast : public mlir::OpRewritePattern { + using mlir::OpRewritePattern::OpRewritePattern; mlir::PatternMatchResult - matchAndRewrite(mlir::Operation *op, + matchAndRewrite(TypeCastOp typeCast, mlir::PatternRewriter &rewriter) const override { - TypeCastOp typeCast = llvm::cast(op); - auto resTy = typeCast.getResult()->getType(); - auto *candidateOp = op; + auto resTy = typeCast.getType(); + auto *candidateOp = typeCast.getOperation(); while (llvm::isa_and_nonnull(candidateOp)) { if (resTy == candidateOp->getOperand(0)->getType()) { rewriter.replaceOp(typeCast, {candidateOp->getOperand(0)}); diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 60c8255..bbca58b 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -205,6 +205,53 @@ protected: llvm::SmallVector generatedOps; }; +/// OpRewritePattern is a wrapper around RewritePattern that allows for +/// matching and rewriting against an instance of a derived operation class as +/// opposed to a raw Operation. +template struct OpRewritePattern : public RewritePattern { + /// Patterns must specify the root operation name they match against, and can + /// also specify the benefit of the pattern matching. + OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1) + : RewritePattern(SourceOp::getOperationName(), benefit, context) {} + + /// Wrappers around the RewritePattern methods that pass the derived op type. + void rewrite(Operation *op, std::unique_ptr state, + PatternRewriter &rewriter) const final { + rewrite(llvm::cast(op), std::move(state), rewriter); + } + void rewrite(Operation *op, PatternRewriter &rewriter) const final { + rewrite(llvm::cast(op), rewriter); + } + PatternMatchResult match(Operation *op) const final { + return match(llvm::cast(op)); + } + PatternMatchResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const final { + return matchAndRewrite(llvm::cast(op), rewriter); + } + + /// Rewrite and Match methods that operate on the SourceOp type. These must be + /// overridden by the derived pattern class. + virtual void rewrite(SourceOp op, std::unique_ptr state, + PatternRewriter &rewriter) const { + rewrite(op, rewriter); + } + virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const { + llvm_unreachable("must override matchAndRewrite or a rewrite method"); + } + virtual PatternMatchResult match(SourceOp op) const { + llvm_unreachable("must override match or matchAndRewrite"); + } + virtual PatternMatchResult matchAndRewrite(SourceOp op, + PatternRewriter &rewriter) const { + if (auto matchResult = match(op)) { + rewrite(op, std::move(*matchResult), rewriter); + return matchSuccess(); + } + return matchFailure(); + } +}; + //===----------------------------------------------------------------------===// // PatternRewriter class //===----------------------------------------------------------------------===// diff --git a/mlir/lib/AffineOps/AffineOps.cpp b/mlir/lib/AffineOps/AffineOps.cpp index 130cb15..1a34c71 100644 --- a/mlir/lib/AffineOps/AffineOps.cpp +++ b/mlir/lib/AffineOps/AffineOps.cpp @@ -654,24 +654,21 @@ void mlir::canonicalizeMapAndOperands( namespace { /// Simplify AffineApply operations. /// -struct SimplifyAffineApply : public RewritePattern { - SimplifyAffineApply(MLIRContext *context) - : RewritePattern(AffineApplyOp::getOperationName(), 1, context) {} +struct SimplifyAffineApply : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(AffineApplyOp apply, PatternRewriter &rewriter) const override { - auto apply = cast(op); auto map = apply.getAffineMap(); AffineMap oldMap = map; SmallVector resultOperands(apply.getOperands()); composeAffineMapAndOperands(&map, &resultOperands); - if (map != oldMap) { - rewriter.replaceOpWithNewOp(op, map, resultOperands); - return matchSuccess(); - } + if (map == oldMap) + return matchFailure(); - return matchFailure(); + rewriter.replaceOpWithNewOp(apply, map, resultOperands); + return matchSuccess(); } }; } // end anonymous namespace. @@ -1002,14 +999,11 @@ void AffineForOp::print(OpAsmPrinter *p) { namespace { /// This is a pattern to fold constant loop bounds. -struct AffineForLoopBoundFolder : public RewritePattern { - /// The rootOpName is the name of the root operation to match against. - AffineForLoopBoundFolder(MLIRContext *context) - : RewritePattern(AffineForOp::getOperationName(), 1, context) {} +struct AffineForLoopBoundFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(AffineForOp forOp, PatternRewriter &rewriter) const override { - auto forOp = cast(op); auto foldLowerOrUpperBound = [&forOp](bool lower) { // Check to see if each of the operands is the result of a constant. If // so, get the value. If not, ignore it. @@ -1056,7 +1050,7 @@ struct AffineForLoopBoundFolder : public RewritePattern { // If any of the bounds were folded we return success. if (!folded) return matchFailure(); - rewriter.updatedRootInPlace(op); + rewriter.updatedRootInPlace(forOp); return matchSuccess(); } }; diff --git a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp index 32d8de3..2a752c2 100644 --- a/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp +++ b/mlir/lib/Dialect/FxpMathOps/Transforms/LowerUniformRealMath.cpp @@ -118,15 +118,13 @@ static Value *emitDequantize(Location loc, Value *input, namespace { -struct UniformDequantizePattern : public RewritePattern { - UniformDequantizePattern(MLIRContext *context) - : RewritePattern(DequantizeCastOp::getOperationName(), 1, context) {} +struct UniformDequantizePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(DequantizeCastOp op, PatternRewriter &rewriter) const { - auto dcastOp = cast(op); - Type inputType = dcastOp.arg()->getType(); - Type outputType = dcastOp.getResult()->getType(); + Type inputType = op.arg()->getType(); + Type outputType = op.getResult()->getType(); QuantizedType inputElementType = QuantizedType::getQuantizedElementType(inputType); @@ -136,8 +134,7 @@ struct UniformDequantizePattern : public RewritePattern { return matchFailure(); } - Value *dequantizedValue = - emitDequantize(dcastOp.getLoc(), dcastOp.arg(), rewriter); + Value *dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter); if (!dequantizedValue) { return matchFailure(); } @@ -322,15 +319,13 @@ tryRewriteAffineMulEwSigned(const UniformBinaryOpInfo &info, namespace { -struct UniformRealAddEwPattern : public RewritePattern { - UniformRealAddEwPattern(MLIRContext *context) - : RewritePattern(RealAddEwOp::getOperationName(), 1, context) {} +struct UniformRealAddEwPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(RealAddEwOp op, PatternRewriter &rewriter) const { - auto addOp = cast(op); - const UniformBinaryOpInfo info(op, addOp.lhs(), addOp.rhs(), - addOp.clamp_min(), addOp.clamp_max()); + const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(), + op.clamp_max()); if (!info.isValid()) { return matchFailure(); } @@ -344,15 +339,13 @@ struct UniformRealAddEwPattern : public RewritePattern { } }; -struct UniformRealMulEwPattern : public RewritePattern { - UniformRealMulEwPattern(MLIRContext *context) - : RewritePattern(RealMulEwOp::getOperationName(), 1, context) {} +struct UniformRealMulEwPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(RealMulEwOp op, PatternRewriter &rewriter) const { - auto mulOp = cast(op); - const UniformBinaryOpInfo info(op, mulOp.lhs(), mulOp.rhs(), - mulOp.clamp_min(), mulOp.clamp_max()); + const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(), + op.clamp_max()); if (!info.isValid()) { return matchFailure(); } diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp index fb5b2e1..e237e8b 100644 --- a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp @@ -38,26 +38,21 @@ namespace { /// Matches x -> [scast -> scast] -> y, replacing the second scast with the /// value of x if the casts invert each other. -class RemoveRedundantStorageCastsRewrite : public RewritePattern { +class RemoveRedundantStorageCastsRewrite + : public OpRewritePattern { public: - RemoveRedundantStorageCastsRewrite(MLIRContext *context) - : RewritePattern(StorageCastOp::getOperationName(), 1, context) {} + using OpRewritePattern::OpRewritePattern; - PatternMatchResult match(Operation *op) const override { - auto scastOp = cast(op); - if (matchPattern(scastOp.arg(), m_Op())) { - auto srcScastOp = cast(scastOp.arg()->getDefiningOp()); - if (srcScastOp.arg()->getType() == scastOp.getResult()->getType()) { - return matchSuccess(); - } - } - return matchFailure(); - } + PatternMatchResult matchAndRewrite(StorageCastOp op, + PatternRewriter &rewriter) const override { + if (!matchPattern(op.arg(), m_Op())) + return matchFailure(); + auto srcScastOp = cast(op.arg()->getDefiningOp()); + if (srcScastOp.arg()->getType() != op.getType()) + return matchFailure(); - void rewrite(Operation *op, PatternRewriter &rewriter) const override { - auto scastOp = cast(op); - auto srcScastOp = cast(scastOp.arg()->getDefiningOp()); rewriter.replaceOp(op, srcScastOp.arg()); + return matchSuccess(); } }; diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp index 44b1156..0c8ba31 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp @@ -36,40 +36,35 @@ public: void runOnFunction() override; }; -class QuantizedConstRewrite : public RewritePattern { -public: - struct State : PatternState { - QuantizedType quantizedElementType; - Attribute value; - }; - - QuantizedConstRewrite(MLIRContext *context) - : RewritePattern(QuantizeCastOp::getOperationName(), 1, context) {} +struct QuantizedConstRewrite : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - PatternMatchResult match(Operation *op) const override; - void rewrite(Operation *op, std::unique_ptr baseState, - PatternRewriter &rewriter) const override; + PatternMatchResult matchAndRewrite(QuantizeCastOp qbarrier, + PatternRewriter &rewriter) const override; }; } // end anonymous namespace /// Matches a [constant] -> [qbarrier] where the qbarrier results type is /// quantized and the operand type is quantizable. -PatternMatchResult QuantizedConstRewrite::match(Operation *op) const { - State state; + +PatternMatchResult +QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier, + PatternRewriter &rewriter) const { + Attribute value; // Is the operand a constant? - auto qbarrier = cast(op); - if (!matchPattern(qbarrier.arg(), m_Constant(&state.value))) { + if (!matchPattern(qbarrier.arg(), m_Constant(&value))) { return matchFailure(); } + // Does the qbarrier convert to a quantized type. This will not be true // if a quantized type has not yet been chosen or if the cast to an equivalent // storage type is not supported. Type qbarrierResultType = qbarrier.getResult()->getType(); - state.quantizedElementType = + QuantizedType quantizedElementType = QuantizedType::getQuantizedElementType(qbarrierResultType); - if (!state.quantizedElementType) { + if (!quantizedElementType) { return matchFailure(); } if (!QuantizedType::castToStorageType(qbarrierResultType)) { @@ -79,43 +74,34 @@ PatternMatchResult QuantizedConstRewrite::match(Operation *op) const { // Is the operand type compatible with the expressed type of the quantized // type? This will not be true if the qbarrier is superfluous (converts // from and to a quantized type). - if (!state.quantizedElementType.isCompatibleExpressedType( + if (!quantizedElementType.isCompatibleExpressedType( qbarrier.arg()->getType())) { return matchFailure(); } // Is the constant value a type expressed in a way that we support? - if (!state.value.isa() && !state.value.isa() && - !state.value.isa() && - !state.value.isa()) { + if (!value.isa() && !value.isa() && + !value.isa() && !value.isa()) { return matchFailure(); } - return matchSuccess(llvm::make_unique(std::move(state))); -} - -void QuantizedConstRewrite::rewrite(Operation *op, - std::unique_ptr baseState, - PatternRewriter &rewriter) const { - auto state = static_cast(baseState.get()); - Type newConstValueType; - Attribute newConstValue = quantizeAttr( - state->value, state->quantizedElementType, newConstValueType); + auto newConstValue = + quantizeAttr(value, quantizedElementType, newConstValueType); if (!newConstValue) { - return; + return matchFailure(); } - auto *origConstOp = op->getOperand(0); // When creating the new const op, use a fused location that combines the // original const and the qbarrier that led to the quantization. - auto fusedLoc = - FusedLoc::get({origConstOp->getDefiningOp()->getLoc(), op->getLoc()}, - rewriter.getContext()); + auto fusedLoc = FusedLoc::get( + {qbarrier.arg()->getDefiningOp()->getLoc(), qbarrier.getLoc()}, + rewriter.getContext()); auto newConstOp = rewriter.create(fusedLoc, newConstValueType, newConstValue); - rewriter.replaceOpWithNewOp( - {origConstOp}, op, *op->result_type_begin(), newConstOp); + rewriter.replaceOpWithNewOp({qbarrier.arg()}, qbarrier, + qbarrier.getType(), newConstOp); + return matchSuccess(); } void ConvertConstPass::runOnFunction() { diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 508ebfe..dd67546 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -291,24 +291,19 @@ static LogicalResult verify(AllocOp op) { namespace { /// Fold constant dimensions into an alloc operation. -struct SimplifyAllocConst : public RewritePattern { - SimplifyAllocConst(MLIRContext *context) - : RewritePattern(AllocOp::getOperationName(), 1, context) {} - - PatternMatchResult match(Operation *op) const override { - auto alloc = cast(op); +struct SimplifyAllocConst : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + PatternMatchResult matchAndRewrite(AllocOp alloc, + PatternRewriter &rewriter) const override { // Check to see if any dimensions operands are constants. If so, we can // substitute and drop them. - for (auto *operand : alloc.getOperands()) - if (matchPattern(operand, m_ConstantIndex())) - return matchSuccess(); - return matchFailure(); - } + if (llvm::none_of(alloc.getOperands(), [](Value *operand) { + return matchPattern(operand, m_ConstantIndex()); + })) + return matchFailure(); - void rewrite(Operation *op, PatternRewriter &rewriter) const override { - auto allocOp = cast(op); - auto memrefType = allocOp.getType(); + auto memrefType = alloc.getType(); // Ok, we have one or more constant operands. Collect the non-constant ones // and keep track of the resultant memref type to build. @@ -325,7 +320,7 @@ struct SimplifyAllocConst : public RewritePattern { newShapeConstants.push_back(dimSize); continue; } - auto *defOp = allocOp.getOperand(dynamicDimPos)->getDefiningOp(); + auto *defOp = alloc.getOperand(dynamicDimPos)->getDefiningOp(); if (auto constantIndexOp = dyn_cast_or_null(defOp)) { // Dynamic shape dimension will be folded. newShapeConstants.push_back(constantIndexOp.getValue()); @@ -334,7 +329,7 @@ struct SimplifyAllocConst : public RewritePattern { } else { // Dynamic shape dimension not folded; copy operand from old memref. newShapeConstants.push_back(-1); - newOperands.push_back(allocOp.getOperand(dynamicDimPos)); + newOperands.push_back(alloc.getOperand(dynamicDimPos)); } dynamicDimPos++; } @@ -347,30 +342,29 @@ struct SimplifyAllocConst : public RewritePattern { // Create and insert the alloc op for the new memref. auto newAlloc = - rewriter.create(allocOp.getLoc(), newMemRefType, newOperands); + rewriter.create(alloc.getLoc(), newMemRefType, newOperands); // Insert a cast so we have the same type as the old alloc. - auto resultCast = rewriter.create(allocOp.getLoc(), newAlloc, - allocOp.getType()); + auto resultCast = rewriter.create(alloc.getLoc(), newAlloc, + alloc.getType()); - rewriter.replaceOp(op, {resultCast}, droppedOperands); + rewriter.replaceOp(alloc, {resultCast}, droppedOperands); + return matchSuccess(); } }; /// Fold alloc operations with no uses. Alloc has side effects on the heap, /// but can still be deleted if it has zero uses. -struct SimplifyDeadAlloc : public RewritePattern { - SimplifyDeadAlloc(MLIRContext *context) - : RewritePattern(AllocOp::getOperationName(), 1, context) {} +struct SimplifyDeadAlloc : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(AllocOp alloc, PatternRewriter &rewriter) const override { // Check if the alloc'ed value has any uses. - auto alloc = cast(op); if (!alloc.use_empty()) return matchFailure(); // If it doesn't, we can eliminate it. - op->erase(); + alloc.erase(); return matchSuccess(); } }; @@ -484,24 +478,22 @@ FunctionType CallOp::getCalleeType() { //===----------------------------------------------------------------------===// namespace { /// Fold indirect calls that have a constant function as the callee operand. -struct SimplifyIndirectCallWithKnownCallee : public RewritePattern { - SimplifyIndirectCallWithKnownCallee(MLIRContext *context) - : RewritePattern(CallIndirectOp::getOperationName(), 1, context) {} +struct SimplifyIndirectCallWithKnownCallee + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall, PatternRewriter &rewriter) const override { - auto indirectCall = cast(op); - // Check that the callee is a constant callee. FunctionAttr calledFn; if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) return matchFailure(); // Replace with a direct call. - SmallVector callResults(op->getResultTypes()); + SmallVector callResults(indirectCall.getResultTypes()); SmallVector callOperands(indirectCall.getArgOperands()); - rewriter.replaceOpWithNewOp(op, calledFn.getValue(), callResults, - callOperands); + rewriter.replaceOpWithNewOp(indirectCall, calledFn.getValue(), + callResults, callOperands); return matchSuccess(); } }; @@ -964,14 +956,11 @@ namespace { /// cond_br true, ^bb1, ^bb2 -> br ^bb1 /// cond_br false, ^bb1, ^bb2 -> br ^bb2 /// -struct SimplifyConstCondBranchPred : public RewritePattern { - SimplifyConstCondBranchPred(MLIRContext *context) - : RewritePattern(CondBranchOp::getOperationName(), 1, context) {} +struct SimplifyConstCondBranchPred : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(CondBranchOp condbr, PatternRewriter &rewriter) const override { - auto condbr = cast(op); - // Check that the condition is a constant. if (!matchPattern(condbr.getCondition(), m_Op())) return matchFailure(); @@ -991,7 +980,7 @@ struct SimplifyConstCondBranchPred : public RewritePattern { branchArgs.assign(condbr.true_operand_begin(), condbr.true_operand_end()); } - rewriter.replaceOpWithNewOp(op, foldedDest, branchArgs); + rewriter.replaceOpWithNewOp(condbr, foldedDest, branchArgs); return matchSuccess(); } }; @@ -1230,18 +1219,14 @@ void ConstantIndexOp::build(Builder *builder, OperationState *result, namespace { /// Fold Dealloc operations that are deallocating an AllocOp that is only used /// by other Dealloc operations. -struct SimplifyDeadDealloc : public RewritePattern { - SimplifyDeadDealloc(MLIRContext *context) - : RewritePattern(DeallocOp::getOperationName(), 1, context) {} +struct SimplifyDeadDealloc : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - PatternMatchResult matchAndRewrite(Operation *op, + PatternMatchResult matchAndRewrite(DeallocOp dealloc, PatternRewriter &rewriter) const override { - auto dealloc = cast(op); - // Check that the memref operand's defining operation is an AllocOp. Value *memref = dealloc.memref(); - Operation *defOp = memref->getDefiningOp(); - if (!isa_and_nonnull(defOp)) + if (!isa_and_nonnull(memref->getDefiningOp())) return matchFailure(); // Check that all of the uses of the AllocOp are other DeallocOps. @@ -1250,7 +1235,7 @@ struct SimplifyDeadDealloc : public RewritePattern { return matchFailure(); // Erase the dealloc operation. - op->erase(); + rewriter.replaceOp(dealloc, llvm::None); return matchSuccess(); } };