[mlir][openacc] Refine data operation data clause attribute
authorRazvan Lupusoru <rlupusoru@nvidia.com>
Wed, 19 Apr 2023 17:59:41 +0000 (10:59 -0700)
committerRazvan Lupusoru <rlupusoru@nvidia.com>
Wed, 19 Apr 2023 22:06:21 +0000 (15:06 -0700)
The data operations added in D148389 hold two data clause fields:
"dataClause" and "decomposedFrom". However, in most cases, dataClause
field holds a default value (except for acc_copyin_readonly,
acc_create_zero, and acc_copyout_zero).

The decomposedFrom field holds the original clause specified by user.
As work began on lowering to these new operations [1], it seems that
having both fields adds a bit of ambiguity. There is only one scenario
where we actually intended to use both:
acc data copyout(zero:)

The original intent was that this clause would be decomposed to
the following operations:
acc.create {dataClause = acc_create_zero, decomposedFrom =
acc_copyout_zero}
...
acc.copyout {dataClause = acc_copyout_zero}

However, we can encode the zero semantics like so without need of both:
acc.create {dataClause = acc_copyout_zero}
...
acc.copyout {dataClause = acc_copyout_zero}

Thus get rid of the decomposedFrom field and update verifier checks
to check for all data clauses that can be decomposed to the particular
operation.

So now the dataClause holds the original user's clause which simplifies
understanding of the operation.

[1] https://reviews.llvm.org/D148721

Reviewed By: clementval

Differential Revision: https://reviews.llvm.org/D148731

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
mlir/test/Dialect/OpenACC/ops.mlir

index 405ff78..ec6b28c 100644 (file)
@@ -139,7 +139,6 @@ class OpenACC_DataEntryOp<string mnemonic, string clause, list<Trait> traits = [
                        Optional<OpenACC_PointerLikeTypeInterface>:$varPtrPtr,
                        Variadic<OpenACC_DataBoundsType>:$bounds, /* rank-0 to rank-{n-1} */
                        DefaultValuedAttr<OpenACC_DataClauseEnum,clause>:$dataClause,
-                       OptionalAttr<OpenACC_DataClauseEnum>:$decomposedFrom,
                        DefaultValuedAttr<BoolAttr, "true">:$structured,
                        DefaultValuedAttr<BoolAttr, "false">:$implicit,
                        OptionalAttr<StrAttr>:$name);
@@ -242,7 +241,6 @@ class OpenACC_DataExitOp<string mnemonic, string clause, list<Trait> traits = []
                        OpenACC_PointerLikeTypeInterface:$accPtr,
                        Variadic<OpenACC_DataBoundsType>:$bounds,
                        DefaultValuedAttr<OpenACC_DataClauseEnum,clause>:$dataClause,
-                       OptionalAttr<OpenACC_DataClauseEnum>:$decomposedFrom,
                        DefaultValuedAttr<BoolAttr, "true">:$structured,
                        DefaultValuedAttr<BoolAttr, "false">:$implicit,
                        OptionalAttr<StrAttr>:$name);
index 17f1bbd..ff431bf 100644 (file)
@@ -88,10 +88,13 @@ LogicalResult acc::PresentOp::verify() {
 // CopyinOp
 //===----------------------------------------------------------------------===//
 LogicalResult acc::CopyinOp::verify() {
+  // Test for all clauses this operation can be decomposed from:
   if (getDataClause() != acc::DataClause::acc_copyin &&
-      getDataClause() != acc::DataClause::acc_copyin_readonly) {
+      getDataClause() != acc::DataClause::acc_copyin_readonly &&
+      getDataClause() != acc::DataClause::acc_copy) {
     return emitError(
-        "data clause associated with copyin operation must match its intent");
+        "data clause associated with copyin operation must match its intent"
+        " or specify original clause this operation was decomposed from");
   }
   return success();
 }
@@ -104,16 +107,22 @@ bool acc::CopyinOp::isCopyinReadonly() {
 // CreateOp
 //===----------------------------------------------------------------------===//
 LogicalResult acc::CreateOp::verify() {
+  // Test for all clauses this operation can be decomposed from:
   if (getDataClause() != acc::DataClause::acc_create &&
-      getDataClause() != acc::DataClause::acc_create_zero) {
+      getDataClause() != acc::DataClause::acc_create_zero &&
+      getDataClause() != acc::DataClause::acc_copyout &&
+      getDataClause() != acc::DataClause::acc_copyout_zero) {
     return emitError(
-        "data clause associated with create operation must match its intent");
+        "data clause associated with create operation must match its intent"
+        " or specify original clause this operation was decomposed from");
   }
   return success();
 }
 
 bool acc::CreateOp::isCreateZero() {
-  return getDataClause() == acc::DataClause::acc_create_zero;
+  // The zero modifier is encoded in the data clause.
+  return getDataClause() == acc::DataClause::acc_create_zero ||
+         getDataClause() == acc::DataClause::acc_copyout_zero;
 }
 
 //===----------------------------------------------------------------------===//
@@ -142,7 +151,13 @@ LogicalResult acc::AttachOp::verify() {
 // GetDevicePtrOp
 //===----------------------------------------------------------------------===//
 LogicalResult acc::GetDevicePtrOp::verify() {
-  if (getDataClause() != acc::DataClause::acc_getdeviceptr) {
+  // This operation is also created for use in unstructured constructs
+  // when we need an "accPtr" to feed to exit operation. Thus we test
+  // for those cases as well:
+  if (getDataClause() != acc::DataClause::acc_getdeviceptr &&
+      getDataClause() != acc::DataClause::acc_copyout &&
+      getDataClause() != acc::DataClause::acc_delete &&
+      getDataClause() != acc::DataClause::acc_detach) {
     return emitError("getDevicePtr mismatch");
   }
   return success();
@@ -152,10 +167,13 @@ LogicalResult acc::GetDevicePtrOp::verify() {
 // CopyoutOp
 //===----------------------------------------------------------------------===//
 LogicalResult acc::CopyoutOp::verify() {
+  // Test for all clauses this operation can be decomposed from:
   if (getDataClause() != acc::DataClause::acc_copyout &&
-      getDataClause() != acc::DataClause::acc_copyout_zero) {
+      getDataClause() != acc::DataClause::acc_copyout_zero &&
+      getDataClause() != acc::DataClause::acc_copy) {
     return emitError(
-        "data clause associated with copyout operation must match its intent");
+        "data clause associated with copyout operation must match its intent"
+        " or specify original clause this operation was decomposed from");
   }
   if (!getVarPtr() || !getAccPtr()) {
     return emitError("must have both host and device pointers");
@@ -171,9 +189,13 @@ bool acc::CopyoutOp::isCopyoutZero() {
 // DeleteOp
 //===----------------------------------------------------------------------===//
 LogicalResult acc::DeleteOp::verify() {
-  if (getDataClause() != acc::DataClause::acc_delete) {
+  // Test for all clauses this operation can be decomposed from:
+  if (getDataClause() != acc::DataClause::acc_delete &&
+      getDataClause() != acc::DataClause::acc_create &&
+      getDataClause() != acc::DataClause::acc_create_zero) {
     return emitError(
-        "data clause associated with delete operation must match its intent");
+        "data clause associated with delete operation must match its intent"
+        " or specify original clause this operation was decomposed from");
   }
   if (!getVarPtr() && !getAccPtr()) {
     return emitError("must have either host or device pointer");
@@ -185,9 +207,12 @@ LogicalResult acc::DeleteOp::verify() {
 // DetachOp
 //===----------------------------------------------------------------------===//
 LogicalResult acc::DetachOp::verify() {
-  if (getDataClause() != acc::DataClause::acc_detach) {
+  // Test for all clauses this operation can be decomposed from:
+  if (getDataClause() != acc::DataClause::acc_detach &&
+      getDataClause() != acc::DataClause::acc_attach) {
     return emitError(
-        "data clause associated with detach operation must match its intent");
+        "data clause associated with detach operation must match its intent"
+        " or specify original clause this operation was decomposed from");
   }
   if (!getVarPtr() && !getAccPtr()) {
     return emitError("must have either host or device pointer");
index 825bc06..6bdb8a2 100644 (file)
@@ -994,19 +994,19 @@ func.func @teststructureddataclauseops(%a: memref<10xf32>, %b: memref<memref<10x
   acc.kernels dataOperands(%copyinreadonly : memref<10xf32>) {
   }
 
-  %copyinfromcopy = acc.copyin varPtr(%a : memref<10xf32>) -> memref<10xf32> {decomposedFrom = 3}
+  %copyinfromcopy = acc.copyin varPtr(%a : memref<10xf32>) -> memref<10xf32> {dataClause = 3}
   acc.serial dataOperands(%copyinfromcopy : memref<10xf32>) {
   }
-  acc.copyout accPtr(%copyinfromcopy : memref<10xf32>) to varPtr(%a : memref<10xf32>) {decomposedFrom = 3}
+  acc.copyout accPtr(%copyinfromcopy : memref<10xf32>) to varPtr(%a : memref<10xf32>) {dataClause = 3}
 
   %create = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32>
   %createimplicit = acc.create varPtr(%c : memref<10x20xf32>) -> memref<10x20xf32> {implicit = true}
   acc.parallel dataOperands(%create, %createimplicit : memref<10xf32>, memref<10x20xf32>) {
   }
-  acc.delete accPtr(%create : memref<10xf32>) {decomposedFrom = 7}
-  acc.delete accPtr(%createimplicit : memref<10x20xf32>) {decomposedFrom = 7, implicit = true}
+  acc.delete accPtr(%create : memref<10xf32>) {dataClause = 7}
+  acc.delete accPtr(%createimplicit : memref<10x20xf32>) {dataClause = 7, implicit = true}
 
-  %copyoutzero = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32> {decomposedFrom = 5}
+  %copyoutzero = acc.create varPtr(%a : memref<10xf32>) -> memref<10xf32> {dataClause = 5}
   acc.parallel dataOperands(%copyoutzero: memref<10xf32>) {
   }
   acc.copyout accPtr(%copyoutzero : memref<10xf32>) to varPtr(%a : memref<10xf32>) {dataClause = 5}
@@ -1014,12 +1014,12 @@ func.func @teststructureddataclauseops(%a: memref<10xf32>, %b: memref<memref<10x
   %attach = acc.attach varPtr(%b : memref<memref<10xf32>>) -> memref<memref<10xf32>>
   acc.parallel dataOperands(%attach : memref<memref<10xf32>>) {
   }
-  acc.detach accPtr(%attach : memref<memref<10xf32>>) {decomposedFrom = 10}
+  acc.detach accPtr(%attach : memref<memref<10xf32>>) {dataClause = 10}
 
-  %copyinparent = acc.copyin varPtr(%a : memref<10xf32>) varPtrPtr(%b : memref<memref<10xf32>>) -> memref<10xf32> {decomposedFrom = 3}
+  %copyinparent = acc.copyin varPtr(%a : memref<10xf32>) varPtrPtr(%b : memref<memref<10xf32>>) -> memref<10xf32> {dataClause = 3}
   acc.parallel dataOperands(%copyinparent : memref<10xf32>) {
   }
-  acc.copyout accPtr(%copyinparent : memref<10xf32>) to varPtr(%a : memref<10xf32>) {decomposedFrom = 3}
+  acc.copyout accPtr(%copyinparent : memref<10xf32>) to varPtr(%a : memref<10xf32>) {dataClause = 3}
 
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
@@ -1037,10 +1037,10 @@ func.func @teststructureddataclauseops(%a: memref<10xf32>, %b: memref<memref<10x
   }
 
   %bounds1partial = acc.bounds lowerbound(%c4 : index) upperbound(%c9 : index) stride(%c1 : index)
-  %copyinpartial = acc.copyin varPtr(%a : memref<10xf32>) bounds(%bounds1partial) -> memref<10xf32> {decomposedFrom = 3}
+  %copyinpartial = acc.copyin varPtr(%a : memref<10xf32>) bounds(%bounds1partial) -> memref<10xf32> {dataClause = 3}
   acc.parallel dataOperands(%copyinpartial : memref<10xf32>) {
   }
-  acc.copyout accPtr(%copyinpartial : memref<10xf32>) bounds(%bounds1partial) to varPtr(%a : memref<10xf32>) {decomposedFrom = 3}
+  acc.copyout accPtr(%copyinpartial : memref<10xf32>) bounds(%bounds1partial) to varPtr(%a : memref<10xf32>) {dataClause = 3}
 
   return
 }
@@ -1058,28 +1058,28 @@ func.func @teststructureddataclauseops(%a: memref<10xf32>, %b: memref<memref<10x
 // CHECK: [[COPYINRO:%.*]] = acc.copyin varPtr([[ARGA]] : memref<10xf32>) -> memref<10xf32> {dataClause = 2 : i64}
 // CHECK-NEXT: acc.kernels dataOperands([[COPYINRO]] : memref<10xf32>) {
 // CHECK-NEXT: }
-// CHECK: [[COPYINCOPY:%.*]] = acc.copyin varPtr([[ARGA]] : memref<10xf32>) -> memref<10xf32> {decomposedFrom = 3 : i64}
+// CHECK: [[COPYINCOPY:%.*]] = acc.copyin varPtr([[ARGA]] : memref<10xf32>) -> memref<10xf32> {dataClause = 3 : i64}
 // CHECK-NEXT: acc.serial dataOperands([[COPYINCOPY]] : memref<10xf32>) {
 // CHECK-NEXT: }
-// CHECK-NEXT: acc.copyout accPtr([[COPYINCOPY]] : memref<10xf32>) to varPtr([[ARGA]] : memref<10xf32>) {decomposedFrom = 3 : i64}
+// CHECK-NEXT: acc.copyout accPtr([[COPYINCOPY]] : memref<10xf32>) to varPtr([[ARGA]] : memref<10xf32>) {dataClause = 3 : i64}
 // CHECK: [[CREATE:%.*]] = acc.create varPtr([[ARGA]] : memref<10xf32>) -> memref<10xf32>
 // CHECK-NEXT: [[CREATEIMP:%.*]] = acc.create varPtr([[ARGC]] : memref<10x20xf32>) -> memref<10x20xf32> {implicit = true}
 // CHECK-NEXT: acc.parallel dataOperands([[CREATE]], [[CREATEIMP]] : memref<10xf32>, memref<10x20xf32>) {
 // CHECK-NEXT: }
-// CHECK-NEXT: acc.delete accPtr([[CREATE]] : memref<10xf32>) {decomposedFrom = 7 : i64}
-// CHECK-NEXT: acc.delete accPtr([[CREATEIMP]] : memref<10x20xf32>) {decomposedFrom = 7 : i64, implicit = true}
-// CHECK: [[COPYOUTZ:%.*]] = acc.create varPtr([[ARGA]] : memref<10xf32>) -> memref<10xf32> {decomposedFrom = 5 : i64}
+// CHECK-NEXT: acc.delete accPtr([[CREATE]] : memref<10xf32>) {dataClause = 7 : i64}
+// CHECK-NEXT: acc.delete accPtr([[CREATEIMP]] : memref<10x20xf32>) {dataClause = 7 : i64, implicit = true}
+// CHECK: [[COPYOUTZ:%.*]] = acc.create varPtr([[ARGA]] : memref<10xf32>) -> memref<10xf32> {dataClause = 5 : i64}
 // CHECK-NEXT: acc.parallel dataOperands([[COPYOUTZ]] : memref<10xf32>) {
 // CHECK-NEXT: }
 // CHECK-NEXT: acc.copyout accPtr([[COPYOUTZ]] : memref<10xf32>) to varPtr([[ARGA]] : memref<10xf32>) {dataClause = 5 : i64}
 // CHECK: [[ATTACH:%.*]] = acc.attach varPtr([[ARGB]] : memref<memref<10xf32>>) -> memref<memref<10xf32>>
 // CHECK-NEXT: acc.parallel dataOperands([[ATTACH]] : memref<memref<10xf32>>) {
 // CHECK-NEXT: }
-// CHECK-NEXT: acc.detach accPtr([[ATTACH]] : memref<memref<10xf32>>) {decomposedFrom = 10 : i64}
-// CHECK: [[COPYINP:%.*]] = acc.copyin varPtr([[ARGA]] : memref<10xf32>) varPtrPtr([[ARGB]] : memref<memref<10xf32>>) -> memref<10xf32> {decomposedFrom = 3 : i64}
+// CHECK-NEXT: acc.detach accPtr([[ATTACH]] : memref<memref<10xf32>>) {dataClause = 10 : i64}
+// CHECK: [[COPYINP:%.*]] = acc.copyin varPtr([[ARGA]] : memref<10xf32>) varPtrPtr([[ARGB]] : memref<memref<10xf32>>) -> memref<10xf32> {dataClause = 3 : i64}
 // CHECK-NEXT: acc.parallel dataOperands([[COPYINP]] : memref<10xf32>) {
 // CHECK-NEXT: }
-// CHECK-NEXT: acc.copyout accPtr([[COPYINP]] : memref<10xf32>) to varPtr([[ARGA]] : memref<10xf32>) {decomposedFrom = 3 : i64}
+// CHECK-NEXT: acc.copyout accPtr([[COPYINP]] : memref<10xf32>) to varPtr([[ARGA]] : memref<10xf32>) {dataClause = 3 : i64}
 // CHECK-DAG: [[CON0:%.*]] = arith.constant 0 : index
 // CHECK-DAG: [[CON1:%.*]] = arith.constant 1 : index
 // CHECK-DAG: [[CON4:%.*]] = arith.constant 4 : index
@@ -1092,10 +1092,10 @@ func.func @teststructureddataclauseops(%a: memref<10xf32>, %b: memref<memref<10x
 // CHECK-NEXT: acc.parallel dataOperands([[COPYINF1]], [[COPYINF2]] : memref<10xf32>, memref<10x20xf32>) {
 // CHECK-NEXT: }
 // CHECK: [[BOUNDS1P:%.*]] = acc.bounds lowerbound([[CON4]] : index) upperbound([[CON9]] : index) stride([[CON1]] : index)
-// CHECK-NEXT: [[COPYINPART:%.*]] = acc.copyin varPtr([[ARGA]] : memref<10xf32>) bounds([[BOUNDS1P]]) -> memref<10xf32> {decomposedFrom = 3 : i64}
+// CHECK-NEXT: [[COPYINPART:%.*]] = acc.copyin varPtr([[ARGA]] : memref<10xf32>) bounds([[BOUNDS1P]]) -> memref<10xf32> {dataClause = 3 : i64}
 // CHECK-NEXT: acc.parallel dataOperands([[COPYINPART]] : memref<10xf32>) {
 // CHECK-NEXT: }
-// CHECK-NEXT: acc.copyout accPtr([[COPYINPART]] : memref<10xf32>) bounds([[BOUNDS1P]]) to varPtr([[ARGA]] : memref<10xf32>) {decomposedFrom = 3 : i64}
+// CHECK-NEXT: acc.copyout accPtr([[COPYINPART]] : memref<10xf32>) bounds([[BOUNDS1P]]) to varPtr([[ARGA]] : memref<10xf32>) {dataClause = 3 : i64}
 
 // -----
 
@@ -1103,7 +1103,7 @@ func.func @testunstructuredclauseops(%a: memref<10xf32>) -> () {
   %copyin = acc.copyin varPtr(%a : memref<10xf32>) -> memref<10xf32> {structured = false}
   acc.enter_data dataOperands(%copyin : memref<10xf32>)
 
-  %devptr = acc.getdeviceptr varPtr(%a : memref<10xf32>) -> memref<10xf32> {decomposedFrom = 4}
+  %devptr = acc.getdeviceptr varPtr(%a : memref<10xf32>) -> memref<10xf32> {dataClause = 4}
   acc.exit_data dataOperands(%devptr : memref<10xf32>)
   acc.copyout accPtr(%devptr : memref<10xf32>) to varPtr(%a : memref<10xf32>) {structured = false}
 
@@ -1113,6 +1113,6 @@ func.func @testunstructuredclauseops(%a: memref<10xf32>) -> () {
 // CHECK: func.func @testunstructuredclauseops([[ARGA:%.*]]: memref<10xf32>) {
 // CHECK: [[COPYIN:%.*]] = acc.copyin varPtr([[ARGA]] : memref<10xf32>) -> memref<10xf32> {structured = false}
 // CHECK-NEXT: acc.enter_data dataOperands([[COPYIN]] : memref<10xf32>)
-// CHECK: [[DEVPTR:%.*]] = acc.getdeviceptr varPtr([[ARGA]] : memref<10xf32>) -> memref<10xf32> {decomposedFrom = 4 : i64}
+// CHECK: [[DEVPTR:%.*]] = acc.getdeviceptr varPtr([[ARGA]] : memref<10xf32>) -> memref<10xf32> {dataClause = 4 : i64}
 // CHECK-NEXT: acc.exit_data dataOperands([[DEVPTR]] : memref<10xf32>)
 // CHECK-NEXT: acc.copyout accPtr([[DEVPTR]] : memref<10xf32>) to varPtr([[ARGA]] : memref<10xf32>) {structured = false}