[flang][openacc] Lower enter data directive
authorValentin Clement <clementval@gmail.com>
Wed, 4 Nov 2020 20:17:15 +0000 (15:17 -0500)
committerclementval <clementval@gmail.com>
Wed, 4 Nov 2020 20:17:50 +0000 (15:17 -0500)
This patch upstream the lowering of Enter Data directive that was initially done in
https://github.com/flang-compiler/f18-llvm-project/pull/526

Reviewed By: schweitz

Differential Revision: https://reviews.llvm.org/D90470

flang/lib/Lower/OpenACC.cpp

index c254d04..3087203 100644 (file)
@@ -529,6 +529,105 @@ genACC(Fortran::lower::AbstractConverter &converter,
 }
 
 static void
+genACCEnterDataOp(Fortran::lower::AbstractConverter &converter,
+                  const Fortran::parser::AccClauseList &accClauseList) {
+  mlir::Value ifCond, async, waitDevnum;
+  SmallVector<Value, 2> copyinOperands, createOperands, createZeroOperands,
+      attachOperands, waitOperands;
+
+  // Async, wait and self clause have optional values but can be present with
+  // no value as well. When there is no value, the op has an attribute to
+  // represent the clause.
+  bool addAsyncAttr = false;
+  bool addWaitAttr = false;
+
+  auto &firOpBuilder = converter.getFirOpBuilder();
+  auto currentLocation = converter.getCurrentLocation();
+
+  // Lower clauses values mapped to operands.
+  // Keep track of each group of operands separatly as clauses can appear
+  // more than once.
+  for (const auto &clause : accClauseList.v) {
+    if (const auto *ifClause =
+            std::get_if<Fortran::parser::AccClause::If>(&clause.u)) {
+      mlir::Value cond = fir::getBase(
+          converter.genExprValue(*Fortran::semantics::GetExpr(ifClause->v)));
+      ifCond = firOpBuilder.createConvert(currentLocation,
+                                          firOpBuilder.getI1Type(), cond);
+    } else if (const auto *asyncClause =
+                   std::get_if<Fortran::parser::AccClause::Async>(&clause.u)) {
+      const auto &asyncClauseValue = asyncClause->v;
+      if (asyncClauseValue) { // async has a value.
+        async = fir::getBase(converter.genExprValue(
+            *Fortran::semantics::GetExpr(*asyncClauseValue)));
+      } else {
+        addAsyncAttr = true;
+      }
+    } else if (const auto *waitClause =
+                   std::get_if<Fortran::parser::AccClause::Wait>(&clause.u)) {
+      const auto &waitClauseValue = waitClause->v;
+      if (waitClauseValue) { // wait has a value.
+        const Fortran::parser::AccWaitArgument &waitArg = *waitClauseValue;
+        const std::list<Fortran::parser::ScalarIntExpr> &waitList =
+            std::get<std::list<Fortran::parser::ScalarIntExpr>>(waitArg.t);
+        for (const Fortran::parser::ScalarIntExpr &value : waitList) {
+          mlir::Value v = fir::getBase(
+              converter.genExprValue(*Fortran::semantics::GetExpr(value)));
+          waitOperands.push_back(v);
+        }
+
+        const std::optional<Fortran::parser::ScalarIntExpr> &waitDevnumValue =
+            std::get<std::optional<Fortran::parser::ScalarIntExpr>>(waitArg.t);
+        if (waitDevnumValue)
+          waitDevnum = fir::getBase(converter.genExprValue(
+              *Fortran::semantics::GetExpr(*waitDevnumValue)));
+      } else {
+        addWaitAttr = true;
+      }
+    } else if (const auto *copyinClause =
+                   std::get_if<Fortran::parser::AccClause::Copyin>(&clause.u)) {
+      const Fortran::parser::AccObjectListWithModifier &listWithModifier =
+          copyinClause->v;
+      const Fortran::parser::AccObjectList &accObjectList =
+          std::get<Fortran::parser::AccObjectList>(listWithModifier.t);
+      genObjectList(accObjectList, converter, copyinOperands);
+    } else if (const auto *createClause =
+                   std::get_if<Fortran::parser::AccClause::Create>(&clause.u)) {
+      genObjectListWithModifier<Fortran::parser::AccClause::Create>(
+          createClause, converter,
+          Fortran::parser::AccDataModifier::Modifier::Zero, createZeroOperands,
+          createOperands);
+    } else if (const auto *attachClause =
+                   std::get_if<Fortran::parser::AccClause::Attach>(&clause.u)) {
+      genObjectList(attachClause->v, converter, attachOperands);
+    } else {
+      llvm::report_fatal_error(
+          "Unknown clause in ENTER DATA directive lowering");
+    }
+  }
+
+  // Prepare the operand segement size attribute and the operands value range.
+  SmallVector<mlir::Value, 16> operands;
+  SmallVector<int32_t, 8> operandSegments;
+  addOperand(operands, operandSegments, ifCond);
+  addOperand(operands, operandSegments, async);
+  addOperand(operands, operandSegments, waitDevnum);
+  addOperands(operands, operandSegments, waitOperands);
+  addOperands(operands, operandSegments, copyinOperands);
+  addOperands(operands, operandSegments, createOperands);
+  addOperands(operands, operandSegments, createZeroOperands);
+  addOperands(operands, operandSegments, attachOperands);
+
+  auto enterDataOp = createSimpleOp<mlir::acc::EnterDataOp>(
+      firOpBuilder, currentLocation, operands, operandSegments);
+
+  if (addAsyncAttr)
+    enterDataOp.asyncAttr(firOpBuilder.getUnitAttr());
+  if (addWaitAttr)
+    enterDataOp.waitAttr(firOpBuilder.getUnitAttr());
+}
+
+static void
 genACCExitDataOp(Fortran::lower::AbstractConverter &converter,
                  const Fortran::parser::AccClauseList &accClauseList) {
   mlir::Value ifCond, async, waitDevnum;
@@ -605,8 +704,8 @@ genACCExitDataOp(Fortran::lower::AbstractConverter &converter,
   }
 
   // Prepare the operand segement size attribute and the operands value range.
-  SmallVector<Value, 8> operands;
-  SmallVector<int32_t, 8> operandSegments;
+  SmallVector<mlir::Value, 14> operands;
+  SmallVector<int32_t, 7> operandSegments;
   addOperand(operands, operandSegments, ifCond);
   addOperand(operands, operandSegments, async);
   addOperand(operands, operandSegments, waitDevnum);
@@ -636,7 +735,7 @@ genACC(Fortran::lower::AbstractConverter &converter,
       std::get<Fortran::parser::AccClauseList>(standaloneConstruct.t);
 
   if (standaloneDirective.v == llvm::acc::Directive::ACCD_enter_data) {
-    TODO("OpenACC enter data directive not lowered yet!");
+    genACCEnterDataOp(converter, accClauseList);
   } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_exit_data) {
     genACCExitDataOp(converter, accClauseList);
   } else if (standaloneDirective.v == llvm::acc::Directive::ACCD_init) {