From cccc7e5aa8088b3b721e1f430c47d199575fae9b Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Fri, 25 Jun 2021 16:16:23 -0400 Subject: [PATCH] [MLIR] Don't remove memref allocation if stored into another allocation A canonicalization accidentally will remove a memref allocation if it is only stored into. However, this is incorrect if the allocation is the value being stored, not the allocation being stored into. Differential Revision: https://reviews.llvm.org/D104947 --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 6 ++++-- mlir/test/Dialect/MemRef/canonicalize.mlir | 10 ++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index cc4e7a4..6f358d8 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -174,8 +174,10 @@ struct SimplifyDeadAlloc : public OpRewritePattern { LogicalResult matchAndRewrite(T alloc, PatternRewriter &rewriter) const override { - if (llvm::any_of(alloc->getUsers(), [](Operation *op) { - return !isa(op); + if (llvm::any_of(alloc->getUsers(), [&](Operation *op) { + if (auto storeOp = dyn_cast(op)) + return storeOp.value() == alloc; + return !isa(op); })) return failure(); diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index cbf2126..c59d1d3 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -420,3 +420,13 @@ func @alloc_const_fold_with_symbols2() -> memref { %0 = memref.alloc(%c1)[%c1, %c1] : memref return %0 : memref } + +// ----- +// CHECK-LABEL: func @allocator +// CHECK: %[[alloc:.+]] = memref.alloc +// CHECK: memref.store %[[alloc:.+]], %arg0 +func @allocator(%arg0 : memref>, %arg1 : index) { + %0 = memref.alloc(%arg1) : memref + memref.store %0, %arg0[] : memref> + return +} -- 2.7.4