[mlir][OpenMP] Add if clause to OpenMP simd construct
authorDominik Adamski <dominik.adamski@amd.com>
Thu, 30 Jun 2022 14:50:03 +0000 (09:50 -0500)
committerDominik Adamski <dominik.adamski@amd.com>
Wed, 6 Jul 2022 12:24:48 +0000 (07:24 -0500)
This patch adds if clause to OpenMP TableGen for simd construct.

Reviewed By: peixin

Differential Revision: https://reviews.llvm.org/D128940

Signed-off-by: Dominik Adamski <dominik.adamski@amd.com>
flang/lib/Lower/OpenMP.cpp
flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
flang/test/Lower/OpenMP/simd.f90
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
mlir/test/Dialect/OpenMP/invalid.mlir
mlir/test/Dialect/OpenMP/ops.mlir
mlir/test/Target/LLVMIR/openmp-llvm.mlir

index 0949cb7..9ac78b5 100644 (file)
@@ -507,6 +507,19 @@ static omp::ClauseProcBindKindAttr genProcBindKindAttr(
   return omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), pbKind);
 }
 
+static mlir::Value
+getIfClauseOperand(Fortran::lower::AbstractConverter &converter,
+                   Fortran::lower::StatementContext &stmtCtx,
+                   const Fortran::parser::OmpClause::If *ifClause) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  mlir::Location currentLocation = converter.getCurrentLocation();
+  auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
+  mlir::Value ifVal = fir::getBase(
+      converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
+  return firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(),
+                                    ifVal);
+}
+
 /* When parallel is used in a combined construct, then use this function to
  * create the parallel operation. It handles the parallel specific clauses
  * and leaves the rest for handling at the inner operations.
@@ -532,11 +545,7 @@ createCombinedParallelOp(Fortran::lower::AbstractConverter &converter,
   for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
     if (const auto &ifClause =
             std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
-      auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
-      mlir::Value ifVal = fir::getBase(
-          converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
-      ifClauseOperand = firOpBuilder.createConvert(
-          currentLocation, firOpBuilder.getI1Type(), ifVal);
+      ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause);
     } else if (const auto &numThreadsClause =
                    std::get_if<Fortran::parser::OmpClause::NumThreads>(
                        &clause.u)) {
@@ -585,11 +594,7 @@ genOMP(Fortran::lower::AbstractConverter &converter,
   for (const auto &clause : opClauseList.v) {
     if (const auto &ifClause =
             std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
-      auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
-      mlir::Value ifVal = fir::getBase(
-          converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
-      ifClauseOperand = firOpBuilder.createConvert(
-          currentLocation, firOpBuilder.getI1Type(), ifVal);
+      ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause);
     } else if (const auto &numThreadsClause =
                    std::get_if<Fortran::parser::OmpClause::NumThreads>(
                        &clause.u)) {
@@ -760,9 +765,10 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
   mlir::Location currentLocation = converter.getCurrentLocation();
   llvm::SmallVector<mlir::Value> lowerBound, upperBound, step, linearVars,
       linearStepVars, reductionVars;
-  mlir::Value scheduleChunkClauseOperand;
+  mlir::Value scheduleChunkClauseOperand, ifClauseOperand;
   mlir::Attribute scheduleClauseOperand, noWaitClauseOperand,
       orderedClauseOperand, orderClauseOperand;
+  Fortran::lower::StatementContext stmtCtx;
   const auto &loopOpClauseList = std::get<Fortran::parser::OmpClauseList>(
       std::get<Fortran::parser::OmpBeginLoopDirective>(loopConstruct.t).t);
 
@@ -823,11 +829,13 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
               std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
                   scheduleClause->v.t)) {
         if (const auto *expr = Fortran::semantics::GetExpr(*chunkExpr)) {
-          Fortran::lower::StatementContext stmtCtx;
           scheduleChunkClauseOperand =
               fir::getBase(converter.genExprValue(*expr, stmtCtx));
         }
       }
+    } else if (const auto &ifClause =
+                   std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
+      ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause);
     }
   }
 
@@ -848,7 +856,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
   if (llvm::omp::OMPD_simd == ompDirective) {
     TypeRange resultType;
     auto SimdLoopOp = firOpBuilder.create<mlir::omp::SimdLoopOp>(
-        currentLocation, resultType, lowerBound, upperBound, step);
+        currentLocation, resultType, lowerBound, upperBound, step,
+        ifClauseOperand, /*inclusive=*/firOpBuilder.getUnitAttr());
     createBodyOfOp<omp::SimdLoopOp>(SimdLoopOp, converter, currentLocation,
                                     eval, &loopOpClauseList, iv);
     return;
index 32e7bfc..1e2d07c 100644 (file)
@@ -181,7 +181,7 @@ func.func @_QPsimd1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref
   omp.parallel  {
     %1 = fir.alloca i32 {adapt.valuebyref, pinned}
     %2 = fir.load %arg0 : !fir.ref<i32>
-    omp.simdloop (%arg2) : i32 = (%c1_i32) to (%2) step (%c1_i32)  {
+    omp.simdloop for (%arg2) : i32 = (%c1_i32) to (%2) step (%c1_i32)  {
       fir.store %arg2 to %1 : !fir.ref<i32>
       %3 = fir.load %1 : !fir.ref<i32>
       %4 = fir.convert %3 : (i32) -> i64
index 1d08fb7..df4489f 100644 (file)
@@ -9,7 +9,7 @@ integer :: i
   ! CHECK: %[[LB:.*]] = arith.constant 1 : i32
   ! CHECK-NEXT: %[[UB:.*]] = arith.constant 9 : i32
   ! CHECK-NEXT: %[[STEP:.*]] = arith.constant 1 : i32
-  ! CHECK-NEXT: omp.simdloop (%[[I:.*]]) : i32 = (%[[LB]]) to (%[[UB]]) step (%[[STEP]]) { 
+  ! CHECK-NEXT: omp.simdloop for (%[[I:.*]]) : i32 = (%[[LB]]) to (%[[UB]]) inclusive step (%[[STEP]]) {
   do i=1, 9
     ! CHECK: fir.store %[[I]] to %[[LOCAL:.*]] : !fir.ref<i32>
     ! CHECK: %[[LD:.*]] = fir.load %[[LOCAL]] : !fir.ref<i32>
@@ -18,3 +18,21 @@ integer :: i
   end do
   !$OMP END SIMD 
 end subroutine
+
+!CHECK-LABEL: func @_QPsimdloop_with_if_clause
+subroutine simdloop_with_if_clause(n, threshold)
+integer :: i, n, threshold
+  !$OMP SIMD IF( n .GE. threshold )
+  ! CHECK: %[[LB:.*]] = arith.constant 1 : i32
+  ! CHECK: %[[UB:.*]] = fir.load %arg0
+  ! CHECK: %[[STEP:.*]] = arith.constant 1 : i32
+  ! CHECK: %[[COND:.*]] = arith.cmpi sge
+  ! CHECK: omp.simdloop if(%[[COND:.*]]) for (%[[I:.*]]) : i32 = (%[[LB]]) to (%[[UB]]) inclusive  step (%[[STEP]]) {
+  do i = 1, n
+    ! CHECK: fir.store %[[I]] to %[[LOCAL:.*]] : !fir.ref<i32>
+    ! CHECK: %[[LD:.*]] = fir.load %[[LOCAL]] : !fir.ref<i32>
+    ! CHECK: fir.call @_FortranAioOutputInteger32({{.*}}, %[[LD]]) : (!fir.ref<i8>, i32) -> i1
+    print*, i
+  end do
+  !$OMP END SIMD
+end subroutine
index 761e964..0c85e6b 100644 (file)
@@ -420,13 +420,18 @@ def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments,
     transformed into a SIMD loop (that is, multiple iterations of the loop can 
     be executed concurrently using SIMD instructions).. The lower and upper 
     bounds specify a half-open range: the range includes the lower bound but 
-    does not include the upper bound.
+    does not include the upper bound. If the `inclusive` attribute is specified
+    then the upper bound is also included.
 
     The body region can contain any number of blocks. The region is terminated
     by "omp.yield" instruction without operands.
+
+    When an if clause is present and evaluates to false, the preferred number of
+    iterations to be executed concurrently is one, regardless of whether
+    a simdlen clause is speciļ¬ed.
     ```
-    omp.simdloop (%i1, %i2) : index = (%c0, %c0) to (%c10, %c10) 
-                                      step (%c1, %c1) {
+    omp.simdloop <clauses>
+    for (%i1, %i2) : index = (%c0, %c0) to (%c10, %c10) step (%c1, %c1) {
       // block operations
       omp.yield
     }
@@ -436,9 +441,17 @@ def SimdLoopOp : OpenMP_Op<"simdloop", [AttrSizedOperandSegments,
   // TODO: Add other clauses
   let arguments = (ins Variadic<IntLikeType>:$lowerBound,
              Variadic<IntLikeType>:$upperBound,
-             Variadic<IntLikeType>:$step);
+             Variadic<IntLikeType>:$step,
+             Optional<I1>:$if_expr,
+             UnitAttr:$inclusive
+     );
  
   let regions = (region AnyRegion:$region);
+  let assemblyFormat = [{
+    oilist(`if` `(` $if_expr `)`
+    ) `for` custom<LoopControl>($region, $lowerBound, $upperBound, $step,
+                                  type($step), $inclusive) attr-dict
+  }];
 
   let extraClassDeclaration = [{
     /// Returns the number of loops in the simd loop nest.
index d09ef96..96ff6b1 100644 (file)
@@ -571,62 +571,6 @@ void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region,
 }
 
 //===----------------------------------------------------------------------===//
-// SimdLoopOp
-//===----------------------------------------------------------------------===//
-/// Parses an OpenMP Simd construct [2.9.3.1]
-///
-/// simdloop ::= `omp.simdloop` loop-control clause-list
-/// loop-control ::= `(` ssa-id-list `)` `:` type `=`  loop-bounds
-/// loop-bounds := `(` ssa-id-list `)` to `(` ssa-id-list `)` steps
-/// steps := `step` `(`ssa-id-list`)`
-/// clause-list ::= clause clause-list | empty
-/// clause ::= TODO
-ParseResult SimdLoopOp::parse(OpAsmParser &parser, OperationState &result) {
-  // Parse an opening `(` followed by induction variables followed by `)`
-  SmallVector<OpAsmParser::Argument> ivs;
-  Type loopVarType;
-  SmallVector<OpAsmParser::UnresolvedOperand> lower, upper, steps;
-  if (parser.parseArgumentList(ivs, OpAsmParser::Delimiter::Paren) ||
-      parser.parseColonType(loopVarType) ||
-      // Parse loop bounds.
-      parser.parseEqual() ||
-      parser.parseOperandList(lower, ivs.size(),
-                              OpAsmParser::Delimiter::Paren) ||
-      parser.resolveOperands(lower, loopVarType, result.operands) ||
-      parser.parseKeyword("to") ||
-      parser.parseOperandList(upper, ivs.size(),
-                              OpAsmParser::Delimiter::Paren) ||
-      parser.resolveOperands(upper, loopVarType, result.operands) ||
-      // Parse step values.
-      parser.parseKeyword("step") ||
-      parser.parseOperandList(steps, ivs.size(),
-                              OpAsmParser::Delimiter::Paren) ||
-      parser.resolveOperands(steps, loopVarType, result.operands))
-    return failure();
-
-  int numIVs = static_cast<int>(ivs.size());
-  SmallVector<int> segments{numIVs, numIVs, numIVs};
-  // TODO: Add parseClauses() when we support clauses
-  result.addAttribute("operand_segment_sizes",
-                      parser.getBuilder().getI32VectorAttr(segments));
-
-  // Now parse the body.
-  Region *body = result.addRegion();
-  for (auto &iv : ivs)
-    iv.type = loopVarType;
-  return parser.parseRegion(*body, ivs);
-}
-
-void SimdLoopOp::print(OpAsmPrinter &p) {
-  auto args = getRegion().front().getArguments();
-  p << " (" << args << ") : " << args[0].getType() << " = (" << lowerBound()
-    << ") to (" << upperBound() << ") ";
-  p << "step (" << step() << ") ";
-
-  p.printRegion(region(), /*printEntryBlockArgs=*/false);
-}
-
-//===----------------------------------------------------------------------===//
 // Verifier for Simd construct [2.9.3.1]
 //===----------------------------------------------------------------------===//
 
index 87409ee..1c7c11f 100644 (file)
@@ -912,6 +912,11 @@ convertOmpSimdLoop(Operation &opInst, llvm::IRBuilderBase &builder,
   SmallVector<llvm::CanonicalLoopInfo *> loopInfos;
   SmallVector<llvm::OpenMPIRBuilder::InsertPointTy> bodyInsertPoints;
   LogicalResult bodyGenStatus = success();
+
+  // TODO: The code generation for if clause is not supported yet.
+  if (loop.if_expr())
+    return failure();
+
   auto bodyGen = [&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *iv) {
     // Make sure further conversions know about the induction variable.
     moduleTranslation.mapValue(
index 27ba2c8..379b8f4 100644 (file)
@@ -197,7 +197,7 @@ func.func @omp_simdloop(%lb : index, %ub : index, %step : i32) -> () {
   "omp.simdloop" (%lb, %ub, %step) ({
     ^bb0(%iv: index):
       omp.yield
-  }) {operand_segment_sizes = dense<[1,1,1]> : vector<3xi32>} :
+  }) {operand_segment_sizes = dense<[1,1,1,0]> : vector<4xi32>} :
     (index, index, i32) -> () 
 
   return
index 2682c97..393378f 100644 (file)
@@ -329,21 +329,29 @@ func.func @omp_wsloop_pretty_multiple(%lb1 : i32, %ub1 : i32, %step1 : i32, %lb2
 
 // CHECK-LABEL: omp_simdloop
 func.func @omp_simdloop(%lb : index, %ub : index, %step : index) -> () {
-  // CHECK: omp.simdloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}})
+  // CHECK: omp.simdloop for (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}})
   "omp.simdloop" (%lb, %ub, %step) ({
     ^bb0(%iv: index):
       omp.yield
-  }) {operand_segment_sizes = dense<[1,1,1]> : vector<3xi32>} :
+  }) {operand_segment_sizes = dense<[1,1,1,0]> : vector<4xi32>} :
     (index, index, index) -> () 
 
   return
 }
 
-
 // CHECK-LABEL: omp_simdloop_pretty
 func.func @omp_simdloop_pretty(%lb : index, %ub : index, %step : index) -> () {
-  // CHECK: omp.simdloop (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}})
-  omp.simdloop (%iv) : index = (%lb) to (%ub) step (%step) {
+  // CHECK: omp.simdloop for (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}})
+  omp.simdloop for (%iv) : index = (%lb) to (%ub) step (%step) {
+    omp.yield
+  }
+  return
+}
+
+// CHECK-LABEL: omp_simdloop_pretty_if
+func.func @omp_simdloop_pretty_if(%lb : index, %ub : index, %step : index, %if_cond : i1) -> () {
+  // CHECK: omp.simdloop if(%{{.*}}) for (%{{.*}}) : index = (%{{.*}}) to (%{{.*}}) step (%{{.*}})
+  omp.simdloop if(%if_cond) for (%iv): index = (%lb) to (%ub) step (%step) {
     omp.yield
   }
   return
@@ -351,8 +359,8 @@ func.func @omp_simdloop_pretty(%lb : index, %ub : index, %step : index) -> () {
 
 // CHECK-LABEL: omp_simdloop_pretty_multiple
 func.func @omp_simdloop_pretty_multiple(%lb1 : index, %ub1 : index, %step1 : index, %lb2 : index, %ub2 : index, %step2 : index) -> () {
-  // CHECK: omp.simdloop (%{{.*}}, %{{.*}}) : index = (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}})
-  omp.simdloop (%iv1, %iv2) : index = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+  // CHECK: omp.simdloop for (%{{.*}}, %{{.*}}) : index = (%{{.*}}, %{{.*}}) to (%{{.*}}, %{{.*}}) step (%{{.*}}, %{{.*}})
+  omp.simdloop for (%iv1, %iv2) : index = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
     omp.yield
   }
   return
index bdb2fad..40ac2e3 100644 (file)
@@ -697,7 +697,7 @@ llvm.func @simdloop_simple(%lb : i64, %ub : i64, %step : i64, %arg0: !llvm.ptr<f
       %4 = llvm.getelementptr %arg0[%iv] : (!llvm.ptr<f32>, i64) -> !llvm.ptr<f32>
       llvm.store %3, %4 : !llvm.ptr<f32>
       omp.yield
-  }) {operand_segment_sizes = dense<[1,1,1]> : vector<3xi32>} :
+  }) {operand_segment_sizes = dense<[1,1,1,0]> : vector<4xi32>} :
     (i64, i64, i64) -> () 
 
   llvm.return
@@ -709,7 +709,7 @@ llvm.func @simdloop_simple(%lb : i64, %ub : i64, %step : i64, %arg0: !llvm.ptr<f
 
 // CHECK-LABEL: @simdloop_simple_multiple
 llvm.func @simdloop_simple_multiple(%lb1 : i64, %ub1 : i64, %step1 : i64, %lb2 : i64, %ub2 : i64, %step2 : i64, %arg0: !llvm.ptr<f32>, %arg1: !llvm.ptr<f32>) {
-  omp.simdloop (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
+  omp.simdloop for (%iv1, %iv2) : i64 = (%lb1, %lb2) to (%ub1, %ub2) step (%step1, %step2) {
     %3 = llvm.mlir.constant(2.000000e+00 : f32) : f32
     // The form of the emitted IR is controlled by OpenMPIRBuilder and
     // tested there. Just check that the right metadata is added.