[mlir][affine] fix affine LICM pass for has effected memory's user
authorlipracer <lipracer@gmail.com>
Thu, 2 Feb 2023 18:23:24 +0000 (10:23 -0800)
committerJeff Niu <jeff@modular.com>
Thu, 2 Feb 2023 18:25:16 +0000 (10:25 -0800)
When the memory is written by dma, its user is moved

Reviewed By: bondhugula

Differential Revision: https://reviews.llvm.org/D141106

mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir

index 81f2cb9..320e184 100644 (file)
@@ -106,7 +106,7 @@ bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs,
       for (auto *user : memref.getUsers()) {
         // If this memref has a user that is a DMA, give up because these
         // operations write to this memref.
-        if (isa<AffineDmaStartOp, AffineDmaWaitOp>(op))
+        if (isa<AffineDmaStartOp, AffineDmaWaitOp>(user))
           return false;
         // If the memref used by the load/store is used in a store elsewhere in
         // the loop nest, we do not hoist. Similarly, if the memref used in a
index 3aecfde..923cead 100644 (file)
@@ -808,3 +808,41 @@ func.func @affine_parallel(%memref_8: memref<4090x2040xf32>, %x: index) {
 
   return
 }
+
+// -----
+
+// CHECK-LABEL: func.func @affine_invariant_use_after_dma
+#map = affine_map<(d0) -> (d0 * 163840)>
+func.func @affine_invariant_use_after_dma(%arg0: memref<10485760xi32>, %arg1: memref<1xi32>, %arg2: memref<10485760xi32>) {
+  %c320 = arith.constant 320 : index
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %alloc = memref.alloc() {alignment = 16 : i64} : memref<0xi32, 2>
+  %alloc_0 = memref.alloc() : memref<1xi32, 2>
+  affine.for %arg3 = 0 to 64 {
+    %0 = affine.apply #map(%arg3)
+    %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<0xi32, 2>
+    %alloc_2 = memref.alloc() : memref<320xi32, 2>
+    affine.dma_start %arg0[%0], %alloc_2[%c0], %alloc_1[%c0], %c320 : memref<10485760xi32>, memref<320xi32, 2>, memref<0xi32, 2>
+    affine.dma_start %arg1[%c0], %alloc_0[%c0], %alloc[%c0], %c1 : memref<1xi32>, memref<1xi32, 2>, memref<0xi32, 2>
+    affine.dma_wait %alloc_1[%c0], %c320 : memref<0xi32, 2>
+    affine.dma_wait %alloc[%c0], %c1 : memref<0xi32, 2>
+    %1 = affine.apply #map(%arg3)
+    %alloc_3 = memref.alloc() {alignment = 16 : i64} : memref<0xi32, 2>
+    %alloc_4 = memref.alloc() : memref<320xi32, 2>
+    affine.for %arg4 = 0 to 320 {
+      %2 = affine.load %alloc_2[%arg4] : memref<320xi32, 2>
+      %3 = affine.load %alloc_0[0] : memref<1xi32, 2>
+      %4 = arith.addi %2, %3 : i32
+      %5 = arith.addi %4, %2 : i32
+      affine.store %5, %alloc_4[%arg4] : memref<320xi32, 2>
+    }
+    affine.dma_start %alloc_4[%c0], %arg2[%1], %alloc_3[%c0], %c320 : memref<320xi32, 2>, memref<10485760xi32>, memref<0xi32, 2>
+    affine.dma_wait %alloc_3[%c0], %c320 : memref<0xi32, 2>
+  }
+  return
+}
+// CHECK: %[[zero:.*]] = arith.constant 0 : index
+// CHECK: %[[scalar_mem:.*]] = memref.alloc() : memref<1xi32, 2>
+// CHECK: affine.dma_start %arg1[%[[zero]]], %alloc_0[%[[zero]]], %alloc[%[[zero]]], %c1
+// CHECK: affine.load %[[scalar_mem]][0]