[mlir] Fix bug in copy removal
authorEhsan Toosi <ehsan.nadjaran_toosi@dfki.de>
Mon, 24 Aug 2020 11:19:50 +0000 (13:19 +0200)
committerEhsan Toosi <ehsan.nadjaran_toosi@dfki.de>
Tue, 8 Sep 2020 12:17:13 +0000 (14:17 +0200)
A crash could happen due to copy removal. The bug is fixed and two more
test cases are added.

Differential Revision: https://reviews.llvm.org/D87128

mlir/lib/Transforms/CopyRemoval.cpp
mlir/test/Transforms/copy-removal.mlir

index ccfd026..c5a8da6 100644 (file)
@@ -30,16 +30,35 @@ public:
       reuseCopySourceAsTarget(copyOp);
       reuseCopyTargetAsSource(copyOp);
     });
+    for (std::pair<Value, Value> &pair : replaceList)
+      pair.first.replaceAllUsesWith(pair.second);
     for (Operation *op : eraseList)
       op->erase();
   }
 
 private:
   /// List of operations that need to be removed.
-  DenseSet<Operation *> eraseList;
+  llvm::SmallPtrSet<Operation *, 4> eraseList;
+
+  /// List of values that need to be replaced with their counterparts.
+  llvm::SmallDenseSet<std::pair<Value, Value>, 4> replaceList;
+
+  /// Returns the allocation operation for `value` in `block` if it exists.
+  /// nullptr otherwise.
+  Operation *getAllocationOpInBlock(Value value, Block *block) {
+    assert(block && "Block cannot be null");
+    Operation *op = value.getDefiningOp();
+    if (op && op->getBlock() == block) {
+      auto effects = dyn_cast<MemoryEffectOpInterface>(op);
+      if (effects && effects.hasEffect<Allocate>())
+        return op;
+    }
+    return nullptr;
+  }
 
   /// Returns the deallocation operation for `value` in `block` if it exists.
-  Operation *getDeallocationInBlock(Value value, Block *block) {
+  /// nullptr otherwise.
+  Operation *getDeallocationOpInBlock(Value value, Block *block) {
     assert(block && "Block cannot be null");
     auto valueUsers = value.getUsers();
     auto it = llvm::find_if(valueUsers, [&](Operation *op) {
@@ -119,9 +138,10 @@ private:
     Value to = copyOp.getTarget();
 
     Operation *copy = copyOp.getOperation();
+    Block *copyBlock = copy->getBlock();
     Operation *fromDefiningOp = from.getDefiningOp();
-    Operation *fromFreeingOp = getDeallocationInBlock(from, copy->getBlock());
-    Operation *toDefiningOp = to.getDefiningOp();
+    Operation *fromFreeingOp = getDeallocationOpInBlock(from, copyBlock);
+    Operation *toDefiningOp = getAllocationOpInBlock(to, copyBlock);
     if (!fromDefiningOp || !fromFreeingOp || !toDefiningOp ||
         !areOpsInTheSameBlock({fromFreeingOp, toDefiningOp, copy}) ||
         hasUsersBetween(to, toDefiningOp, copy) ||
@@ -129,7 +149,7 @@ private:
         hasMemoryEffectOpBetween(copy, fromFreeingOp))
       return;
 
-    to.replaceAllUsesWith(from);
+    replaceList.insert({to, from});
     eraseList.insert(copy);
     eraseList.insert(toDefiningOp);
     eraseList.insert(fromFreeingOp);
@@ -169,8 +189,9 @@ private:
     Value to = copyOp.getTarget();
 
     Operation *copy = copyOp.getOperation();
-    Operation *fromDefiningOp = from.getDefiningOp();
-    Operation *fromFreeingOp = getDeallocationInBlock(from, copy->getBlock());
+    Block *copyBlock = copy->getBlock();
+    Operation *fromDefiningOp = getAllocationOpInBlock(from, copyBlock);
+    Operation *fromFreeingOp = getDeallocationOpInBlock(from, copyBlock);
     if (!fromDefiningOp || !fromFreeingOp ||
         !areOpsInTheSameBlock({fromFreeingOp, fromDefiningOp, copy}) ||
         hasUsersBetween(to, fromDefiningOp, copy) ||
@@ -178,7 +199,7 @@ private:
         hasMemoryEffectOpBetween(copy, fromFreeingOp))
       return;
 
-    from.replaceAllUsesWith(to);
+    replaceList.insert({from, to});
     eraseList.insert(copy);
     eraseList.insert(fromDefiningOp);
     eraseList.insert(fromFreeingOp);
index f750dab..a0d1193 100644 (file)
@@ -283,3 +283,67 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>){
   dealloc %temp : memref<2xf32>
   return
 }
+
+// -----
+
+// The only redundant copy is linalg.copy(%4, %5)
+
+// CHECK-LABEL: func @loop_alloc
+func @loop_alloc(%arg0: index, %arg1: index, %arg2: index, %arg3: memref<2xf32>, %arg4: memref<2xf32>) {
+  // CHECK: %{{.*}} = alloc()
+  %0 = alloc() : memref<2xf32>
+  dealloc %0 : memref<2xf32>
+  // CHECK: %{{.*}} = alloc()
+  %1 = alloc() : memref<2xf32>
+  // CHECK: linalg.copy
+  linalg.copy(%arg3, %1) : memref<2xf32>, memref<2xf32>
+  %2 = scf.for %arg5 = %arg0 to %arg1 step %arg2 iter_args(%arg6 = %1) -> (memref<2xf32>) {
+    %3 = cmpi "eq", %arg5, %arg1 : index
+    // CHECK: dealloc
+    dealloc %arg6 : memref<2xf32>
+    // CHECK: %[[PERCENT4:.*]] = alloc()
+    %4 = alloc() : memref<2xf32>
+    // CHECK-NOT: alloc
+    // CHECK-NOT: linalg.copy
+    // CHECK-NOT: dealloc
+    %5 = alloc() : memref<2xf32>
+    linalg.copy(%4, %5) : memref<2xf32>, memref<2xf32>
+    dealloc %4 : memref<2xf32>
+    // CHECK: %[[PERCENT6:.*]] = alloc()
+    %6 = alloc() : memref<2xf32>
+    // CHECK: linalg.copy(%[[PERCENT4]], %[[PERCENT6]])
+    linalg.copy(%5, %6) : memref<2xf32>, memref<2xf32>
+    scf.yield %6 : memref<2xf32>
+  }
+  // CHECK: linalg.copy
+  linalg.copy(%2, %arg4) : memref<2xf32>, memref<2xf32>
+  dealloc %2 : memref<2xf32>
+  return
+}
+
+// -----
+
+// The linalg.copy operation can be removed in addition to alloc and dealloc
+// operations. All uses of %0 is then replaced with %arg2.
+
+// CHECK-LABEL: func @check_with_affine_dialect
+func @check_with_affine_dialect(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg2: memref<4xf32>) {
+  // CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32>, %[[ARG1:.*]]: memref<4xf32>, %[[RES:.*]]: memref<4xf32>)
+  // CHECK-NOT: alloc
+  %0 = alloc() : memref<4xf32>
+  affine.for %arg3 = 0 to 4 {
+    %5 = affine.load %arg0[%arg3] : memref<4xf32>
+    %6 = affine.load %arg1[%arg3] : memref<4xf32>
+    %7 = cmpf "ogt", %5, %6 : f32
+    // CHECK: %[[SELECT_RES:.*]] = select
+    %8 = select %7, %5, %6 : f32
+    // CHECK-NEXT: affine.store %[[SELECT_RES]], %[[RES]]
+    affine.store %8, %0[%arg3] : memref<4xf32>
+  }
+  // CHECK-NOT: linalg.copy
+  // CHECK-NOT: dealloc
+  "linalg.copy"(%0, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
+  dealloc %0 : memref<4xf32>
+  //CHECK: return
+  return
+}