From f49b6afc231242dfee027d5da69734836097cd43 Mon Sep 17 00:00:00 2001 From: Nimish Mishra Date: Mon, 27 Feb 2023 17:24:02 +0530 Subject: [PATCH] [flang][OpenMP] Handle lastprivate on sections construct This patch adds support for lastprivate on sections construct. One omp.sections operation can have several omp.section operation. As such, the privatization happens in the lexically last omp.section operation. Reviewed By: kiranchandramohan, peixin Differential Revision: https://reviews.llvm.org/D133686 --- flang/lib/Lower/OpenMP.cpp | 66 ++++++++++++++++++++- flang/test/Lower/OpenMP/sections.f90 | 110 +++++++++++++++++++++++++++++++++++ 2 files changed, 173 insertions(+), 3 deletions(-) diff --git a/flang/lib/Lower/OpenMP.cpp b/flang/lib/Lower/OpenMP.cpp index ea1c371..da58962 100644 --- a/flang/lib/Lower/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP.cpp @@ -120,8 +120,68 @@ static bool privatizeVars(Op &op, Fortran::lower::AbstractConverter &converter, } else if (const auto &lastPrivateClause = std::get_if( &clause.u)) { - // TODO: Add lastprivate support for sections construct, simd construct - if (std::is_same_v) { + // TODO: Add lastprivate support for simd construct + if (std::is_same_v) { + if (&eval == &eval.parentConstruct->getLastNestedEvaluation()) { + // For `omp.sections`, lastprivatized variables occur in + // lexically final `omp.section` operation. The following FIR + // shall be generated for the same: + // + // omp.sections lastprivate(...) { + // omp.section {...} + // omp.section {...} + // omp.section { + // fir.allocate for `private`/`firstprivate` + // + // scf.if %true { + // ^%lpv_update_blk + // } + // } + // } + // + // To keep code consistency while handling privatization + // through this control flow, add a `scf.if` operation + // that always evaluates to true, in order to create + // a dedicated sub-region in `omp.section` where + // lastprivate FIR can reside. Later canonicalizations + // will optimize away this operation. + + omp::SectionOp *sectionOp = dyn_cast(&op); + mlir::scf::IfOp ifOp = firOpBuilder.create( + sectionOp->getLoc(), + firOpBuilder.createIntegerConstant( + sectionOp->getLoc(), firOpBuilder.getIntegerType(1), 0x1), + /*else*/ false); + firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + + const Fortran::parser::OpenMPConstruct *parentOmpConstruct = + eval.parentConstruct->getIf(); + assert(parentOmpConstruct && + "Expected a valid enclosing OpenMP construct"); + const Fortran::parser::OpenMPSectionsConstruct *sectionsConstruct = + std::get_if( + &parentOmpConstruct->u); + assert(sectionsConstruct && + "Expected an enclosing omp.sections construct"); + const Fortran::parser::OmpClauseList §ionsEndClauseList = + std::get( + std::get( + sectionsConstruct->t) + .t); + for (const Fortran::parser::OmpClause &otherClause : + sectionsEndClauseList.v) + if (std::get_if(&otherClause.u)) + // Emit implicit barrier to synchronize threads and avoid data + // races on post-update of lastprivate variables when `nowait` + // clause is present. + firOpBuilder.create( + converter.getCurrentLocation()); + firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front()); + lastPrivIP = firOpBuilder.saveInsertionPoint(); + firOpBuilder.setInsertionPoint(ifOp); + insPt = firOpBuilder.saveInsertionPoint(); + } + } else if (std::is_same_v) { omp::WsLoopOp *wsLoopOp = dyn_cast(&op); mlir::Operation *lastOper = wsLoopOp->getRegion().back().getTerminator(); @@ -549,7 +609,7 @@ createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter, // new control flow, changes the insertion point, // thus restore it. // TODO: Clean up later a bit to avoid this many sets and resets. - if (lastPrivateOp) + if (lastPrivateOp && !std::is_same_v) resetBeforeTerminator(firOpBuilder, storeOp, block); } diff --git a/flang/test/Lower/OpenMP/sections.f90 b/flang/test/Lower/OpenMP/sections.f90 index 358f317..3bec578 100644 --- a/flang/test/Lower/OpenMP/sections.f90 +++ b/flang/test/Lower/OpenMP/sections.f90 @@ -108,3 +108,113 @@ subroutine firstprivate(alpha) alpha = alpha * 5 !$omp end sections end subroutine + +subroutine lastprivate() + integer :: x +!CHECK: %[[X:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlastprivateEx"} +!CHECK: omp.sections { + !$omp sections lastprivate(x) +!CHECK: omp.section { +!CHECK: %[[PRIVATE_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFlastprivateEx"} +!CHECK: %[[const:.*]] = arith.constant 10 : i32 +!CHECK: %[[temp:.*]] = fir.load %[[PRIVATE_X]] : !fir.ref +!CHECK: %[[result:.*]] = arith.muli %c10_i32, %[[temp]] : i32 +!CHECK: fir.store %[[result]] to %[[PRIVATE_X]] : !fir.ref +!CHECK: omp.terminator +!CHECK: } + !$omp section + x = x * 10 +!CHECK: omp.section { +!CHECK: %[[PRIVATE_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFlastprivateEx"} +!CHECK: %[[true:.*]] = arith.constant true +!CHECK: %[[temp:.*]] = fir.load %[[PRIVATE_X]] : !fir.ref +!CHECK: %[[const:.*]] = arith.constant 1 : i32 +!CHECK: %[[result:.*]] = arith.addi %[[temp]], %[[const]] : i32 +!CHECK: fir.store %[[result]] to %[[PRIVATE_X]] : !fir.ref +!CHECK: scf.if %[[true]] { +!CHECK: %[[temp:.*]] = fir.load %[[PRIVATE_X]] : !fir.ref +!CHECK: fir.store %[[temp]] to %[[X]] : !fir.ref +!CHECK: } +!CHECK: omp.terminator +!CHECK: } + !$omp section + x = x + 1 +!CHECK: omp.terminator +!CHECK: } + !$omp end sections + +!CHECK: omp.sections { + !$omp sections firstprivate(x) lastprivate(x) +!CHECK: omp.section { +!CHECK: %[[PRIVATE_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFlastprivateEx"} +!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref +!CHECK: fir.store %[[temp]] to %[[PRIVATE_X]] : !fir.ref +!CHECK: omp.barrier +!CHECK: %[[const:.*]] = arith.constant 10 : i32 +!CHECK: %[[temp:.*]] = fir.load %[[PRIVATE_X]] : !fir.ref +!CHECK: %[[result:.*]] = arith.muli %c10_i32, %[[temp]] : i32 +!CHECK: fir.store %[[result]] to %[[PRIVATE_X]] : !fir.ref +!CHECK: omp.terminator +!CHECK: } + !$omp section + x = x * 10 +!CHECK: omp.section { +!CHECK: %[[PRIVATE_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFlastprivateEx"} +!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref +!CHECK: fir.store %[[temp]] to %[[PRIVATE_X]] : !fir.ref +!CHECK: omp.barrier +!CHECK: %[[true:.*]] = arith.constant true +!CHECK: %[[temp:.*]] = fir.load %[[PRIVATE_X]] : !fir.ref +!CHECK: %[[const:.*]] = arith.constant 1 : i32 +!CHECK: %[[result:.*]] = arith.addi %[[temp]], %[[const]] : i32 +!CHECK: fir.store %[[result]] to %[[PRIVATE_X]] : !fir.ref +!CHECK: scf.if %true { +!CHECK: %[[temp:.*]] = fir.load %[[PRIVATE_X]] : !fir.ref +!CHECK: fir.store %[[temp]] to %[[X]] : !fir.ref +!CHECK: } +!CHECK: omp.terminator +!CHECK: } + !$omp section + x = x + 1 +!CHECK: omp.terminator +!CHECK: } + !$omp end sections + +!CHECK: omp.sections nowait { + !$omp sections firstprivate(x) lastprivate(x) +!CHECK: omp.section { +!CHECK: %[[PRIVATE_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFlastprivateEx"} +!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref +!CHECK: fir.store %[[temp]] to %[[PRIVATE_X]] : !fir.ref +!CHECK: omp.barrier +!CHECK: %[[const:.*]] = arith.constant 10 : i32 +!CHECK: %[[temp:.*]] = fir.load %[[PRIVATE_X]] : !fir.ref +!CHECK: %[[result:.*]] = arith.muli %c10_i32, %[[temp]] : i32 +!CHECK: fir.store %[[result]] to %[[PRIVATE_X]] : !fir.ref +!CHECK: omp.terminator +!CHECK: } + !$omp section + x = x * 10 +!CHECK: omp.section { +!CHECK: %[[PRIVATE_X:.*]] = fir.alloca i32 {bindc_name = "x", pinned, uniq_name = "_QFlastprivateEx"} +!CHECK: %[[temp:.*]] = fir.load %[[X]] : !fir.ref +!CHECK: fir.store %[[temp]] to %[[PRIVATE_X]] : !fir.ref +!CHECK: omp.barrier +!CHECK: %[[true:.*]] = arith.constant true +!CHECK: %[[temp:.*]] = fir.load %[[PRIVATE_X]] : !fir.ref +!CHECK: %[[const:.*]] = arith.constant 1 : i32 +!CHECK: %[[result:.*]] = arith.addi %[[temp]], %[[const]] : i32 +!CHECK: fir.store %[[result]] to %[[PRIVATE_X]] : !fir.ref +!CHECK: scf.if %true { +!CHECK: %[[temp:.*]] = fir.load %[[PRIVATE_X]] : !fir.ref +!CHECK: fir.store %[[temp]] to %[[X]] : !fir.ref +!CHECK: omp.barrier +!CHECK: } +!CHECK: omp.terminator +!CHECK: } + !$omp section + x = x + 1 +!CHECK: omp.terminator +!CHECK: } + !$omp end sections nowait +end subroutine -- 2.7.4