// Common canonicalization pattern support logic
//===----------------------------------------------------------------------===//
-namespace {
/// This is a common class used for patterns of the form
/// "someop(memrefcast) -> someop". It folds the source of any memref_cast
/// into the root operation directly.
-struct MemRefCastFolder : public RewritePattern {
- /// The rootOpName is the name of the root operation to match against.
- MemRefCastFolder(StringRef rootOpName, MLIRContext *context)
- : RewritePattern(rootOpName, 1, context) {}
-
- PatternMatchResult match(Operation *op) const override {
- for (auto *operand : op->getOperands())
- if (matchPattern(operand, m_Op<MemRefCastOp>()))
- return matchSuccess();
-
- return matchFailure();
- }
-
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
- if (auto *memref = op->getOperand(i)->getDefiningOp())
- if (auto cast = dyn_cast<MemRefCastOp>(memref))
- op->setOperand(i, cast.getOperand());
- rewriter.updatedRootInPlace(op);
+static LogicalResult foldMemRefCast(Operation *op) {
+ bool folded = false;
+ for (OpOperand &operand : op->getOpOperands()) {
+ auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get()->getDefiningOp());
+ if (cast && !cast.getOperand()->getType().isa<UnrankedMemRefType>()) {
+ operand.set(cast.getOperand());
+ folded = true;
+ }
}
-};
-
-} // end anonymous namespace.
+ return success(folded);
+}
//===----------------------------------------------------------------------===//
// AffineDmaStartOp
return success();
}
-void AffineDmaStartOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
+LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
/// dma_start(memrefcast) -> dma_start
- results.insert<MemRefCastFolder>(getOperationName(), context);
+ return foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
return success();
}
-void AffineDmaWaitOp::getCanonicalizationPatterns(
- OwningRewritePatternList &results, MLIRContext *context) {
+LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
/// dma_wait(memrefcast) -> dma_wait
- results.insert<MemRefCastFolder>(getOperationName(), context);
+ return foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
"expected valid affine map representation for loop bounds");
}
-ParseResult parseAffineForOp(OpAsmParser &parser, OperationState &result) {
+static ParseResult parseAffineForOp(OpAsmParser &parser,
+ OperationState &result) {
auto &builder = parser.getBuilder();
OpAsmParser::OperandType inductionVariable;
// Parse the induction variable followed by '='.
map.getNumDims(), p);
}
-void print(OpAsmPrinter &p, AffineForOp op) {
+static void print(OpAsmPrinter &p, AffineForOp op) {
p << "affine.for ";
p.printOperand(op.getBody()->getArgument(0));
p << " = ";
op.getStepAttrName()});
}
-namespace {
-/// This is a pattern to fold trivially empty loops.
-struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
- using OpRewritePattern<AffineForOp>::OpRewritePattern;
+/// Fold the constant bounds of a loop.
+static LogicalResult foldLoopBounds(AffineForOp forOp) {
+ auto foldLowerOrUpperBound = [&forOp](bool lower) {
+ // Check to see if each of the operands is the result of a constant. If
+ // so, get the value. If not, ignore it.
+ SmallVector<Attribute, 8> operandConstants;
+ auto boundOperands =
+ lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
+ for (auto *operand : boundOperands) {
+ Attribute operandCst;
+ matchPattern(operand, m_Constant(&operandCst));
+ operandConstants.push_back(operandCst);
+ }
- PatternMatchResult matchAndRewrite(AffineForOp forOp,
- PatternRewriter &rewriter) const override {
- // Check that the body only contains a terminator.
- if (!has_single_element(*forOp.getBody()))
- return matchFailure();
- rewriter.eraseOp(forOp);
- return matchSuccess();
- }
-};
+ AffineMap boundMap =
+ lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
+ assert(boundMap.getNumResults() >= 1 &&
+ "bound maps should have at least one result");
+ SmallVector<Attribute, 4> foldedResults;
+ if (failed(boundMap.constantFold(operandConstants, foldedResults)))
+ return failure();
-/// This is a pattern to fold constant loop bounds.
-struct AffineForOpBoundFolder : public OpRewritePattern<AffineForOp> {
- using OpRewritePattern<AffineForOp>::OpRewritePattern;
+ // Compute the max or min as applicable over the results.
+ assert(!foldedResults.empty() && "bounds should have at least one result");
+ auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
+ for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
+ auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
+ maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
+ : llvm::APIntOps::smin(maxOrMin, foldedResult);
+ }
+ lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
+ : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
+ return success();
+ };
- PatternMatchResult matchAndRewrite(AffineForOp forOp,
- PatternRewriter &rewriter) const override {
- auto foldLowerOrUpperBound = [&forOp](bool lower) {
- // Check to see if each of the operands is the result of a constant. If
- // so, get the value. If not, ignore it.
- SmallVector<Attribute, 8> operandConstants;
- auto boundOperands =
- lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
- for (auto *operand : boundOperands) {
- Attribute operandCst;
- matchPattern(operand, m_Constant(&operandCst));
- operandConstants.push_back(operandCst);
- }
+ // Try to fold the lower bound.
+ bool folded = false;
+ if (!forOp.hasConstantLowerBound())
+ folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
- AffineMap boundMap =
- lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
- assert(boundMap.getNumResults() >= 1 &&
- "bound maps should have at least one result");
- SmallVector<Attribute, 4> foldedResults;
- if (failed(boundMap.constantFold(operandConstants, foldedResults)))
- return failure();
-
- // Compute the max or min as applicable over the results.
- assert(!foldedResults.empty() &&
- "bounds should have at least one result");
- auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
- for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
- auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
- maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
- : llvm::APIntOps::smin(maxOrMin, foldedResult);
- }
- lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
- : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
- return success();
- };
-
- // Try to fold the lower bound.
- bool folded = false;
- if (!forOp.hasConstantLowerBound())
- folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
-
- // Try to fold the upper bound.
- if (!forOp.hasConstantUpperBound())
- folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
-
- // If any of the bounds were folded we return success.
- if (!folded)
- return matchFailure();
- rewriter.updatedRootInPlace(forOp);
- return matchSuccess();
- }
-};
+ // Try to fold the upper bound.
+ if (!forOp.hasConstantUpperBound())
+ folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
+ return success(folded);
+}
-// This is a pattern to canonicalize affine for op loop bounds.
-struct AffineForOpBoundCanonicalizer : public OpRewritePattern<AffineForOp> {
- using OpRewritePattern<AffineForOp>::OpRewritePattern;
+/// Canonicalize the bounds of the given loop.
+static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
+ SmallVector<Value *, 4> lbOperands(forOp.getLowerBoundOperands());
+ SmallVector<Value *, 4> ubOperands(forOp.getUpperBoundOperands());
- PatternMatchResult matchAndRewrite(AffineForOp forOp,
- PatternRewriter &rewriter) const override {
- SmallVector<Value *, 4> lbOperands(forOp.getLowerBoundOperands());
- SmallVector<Value *, 4> ubOperands(forOp.getUpperBoundOperands());
+ auto lbMap = forOp.getLowerBoundMap();
+ auto ubMap = forOp.getUpperBoundMap();
+ auto prevLbMap = lbMap;
+ auto prevUbMap = ubMap;
- auto lbMap = forOp.getLowerBoundMap();
- auto ubMap = forOp.getUpperBoundMap();
- auto prevLbMap = lbMap;
- auto prevUbMap = ubMap;
+ canonicalizeMapAndOperands(&lbMap, &lbOperands);
+ canonicalizeMapAndOperands(&ubMap, &ubOperands);
- canonicalizeMapAndOperands(&lbMap, &lbOperands);
- canonicalizeMapAndOperands(&ubMap, &ubOperands);
+ // Any canonicalization change always leads to updated map(s).
+ if (lbMap == prevLbMap && ubMap == prevUbMap)
+ return failure();
- // Any canonicalization change always leads to updated map(s).
- if (lbMap == prevLbMap && ubMap == prevUbMap)
- return matchFailure();
+ if (lbMap != prevLbMap)
+ forOp.setLowerBound(lbOperands, lbMap);
+ if (ubMap != prevUbMap)
+ forOp.setUpperBound(ubOperands, ubMap);
+ return success();
+}
- if (lbMap != prevLbMap)
- forOp.setLowerBound(lbOperands, lbMap);
- if (ubMap != prevUbMap)
- forOp.setUpperBound(ubOperands, ubMap);
+namespace {
+/// This is a pattern to fold trivially empty loops.
+struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
+ using OpRewritePattern<AffineForOp>::OpRewritePattern;
- rewriter.updatedRootInPlace(forOp);
+ PatternMatchResult matchAndRewrite(AffineForOp forOp,
+ PatternRewriter &rewriter) const override {
+ // Check that the body only contains a terminator.
+ if (!has_single_element(*forOp.getBody()))
+ return matchFailure();
+ rewriter.eraseOp(forOp);
return matchSuccess();
}
};
-
} // end anonymous namespace
void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- results.insert<AffineForEmptyLoopFolder, AffineForOpBoundFolder,
- AffineForOpBoundCanonicalizer>(context);
+ results.insert<AffineForEmptyLoopFolder>(context);
+}
+
+LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ bool folded = succeeded(foldLoopBounds(*this));
+ folded |= succeeded(canonicalizeLoopBounds(*this));
+ return success(folded);
}
AffineBound AffineForOp::getLowerBound() {
AffineIfOp::ensureTerminator(*elseRegion, *builder, result.location);
}
-namespace {
-// This is a pattern to canonicalize an affine if op's conditional (integer
-// set + operands).
-struct AffineIfOpCanonicalizer : public OpRewritePattern<AffineIfOp> {
- using OpRewritePattern<AffineIfOp>::OpRewritePattern;
+/// Canonicalize an affine if op's conditional (integer set + operands).
+LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
+ SmallVectorImpl<OpFoldResult> &) {
+ auto set = getIntegerSet();
+ SmallVector<Value *, 4> operands(getOperands());
+ canonicalizeSetAndOperands(&set, &operands);
- PatternMatchResult matchAndRewrite(AffineIfOp ifOp,
- PatternRewriter &rewriter) const override {
- auto set = ifOp.getIntegerSet();
- SmallVector<Value *, 4> operands(ifOp.getOperands());
-
- canonicalizeSetAndOperands(&set, &operands);
-
- // Any canonicalization change always leads to either a reduction in the
- // number of operands or a change in the number of symbolic operands
- // (promotion of dims to symbols).
- if (operands.size() < ifOp.getIntegerSet().getNumInputs() ||
- set.getNumSymbols() > ifOp.getIntegerSet().getNumSymbols()) {
- ifOp.setConditional(set, operands);
- rewriter.updatedRootInPlace(ifOp);
- return matchSuccess();
- }
-
- return matchFailure();
+ // Any canonicalization change always leads to either a reduction in the
+ // number of operands or a change in the number of symbolic operands
+ // (promotion of dims to symbols).
+ if (operands.size() < getIntegerSet().getNumInputs() ||
+ set.getNumSymbols() > getIntegerSet().getNumSymbols()) {
+ setConditional(set, operands);
+ return success();
}
-};
-} // end anonymous namespace
-void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
- results.insert<AffineIfOpCanonicalizer>(context);
+ return failure();
}
//===----------------------------------------------------------------------===//
void AffineLoadOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
- /// load(memrefcast) -> load
- results.insert<MemRefCastFolder>(getOperationName(), context);
results.insert<SimplifyAffineOp<AffineLoadOp>>(context);
}
+OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
+ /// load(memrefcast) -> load
+ if (succeeded(foldMemRefCast(*this)))
+ return getResult();
+ return OpFoldResult();
+}
+
//===----------------------------------------------------------------------===//
// AffineStoreOp
//===----------------------------------------------------------------------===//
void AffineStoreOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
- /// load(memrefcast) -> load
- results.insert<MemRefCastFolder>(getOperationName(), context);
results.insert<SimplifyAffineOp<AffineStoreOp>>(context);
}
+LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ /// store(memrefcast) -> store
+ return foldMemRefCast(*this);
+}
+
//===----------------------------------------------------------------------===//
// AffineMinOp
//===----------------------------------------------------------------------===//
// Common canonicalization pattern support logic
//===----------------------------------------------------------------------===//
-namespace {
/// This is a common class used for patterns of the form
/// "someop(memrefcast) -> someop". It folds the source of any memref_cast
/// into the root operation directly.
-struct MemRefCastFolder : public RewritePattern {
- /// The rootOpName is the name of the root operation to match against.
- MemRefCastFolder(StringRef rootOpName, MLIRContext *context)
- : RewritePattern(rootOpName, 1, context) {}
-
- PatternMatchResult match(Operation *op) const override {
- for (auto *operand : op->getOperands())
- if (matchPattern(operand, m_Op<MemRefCastOp>()))
- return matchSuccess();
-
- return matchFailure();
- }
-
- void rewrite(Operation *op, PatternRewriter &rewriter) const override {
- for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
- if (auto *memref = op->getOperand(i)->getDefiningOp())
- if (auto cast = dyn_cast<MemRefCastOp>(memref))
- op->setOperand(i, cast.getOperand());
- rewriter.updatedRootInPlace(op);
+static LogicalResult foldMemRefCast(Operation *op) {
+ bool folded = false;
+ for (OpOperand &operand : op->getOpOperands()) {
+ auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get()->getDefiningOp());
+ if (cast && !cast.getOperand()->getType().isa<UnrankedMemRefType>()) {
+ operand.set(cast.getOperand());
+ folded = true;
+ }
}
-};
-} // end anonymous namespace.
+ return success(folded);
+}
//===----------------------------------------------------------------------===//
// AddFOp
void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
- /// dealloc(memrefcast) -> dealloc
- results.insert<MemRefCastFolder>(getOperationName(), context);
results.insert<SimplifyDeadDealloc>(context);
}
+LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
+ /// dealloc(memrefcast) -> dealloc
+ return foldMemRefCast(*this);
+}
+
//===----------------------------------------------------------------------===//
// DimOp
//===----------------------------------------------------------------------===//
return {};
// The size at getIndex() is now a dynamic size of a memref.
-
auto memref = memrefOrTensor()->getDefiningOp();
if (auto alloc = dyn_cast_or_null<AllocOp>(memref))
return *(alloc.getDynamicSizes().begin() +
return *(sizes.begin() + getIndex());
}
- return {};
-}
-
-void DimOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
/// dim(memrefcast) -> dim
- results.insert<MemRefCastFolder>(getOperationName(), context);
+ if (succeeded(foldMemRefCast(*this)))
+ return getResult();
+
+ return {};
}
//===----------------------------------------------------------------------===//
return success();
}
-void DmaStartOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
+LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
/// dma_start(memrefcast) -> dma_start
- results.insert<MemRefCastFolder>(getOperationName(), context);
+ return foldMemRefCast(*this);
}
// ---------------------------------------------------------------------------
return success();
}
-void DmaWaitOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
+LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
/// dma_wait(memrefcast) -> dma_wait
- results.insert<MemRefCastFolder>(getOperationName(), context);
+ return foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//
return success();
}
-void LoadOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
+OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
/// load(memrefcast) -> load
- results.insert<MemRefCastFolder>(getOperationName(), context);
+ if (succeeded(foldMemRefCast(*this)))
+ return getResult();
+ return OpFoldResult();
}
//===----------------------------------------------------------------------===//
return success();
}
-void StoreOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
- MLIRContext *context) {
+LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
+ SmallVectorImpl<OpFoldResult> &results) {
/// store(memrefcast) -> store
- results.insert<MemRefCastFolder>(getOperationName(), context);
+ return foldMemRefCast(*this);
}
//===----------------------------------------------------------------------===//