[MLIR][OpenMP] Add Lowering support for OpenMP Target Data, Exit Data and Enter Data...
authorAkash Banerjee <Akash.Banerjee@amd.com>
Tue, 24 Jan 2023 11:29:50 +0000 (11:29 +0000)
committerAkash Banerjee <Akash.Banerjee@amd.com>
Thu, 26 Jan 2023 10:59:55 +0000 (10:59 +0000)
This patch adds Fortran Lowering support for the OpenMP Target Data, Target Exit Data and Target Enter Data constructs.
operation.

Differential Revision: https://reviews.llvm.org/D142357

flang/lib/Lower/OpenMP.cpp
flang/test/Lower/OpenMP/target_data.f90 [new file with mode: 0644]

index 8e13c24..4bb734b 100644 (file)
@@ -404,6 +404,19 @@ static void genObjectList(const Fortran::parser::OmpObjectList &objectList,
   }
 }
 
+static mlir::Value
+getIfClauseOperand(Fortran::lower::AbstractConverter &converter,
+                   Fortran::lower::StatementContext &stmtCtx,
+                   const Fortran::parser::OmpClause::If *ifClause) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  mlir::Location currentLocation = converter.getCurrentLocation();
+  auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
+  mlir::Value ifVal = fir::getBase(
+      converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
+  return firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(),
+                                    ifVal);
+}
+
 static mlir::Type getLoopVarType(Fortran::lower::AbstractConverter &converter,
                                  std::size_t loopVarTypeSize) {
   // OpenMP runtime requires 32-bit or 64-bit loop variables.
@@ -547,6 +560,130 @@ createBodyOfOp(Op &op, Fortran::lower::AbstractConverter &converter,
   }
 }
 
+static void
+createTargetDataOp(Fortran::lower::AbstractConverter &converter,
+                   const Fortran::parser::OmpClauseList &opClauseList,
+                   const llvm::omp::Directive &directive) {
+  Fortran::lower::StatementContext stmtCtx;
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+  mlir::Value ifClauseOperand, deviceOperand;
+  mlir::UnitAttr nowaitAttr;
+  llvm::SmallVector<mlir::Value> useDevicePtrOperand, useDeviceAddrOperand,
+      mapOperands;
+  llvm::SmallVector<mlir::IntegerAttr> mapTypes;
+
+  auto addMapClause = [&firOpBuilder, &converter, &mapOperands,
+                       &mapTypes](const auto &mapClause) {
+    auto mapType = std::get<Fortran::parser::OmpMapType::Type>(
+        std::get<std::optional<Fortran::parser::OmpMapType>>(mapClause->v.t)
+            ->t);
+    llvm::omp::OpenMPOffloadMappingFlags mapTypeBits =
+        llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_NONE;
+    switch (mapType) {
+    case Fortran::parser::OmpMapType::Type::To:
+      mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO;
+      break;
+    case Fortran::parser::OmpMapType::Type::From:
+      mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+      break;
+    case Fortran::parser::OmpMapType::Type::Tofrom:
+      mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_TO |
+                     llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_FROM;
+      break;
+    case Fortran::parser::OmpMapType::Type::Alloc:
+    case Fortran::parser::OmpMapType::Type::Release:
+      // alloc and release is the default map_type for the Target Data Ops, i.e.
+      // if no bits for map_type is supplied then alloc/release is implicitly
+      // assumed based on the target directive. Default value for Target Data
+      // and Enter Data is alloc and for Exit Data it is release.
+      break;
+    case Fortran::parser::OmpMapType::Type::Delete:
+      mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_DELETE;
+    }
+    if (std::get<std::optional<Fortran::parser::OmpMapType::Always>>(
+            std::get<std::optional<Fortran::parser::OmpMapType>>(mapClause->v.t)
+                ->t)
+            .has_value())
+      mapTypeBits |= llvm::omp::OpenMPOffloadMappingFlags::OMP_MAP_ALWAYS;
+
+    // TODO: Add support MapTypeModifiers close, mapper, present, iterator
+
+    mlir::IntegerAttr mapTypeAttr = firOpBuilder.getIntegerAttr(
+        firOpBuilder.getI64Type(),
+        static_cast<
+            std::underlying_type_t<llvm::omp::OpenMPOffloadMappingFlags>>(
+            mapTypeBits));
+
+    llvm::SmallVector<mlir::Value> mapOperand;
+    genObjectList(std::get<Fortran::parser::OmpObjectList>(mapClause->v.t),
+                  converter, mapOperand);
+
+    for (mlir::Value mapOp : mapOperand) {
+      mapOperands.push_back(mapOp);
+      mapTypes.push_back(mapTypeAttr);
+    }
+  };
+
+  for (const Fortran::parser::OmpClause &clause : opClauseList.v) {
+    mlir::Location currentLocation = converter.genLocation(clause.source);
+    if (const auto &ifClause =
+            std::get_if<Fortran::parser::OmpClause::If>(&clause.u)) {
+      ifClauseOperand = getIfClauseOperand(converter, stmtCtx, ifClause);
+    } else if (const auto &deviceClause =
+                   std::get_if<Fortran::parser::OmpClause::Device>(&clause.u)) {
+      if (auto deviceModifier = std::get<
+              std::optional<Fortran::parser::OmpDeviceClause::DeviceModifier>>(
+              deviceClause->v.t)) {
+        if (deviceModifier ==
+            Fortran::parser::OmpDeviceClause::DeviceModifier::Ancestor) {
+          TODO(currentLocation, "OMPD_target Device Modifier Ancestor");
+        }
+      }
+      if (const auto *deviceExpr = Fortran::semantics::GetExpr(
+              std::get<Fortran::parser::ScalarIntExpr>(deviceClause->v.t))) {
+        deviceOperand =
+            fir::getBase(converter.genExprValue(*deviceExpr, stmtCtx));
+      }
+    } else if (std::get_if<Fortran::parser::OmpClause::UseDevicePtr>(
+                   &clause.u)) {
+      TODO(currentLocation, "OMPD_target Use Device Ptr");
+    } else if (std::get_if<Fortran::parser::OmpClause::UseDeviceAddr>(
+                   &clause.u)) {
+      TODO(currentLocation, "OMPD_target Use Device Addr");
+    } else if (std::get_if<Fortran::parser::OmpClause::Nowait>(&clause.u)) {
+      nowaitAttr = firOpBuilder.getUnitAttr();
+    } else if (const auto &mapClause =
+                   std::get_if<Fortran::parser::OmpClause::Map>(&clause.u)) {
+      addMapClause(mapClause);
+    } else {
+      TODO(currentLocation, "OMPD_target unhandled clause");
+    }
+  }
+
+  llvm::SmallVector<mlir::Attribute> mapTypesAttr(mapTypes.begin(),
+                                                  mapTypes.end());
+  mlir::ArrayAttr mapTypesArrayAttr =
+      ArrayAttr::get(firOpBuilder.getContext(), mapTypesAttr);
+  mlir::Location currentLocation = converter.getCurrentLocation();
+
+  if (directive == llvm::omp::Directive::OMPD_target_data) {
+    firOpBuilder.create<omp::DataOp>(
+        currentLocation, ifClauseOperand, deviceOperand, useDevicePtrOperand,
+        useDeviceAddrOperand, mapOperands, mapTypesArrayAttr);
+  } else if (directive == llvm::omp::Directive::OMPD_target_enter_data) {
+    firOpBuilder.create<omp::EnterDataOp>(currentLocation, ifClauseOperand,
+                                          deviceOperand, nowaitAttr,
+                                          mapOperands, mapTypesArrayAttr);
+  } else if (directive == llvm::omp::Directive::OMPD_target_exit_data) {
+    firOpBuilder.create<omp::ExitDataOp>(currentLocation, ifClauseOperand,
+                                         deviceOperand, nowaitAttr, mapOperands,
+                                         mapTypesArrayAttr);
+  } else {
+    TODO(currentLocation, "OMPD_target directive unknown");
+  }
+}
+
 static void genOMP(Fortran::lower::AbstractConverter &converter,
                    Fortran::lower::pft::Evaluation &eval,
                    const Fortran::parser::OpenMPSimpleStandaloneConstruct
@@ -554,25 +691,27 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
   const auto &directive =
       std::get<Fortran::parser::OmpSimpleStandaloneDirective>(
           simpleStandaloneConstruct.t);
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+  const Fortran::parser::OmpClauseList &opClauseList =
+      std::get<Fortran::parser::OmpClauseList>(simpleStandaloneConstruct.t);
+
   switch (directive.v) {
   default:
     break;
   case llvm::omp::Directive::OMPD_barrier:
-    converter.getFirOpBuilder().create<mlir::omp::BarrierOp>(
-        converter.getCurrentLocation());
+    firOpBuilder.create<omp::BarrierOp>(converter.getCurrentLocation());
     break;
   case llvm::omp::Directive::OMPD_taskwait:
-    converter.getFirOpBuilder().create<mlir::omp::TaskwaitOp>(
-        converter.getCurrentLocation());
+    firOpBuilder.create<omp::TaskwaitOp>(converter.getCurrentLocation());
     break;
   case llvm::omp::Directive::OMPD_taskyield:
-    converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>(
-        converter.getCurrentLocation());
+    firOpBuilder.create<omp::TaskyieldOp>(converter.getCurrentLocation());
     break;
+  case llvm::omp::Directive::OMPD_target_data:
   case llvm::omp::Directive::OMPD_target_enter_data:
-    TODO(converter.getCurrentLocation(), "OMPD_target_enter_data");
   case llvm::omp::Directive::OMPD_target_exit_data:
-    TODO(converter.getCurrentLocation(), "OMPD_target_exit_data");
+    createTargetDataOp(converter, opClauseList, directive.v);
+    break;
   case llvm::omp::Directive::OMPD_target_update:
     TODO(converter.getCurrentLocation(), "OMPD_target_update");
   case llvm::omp::Directive::OMPD_ordered:
@@ -669,19 +808,6 @@ static omp::ClauseProcBindKindAttr genProcBindKindAttr(
   return omp::ClauseProcBindKindAttr::get(firOpBuilder.getContext(), pbKind);
 }
 
-static mlir::Value
-getIfClauseOperand(Fortran::lower::AbstractConverter &converter,
-                   Fortran::lower::StatementContext &stmtCtx,
-                   const Fortran::parser::OmpClause::If *ifClause) {
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-  mlir::Location currentLocation = converter.getCurrentLocation();
-  auto &expr = std::get<Fortran::parser::ScalarLogicalExpr>(ifClause->v.t);
-  mlir::Value ifVal = fir::getBase(
-      converter.genExprValue(*Fortran::semantics::GetExpr(expr), stmtCtx));
-  return firOpBuilder.createConvert(currentLocation, firOpBuilder.getI1Type(),
-                                    ifVal);
-}
-
 /* When parallel is used in a combined construct, then use this function to
  * create the parallel operation. It handles the parallel specific clauses
  * and leaves the rest for handling at the inner operations.
diff --git a/flang/test/Lower/OpenMP/target_data.f90 b/flang/test/Lower/OpenMP/target_data.f90
new file mode 100644 (file)
index 0000000..d3994bd
--- /dev/null
@@ -0,0 +1,105 @@
+!RUN: %flang_fc1 -emit-fir -fopenmp %s -o - | FileCheck %s
+
+!===============================================================================
+! Target_Enter Simple
+!===============================================================================
+
+
+!CHECK-LABEL: func.func @_QPomp_target_enter_simple() {
+subroutine omp_target_enter_simple
+   integer :: a(1024)
+   !CHECK: omp.target_enter_data   map((to -> {{.*}} : !fir.ref<!fir.array<1024xi32>>))
+   !$omp target enter data map(to: a)
+end subroutine omp_target_enter_simple
+
+!===============================================================================
+! Target_Enter Map types
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_enter_mt() {
+subroutine omp_target_enter_mt
+   integer :: a(1024)
+   integer :: b(1024)
+   integer :: c(1024)
+   integer :: d(1024)
+   !CHECK: omp.target_enter_data   map((to -> {{.*}} : !fir.ref<!fir.array<1024xi32>>), (to -> {{.*}} : !fir.ref<!fir.array<1024xi32>>), (always, alloc -> {{.*}} : !fir.ref<!fir.array<1024xi32>>), (to -> {{.*}} : !fir.ref<!fir.array<1024xi32>>))
+   !$omp target enter data map(to: a, b) map(always, alloc: c) map(to: d)
+end subroutine omp_target_enter_mt
+
+!===============================================================================
+! `Nowait` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_enter_nowait() {
+subroutine omp_target_enter_nowait
+   integer :: a(1024)
+   !CHECK: omp.target_enter_data   nowait map((to -> {{.*}} : !fir.ref<!fir.array<1024xi32>>))
+   !$omp target enter data map(to: a) nowait
+end subroutine omp_target_enter_nowait
+
+!===============================================================================
+! `if` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_enter_if() {
+subroutine omp_target_enter_if
+   integer :: a(1024)
+   integer :: i
+   i = 5
+   !CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_1:.*]] : !fir.ref<i32>
+   !CHECK: %[[VAL_4:.*]] = arith.constant 10 : i32
+   !CHECK: %[[VAL_5:.*]] = arith.cmpi slt, %[[VAL_3:.*]], %[[VAL_4:.*]] : i32
+   !CHECK: omp.target_enter_data   if(%[[VAL_5:.*]] : i1) map((to -> {{.*}} : !fir.ref<!fir.array<1024xi32>>))
+   !$omp target enter data if(i<10) map(to: a)
+end subroutine omp_target_enter_if
+
+!===============================================================================
+! `device` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_enter_device() {
+subroutine omp_target_enter_device
+   integer :: a(1024)
+   !CHECK: %[[VAL_1:.*]] = arith.constant 2 : i32
+   !CHECK: omp.target_enter_data   device(%[[VAL_1:.*]] : i32) map((to -> {{.*}} : !fir.ref<!fir.array<1024xi32>>))
+   !$omp target enter data map(to: a) device(2)
+end subroutine omp_target_enter_device
+
+!===============================================================================
+! Target_Exit Simple
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_exit_simple() {
+subroutine omp_target_exit_simple
+   integer :: a(1024)
+   !CHECK: omp.target_exit_data   map((from -> {{.*}} : !fir.ref<!fir.array<1024xi32>>))
+   !$omp target exit data map(from: a)
+end subroutine omp_target_exit_simple
+
+!===============================================================================
+! Target_Exit Map types
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_exit_mt() {
+subroutine omp_target_exit_mt
+   integer :: a(1024)
+   integer :: b(1024)
+   integer :: c(1024)
+   integer :: d(1024)
+   integer :: e(1024)
+   !CHECK: omp.target_exit_data   map((from -> {{.*}} : !fir.ref<!fir.array<1024xi32>>), (from -> {{.*}} : !fir.ref<!fir.array<1024xi32>>), (release -> {{.*}} : !fir.ref<!fir.array<1024xi32>>), (always, delete -> {{.*}} : !fir.ref<!fir.array<1024xi32>>), (from -> {{.*}} : !fir.ref<!fir.array<1024xi32>>))
+   !$omp target exit data map(from: a,b) map(release: c) map(always, delete: d) map(from: e)
+end subroutine omp_target_exit_mt
+
+!===============================================================================
+! `device` clause
+!===============================================================================
+
+!CHECK-LABEL: func.func @_QPomp_target_exit_device() {
+subroutine omp_target_exit_device
+   integer :: a(1024)
+   integer :: d
+   !CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_1:.*]] : !fir.ref<i32>
+   !CHECK: omp.target_exit_data   device(%[[VAL_2:.*]] : i32) map((from -> {{.*}} : !fir.ref<!fir.array<1024xi32>>))
+   !$omp target exit data map(from: a) device(d)
+end subroutine omp_target_exit_device