From fc253e69f9b988e8b2d4c940946146696b2acf5a Mon Sep 17 00:00:00 2001 From: Julian Gross Date: Mon, 3 May 2021 16:59:59 +0200 Subject: [PATCH] Fixed bug in buffer deallocation pass using unranked memref types. In the buffer deallocation pass, unranked memref types are not properly supported. After investigating this issue, it turns out that the Clone and Dealloc operation does not support unranked memref types in the current implementation. This patch adds the missing feature and enables the transformation of any memref type. This patch solves this bug: https://bugs.llvm.org/show_bug.cgi?id=48385 Differential Revision: https://reviews.llvm.org/D101760 --- mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td | 8 +++-- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 8 ----- mlir/test/Conversion/StandardToSPIRV/alloc.mlir | 4 +-- mlir/test/Dialect/MemRef/ops.mlir | 16 ++++++++++ mlir/test/Transforms/buffer-deallocation.mlir | 37 ++++++++++++++++++++++++ mlir/test/lib/Dialect/Test/TestOps.td | 7 +++-- 6 files changed, 64 insertions(+), 16 deletions(-) diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 7b341b1..74afcd0 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -342,8 +342,8 @@ def CloneOp : MemRef_Op<"clone", [ undefined behavior. }]; - let arguments = (ins Arg:$input); - let results = (outs Arg:$output); + let arguments = (ins Arg:$input); + let results = (outs Arg:$output); let extraClassDeclaration = [{ Value getSource() { return input();} @@ -353,6 +353,7 @@ def CloneOp : MemRef_Op<"clone", [ let assemblyFormat = "$input attr-dict `:` type($input) `to` type($output)"; let hasFolder = 1; + let verifier = ?; let hasCanonicalizer = 1; } @@ -376,9 +377,10 @@ def MemRef_DeallocOp : MemRef_Op<"dealloc", [MemRefsNormalizable]> { ``` }]; - let arguments = (ins Arg:$memref); + let arguments = (ins Arg:$memref); let hasFolder = 1; + let verifier = ?; let assemblyFormat = "$memref attr-dict `:` type($memref)"; } diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 57c1b15..9a50278 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -474,8 +474,6 @@ OpFoldResult CastOp::fold(ArrayRef operands) { // CloneOp //===----------------------------------------------------------------------===// -static LogicalResult verify(CloneOp op) { return success(); } - void CloneOp::getEffects( SmallVectorImpl> &effects) { @@ -544,12 +542,6 @@ OpFoldResult CloneOp::fold(ArrayRef operands) { // DeallocOp //===----------------------------------------------------------------------===// -static LogicalResult verify(DeallocOp op) { - if (!op.memref().getType().isa()) - return op.emitOpError("operand must be a memref"); - return success(); -} - LogicalResult DeallocOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// dealloc(memrefcast) -> dealloc diff --git a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir index 2d9dcf4..2d8e84a 100644 --- a/mlir/test/Conversion/StandardToSPIRV/alloc.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/alloc.mlir @@ -139,7 +139,7 @@ module attributes { { func @alloc_dealloc_dynamic_workgroup_mem(%arg0 : memref<4x?xf32, 3>) { // expected-error @+2 {{unhandled deallocation type}} - // expected-error @+1 {{'memref.dealloc' op operand #0 must be memref of any type values}} + // expected-error @+1 {{'memref.dealloc' op operand #0 must be unranked.memref of any type values or memref of any type values}} memref.dealloc %arg0 : memref<4x?xf32, 3> return } @@ -154,7 +154,7 @@ module attributes { { func @alloc_dealloc_mem(%arg0 : memref<4x5xf32>) { // expected-error @+2 {{unhandled deallocation type}} - // expected-error @+1 {{op operand #0 must be memref of any type values}} + // expected-error @+1 {{op operand #0 must be unranked.memref of any type values or memref of any type values}} memref.dealloc %arg0 : memref<4x5xf32> return } diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir index 9b6a05d..1b57284 100644 --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -60,3 +60,19 @@ func @read_global_memref() { %1 = memref.tensor_load %0 : memref<2xf32> return } + +// CHECK-LABEL: func @memref_clone +func @memref_clone() { + %0 = memref.alloc() : memref<2xf32> + %1 = memref.cast %0 : memref<2xf32> to memref<*xf32> + %2 = memref.clone %1 : memref<*xf32> to memref<*xf32> + return +} + +// CHECK-LABEL: func @memref_dealloc +func @memref_dealloc() { + %0 = memref.alloc() : memref<2xf32> + %1 = memref.cast %0 : memref<2xf32> to memref<*xf32> + memref.dealloc %1 : memref<*xf32> + return +} diff --git a/mlir/test/Transforms/buffer-deallocation.mlir b/mlir/test/Transforms/buffer-deallocation.mlir index 35f7bbf..7794511 100644 --- a/mlir/test/Transforms/buffer-deallocation.mlir +++ b/mlir/test/Transforms/buffer-deallocation.mlir @@ -90,6 +90,43 @@ func @condBranchDynamicType( // ----- +// Test case: See above. + +// CHECK-LABEL: func @condBranchUnrankedType +func @condBranchUnrankedType( + %arg0: i1, + %arg1: memref<*xf32>, + %arg2: memref<*xf32>, + %arg3: index) { + cond_br %arg0, ^bb1, ^bb2(%arg3: index) +^bb1: + br ^bb3(%arg1 : memref<*xf32>) +^bb2(%0: index): + %1 = memref.alloc(%0) : memref + %2 = memref.cast %1 : memref to memref<*xf32> + test.buffer_based in(%arg1: memref<*xf32>) out(%2: memref<*xf32>) + br ^bb3(%2 : memref<*xf32>) +^bb3(%3: memref<*xf32>): + test.copy(%3, %arg2) : (memref<*xf32>, memref<*xf32>) + return +} + +// CHECK-NEXT: cond_br +// CHECK: %[[ALLOC0:.*]] = memref.clone +// CHECK-NEXT: br ^bb3(%[[ALLOC0]] +// CHECK: ^bb2(%[[IDX:.*]]:{{.*}}) +// CHECK-NEXT: %[[ALLOC1:.*]] = memref.alloc(%[[IDX]]) +// CHECK: test.buffer_based +// CHECK-NEXT: %[[ALLOC2:.*]] = memref.clone +// CHECK-NEXT: memref.dealloc %[[ALLOC1]] +// CHECK-NEXT: br ^bb3 +// CHECK-NEXT: ^bb3(%[[ALLOC3:.*]]:{{.*}}) +// CHECK: test.copy(%[[ALLOC3]], +// CHECK-NEXT: memref.dealloc %[[ALLOC3]] +// CHECK-NEXT: return + +// ----- + // Test Case: // bb0 // / \ diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index b0c2fe4..795a5af 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1879,8 +1879,8 @@ def CopyOp : TEST_Op<"copy", [CopyOpInterface]> { let description = [{ Represents a copy operation. }]; - let arguments = (ins Res:$source, - Res:$target); + let arguments = (ins Res:$source, + Res:$target); let assemblyFormat = [{ `(` $source `,` $target `)` `:` `(` type($source) `,` type($target) `)` attr-dict @@ -1915,7 +1915,8 @@ class BufferBasedOpBase traits> let description = [{ A buffer based operation, that uses memRefs as input and output. }]; - let arguments = (ins AnyMemRef:$input, AnyMemRef:$output); + let arguments = (ins AnyRankedOrUnrankedMemRef:$input, + AnyRankedOrUnrankedMemRef:$output); } def BufferBasedOp : BufferBasedOpBase<"buffer_based", []>{ -- 2.7.4