/// mlir::StoreOp requires finding the proper indexing in the supporting MemRef.
/// This is most easily achieved by calling emitAndReturnFullyComposedView to
/// fold away all the SliceOp.
-template <typename LoadOrStoreOpTy> struct Rewriter : public RewritePattern {
- explicit Rewriter(MLIRContext *context)
- : RewritePattern(LoadOrStoreOpTy::getOperationName(), 1, context) {}
+template <typename LoadOrStoreOpTy>
+struct Rewriter : public OpRewritePattern<LoadOrStoreOpTy> {
+ using OpRewritePattern<LoadOrStoreOpTy>::OpRewritePattern;
/// Performs the rewrite.
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(LoadOrStoreOpTy op,
PatternRewriter &rewriter) const override;
};
template <>
PatternMatchResult
-Rewriter<linalg::LoadOp>::matchAndRewrite(Operation *op,
+Rewriter<linalg::LoadOp>::matchAndRewrite(linalg::LoadOp load,
PatternRewriter &rewriter) const {
- auto load = cast<linalg::LoadOp>(op);
SliceOp slice = dyn_cast<SliceOp>(load.getView()->getDefiningOp());
ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
: cast<ViewOp>(load.getView()->getDefiningOp());
ScopedContext scope(builder, load.getLoc());
auto *memRef = view.getSupportingMemRef();
auto operands = emitAndReturnLoadStoreOperands(load, view);
- rewriter.replaceOpWithNewOp<mlir::LoadOp>(op, memRef, operands);
+ rewriter.replaceOpWithNewOp<mlir::LoadOp>(load, memRef, operands);
return matchSuccess();
}
template <>
PatternMatchResult
-Rewriter<linalg::StoreOp>::matchAndRewrite(Operation *op,
+Rewriter<linalg::StoreOp>::matchAndRewrite(linalg::StoreOp store,
PatternRewriter &rewriter) const {
- auto store = cast<linalg::StoreOp>(op);
SliceOp slice = dyn_cast<SliceOp>(store.getView()->getDefiningOp());
ViewOp view = slice ? emitAndReturnFullyComposedView(slice.getResult())
: cast<ViewOp>(store.getView()->getDefiningOp());
auto *valueToStore = store.getValueToStore();
auto *memRef = view.getSupportingMemRef();
auto operands = emitAndReturnLoadStoreOperands(store, view);
- rewriter.replaceOpWithNewOp<mlir::StoreOp>(op, valueToStore, memRef,
+ rewriter.replaceOpWithNewOp<mlir::StoreOp>(store, valueToStore, memRef,
operands);
return matchSuccess();
}
namespace {
/// Fold transpose(transpose(x) -> transpose(x)
-struct SimplifyRedundantTranspose : public mlir::RewritePattern {
+struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// We register this pattern to match every toy.transpose in the IR.
/// The "benefit" is used by the framework to order the patterns and process
/// them in order of profitability.
SimplifyRedundantTranspose(mlir::MLIRContext *context)
- : RewritePattern(TransposeOp::getOperationName(), /* benefit = */ 1,
- context) {}
+ : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
/// This method is attempting to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. It is expected
/// to interact with it to perform any changes to the IR from here.
mlir::PatternMatchResult
- matchAndRewrite(mlir::Operation *op,
+ matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
- // We can directly cast the current operation as this will only get invoked
- // on TransposeOp.
- TransposeOp transpose = llvm::cast<TransposeOp>(op);
// Look through the input of the current transpose.
- mlir::Value *transposeInput = transpose.getOperand();
+ mlir::Value *transposeInput = op.getOperand();
TransposeOp transposeInputOp =
llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
// If the input is defined by another Transpose, bingo!
};
/// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place.
-struct SimplifyReshapeConstant : public mlir::RewritePattern {
- SimplifyReshapeConstant(mlir::MLIRContext *context)
- : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
- context) {}
+struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
+ using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
mlir::PatternMatchResult
- matchAndRewrite(mlir::Operation *op,
+ matchAndRewrite(ReshapeOp reshape,
mlir::PatternRewriter &rewriter) const override {
- ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
// Look through the input of the current reshape.
ConstantOp constantOp = llvm::dyn_cast_or_null<ConstantOp>(
reshape.getOperand()->getDefiningOp());
if (!constantOp)
return matchFailure();
- auto reshapeType = op->getResult(0)->getType().cast<ToyArrayType>();
+ auto reshapeType = reshape.getType().cast<ToyArrayType>();
if (auto valueAttr =
constantOp.getAttrOfType<mlir::DenseElementsAttr>("value")) {
// FIXME Check matching of element count!
reshapeType.getShape(), valueAttr.getType().getElementType());
auto newAttr =
mlir::DenseElementsAttr::get(newType, valueAttr.getRawData());
- rewriter.replaceOpWithNewOp<ConstantOp>(op, reshapeType.getShape(),
+ rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
newAttr);
} else if (auto valueAttr =
constantOp.getAttrOfType<mlir::FloatAttr>("value")) {
auto tensorTy = rewriter.getTensorType(reshapeType.getShape(),
reshapeType.getElementType());
auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data);
- rewriter.replaceOpWithNewOp<ConstantOp>(op, reshapeType.getShape(),
+ rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
newAttr);
} else {
llvm_unreachable("Unsupported Constant format");
};
/// Fold reshape(reshape(x)) -> reshape(x)
-struct SimplifyReshapeReshape : public mlir::RewritePattern {
- SimplifyReshapeReshape(mlir::MLIRContext *context)
- : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
- context) {}
+struct SimplifyReshapeReshape : public mlir::OpRewritePattern<ReshapeOp> {
+ using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
mlir::PatternMatchResult
- matchAndRewrite(mlir::Operation *op,
+ matchAndRewrite(ReshapeOp op,
mlir::PatternRewriter &rewriter) const override {
- ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
// Look through the input of the current reshape.
- mlir::Value *reshapeInput = reshape.getOperand();
+ mlir::Value *reshapeInput = op.getOperand();
+
// If the input is defined by another reshape, bingo!
if (!matchPattern(reshapeInput, mlir::m_Op<ReshapeOp>()))
return matchFailure();
};
/// Fold reshape(x)) -> x, when input type matches output type
-struct SimplifyNullReshape : public mlir::RewritePattern {
- SimplifyNullReshape(mlir::MLIRContext *context)
- : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
- context) {}
+struct SimplifyNullReshape : public mlir::OpRewritePattern<ReshapeOp> {
+ using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
mlir::PatternMatchResult
- matchAndRewrite(mlir::Operation *op,
+ matchAndRewrite(ReshapeOp op,
mlir::PatternRewriter &rewriter) const override {
- ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
- if (reshape.getOperand()->getType() != reshape.getResult()->getType())
+ if (op.getOperand()->getType() != op.getType())
return matchFailure();
- rewriter.replaceOp(reshape, {reshape.getOperand()});
+ rewriter.replaceOp(op, {op.getOperand()});
return matchSuccess();
}
};
#include "toy/Dialect.h"
-#include "mlir/IR/Operation.h"
+#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
namespace {
-/// Fold transpose(transpose(x)) -> transpose(x)
-struct SimplifyRedundantTranspose : public mlir::RewritePattern {
+/// Fold transpose(transpose(x) -> transpose(x)
+struct SimplifyRedundantTranspose : public mlir::OpRewritePattern<TransposeOp> {
/// We register this pattern to match every toy.transpose in the IR.
/// The "benefit" is used by the framework to order the patterns and process
/// them in order of profitability.
SimplifyRedundantTranspose(mlir::MLIRContext *context)
- : RewritePattern(TransposeOp::getOperationName(), /* benefit = */ 1,
- context) {}
+ : OpRewritePattern<TransposeOp>(context, /*benefit=*/1) {}
/// This method is attempting to match a pattern and rewrite it. The rewriter
/// argument is the orchestrator of the sequence of rewrites. It is expected
/// to interact with it to perform any changes to the IR from here.
mlir::PatternMatchResult
- matchAndRewrite(mlir::Operation *op,
+ matchAndRewrite(TransposeOp op,
mlir::PatternRewriter &rewriter) const override {
- // We can directly cast the current operation as this will only get invoked
- // on TransposeOp.
- TransposeOp transpose = llvm::cast<TransposeOp>(op);
- // look through the input to the current transpose
- mlir::Value *transposeInput = transpose.getOperand();
- mlir::Operation *transposeInputInst = transposeInput->getDefiningOp();
- // If the input is defined by another Transpose, bingo!
+ // Look through the input of the current transpose.
+ mlir::Value *transposeInput = op.getOperand();
TransposeOp transposeInputOp =
- mlir::dyn_cast_or_null<TransposeOp>(transposeInputInst);
+ llvm::dyn_cast_or_null<TransposeOp>(transposeInput->getDefiningOp());
+
+ // If the input is defined by another Transpose, bingo!
if (!transposeInputOp)
return matchFailure();
};
/// Fold reshape(constant(x)) -> constant(x'), with x' being reshaped in place.
-struct SimplifyReshapeConstant : public mlir::RewritePattern {
- SimplifyReshapeConstant(mlir::MLIRContext *context)
- : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
- context) {}
+struct SimplifyReshapeConstant : public mlir::OpRewritePattern<ReshapeOp> {
+ using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
mlir::PatternMatchResult
- matchAndRewrite(mlir::Operation *op,
+ matchAndRewrite(ReshapeOp reshape,
mlir::PatternRewriter &rewriter) const override {
- ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
- // look through the input to the current reshape
- mlir::Value *reshapeInput = reshape.getOperand();
- mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp();
- // If the input is defined by another reshape, bingo!
- ConstantOp constantOp =
- mlir::dyn_cast_or_null<ConstantOp>(reshapeInputInst);
+ // Look through the input of the current reshape.
+ ConstantOp constantOp = llvm::dyn_cast_or_null<ConstantOp>(
+ reshape.getOperand()->getDefiningOp());
+
+ // If the input is defined by another constant, bingo!
if (!constantOp)
return matchFailure();
- auto reshapeType = op->getResult(0)->getType().cast<ToyArrayType>();
+ auto reshapeType = reshape.getType().cast<ToyArrayType>();
if (auto valueAttr =
constantOp.getAttrOfType<mlir::DenseElementsAttr>("value")) {
// FIXME Check matching of element count!
reshapeType.getShape(), valueAttr.getType().getElementType());
auto newAttr =
mlir::DenseElementsAttr::get(newType, valueAttr.getRawData());
- auto newConstant = rewriter.create<ConstantOp>(
- constantOp.getLoc(), reshapeType.getShape(), newAttr);
- rewriter.replaceOp(op, {newConstant});
+ rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
+ newAttr);
} else if (auto valueAttr =
constantOp.getAttrOfType<mlir::FloatAttr>("value")) {
// Broadcast
auto tensorTy = rewriter.getTensorType(reshapeType.getShape(),
reshapeType.getElementType());
auto newAttr = mlir::DenseElementsAttr::get(tensorTy, data);
- auto newConstant = rewriter.create<ConstantOp>(
- constantOp.getLoc(), reshapeType.getShape(), newAttr);
- rewriter.replaceOp(op, {newConstant});
+ rewriter.replaceOpWithNewOp<ConstantOp>(reshape, reshapeType.getShape(),
+ newAttr);
} else {
llvm_unreachable("Unsupported Constant format");
}
};
/// Fold reshape(reshape(x)) -> reshape(x)
-struct SimplifyReshapeReshape : public mlir::RewritePattern {
- SimplifyReshapeReshape(mlir::MLIRContext *context)
- : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
- context) {}
+struct SimplifyReshapeReshape : public mlir::OpRewritePattern<ReshapeOp> {
+ using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
mlir::PatternMatchResult
- matchAndRewrite(mlir::Operation *op,
+ matchAndRewrite(ReshapeOp op,
mlir::PatternRewriter &rewriter) const override {
- ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
- // look through the input to the current reshape
- mlir::Value *reshapeInput = reshape.getOperand();
- mlir::Operation *reshapeInputInst = reshapeInput->getDefiningOp();
+ // Look through the input of the current reshape.
+ mlir::Value *reshapeInput = op.getOperand();
+
// If the input is defined by another reshape, bingo!
- ReshapeOp reshapeInputOp =
- mlir::dyn_cast_or_null<ReshapeOp>(reshapeInputInst);
- if (!reshapeInputOp)
+ if (!matchPattern(reshapeInput, mlir::m_Op<ReshapeOp>()))
return matchFailure();
// Use the rewriter to perform the replacement
- rewriter.replaceOp(op, {reshapeInputOp});
+ rewriter.replaceOp(op, {reshapeInput});
return matchSuccess();
}
};
/// Fold reshape(x)) -> x, when input type matches output type
-struct SimplifyNullReshape : public mlir::RewritePattern {
- SimplifyNullReshape(mlir::MLIRContext *context)
- : RewritePattern(ReshapeOp::getOperationName(), /* benefit = */ 1,
- context) {}
+struct SimplifyNullReshape : public mlir::OpRewritePattern<ReshapeOp> {
+ using mlir::OpRewritePattern<ReshapeOp>::OpRewritePattern;
mlir::PatternMatchResult
- matchAndRewrite(mlir::Operation *op,
+ matchAndRewrite(ReshapeOp op,
mlir::PatternRewriter &rewriter) const override {
- ReshapeOp reshape = llvm::cast<ReshapeOp>(op);
- if (reshape.getOperand()->getType() != reshape.getResult()->getType())
+ if (op.getOperand()->getType() != op.getType())
return matchFailure();
- rewriter.replaceOp(reshape, {reshape.getOperand()});
+ rewriter.replaceOp(op, {op.getOperand()});
return matchSuccess();
}
};
namespace {
/// Fold type.cast(x) -> x, when input type matches output type
-struct SimplifyIdentityTypeCast : public mlir::RewritePattern {
- SimplifyIdentityTypeCast(mlir::MLIRContext *context)
- : RewritePattern(TypeCastOp::getOperationName(), /* benefit = */ 1,
- context) {}
+struct SimplifyIdentityTypeCast : public mlir::OpRewritePattern<TypeCastOp> {
+ using mlir::OpRewritePattern<TypeCastOp>::OpRewritePattern;
mlir::PatternMatchResult
- matchAndRewrite(mlir::Operation *op,
+ matchAndRewrite(TypeCastOp typeCast,
mlir::PatternRewriter &rewriter) const override {
- TypeCastOp typeCast = llvm::cast<TypeCastOp>(op);
- auto resTy = typeCast.getResult()->getType();
- auto *candidateOp = op;
+ auto resTy = typeCast.getType();
+ auto *candidateOp = typeCast.getOperation();
while (llvm::isa_and_nonnull<TypeCastOp>(candidateOp)) {
if (resTy == candidateOp->getOperand(0)->getType()) {
rewriter.replaceOp(typeCast, {candidateOp->getOperand(0)});
llvm::SmallVector<OperationName, 2> generatedOps;
};
+/// OpRewritePattern is a wrapper around RewritePattern that allows for
+/// matching and rewriting against an instance of a derived operation class as
+/// opposed to a raw Operation.
+template <typename SourceOp> struct OpRewritePattern : public RewritePattern {
+ /// Patterns must specify the root operation name they match against, and can
+ /// also specify the benefit of the pattern matching.
+ OpRewritePattern(MLIRContext *context, PatternBenefit benefit = 1)
+ : RewritePattern(SourceOp::getOperationName(), benefit, context) {}
+
+ /// Wrappers around the RewritePattern methods that pass the derived op type.
+ void rewrite(Operation *op, std::unique_ptr<PatternState> state,
+ PatternRewriter &rewriter) const final {
+ rewrite(llvm::cast<SourceOp>(op), std::move(state), rewriter);
+ }
+ void rewrite(Operation *op, PatternRewriter &rewriter) const final {
+ rewrite(llvm::cast<SourceOp>(op), rewriter);
+ }
+ PatternMatchResult match(Operation *op) const final {
+ return match(llvm::cast<SourceOp>(op));
+ }
+ PatternMatchResult matchAndRewrite(Operation *op,
+ PatternRewriter &rewriter) const final {
+ return matchAndRewrite(llvm::cast<SourceOp>(op), rewriter);
+ }
+
+ /// Rewrite and Match methods that operate on the SourceOp type. These must be
+ /// overridden by the derived pattern class.
+ virtual void rewrite(SourceOp op, std::unique_ptr<PatternState> state,
+ PatternRewriter &rewriter) const {
+ rewrite(op, rewriter);
+ }
+ virtual void rewrite(SourceOp op, PatternRewriter &rewriter) const {
+ llvm_unreachable("must override matchAndRewrite or a rewrite method");
+ }
+ virtual PatternMatchResult match(SourceOp op) const {
+ llvm_unreachable("must override match or matchAndRewrite");
+ }
+ virtual PatternMatchResult matchAndRewrite(SourceOp op,
+ PatternRewriter &rewriter) const {
+ if (auto matchResult = match(op)) {
+ rewrite(op, std::move(*matchResult), rewriter);
+ return matchSuccess();
+ }
+ return matchFailure();
+ }
+};
+
//===----------------------------------------------------------------------===//
// PatternRewriter class
//===----------------------------------------------------------------------===//
namespace {
/// Simplify AffineApply operations.
///
-struct SimplifyAffineApply : public RewritePattern {
- SimplifyAffineApply(MLIRContext *context)
- : RewritePattern(AffineApplyOp::getOperationName(), 1, context) {}
+struct SimplifyAffineApply : public OpRewritePattern<AffineApplyOp> {
+ using OpRewritePattern<AffineApplyOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(AffineApplyOp apply,
PatternRewriter &rewriter) const override {
- auto apply = cast<AffineApplyOp>(op);
auto map = apply.getAffineMap();
AffineMap oldMap = map;
SmallVector<Value *, 8> resultOperands(apply.getOperands());
composeAffineMapAndOperands(&map, &resultOperands);
- if (map != oldMap) {
- rewriter.replaceOpWithNewOp<AffineApplyOp>(op, map, resultOperands);
- return matchSuccess();
- }
+ if (map == oldMap)
+ return matchFailure();
- return matchFailure();
+ rewriter.replaceOpWithNewOp<AffineApplyOp>(apply, map, resultOperands);
+ return matchSuccess();
}
};
} // end anonymous namespace.
namespace {
/// This is a pattern to fold constant loop bounds.
-struct AffineForLoopBoundFolder : public RewritePattern {
- /// The rootOpName is the name of the root operation to match against.
- AffineForLoopBoundFolder(MLIRContext *context)
- : RewritePattern(AffineForOp::getOperationName(), 1, context) {}
+struct AffineForLoopBoundFolder : public OpRewritePattern<AffineForOp> {
+ using OpRewritePattern<AffineForOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(AffineForOp forOp,
PatternRewriter &rewriter) const override {
- auto forOp = cast<AffineForOp>(op);
auto foldLowerOrUpperBound = [&forOp](bool lower) {
// Check to see if each of the operands is the result of a constant. If
// so, get the value. If not, ignore it.
// If any of the bounds were folded we return success.
if (!folded)
return matchFailure();
- rewriter.updatedRootInPlace(op);
+ rewriter.updatedRootInPlace(forOp);
return matchSuccess();
}
};
namespace {
-struct UniformDequantizePattern : public RewritePattern {
- UniformDequantizePattern(MLIRContext *context)
- : RewritePattern(DequantizeCastOp::getOperationName(), 1, context) {}
+struct UniformDequantizePattern : public OpRewritePattern<DequantizeCastOp> {
+ using OpRewritePattern<DequantizeCastOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(DequantizeCastOp op,
PatternRewriter &rewriter) const {
- auto dcastOp = cast<DequantizeCastOp>(op);
- Type inputType = dcastOp.arg()->getType();
- Type outputType = dcastOp.getResult()->getType();
+ Type inputType = op.arg()->getType();
+ Type outputType = op.getResult()->getType();
QuantizedType inputElementType =
QuantizedType::getQuantizedElementType(inputType);
return matchFailure();
}
- Value *dequantizedValue =
- emitDequantize(dcastOp.getLoc(), dcastOp.arg(), rewriter);
+ Value *dequantizedValue = emitDequantize(op.getLoc(), op.arg(), rewriter);
if (!dequantizedValue) {
return matchFailure();
}
namespace {
-struct UniformRealAddEwPattern : public RewritePattern {
- UniformRealAddEwPattern(MLIRContext *context)
- : RewritePattern(RealAddEwOp::getOperationName(), 1, context) {}
+struct UniformRealAddEwPattern : public OpRewritePattern<RealAddEwOp> {
+ using OpRewritePattern<RealAddEwOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(RealAddEwOp op,
PatternRewriter &rewriter) const {
- auto addOp = cast<RealAddEwOp>(op);
- const UniformBinaryOpInfo info(op, addOp.lhs(), addOp.rhs(),
- addOp.clamp_min(), addOp.clamp_max());
+ const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
+ op.clamp_max());
if (!info.isValid()) {
return matchFailure();
}
}
};
-struct UniformRealMulEwPattern : public RewritePattern {
- UniformRealMulEwPattern(MLIRContext *context)
- : RewritePattern(RealMulEwOp::getOperationName(), 1, context) {}
+struct UniformRealMulEwPattern : public OpRewritePattern<RealMulEwOp> {
+ using OpRewritePattern<RealMulEwOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(RealMulEwOp op,
PatternRewriter &rewriter) const {
- auto mulOp = cast<RealMulEwOp>(op);
- const UniformBinaryOpInfo info(op, mulOp.lhs(), mulOp.rhs(),
- mulOp.clamp_min(), mulOp.clamp_max());
+ const UniformBinaryOpInfo info(op, op.lhs(), op.rhs(), op.clamp_min(),
+ op.clamp_max());
if (!info.isValid()) {
return matchFailure();
}
/// Matches x -> [scast -> scast] -> y, replacing the second scast with the
/// value of x if the casts invert each other.
-class RemoveRedundantStorageCastsRewrite : public RewritePattern {
+class RemoveRedundantStorageCastsRewrite
+ : public OpRewritePattern<StorageCastOp> {
public:
- RemoveRedundantStorageCastsRewrite(MLIRContext *context)
- : RewritePattern(StorageCastOp::getOperationName(), 1, context) {}
+ using OpRewritePattern<StorageCastOp>::OpRewritePattern;
- PatternMatchResult match(Operation *op) const override {
- auto scastOp = cast<StorageCastOp>(op);
- if (matchPattern(scastOp.arg(), m_Op<StorageCastOp>())) {
- auto srcScastOp = cast<StorageCastOp>(scastOp.arg()->getDefiningOp());
- if (srcScastOp.arg()->getType() == scastOp.getResult()->getType()) {
- return matchSuccess();
- }
- }
- return matchFailure();
- }
+ PatternMatchResult matchAndRewrite(StorageCastOp op,
+ PatternRewriter &rewriter) const override {
+ if (!matchPattern(op.arg(), m_Op<StorageCastOp>()))
+ return matchFailure();
+ auto srcScastOp = cast<StorageCastOp>(op.arg()->getDefiningOp());
+ if (srcScastOp.arg()->getType() != op.getType())
+ return matchFailure();
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- auto scastOp = cast<StorageCastOp>(op);
- auto srcScastOp = cast<StorageCastOp>(scastOp.arg()->getDefiningOp());
rewriter.replaceOp(op, srcScastOp.arg());
+ return matchSuccess();
}
};
void runOnFunction() override;
};
-class QuantizedConstRewrite : public RewritePattern {
-public:
- struct State : PatternState {
- QuantizedType quantizedElementType;
- Attribute value;
- };
-
- QuantizedConstRewrite(MLIRContext *context)
- : RewritePattern(QuantizeCastOp::getOperationName(), 1, context) {}
+struct QuantizedConstRewrite : public OpRewritePattern<QuantizeCastOp> {
+ using OpRewritePattern<QuantizeCastOp>::OpRewritePattern;
- PatternMatchResult match(Operation *op) const override;
- void rewrite(Operation *op, std::unique_ptr<PatternState> baseState,
- PatternRewriter &rewriter) const override;
+ PatternMatchResult matchAndRewrite(QuantizeCastOp qbarrier,
+ PatternRewriter &rewriter) const override;
};
} // end anonymous namespace
/// Matches a [constant] -> [qbarrier] where the qbarrier results type is
/// quantized and the operand type is quantizable.
-PatternMatchResult QuantizedConstRewrite::match(Operation *op) const {
- State state;
+
+PatternMatchResult
+QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
+ PatternRewriter &rewriter) const {
+ Attribute value;
// Is the operand a constant?
- auto qbarrier = cast<QuantizeCastOp>(op);
- if (!matchPattern(qbarrier.arg(), m_Constant(&state.value))) {
+ if (!matchPattern(qbarrier.arg(), m_Constant(&value))) {
return matchFailure();
}
+
// Does the qbarrier convert to a quantized type. This will not be true
// if a quantized type has not yet been chosen or if the cast to an equivalent
// storage type is not supported.
Type qbarrierResultType = qbarrier.getResult()->getType();
- state.quantizedElementType =
+ QuantizedType quantizedElementType =
QuantizedType::getQuantizedElementType(qbarrierResultType);
- if (!state.quantizedElementType) {
+ if (!quantizedElementType) {
return matchFailure();
}
if (!QuantizedType::castToStorageType(qbarrierResultType)) {
// Is the operand type compatible with the expressed type of the quantized
// type? This will not be true if the qbarrier is superfluous (converts
// from and to a quantized type).
- if (!state.quantizedElementType.isCompatibleExpressedType(
+ if (!quantizedElementType.isCompatibleExpressedType(
qbarrier.arg()->getType())) {
return matchFailure();
}
// Is the constant value a type expressed in a way that we support?
- if (!state.value.isa<FloatAttr>() && !state.value.isa<SplatElementsAttr>() &&
- !state.value.isa<DenseElementsAttr>() &&
- !state.value.isa<SparseElementsAttr>()) {
+ if (!value.isa<FloatAttr>() && !value.isa<SplatElementsAttr>() &&
+ !value.isa<DenseElementsAttr>() && !value.isa<SparseElementsAttr>()) {
return matchFailure();
}
- return matchSuccess(llvm::make_unique<State>(std::move(state)));
-}
-
-void QuantizedConstRewrite::rewrite(Operation *op,
- std::unique_ptr<PatternState> baseState,
- PatternRewriter &rewriter) const {
- auto state = static_cast<State *>(baseState.get());
-
Type newConstValueType;
- Attribute newConstValue = quantizeAttr(
- state->value, state->quantizedElementType, newConstValueType);
+ auto newConstValue =
+ quantizeAttr(value, quantizedElementType, newConstValueType);
if (!newConstValue) {
- return;
+ return matchFailure();
}
- auto *origConstOp = op->getOperand(0);
// When creating the new const op, use a fused location that combines the
// original const and the qbarrier that led to the quantization.
- auto fusedLoc =
- FusedLoc::get({origConstOp->getDefiningOp()->getLoc(), op->getLoc()},
- rewriter.getContext());
+ auto fusedLoc = FusedLoc::get(
+ {qbarrier.arg()->getDefiningOp()->getLoc(), qbarrier.getLoc()},
+ rewriter.getContext());
auto newConstOp =
rewriter.create<ConstantOp>(fusedLoc, newConstValueType, newConstValue);
- rewriter.replaceOpWithNewOp<StorageCastOp>(
- {origConstOp}, op, *op->result_type_begin(), newConstOp);
+ rewriter.replaceOpWithNewOp<StorageCastOp>({qbarrier.arg()}, qbarrier,
+ qbarrier.getType(), newConstOp);
+ return matchSuccess();
}
void ConvertConstPass::runOnFunction() {
namespace {
/// Fold constant dimensions into an alloc operation.
-struct SimplifyAllocConst : public RewritePattern {
- SimplifyAllocConst(MLIRContext *context)
- : RewritePattern(AllocOp::getOperationName(), 1, context) {}
-
- PatternMatchResult match(Operation *op) const override {
- auto alloc = cast<AllocOp>(op);
+struct SimplifyAllocConst : public OpRewritePattern<AllocOp> {
+ using OpRewritePattern<AllocOp>::OpRewritePattern;
+ PatternMatchResult matchAndRewrite(AllocOp alloc,
+ PatternRewriter &rewriter) const override {
// Check to see if any dimensions operands are constants. If so, we can
// substitute and drop them.
- for (auto *operand : alloc.getOperands())
- if (matchPattern(operand, m_ConstantIndex()))
- return matchSuccess();
- return matchFailure();
- }
+ if (llvm::none_of(alloc.getOperands(), [](Value *operand) {
+ return matchPattern(operand, m_ConstantIndex());
+ }))
+ return matchFailure();
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- auto allocOp = cast<AllocOp>(op);
- auto memrefType = allocOp.getType();
+ auto memrefType = alloc.getType();
// Ok, we have one or more constant operands. Collect the non-constant ones
// and keep track of the resultant memref type to build.
newShapeConstants.push_back(dimSize);
continue;
}
- auto *defOp = allocOp.getOperand(dynamicDimPos)->getDefiningOp();
+ auto *defOp = alloc.getOperand(dynamicDimPos)->getDefiningOp();
if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
// Dynamic shape dimension will be folded.
newShapeConstants.push_back(constantIndexOp.getValue());
} else {
// Dynamic shape dimension not folded; copy operand from old memref.
newShapeConstants.push_back(-1);
- newOperands.push_back(allocOp.getOperand(dynamicDimPos));
+ newOperands.push_back(alloc.getOperand(dynamicDimPos));
}
dynamicDimPos++;
}
// Create and insert the alloc op for the new memref.
auto newAlloc =
- rewriter.create<AllocOp>(allocOp.getLoc(), newMemRefType, newOperands);
+ rewriter.create<AllocOp>(alloc.getLoc(), newMemRefType, newOperands);
// Insert a cast so we have the same type as the old alloc.
- auto resultCast = rewriter.create<MemRefCastOp>(allocOp.getLoc(), newAlloc,
- allocOp.getType());
+ auto resultCast = rewriter.create<MemRefCastOp>(alloc.getLoc(), newAlloc,
+ alloc.getType());
- rewriter.replaceOp(op, {resultCast}, droppedOperands);
+ rewriter.replaceOp(alloc, {resultCast}, droppedOperands);
+ return matchSuccess();
}
};
/// Fold alloc operations with no uses. Alloc has side effects on the heap,
/// but can still be deleted if it has zero uses.
-struct SimplifyDeadAlloc : public RewritePattern {
- SimplifyDeadAlloc(MLIRContext *context)
- : RewritePattern(AllocOp::getOperationName(), 1, context) {}
+struct SimplifyDeadAlloc : public OpRewritePattern<AllocOp> {
+ using OpRewritePattern<AllocOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(AllocOp alloc,
PatternRewriter &rewriter) const override {
// Check if the alloc'ed value has any uses.
- auto alloc = cast<AllocOp>(op);
if (!alloc.use_empty())
return matchFailure();
// If it doesn't, we can eliminate it.
- op->erase();
+ alloc.erase();
return matchSuccess();
}
};
//===----------------------------------------------------------------------===//
namespace {
/// Fold indirect calls that have a constant function as the callee operand.
-struct SimplifyIndirectCallWithKnownCallee : public RewritePattern {
- SimplifyIndirectCallWithKnownCallee(MLIRContext *context)
- : RewritePattern(CallIndirectOp::getOperationName(), 1, context) {}
+struct SimplifyIndirectCallWithKnownCallee
+ : public OpRewritePattern<CallIndirectOp> {
+ using OpRewritePattern<CallIndirectOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(CallIndirectOp indirectCall,
PatternRewriter &rewriter) const override {
- auto indirectCall = cast<CallIndirectOp>(op);
-
// Check that the callee is a constant callee.
FunctionAttr calledFn;
if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn)))
return matchFailure();
// Replace with a direct call.
- SmallVector<Type, 8> callResults(op->getResultTypes());
+ SmallVector<Type, 8> callResults(indirectCall.getResultTypes());
SmallVector<Value *, 8> callOperands(indirectCall.getArgOperands());
- rewriter.replaceOpWithNewOp<CallOp>(op, calledFn.getValue(), callResults,
- callOperands);
+ rewriter.replaceOpWithNewOp<CallOp>(indirectCall, calledFn.getValue(),
+ callResults, callOperands);
return matchSuccess();
}
};
/// cond_br true, ^bb1, ^bb2 -> br ^bb1
/// cond_br false, ^bb1, ^bb2 -> br ^bb2
///
-struct SimplifyConstCondBranchPred : public RewritePattern {
- SimplifyConstCondBranchPred(MLIRContext *context)
- : RewritePattern(CondBranchOp::getOperationName(), 1, context) {}
+struct SimplifyConstCondBranchPred : public OpRewritePattern<CondBranchOp> {
+ using OpRewritePattern<CondBranchOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(CondBranchOp condbr,
PatternRewriter &rewriter) const override {
- auto condbr = cast<CondBranchOp>(op);
-
// Check that the condition is a constant.
if (!matchPattern(condbr.getCondition(), m_Op<ConstantOp>()))
return matchFailure();
branchArgs.assign(condbr.true_operand_begin(), condbr.true_operand_end());
}
- rewriter.replaceOpWithNewOp<BranchOp>(op, foldedDest, branchArgs);
+ rewriter.replaceOpWithNewOp<BranchOp>(condbr, foldedDest, branchArgs);
return matchSuccess();
}
};
namespace {
/// Fold Dealloc operations that are deallocating an AllocOp that is only used
/// by other Dealloc operations.
-struct SimplifyDeadDealloc : public RewritePattern {
- SimplifyDeadDealloc(MLIRContext *context)
- : RewritePattern(DeallocOp::getOperationName(), 1, context) {}
+struct SimplifyDeadDealloc : public OpRewritePattern<DeallocOp> {
+ using OpRewritePattern<DeallocOp>::OpRewritePattern;
- PatternMatchResult matchAndRewrite(Operation *op,
+ PatternMatchResult matchAndRewrite(DeallocOp dealloc,
PatternRewriter &rewriter) const override {
- auto dealloc = cast<DeallocOp>(op);
-
// Check that the memref operand's defining operation is an AllocOp.
Value *memref = dealloc.memref();
- Operation *defOp = memref->getDefiningOp();
- if (!isa_and_nonnull<AllocOp>(defOp))
+ if (!isa_and_nonnull<AllocOp>(memref->getDefiningOp()))
return matchFailure();
// Check that all of the uses of the AllocOp are other DeallocOps.
return matchFailure();
// Erase the dealloc operation.
- op->erase();
+ rewriter.replaceOp(dealloc, llvm::None);
return matchSuccess();
}
};