[MLIR][memref] Fix findDealloc() to handle > 1 dealloc for the given alloc.
authorRahul Joshi <jurahul@google.com>
Thu, 22 Jul 2021 00:13:40 +0000 (17:13 -0700)
committerRahul Joshi <jurahul@google.com>
Thu, 22 Jul 2021 16:34:19 +0000 (09:34 -0700)
- Change findDealloc() to return Optional<Operation *> and return None if > 1
  dealloc is associated with the given alloc.
- Add findDeallocs() to return all deallocs associated with the given alloc.
- Fix current uses of findDealloc() to bail out if > 1 dealloc is found.

Differential Revision: https://reviews.llvm.org/D106456

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td
mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h
mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
mlir/lib/Transforms/BufferUtils.cpp
mlir/test/Dialect/MemRef/canonicalize.mlir

index 5101d5930315892db55ec47b810312b6eca7a415..1ca01b8e9732c592890e235f03446b25fde953ec 100644 (file)
@@ -435,7 +435,7 @@ def CloneOp : MemRef_Op<"clone", [
   let results = (outs Arg<AnyRankedOrUnrankedMemRef, "", []>:$output);
 
   let extraClassDeclaration = [{
-    Value getSource() { return input();}
+    Value getSource() { return input(); }
     Value getTarget() { return output(); }
   }];
 
index 024fe5ebfbc35cdcd5dc33ee69d892a2b762242c..279ea4b1d898a6268f6db8cafdf12d3d0b5dd97b 100644 (file)
 
 namespace mlir {
 
-/// Finds the associated dealloc that can be linked to our allocation nodes (if
-/// any).
-Operation *findDealloc(Value allocValue);
-
+/// Finds a single dealloc operation for the given allocated value. If there
+/// are > 1 deallocates for `allocValue`, returns None, else returns the single
+/// deallocate if it exists or nullptr.
+llvm::Optional<Operation *> findDealloc(Value allocValue);
 } // end namespace mlir
 
 #endif // MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H
index 03fa871b1f7a3464ca3c11b633aedffd5a701cad..b65a6833f1fd9560b89299c9e2d51b77d7d526d5 100644 (file)
@@ -175,9 +175,9 @@ struct SimplifyDeadAlloc : public OpRewritePattern<T> {
   LogicalResult matchAndRewrite(T alloc,
                                 PatternRewriter &rewriter) const override {
     if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
-        if (auto storeOp = dyn_cast<StoreOp>(op))
-          return storeOp.value() == alloc;
-        return !isa<DeallocOp>(op);
+          if (auto storeOp = dyn_cast<StoreOp>(op))
+            return storeOp.value() == alloc;
+          return !isa<DeallocOp>(op);
         }))
       return failure();
 
@@ -519,8 +519,8 @@ void CloneOp::getEffects(
 }
 
 namespace {
-/// Fold Dealloc operations that are deallocating an AllocOp that is only used
-/// by other Dealloc operations.
+/// Merge the clone and its source (by converting the clone to a cast) when
+/// possible.
 struct SimplifyClones : public OpRewritePattern<CloneOp> {
   using OpRewritePattern<CloneOp>::OpRewritePattern;
 
@@ -536,8 +536,16 @@ struct SimplifyClones : public OpRewritePattern<CloneOp> {
     // This only finds dealloc operations for the immediate value. It should
     // also consider aliases. That would also make the safety check below
     // redundant.
-    Operation *cloneDeallocOp = findDealloc(cloneOp.output());
-    Operation *sourceDeallocOp = findDealloc(source);
+    llvm::Optional<Operation *> maybeCloneDeallocOp =
+        findDealloc(cloneOp.output());
+    // Skip if either of them has > 1 deallocate operations.
+    if (!maybeCloneDeallocOp.hasValue())
+      return failure();
+    llvm::Optional<Operation *> maybeSourceDeallocOp = findDealloc(source);
+    if (!maybeSourceDeallocOp.hasValue())
+      return failure();
+    Operation *cloneDeallocOp = *maybeCloneDeallocOp;
+    Operation *sourceDeallocOp = *maybeSourceDeallocOp;
 
     // If both are deallocated in the same block, their in-block lifetimes
     // might not fully overlap, so we cannot decide which one to drop.
index eb9817014c2fff5e8c86885a385cf09e1e90991a..edb7e46e7e9a47e65b7d0307c055c92ddfe65a29 100644 (file)
@@ -1,4 +1,4 @@
-//===- Utils.cpp - Utilities to support the MemRef dialect ----------------===//
+//===- MemRefUtils.cpp - Utilities to support the MemRef dialect ----------===//
 //
 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 // See https://llvm.org/LICENSE.txt for license information.
 
 using namespace mlir;
 
-/// Finds associated deallocs that can be linked to our allocation nodes (if
-/// any).
-Operation *mlir::findDealloc(Value allocValue) {
-  auto userIt = llvm::find_if(allocValue.getUsers(), [&](Operation *user) {
+/// Finds a single dealloc operation for the given allocated value.
+llvm::Optional<Operation *> mlir::findDealloc(Value allocValue) {
+  Operation *dealloc = nullptr;
+  for (Operation *user : allocValue.getUsers()) {
     auto effectInterface = dyn_cast<MemoryEffectOpInterface>(user);
     if (!effectInterface)
-      return false;
+      continue;
     // Try to find a free effect that is applied to one of our values
     // that will be automatically freed by our pass.
     SmallVector<MemoryEffects::EffectInstance, 2> effects;
     effectInterface.getEffectsOnValue(allocValue, effects);
-    return llvm::any_of(effects, [&](MemoryEffects::EffectInstance &it) {
-      return isa<MemoryEffects::Free>(it.getEffect());
-    });
-  });
-  // Assign the associated dealloc operation (if any).
-  return userIt != allocValue.user_end() ? *userIt : nullptr;
+    const bool isFree =
+        llvm::any_of(effects, [&](MemoryEffects::EffectInstance &it) {
+          return isa<MemoryEffects::Free>(it.getEffect());
+        });
+    if (!isFree)
+      continue;
+    // If we found > 1 dealloc, return None.
+    if (dealloc)
+      return llvm::None;
+    dealloc = user;
+  }
+  return dealloc;
 }
index 0cefd53d2d3478e6bb9863e408a6c88e89396e0c..c24293cb4bd6848060ecc91ae9a1a5c46df72cd2 100644 (file)
@@ -77,7 +77,11 @@ void BufferPlacementAllocs::build(Operation *op) {
     // Get allocation result.
     Value allocValue = allocateResultEffects[0].getValue();
     // Find the associated dealloc value and register the allocation entry.
-    allocs.push_back(std::make_tuple(allocValue, findDealloc(allocValue)));
+    llvm::Optional<Operation *> dealloc = findDealloc(allocValue);
+    // If the allocation has > 1 dealloc associated with it, skip handling it.
+    if (!dealloc.hasValue())
+      return;
+    allocs.push_back(std::make_tuple(allocValue, *dealloc));
   });
 }
 
index 02a8ce4441c3222e578b652dcc0fa24bc09a4abe..c63994b480dbee9d9269cc2f6382c28e9b00e381 100644 (file)
@@ -195,6 +195,44 @@ func @alias_is_freed(%arg0 : memref<?xf32>) {
 
 // -----
 
+// Verify SimplifyClones skips clones with multiple deallocations.
+// CHECK-LABEL: @clone_multiple_dealloc_of_source
+// CHECK-SAME: %[[ARG:.*]]: memref<?xf32>
+func @clone_multiple_dealloc_of_source(%arg0: memref<?xf32>) -> memref<?xf32> {
+  // CHECK-NEXT: %[[RES:.*]] = memref.clone %[[ARG]]
+  // CHECK: memref.dealloc %[[ARG]]
+  // CHECK: memref.dealloc %[[ARG]]
+  // CHECK: return %[[RES]]
+  %0 = memref.clone %arg0 : memref<?xf32> to memref<?xf32>
+  "if_else"() ({
+    memref.dealloc %arg0 : memref<?xf32>
+    }, {
+    memref.dealloc %arg0 : memref<?xf32>
+    }) : () -> ()
+  return %0 : memref<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @clone_multiple_dealloc_of_clone
+// CHECK-SAME: %[[ARG:.*]]: memref<?xf32>
+func @clone_multiple_dealloc_of_clone(%arg0: memref<?xf32>) -> memref<?xf32> {
+  // CHECK-NEXT: %[[CLONE:.*]] = memref.clone %[[ARG]]
+  // CHECK: memref.dealloc %[[CLONE]]
+  // CHECK: memref.dealloc %[[CLONE]]
+  // CHECK: return %[[ARG]]
+  %0 = memref.clone %arg0 : memref<?xf32> to memref<?xf32>
+  "use"(%0) : (memref<?xf32>) -> ()
+  "if_else"() ({
+    memref.dealloc %0 : memref<?xf32>
+    }, {
+    memref.dealloc %0 : memref<?xf32>
+    }) : () -> ()
+  return %arg0 : memref<?xf32>
+}
+
+// -----
+
 // CHECK-LABEL: func @dim_of_sized_view
 //  CHECK-SAME:   %{{[a-z0-9A-Z_]+}}: memref<?xi8>
 //  CHECK-SAME:   %[[SIZE:.[a-z0-9A-Z_]+]]: index
@@ -393,7 +431,7 @@ func @alloc_const_fold_with_symbols2() -> memref<?xi32, #map0> {
 func @allocator(%arg0 : memref<memref<?xi32>>, %arg1 : index)  {
   %0 = memref.alloc(%arg1) : memref<?xi32>
   memref.store %0, %arg0[] : memref<memref<?xi32>>
-  return 
+  return
 }
 
 // -----