#include "mlir/Transforms/DialectConversion.h"
namespace mlir {
-#define GEN_PASS_DEF_CONVERTSCFTOOPENMP
+#define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
#include "mlir/Conversion/Passes.h.inc"
} // namespace mlir
return decl;
}
+/// Returns an LLVM pointer type with the given element type, or an opaque
+/// pointer if 'useOpaquePointers' is true.
+static LLVM::LLVMPointerType getPointerType(Type elementType,
+ bool useOpaquePointers) {
+ if (useOpaquePointers)
+ return LLVM::LLVMPointerType::get(elementType.getContext());
+ return LLVM::LLVMPointerType::get(elementType);
+}
+
/// Adds an atomic reduction combiner to the given OpenMP reduction declaration
/// using llvm.atomicrmw of the given kind.
static omp::ReductionDeclareOp addAtomicRMW(OpBuilder &builder,
LLVM::AtomicBinOp atomicKind,
omp::ReductionDeclareOp decl,
- scf::ReduceOp reduce) {
+ scf::ReduceOp reduce,
+ bool useOpaquePointers) {
OpBuilder::InsertionGuard guard(builder);
Type type = reduce.getOperand().getType();
- Type ptrType = LLVM::LLVMPointerType::get(type);
+ Type ptrType = getPointerType(type, useOpaquePointers);
Location reduceOperandLoc = reduce.getOperand().getLoc();
builder.createBlock(&decl.getAtomicReductionRegion(),
decl.getAtomicReductionRegion().end(), {ptrType, ptrType},
{reduceOperandLoc, reduceOperandLoc});
Block *atomicBlock = &decl.getAtomicReductionRegion().back();
builder.setInsertionPointToEnd(atomicBlock);
- Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(),
+ Value loaded = builder.create<LLVM::LoadOp>(reduce.getLoc(), decl.getType(),
atomicBlock->getArgument(1));
builder.create<LLVM::AtomicRMWOp>(reduce.getLoc(), atomicKind,
atomicBlock->getArgument(0), loaded,
/// the neutral value, necessary for the OpenMP declaration. If the reduction
/// cannot be recognized, returns null.
static omp::ReductionDeclareOp declareReduction(PatternRewriter &builder,
- scf::ReduceOp reduce) {
+ scf::ReduceOp reduce,
+ bool useOpaquePointers) {
Operation *container = SymbolTable::getNearestSymbolTable(reduce);
SymbolTable symbolTable(container);
if (matchSimpleReduction<arith::AddFOp, LLVM::FAddOp>(reduction)) {
omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
builder.getFloatAttr(type, 0.0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce);
+ return addAtomicRMW(builder, LLVM::AtomicBinOp::fadd, decl, reduce,
+ useOpaquePointers);
}
if (matchSimpleReduction<arith::AddIOp, LLVM::AddOp>(reduction)) {
omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce);
+ return addAtomicRMW(builder, LLVM::AtomicBinOp::add, decl, reduce,
+ useOpaquePointers);
}
if (matchSimpleReduction<arith::OrIOp, LLVM::OrOp>(reduction)) {
omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce);
+ return addAtomicRMW(builder, LLVM::AtomicBinOp::_or, decl, reduce,
+ useOpaquePointers);
}
if (matchSimpleReduction<arith::XOrIOp, LLVM::XOrOp>(reduction)) {
omp::ReductionDeclareOp decl = createDecl(builder, symbolTable, reduce,
builder.getIntegerAttr(type, 0));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce);
+ return addAtomicRMW(builder, LLVM::AtomicBinOp::_xor, decl, reduce,
+ useOpaquePointers);
}
if (matchSimpleReduction<arith::AndIOp, LLVM::AndOp>(reduction)) {
omp::ReductionDeclareOp decl = createDecl(
builder, symbolTable, reduce,
builder.getIntegerAttr(
type, llvm::APInt::getAllOnesValue(type.getIntOrFloatBitWidth())));
- return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce);
+ return addAtomicRMW(builder, LLVM::AtomicBinOp::_and, decl, reduce,
+ useOpaquePointers);
}
// Match simple binary reductions that cannot be expressed with atomicrmw.
builder, symbolTable, reduce, minMaxValueForSignedInt(type, !isMin));
return addAtomicRMW(builder,
isMin ? LLVM::AtomicBinOp::min : LLVM::AtomicBinOp::max,
- decl, reduce);
+ decl, reduce, useOpaquePointers);
}
if (matchSelectReduction<arith::CmpIOp, arith::SelectOp>(
reduction, {arith::CmpIPredicate::ult, arith::CmpIPredicate::ule},
builder, symbolTable, reduce, minMaxValueForUnsignedInt(type, !isMin));
return addAtomicRMW(
builder, isMin ? LLVM::AtomicBinOp::umin : LLVM::AtomicBinOp::umax,
- decl, reduce);
+ decl, reduce, useOpaquePointers);
}
return nullptr;
namespace {
struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
- using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;
+
+ bool useOpaquePointers;
+
+ ParallelOpLowering(MLIRContext *context, bool useOpaquePointers)
+ : OpRewritePattern<scf::ParallelOp>(context),
+ useOpaquePointers(useOpaquePointers) {}
LogicalResult matchAndRewrite(scf::ParallelOp parallelOp,
PatternRewriter &rewriter) const override {
// declaration and use it instead of redeclaring.
SmallVector<Attribute> reductionDeclSymbols;
for (auto reduce : parallelOp.getOps<scf::ReduceOp>()) {
- omp::ReductionDeclareOp decl = declareReduction(rewriter, reduce);
+ omp::ReductionDeclareOp decl =
+ declareReduction(rewriter, reduce, useOpaquePointers);
if (!decl)
return failure();
reductionDeclSymbols.push_back(
"cannot create a reduction variable if the type is not an LLVM "
"pointer element");
Value storage = rewriter.create<LLVM::AllocaOp>(
- loc, LLVM::LLVMPointerType::get(init.getType()), one, 0);
+ loc, getPointerType(init.getType(), useOpaquePointers),
+ init.getType(), one, 0);
rewriter.create<LLVM::StoreOp>(loc, init, storage);
reductionVariables.push_back(storage);
}
// Load loop results.
SmallVector<Value> results;
results.reserve(reductionVariables.size());
- for (Value variable : reductionVariables) {
- Value res = rewriter.create<LLVM::LoadOp>(loc, variable);
+ for (auto [variable, type] :
+ llvm::zip(reductionVariables, parallelOp.getResultTypes())) {
+ Value res = rewriter.create<LLVM::LoadOp>(loc, type, variable);
results.push_back(res);
}
rewriter.replaceOp(parallelOp, results);
};
/// Applies the conversion patterns in the given function.
-static LogicalResult applyPatterns(ModuleOp module) {
+static LogicalResult applyPatterns(ModuleOp module, bool useOpaquePointers) {
ConversionTarget target(*module.getContext());
target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
memref::MemRefDialect>();
RewritePatternSet patterns(module.getContext());
- patterns.add<ParallelOpLowering>(module.getContext());
+ patterns.add<ParallelOpLowering>(module.getContext(), useOpaquePointers);
FrozenRewritePatternSet frozen(std::move(patterns));
return applyPartialConversion(module, target, frozen);
}
/// A pass converting SCF operations to OpenMP operations.
-struct SCFToOpenMPPass : public impl::ConvertSCFToOpenMPBase<SCFToOpenMPPass> {
+struct SCFToOpenMPPass
+ : public impl::ConvertSCFToOpenMPPassBase<SCFToOpenMPPass> {
+
+ using Base::Base;
+
/// Pass entry point.
void runOnOperation() override {
- if (failed(applyPatterns(getOperation())))
+ if (failed(applyPatterns(getOperation(), useOpaquePointers)))
signalPassFailure();
}
};
} // namespace
-
-std::unique_ptr<OperationPass<ModuleOp>> mlir::createConvertSCFToOpenMPPass() {
- return std::make_unique<SCFToOpenMPPass>();
-}
--- /dev/null
+// RUN: mlir-opt -convert-scf-to-openmp='use-opaque-pointers=0' -split-input-file %s | FileCheck %s
+
+// CHECK: omp.reduction.declare @[[$REDF1:.*]] : f32
+
+// CHECK: init
+// CHECK: %[[INIT:.*]] = llvm.mlir.constant(-3.4
+// CHECK: omp.yield(%[[INIT]] : f32)
+
+// CHECK: combiner
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32)
+// CHECK: %[[CMP:.*]] = arith.cmpf oge, %[[ARG0]], %[[ARG1]]
+// CHECK: %[[RES:.*]] = arith.select %[[CMP]], %[[ARG0]], %[[ARG1]]
+// CHECK: omp.yield(%[[RES]] : f32)
+
+// CHECK-NOT: atomic
+
+// CHECK: omp.reduction.declare @[[$REDF2:.*]] : i64
+
+// CHECK: init
+// CHECK: %[[INIT:.*]] = llvm.mlir.constant
+// CHECK: omp.yield(%[[INIT]] : i64)
+
+// CHECK: combiner
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: i64, %[[ARG1:.*]]: i64)
+// CHECK: %[[CMP:.*]] = arith.cmpi slt, %[[ARG0]], %[[ARG1]]
+// CHECK: %[[RES:.*]] = arith.select %[[CMP]], %[[ARG1]], %[[ARG0]]
+// CHECK: omp.yield(%[[RES]] : i64)
+
+// CHECK: atomic
+// CHECK: ^{{.*}}(%[[ARG0:.*]]: !llvm.ptr<i64>, %[[ARG1:.*]]: !llvm.ptr<i64>):
+// CHECK: %[[RHS:.*]] = llvm.load %[[ARG1]]
+// CHECK: llvm.atomicrmw max %[[ARG0]], %[[RHS]] monotonic
+
+// CHECK-LABEL: @reduction4
+func.func @reduction4(%arg0 : index, %arg1 : index, %arg2 : index,
+ %arg3 : index, %arg4 : index) -> (f32, i64) {
+ %step = arith.constant 1 : index
+ // CHECK: %[[ZERO:.*]] = arith.constant 0.0
+ %zero = arith.constant 0.0 : f32
+ // CHECK: %[[IONE:.*]] = arith.constant 1
+ %ione = arith.constant 1 : i64
+ // CHECK: %[[BUF1:.*]] = llvm.alloca %{{.*}} x f32
+ // CHECK: llvm.store %[[ZERO]], %[[BUF1]]
+ // CHECK: %[[BUF2:.*]] = llvm.alloca %{{.*}} x i64
+ // CHECK: llvm.store %[[IONE]], %[[BUF2]]
+
+ // CHECK: omp.parallel
+ // CHECK: omp.wsloop
+ // CHECK-SAME: reduction(@[[$REDF1]] -> %[[BUF1]]
+ // CHECK-SAME: @[[$REDF2]] -> %[[BUF2]]
+ // CHECK: memref.alloca_scope
+ %res:2 = scf.parallel (%i0, %i1) = (%arg0, %arg1) to (%arg2, %arg3)
+ step (%arg4, %step) init (%zero, %ione) -> (f32, i64) {
+ %one = arith.constant 1.0 : f32
+ // CHECK: omp.reduction %{{.*}}, %[[BUF1]]
+ scf.reduce(%one) : f32 {
+ ^bb0(%lhs : f32, %rhs: f32):
+ %cmp = arith.cmpf oge, %lhs, %rhs : f32
+ %res = arith.select %cmp, %lhs, %rhs : f32
+ scf.reduce.return %res : f32
+ }
+ // CHECK: arith.fptosi
+ %1 = arith.fptosi %one : f32 to i64
+ // CHECK: omp.reduction %{{.*}}, %[[BUF2]]
+ scf.reduce(%1) : i64 {
+ ^bb1(%lhs: i64, %rhs: i64):
+ %cmp = arith.cmpi slt, %lhs, %rhs : i64
+ %res = arith.select %cmp, %rhs, %lhs : i64
+ scf.reduce.return %res : i64
+ }
+ // CHECK: omp.yield
+ }
+ // CHECK: omp.terminator
+ // CHECK: %[[RES1:.*]] = llvm.load %[[BUF1]] : !llvm.ptr<f32>
+ // CHECK: %[[RES2:.*]] = llvm.load %[[BUF2]] : !llvm.ptr<i64>
+ // CHECK: return %[[RES1]], %[[RES2]]
+ return %res#0, %res#1 : f32, i64
+}