[mlir][cf] Add support for opaque pointers to ControlFlowToLLVM lowering
authorMarkus Böck <markus.boeck02@gmail.com>
Wed, 8 Feb 2023 15:09:04 +0000 (16:09 +0100)
committerMarkus Böck <markus.boeck02@gmail.com>
Wed, 8 Feb 2023 20:23:23 +0000 (21:23 +0100)
Part of https://discourse.llvm.org/t/rfc-switching-the-llvm-dialect-and-dialect-lowerings-to-opaque-pointers/68179

This is a very simple patch since there is only one use of pointers types in `cf.assert` that has to be changed. Pointer types are conditionally created with element types and the GEP had to be adjusted to use the array type as base type.

Differential Revision: https://reviews.llvm.org/D143583

mlir/include/mlir/Conversion/Passes.td
mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp
mlir/test/Conversion/ControlFlowToLLVM/assert.mlir [new file with mode: 0644]

index d834b66..7fe6e63 100644 (file)
@@ -248,6 +248,9 @@ def ConvertControlFlowToLLVM : Pass<"convert-cf-to-llvm", "ModuleOp"> {
     Option<"indexBitwidth", "index-bitwidth", "unsigned",
            /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
            "Bitwidth of the index type, 0 to use size of machine word">,
+    Option<"useOpaquePointers", "use-opaque-pointers", "bool",
+                   /*default=*/"false", "Generate LLVM IR using opaque pointers "
+                   "instead of typed pointers">,
   ];
 }
 
index d448b05..6748b7b 100644 (file)
@@ -45,7 +45,7 @@ static std::string generateGlobalMsgSymbolName(ModuleOp moduleOp) {
 
 /// Generate IR that prints the given string to stderr.
 static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
-                           StringRef msg) {
+                           StringRef msg, LLVMTypeConverter &typeConverter) {
   auto ip = builder.saveInsertionPoint();
   builder.setInsertionPointToStart(moduleOp.getBody());
   MLIRContext *ctx = builder.getContext();
@@ -68,12 +68,13 @@ static void createPrintMsg(OpBuilder &builder, Location loc, ModuleOp moduleOp,
   // Emit call to `printStr` in runtime library.
   builder.restoreInsertionPoint(ip);
   auto msgAddr = builder.create<LLVM::AddressOfOp>(
-      loc, LLVM::LLVMPointerType::get(arrayTy), globalOp.getName());
+      loc, typeConverter.getPointerType(arrayTy), globalOp.getName());
   SmallVector<LLVM::GEPArg> indices(1, 0);
   Value gep = builder.create<LLVM::GEPOp>(
-      loc, LLVM::LLVMPointerType::get(builder.getI8Type()), msgAddr, indices);
-  Operation *printer =
-      LLVM::lookupOrCreatePrintStrFn(moduleOp, /*TODO: opaquePointers=*/false);
+      loc, typeConverter.getPointerType(builder.getI8Type()), arrayTy, msgAddr,
+      indices);
+  Operation *printer = LLVM::lookupOrCreatePrintStrFn(
+      moduleOp, typeConverter.useOpaquePointers());
   builder.create<LLVM::CallOp>(loc, TypeRange(), SymbolRefAttr::get(printer),
                                gep);
 }
@@ -102,7 +103,7 @@ struct AssertOpLowering : public ConvertOpToLLVMPattern<cf::AssertOp> {
 
     // Failed block: Generate IR to print the message and call `abort`.
     Block *failureBlock = rewriter.createBlock(opBlock->getParent());
-    createPrintMsg(rewriter, loc, module, op.getMsg());
+    createPrintMsg(rewriter, loc, module, op.getMsg(), *getTypeConverter());
     if (abortOnFailedAssert) {
       // Insert the `abort` declaration if necessary.
       auto abortFunc = module.lookupSymbol<LLVM::LLVMFuncOp>("abort");
@@ -274,6 +275,7 @@ struct ConvertControlFlowToLLVM
     LowerToLLVMOptions options(&getContext());
     if (indexBitwidth != kDeriveIndexBitwidthFromDataLayout)
       options.overrideIndexBitwidth(indexBitwidth);
+    options.useOpaquePointers = useOpaquePointers;
 
     LLVMTypeConverter converter(&getContext(), options);
     mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
diff --git a/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir b/mlir/test/Conversion/ControlFlowToLLVM/assert.mlir
new file mode 100644 (file)
index 0000000..67804b6
--- /dev/null
@@ -0,0 +1,17 @@
+// RUN: mlir-opt %s -convert-cf-to-llvm='use-opaque-pointers=1' | FileCheck %s
+
+func.func @main() {
+  %a = arith.constant 0 : i1
+  cf.assert %a, "assertion foo"
+  return
+}
+
+// CHECK: llvm.func @puts(!llvm.ptr)
+
+// CHECK-LABEL: @main
+// CHECK: llvm.cond_br %{{.*}}, ^{{.*}}, ^[[FALSE_BRANCH:[[:alnum:]]+]]
+
+// CHECK: ^[[FALSE_BRANCH]]:
+// CHECK: %[[ADDRESS_OF:.*]] = llvm.mlir.addressof @{{.*}} : !llvm.ptr{{$}}
+// CHECK: %[[GEP:.*]] = llvm.getelementptr %[[ADDRESS_OF]][0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<{{[0-9]+}} x i8>
+// CHECK: llvm.call @puts(%[[GEP]]) : (!llvm.ptr) -> ()