Add a lowering for memref.dealloc with unranked memrefs.
authorJohannes Reifferscheid <jreiffers@google.com>
Thu, 16 Feb 2023 13:17:28 +0000 (14:17 +0100)
committerJohannes Reifferscheid <jreiffers@google.com>
Thu, 16 Feb 2023 13:19:16 +0000 (14:19 +0100)
This is permitted by the op, but the current lowering generates invalid IR.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D144090

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp
mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir

index 35091a3..7da7b66 100644 (file)
@@ -382,14 +382,27 @@ struct DeallocOpLowering : public ConvertOpToLLVMPattern<memref::DeallocOp> {
     // Insert the `free` declaration if it is not already present.
     LLVM::LLVMFuncOp freeFunc =
         getFreeFn(getTypeConverter(), op->getParentOfType<ModuleOp>());
-    MemRefDescriptor memref(adaptor.getMemref());
-    Value allocatedPtr = memref.allocatedPtr(rewriter, op.getLoc());
-    Value casted = allocatedPtr;
+    Value allocatedPtr;
+    if (auto unrankedTy =
+            llvm::dyn_cast<UnrankedMemRefType>(op.getMemref().getType())) {
+      Type elementType = unrankedTy.getElementType();
+      Type llvmElementTy = getTypeConverter()->convertType(elementType);
+      LLVM::LLVMPointerType elementPtrTy = getTypeConverter()->getPointerType(
+          llvmElementTy, unrankedTy.getMemorySpaceAsInt());
+      allocatedPtr = UnrankedMemRefDescriptor::allocatedPtr(
+          rewriter, op.getLoc(),
+          UnrankedMemRefDescriptor(adaptor.getMemref())
+              .memRefDescPtr(rewriter, op.getLoc()),
+          elementPtrTy);
+    } else {
+      allocatedPtr = MemRefDescriptor(adaptor.getMemref())
+                         .allocatedPtr(rewriter, op.getLoc());
+    }
     if (!getTypeConverter()->useOpaquePointers())
-      casted = rewriter.create<LLVM::BitcastOp>(op.getLoc(), getVoidPtrType(),
-                                                allocatedPtr);
+      allocatedPtr = rewriter.create<LLVM::BitcastOp>(
+          op.getLoc(), getVoidPtrType(), allocatedPtr);
 
-    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, casted);
+    rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, freeFunc, allocatedPtr);
     return success();
   }
 };
index 7b9c00c..b703184 100644 (file)
@@ -42,6 +42,17 @@ func.func @mixed_dealloc(%arg0: memref<?x42x?xf32>) {
 
 // -----
 
+// CHECK-LABEL: func @unranked_dealloc
+func.func @unranked_dealloc(%arg0: memref<*xf32>) {
+//      CHECK: %[[memref:.*]] = llvm.extractvalue %{{.*}} : !llvm.struct<(i64, ptr)>
+//      CHECK: %[[ptr:.*]] = llvm.load %[[memref]]
+// CHECK-NEXT: llvm.call @free(%[[ptr]])
+  memref.dealloc %arg0 : memref<*xf32>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: func @dynamic_alloc(
 //       CHECK:   %[[Marg:.*]]: index, %[[Narg:.*]]: index)
 func.func @dynamic_alloc(%arg0: index, %arg1: index) -> memref<?x?xf32> {