[mlir][openacc] Add wait attribute and variadic operand
authorValentin Clement <clementval@gmail.com>
Fri, 30 Jun 2023 17:04:17 +0000 (10:04 -0700)
committerValentin Clement <clementval@gmail.com>
Fri, 30 Jun 2023 17:04:49 +0000 (10:04 -0700)
OpenACC 3.2 allowed the wait clause to the data construct. This patch
adds a unit attribute and a variadic operand to the data operation to model
the wait clause in a similar way it was added to other data operation.
The attribute models the presence of the clause without any argument. When
arguments are provided they are placed in the wait operand.

Depends on D154111

Reviewed By: razvanlupusoru

Differential Revision: https://reviews.llvm.org/D154131

flang/lib/Lower/OpenACC.cpp
mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
mlir/test/Dialect/OpenACC/ops.mlir

index 2e32abc..4313413 100644 (file)
@@ -1353,19 +1353,20 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
                          Fortran::semantics::SemanticsContext &semanticsContext,
                          Fortran::lower::StatementContext &stmtCtx,
                          const Fortran::parser::AccClauseList &accClauseList) {
-  mlir::Value ifCond, async;
+  mlir::Value ifCond, async, waitDevnum;
   llvm::SmallVector<mlir::Value> attachEntryOperands, createEntryOperands,
-      copyEntryOperands, copyoutEntryOperands, dataClauseOperands;
+      copyEntryOperands, copyoutEntryOperands, dataClauseOperands, waitOperands;
 
-  // Async has an optional value but can be present with
+  // Async and wait have an optional value but can be present with
   // no value as well. When there is no value, the op has an attribute to
   // represent the clause.
   bool addAsyncAttr = false;
+  bool addWaitAttr = false;
 
   fir::FirOpBuilder &builder = converter.getFirOpBuilder();
 
   // Lower clauses values mapped to operands.
-  // Keep track of each group of operands separatly as clauses can appear
+  // Keep track of each group of operands separately as clauses can appear
   // more than once.
   for (const Fortran::parser::AccClause &clause : accClauseList.v) {
     mlir::Location clauseLocation = converter.genLocation(clause.source);
@@ -1450,12 +1451,15 @@ static void genACCDataOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<int32_t> operandSegments;
   addOperand(operands, operandSegments, ifCond);
   addOperand(operands, operandSegments, async);
+  addOperand(operands, operandSegments, waitDevnum);
+  addOperands(operands, operandSegments, waitOperands);
   addOperands(operands, operandSegments, dataClauseOperands);
 
   auto dataOp = createRegionOp<mlir::acc::DataOp, mlir::acc::TerminatorOp>(
       builder, currentLocation, operands, operandSegments);
 
   dataOp.setAsyncAttr(addAsyncAttr);
+  dataOp.setAsyncAttr(addWaitAttr);
 
   auto insPt = builder.saveInsertionPoint();
   builder.setInsertionPointAfter(dataOp);
index 5960dfa..a1af1cf 100644 (file)
@@ -867,6 +867,9 @@ def OpenACC_DataOp : OpenACC_Op<"data",
   let arguments = (ins Optional<I1>:$ifCond,
                        Optional<IntOrIndex>:$async,
                        UnitAttr:$asyncAttr,
+                       Optional<IntOrIndex>:$waitDevnum,
+                       Variadic<IntOrIndex>:$waitOperands,
+                       UnitAttr:$waitAttr,
                        Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
                        OptionalAttr<DefaultValueAttr>:$defaultAttr);
 
@@ -885,6 +888,8 @@ def OpenACC_DataOp : OpenACC_Op<"data",
         `if` `(` $ifCond `)`
       | `async` `(` $async `:` type($async) `)`
       | `dataOperands` `(` $dataClauseOperands `:` type($dataClauseOperands) `)`
+      | `wait_devnum` `(` $waitDevnum `:` type($waitDevnum) `)`
+      | `wait` `(` $waitOperands `:` type($waitOperands) `)`
     )
     $region attr-dict-with-keyword
   }];
index 2ac6899..9c6cffa 100644 (file)
@@ -868,6 +868,7 @@ unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
 Value DataOp::getDataOperand(unsigned i) {
   unsigned numOptional = getIfCond() ? 1 : 0;
   numOptional += getAsync() ? 1 : 0;
+  numOptional += getWaitOperands().size();
   return getOperand(numOptional + i);
 }
 
index a1f9432..fa2f8c9 100644 (file)
@@ -822,6 +822,17 @@ func.func @testdataop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () {
   acc.data async(%a1 : i64) {
   } attributes { defaultAttr = #acc<defaultvalue none>, async }
 
+  acc.data {
+  } attributes { defaultAttr = #acc<defaultvalue none>, wait }
+
+  %w1 = arith.constant 1 : i64
+  acc.data wait(%w1 : i64) {
+  } attributes { defaultAttr = #acc<defaultvalue none>, wait }
+
+  %wd1 = arith.constant 1 : i64
+  acc.data wait_devnum(%wd1 : i64) wait(%w1 : i64) {
+  } attributes { defaultAttr = #acc<defaultvalue none>, wait }
+
   return
 }
 
@@ -927,6 +938,15 @@ func.func @testdataop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () {
 // CHECK:      acc.data async(%{{.*}} : i64) {
 // CHECK-NEXT: } attributes {async, defaultAttr = #acc<defaultvalue none>}
 
+// CHECK:      acc.data {
+// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>, wait}
+
+// CHECK:      acc.data wait(%{{.*}} : i64) {
+// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>, wait}
+
+// CHECK:      acc.data wait_devnum(%{{.*}} : i64) wait(%{{.*}} : i64) {
+// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>, wait}
+
 // -----
 
 func.func @testupdateop(%a: memref<f32>, %b: memref<f32>, %c: memref<f32>) -> () {