[mlir][SCFToOpenMP] Add pass option to emit LLVM opaque pointers
authorMarkus Böck <markus.boeck02@gmail.com>
Sun, 12 Feb 2023 20:34:21 +0000 (21:34 +0100)
committerMarkus Böck <markus.boeck02@gmail.com>
Mon, 13 Feb 2023 10:49:37 +0000 (11:49 +0100)
Part of https://discourse.llvm.org/t/rfc-switching-the-llvm-dialect-and-dialect-lowerings-to-opaque-pointers/68179

There were luckily only very few changes that had to be made. To allow users to also specify the pass option from C++ code I have also migrated the pass to use autogenerated constructors to autogenerate a pass option struct.

Differential Revision: https://reviews.llvm.org/D143855

mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Conversion/SCFToOpenMP/SCFToOpenMP.h
mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
mlir/test/Conversion/SCFToOpenMP/reductions.mlir
mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir
mlir/test/Conversion/SCFToOpenMP/typed-pointers.mlir [new file with mode: 0644]

index 2924cc8..0533373 100644 (file)
@@ -746,10 +746,16 @@ def SCFToControlFlow : Pass<"convert-scf-to-cf"> {
 // SCFToOpenMP
 //===----------------------------------------------------------------------===//
 
-def ConvertSCFToOpenMP : Pass<"convert-scf-to-openmp", "ModuleOp"> {
+def ConvertSCFToOpenMPPass : Pass<"convert-scf-to-openmp", "ModuleOp"> {
   let summary = "Convert SCF parallel loop to OpenMP parallel + workshare "
                 "constructs.";
-  let constructor = "mlir::createConvertSCFToOpenMPPass()";
+
+  let options = [
+    Option<"useOpaquePointers", "use-opaque-pointers", "bool",
+                 /*default=*/"false", "Generate LLVM IR using opaque pointers "
+                 "instead of typed pointers">
+  ];
+
   let dependentDialects = ["omp::OpenMPDialect", "LLVM::LLVMDialect",
                            "memref::MemRefDialect"];
 }
index 7dd5315..dfff8e6 100644 (file)
 #include <memory>
 
 namespace mlir {
-class ModuleOp;
-template <typename T>
-class OperationPass;
+class Pass;
 
-#define GEN_PASS_DECL_CONVERTSCFTOOPENMP
+#define GEN_PASS_DECL_CONVERTSCFTOOPENMPPASS
 #include "mlir/Conversion/Passes.h.inc"
 
-std::unique_ptr<OperationPass<ModuleOp>> createConvertSCFToOpenMPPass();
-
 } // namespace mlir
 
 #endif // MLIR_CONVERSION_SCFTOOPENMP_SCFTOOPENMP_H
index a4acf04..78e63a5 100644 (file)
@@ -26,7 +26,7 @@
 #include "mlir/Transforms/DialectConversion.h"
 
 namespace mlir {
-#define GEN_PASS_DEF_CONVERTSCFTOOPENMP
+#define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
 #include "mlir/Conversion/Passes.h.inc"
 } // namespace mlir
 
@@ -212,22 +212,32 @@ static omp::ReductionDeclareOp createDecl(PatternRewriter &builder,
   return decl;
 }
 
+/// Returns an LLVM pointer type with the given element type, or an opaque
+/// pointer if 'useOpaquePointers' is true.
+static LLVM::LLVMPointerType getPointerType(Type elementType,
+                                            bool useOpaquePointers) {
+  if (useOpaquePointers)
+    return LLVM::LLVMPointerType::get(elementType.getContext());
+  return LLVM::LLVMPointerType::get(elementType);
+}
+
 /// Adds an atomic reduction combiner to the given OpenMP reduction declaration
 /// using llvm.atomicrmw of the given kind.
 static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder,
                                             LLVM::AtomicBinOp atomicKind,
                                             omp::ReductionDeclareOp decl,
-                                            scf::ReduceOp reduce) {
+                                            scf::ReduceOp reduce,
+                                            bool useOpaquePointers) {
   OpBuilder::InsertionGuard guard(builder);
   Type type = reduce.getOperand().getType();
-  Type ptrType = LLVM::LLVMPointerType::get(type);
+  Type ptrType = getPointerType(type, useOpaquePointers);
   Location reduceOperandLoc = reduce.getOperand().getLoc();
   builder.createBlock(&decl.getAtomicReductionRegion(),
                       decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
                       {reduceOperandLoc, reduceOperandLoc});
   Block *atomicBlock = &decl.getAtomicReductionRegion().back();
   builder.setInsertionPointToEnd(atomicBlock);
-  Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(),
+  Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(), decl.getType(),
                                               atomicBlock->getArgument(1));
   builder.create<LLVM::AtomicRMWOp>(reduce.getLoc(), atomicKind,
                                     atomicBlock->getArgument(0), loaded,
@@ -241,7 +251,8 @@ static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder,
 /// the neutral value, necessary for the OpenMP declaration. If the reduction
 /// cannot be recognized, returns null.
 static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
-                                                scf::ReduceOp reduce) {
+                                                scf::ReduceOp reduce,
+                                                bool useOpaquePointers) {
   Operation *container = SymbolTable::getNearestSymbolTable(reduce);
   SymbolTable symbolTable(container);
 
@@ -262,29 +273,34 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
   if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
     omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
                                               builder.getFloatAttr(type, 0.0));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce);
+    return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
+                        useOpaquePointers);
   }
   if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
     omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
                                               builder.getIntegerAttr(type, 0));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce);
+    return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
+                        useOpaquePointers);
   }
   if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
     omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
                                               builder.getIntegerAttr(type, 0));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce);
+    return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
+                        useOpaquePointers);
   }
   if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
     omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
                                               builder.getIntegerAttr(type, 0));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce);
+    return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
+                        useOpaquePointers);
   }
   if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
     omp::ReductionDeclareOp decl = createDecl(
         builder, symbolTable, reduce,
         builder.getIntegerAttr(
             type, llvm::APInt::getAllOnesValue(type.getIntOrFloatBitWidth())));
-    return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce);
+    return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
+                        useOpaquePointers);
   }
 
   // Match simple binary reductions that cannot be expressed with atomicrmw.
@@ -316,7 +332,7 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
         builder, symbolTable, reduce, minMaxValueForSignedInt(type, !isMin));
     return addAtomicRMW(builder,
                         isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
-                        decl, reduce);
+                        decl, reduce, useOpaquePointers);
   }
   if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
           reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
@@ -328,7 +344,7 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
         builder, symbolTable, reduce, minMaxValueForUnsignedInt(type, !isMin));
     return addAtomicRMW(
         builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
-        decl, reduce);
+        decl, reduce, useOpaquePointers);
   }
 
   return nullptr;
@@ -337,7 +353,12 @@ static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
 namespace {
 
 struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
-  using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
+
+  bool useOpaquePointers;
+
+  ParallelOpLowering(MLIRContext *context, bool useOpaquePointers)
+      : OpRewritePattern<scf::ParallelOp>(context),
+        useOpaquePointers(useOpaquePointers) {}
 
   LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
                                 PatternRewriter &rewriter) const override {
@@ -346,7 +367,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
     // declaration and use it instead of redeclaring.
     SmallVector<Attribute> reductionDeclSymbols;
     for (auto reduce : parallelOp.getOps<scf::ReduceOp>()) {
-      omp::ReductionDeclareOp decl = declareReduction(rewriter, reduce);
+      omp::ReductionDeclareOp decl =
+          declareReduction(rewriter, reduce, useOpaquePointers);
       if (!decl)
         return failure();
       reductionDeclSymbols.push_back(
@@ -366,7 +388,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
              "cannot create a reduction variable if the type is not an LLVM "
              "pointer element");
       Value storage = rewriter.create<LLVM::AllocaOp>(
-          loc, LLVM::LLVMPointerType::get(init.getType()), one, 0);
+          loc, getPointerType(init.getType(), useOpaquePointers),
+          init.getType(), one, 0);
       rewriter.create<LLVM::StoreOp>(loc, init, storage);
       reductionVariables.push_back(storage);
     }
@@ -426,8 +449,9 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
     // Load loop results.
     SmallVector<Value> results;
     results.reserve(reductionVariables.size());
-    for (Value variable : reductionVariables) {
-      Value res = rewriter.create<LLVM::LoadOp>(loc, variable);
+    for (auto [variable, type] :
+         llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
+      Value res = rewriter.create<LLVM::LoadOp>(loc, type, variable);
       results.push_back(res);
     }
     rewriter.replaceOp(parallelOp, results);
@@ -437,29 +461,29 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
 };
 
 /// Applies the conversion patterns in the given function.
-static LogicalResult applyPatterns(ModuleOp module) {
+static LogicalResult applyPatterns(ModuleOp module, bool useOpaquePointers) {
   ConversionTarget target(*module.getContext());
   target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
   target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
                          memref::MemRefDialect>();
 
   RewritePatternSet patterns(module.getContext());
-  patterns.add<ParallelOpLowering>(module.getContext());
+  patterns.add<ParallelOpLowering>(module.getContext(), useOpaquePointers);
   FrozenRewritePatternSet frozen(std::move(patterns));
   return applyPartialConversion(module, target, frozen);
 }
 
 /// A pass converting SCF operations to OpenMP operations.
-struct SCFToOpenMPPass : public impl::ConvertSCFToOpenMPBase<SCFToOpenMPPass> {
+struct SCFToOpenMPPass
+    : public impl::ConvertSCFToOpenMPPassBase<SCFToOpenMPPass> {
+
+  using Base::Base;
+
   /// Pass entry point.
   void runOnOperation() override {
-    if (failed(applyPatterns(getOperation())))
+    if (failed(applyPatterns(getOperation(), useOpaquePointers)))
       signalPassFailure();
   }
 };
 
 } // namespace
-
-std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertSCFToOpenMPPass() {
-  return std::make_unique<SCFToOpenMPPass>();
-}
index d71f757..4cf5f1b 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-scf-to-openmp -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -convert-scf-to-openmp='use-opaque-pointers=1' -split-input-file %s | FileCheck %s
 
 // CHECK: omp.reduction.declare @[[$REDF:.*]] : f32
 
@@ -12,8 +12,8 @@
 // CHECK: omp.yield(%[[RES]] : f32)
 
 // CHECK: atomic
-// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr<f32>, %[[ARG1:.*]]: !llvm.ptr<f32>):
-// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]]
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr):
+// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]] : !llvm.ptr -> f32
 // CHECK: llvm.atomicrmw fadd %[[ARG0]], %[[RHS]] monotonic
 
 // CHECK-LABEL: @reduction1
@@ -143,8 +143,8 @@ func.func @reduction3(%arg0 : index, %arg1 : index, %arg2 : index,
 // CHECK: omp.yield(%[[RES]] : i64)
 
 // CHECK: atomic
-// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr<i64>, %[[ARG1:.*]]: !llvm.ptr<i64>):
-// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]]
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr, %[[ARG1:.*]]: !llvm.ptr):
+// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]] : !llvm.ptr -> i64
 // CHECK: llvm.atomicrmw max %[[ARG0]], %[[RHS]] monotonic
 
 // CHECK-LABEL: @reduction4
@@ -187,8 +187,8 @@ func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
     // CHECK: omp.yield
   }
   // CHECK: omp.terminator
-  // CHECK: %[[RES1:.*]] = llvm.load %[[BUF1]]
-  // CHECK: %[[RES2:.*]] = llvm.load %[[BUF2]]
+  // CHECK: %[[RES1:.*]] = llvm.load %[[BUF1]] : !llvm.ptr -> f32
+  // CHECK: %[[RES2:.*]] = llvm.load %[[BUF2]] : !llvm.ptr -> i64
   // CHECK: return %[[RES1]], %[[RES2]]
   return %res#0, %res#1 : f32, i64
 }
index e0fdcae..508052d 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -convert-scf-to-openmp %s | FileCheck %s
+// RUN: mlir-opt -convert-scf-to-openmp='use-opaque-pointers=1' %s | FileCheck %s
 
 // CHECK-LABEL: @parallel
 func.func @parallel(%arg0: index, %arg1: index, %arg2: index,
diff --git a/mlir/test/Conversion/SCFToOpenMP/typed-pointers.mlir b/mlir/test/Conversion/SCFToOpenMP/typed-pointers.mlir
new file mode 100644 (file)
index 0000000..fb90c5d
--- /dev/null
@@ -0,0 +1,78 @@
+// RUN: mlir-opt -convert-scf-to-openmp='use-opaque-pointers=0' -split-input-file %s | FileCheck %s
+
+// CHECK: omp.reduction.declare @[[$REDF1:.*]] : f32
+
+// CHECK: init
+// CHECK: %[[INIT:.*]] = llvm.mlir.constant(-3.4
+// CHECK: omp.yield(%[[INIT]] : f32)
+
+// CHECK: combiner
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
+// CHECK: %[[CMP:.*]] = arith.cmpf oge, %[[ARG0]], %[[ARG1]]
+// CHECK: %[[RES:.*]] = arith.select %[[CMP]], %[[ARG0]], %[[ARG1]]
+// CHECK: omp.yield(%[[RES]] : f32)
+
+// CHECK-NOT: atomic
+
+// CHECK: omp.reduction.declare @[[$REDF2:.*]] : i64
+
+// CHECK: init
+// CHECK: %[[INIT:.*]] = llvm.mlir.constant
+// CHECK: omp.yield(%[[INIT]] : i64)
+
+// CHECK: combiner
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64)
+// CHECK: %[[CMP:.*]] = arith.cmpi slt, %[[ARG0]], %[[ARG1]]
+// CHECK: %[[RES:.*]] = arith.select %[[CMP]], %[[ARG1]], %[[ARG0]]
+// CHECK: omp.yield(%[[RES]] : i64)
+
+// CHECK: atomic
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr<i64>, %[[ARG1:.*]]: !llvm.ptr<i64>):
+// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]]
+// CHECK: llvm.atomicrmw max %[[ARG0]], %[[RHS]] monotonic
+
+// CHECK-LABEL: @reduction4
+func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
+                 %arg3 : index, %arg4 : index) -> (f32, i64) {
+  %step = arith.constant 1 : index
+  // CHECK: %[[ZERO:.*]] = arith.constant 0.0
+  %zero = arith.constant 0.0 : f32
+  // CHECK: %[[IONE:.*]] = arith.constant 1
+  %ione = arith.constant 1 : i64
+  // CHECK: %[[BUF1:.*]] = llvm.alloca %{{.*}} x f32
+  // CHECK: llvm.store %[[ZERO]], %[[BUF1]]
+  // CHECK: %[[BUF2:.*]] = llvm.alloca %{{.*}} x i64
+  // CHECK: llvm.store %[[IONE]], %[[BUF2]]
+
+  // CHECK: omp.parallel
+  // CHECK: omp.wsloop
+  // CHECK-SAME: reduction(@[[$REDF1]] -> %[[BUF1]]
+  // CHECK-SAME:           @[[$REDF2]] -> %[[BUF2]]
+  // CHECK: memref.alloca_scope
+  %res:2 = scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
+                        step (%arg4, %step) init (%zero, %ione) -> (f32, i64) {
+    %one = arith.constant 1.0 : f32
+    // CHECK: omp.reduction %{{.*}}, %[[BUF1]]
+    scf.reduce(%one) : f32 {
+    ^bb0(%lhs : f32, %rhs: f32):
+      %cmp = arith.cmpf oge, %lhs, %rhs : f32
+      %res = arith.select %cmp, %lhs, %rhs : f32
+      scf.reduce.return %res : f32
+    }
+    // CHECK: arith.fptosi
+    %1 = arith.fptosi %one : f32 to i64
+    // CHECK: omp.reduction %{{.*}}, %[[BUF2]]
+    scf.reduce(%1) : i64 {
+    ^bb1(%lhs: i64, %rhs: i64):
+      %cmp = arith.cmpi slt, %lhs, %rhs : i64
+      %res = arith.select %cmp, %rhs, %lhs : i64
+      scf.reduce.return %res : i64
+    }
+    // CHECK: omp.yield
+  }
+  // CHECK: omp.terminator
+  // CHECK: %[[RES1:.*]] = llvm.load %[[BUF1]] : !llvm.ptr<f32>
+  // CHECK: %[[RES2:.*]] = llvm.load %[[BUF2]] : !llvm.ptr<i64>
+  // CHECK: return %[[RES1]], %[[RES2]]
+  return %res#0, %res#1 : f32, i64
+}