Add affine-to-standard lowerings for affine.load/store/dma_start/dma_wait.
authorAndy Davis <andydavis@google.com>
Mon, 1 Jul 2019 15:32:44 +0000 (08:32 -0700)
committerjpienaar <jpienaar@google.com>
Mon, 1 Jul 2019 16:56:22 +0000 (09:56 -0700)
PiperOrigin-RevId: 255960171

mlir/lib/Transforms/LowerAffine.cpp
mlir/test/Transforms/lower-affine.mlir

index 6b6ba9065373b9c66c3af12ab6fbe48481a4518e..77a23b156b0f330086d98c148853270d2d82cf58 100644 (file)
@@ -597,11 +597,140 @@ public:
     return matchSuccess();
   }
 };
+
+// Apply the affine map from an 'affine.load' operation to its operands, and
+// feed the results to a newly created 'std.load' operation (which replaces the
+// original 'affine.load').
+class AffineLoadLowering : public ConversionPattern {
+public:
+  AffineLoadLowering(MLIRContext *ctx)
+      : ConversionPattern(AffineLoadOp::getOperationName(), 1, ctx) {}
+
+  virtual PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  PatternRewriter &rewriter) const override {
+    auto affineLoadOp = cast<AffineLoadOp>(op);
+    // Expand affine map from 'affineLoadOp'.
+    auto maybeExpandedMap =
+        expandAffineMap(rewriter, op->getLoc(), affineLoadOp.getAffineMap(),
+                        operands.drop_front());
+    if (!maybeExpandedMap)
+      return matchFailure();
+    // Build std.load memref[expandedMap.results].
+    rewriter.replaceOpWithNewOp<LoadOp>(op, operands[0], *maybeExpandedMap);
+    return matchSuccess();
+  }
+};
+
+// Apply the affine map from an 'affine.store' operation to its operands, and
+// feed the results to a newly created 'std.store' operation (which replaces the
+// original 'affine.store').
+class AffineStoreLowering : public ConversionPattern {
+public:
+  AffineStoreLowering(MLIRContext *ctx)
+      : ConversionPattern(AffineStoreOp::getOperationName(), 1, ctx) {}
+
+  virtual PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  PatternRewriter &rewriter) const override {
+    auto affineStoreOp = cast<AffineStoreOp>(op);
+    // Expand affine map from 'affineStoreOp'.
+    auto maybeExpandedMap =
+        expandAffineMap(rewriter, op->getLoc(), affineStoreOp.getAffineMap(),
+                        operands.drop_front(2));
+    if (!maybeExpandedMap)
+      return matchFailure();
+    // Build std.store valutToStore, memref[expandedMap.results].
+    rewriter.replaceOpWithNewOp<StoreOp>(op, operands[0], operands[1],
+                                         *maybeExpandedMap);
+    return matchSuccess();
+  }
+};
+
+// Apply the affine maps from an 'affine.dma_start' operation to each of their
+// respective map operands, and feed the results to a newly created
+// 'std.dma_start' operation (which replaces the original 'affine.dma_start').
+class AffineDmaStartLowering : public ConversionPattern {
+public:
+  AffineDmaStartLowering(MLIRContext *ctx)
+      : ConversionPattern(AffineDmaStartOp::getOperationName(), 1, ctx) {}
+
+  virtual PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  PatternRewriter &rewriter) const override {
+    auto affineDmaStartOp = cast<AffineDmaStartOp>(op);
+    // Expand affine map for DMA source memref.
+    auto maybeExpandedSrcMap = expandAffineMap(
+        rewriter, op->getLoc(), affineDmaStartOp.getSrcMap(),
+        operands.drop_front(affineDmaStartOp.getSrcMemRefOperandIndex() + 1));
+    if (!maybeExpandedSrcMap)
+      return matchFailure();
+    // Expand affine map for DMA destination memref.
+    auto maybeExpandedDstMap = expandAffineMap(
+        rewriter, op->getLoc(), affineDmaStartOp.getDstMap(),
+        operands.drop_front(affineDmaStartOp.getDstMemRefOperandIndex() + 1));
+    if (!maybeExpandedDstMap)
+      return matchFailure();
+    // Expand affine map for DMA tag memref.
+    auto maybeExpandedTagMap = expandAffineMap(
+        rewriter, op->getLoc(), affineDmaStartOp.getTagMap(),
+        operands.drop_front(affineDmaStartOp.getTagMemRefOperandIndex() + 1));
+    if (!maybeExpandedTagMap)
+      return matchFailure();
+
+    // Build std.dma_start operation with affine map results.
+    auto *srcMemRef = operands[affineDmaStartOp.getSrcMemRefOperandIndex()];
+    auto *dstMemRef = operands[affineDmaStartOp.getDstMemRefOperandIndex()];
+    auto *tagMemRef = operands[affineDmaStartOp.getTagMemRefOperandIndex()];
+    unsigned numElementsIndex = affineDmaStartOp.getTagMemRefOperandIndex() +
+                                1 + affineDmaStartOp.getTagMap().getNumInputs();
+    auto *numElements = operands[numElementsIndex];
+    auto *stride =
+        affineDmaStartOp.isStrided() ? operands[numElementsIndex + 1] : nullptr;
+    auto *eltsPerStride =
+        affineDmaStartOp.isStrided() ? operands[numElementsIndex + 2] : nullptr;
+
+    rewriter.replaceOpWithNewOp<DmaStartOp>(
+        op, srcMemRef, *maybeExpandedSrcMap, dstMemRef, *maybeExpandedDstMap,
+        numElements, tagMemRef, *maybeExpandedTagMap, stride, eltsPerStride);
+    return matchSuccess();
+  }
+};
+
+// Apply the affine map from an 'affine.dma_wait' operation tag memref,
+// and feed the results to a newly created 'std.dma_wait' operation (which
+// replaces the original 'affine.dma_wait').
+class AffineDmaWaitLowering : public ConversionPattern {
+public:
+  AffineDmaWaitLowering(MLIRContext *ctx)
+      : ConversionPattern(AffineDmaWaitOp::getOperationName(), 1, ctx) {}
+
+  virtual PatternMatchResult
+  matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
+                  PatternRewriter &rewriter) const override {
+    auto affineDmaWaitOp = cast<AffineDmaWaitOp>(op);
+    // Expand affine map for DMA tag memref.
+    auto maybeExpandedTagMap =
+        expandAffineMap(rewriter, op->getLoc(), affineDmaWaitOp.getTagMap(),
+                        operands.drop_front());
+    if (!maybeExpandedTagMap)
+      return matchFailure();
+
+    // Build std.dma_wait operation with affine map results.
+    unsigned numElementsIndex = 1 + affineDmaWaitOp.getTagMap().getNumInputs();
+    rewriter.replaceOpWithNewOp<DmaWaitOp>(
+        op, operands[0], *maybeExpandedTagMap, operands[numElementsIndex]);
+    return matchSuccess();
+  }
+};
+
 } // end namespace
 
 LogicalResult mlir::lowerAffineConstructs(Function &function) {
   OwningRewritePatternList patterns;
-  RewriteListBuilder<AffineApplyLowering, AffineForLowering, AffineIfLowering,
+  RewriteListBuilder<AffineApplyLowering, AffineDmaStartLowering,
+                     AffineDmaWaitLowering, AffineLoadLowering,
+                     AffineStoreLowering, AffineForLowering, AffineIfLowering,
                      AffineTerminatorLowering>::build(patterns,
                                                       function.getContext());
   ConversionTarget target(*function.getContext());
index fc6afbd9b68ed1346043662cf6585df88136cf5f..8538a94a2616909b4ce6ba484004636ef56cec75 100644 (file)
@@ -637,3 +637,65 @@ func @affine_apply_ceildiv(%arg0 : index) -> (index) {
   %0 = affine.apply #mapceildiv (%arg0)
   return %0 : index
 }
+
+// CHECK-LABEL: func @affine_load
+func @affine_load(%arg0 : index) {
+  %0 = alloc() : memref<10xf32>
+  affine.for %i0 = 0 to 10 {
+    %1 = affine.load %0[%i0 + symbol(%arg0) + 7] : memref<10xf32>
+  }
+// CHECK:       %3 = addi %1, %arg0 : index
+// CHECK-NEXT:  %c7 = constant 7 : index
+// CHECK-NEXT:  %4 = addi %3, %c7 : index
+// CHECK-NEXT:  %5 = load %0[%4] : memref<10xf32>
+  return
+}
+
+// CHECK-LABEL: func @affine_store
+func @affine_store(%arg0 : index) {
+  %0 = alloc() : memref<10xf32>
+  %1 = constant 11.0 : f32 
+  affine.for %i0 = 0 to 10 {
+    affine.store %1, %0[%i0 - symbol(%arg0) + 7] : memref<10xf32>
+  }
+// CHECK:       %c-1 = constant -1 : index
+// CHECK-NEXT:  %3 = muli %arg0, %c-1 : index
+// CHECK-NEXT:  %4 = addi %1, %3 : index
+// CHECK-NEXT:  %c7 = constant 7 : index
+// CHECK-NEXT:  %5 = addi %4, %c7 : index
+// CHECK-NEXT:  store %cst, %0[%5] : memref<10xf32>
+  return
+}
+
+// CHECK-LABEL: func @affine_dma_start
+func @affine_dma_start(%arg0 : index) {
+  %0 = alloc() : memref<100xf32>
+  %1 = alloc() : memref<100xf32, 2>
+  %2 = alloc() : memref<1xi32>
+  %c0 = constant 0 : index
+  %c64 = constant 64 : index
+  affine.for %i0 = 0 to 10 {
+    affine.dma_start %0[%i0 + 7], %1[%arg0 + 11], %2[%c0], %c64
+        : memref<100xf32>, memref<100xf32, 2>, memref<1xi32>
+  }
+// CHECK:       %c7 = constant 7 : index
+// CHECK-NEXT:  %5 = addi %3, %c7 : index
+// CHECK-NEXT:  %c11 = constant 11 : index
+// CHECK-NEXT:  %6 = addi %arg0, %c11 : index
+// CHECK-NEXT:  dma_start %0[%5], %1[%6], %c64, %2[%c0] : memref<100xf32>, memref<100xf32, 2>, memref<1xi32>
+  return
+}
+
+// CHECK-LABEL: func @affine_dma_wait
+func @affine_dma_wait(%arg0 : index) {
+  %2 = alloc() : memref<1xi32>
+  %c64 = constant 64 : index
+  affine.for %i0 = 0 to 10 {
+    affine.dma_wait %2[%i0 + %arg0 + 17], %c64 : memref<1xi32>
+  }
+// CHECK:       %3 = addi %1, %arg0 : index
+// CHECK-NEXT:  %c17 = constant 17 : index
+// CHECK-NEXT:  %4 = addi %3, %c17 : index
+// CHECK-NEXT:  dma_wait %0[%4], %c64 : memref<1xi32>
+  return
+}