From 40715789f84df3e230fa29fcd397d6603630daaa Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 15 Jul 2019 07:35:00 -0700 Subject: [PATCH] Refactor LowerAffine to use OpRewritePattern instead of ConversionPattern. ConversionPattern should ideally only be used when the types of the operands are changing, which in this case they aren't. Using OpRewritePattern also lends to much simpler code. PiperOrigin-RevId: 258158474 --- mlir/lib/Transforms/LowerAffine.cpp | 167 +++++++++++++++--------------------- 1 file changed, 70 insertions(+), 97 deletions(-) diff --git a/mlir/lib/Transforms/LowerAffine.cpp b/mlir/lib/Transforms/LowerAffine.cpp index 76846c0..0ff80ca 100644 --- a/mlir/lib/Transforms/LowerAffine.cpp +++ b/mlir/lib/Transforms/LowerAffine.cpp @@ -295,55 +295,48 @@ Value *mlir::lowerAffineUpperBound(AffineForOp op, OpBuilder &builder) { namespace { // Affine terminators are removed. -class AffineTerminatorLowering : public ConversionPattern { +class AffineTerminatorLowering : public OpRewritePattern { public: - AffineTerminatorLowering(MLIRContext *ctx) - : ConversionPattern(AffineTerminatorOp::getOperationName(), 1, ctx) {} + using OpRewritePattern::OpRewritePattern; - virtual PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { + PatternMatchResult matchAndRewrite(AffineTerminatorOp op, + PatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op); return matchSuccess(); } }; -class AffineForLowering : public ConversionPattern { +class AffineForLowering : public OpRewritePattern { public: - AffineForLowering(MLIRContext *ctx) - : ConversionPattern(AffineForOp::getOperationName(), 1, ctx) {} - - virtual PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { - auto affineForOp = cast(op); - Location loc = op->getLoc(); - Value *lowerBound = lowerAffineLowerBound(affineForOp, rewriter); - Value *upperBound = lowerAffineUpperBound(affineForOp, rewriter); - Value *step = rewriter.create(loc, affineForOp.getStep()); + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(AffineForOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value *lowerBound = lowerAffineLowerBound(op, rewriter); + Value *upperBound = lowerAffineUpperBound(op, rewriter); + Value *step = rewriter.create(loc, op.getStep()); auto f = rewriter.create(loc, lowerBound, upperBound, step); f.region().getBlocks().clear(); - rewriter.inlineRegionBefore(affineForOp.getRegion(), f.region(), - f.region().end()); + rewriter.inlineRegionBefore(op.getRegion(), f.region(), f.region().end()); rewriter.replaceOp(op, {}); return matchSuccess(); } }; -class AffineIfLowering : public ConversionPattern { +class AffineIfLowering : public OpRewritePattern { public: - AffineIfLowering(MLIRContext *ctx) - : ConversionPattern(AffineIfOp::getOperationName(), 1, ctx) {} + using OpRewritePattern::OpRewritePattern; - virtual PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { - auto affineIfOp = cast(op); - auto loc = op->getLoc(); + PatternMatchResult matchAndRewrite(AffineIfOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); // Now we just have to handle the condition logic. - auto integerSet = affineIfOp.getIntegerSet(); + auto integerSet = op.getIntegerSet(); Value *zeroConstant = rewriter.create(loc, 0); + SmallVector operands(op.getOperands()); + auto operandsRef = llvm::makeArrayRef(operands); // Calculate cond as a conjunction without short-circuiting. Value *cond = nullptr; @@ -354,8 +347,8 @@ public: // Build and apply an affine expression auto numDims = integerSet.getNumDims(); Value *affResult = expandAffineExpr(rewriter, loc, constraintExpr, - operands.take_front(numDims), - operands.drop_front(numDims)); + operandsRef.take_front(numDims), + operandsRef.drop_front(numDims)); if (!affResult) return matchFailure(); auto pred = isEquality ? CmpIPredicate::EQ : CmpIPredicate::SGE; @@ -367,13 +360,12 @@ public: cond = cond ? cond : rewriter.create(loc, /*value=*/1, /*width=*/1); - bool hasElseRegion = !affineIfOp.getElseBlocks().empty(); + bool hasElseRegion = !op.getElseBlocks().empty(); auto ifOp = rewriter.create(loc, cond, hasElseRegion); - rewriter.inlineRegionBefore(affineIfOp.getThenBlocks(), - &ifOp.thenRegion().back()); + rewriter.inlineRegionBefore(op.getThenBlocks(), &ifOp.thenRegion().back()); ifOp.thenRegion().back().erase(); if (hasElseRegion) { - rewriter.inlineRegionBefore(affineIfOp.getElseBlocks(), + rewriter.inlineRegionBefore(op.getElseBlocks(), &ifOp.elseRegion().back()); ifOp.elseRegion().back().erase(); } @@ -386,17 +378,15 @@ public: // Convert an "affine.apply" operation into a sequence of arithmetic // operations using the StandardOps dialect. -class AffineApplyLowering : public ConversionPattern { +class AffineApplyLowering : public OpRewritePattern { public: - AffineApplyLowering(MLIRContext *ctx) - : ConversionPattern(AffineApplyOp::getOperationName(), 1, ctx) {} + using OpRewritePattern::OpRewritePattern; virtual PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { - auto affineApplyOp = cast(op); - auto maybeExpandedMap = expandAffineMap( - rewriter, op->getLoc(), affineApplyOp.getAffineMap(), operands); + matchAndRewrite(AffineApplyOp op, PatternRewriter &rewriter) const override { + auto maybeExpandedMap = + expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), + llvm::to_vector<8>(op.getOperands())); if (!maybeExpandedMap) return matchFailure(); rewriter.replaceOp(op, *maybeExpandedMap); @@ -407,23 +397,21 @@ public: // Apply the affine map from an 'affine.load' operation to its operands, and // feed the results to a newly created 'std.load' operation (which replaces the // original 'affine.load'). -class AffineLoadLowering : public ConversionPattern { +class AffineLoadLowering : public OpRewritePattern { public: - AffineLoadLowering(MLIRContext *ctx) - : ConversionPattern(AffineLoadOp::getOperationName(), 1, ctx) {} + using OpRewritePattern::OpRewritePattern; virtual PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { - auto affineLoadOp = cast(op); + matchAndRewrite(AffineLoadOp op, PatternRewriter &rewriter) const override { // Expand affine map from 'affineLoadOp'. + SmallVector indices(op.getIndices()); auto maybeExpandedMap = - expandAffineMap(rewriter, op->getLoc(), affineLoadOp.getAffineMap(), - operands.drop_front()); + expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); if (!maybeExpandedMap) return matchFailure(); + // Build std.load memref[expandedMap.results]. - rewriter.replaceOpWithNewOp(op, operands[0], *maybeExpandedMap); + rewriter.replaceOpWithNewOp(op, op.getMemRef(), *maybeExpandedMap); return matchSuccess(); } }; @@ -431,24 +419,22 @@ public: // Apply the affine map from an 'affine.store' operation to its operands, and // feed the results to a newly created 'std.store' operation (which replaces the // original 'affine.store'). -class AffineStoreLowering : public ConversionPattern { +class AffineStoreLowering : public OpRewritePattern { public: - AffineStoreLowering(MLIRContext *ctx) - : ConversionPattern(AffineStoreOp::getOperationName(), 1, ctx) {} + using OpRewritePattern::OpRewritePattern; virtual PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - PatternRewriter &rewriter) const override { - auto affineStoreOp = cast(op); + matchAndRewrite(AffineStoreOp op, PatternRewriter &rewriter) const override { // Expand affine map from 'affineStoreOp'. + SmallVector indices(op.getIndices()); auto maybeExpandedMap = - expandAffineMap(rewriter, op->getLoc(), affineStoreOp.getAffineMap(), - operands.drop_front(2)); + expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), indices); if (!maybeExpandedMap) return matchFailure(); + // Build std.store valutToStore, memref[expandedMap.results]. - rewriter.replaceOpWithNewOp(op, operands[0], operands[1], - *maybeExpandedMap); + rewriter.replaceOpWithNewOp(op, op.getValueToStore(), + op.getMemRef(), *maybeExpandedMap); return matchSuccess(); } }; @@ -456,49 +442,40 @@ public: // Apply the affine maps from an 'affine.dma_start' operation to each of their // respective map operands, and feed the results to a newly created // 'std.dma_start' operation (which replaces the original 'affine.dma_start'). -class AffineDmaStartLowering : public ConversionPattern { +class AffineDmaStartLowering : public OpRewritePattern { public: - AffineDmaStartLowering(MLIRContext *ctx) - : ConversionPattern(AffineDmaStartOp::getOperationName(), 1, ctx) {} + using OpRewritePattern::OpRewritePattern; virtual PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(AffineDmaStartOp op, PatternRewriter &rewriter) const override { - auto affineDmaStartOp = cast(op); + SmallVector operands(op.getOperands()); + auto operandsRef = llvm::makeArrayRef(operands); + // Expand affine map for DMA source memref. auto maybeExpandedSrcMap = expandAffineMap( - rewriter, op->getLoc(), affineDmaStartOp.getSrcMap(), - operands.drop_front(affineDmaStartOp.getSrcMemRefOperandIndex() + 1)); + rewriter, op.getLoc(), op.getSrcMap(), + operandsRef.drop_front(op.getSrcMemRefOperandIndex() + 1)); if (!maybeExpandedSrcMap) return matchFailure(); // Expand affine map for DMA destination memref. auto maybeExpandedDstMap = expandAffineMap( - rewriter, op->getLoc(), affineDmaStartOp.getDstMap(), - operands.drop_front(affineDmaStartOp.getDstMemRefOperandIndex() + 1)); + rewriter, op.getLoc(), op.getDstMap(), + operandsRef.drop_front(op.getDstMemRefOperandIndex() + 1)); if (!maybeExpandedDstMap) return matchFailure(); // Expand affine map for DMA tag memref. auto maybeExpandedTagMap = expandAffineMap( - rewriter, op->getLoc(), affineDmaStartOp.getTagMap(), - operands.drop_front(affineDmaStartOp.getTagMemRefOperandIndex() + 1)); + rewriter, op.getLoc(), op.getTagMap(), + operandsRef.drop_front(op.getTagMemRefOperandIndex() + 1)); if (!maybeExpandedTagMap) return matchFailure(); // Build std.dma_start operation with affine map results. - auto *srcMemRef = operands[affineDmaStartOp.getSrcMemRefOperandIndex()]; - auto *dstMemRef = operands[affineDmaStartOp.getDstMemRefOperandIndex()]; - auto *tagMemRef = operands[affineDmaStartOp.getTagMemRefOperandIndex()]; - unsigned numElementsIndex = affineDmaStartOp.getTagMemRefOperandIndex() + - 1 + affineDmaStartOp.getTagMap().getNumInputs(); - auto *numElements = operands[numElementsIndex]; - auto *stride = - affineDmaStartOp.isStrided() ? operands[numElementsIndex + 1] : nullptr; - auto *eltsPerStride = - affineDmaStartOp.isStrided() ? operands[numElementsIndex + 2] : nullptr; - rewriter.replaceOpWithNewOp( - op, srcMemRef, *maybeExpandedSrcMap, dstMemRef, *maybeExpandedDstMap, - numElements, tagMemRef, *maybeExpandedTagMap, stride, eltsPerStride); + op, op.getSrcMemRef(), *maybeExpandedSrcMap, op.getDstMemRef(), + *maybeExpandedDstMap, op.getNumElements(), op.getTagMemRef(), + *maybeExpandedTagMap, op.getStride(), op.getNumElementsPerStride()); return matchSuccess(); } }; @@ -506,26 +483,23 @@ public: // Apply the affine map from an 'affine.dma_wait' operation tag memref, // and feed the results to a newly created 'std.dma_wait' operation (which // replaces the original 'affine.dma_wait'). -class AffineDmaWaitLowering : public ConversionPattern { +class AffineDmaWaitLowering : public OpRewritePattern { public: - AffineDmaWaitLowering(MLIRContext *ctx) - : ConversionPattern(AffineDmaWaitOp::getOperationName(), 1, ctx) {} + using OpRewritePattern::OpRewritePattern; virtual PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, + matchAndRewrite(AffineDmaWaitOp op, PatternRewriter &rewriter) const override { - auto affineDmaWaitOp = cast(op); // Expand affine map for DMA tag memref. + SmallVector indices(op.getTagIndices()); auto maybeExpandedTagMap = - expandAffineMap(rewriter, op->getLoc(), affineDmaWaitOp.getTagMap(), - operands.drop_front()); + expandAffineMap(rewriter, op.getLoc(), op.getTagMap(), indices); if (!maybeExpandedTagMap) return matchFailure(); // Build std.dma_wait operation with affine map results. - unsigned numElementsIndex = 1 + affineDmaWaitOp.getTagMap().getNumInputs(); rewriter.replaceOpWithNewOp( - op, operands[0], *maybeExpandedTagMap, operands[numElementsIndex]); + op, op.getTagMemRef(), *maybeExpandedTagMap, op.getNumElements()); return matchSuccess(); } }; @@ -540,8 +514,7 @@ LogicalResult mlir::lowerAffineConstructs(FuncOp function) { AffineTerminatorLowering>::build(patterns, function.getContext()); ConversionTarget target(*function.getContext()); - target.addLegalDialect(); - target.addLegalDialect(); + target.addLegalDialect(); return applyConversionPatterns(function, target, std::move(patterns)); } -- 2.7.4