return 0;
}
-/// Doubles the buffer of the supplied memref while replacing all uses of the
-/// old memref. Returns false if such a replacement cannot be performed.
-static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) {
+/// Doubles the buffer of the supplied memref on the specified 'for' statement
+/// by adding a leading dimension of size two to the memref. Replaces all uses
+/// of the old memref by the new one while indexing the newly added dimension by
+/// the loop IV of the specified 'for' statement modulo 2. Returns false if such
+/// a replacement cannot be performed.
+static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
MLFuncBuilder bInner(forStmt, forStmt->begin());
bInner.setInsertionPoint(forStmt, forStmt->begin());
return newMemRefType;
};
- auto newMemRefType = doubleShape(oldMemRef->getType().cast<MemRefType>());
+ auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
+ auto newMemRefType = doubleShape(oldMemRefType);
+
+ // Put together alloc operands for the dynamic dimensions of the memref.
+ MLFuncBuilder bOuter(forStmt);
+ SmallVector<SSAValue *, 4> allocOperands;
+ unsigned dynamicDimCount = 0;
+ for (auto dimSize : oldMemRefType.getShape()) {
+ if (dimSize == -1)
+ allocOperands.push_back(bOuter.create<DimOp>(forStmt->getLoc(), oldMemRef,
+ dynamicDimCount++));
+ }
- // Create and place the alloc at the top level.
- MLFuncBuilder topBuilder(forStmt->getFunction());
- auto newMemRef = cast<MLValue>(
- topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType)
- ->getResult());
+ // Create and place the alloc right before the 'for' statement.
+ // TODO(mlir-team): we are assuming scoped allocation here, and aren't
+ // inserting a dealloc -- this isn't the right thing.
+ SSAValue *newMemRef =
+ bOuter.create<AllocOp>(forStmt->getLoc(), newMemRefType, allocOperands);
+ // Create 'iv mod 2' value to index the leading dimension.
auto d0 = bInner.getAffineDimExpr(0);
auto modTwoMap =
bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, {d0 % 2}, {});
auto ivModTwoOp =
bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt);
- if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef,
- cast<MLValue>(ivModTwoOp->getResult(0)))) {
+
+ // replaceAllMemRefUsesWith will always succeed unless the forStmt body has
+ // non-deferencing uses of the memref.
+ if (!replaceAllMemRefUsesWith(oldMemRef, cast<MLValue>(newMemRef),
+ ivModTwoOp->getResult(0), AffineMap::Null(), {},
+ &*forStmt->begin())) {
LLVM_DEBUG(llvm::dbgs()
<< "memref replacement for double buffering failed\n";);
ivModTwoOp->getOperation()->erase();
cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()));
bool escapingUses = false;
for (const auto &use : memref->getUses()) {
- if (!dominates(*forStmt, *use.getOwner())) {
+ if (!dominates(*forStmt->begin(), *use.getOwner())) {
LLVM_DEBUG(llvm::dbgs()
<< "can't pipeline: buffer is live out of loop\n";);
escapingUses = true;
return success();
}
// If the old memref has no more uses, remove its 'dead' alloc if it was
- // alloc'ed (note: DMA buffers are rarely function live-in).
+ // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim'
+ // operation could have been used on it if it was dynamically shaped in
+ // order to create the double buffer above)
if (oldMemRef->use_empty())
if (auto *allocStmt = oldMemRef->getDefiningStmt())
allocStmt->erase();
// CHECK-LABEL: mlfunc @loop_nest_dma() {
mlfunc @loop_nest_dma() {
-// CHECK: %c8 = constant 8 : index
-// CHECK-NEXT: %c0 = constant 0 : index
-// CHECK-NEXT: %0 = alloc() : memref<2x1xf32>
-// CHECK-NEXT: %1 = alloc() : memref<2x32xf32, 1>
-// CHECK-NEXT: %2 = alloc() : memref<256xf32>
-// CHECK-NEXT: %c0_0 = constant 0 : index
-// CHECK-NEXT: %c128 = constant 128 : index
-// CHECK-NEXT: %3 = affine_apply #map0(%c0)
-// CHECK-NEXT: dma_start %2[%c0], %1[%3#0, %c0], %c128, %0[%3#1, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32>
+// CHECK: %0 = alloc() : memref<256xf32>
+// CHECK: %1 = alloc() : memref<2x32xf32, 1>
+// CHECK: %2 = alloc() : memref<2x1xf32>
+// CHECK: dma_start %0[%c0], %1[%3#0, %c0], %c128, %2[%3#1, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32>
// CHECK-NEXT: for %i0 = 1 to 8 {
// CHECK-NEXT: %4 = affine_apply #map0(%i0)
-// CHECK-NEXT: dma_start %2[%i0], %1[%4#0, %i0], %c128, %0[%4#1, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32>
+// CHECK-NEXT: dma_start %0[%i0], %1[%4#0, %i0], %c128, %2[%4#1, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32>
// CHECK-NEXT: %5 = affine_apply #map1(%i0)
// CHECK-NEXT: %6 = affine_apply #map2(%5)
// CHECK-NEXT: %7 = affine_apply #map2(%5)
-// CHECK-NEXT: dma_wait %0[%6, %c0_0], %c128 : memref<2x1xf32>
-// CHECK-NEXT: %8 = load %1[%7, %5] : memref<2x32xf32, 1>
+// CHECK-NEXT: dma_wait %2[%6, %c0_0], %c128 : memref<2x1xf32>
+// CHECK-NEXT: %8 = load %1[%7, %5] : memref<2x32xf32, 1>
// CHECK-NEXT: %9 = "compute"(%8) : (f32) -> f32
// CHECK-NEXT: store %9, %1[%7, %5] : memref<2x32xf32, 1>
// CHECK-NEXT: for %i1 = 0 to 128 {
// CHECK-NEXT: %10 = affine_apply #map1(%c8)
// CHECK-NEXT: %11 = affine_apply #map2(%10)
// CHECK-NEXT: %12 = affine_apply #map2(%10)
-// CHECK-NEXT: dma_wait %0[%11, %c0_0], %c128 : memref<2x1xf32>
+// CHECK-NEXT: dma_wait %2[%11, %c0_0], %c128 : memref<2x1xf32>
// CHECK-NEXT: %13 = load %1[%12, %10] : memref<2x32xf32, 1>
// CHECK-NEXT: %14 = "compute"(%13) : (f32) -> f32
// CHECK-NEXT: store %14, %1[%12, %10] : memref<2x32xf32, 1>
dma_wait %5[%c0], %num_elts : memref<2xi32>
// Steady state for DMA overlap on arg2
// CHECK: dma_start %arg2[
- // CHECK: dma_wait %0[
+ // CHECK: dma_wait %1[
// Prologue for DMA overlap on arg0, arg1 nested within i0
// CHECK: dma_start %arg0[
// CHECK: dma_start %arg1[
// Steady state for DMA overlap on arg0, arg1
// CHECK: dma_start %arg0[
// CHECK: dma_start %arg1[
- // CHECK: dma_wait %3[
- // CHECK: dma_wait %2[
+ // CHECK: dma_wait %10[
+ // CHECK: dma_wait %11[
// CHECK-NEXT: for %i2 = 0 to 4 {
for %i2 = 0 to 4 {
"foo"() : () -> ()
}
}
// epilogue for arg0, arg1
- // CHECK: dma_wait %3[
- // CHECK: dma_wait %2[
+ // CHECK: dma_wait %10[
+ // CHECK: dma_wait %11[
// epilogue for DMA overlap on %arg2
- // CHECK: dma_wait %0[%31, %c0_2], %c256 : memref<2x2xi32>
+ // CHECK: dma_wait %1[%31, %c0_2], %c256 : memref<2x2xi32>
// Within the epilogue for arg2's DMA, we have the DMAs on %arg1, %arg2 nested.
// CHECK: dma_start %arg0[
// CHECK: dma_start %arg1[
// CHECK: for %i4 = 1 to 8 {
// CHECK: dma_start %arg0[
// CHECK: dma_start %arg1[
- // CHECK: dma_wait %3[
- // CHECK: dma_wait %2[
+ // CHECK: dma_wait %36[
+ // CHECK: dma_wait %37[
// CHECK: for %i5 = 0 to 4 {
// CHECK: "foo"() : () -> ()
- // CHECK: dma_wait %3[
- // CHECK: dma_wait %2[
+ // CHECK: dma_wait %36[
+ // CHECK: dma_wait %37[
// CHECK: for %i6 = 0 to 4 {
// The DMAs below are outgoing DMAs on arg2, not yet overlapped.
- // CHECK: dma_start %1{{.*}}, %arg2[
- // CHECK-NEXT: dma_wait %0[
+ // CHECK: dma_start %0{{.*}}, %arg2[
+ // CHECK-NEXT: dma_wait %1[
dma_start %2[%c0, %c0], %arg2[%6#0, %6#1], %num_elts, %5[%c0] : memref<64x4xvector<8xf32>, #map0, 2>, memref<512x32xvector<8xf32>, #map0>, memref<2xi32>
dma_wait %5[%c0], %num_elts : memref<2xi32>
} // CHECK }
return
}
+
+// CHECK-LABEL: mlfunc @escaping_use
+mlfunc @escaping_use(%arg0: memref<512 x 32 x f32>) {
+ %c32 = constant 32 : index
+ %num_elt = constant 512 : index
+ %zero = constant 0 : index
+ %Av = alloc() : memref<32 x 32 x f32, 2>
+ %tag = alloc() : memref<1 x i32>
+
+ // CHECK-NOT: dma_start
+ // CHECK: for %i0 = 0 to 16 {
+ for %kTT = 0 to 16 {
+ dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %num_elt, %tag[%zero] :
+ memref<512 x 32 x f32>,
+ memref<32 x 32 x f32, 2>, memref<1 x i32>
+ dma_wait %tag[%zero], %num_elt : memref<1 x i32>
+ // escaping use; no DMA pipelining / double buffering will be done.
+ "foo"(%Av) : (memref<32 x 32 x f32, 2>) -> ()
+ }
+ return
+// CHECK: "foo"(%{{[0-9]+}}) : (memref<32x32xf32, 2>) -> ()
+// CHECK: }
+// CHECK-NEXT: return
+}
+
+// CHECK-LABEL: mlfunc @live_out_use
+mlfunc @live_out_use(%arg0: memref<512 x 32 x f32>) -> f32 {
+ %c32 = constant 32 : index
+ %num_elt = constant 512 : index
+ %zero = constant 0 : index
+ %Av = alloc() : memref<32 x 32 x f32, 2>
+ %tag = alloc() : memref<1 x i32>
+
+ // CHECK-NOT: dma_start
+ // CHECK: for %i0 = 0 to 16 {
+ for %kTT = 0 to 16 {
+ dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %num_elt, %tag[%zero] :
+ memref<512 x 32 x f32>,
+ memref<32 x 32 x f32, 2>, memref<1 x i32>
+ dma_wait %tag[%zero], %num_elt : memref<1 x i32>
+ }
+ // Use live out of 'for' stmt; no DMA pipelining will be done.
+ %v = load %Av[%zero, %zero] : memref<32 x 32 x f32, 2>
+ return %v : f32
+// CHECK: %{{[0-9]+}} = load %{{[0-9]+}}[%c0, %c0] : memref<32x32xf32, 2>
+// CHECK-NEXT: return
+}
+
+// CHECK-LABEL: mlfunc @dynamic_shape_dma_buffer
+mlfunc @dynamic_shape_dma_buffer(%arg0: memref<512 x 32 x f32>) {
+ %c32 = constant 32 : index
+ %num_elt = constant 512 : index
+ %zero = constant 0 : index
+
+ %Av = alloc(%c32, %c32) : memref<? x ? x f32, 2>
+ %tag = alloc() : memref<1 x i32>
+
+// Double buffering for dynamic shaped buffer.
+// CHECK: %0 = alloc(%c32, %c32) : memref<?x?xf32, 2>
+// CHECK-NEXT: %1 = dim %0, 0 : memref<?x?xf32, 2>
+// CHECK-NEXT: %2 = dim %0, 1 : memref<?x?xf32, 2>
+// CHECK-NEXT: %3 = alloc(%1, %2) : memref<2x?x?xf32, 2>
+
+// CHECK: dma_start %arg0[%c0_0, %c0_0], %3[%5#0, %c0_0, %c0_0],
+// CHECK-NEXT: for %i0 = 1 to 16 {
+ for %kTT = 0 to 16 {
+ dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %num_elt, %tag[%zero] :
+ memref<512 x 32 x f32>,
+ memref<? x ? x f32, 2>, memref<1 x i32>
+ dma_wait %tag[%zero], %num_elt : memref<1 x i32>
+ }
+ return
+// CHECK: dma_start %arg0[%c0_0, %c0_0], %3[%6#0, %c0_0, %c0_0], %c512, %4[%6#1, %c0_0]
+// CHECK: dma_wait %4[%8, %c0_0], %c512 : memref<2x1xi32>
+// CHECK: }
+// CHECK: dma_wait %4[%11, %c0_0], %c512 : memref<2x1xi32>
+// CHECK-NEXT: return
+}