return matchSuccess();
}
};
+
+// Apply the affine map from an 'affine.load' operation to its operands, and
+// feed the results to a newly created 'std.load' operation (which replaces the
+// original 'affine.load').
+class AffineLoadLowering : public ConversionPattern {
+public:
+ AffineLoadLowering(MLIRContext *ctx)
+ : ConversionPattern(AffineLoadOp::getOperationName(), 1, ctx) {}
+
+ virtual PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ PatternRewriter &rewriter) const override {
+ auto affineLoadOp = cast<AffineLoadOp>(op);
+ // Expand affine map from 'affineLoadOp'.
+ auto maybeExpandedMap =
+ expandAffineMap(rewriter, op->getLoc(), affineLoadOp.getAffineMap(),
+ operands.drop_front());
+ if (!maybeExpandedMap)
+ return matchFailure();
+ // Build std.load memref[expandedMap.results].
+ rewriter.replaceOpWithNewOp<LoadOp>(op, operands[0], *maybeExpandedMap);
+ return matchSuccess();
+ }
+};
+
+// Apply the affine map from an 'affine.store' operation to its operands, and
+// feed the results to a newly created 'std.store' operation (which replaces the
+// original 'affine.store').
+class AffineStoreLowering : public ConversionPattern {
+public:
+ AffineStoreLowering(MLIRContext *ctx)
+ : ConversionPattern(AffineStoreOp::getOperationName(), 1, ctx) {}
+
+ virtual PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ PatternRewriter &rewriter) const override {
+ auto affineStoreOp = cast<AffineStoreOp>(op);
+ // Expand affine map from 'affineStoreOp'.
+ auto maybeExpandedMap =
+ expandAffineMap(rewriter, op->getLoc(), affineStoreOp.getAffineMap(),
+ operands.drop_front(2));
+ if (!maybeExpandedMap)
+ return matchFailure();
+ // Build std.store valutToStore, memref[expandedMap.results].
+ rewriter.replaceOpWithNewOp<StoreOp>(op, operands[0], operands[1],
+ *maybeExpandedMap);
+ return matchSuccess();
+ }
+};
+
+// Apply the affine maps from an 'affine.dma_start' operation to each of their
+// respective map operands, and feed the results to a newly created
+// 'std.dma_start' operation (which replaces the original 'affine.dma_start').
+class AffineDmaStartLowering : public ConversionPattern {
+public:
+ AffineDmaStartLowering(MLIRContext *ctx)
+ : ConversionPattern(AffineDmaStartOp::getOperationName(), 1, ctx) {}
+
+ virtual PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ PatternRewriter &rewriter) const override {
+ auto affineDmaStartOp = cast<AffineDmaStartOp>(op);
+ // Expand affine map for DMA source memref.
+ auto maybeExpandedSrcMap = expandAffineMap(
+ rewriter, op->getLoc(), affineDmaStartOp.getSrcMap(),
+ operands.drop_front(affineDmaStartOp.getSrcMemRefOperandIndex() + 1));
+ if (!maybeExpandedSrcMap)
+ return matchFailure();
+ // Expand affine map for DMA destination memref.
+ auto maybeExpandedDstMap = expandAffineMap(
+ rewriter, op->getLoc(), affineDmaStartOp.getDstMap(),
+ operands.drop_front(affineDmaStartOp.getDstMemRefOperandIndex() + 1));
+ if (!maybeExpandedDstMap)
+ return matchFailure();
+ // Expand affine map for DMA tag memref.
+ auto maybeExpandedTagMap = expandAffineMap(
+ rewriter, op->getLoc(), affineDmaStartOp.getTagMap(),
+ operands.drop_front(affineDmaStartOp.getTagMemRefOperandIndex() + 1));
+ if (!maybeExpandedTagMap)
+ return matchFailure();
+
+ // Build std.dma_start operation with affine map results.
+ auto *srcMemRef = operands[affineDmaStartOp.getSrcMemRefOperandIndex()];
+ auto *dstMemRef = operands[affineDmaStartOp.getDstMemRefOperandIndex()];
+ auto *tagMemRef = operands[affineDmaStartOp.getTagMemRefOperandIndex()];
+ unsigned numElementsIndex = affineDmaStartOp.getTagMemRefOperandIndex() +
+ 1 + affineDmaStartOp.getTagMap().getNumInputs();
+ auto *numElements = operands[numElementsIndex];
+ auto *stride =
+ affineDmaStartOp.isStrided() ? operands[numElementsIndex + 1] : nullptr;
+ auto *eltsPerStride =
+ affineDmaStartOp.isStrided() ? operands[numElementsIndex + 2] : nullptr;
+
+ rewriter.replaceOpWithNewOp<DmaStartOp>(
+ op, srcMemRef, *maybeExpandedSrcMap, dstMemRef, *maybeExpandedDstMap,
+ numElements, tagMemRef, *maybeExpandedTagMap, stride, eltsPerStride);
+ return matchSuccess();
+ }
+};
+
+// Apply the affine map from an 'affine.dma_wait' operation tag memref,
+// and feed the results to a newly created 'std.dma_wait' operation (which
+// replaces the original 'affine.dma_wait').
+class AffineDmaWaitLowering : public ConversionPattern {
+public:
+ AffineDmaWaitLowering(MLIRContext *ctx)
+ : ConversionPattern(AffineDmaWaitOp::getOperationName(), 1, ctx) {}
+
+ virtual PatternMatchResult
+ matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+ PatternRewriter &rewriter) const override {
+ auto affineDmaWaitOp = cast<AffineDmaWaitOp>(op);
+ // Expand affine map for DMA tag memref.
+ auto maybeExpandedTagMap =
+ expandAffineMap(rewriter, op->getLoc(), affineDmaWaitOp.getTagMap(),
+ operands.drop_front());
+ if (!maybeExpandedTagMap)
+ return matchFailure();
+
+ // Build std.dma_wait operation with affine map results.
+ unsigned numElementsIndex = 1 + affineDmaWaitOp.getTagMap().getNumInputs();
+ rewriter.replaceOpWithNewOp<DmaWaitOp>(
+ op, operands[0], *maybeExpandedTagMap, operands[numElementsIndex]);
+ return matchSuccess();
+ }
+};
+
} // end namespace
LogicalResult mlir::lowerAffineConstructs(Function &function) {
OwningRewritePatternList patterns;
- RewriteListBuilder<AffineApplyLowering, AffineForLowering, AffineIfLowering,
+ RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering,
+ AffineDmaWaitLowering, AffineLoadLowering,
+ AffineStoreLowering, AffineForLowering, AffineIfLowering,
AffineTerminatorLowering>::build(patterns,
function.getContext());
ConversionTarget target(*function.getContext());
%0 = affine.apply #mapceildiv (%arg0)
return %0 : index
}
+
+// CHECK-LABEL: func @affine_load
+func @affine_load(%arg0 : index) {
+ %0 = alloc() : memref<10xf32>
+ affine.for %i0 = 0 to 10 {
+ %1 = affine.load %0[%i0 + symbol(%arg0) + 7] : memref<10xf32>
+ }
+// CHECK: %3 = addi %1, %arg0 : index
+// CHECK-NEXT: %c7 = constant 7 : index
+// CHECK-NEXT: %4 = addi %3, %c7 : index
+// CHECK-NEXT: %5 = load %0[%4] : memref<10xf32>
+ return
+}
+
+// CHECK-LABEL: func @affine_store
+func @affine_store(%arg0 : index) {
+ %0 = alloc() : memref<10xf32>
+ %1 = constant 11.0 : f32
+ affine.for %i0 = 0 to 10 {
+ affine.store %1, %0[%i0 - symbol(%arg0) + 7] : memref<10xf32>
+ }
+// CHECK: %c-1 = constant -1 : index
+// CHECK-NEXT: %3 = muli %arg0, %c-1 : index
+// CHECK-NEXT: %4 = addi %1, %3 : index
+// CHECK-NEXT: %c7 = constant 7 : index
+// CHECK-NEXT: %5 = addi %4, %c7 : index
+// CHECK-NEXT: store %cst, %0[%5] : memref<10xf32>
+ return
+}
+
+// CHECK-LABEL: func @affine_dma_start
+func @affine_dma_start(%arg0 : index) {
+ %0 = alloc() : memref<100xf32>
+ %1 = alloc() : memref<100xf32, 2>
+ %2 = alloc() : memref<1xi32>
+ %c0 = constant 0 : index
+ %c64 = constant 64 : index
+ affine.for %i0 = 0 to 10 {
+ affine.dma_start %0[%i0 + 7], %1[%arg0 + 11], %2[%c0], %c64
+ : memref<100xf32>, memref<100xf32, 2>, memref<1xi32>
+ }
+// CHECK: %c7 = constant 7 : index
+// CHECK-NEXT: %5 = addi %3, %c7 : index
+// CHECK-NEXT: %c11 = constant 11 : index
+// CHECK-NEXT: %6 = addi %arg0, %c11 : index
+// CHECK-NEXT: dma_start %0[%5], %1[%6], %c64, %2[%c0] : memref<100xf32>, memref<100xf32, 2>, memref<1xi32>
+ return
+}
+
+// CHECK-LABEL: func @affine_dma_wait
+func @affine_dma_wait(%arg0 : index) {
+ %2 = alloc() : memref<1xi32>
+ %c64 = constant 64 : index
+ affine.for %i0 = 0 to 10 {
+ affine.dma_wait %2[%i0 + %arg0 + 17], %c64 : memref<1xi32>
+ }
+// CHECK: %3 = addi %1, %arg0 : index
+// CHECK-NEXT: %c17 = constant 17 : index
+// CHECK-NEXT: %4 = addi %3, %c17 : index
+// CHECK-NEXT: dma_wait %0[%4], %c64 : memref<1xi32>
+ return
+}