[mlir][openacc] Switch numGangs to a variadic operand
authorValentin Clement <clementval@gmail.com>
Tue, 27 Jun 2023 18:08:22 +0000 (11:08 -0700)
committerValentin Clement <clementval@gmail.com>
Tue, 27 Jun 2023 18:08:44 +0000 (11:08 -0700)
In the latest spec, the `num_gangs` clause accepts up to three
arguments. Update the dialect to swicth `numGangs` operands from
optional single operand to a variadic operand. The verifier limits
the number of operands to three as specified in the spec.

Reviewed By: razvanlupusoru

Differential Revision: https://reviews.llvm.org/D153796

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
mlir/test/Dialect/OpenACC/invalid.mlir
mlir/test/Dialect/OpenACC/ops.mlir

index 32882bd..076faa7 100644 (file)
@@ -656,7 +656,7 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
                        UnitAttr:$asyncAttr,
                        Variadic<IntOrIndex>:$waitOperands,
                        UnitAttr:$waitAttr,
-                       Optional<IntOrIndex>:$numGangs,
+                       Variadic<IntOrIndex>:$numGangs,
                        Optional<IntOrIndex>:$numWorkers,
                        Optional<IntOrIndex>:$vectorLength,
                        Optional<I1>:$ifCond,
@@ -802,7 +802,7 @@ def OpenACC_KernelsOp : OpenACC_Op<"kernels",
                        UnitAttr:$asyncAttr,
                        Variadic<IntOrIndex>:$waitOperands,
                        UnitAttr:$waitAttr,
-                       Optional<IntOrIndex>:$numGangs,
+                       Variadic<IntOrIndex>:$numGangs,
                        Optional<IntOrIndex>:$numWorkers,
                        Optional<IntOrIndex>:$vectorLength,
                        Optional<I1>:$ifCond,
index 9b04d6c..77de452 100644 (file)
@@ -573,7 +573,7 @@ unsigned ParallelOp::getNumDataOperands() {
 
 Value ParallelOp::getDataOperand(unsigned i) {
   unsigned numOptional = getAsync() ? 1 : 0;
-  numOptional += getNumGangs() ? 1 : 0;
+  numOptional += getNumGangs().size();
   numOptional += getNumWorkers() ? 1 : 0;
   numOptional += getVectorLength() ? 1 : 0;
   numOptional += getIfCond() ? 1 : 0;
@@ -590,6 +590,8 @@ LogicalResult acc::ParallelOp::verify() {
           *this, getReductionRecipes(), getReductionOperands(), "reduction",
           "reductions", false)))
     return failure();
+  if (getNumGangs().size() > 3)
+    return emitOpError() << "num_gangs expects a maximum of 3 values";
   return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
 }
 
@@ -631,12 +633,18 @@ unsigned KernelsOp::getNumDataOperands() {
 
 Value KernelsOp::getDataOperand(unsigned i) {
   unsigned numOptional = getAsync() ? 1 : 0;
+  numOptional += getWaitOperands().size();
+  numOptional += getNumGangs().size();
+  numOptional += getNumWorkers() ? 1 : 0;
+  numOptional += getVectorLength() ? 1 : 0;
   numOptional += getIfCond() ? 1 : 0;
   numOptional += getSelfCond() ? 1 : 0;
-  return getOperand(getWaitOperands().size() + numOptional + i);
+  return getOperand(numOptional + i);
 }
 
 LogicalResult acc::KernelsOp::verify() {
+  if (getNumGangs().size() > 3)
+    return emitOpError() << "num_gangs expects a maximum of 3 values";
   return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
 }
 
index a3d9386..2a36cca 100644 (file)
@@ -486,3 +486,10 @@ acc.loop gang() {
   "test.openacc_dummy_op"() : () -> ()
   acc.yield
 }
+
+// -----
+
+%i64value = arith.constant 1 : i64
+// expected-error@+1 {{num_gangs expects a maximum of 3 values}}
+acc.parallel num_gangs(%i64value, %i64value, %i64value, %i64value : i64, i64, i64, i64) {
+}
index 92d2a92..e07ab8c 100644 (file)
@@ -443,6 +443,8 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
   }
   acc.parallel num_gangs(%idxValue: index) {
   }
+  acc.parallel num_gangs(%i64value, %i64value, %idxValue : i64, i64, index) {
+  }
   acc.parallel num_workers(%i64value: i64) {
   }
   acc.parallel num_workers(%i32value: i32) {
@@ -494,6 +496,8 @@ func.func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x
 // CHECK-NEXT: }
 // CHECK:      acc.parallel num_gangs([[IDXVALUE]] : index) {
 // CHECK-NEXT: }
+// CHECK:      acc.parallel num_gangs([[I64VALUE]], [[I64VALUE]], [[IDXVALUE]] : i64, i64, index) {
+// CHECK-NEXT: }
 // CHECK:      acc.parallel num_workers([[I64VALUE]] : i64) {
 // CHECK-NEXT: }
 // CHECK:      acc.parallel num_workers([[I32VALUE]] : i32) {