Update/fix -pipeline-data-transfer; fix b/120770946
authorUday Bondhugula <bondhugula@google.com>
Mon, 10 Dec 2018 19:39:31 +0000 (11:39 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 21:24:22 +0000 (14:24 -0700)
- fix replaceAllMemRefUsesWith call to replace only inside loop body.
- handle the case where DMA buffers are dynamic; extend doubleBuffer() method
  to handle dynamically shaped DMA buffers (pass the right operands to AllocOp)
- place alloc's for DMA buffers at the depth at which pipelining is being done
  (instead of at top-level)
- add more test cases

PiperOrigin-RevId: 224852231

mlir/include/mlir/Transforms/Utils.h
mlir/lib/Transforms/PipelineDataTransfer.cpp
mlir/lib/Transforms/Utils/Utils.cpp
mlir/test/Transforms/pipeline-data-transfer.mlir

index 5f871992caecdf8f7be21731c2af95d7af112980..7fe4b8a0a0652dcf695c0b5608812e762ccb5bdf 100644 (file)
@@ -54,7 +54,7 @@ class SSAValue;
 // TODO(mlir-team): extend this for SSAValue / CFGFunctions. Can also be easily
 // extended to add additional indices at any position.
 bool replaceAllMemRefUsesWith(const MLValue *oldMemRef, MLValue *newMemRef,
-                              ArrayRef<MLValue *> extraIndices = {},
+                              ArrayRef<SSAValue *> extraIndices = {},
                               AffineMap indexRemap = AffineMap::Null(),
                               ArrayRef<SSAValue *> extraOperands = {},
                               const Statement *domStmtFilter = nullptr);
index 3d8c21c543e5db61e187c53b70e95c65ee5b842e..fc97aa8d2d20e478baccf085f2b319de108ab6e1 100644 (file)
@@ -75,9 +75,12 @@ static unsigned getTagMemRefPos(const OperationStmt &dmaStmt) {
   return 0;
 }
 
-/// Doubles the buffer of the supplied memref while replacing all uses of the
-/// old memref. Returns false if such a replacement cannot be performed.
-static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) {
+/// Doubles the buffer of the supplied memref on the specified 'for' statement
+/// by adding a leading dimension of size two to the memref. Replaces all uses
+/// of the old memref by the new one while indexing the newly added dimension by
+/// the loop IV of the specified 'for' statement modulo 2. Returns false if such
+/// a replacement cannot be performed.
+static bool doubleBuffer(MLValue *oldMemRef, ForStmt *forStmt) {
   MLFuncBuilder bInner(forStmt, forStmt->begin());
   bInner.setInsertionPoint(forStmt, forStmt->begin());
 
@@ -94,21 +97,37 @@ static bool doubleBuffer(const MLValue *oldMemRef, ForStmt *forStmt) {
     return newMemRefType;
   };
 
-  auto newMemRefType = doubleShape(oldMemRef->getType().cast<MemRefType>());
+  auto oldMemRefType = oldMemRef->getType().cast<MemRefType>();
+  auto newMemRefType = doubleShape(oldMemRefType);
+
+  // Put together alloc operands for the dynamic dimensions of the memref.
+  MLFuncBuilder bOuter(forStmt);
+  SmallVector<SSAValue *, 4> allocOperands;
+  unsigned dynamicDimCount = 0;
+  for (auto dimSize : oldMemRefType.getShape()) {
+    if (dimSize == -1)
+      allocOperands.push_back(bOuter.create<DimOp>(forStmt->getLoc(), oldMemRef,
+                                                   dynamicDimCount++));
+  }
 
-  // Create and place the alloc at the top level.
-  MLFuncBuilder topBuilder(forStmt->getFunction());
-  auto newMemRef = cast<MLValue>(
-      topBuilder.create<AllocOp>(forStmt->getLoc(), newMemRefType)
-          ->getResult());
+  // Create and place the alloc right before the 'for' statement.
+  // TODO(mlir-team): we are assuming scoped allocation here, and aren't
+  // inserting a dealloc -- this isn't the right thing.
+  SSAValue *newMemRef =
+      bOuter.create<AllocOp>(forStmt->getLoc(), newMemRefType, allocOperands);
 
+  // Create 'iv mod 2' value to index the leading dimension.
   auto d0 = bInner.getAffineDimExpr(0);
   auto modTwoMap =
       bInner.getAffineMap(/*dimCount=*/1, /*symbolCount=*/0, {d0 % 2}, {});
   auto ivModTwoOp =
       bInner.create<AffineApplyOp>(forStmt->getLoc(), modTwoMap, forStmt);
-  if (!replaceAllMemRefUsesWith(oldMemRef, newMemRef,
-                                cast<MLValue>(ivModTwoOp->getResult(0)))) {
+
+  // replaceAllMemRefUsesWith will always succeed unless the forStmt body has
+  // non-deferencing uses of the memref.
+  if (!replaceAllMemRefUsesWith(oldMemRef, cast<MLValue>(newMemRef),
+                                ivModTwoOp->getResult(0), AffineMap::Null(), {},
+                                &*forStmt->begin())) {
     LLVM_DEBUG(llvm::dbgs()
                    << "memref replacement for double buffering failed\n";);
     ivModTwoOp->getOperation()->erase();
@@ -185,7 +204,7 @@ static void findMatchingStartFinishStmts(
         cast<MLValue>(dmaStartOp->getOperand(dmaStartOp->getFasterMemPos()));
     bool escapingUses = false;
     for (const auto &use : memref->getUses()) {
-      if (!dominates(*forStmt, *use.getOwner())) {
+      if (!dominates(*forStmt->begin(), *use.getOwner())) {
         LLVM_DEBUG(llvm::dbgs()
                        << "can't pipeline: buffer is live out of loop\n";);
         escapingUses = true;
@@ -247,7 +266,9 @@ PassResult PipelineDataTransfer::runOnForStmt(ForStmt *forStmt) {
       return success();
     }
     // If the old memref has no more uses, remove its 'dead' alloc if it was
-    // alloc'ed (note: DMA buffers are rarely function live-in).
+    // alloc'ed. (note: DMA buffers are rarely function live-in; but a 'dim'
+    // operation could have been used on it if it was dynamically shaped in
+    // order to create the double buffer above)
     if (oldMemRef->use_empty())
       if (auto *allocStmt = oldMemRef->getDefiningStmt())
         allocStmt->erase();
index ae3533fd90ce7ab7662a0dbcc855a442bef66834..8d375c42ca3c14cdb3eeb806b0939b29de67eb78 100644 (file)
@@ -65,7 +65,7 @@ static bool isMemRefDereferencingOp(const Operation &op) {
 // extended to add additional indices at any position.
 bool mlir::replaceAllMemRefUsesWith(const MLValue *oldMemRef,
                                     MLValue *newMemRef,
-                                    ArrayRef<MLValue *> extraIndices,
+                                    ArrayRef<SSAValue *> extraIndices,
                                     AffineMap indexRemap,
                                     ArrayRef<SSAValue *> extraOperands,
                                     const Statement *domStmtFilter) {
index db4ea67b0cb3d7db815e32f2a133f09ce09104f2..70468f7e6c163f6fb13f71d111744252024d9923 100644 (file)
@@ -2,23 +2,18 @@
 
 // CHECK-LABEL: mlfunc @loop_nest_dma() {
 mlfunc @loop_nest_dma() {
-// CHECK:        %c8 = constant 8 : index
-// CHECK-NEXT:   %c0 = constant 0 : index
-// CHECK-NEXT:   %0 = alloc() : memref<2x1xf32>
-// CHECK-NEXT:   %1 = alloc() : memref<2x32xf32, 1>
-// CHECK-NEXT:   %2 = alloc() : memref<256xf32>
-// CHECK-NEXT:   %c0_0 = constant 0 : index
-// CHECK-NEXT:   %c128 = constant 128 : index
-// CHECK-NEXT:   %3 = affine_apply #map0(%c0)
-// CHECK-NEXT:   dma_start %2[%c0], %1[%3#0, %c0], %c128, %0[%3#1, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32>
+// CHECK:        %0 = alloc() : memref<256xf32>
+// CHECK:        %1 = alloc() : memref<2x32xf32, 1>
+// CHECK:        %2 = alloc() : memref<2x1xf32>
+// CHECK:        dma_start %0[%c0], %1[%3#0, %c0], %c128, %2[%3#1, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32>
 // CHECK-NEXT:   for %i0 = 1 to 8 {
 // CHECK-NEXT:     %4 = affine_apply #map0(%i0)
-// CHECK-NEXT:     dma_start %2[%i0], %1[%4#0, %i0], %c128, %0[%4#1, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32>
+// CHECK-NEXT:     dma_start %0[%i0], %1[%4#0, %i0], %c128, %2[%4#1, %c0_0] : memref<256xf32>, memref<2x32xf32, 1>, memref<2x1xf32>
 // CHECK-NEXT:     %5 = affine_apply #map1(%i0)
 // CHECK-NEXT:     %6 = affine_apply #map2(%5)
 // CHECK-NEXT:     %7 = affine_apply #map2(%5)
-// CHECK-NEXT:     dma_wait %0[%6, %c0_0], %c128 : memref<2x1xf32>
-// CHECK-NEXT:    %8 = load %1[%7, %5] : memref<2x32xf32, 1>
+// CHECK-NEXT:     dma_wait %2[%6, %c0_0], %c128 : memref<2x1xf32>
+// CHECK-NEXT:     %8 = load %1[%7, %5] : memref<2x32xf32, 1>
 // CHECK-NEXT:     %9 = "compute"(%8) : (f32) -> f32
 // CHECK-NEXT:     store %9, %1[%7, %5] : memref<2x32xf32, 1>
 // CHECK-NEXT:     for %i1 = 0 to 128 {
@@ -28,7 +23,7 @@ mlfunc @loop_nest_dma() {
 // CHECK-NEXT:   %10 = affine_apply #map1(%c8)
 // CHECK-NEXT:   %11 = affine_apply #map2(%10)
 // CHECK-NEXT:   %12 = affine_apply #map2(%10)
-// CHECK-NEXT:   dma_wait %0[%11, %c0_0], %c128 : memref<2x1xf32>
+// CHECK-NEXT:   dma_wait %2[%11, %c0_0], %c128 : memref<2x1xf32>
 // CHECK-NEXT:   %13 = load %1[%12, %10] : memref<2x32xf32, 1>
 // CHECK-NEXT:   %14 = "compute"(%13) : (f32) -> f32
 // CHECK-NEXT:   store %14, %1[%12, %10] : memref<2x32xf32, 1>
@@ -80,7 +75,7 @@ mlfunc @loop_dma_nested(%arg0 : memref<512x32xvector<8xf32>, #map0>, %arg1 : mem
     dma_wait %5[%c0], %num_elts : memref<2xi32>
     // Steady state for DMA overlap on arg2
     // CHECK: dma_start %arg2[
-    // CHECK: dma_wait %0[
+    // CHECK: dma_wait %1[
     // Prologue for DMA overlap on arg0, arg1 nested within i0
     // CHECK: dma_start %arg0[
     // CHECK: dma_start %arg1[
@@ -95,38 +90,116 @@ mlfunc @loop_dma_nested(%arg0 : memref<512x32xvector<8xf32>, #map0>, %arg1 : mem
       // Steady state for DMA overlap on arg0, arg1
       // CHECK: dma_start %arg0[
       // CHECK: dma_start %arg1[
-      // CHECK: dma_wait %3[
-      // CHECK: dma_wait %2[
+      // CHECK: dma_wait %10[
+      // CHECK: dma_wait %11[
       // CHECK-NEXT: for %i2 = 0 to 4 {
       for %i2 = 0 to 4 {
         "foo"() : () -> ()
       }
     }
     // epilogue for arg0, arg1
-    // CHECK: dma_wait %3[
-    // CHECK: dma_wait %2[
+    // CHECK: dma_wait %10[
+    // CHECK: dma_wait %11[
 
     // epilogue for DMA overlap on %arg2
-    // CHECK:  dma_wait %0[%31, %c0_2], %c256 : memref<2x2xi32>
+    // CHECK:  dma_wait %1[%31, %c0_2], %c256 : memref<2x2xi32>
     // Within the epilogue for arg2's DMA, we have the DMAs on %arg1, %arg2 nested.
     // CHECK:  dma_start %arg0[
     // CHECK:  dma_start %arg1[
     // CHECK:  for %i4 = 1 to 8 {
     // CHECK:    dma_start %arg0[
     // CHECK:    dma_start %arg1[
-    // CHECK:    dma_wait %3[
-    // CHECK:    dma_wait %2[
+    // CHECK:    dma_wait %36[
+    // CHECK:    dma_wait %37[
     // CHECK:    for %i5 = 0 to 4 {
     // CHECK:      "foo"() : () -> ()
-    // CHECK:  dma_wait %3[
-    // CHECK:  dma_wait %2[
+    // CHECK:  dma_wait %36[
+    // CHECK:  dma_wait %37[
     // CHECK:  for %i6 = 0 to 4 {
 
     // The DMAs below are outgoing DMAs on arg2, not yet overlapped.
-    // CHECK: dma_start %1{{.*}}, %arg2[
-    // CHECK-NEXT:  dma_wait %0[
+    // CHECK: dma_start %0{{.*}}, %arg2[
+    // CHECK-NEXT:  dma_wait %1[
     dma_start %2[%c0, %c0], %arg2[%6#0, %6#1], %num_elts, %5[%c0] : memref<64x4xvector<8xf32>, #map0, 2>, memref<512x32xvector<8xf32>, #map0>, memref<2xi32>
     dma_wait %5[%c0], %num_elts : memref<2xi32>
   } // CHECK }
   return
 }
+
+// CHECK-LABEL: mlfunc @escaping_use
+mlfunc @escaping_use(%arg0: memref<512 x 32 x f32>) {
+  %c32 = constant 32 : index
+  %num_elt = constant 512 : index
+  %zero = constant 0 : index
+  %Av = alloc() : memref<32 x 32 x f32, 2>
+  %tag = alloc() : memref<1 x i32>
+
+  // CHECK-NOT: dma_start
+  // CHECK: for %i0 = 0 to 16 {
+  for %kTT = 0 to 16 {
+    dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %num_elt, %tag[%zero] :
+      memref<512 x 32 x f32>,
+      memref<32 x 32 x f32, 2>, memref<1 x i32>
+    dma_wait %tag[%zero], %num_elt : memref<1 x i32>
+    // escaping use; no DMA pipelining / double buffering will be done.
+    "foo"(%Av) : (memref<32 x 32 x f32, 2>) -> ()
+  }
+  return
+// CHECK:        "foo"(%{{[0-9]+}}) : (memref<32x32xf32, 2>) -> ()
+// CHECK:      }
+// CHECK-NEXT: return
+}
+
+// CHECK-LABEL: mlfunc @live_out_use
+mlfunc @live_out_use(%arg0: memref<512 x 32 x f32>) -> f32 {
+  %c32 = constant 32 : index
+  %num_elt = constant 512 : index
+  %zero = constant 0 : index
+  %Av = alloc() : memref<32 x 32 x f32, 2>
+  %tag = alloc() : memref<1 x i32>
+
+  // CHECK-NOT: dma_start
+  // CHECK: for %i0 = 0 to 16 {
+  for %kTT = 0 to 16 {
+    dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %num_elt, %tag[%zero] :
+      memref<512 x 32 x f32>,
+      memref<32 x 32 x f32, 2>, memref<1 x i32>
+    dma_wait %tag[%zero], %num_elt : memref<1 x i32>
+  }
+  // Use live out of 'for' stmt; no DMA pipelining will be done.
+  %v = load %Av[%zero, %zero] : memref<32 x 32 x f32, 2>
+  return %v : f32
+// CHECK:      %{{[0-9]+}} = load %{{[0-9]+}}[%c0, %c0] : memref<32x32xf32, 2>
+// CHECK-NEXT: return
+}
+
+// CHECK-LABEL: mlfunc @dynamic_shape_dma_buffer
+mlfunc @dynamic_shape_dma_buffer(%arg0: memref<512 x 32 x f32>) {
+  %c32 = constant 32 : index
+  %num_elt = constant 512 : index
+  %zero = constant 0 : index
+
+  %Av = alloc(%c32, %c32) : memref<? x ? x f32, 2>
+  %tag = alloc() : memref<1 x i32>
+
+// Double buffering for dynamic shaped buffer.
+// CHECK:       %0 = alloc(%c32, %c32) : memref<?x?xf32, 2>
+// CHECK-NEXT:  %1 = dim %0, 0 : memref<?x?xf32, 2>
+// CHECK-NEXT:  %2 = dim %0, 1 : memref<?x?xf32, 2>
+// CHECK-NEXT:  %3 = alloc(%1, %2) : memref<2x?x?xf32, 2>
+
+// CHECK:  dma_start %arg0[%c0_0, %c0_0], %3[%5#0, %c0_0, %c0_0],
+// CHECK-NEXT:  for %i0 = 1 to 16 {
+  for %kTT = 0 to 16 {
+    dma_start %arg0[%zero, %zero], %Av[%zero, %zero], %num_elt, %tag[%zero] :
+      memref<512 x 32 x f32>,
+      memref<? x ? x f32, 2>, memref<1 x i32>
+    dma_wait %tag[%zero], %num_elt : memref<1 x i32>
+  }
+  return
+// CHECK:          dma_start %arg0[%c0_0, %c0_0], %3[%6#0, %c0_0, %c0_0], %c512, %4[%6#1, %c0_0]
+// CHECK:          dma_wait %4[%8, %c0_0], %c512 : memref<2x1xi32>
+// CHECK:       }
+// CHECK:       dma_wait %4[%11, %c0_0], %c512 : memref<2x1xi32>
+// CHECK-NEXT:  return
+}