}
Value source = cloneOp.getInput();
+ // Aims to find the dealloc op for the canonical source
+ // which otherwise could prevent removal of unnecessary allocs.
+ Value canonicalSource = source;
+ while (auto iface = dyn_cast_or_null<ViewLikeOpInterface>(
+ canonicalSource.getDefiningOp()))
+ canonicalSource = iface.getViewSource();
- // This only finds dealloc operations for the immediate value. It should
- // also consider aliases. That would also make the safety check below
- // redundant.
llvm::Optional<Operation *> maybeCloneDeallocOp =
memref::findDealloc(cloneOp.getOutput());
// Skip if either of them has > 1 deallocate operations.
if (!maybeCloneDeallocOp.has_value())
return failure();
llvm::Optional<Operation *> maybeSourceDeallocOp =
- memref::findDealloc(source);
+ memref::findDealloc(canonicalSource);
if (!maybeSourceDeallocOp.has_value())
return failure();
Operation *cloneDeallocOp = *maybeCloneDeallocOp;
// CHECK: %[[T0:.+]] = bufferization.alloc_tensor() : tensor<4x5x6xf32>
// CHECK: %[[T1:.+]] = tensor.cast %[[T0]] : tensor<4x5x6xf32> to tensor<4x5x?xf32>
// CHECK: return %[[T1]]
+
+// -----
+
+func.func @dealloc_canonicalize_clone_removal(%arg0: memref<?xindex>) -> memref<*xf32> {
+ %c1 = arith.constant 1 : index
+ %0 = memref.alloc(%c1) : memref<?xf32>
+ %1 = memref.reshape %0(%arg0) : (memref<?xf32>, memref<?xindex>) -> memref<*xf32>
+ %2 = bufferization.clone %1 : memref<*xf32> to memref<*xf32>
+ memref.dealloc %0 : memref<?xf32>
+ return %2 : memref<*xf32>
+}
+// CHECK-LABEL: @dealloc_canonicalize_clone_removal
+// CHECK-NOT: bufferization.clone
+// CHECK-NOT: memref.dealloc
+// CHECK: return {{.*}}