}
// Replace the loop.
+ auto omp = rewriter.create<omp::ParallelOp>(parallelOp.getLoc());
+ Block *block = rewriter.createBlock(&omp.getRegion());
+ rewriter.setInsertionPointToStart(block);
auto loop = rewriter.create<omp::WsLoopOp>(
parallelOp.getLoc(), parallelOp.lowerBound(), parallelOp.upperBound(),
parallelOp.step());
rewriter.inlineRegionBefore(parallelOp.region(), loop.region(),
loop.region().begin());
+ rewriter.create<omp::TerminatorOp>(parallelOp.getLoc());
+
rewriter.eraseOp(parallelOp);
return success();
}
};
-/// Inserts OpenMP "parallel" operations around top-level SCF "parallel"
-/// operations in the given function. This is implemented as a direct IR
-/// modification rather than as a conversion pattern because it does not
-/// modify the top-level operation it matches, which is a requirement for
-/// rewrite patterns.
-//
-// TODO: consider creating nested parallel operations when necessary.
-static void insertOpenMPParallel(FuncOp func) {
- // Collect top-level SCF "parallel" ops.
- SmallVector<scf::ParallelOp, 4> topLevelParallelOps;
- func.walk([&topLevelParallelOps](scf::ParallelOp parallelOp) {
- // Ignore ops that are already within OpenMP parallel construct.
- if (!parallelOp->getParentOfType<scf::ParallelOp>())
- topLevelParallelOps.push_back(parallelOp);
- });
-
- // Wrap SCF ops into OpenMP "parallel" ops.
- for (scf::ParallelOp parallelOp : topLevelParallelOps) {
- OpBuilder builder(parallelOp);
- auto omp = builder.create<omp::ParallelOp>(parallelOp.getLoc());
- Block *block = builder.createBlock(&omp.getRegion());
- builder.create<omp::TerminatorOp>(parallelOp.getLoc());
- block->getOperations().splice(block->begin(),
- parallelOp->getBlock()->getOperations(),
- parallelOp.getOperation());
- }
-}
-
/// Applies the conversion patterns in the given function.
static LogicalResult applyPatterns(FuncOp func) {
ConversionTarget target(*func.getContext());
struct SCFToOpenMPPass : public ConvertSCFToOpenMPBase<SCFToOpenMPPass> {
/// Pass entry point.
void runOnFunction() override {
- insertOpenMPParallel(getFunction());
if (failed(applyPatterns(getFunction())))
signalPassFailure();
}
%arg3: index, %arg4: index, %arg5: index) {
// CHECK: omp.parallel {
// CHECK: omp.wsloop (%[[LVAR_OUT1:.*]]) : index = (%arg0) to (%arg2) step (%arg4) {
- // CHECK-NOT: omp.parallel
scf.parallel (%i) = (%arg0) to (%arg2) step (%arg4) {
+ // CHECK: omp.parallel
// CHECK: omp.wsloop (%[[LVAR_IN1:.*]]) : index = (%arg1) to (%arg3) step (%arg5) {
scf.parallel (%j) = (%arg1) to (%arg3) step (%arg5) {
// CHECK: "test.payload"(%[[LVAR_OUT1]], %[[LVAR_IN1]]) : (index, index) -> ()