[MLIR] Fix affine LICM pass for unknown region holding ops
authorUday Bondhugula <uday@polymagelabs.com>
Sat, 31 Dec 2022 14:56:40 +0000 (20:26 +0530)
committerUday Bondhugula <uday@polymagelabs.com>
Sat, 31 Dec 2022 14:56:50 +0000 (20:26 +0530)
Fix affine LICM pass for unknown region-holding ops. The logic was
completely ignoring regions of unknown ops leading to generation of
invalid IR on hoisting. Handle affine.parallel op among those with
regions that are supported.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D140738

mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
mlir/test/Dialect/Affine/affine-loop-invariant-code-motion.mlir

index 3258165..b124e73 100644 (file)
@@ -84,8 +84,17 @@ bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs,
     if (!areAllOpsInTheBlockListInvariant(forOp.getLoopBody(), indVar, iterArgs,
                                           opsWithUsers, opsToHoist))
       return false;
+  } else if (auto parOp = dyn_cast<AffineParallelOp>(op)) {
+    if (!areAllOpsInTheBlockListInvariant(parOp.getLoopBody(), indVar, iterArgs,
+                                          opsWithUsers, opsToHoist))
+      return false;
   } else if (isa<AffineDmaStartOp, AffineDmaWaitOp>(op)) {
     // TODO: Support DMA ops.
+    // FIXME: This should be fixed to not special-case these affine DMA ops but
+    // instead rely on side effects.
+    return false;
+  } else if (op.getNumRegions() > 0) {
+    // We can't handle region-holding ops we don't know about.
     return false;
   } else if (!matchPattern(&op, m_Constant())) {
     // Register op in the set of ops that have users.
index 72a387c..3aecfde 100644 (file)
@@ -752,3 +752,59 @@ func.func @use_of_iter_args_not_invariant(%m : memref<10xindex>) {
 // CHECK-NEXT:  affine.for
 // CHECK-NEXT:  arith.addi
 // CHECK-NEXT:  affine.yield
+
+#map = affine_map<(d0) -> (64, d0 * -64 + 1020)>
+// CHECK-LABEL: func.func @affine_parallel
+func.func @affine_parallel(%memref_8: memref<4090x2040xf32>, %x: index) {
+  %cst = arith.constant 0.000000e+00 : f32
+  affine.parallel (%arg3) = (0) to (32) {
+    affine.for %arg4 = 0 to 16 {
+      affine.parallel (%arg5, %arg6) = (0, 0) to (min(128, 122), min(64, %arg3 * -64 + 2040)) {
+        affine.for %arg7 = 0 to min #map(%arg4) {
+          affine.store %cst, %memref_8[%arg5 + 3968, %arg6 + %arg3 * 64] : memref<4090x2040xf32>
+        }
+      }
+    }
+  }
+  // CHECK:       affine.parallel
+  // CHECK-NEXT:    affine.for
+  // CHECK-NEXT:      affine.parallel
+  // CHECK-NEXT:        affine.store
+  // CHECK-NEXT:        affine.for
+
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c32 = arith.constant 32 : index
+  scf.parallel (%arg3) = (%c0) to (%c32) step (%c1) {
+    affine.for %arg4 = 0 to 16 {
+      affine.parallel (%arg5, %arg6) = (0, 0) to (min(128, 122), min(64, %x * -64 + 2040)) {
+        affine.for %arg7 = 0 to min #map(%arg4) {
+          affine.store %cst, %memref_8[%arg5 + 3968, %arg6] : memref<4090x2040xf32>
+        }
+      }
+    }
+  }
+  // CHECK:       scf.parallel
+  // CHECK-NEXT:    affine.for
+  // CHECK-NEXT:      affine.parallel
+  // CHECK-NEXT:        affine.store
+  // CHECK-NEXT:        affine.for
+
+  affine.for %arg3 = 0 to 32 {
+    affine.for %arg4 = 0 to 16 {
+      affine.parallel (%arg5, %arg6) = (0, 0) to (min(128, 122), min(64, %arg3 * -64 + 2040)) {
+        // Unknown region-holding op for this pass.
+        scf.for %arg7 = %c0 to %x step %c1 {
+          affine.store %cst, %memref_8[%arg5 + 3968, %arg6 + %arg3 * 64] : memref<4090x2040xf32>
+        }
+      }
+    }
+  }
+  // CHECK:       affine.for
+  // CHECK-NEXT:    affine.for
+  // CHECK-NEXT:      affine.parallel
+  // CHECK-NEXT:        scf.for
+  // CHECK-NEXT:          affine.store
+
+  return
+}