[mlir] use getNumDimAndSymbolVars when iterate dims/symbols of FlatAffineValueConstraints
authorXiang <python3kgae@outlook.com>
Sun, 29 Jan 2023 23:12:47 +0000 (18:12 -0500)
committerXiang <python3kgae@outlook.com>
Mon, 30 Jan 2023 14:07:49 +0000 (09:07 -0500)
Fixes #59443  https://github.com/llvm/llvm-project/issues/59443

getNumVars will add locals and cause out of bound access.

Differential Revision: https://reviews.llvm.org/D142851

mlir/lib/Dialect/Affine/Analysis/Utils.cpp
mlir/test/Dialect/SCF/for-loop-canonicalization.mlir

index ff88acd..581ffb0 100644 (file)
@@ -1546,7 +1546,7 @@ mlir::simplifyConstrainedMinMaxOp(Operation *op,
   unpackOptionalValues(constraints.getMaybeValues(), newOperands);
   // If dims/symbols have known constant values, use those in order to simplify
   // the affine map further.
-  for (int64_t i = 0, e = constraints.getNumVars(); i < e; ++i) {
+  for (int64_t i = 0, e = constraints.getNumDimAndSymbolVars(); i < e; ++i) {
     // Skip unused operands and operands that are already constants.
     if (!newOperands[i] || getConstantIntValue(newOperands[i]))
       continue;
index 4638a21..83c236e 100644 (file)
@@ -391,3 +391,32 @@ func.func @regression_multiplication_with_sym(%A : memref<i64>) {
   }
   return
 }
+
+// -----
+
+// Make sure min is transformed into zero.
+
+// CHECK: %[[ZERO:.+]] = arith.constant 0 : index
+// CHECK: scf.index_switch %[[ZERO]] -> i1
+
+#map6 = affine_map<(d0, d1, d2) -> (d0 floordiv 64)>
+#map29 = affine_map<(d0, d1, d2) -> (d2 * 64 - 2, 5, (d1 mod 4) floordiv 8)>
+module {
+  func.func @func1() {
+    %true = arith.constant true
+    %c0 = arith.constant 0 : index
+    %c5 = arith.constant 5 : index
+    %c11 = arith.constant 11 : index
+    %c14 = arith.constant 14 : index
+    %c15 = arith.constant 15 : index
+    %alloc_249 = memref.alloc() : memref<7xf32>
+    %135 = affine.apply #map6(%c15, %c0, %c14)
+    %163 = affine.min #map29(%c5, %135, %c11)
+    %196 = scf.index_switch %163 -> i1
+    default {
+      memref.assume_alignment %alloc_249, 1 : memref<7xf32>
+      scf.yield %true : i1
+    }
+    return
+  }
+}