fix simplify-affine-structures bug
authorUday Bondhugula <udayb@iisc.ac.in>
Mon, 7 Oct 2019 17:03:38 +0000 (10:03 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 7 Oct 2019 17:04:50 +0000 (10:04 -0700)
Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>
Closes tensorflow/mlir#157

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/157 from bondhugula:quickfix bd1fcd79825fc0bd5b4a3e688153fa0993ab703d
PiperOrigin-RevId: 273316498

mlir/lib/Transforms/SimplifyAffineStructures.cpp
mlir/test/Transforms/memref-normalize.mlir

index e243c1b..9512ff7 100644 (file)
@@ -102,9 +102,16 @@ void SimplifyAffineStructures::runOnFunction() {
     }
   });
 
-  // Turn memrefs' non-identity layouts maps into ones with identity.
-  func.walk([](AllocOp op) { normalizeMemRef(op); });
+  // Turn memrefs' non-identity layouts maps into ones with identity. Collect
+  // alloc ops first and then process since normalizeMemRef replaces/erases ops
+  // during memref rewriting.
+  SmallVector<AllocOp, 4> allocOps;
+  func.walk([&](AllocOp op) { allocOps.push_back(op); });
+  for (auto allocOp : allocOps) {
+    normalizeMemRef(allocOp);
+  }
 }
 
 static PassRegistration<SimplifyAffineStructures>
-    pass("simplify-affine-structures", "Simplify affine expressions");
+    pass("simplify-affine-structures",
+         "Simplify affine expressions in maps/sets and normalize memrefs");
index e9b6362..90b3632 100644 (file)
@@ -22,10 +22,12 @@ func @permute() {
 // CHECK-NEXT: dealloc [[MEM]]
 // CHECK-NEXT: return
 
-// CHECK-LABEL: func @shift()
-func @shift() {
-  // CHECK-NOT:  memref<64xf32, (d0) -> (d0 + 1)>
+// CHECK-LABEL: func @shift
+func @shift(%idx : index) {
+  // CHECK-NEXT: alloc() : memref<65xf32>
   %A = alloc() : memref<64xf32, (d0) -> (d0 + 1)>
+  // CHECK-NEXT: affine.load %{{.*}}[symbol(%arg0) + 1] : memref<65xf32>
+  affine.load %A[%idx] : memref<64xf32, (d0) -> (d0 + 1)>
   affine.for %i = 0 to 64 {
     affine.load %A[%i] : memref<64xf32, (d0) -> (d0 + 1)>
     // CHECK: %{{.*}} = affine.load %{{.*}}[%arg{{.*}} + 1] : memref<65xf32>
@@ -59,10 +61,12 @@ func @invalid_map() {
 }
 
 // A tiled layout.
-// CHECK-LABEL: func @data_tiling()
-func @data_tiling() {
+// CHECK-LABEL: func @data_tiling
+func @data_tiling(%idx : index) {
+  // CHECK: alloc() : memref<8x32x8x16xf32>
   %A = alloc() : memref<64x512xf32, (d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)>
-  // CHECK: %{{.*}} = alloc() : memref<8x32x8x16xf32>
+  // CHECK: affine.load %{{.*}}[symbol(%arg0) floordiv 8, symbol(%arg0) floordiv 16, symbol(%arg0) mod 8, symbol(%arg0) mod 16]
+  affine.load %A[%idx, %idx] : memref<64x512xf32, (d0, d1) -> (d0 floordiv 8, d1 floordiv 16, d0 mod 8, d1 mod 16)>
   return
 }