[mlir][openacc] Use new private representation in acc.parallel
authorValentin Clement <clementval@gmail.com>
Mon, 22 May 2023 16:49:07 +0000 (09:49 -0700)
committerValentin Clement <clementval@gmail.com>
Mon, 22 May 2023 16:49:31 +0000 (09:49 -0700)
Update acc.parallel private operands list to use the new design
introduced in D150622.

Test in flang/test/Lower/OpenACC/acc-parallel.f90 and
flang/test/Lower/OpenACC/acc-parallel-loop.f90 are temporarly
disabled and will be enabled with updated lowering in the follow-up
patch.

Reviewed By: razvanlupusoru

Differential Revision: https://reviews.llvm.org/D150971

flang/test/Lower/OpenACC/acc-parallel-loop.f90
flang/test/Lower/OpenACC/acc-parallel.f90
mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
mlir/test/Dialect/OpenACC/ops.mlir

index fa7323d..1b91b7c 100644 (file)
@@ -442,18 +442,19 @@ subroutine acc_parallel_loop
 ! CHECK:        acc.yield
 ! CHECK-NEXT: }{{$}}
 
-  !$acc parallel loop private(a) firstprivate(b)
-  DO i = 1, n
-    a(i) = b(i)
-  END DO
-
-! CHECK:      acc.parallel firstprivate(%[[B]] : !fir.ref<!fir.array<10xf32>>) private(%[[A]] : !fir.ref<!fir.array<10xf32>>) {
-! CHECK:        acc.loop private(%[[A]] : !fir.ref<!fir.array<10xf32>>) {
-! CHECK:          fir.do_loop
-! CHECK:          acc.yield
-! CHECK-NEXT:   }{{$}}
-! CHECK:        acc.yield
-! CHECK-NEXT: }{{$}}
+! TODO: will be updated after lowering change in privatization to MLIR
+!  !$acc parallel loop private(a) firstprivate(b)
+!  DO i = 1, n
+!    a(i) = b(i)
+!  END DO
+
+! TODO:      acc.parallel firstprivate(%[[B]] : !fir.ref<!fir.array<10xf32>>) private(%[[A]] : !fir.ref<!fir.array<10xf32>>) {
+! TODO:        acc.loop private(%[[A]] : !fir.ref<!fir.array<10xf32>>) {
+! TODO:          fir.do_loop
+! TODO:          acc.yield
+! TODO-NEXT:   }{{$}}
+! TODO:        acc.yield
+! TODO-NEXT: }{{$}}
 
   !$acc parallel loop seq
   DO i = 1, n
index a880f15..93caa36 100644 (file)
@@ -288,11 +288,12 @@ subroutine acc_parallel
 !CHECK: acc.detach accPtr(%[[ATTACH_D]] : !fir.ptr<f32>) {dataClause = 10 : i64, name = "d"}
 !CHECK: acc.detach accPtr(%[[ATTACH_E]] : !fir.ptr<f32>) {dataClause = 10 : i64, name = "e"}
 
-  !$acc parallel private(a) firstprivate(b) private(c)
-  !$acc end parallel
+! TODO: will be updated after lowering change in privatization to MLIR
+!  !$acc parallel private(a) firstprivate(b) private(c)
+!  !$acc end parallel
 
-!CHECK:      acc.parallel firstprivate(%[[B]] : !fir.ref<!fir.array<10x10xf32>>) private(%[[A]], %[[C]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>) {
-!CHECK:        acc.yield
-!CHECK-NEXT: }{{$}}
+!TODO:      acc.parallel firstprivate(%[[B]] : !fir.ref<!fir.array<10x10xf32>>) private(%[[A]], %[[C]] : !fir.ref<!fir.array<10x10xf32>>, !fir.ref<!fir.array<10x10xf32>>) {
+!TODO:        acc.yield
+!TODO-NEXT: }{{$}}
 
 end subroutine acc_parallel
index 83f1eba..4cc9117 100644 (file)
@@ -636,7 +636,8 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
                        UnitAttr:$selfAttr,
                        OptionalAttr<OpenACC_ReductionOperatorAttr>:$reductionOp,
                        Variadic<AnyType>:$reductionOperands,
-                       Variadic<AnyType>:$gangPrivateOperands,
+                       Variadic<OpenACC_PointerLikeTypeInterface>:$gangPrivateOperands,
+                       OptionalAttr<SymbolRefArrayAttr>:$privatizations,
                        Variadic<AnyType>:$gangFirstPrivateOperands,
                        Variadic<OpenACC_PointerLikeTypeInterface>:$dataClauseOperands,
                        OptionalAttr<DefaultValueAttr>:$defaultAttr);
@@ -659,7 +660,9 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
             type($gangFirstPrivateOperands) `)`
       | `num_gangs` `(` $numGangs `:` type($numGangs) `)`
       | `num_workers` `(` $numWorkers `:` type($numWorkers) `)`
-      | `private` `(` $gangPrivateOperands `:` type($gangPrivateOperands) `)`
+      | `private` `(` custom<PrivatizationList>(
+            $gangPrivateOperands, type($gangPrivateOperands), $privatizations)
+        `)`
       | `vector_length` `(` $vectorLength `:` type($vectorLength) `)`
       | `wait` `(` $waitOperands `:` type($waitOperands) `)`
       | `self` `(` $selfCond `)`
index 430582d..89f97e0 100644 (file)
@@ -437,6 +437,43 @@ LogicalResult acc::ReductionRecipeOp::verifyRegions() {
 }
 
 //===----------------------------------------------------------------------===//
+// Custom parser and printer verifier for private clause
+//===----------------------------------------------------------------------===//
+
+static ParseResult parsePrivatizationList(
+    mlir::OpAsmParser &parser,
+    llvm::SmallVectorImpl<mlir::OpAsmParser::UnresolvedOperand> &operands,
+    llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &privatizationSymbols) {
+  llvm::SmallVector<SymbolRefAttr> privatizationVec;
+  if (failed(parser.parseCommaSeparatedList([&]() {
+        if (parser.parseAttribute(privatizationVec.emplace_back()) ||
+            parser.parseArrow() ||
+            parser.parseOperand(operands.emplace_back()) ||
+            parser.parseColonType(types.emplace_back()))
+          return failure();
+        return success();
+      })))
+    return failure();
+  llvm::SmallVector<mlir::Attribute> privatizations(privatizationVec.begin(),
+                                                    privatizationVec.end());
+  privatizationSymbols = ArrayAttr::get(parser.getContext(), privatizations);
+  return success();
+}
+
+static void
+printPrivatizationList(mlir::OpAsmPrinter &p, mlir::Operation *op,
+                       mlir::OperandRange privateOperands,
+                       mlir::TypeRange privateTypes,
+                       std::optional<mlir::ArrayAttr> privatizations) {
+  for (unsigned i = 0, e = privatizations->size(); i < e; ++i) {
+    if (i != 0)
+      p << ", ";
+    p << (*privatizations)[i] << " -> " << privateOperands[i] << " : "
+      << privateOperands[i].getType();
+  }
+}
+
+//===----------------------------------------------------------------------===//
 // ParallelOp
 //===----------------------------------------------------------------------===//
 
@@ -455,6 +492,45 @@ static LogicalResult checkDataOperands(Op op,
   return success();
 }
 
+static LogicalResult
+checkPrivatizationList(Operation *op,
+                       std::optional<mlir::ArrayAttr> privatizations,
+                       mlir::OperandRange privateOperands) {
+  if (!privateOperands.empty()) {
+    if (!privatizations || privatizations->size() != privateOperands.size())
+      return op->emitOpError() << "expected as many privatizations symbol "
+                                  "reference as private operands";
+  } else {
+    if (privatizations)
+      return op->emitOpError() << "unexpected privatizations symbol reference";
+    return success();
+  }
+
+  llvm::DenseSet<Value> privates;
+  for (auto args : llvm::zip(privateOperands, *privatizations)) {
+    mlir::Value privateOperand = std::get<0>(args);
+
+    if (!privates.insert(privateOperand).second)
+      return op->emitOpError() << "private operand appears more than once";
+
+    mlir::Type varType = privateOperand.getType();
+    auto symbolRef = std::get<1>(args).cast<SymbolRefAttr>();
+    auto decl =
+        SymbolTable::lookupNearestSymbolFrom<PrivateRecipeOp>(op, symbolRef);
+    if (!decl)
+      return op->emitOpError() << "expected symbol reference " << symbolRef
+                               << " to point to a private declaration";
+
+    if (decl.getType() && decl.getType() != varType)
+      return op->emitOpError()
+             << "expected private (" << varType
+             << ") to be the same type as private declaration ("
+             << decl.getType() << ")";
+  }
+
+  return success();
+}
+
 unsigned ParallelOp::getNumDataOperands() {
   return getReductionOperands().size() + getGangPrivateOperands().size() +
          getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
@@ -471,6 +547,9 @@ Value ParallelOp::getDataOperand(unsigned i) {
 }
 
 LogicalResult acc::ParallelOp::verify() {
+  if (failed(checkPrivatizationList(*this, getPrivatizations(),
+                                    getGangPrivateOperands())))
+    return failure();
   return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
 }
 
index be1973f..c0498f9 100644 (file)
@@ -114,6 +114,16 @@ func.func @compute2(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x
 
 // -----
 
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+  %0 = memref.alloc() : memref<10xf32>
+  acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+  memref.dealloc %arg0 : memref<10xf32> 
+  acc.terminator
+}
+
 func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>, %d: memref<10xf32>) -> memref<10xf32> {
   %lb = arith.constant 0 : index
   %st = arith.constant 1 : index
@@ -126,7 +136,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
   %pc = acc.present varPtr(%c : memref<10xf32>) -> memref<10xf32>
   %pd = acc.present varPtr(%d : memref<10xf32>) -> memref<10xf32>
   acc.data dataOperands(%pa, %pb, %pc, %pd: memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) {
-    acc.parallel num_gangs(%numGangs: i64) num_workers(%numWorkers: i64) private(%c : memref<10xf32>) {
+    acc.parallel num_gangs(%numGangs: i64) num_workers(%numWorkers: i64) private(@privatization_memref_10_f32 -> %c : memref<10xf32>) {
       acc.loop gang {
         scf.for %x = %lb to %c10 step %st {
           acc.loop worker {
@@ -168,7 +178,7 @@ func.func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10x
 // CHECK-NEXT:   [[NUMGANG:%.*]] = arith.constant 10 : i64
 // CHECK-NEXT:   [[NUMWORKERS:%.*]] = arith.constant 10 : i64
 // CHECK:        acc.data dataOperands(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : memref<10x10xf32>, memref<10x10xf32>, memref<10xf32>, memref<10xf32>) {
-// CHECK-NEXT:     acc.parallel num_gangs([[NUMGANG]] : i64) num_workers([[NUMWORKERS]] : i64) private([[ARG2]] : memref<10xf32>) {
+// CHECK-NEXT:     acc.parallel num_gangs([[NUMGANG]] : i64) num_workers([[NUMWORKERS]] : i64) private(@privatization_memref_10_f32 -> [[ARG2]] : memref<10xf32>) {
 // CHECK-NEXT:       acc.loop gang {
 // CHECK-NEXT:         scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
 // CHECK-NEXT:           acc.loop worker {
@@ -358,6 +368,26 @@ func.func @acc_loop_multiple_block() {
 
 // -----
 
+acc.private.recipe @privatization_memref_10_f32 : memref<10xf32> init {
+^bb0(%arg0: memref<10xf32>):
+  %0 = memref.alloc() : memref<10xf32>
+  acc.yield %0 : memref<10xf32>
+} destroy {
+^bb0(%arg0: memref<10xf32>):
+  memref.dealloc %arg0 : memref<10xf32> 
+  acc.terminator
+}
+
+acc.private.recipe @privatization_memref_10_10_f32 : memref<10x10xf32> init {
+^bb0(%arg0: memref<10x10xf32>):
+  %0 = memref.alloc() : memref<10x10xf32>
+  acc.yield %0 : memref<10x10xf32>
+} destroy {
+^bb0(%arg0: memref<10x10xf32>):
+  memref.dealloc %arg0 : memref<10x10xf32> 
+  acc.terminator
+}
+
 func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () {
   %i64value = arith.constant 1 : i64
   %i32value = arith.constant 1 : i32
@@ -394,7 +424,7 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
   }
   acc.parallel vector_length(%idxValue: index) {
   }
-  acc.parallel private(%a, %c : memref<10xf32>, memref<10x10xf32>) firstprivate(%b: memref<10xf32>) {
+  acc.parallel private(@privatization_memref_10_f32 -> %a : memref<10xf32>, @privatization_memref_10_10_f32 -> %c : memref<10x10xf32>) firstprivate(%b: memref<10xf32>) {
   }
   acc.parallel {
   } attributes {defaultAttr = #acc<defaultvalue none>}
@@ -445,7 +475,7 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
 // CHECK-NEXT: }
 // CHECK:      acc.parallel vector_length([[IDXVALUE]] : index) {
 // CHECK-NEXT: }
-// CHECK:      acc.parallel firstprivate([[ARGB]] : memref<10xf32>) private([[ARGA]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) {
+// CHECK:      acc.parallel firstprivate([[ARGB]] : memref<10xf32>) private(@privatization_memref_10_f32 -> [[ARGA]] : memref<10xf32>, @privatization_memref_10_10_f32 -> [[ARGC]] : memref<10x10xf32>) {
 // CHECK-NEXT: }
 // CHECK:      acc.parallel {
 // CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>}