[Flang][OpenMP] Support depend clause for task construct, excluding array sections
authorPrabhdeep Singh Soni <prabhdeep.singh.soni3@huawei.com>
Thu, 23 Mar 2023 15:26:50 +0000 (11:26 -0400)
committerPrabhdeep Singh Soni <prabhdeep.singh.soni3@huawei.com>
Tue, 6 Jun 2023 14:21:05 +0000 (10:21 -0400)
This patch adds support for the OpenMP 4.0 depend clause for the task
construct, excluding array sections, to Flang lowering from parse-tree
to MLIR.

Reviewed By: kiranchandramohan

Differential Revision: https://reviews.llvm.org/D146766

flang/lib/Lower/OpenMP.cpp
flang/test/Lower/OpenMP/task.f90

index 13b7c5891b7271a338988bfb3392120869324ac5..94bd8fb4ed860287dcaf7749aadad904d6f01212 100644 (file)
@@ -1006,6 +1006,31 @@ static omp::ClauseProcBindKindAttr genProcBindKindAttr(
   return omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), pbKind);
 }
 
+static omp::ClauseTaskDependAttr
+genDependKindAttr(fir::FirOpBuilder &firOpBuilder,
+                  const Fortran::parser::OmpClause::Depend *dependClause) {
+  omp::ClauseTaskDepend pbKind;
+  switch (
+      std::get<Fortran::parser::OmpDependenceType>(
+          std::get<Fortran::parser::OmpDependClause::InOut>(dependClause->v.u)
+              .t)
+          .v) {
+  case Fortran::parser::OmpDependenceType::Type::In:
+    pbKind = omp::ClauseTaskDepend::taskdependin;
+    break;
+  case Fortran::parser::OmpDependenceType::Type::Out:
+    pbKind = omp::ClauseTaskDepend::taskdependout;
+    break;
+  case Fortran::parser::OmpDependenceType::Type::Inout:
+    pbKind = omp::ClauseTaskDepend::taskdependinout;
+    break;
+  default:
+    llvm_unreachable("unknown parser task dependence type");
+    break;
+  }
+  return omp::ClauseTaskDependAttr::get(firOpBuilder.getContext(), pbKind);
+}
+
 /* When parallel is used in a combined construct, then use this function to
  * create the parallel operation. It handles the parallel specific clauses
  * and leaves the rest for handling at the inner operations.
@@ -1072,7 +1097,8 @@ genOMP(Fortran::lower::AbstractConverter &converter,
   mlir::Value ifClauseOperand, numThreadsClauseOperand, finalClauseOperand,
       priorityClauseOperand;
   mlir::omp::ClauseProcBindKindAttr procBindKindAttr;
-  SmallVector<Value> allocateOperands, allocatorOperands;
+  SmallVector<Value> allocateOperands, allocatorOperands, dependOperands;
+  SmallVector<Attribute> dependTypeOperands;
   mlir::UnitAttr nowaitAttr, untiedAttr, mergeableAttr;
 
   const auto &opClauseList =
@@ -1148,6 +1174,42 @@ genOMP(Fortran::lower::AbstractConverter &converter,
            "Reduction in OpenMP " +
                llvm::omp::getOpenMPDirectiveName(blockDirective.v) +
                " construct");
+    } else if (const auto &dependClause =
+                   std::get_if<Fortran::parser::OmpClause::Depend>(&clause.u)) {
+      const std::list<Fortran::parser::Designator> &depVal =
+          std::get<std::list<Fortran::parser::Designator>>(
+              std::get<Fortran::parser::OmpDependClause::InOut>(
+                  dependClause->v.u)
+                  .t);
+      omp::ClauseTaskDependAttr dependTypeOperand =
+          genDependKindAttr(firOpBuilder, dependClause);
+      dependTypeOperands.insert(dependTypeOperands.end(), depVal.size(),
+                                dependTypeOperand);
+      for (const Fortran::parser::Designator &ompObject : depVal) {
+        Fortran::semantics::Symbol *sym = nullptr;
+        std::visit(
+            Fortran::common::visitors{
+                [&](const Fortran::parser::DataRef &designator) {
+                  if (const Fortran::parser::Name *name =
+                          std::get_if<Fortran::parser::Name>(&designator.u)) {
+                    sym = name->symbol;
+                  } else if (const Fortran::common::Indirection<
+                                 Fortran::parser::ArrayElement> *a =
+                                 std::get_if<Fortran::common::Indirection<
+                                     Fortran::parser::ArrayElement>>(
+                                     &designator.u)) {
+                    TODO(converter.getCurrentLocation(),
+                         "array sections not supported for task depend");
+                  }
+                },
+                [&](const Fortran::parser::Substring &designator) {
+                  TODO(converter.getCurrentLocation(),
+                       "substring not supported for task depend");
+                }},
+            (ompObject).u);
+        const mlir::Value variable = converter.getSymbolAddress(*sym);
+        dependOperands.push_back(((variable)));
+      }
     } else {
       TODO(converter.getCurrentLocation(), "OpenMP Block construct clause");
     }
@@ -1185,8 +1247,12 @@ genOMP(Fortran::lower::AbstractConverter &converter,
     auto taskOp = firOpBuilder.create<mlir::omp::TaskOp>(
         currentLocation, ifClauseOperand, finalClauseOperand, untiedAttr,
         mergeableAttr, /*in_reduction_vars=*/ValueRange(),
-        /*in_reductions=*/nullptr, priorityClauseOperand, /*depends=*/nullptr,
-        /*depend_vars=*/ValueRange(), allocateOperands, allocatorOperands);
+        /*in_reductions=*/nullptr, priorityClauseOperand,
+        dependTypeOperands.empty()
+            ? nullptr
+            : mlir::ArrayAttr::get(firOpBuilder.getContext(),
+                                   dependTypeOperands),
+        dependOperands, allocateOperands, allocatorOperands);
     createBodyOfOp(taskOp, converter, currentLocation, eval, &opClauseList);
   } else if (blockDirective.v == llvm::omp::OMPD_taskgroup) {
     // TODO: Add task_reduction support
index 810e3b521a26b0fb93708caab9b19630a495c857..d7419bd1100e69075cef2dd9ea0cb2abc5d56383 100644 (file)
@@ -99,6 +99,79 @@ subroutine task_allocate()
   !$omp end task
 end subroutine task_allocate
 
+!===============================================================================
+! `depend` clause
+!===============================================================================
+
+!CHECK-LABEL: func @_QPtask_depend
+subroutine task_depend()
+  integer :: x
+  !CHECK: omp.task depend(taskdependin -> %{{.+}} : !fir.ref<i32>) {
+  !$omp task depend(in : x)
+  !CHECK: arith.addi
+  x = x + 12
+  !CHECK: omp.terminator
+  !$omp end task
+end subroutine task_depend
+
+!CHECK-LABEL: func @_QPtask_depend_non_int
+subroutine task_depend_non_int()
+  character(len = 15) :: x
+  integer, allocatable :: y
+  complex :: z
+  !CHECK: omp.task depend(taskdependin -> %{{.+}} : !fir.ref<!fir.char<1,15>>, taskdependin -> %{{.+}} : !fir.ref<!fir.box<!fir.heap<i32>>>, taskdependin ->  %{{.+}} : !fir.ref<!fir.complex<4>>) {
+  !$omp task depend(in : x, y, z)
+  !CHECK: omp.terminator
+  !$omp end task
+end subroutine task_depend_non_int
+
+!CHECK-LABEL: func @_QPtask_depend_all_kinds_one_task
+subroutine task_depend_all_kinds_one_task()
+  integer :: x
+  !CHECK: omp.task depend(taskdependin -> %{{.+}} : !fir.ref<i32>, taskdependout -> %{{.+}} : !fir.ref<i32>, taskdependinout -> %{{.+}} : !fir.ref<i32>) {
+  !$omp task depend(in : x) depend(out : x) depend(inout : x)
+  !CHECK: arith.addi
+  x = x + 12
+  !CHECK: omp.terminator
+  !$omp end task
+end subroutine task_depend_all_kinds_one_task
+
+!CHECK-LABEL: func @_QPtask_depend_multi_var
+subroutine task_depend_multi_var()
+  integer :: x
+  integer :: y
+  !CHECK: omp.task depend(taskdependin -> %{{.*}} : !fir.ref<i32>, taskdependin -> %{{.+}} : !fir.ref<i32>) {
+  !$omp task depend(in :x,y)
+  !CHECK: arith.addi
+  x = x + 12
+  y = y + 12
+  !CHECK: omp.terminator
+  !$omp end task
+end subroutine task_depend_multi_var
+
+!CHECK-LABEL: func @_QPtask_depend_multi_task
+subroutine task_depend_multi_task()
+  integer :: x
+  !CHECK: omp.task depend(taskdependout -> %{{.+}} : !fir.ref<i32>)
+  !$omp task depend(out : x)
+  !CHECK: arith.addi
+  x = x + 12
+  !CHECK: omp.terminator
+  !$omp end task
+  !CHECK: omp.task depend(taskdependinout -> %{{.+}} : !fir.ref<i32>)
+  !$omp task depend(inout : x)
+  !CHECK: arith.addi
+  x = x + 12
+  !CHECK: omp.terminator
+  !$omp end task
+  !CHECK: omp.task depend(taskdependin -> %{{.+}} : !fir.ref<i32>)
+  !$omp task depend(in : x)
+  !CHECK: arith.addi
+  x = x + 12
+  !CHECK: omp.terminator
+  !$omp end task
+end subroutine task_depend_multi_task
+
 !===============================================================================
 ! `private` clause
 !===============================================================================