From 7ac42fa26e5ac2c3554eb38b7456c6bd81e69cec Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 13 Dec 2019 14:52:39 -0800 Subject: [PATCH] Refactor various canonicalization patterns as in-place folds. This is more efficient, and allows for these to fire in more situations: e.g. createOrFold, DialectConversion, etc. PiperOrigin-RevId: 285476837 --- mlir/include/mlir/Dialect/AffineOps/AffineOps.h | 11 +- mlir/include/mlir/Dialect/AffineOps/AffineOps.td | 5 +- mlir/include/mlir/Dialect/QuantOps/QuantOps.td | 2 +- mlir/include/mlir/Dialect/StandardOps/Ops.h | 8 +- mlir/include/mlir/Dialect/StandardOps/Ops.td | 6 +- mlir/lib/Dialect/AffineOps/AffineOps.cpp | 278 ++++++++++------------- mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp | 44 +--- mlir/lib/Dialect/StandardOps/Ops.cpp | 76 +++---- 8 files changed, 187 insertions(+), 243 deletions(-) diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h index 835ac24..8268f81 100644 --- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.h +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.h @@ -295,8 +295,8 @@ public: static ParseResult parse(OpAsmParser &parser, OperationState &result); void print(OpAsmPrinter &p); LogicalResult verify(); - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); + LogicalResult fold(ArrayRef cstOperands, + SmallVectorImpl &results); /// Returns true if this DMA operation is strided, returns false otherwise. bool isStrided() { @@ -380,8 +380,8 @@ public: static ParseResult parse(OpAsmParser &parser, OperationState &result); void print(OpAsmPrinter &p); LogicalResult verify(); - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); + LogicalResult fold(ArrayRef cstOperands, + SmallVectorImpl &results); }; /// The "affine.load" op reads an element from a memref, where the index @@ -450,6 +450,7 @@ public: LogicalResult verify(); static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context); + OpFoldResult fold(ArrayRef operands); }; /// The "affine.store" op writes an element to a memref, where the index @@ -520,6 +521,8 @@ public: LogicalResult verify(); static void getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context); + LogicalResult fold(ArrayRef cstOperands, + SmallVectorImpl &results); }; /// Returns true if the given Value can be used as a dimension id. diff --git a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td index 4d40604..cea44b8 100644 --- a/mlir/include/mlir/Dialect/AffineOps/AffineOps.td +++ b/mlir/include/mlir/Dialect/AffineOps/AffineOps.td @@ -177,12 +177,13 @@ def AffineForOp : Affine_Op<"for", /// Sets the upper bound to the given constant value. void setConstantUpperBound(int64_t value); - /// Returns true if both the lower and upper bound have the same operand + /// Returns true if both the lower and upper bound have the same operand /// lists (same operands in the same order). bool matchingBoundOperandList(); }]; let hasCanonicalizer = 1; + let hasFolder = 1; } def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> { @@ -239,7 +240,7 @@ def AffineIfOp : Affine_Op<"if", [ImplicitAffineTerminator]> { } }]; - let hasCanonicalizer = 1; + let hasFolder = 1; } def AffineMinOp : Affine_Op<"min"> { diff --git a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td index 85d5cd2..072715d 100644 --- a/mlir/include/mlir/Dialect/QuantOps/QuantOps.td +++ b/mlir/include/mlir/Dialect/QuantOps/QuantOps.td @@ -93,7 +93,7 @@ def quant_DequantizeCastOp : quant_Op<"dcast", [NoSideEffect]> { def quant_StorageCastOp : quant_Op<"scast", [NoSideEffect]> { let arguments = (ins quant_RealOrStorageValueType:$arg); let results = (outs quant_RealOrStorageValueType); - let hasCanonicalizer = 0b1; + let hasFolder = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.h b/mlir/include/mlir/Dialect/StandardOps/Ops.h index c7c8714..fcf16c0 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.h @@ -268,8 +268,8 @@ public: void print(OpAsmPrinter &p); LogicalResult verify(); - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); + LogicalResult fold(ArrayRef cstOperands, + SmallVectorImpl &results); bool isStrided() { return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + @@ -331,8 +331,8 @@ public: static ParseResult parse(OpAsmParser &parser, OperationState &result); void print(OpAsmPrinter &p); - static void getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context); + LogicalResult fold(ArrayRef cstOperands, + SmallVectorImpl &results); }; /// Prints dimension and symbol list. diff --git a/mlir/include/mlir/Dialect/StandardOps/Ops.td b/mlir/include/mlir/Dialect/StandardOps/Ops.td index 8e21a8b..553a612 100644 --- a/mlir/include/mlir/Dialect/StandardOps/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/Ops.td @@ -659,6 +659,7 @@ def DeallocOp : Std_Op<"dealloc"> { let arguments = (ins AnyMemRef:$memref); let hasCanonicalizer = 1; + let hasFolder = 1; } def DimOp : Std_Op<"dim", [NoSideEffect]> { @@ -691,7 +692,6 @@ def DimOp : Std_Op<"dim", [NoSideEffect]> { }]; let hasFolder = 1; - let hasCanonicalizer = 1; } def DivFOp : FloatArithmeticOp<"divf"> { @@ -834,7 +834,7 @@ def LoadOp : Std_Op<"load"> { operand_range getIndices() { return {operand_begin() + 1, operand_end()}; } }]; - let hasCanonicalizer = 1; + let hasFolder = 1; } def LogOp : FloatUnaryOp<"log"> { @@ -1137,7 +1137,7 @@ def StoreOp : Std_Op<"store"> { } }]; - let hasCanonicalizer = 1; + let hasFolder = 1; } def SubFOp : FloatArithmeticOp<"subf"> { diff --git a/mlir/lib/Dialect/AffineOps/AffineOps.cpp b/mlir/lib/Dialect/AffineOps/AffineOps.cpp index 22d4ec1..e58f6f8 100644 --- a/mlir/lib/Dialect/AffineOps/AffineOps.cpp +++ b/mlir/lib/Dialect/AffineOps/AffineOps.cpp @@ -814,33 +814,20 @@ void AffineApplyOp::getCanonicalizationPatterns( // Common canonicalization pattern support logic //===----------------------------------------------------------------------===// -namespace { /// This is a common class used for patterns of the form /// "someop(memrefcast) -> someop". It folds the source of any memref_cast /// into the root operation directly. -struct MemRefCastFolder : public RewritePattern { - /// The rootOpName is the name of the root operation to match against. - MemRefCastFolder(StringRef rootOpName, MLIRContext *context) - : RewritePattern(rootOpName, 1, context) {} - - PatternMatchResult match(Operation *op) const override { - for (auto *operand : op->getOperands()) - if (matchPattern(operand, m_Op())) - return matchSuccess(); - - return matchFailure(); - } - - void rewrite(Operation *op, PatternRewriter &rewriter) const override { - for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) - if (auto *memref = op->getOperand(i)->getDefiningOp()) - if (auto cast = dyn_cast(memref)) - op->setOperand(i, cast.getOperand()); - rewriter.updatedRootInPlace(op); +static LogicalResult foldMemRefCast(Operation *op) { + bool folded = false; + for (OpOperand &operand : op->getOpOperands()) { + auto cast = dyn_cast_or_null(operand.get()->getDefiningOp()); + if (cast && !cast.getOperand()->getType().isa()) { + operand.set(cast.getOperand()); + folded = true; + } } -}; - -} // end anonymous namespace. + return success(folded); +} //===----------------------------------------------------------------------===// // AffineDmaStartOp @@ -1005,10 +992,10 @@ LogicalResult AffineDmaStartOp::verify() { return success(); } -void AffineDmaStartOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { +LogicalResult AffineDmaStartOp::fold(ArrayRef cstOperands, + SmallVectorImpl &results) { /// dma_start(memrefcast) -> dma_start - results.insert(getOperationName(), context); + return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// @@ -1081,10 +1068,10 @@ LogicalResult AffineDmaWaitOp::verify() { return success(); } -void AffineDmaWaitOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { +LogicalResult AffineDmaWaitOp::fold(ArrayRef cstOperands, + SmallVectorImpl &results) { /// dma_wait(memrefcast) -> dma_wait - results.insert(getOperationName(), context); + return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// @@ -1255,7 +1242,8 @@ static ParseResult parseBound(bool isLower, OperationState &result, "expected valid affine map representation for loop bounds"); } -ParseResult parseAffineForOp(OpAsmParser &parser, OperationState &result) { +static ParseResult parseAffineForOp(OpAsmParser &parser, + OperationState &result) { auto &builder = parser.getBuilder(); OpAsmParser::OperandType inductionVariable; // Parse the induction variable followed by '='. @@ -1344,7 +1332,7 @@ static void printBound(AffineMapAttr boundMap, map.getNumDims(), p); } -void print(OpAsmPrinter &p, AffineForOp op) { +static void print(OpAsmPrinter &p, AffineForOp op) { p << "affine.for "; p.printOperand(op.getBody()->getArgument(0)); p << " = "; @@ -1363,115 +1351,102 @@ void print(OpAsmPrinter &p, AffineForOp op) { op.getStepAttrName()}); } -namespace { -/// This is a pattern to fold trivially empty loops. -struct AffineForEmptyLoopFolder : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +/// Fold the constant bounds of a loop. +static LogicalResult foldLoopBounds(AffineForOp forOp) { + auto foldLowerOrUpperBound = [&forOp](bool lower) { + // Check to see if each of the operands is the result of a constant. If + // so, get the value. If not, ignore it. + SmallVector operandConstants; + auto boundOperands = + lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands(); + for (auto *operand : boundOperands) { + Attribute operandCst; + matchPattern(operand, m_Constant(&operandCst)); + operandConstants.push_back(operandCst); + } - PatternMatchResult matchAndRewrite(AffineForOp forOp, - PatternRewriter &rewriter) const override { - // Check that the body only contains a terminator. - if (!has_single_element(*forOp.getBody())) - return matchFailure(); - rewriter.eraseOp(forOp); - return matchSuccess(); - } -}; + AffineMap boundMap = + lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap(); + assert(boundMap.getNumResults() >= 1 && + "bound maps should have at least one result"); + SmallVector foldedResults; + if (failed(boundMap.constantFold(operandConstants, foldedResults))) + return failure(); -/// This is a pattern to fold constant loop bounds. -struct AffineForOpBoundFolder : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + // Compute the max or min as applicable over the results. + assert(!foldedResults.empty() && "bounds should have at least one result"); + auto maxOrMin = foldedResults[0].cast().getValue(); + for (unsigned i = 1, e = foldedResults.size(); i < e; i++) { + auto foldedResult = foldedResults[i].cast().getValue(); + maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult) + : llvm::APIntOps::smin(maxOrMin, foldedResult); + } + lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue()) + : forOp.setConstantUpperBound(maxOrMin.getSExtValue()); + return success(); + }; - PatternMatchResult matchAndRewrite(AffineForOp forOp, - PatternRewriter &rewriter) const override { - auto foldLowerOrUpperBound = [&forOp](bool lower) { - // Check to see if each of the operands is the result of a constant. If - // so, get the value. If not, ignore it. - SmallVector operandConstants; - auto boundOperands = - lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands(); - for (auto *operand : boundOperands) { - Attribute operandCst; - matchPattern(operand, m_Constant(&operandCst)); - operandConstants.push_back(operandCst); - } + // Try to fold the lower bound. + bool folded = false; + if (!forOp.hasConstantLowerBound()) + folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true)); - AffineMap boundMap = - lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap(); - assert(boundMap.getNumResults() >= 1 && - "bound maps should have at least one result"); - SmallVector foldedResults; - if (failed(boundMap.constantFold(operandConstants, foldedResults))) - return failure(); - - // Compute the max or min as applicable over the results. - assert(!foldedResults.empty() && - "bounds should have at least one result"); - auto maxOrMin = foldedResults[0].cast().getValue(); - for (unsigned i = 1, e = foldedResults.size(); i < e; i++) { - auto foldedResult = foldedResults[i].cast().getValue(); - maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult) - : llvm::APIntOps::smin(maxOrMin, foldedResult); - } - lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue()) - : forOp.setConstantUpperBound(maxOrMin.getSExtValue()); - return success(); - }; - - // Try to fold the lower bound. - bool folded = false; - if (!forOp.hasConstantLowerBound()) - folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true)); - - // Try to fold the upper bound. - if (!forOp.hasConstantUpperBound()) - folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false)); - - // If any of the bounds were folded we return success. - if (!folded) - return matchFailure(); - rewriter.updatedRootInPlace(forOp); - return matchSuccess(); - } -}; + // Try to fold the upper bound. + if (!forOp.hasConstantUpperBound()) + folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false)); + return success(folded); +} -// This is a pattern to canonicalize affine for op loop bounds. -struct AffineForOpBoundCanonicalizer : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +/// Canonicalize the bounds of the given loop. +static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) { + SmallVector lbOperands(forOp.getLowerBoundOperands()); + SmallVector ubOperands(forOp.getUpperBoundOperands()); - PatternMatchResult matchAndRewrite(AffineForOp forOp, - PatternRewriter &rewriter) const override { - SmallVector lbOperands(forOp.getLowerBoundOperands()); - SmallVector ubOperands(forOp.getUpperBoundOperands()); + auto lbMap = forOp.getLowerBoundMap(); + auto ubMap = forOp.getUpperBoundMap(); + auto prevLbMap = lbMap; + auto prevUbMap = ubMap; - auto lbMap = forOp.getLowerBoundMap(); - auto ubMap = forOp.getUpperBoundMap(); - auto prevLbMap = lbMap; - auto prevUbMap = ubMap; + canonicalizeMapAndOperands(&lbMap, &lbOperands); + canonicalizeMapAndOperands(&ubMap, &ubOperands); - canonicalizeMapAndOperands(&lbMap, &lbOperands); - canonicalizeMapAndOperands(&ubMap, &ubOperands); + // Any canonicalization change always leads to updated map(s). + if (lbMap == prevLbMap && ubMap == prevUbMap) + return failure(); - // Any canonicalization change always leads to updated map(s). - if (lbMap == prevLbMap && ubMap == prevUbMap) - return matchFailure(); + if (lbMap != prevLbMap) + forOp.setLowerBound(lbOperands, lbMap); + if (ubMap != prevUbMap) + forOp.setUpperBound(ubOperands, ubMap); + return success(); +} - if (lbMap != prevLbMap) - forOp.setLowerBound(lbOperands, lbMap); - if (ubMap != prevUbMap) - forOp.setUpperBound(ubOperands, ubMap); +namespace { +/// This is a pattern to fold trivially empty loops. +struct AffineForEmptyLoopFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; - rewriter.updatedRootInPlace(forOp); + PatternMatchResult matchAndRewrite(AffineForOp forOp, + PatternRewriter &rewriter) const override { + // Check that the body only contains a terminator. + if (!has_single_element(*forOp.getBody())) + return matchFailure(); + rewriter.eraseOp(forOp); return matchSuccess(); } }; - } // end anonymous namespace void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); +} + +LogicalResult AffineForOp::fold(ArrayRef operands, + SmallVectorImpl &results) { + bool folded = succeeded(foldLoopBounds(*this)); + folded |= succeeded(canonicalizeLoopBounds(*this)); + return success(folded); } AffineBound AffineForOp::getLowerBound() { @@ -1741,37 +1716,23 @@ void AffineIfOp::build(Builder *builder, OperationState &result, IntegerSet set, AffineIfOp::ensureTerminator(*elseRegion, *builder, result.location); } -namespace { -// This is a pattern to canonicalize an affine if op's conditional (integer -// set + operands). -struct AffineIfOpCanonicalizer : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; +/// Canonicalize an affine if op's conditional (integer set + operands). +LogicalResult AffineIfOp::fold(ArrayRef, + SmallVectorImpl &) { + auto set = getIntegerSet(); + SmallVector operands(getOperands()); + canonicalizeSetAndOperands(&set, &operands); - PatternMatchResult matchAndRewrite(AffineIfOp ifOp, - PatternRewriter &rewriter) const override { - auto set = ifOp.getIntegerSet(); - SmallVector operands(ifOp.getOperands()); - - canonicalizeSetAndOperands(&set, &operands); - - // Any canonicalization change always leads to either a reduction in the - // number of operands or a change in the number of symbolic operands - // (promotion of dims to symbols). - if (operands.size() < ifOp.getIntegerSet().getNumInputs() || - set.getNumSymbols() > ifOp.getIntegerSet().getNumSymbols()) { - ifOp.setConditional(set, operands); - rewriter.updatedRootInPlace(ifOp); - return matchSuccess(); - } - - return matchFailure(); + // Any canonicalization change always leads to either a reduction in the + // number of operands or a change in the number of symbolic operands + // (promotion of dims to symbols). + if (operands.size() < getIntegerSet().getNumInputs() || + set.getNumSymbols() > getIntegerSet().getNumSymbols()) { + setConditional(set, operands); + return success(); } -}; -} // end anonymous namespace -void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { - results.insert(context); + return failure(); } //===----------------------------------------------------------------------===// @@ -1866,11 +1827,16 @@ LogicalResult AffineLoadOp::verify() { void AffineLoadOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - /// load(memrefcast) -> load - results.insert(getOperationName(), context); results.insert>(context); } +OpFoldResult AffineLoadOp::fold(ArrayRef cstOperands) { + /// load(memrefcast) -> load + if (succeeded(foldMemRefCast(*this))) + return getResult(); + return OpFoldResult(); +} + //===----------------------------------------------------------------------===// // AffineStoreOp //===----------------------------------------------------------------------===// @@ -1959,11 +1925,15 @@ LogicalResult AffineStoreOp::verify() { void AffineStoreOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { - /// load(memrefcast) -> load - results.insert(getOperationName(), context); results.insert>(context); } +LogicalResult AffineStoreOp::fold(ArrayRef cstOperands, + SmallVectorImpl &results) { + /// store(memrefcast) -> store + return foldMemRefCast(*this); +} + //===----------------------------------------------------------------------===// // AffineMinOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp index b618ac0..51f1994 100644 --- a/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/QuantOps/IR/QuantOps.cpp @@ -32,38 +32,6 @@ using namespace mlir; using namespace mlir::quant; using namespace mlir::quant::detail; -#define GET_OP_CLASSES -#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc" - -namespace { - -/// Matches x -> [scast -> scast] -> y, replacing the second scast with the -/// value of x if the casts invert each other. -class RemoveRedundantStorageCastsRewrite - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - PatternMatchResult matchAndRewrite(StorageCastOp op, - PatternRewriter &rewriter) const override { - if (!matchPattern(op.arg(), m_Op())) - return matchFailure(); - auto srcScastOp = cast(op.arg()->getDefiningOp()); - if (srcScastOp.arg()->getType() != op.getType()) - return matchFailure(); - - rewriter.replaceOp(op, srcScastOp.arg()); - return matchSuccess(); - } -}; - -} // end anonymous namespace - -void StorageCastOp::getCanonicalizationPatterns( - OwningRewritePatternList &patterns, MLIRContext *context) { - patterns.insert(context); -} - QuantizationDialect::QuantizationDialect(MLIRContext *context) : Dialect(/*name=*/"quant", context) { addTypes(); } + +OpFoldResult StorageCastOp::fold(ArrayRef operands) { + /// Matches x -> [scast -> scast] -> y, replacing the second scast with the + /// value of x if the casts invert each other. + auto srcScastOp = dyn_cast_or_null(arg()->getDefiningOp()); + if (!srcScastOp || srcScastOp.arg()->getType() != getType()) + return OpFoldResult(); + return srcScastOp.arg(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/QuantOps/QuantOps.cpp.inc" diff --git a/mlir/lib/Dialect/StandardOps/Ops.cpp b/mlir/lib/Dialect/StandardOps/Ops.cpp index 713546f..3189e42 100644 --- a/mlir/lib/Dialect/StandardOps/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/Ops.cpp @@ -212,32 +212,20 @@ static detail::op_matcher m_ConstantIndex() { // Common canonicalization pattern support logic //===----------------------------------------------------------------------===// -namespace { /// This is a common class used for patterns of the form /// "someop(memrefcast) -> someop". It folds the source of any memref_cast /// into the root operation directly. -struct MemRefCastFolder : public RewritePattern { - /// The rootOpName is the name of the root operation to match against. - MemRefCastFolder(StringRef rootOpName, MLIRContext *context) - : RewritePattern(rootOpName, 1, context) {} - - PatternMatchResult match(Operation *op) const override { - for (auto *operand : op->getOperands()) - if (matchPattern(operand, m_Op())) - return matchSuccess(); - - return matchFailure(); - } - - void rewrite(Operation *op, PatternRewriter &rewriter) const override { - for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) - if (auto *memref = op->getOperand(i)->getDefiningOp()) - if (auto cast = dyn_cast(memref)) - op->setOperand(i, cast.getOperand()); - rewriter.updatedRootInPlace(op); +static LogicalResult foldMemRefCast(Operation *op) { + bool folded = false; + for (OpOperand &operand : op->getOpOperands()) { + auto cast = dyn_cast_or_null(operand.get()->getDefiningOp()); + if (cast && !cast.getOperand()->getType().isa()) { + operand.set(cast.getOperand()); + folded = true; + } } -}; -} // end anonymous namespace. + return success(folded); +} //===----------------------------------------------------------------------===// // AddFOp @@ -1232,11 +1220,15 @@ static LogicalResult verify(DeallocOp op) { void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - /// dealloc(memrefcast) -> dealloc - results.insert(getOperationName(), context); results.insert(context); } +LogicalResult DeallocOp::fold(ArrayRef cstOperands, + SmallVectorImpl &results) { + /// dealloc(memrefcast) -> dealloc + return foldMemRefCast(*this); +} + //===----------------------------------------------------------------------===// // DimOp //===----------------------------------------------------------------------===// @@ -1304,7 +1296,6 @@ OpFoldResult DimOp::fold(ArrayRef operands) { return {}; // The size at getIndex() is now a dynamic size of a memref. - auto memref = memrefOrTensor()->getDefiningOp(); if (auto alloc = dyn_cast_or_null(memref)) return *(alloc.getDynamicSizes().begin() + @@ -1321,13 +1312,11 @@ OpFoldResult DimOp::fold(ArrayRef operands) { return *(sizes.begin() + getIndex()); } - return {}; -} - -void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { /// dim(memrefcast) -> dim - results.insert(getOperationName(), context); + if (succeeded(foldMemRefCast(*this))) + return getResult(); + + return {}; } //===----------------------------------------------------------------------===// @@ -1507,10 +1496,10 @@ LogicalResult DmaStartOp::verify() { return success(); } -void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { +LogicalResult DmaStartOp::fold(ArrayRef cstOperands, + SmallVectorImpl &results) { /// dma_start(memrefcast) -> dma_start - results.insert(getOperationName(), context); + return foldMemRefCast(*this); } // --------------------------------------------------------------------------- @@ -1565,10 +1554,10 @@ ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { return success(); } -void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { +LogicalResult DmaWaitOp::fold(ArrayRef cstOperands, + SmallVectorImpl &results) { /// dma_wait(memrefcast) -> dma_wait - results.insert(getOperationName(), context); + return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// @@ -1688,10 +1677,11 @@ static LogicalResult verify(LoadOp op) { return success(); } -void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { +OpFoldResult LoadOp::fold(ArrayRef cstOperands) { /// load(memrefcast) -> load - results.insert(getOperationName(), context); + if (succeeded(foldMemRefCast(*this))) + return getResult(); + return OpFoldResult(); } //===----------------------------------------------------------------------===// @@ -2092,10 +2082,10 @@ static LogicalResult verify(StoreOp op) { return success(); } -void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results, - MLIRContext *context) { +LogicalResult StoreOp::fold(ArrayRef cstOperands, + SmallVectorImpl &results) { /// store(memrefcast) -> store - results.insert(getOperationName(), context); + return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// -- 2.7.4