[flang][openacc] Add basic lowering to new data operations for acc.enter_data
authorValentin Clement <clementval@gmail.com>
Wed, 19 Apr 2023 22:52:24 +0000 (15:52 -0700)
committerValentin Clement <clementval@gmail.com>
Wed, 19 Apr 2023 22:52:46 +0000 (15:52 -0700)
This is an initial patch that lowers acc.enter_data copyin/create/attach
clauses to the newly added data operand operations. Follow up patches will
add support for array section and derived type and derived type component
as well as support in other data operation.

Reviewed By: razvanlupusoru

Differential Revision: https://reviews.llvm.org/D148721

flang/lib/Lower/OpenACC.cpp
flang/test/Lower/OpenACC/acc-enter-data.f90

index d17dab6..a72db4c 100644 (file)
@@ -100,6 +100,80 @@ genObjectList(const Fortran::parser::AccObjectList &objectList,
   }
 }
 
+template <typename Op>
+static void
+genDataOperandOperations(const Fortran::parser::AccObjectList &objectList,
+                         Fortran::lower::AbstractConverter &converter,
+                         Fortran::semantics::SemanticsContext &semanticsContext,
+                         Fortran::lower::StatementContext &stmtCtx,
+                         llvm::SmallVectorImpl<mlir::Value> &dataOperands,
+                         mlir::acc::DataClause dataClause, bool structured) {
+
+  fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+
+  auto createOpAndAddOperand = [&](Fortran::lower::SymbolRef sym,
+                                   llvm::StringRef name, mlir::Location loc) {
+    mlir::Value symAddr = converter.getSymbolAddress(sym);
+    // TODO: Might need revisiting to handle for non-shared clauses
+    if (!symAddr) {
+      if (const auto *details =
+              sym->detailsIf<Fortran::semantics::HostAssocDetails>())
+        symAddr = converter.getSymbolAddress(details->symbol());
+    }
+
+    if (!symAddr)
+      llvm::report_fatal_error("could not retrieve symbol address");
+
+    Op op = builder.create<Op>(loc, symAddr.getType(), symAddr);
+    op.setNameAttr(builder.getStringAttr(name));
+    op.setStructured(structured);
+    op.setDataClause(dataClause);
+    op->setAttr(Op::getOperandSegmentSizeAttr(),
+                builder.getDenseI32ArrayAttr({1, 0, 0}));
+    dataOperands.push_back(op.getAccPtr());
+  };
+
+  for (const auto &accObject : objectList.v) {
+    std::visit(
+        Fortran::common::visitors{
+            [&](const Fortran::parser::Designator &designator) {
+              mlir::Location operandLocation =
+                  converter.genLocation(designator.source);
+              if (auto expr{Fortran::semantics::AnalyzeExpr(semanticsContext,
+                                                            designator)}) {
+                if ((*expr).Rank() > 0 &&
+                    Fortran::parser::Unwrap<Fortran::parser::ArrayElement>(
+                        designator)) {
+                  TODO(operandLocation, "OpenACC array section data operand");
+                } else if (Fortran::parser::Unwrap<
+                               Fortran::parser::StructureComponent>(
+                               designator)) {
+                  TODO(operandLocation, "OpenACC derived-type data operand");
+                } else {
+                  // Scalar or full array.
+                  if (const auto *dataRef{std::get_if<Fortran::parser::DataRef>(
+                          &designator.u)}) {
+                    const Fortran::parser::Name &name =
+                        Fortran::parser::GetLastName(*dataRef);
+                    createOpAndAddOperand(*name.symbol, name.ToString(),
+                                          operandLocation);
+                  } else { // Unsupported
+                    llvm::report_fatal_error(
+                        "Unsupported type of OpenACC operand");
+                  }
+                }
+              }
+            },
+            [&](const Fortran::parser::Name &name) {
+              mlir::Location operandLocation =
+                  converter.genLocation(name.source);
+              createOpAndAddOperand(*name.symbol, name.ToString(),
+                                    operandLocation);
+            }},
+        accObject.u);
+  }
+}
+
 template <typename Clause>
 static void genObjectListWithModifier(
     const Clause *x, Fortran::lower::AbstractConverter &converter,
@@ -794,18 +868,33 @@ genACCEnterDataOp(Fortran::lower::AbstractConverter &converter,
           copyinClause->v;
       const auto &accObjectList =
           std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
-      genObjectList(accObjectList, converter, semanticsContext, stmtCtx,
-                    copyinOperands);
+      genDataOperandOperations<mlir::acc::CopyinOp>(
+          accObjectList, converter, semanticsContext, stmtCtx,
+          dataClauseOperands, mlir::acc::DataClause::acc_copyin, false);
     } else if (const auto *createClause =
                    std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
-      genObjectListWithModifier<Fortran::parser::AccClause::Create>(
-          createClause, converter, semanticsContext, stmtCtx,
-          Fortran::parser::AccDataModifier::Modifier::Zero, createZeroOperands,
-          createOperands);
+      const Fortran::parser::AccObjectListWithModifier &listWithModifier =
+          createClause->v;
+      const auto &accObjectList =
+          std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
+      const auto &modifier =
+          std::get<std::optional<Fortran::parser::AccDataModifier>>(
+              listWithModifier.t);
+      if (modifier &&
+          (*modifier).v == Fortran::parser::AccDataModifier::Modifier::Zero) {
+        genDataOperandOperations<mlir::acc::CreateOp>(
+            accObjectList, converter, semanticsContext, stmtCtx,
+            dataClauseOperands, mlir::acc::DataClause::acc_create_zero, false);
+      } else {
+        genDataOperandOperations<mlir::acc::CreateOp>(
+            accObjectList, converter, semanticsContext, stmtCtx,
+            dataClauseOperands, mlir::acc::DataClause::acc_create, false);
+      }
     } else if (const auto *attachClause =
                    std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) {
-      genObjectList(attachClause->v, converter, semanticsContext, stmtCtx,
-                    attachOperands);
+      genDataOperandOperations<mlir::acc::AttachOp>(
+          attachClause->v, converter, semanticsContext, stmtCtx,
+          dataClauseOperands, mlir::acc::DataClause::acc_attach, false);
     } else {
       llvm::report_fatal_error(
           "Unknown clause in ENTER DATA directive lowering");
index ca26c8f..a2eff3a 100644 (file)
@@ -8,62 +8,82 @@ subroutine acc_enter_data
   real, pointer :: d
   logical :: ifCondition = .TRUE.
 
-!CHECK: [[A:%.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Ea"}
-!CHECK: [[B:%.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Eb"}
-!CHECK: [[C:%.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Ec"}
-!CHECK: [[D:%.*]] = fir.alloca !fir.box<!fir.ptr<f32>> {bindc_name = "d", uniq_name = "{{.*}}Ed"}
+!CHECK: %[[A:.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Ea"}
+!CHECK: %[[B:.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Eb"}
+!CHECK: %[[C:.*]] = fir.alloca !fir.array<10x10xf32> {{{.*}}uniq_name = "{{.*}}Ec"}
+!CHECK: %[[D:.*]] = fir.alloca !fir.box<!fir.ptr<f32>> {bindc_name = "d", uniq_name = "{{.*}}Ed"}
 
   !$acc enter data create(a)
-!CHECK: acc.enter_data create([[A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
+!CHECK: %[[CREATE_A:.*]] = acc.create varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "a", structured = false}
+!CHECK: acc.enter_data dataOperands(%[[CREATE_A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
 
   !$acc enter data create(a) if(.true.)
+!CHECK: %[[CREATE_A:.*]] = acc.create varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "a", structured = false}
 !CHECK: [[IF1:%.*]] = arith.constant true
-!CHECK: acc.enter_data if([[IF1]]) create([[A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
+!CHECK: acc.enter_data if([[IF1]]) dataOperands(%[[CREATE_A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
 
   !$acc enter data create(a) if(ifCondition)
+!CHECK: %[[CREATE_A:.*]] = acc.create varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "a", structured = false}
 !CHECK: [[IFCOND:%.*]] = fir.load %{{.*}} : !fir.ref<!fir.logical<4>>
 !CHECK: [[IF2:%.*]] = fir.convert [[IFCOND]] : (!fir.logical<4>) -> i1
-!CHECK: acc.enter_data if([[IF2]]) create([[A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
+!CHECK: acc.enter_data if([[IF2]]) dataOperands(%[[CREATE_A]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
 
   !$acc enter data create(a) create(b) create(c)
-!CHECK: acc.enter_data create([[A]], [[B]], [[C]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>){{$}}
+!CHECK: %[[CREATE_A:.*]] = acc.create varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "a", structured = false}
+!CHECK: %[[CREATE_B:.*]] = acc.create varPtr(%[[B]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "b", structured = false}
+!CHECK: %[[CREATE_C:.*]] = acc.create varPtr(%[[C]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "c", structured = false}
+!CHECK: acc.enter_data dataOperands(%[[CREATE_A]], %[[CREATE_B]], %[[CREATE_C]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>){{$}}
 
   !$acc enter data create(a) create(b) create(zero: c)
-!CHECK: acc.enter_data create([[A]], [[B]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>) create_zero([[C]] : !fir.ref<!fir.array<10x10xf32>>){{$}}
+!CHECK: %[[CREATE_A:.*]] = acc.create varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "a", structured = false}
+!CHECK: %[[CREATE_B:.*]] = acc.create varPtr(%[[B]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "b", structured = false}
+!CHECK: %[[CREATE_C:.*]] = acc.create varPtr(%[[C]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {dataClause = 8 : i64, name = "c", structured = false}
+!CHECK: acc.enter_data dataOperands(%[[CREATE_A]], %[[CREATE_B]], %[[CREATE_C]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>){{$}}
 
   !$acc enter data copyin(a) create(b) attach(d)
-!CHECK: acc.enter_data copyin([[A]] : !fir.ref<!fir.array<10x10xf32>>) create([[B]] : !fir.ref<!fir.array<10x10xf32>>) attach([[D]] : !fir.ref<!fir.box<!fir.ptr<f32>>>){{$}}
+!CHECK: %[[COPYIN_A:.*]] = acc.copyin varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "a", structured = false}
+!CHECK: %[[CREATE_B:.*]] = acc.create varPtr(%[[B]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "b", structured = false}
+!CHECK: %[[ATTACH_D:.*]] = acc.attach varPtr(%[[D]] : !fir.ref<!fir.box<!fir.ptr<f32>>>)   -> !fir.ref<!fir.box<!fir.ptr<f32>>> {name = "d", structured = false}
+!CHECK: acc.enter_data dataOperands(%[[COPYIN_A]], %[[CREATE_B]], %[[ATTACH_D]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.box<!fir.ptr<f32>>>){{$}}
 
   !$acc enter data create(a) async
-!CHECK: acc.enter_data create([[A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async}
+!CHECK: %[[CREATE_A:.*]] = acc.create varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "a", structured = false}
+!CHECK: acc.enter_data dataOperands(%[[CREATE_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async}
 
   !$acc enter data create(a) wait
-!CHECK: acc.enter_data create([[A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {wait}
+!CHECK: %[[CREATE_A:.*]] = acc.create varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "a", structured = false}
+!CHECK: acc.enter_data dataOperands(%[[CREATE_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {wait}
 
   !$acc enter data create(a) async wait
-!CHECK: acc.enter_data create([[A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async, wait}
+!CHECK: %[[CREATE_A:.*]] = acc.create varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "a", structured = false}
+!CHECK: acc.enter_data dataOperands(%[[CREATE_A]] : !fir.ref<!fir.array<10x10xf32>>) attributes {async, wait}
 
   !$acc enter data create(a) async(1)
-!CHECK: [[ASYNC1:%.*]] = arith.constant 1 : i32
-!CHECK: acc.enter_data async([[ASYNC1]] : i32) create([[A]] : !fir.ref<!fir.array<10x10xf32>>)
+!CHECK: %[[CREATE_A:.*]] = acc.create varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "a", structured = false}
+!CHECK: %[[ASYNC1:.*]] = arith.constant 1 : i32
+!CHECK: acc.enter_data async(%[[ASYNC1]] : i32) dataOperands(%[[CREATE_A]] : !fir.ref<!fir.array<10x10xf32>>)
 
   !$acc enter data create(a) async(async)
-!CHECK: [[ASYNC2:%.*]] = fir.load %{{.*}} : !fir.ref<i32>
-!CHECK: acc.enter_data async([[ASYNC2]] : i32) create([[A]] : !fir.ref<!fir.array<10x10xf32>>)
+!CHECK: %[[CREATE_A:.*]] = acc.create varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "a", structured = false}
+!CHECK: %[[ASYNC2:.*]] = fir.load %{{.*}} : !fir.ref<i32>
+!CHECK: acc.enter_data async(%[[ASYNC2]] : i32) dataOperands(%[[CREATE_A]] : !fir.ref<!fir.array<10x10xf32>>)
 
   !$acc enter data create(a) wait(1)
-!CHECK: [[WAIT1:%.*]] = arith.constant 1 : i32
-!CHECK: acc.enter_data wait([[WAIT1]] : i32) create([[A]] : !fir.ref<!fir.array<10x10xf32>>)
+!CHECK: %[[CREATE_A:.*]] = acc.create varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "a", structured = false}
+!CHECK: %[[WAIT1:.*]] = arith.constant 1 : i32
+!CHECK: acc.enter_data wait(%[[WAIT1]] : i32) dataOperands(%[[CREATE_A]] : !fir.ref<!fir.array<10x10xf32>>)
 
   !$acc enter data create(a) wait(queues: 1, 2)
-!CHECK: [[WAIT2:%.*]] = arith.constant 1 : i32
-!CHECK: [[WAIT3:%.*]] = arith.constant 2 : i32
-!CHECK: acc.enter_data wait([[WAIT2]], [[WAIT3]] : i32, i32) create([[A]] : !fir.ref<!fir.array<10x10xf32>>)
+!CHECK: %[[CREATE_A:.*]] = acc.create varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "a", structured = false}
+!CHECK: %[[WAIT2:.*]] = arith.constant 1 : i32
+!CHECK: %[[WAIT3:.*]] = arith.constant 2 : i32
+!CHECK: acc.enter_data wait(%[[WAIT2]], %[[WAIT3]] : i32, i32) dataOperands(%[[CREATE_A]] : !fir.ref<!fir.array<10x10xf32>>)
 
   !$acc enter data create(a) wait(devnum: 1: queues: 1, 2)
-!CHECK: [[WAIT4:%.*]] = arith.constant 1 : i32
-!CHECK: [[WAIT5:%.*]] = arith.constant 2 : i32
-!CHECK: [[WAIT6:%.*]] = arith.constant 1 : i32
-!CHECK: acc.enter_data wait_devnum([[WAIT6]] : i32) wait([[WAIT4]], [[WAIT5]] : i32, i32) create([[A]] : !fir.ref<!fir.array<10x10xf32>>)
+!CHECK: %[[CREATE_A:.*]] = acc.create varPtr(%[[A]] : !fir.ref<!fir.array<10x10xf32>>)   -> !fir.ref<!fir.array<10x10xf32>> {name = "a", structured = false}
+!CHECK: %[[WAIT4:.*]] = arith.constant 1 : i32
+!CHECK: %[[WAIT5:.*]] = arith.constant 2 : i32
+!CHECK: %[[WAIT6:.*]] = arith.constant 1 : i32
+!CHECK: acc.enter_data wait_devnum(%[[WAIT6]] : i32) wait(%[[WAIT4]], %[[WAIT5]] : i32, i32) dataOperands(%[[CREATE_A]] : !fir.ref<!fir.array<10x10xf32>>)
 
 end subroutine acc_enter_data