[mlir][openacc] Add acc.serial operation
authorValentin Clement <clementval@gmail.com>
Thu, 13 Apr 2023 17:35:37 +0000 (10:35 -0700)
committerValentin Clement <clementval@gmail.com>
Thu, 13 Apr 2023 17:35:56 +0000 (10:35 -0700)
The acc.serial operation models the OpenACC serial construct.
The serial construct defines a region of a program that is to be
executed sequentially on the current device.
The operation is modelled on the acc.parallel operation and will
receive similar updates when the data operands operations will
be implemented.

Reviewed By: PeteSteinfeld

Differential Revision: https://reviews.llvm.org/D148250

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
mlir/test/Dialect/OpenACC/ops.mlir

index 0635c5c..f0f0acc 100644 (file)
@@ -157,6 +157,88 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
 }
 
 //===----------------------------------------------------------------------===//
+// 2.5.2 serial Construct
+//===----------------------------------------------------------------------===//
+
+def OpenACC_SerialOp : OpenACC_Op<"serial", [AttrSizedOperandSegments]> {
+  let summary = "serial construct";
+  let description = [{
+    The "acc.serial" operation represents a serial construct block. It has
+    one region to be executed in serial on the current device.
+
+    Example:
+
+    ```mlir
+    acc.serial private(%c : memref<10xf32>) {
+      // serial region
+    }
+    ```
+  }];
+
+  let arguments = (ins Optional<IntOrIndex>:$async,
+                       UnitAttr:$asyncAttr,
+                       Variadic<IntOrIndex>:$waitOperands,
+                       UnitAttr:$waitAttr,
+                       Optional<I1>:$ifCond,
+                       Optional<I1>:$selfCond,
+                       UnitAttr:$selfAttr,
+                       OptionalAttr<OpenACC_ReductionOpAttr>:$reductionOp,
+                       Variadic<AnyType>:$reductionOperands,
+                       Variadic<AnyType>:$copyOperands,
+                       Variadic<AnyType>:$copyinOperands,
+                       Variadic<AnyType>:$copyinReadonlyOperands,
+                       Variadic<AnyType>:$copyoutOperands,
+                       Variadic<AnyType>:$copyoutZeroOperands,
+                       Variadic<AnyType>:$createOperands,
+                       Variadic<AnyType>:$createZeroOperands,
+                       Variadic<AnyType>:$noCreateOperands,
+                       Variadic<AnyType>:$presentOperands,
+                       Variadic<AnyType>:$devicePtrOperands,
+                       Variadic<AnyType>:$attachOperands,
+                       Variadic<AnyType>:$gangPrivateOperands,
+                       Variadic<AnyType>:$gangFirstPrivateOperands,
+                       OptionalAttr<DefaultValueAttr>:$defaultAttr);
+
+  let regions = (region AnyRegion:$region);
+
+  let extraClassDeclaration = [{
+    /// The number of data operands.
+    unsigned getNumDataOperands();
+
+    /// The i-th data operand passed.
+    Value getDataOperand(unsigned i);
+  }];
+
+  let assemblyFormat = [{
+    oilist(
+        `attach` `(` $attachOperands `:` type($attachOperands) `)`
+      | `async` `(` $async `:` type($async) `)`
+      | `copy` `(` $copyOperands `:` type($copyOperands) `)`
+      | `copyin` `(` $copyinOperands `:` type($copyinOperands) `)`
+      | `copyin_readonly` `(` $copyinReadonlyOperands `:`
+          type($copyinReadonlyOperands) `)`
+      | `copyout` `(` $copyoutOperands `:` type($copyoutOperands) `)`
+      | `copyout_zero` `(` $copyoutZeroOperands `:`
+          type($copyoutZeroOperands) `)`
+      | `create` `(` $createOperands `:` type($createOperands) `)`
+      | `create_zero` `(` $createZeroOperands `:`
+          type($createZeroOperands) `)`
+      | `deviceptr` `(` $devicePtrOperands `:` type($devicePtrOperands) `)`
+      | `firstprivate` `(` $gangFirstPrivateOperands `:`
+            type($gangFirstPrivateOperands) `)`
+      | `no_create` `(` $noCreateOperands `:` type($noCreateOperands) `)`
+      | `private` `(` $gangPrivateOperands `:` type($gangPrivateOperands) `)`
+      | `present` `(` $presentOperands `:` type($presentOperands) `)`
+      | `wait` `(` $waitOperands `:` type($waitOperands) `)`
+      | `self` `(` $selfCond `)`
+      | `if` `(` $ifCond `)`
+      | `reduction` `(` $reductionOperands `:` type($reductionOperands) `)`
+    )
+    $region attr-dict-with-keyword
+  }];
+}
+
+//===----------------------------------------------------------------------===//
 // 2.6.5 data Construct
 //===----------------------------------------------------------------------===//
 
index 7de89dc..1dbb044 100644 (file)
@@ -188,6 +188,27 @@ Value ParallelOp::getDataOperand(unsigned i) {
 }
 
 //===----------------------------------------------------------------------===//
+// SerialOp
+//===----------------------------------------------------------------------===//
+
+unsigned SerialOp::getNumDataOperands() {
+  return getReductionOperands().size() + getCopyOperands().size() +
+         getCopyinOperands().size() + getCopyinReadonlyOperands().size() +
+         getCopyoutOperands().size() + getCopyoutZeroOperands().size() +
+         getCreateOperands().size() + getCreateZeroOperands().size() +
+         getNoCreateOperands().size() + getPresentOperands().size() +
+         getDevicePtrOperands().size() + getAttachOperands().size() +
+         getGangPrivateOperands().size() + getGangFirstPrivateOperands().size();
+}
+
+Value SerialOp::getDataOperand(unsigned i) {
+  unsigned numOptional = getAsync() ? 1 : 0;
+  numOptional += getIfCond() ? 1 : 0;
+  numOptional += getSelfCond() ? 1 : 0;
+  return getOperand(getWaitOperands().size() + numOptional + i);
+}
+
+//===----------------------------------------------------------------------===//
 // LoopOp
 //===----------------------------------------------------------------------===//
 
index 91d29d4..c659467 100644 (file)
@@ -484,6 +484,102 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
 
 // -----
 
+// -----
+
+func.func @testserialop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () {
+  %i64value = arith.constant 1 : i64
+  %i32value = arith.constant 1 : i32
+  %idxValue = arith.constant 1 : index
+  acc.serial async(%i64value: i64) {
+  }
+  acc.serial async(%i32value: i32) {
+  }
+  acc.serial async(%idxValue: index) {
+  }
+  acc.serial wait(%i64value: i64) {
+  }
+  acc.serial wait(%i32value: i32) {
+  }
+  acc.serial wait(%idxValue: index) {
+  }
+  acc.serial wait(%i64value, %i32value, %idxValue : i64, i32, index) {
+  }
+  acc.serial copyin(%a, %b : memref<10xf32>, memref<10xf32>) {
+  }
+  acc.serial copyin_readonly(%a, %b : memref<10xf32>, memref<10xf32>) {
+  }
+  acc.serial copyin(%a: memref<10xf32>) copyout_zero(%b, %c : memref<10xf32>, memref<10x10xf32>) {
+  }
+  acc.serial copyout(%b, %c : memref<10xf32>, memref<10x10xf32>) create(%a: memref<10xf32>) {
+  }
+  acc.serial copyout_zero(%b, %c : memref<10xf32>, memref<10x10xf32>) create_zero(%a: memref<10xf32>) {
+  }
+  acc.serial no_create(%a: memref<10xf32>) present(%b, %c : memref<10xf32>, memref<10x10xf32>) {
+  }
+  acc.serial deviceptr(%a: memref<10xf32>) attach(%b, %c : memref<10xf32>, memref<10x10xf32>) {
+  }
+  acc.serial private(%a, %c : memref<10xf32>, memref<10x10xf32>) firstprivate(%b: memref<10xf32>) {
+  }
+  acc.serial {
+  } attributes {defaultAttr = #acc<defaultvalue none>}
+  acc.serial {
+  } attributes {defaultAttr = #acc<defaultvalue present>}
+  acc.serial {
+  } attributes {asyncAttr}
+  acc.serial {
+  } attributes {waitAttr}
+  acc.serial {
+  } attributes {selfAttr}
+  return
+}
+
+// CHECK:      func @testserialop([[ARGA:%.*]]: memref<10xf32>, [[ARGB:%.*]]: memref<10xf32>, [[ARGC:%.*]]: memref<10x10xf32>) {
+// CHECK:      [[I64VALUE:%.*]] = arith.constant 1 : i64
+// CHECK:      [[I32VALUE:%.*]] = arith.constant 1 : i32
+// CHECK:      [[IDXVALUE:%.*]] = arith.constant 1 : index
+// CHECK:      acc.serial async([[I64VALUE]] : i64) {
+// CHECK-NEXT: }
+// CHECK:      acc.serial async([[I32VALUE]] : i32) {
+// CHECK-NEXT: }
+// CHECK:      acc.serial async([[IDXVALUE]] : index) {
+// CHECK-NEXT: }
+// CHECK:      acc.serial wait([[I64VALUE]] : i64) {
+// CHECK-NEXT: }
+// CHECK:      acc.serial wait([[I32VALUE]] : i32) {
+// CHECK-NEXT: }
+// CHECK:      acc.serial wait([[IDXVALUE]] : index) {
+// CHECK-NEXT: }
+// CHECK:      acc.serial wait([[I64VALUE]], [[I32VALUE]], [[IDXVALUE]] : i64, i32, index) {
+// CHECK-NEXT: }
+// CHECK:      acc.serial copyin([[ARGA]], [[ARGB]] : memref<10xf32>, memref<10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.serial copyin_readonly([[ARGA]], [[ARGB]] : memref<10xf32>, memref<10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.serial copyin([[ARGA]] : memref<10xf32>) copyout_zero([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.serial copyout([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) create([[ARGA]] : memref<10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.serial copyout_zero([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) create_zero([[ARGA]] : memref<10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.serial no_create([[ARGA]] : memref<10xf32>) present([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.serial attach([[ARGB]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) deviceptr([[ARGA]] : memref<10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.serial firstprivate([[ARGB]] : memref<10xf32>) private([[ARGA]], [[ARGC]] : memref<10xf32>, memref<10x10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.serial {
+// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue none>}
+// CHECK:      acc.serial {
+// CHECK-NEXT: } attributes {defaultAttr = #acc<defaultvalue present>}
+// CHECK:      acc.serial {
+// CHECK-NEXT: } attributes {asyncAttr}
+// CHECK:      acc.serial {
+// CHECK-NEXT: } attributes {waitAttr}
+// CHECK:      acc.serial {
+// CHECK-NEXT: } attributes {selfAttr}
+
+// -----
+
 func.func @testdataop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () {
   %ifCond = arith.constant true
   acc.data if(%ifCond) present(%a : memref<10xf32>) {