[mlir][openacc] Change operand type from index to AnyInteger in parallel op
authorValentin Clement <clementval@gmail.com>
Thu, 17 Sep 2020 15:33:31 +0000 (11:33 -0400)
committerclementval <clementval@gmail.com>
Thu, 17 Sep 2020 15:33:55 +0000 (11:33 -0400)
This patch change the type of operands async, wait, numGangs, numWorkers and vectorLength from index
to AnyInteger to fit with acc.loop and the OpenACC specification.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D87712

mlir/include/mlir/Dialect/OpenACC/OpenACCOps.td
mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp
mlir/test/Dialect/OpenACC/ops.mlir

index f6350db..3fa26f9 100644 (file)
@@ -36,7 +36,7 @@ class OpenACC_Op<string mnemonic, list<OpTrait> traits = []> :
   let parser = [{ return ::parse$cppClass(parser, result); }];
 }
 
-// Reduction operation enumeration
+// Reduction operation enumeration.
 def OpenACC_ReductionOpAdd : StrEnumAttrCase<"redop_add">;
 def OpenACC_ReductionOpMul : StrEnumAttrCase<"redop_mul">;
 def OpenACC_ReductionOpMax : StrEnumAttrCase<"redop_max">;
@@ -60,6 +60,9 @@ def OpenACC_ReductionOpAttr : StrEnumAttr<"ReductionOpAttr",
   let cppNamespace = "::mlir::acc";
 }
 
+// Type used in operation below.
+def IntOrIndex : AnyTypeOf<[AnyInteger, Index]>;
+
 //===----------------------------------------------------------------------===//
 // 2.5.1 parallel Construct
 //===----------------------------------------------------------------------===//
@@ -90,11 +93,11 @@ def OpenACC_ParallelOp : OpenACC_Op<"parallel",
     ```
   }];
 
-  let arguments = (ins Optional<Index>:$async,
-                       Variadic<Index>:$waitOperands,
-                       Optional<Index>:$numGangs,
-                       Optional<Index>:$numWorkers,
-                       Optional<Index>:$vectorLength,
+  let arguments = (ins Optional<IntOrIndex>:$async,
+                       Variadic<IntOrIndex>:$waitOperands,
+                       Optional<IntOrIndex>:$numGangs,
+                       Optional<IntOrIndex>:$numWorkers,
+                       Optional<IntOrIndex>:$vectorLength,
                        Optional<I1>:$ifCond,
                        Optional<I1>:$selfCond,
                        OptionalAttr<OpenACC_ReductionOpAttr>:$reductionOp,
index 6149512..3cae3c8 100644 (file)
@@ -101,6 +101,22 @@ static ParseResult parseOptionalOperand(OpAsmParser &parser, StringRef keyword,
   return success();
 }
 
+static OptionalParseResult parseOptionalOperandAndType(OpAsmParser &parser,
+                                                       StringRef keyword,
+                                                       OperationState &result) {
+  OpAsmParser::OperandType operand;
+  Type type;
+  if (succeeded(parser.parseOptionalKeyword(keyword))) {
+    if (parser.parseLParen() || parser.parseOperand(operand) ||
+        parser.parseColonType(type) ||
+        parser.resolveOperand(operand, type, result.operands) ||
+        parser.parseRParen())
+      return failure();
+    return success();
+  }
+  return llvm::None;
+}
+
 //===----------------------------------------------------------------------===//
 // ParallelOp
 //===----------------------------------------------------------------------===//
@@ -142,17 +158,17 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
       createZeroOperandTypes, noCreateOperandTypes, presentOperandTypes,
       deviceptrOperandTypes, attachOperandTypes, privateOperandTypes,
       firstprivateOperandTypes;
-  OpAsmParser::OperandType async, numGangs, numWorkers, vectorLength, ifCond,
-      selfCond;
-  bool hasAsync = false, hasNumGangs = false, hasNumWorkers = false;
-  bool hasVectorLength = false, hasIfCond = false, hasSelfCond = false;
 
-  Type indexType = builder.getIndexType();
+  SmallVector<Type, 8> operandTypes;
+  OpAsmParser::OperandType ifCond, selfCond;
+  bool hasIfCond = false, hasSelfCond = false;
+  OptionalParseResult async, numGangs, numWorkers, vectorLength;
   Type i1Type = builder.getI1Type();
 
   // async()?
-  if (failed(parseOptionalOperand(parser, ParallelOp::getAsyncKeyword(), async,
-                                  indexType, hasAsync, result)))
+  async = parseOptionalOperandAndType(parser, ParallelOp::getAsyncKeyword(),
+                                      result);
+  if (async.hasValue() && failed(*async))
     return failure();
 
   // wait()?
@@ -161,20 +177,21 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
     return failure();
 
   // num_gangs(value)?
-  if (failed(parseOptionalOperand(parser, ParallelOp::getNumGangsKeyword(),
-                                  numGangs, indexType, hasNumGangs, result)))
+  numGangs = parseOptionalOperandAndType(
+      parser, ParallelOp::getNumGangsKeyword(), result);
+  if (numGangs.hasValue() && failed(*numGangs))
     return failure();
 
   // num_workers(value)?
-  if (failed(parseOptionalOperand(parser, ParallelOp::getNumWorkersKeyword(),
-                                  numWorkers, indexType, hasNumWorkers,
-                                  result)))
+  numWorkers = parseOptionalOperandAndType(
+      parser, ParallelOp::getNumWorkersKeyword(), result);
+  if (numWorkers.hasValue() && failed(*numWorkers))
     return failure();
 
   // vector_length(value)?
-  if (failed(parseOptionalOperand(parser, ParallelOp::getVectorLengthKeyword(),
-                                  vectorLength, indexType, hasVectorLength,
-                                  result)))
+  vectorLength = parseOptionalOperandAndType(
+      parser, ParallelOp::getVectorLengthKeyword(), result);
+  if (vectorLength.hasValue() && failed(*vectorLength))
     return failure();
 
   // if()?
@@ -267,29 +284,30 @@ static ParseResult parseParallelOp(OpAsmParser &parser,
   if (failed(parseRegions<ParallelOp>(parser, result)))
     return failure();
 
-  result.addAttribute(ParallelOp::getOperandSegmentSizeAttr(),
-                      builder.getI32VectorAttr(
-                          {static_cast<int32_t>(hasAsync ? 1 : 0),
-                           static_cast<int32_t>(waitOperands.size()),
-                           static_cast<int32_t>(hasNumGangs ? 1 : 0),
-                           static_cast<int32_t>(hasNumWorkers ? 1 : 0),
-                           static_cast<int32_t>(hasVectorLength ? 1 : 0),
-                           static_cast<int32_t>(hasIfCond ? 1 : 0),
-                           static_cast<int32_t>(hasSelfCond ? 1 : 0),
-                           static_cast<int32_t>(reductionOperands.size()),
-                           static_cast<int32_t>(copyOperands.size()),
-                           static_cast<int32_t>(copyinOperands.size()),
-                           static_cast<int32_t>(copyinReadonlyOperands.size()),
-                           static_cast<int32_t>(copyoutOperands.size()),
-                           static_cast<int32_t>(copyoutZeroOperands.size()),
-                           static_cast<int32_t>(createOperands.size()),
-                           static_cast<int32_t>(createZeroOperands.size()),
-                           static_cast<int32_t>(noCreateOperands.size()),
-                           static_cast<int32_t>(presentOperands.size()),
-                           static_cast<int32_t>(devicePtrOperands.size()),
-                           static_cast<int32_t>(attachOperands.size()),
-                           static_cast<int32_t>(privateOperands.size()),
-                           static_cast<int32_t>(firstprivateOperands.size())}));
+  result.addAttribute(
+      ParallelOp::getOperandSegmentSizeAttr(),
+      builder.getI32VectorAttr(
+          {static_cast<int32_t>(async.hasValue() ? 1 : 0),
+           static_cast<int32_t>(waitOperands.size()),
+           static_cast<int32_t>(numGangs.hasValue() ? 1 : 0),
+           static_cast<int32_t>(numWorkers.hasValue() ? 1 : 0),
+           static_cast<int32_t>(vectorLength.hasValue() ? 1 : 0),
+           static_cast<int32_t>(hasIfCond ? 1 : 0),
+           static_cast<int32_t>(hasSelfCond ? 1 : 0),
+           static_cast<int32_t>(reductionOperands.size()),
+           static_cast<int32_t>(copyOperands.size()),
+           static_cast<int32_t>(copyinOperands.size()),
+           static_cast<int32_t>(copyinReadonlyOperands.size()),
+           static_cast<int32_t>(copyoutOperands.size()),
+           static_cast<int32_t>(copyoutZeroOperands.size()),
+           static_cast<int32_t>(createOperands.size()),
+           static_cast<int32_t>(createZeroOperands.size()),
+           static_cast<int32_t>(noCreateOperands.size()),
+           static_cast<int32_t>(presentOperands.size()),
+           static_cast<int32_t>(devicePtrOperands.size()),
+           static_cast<int32_t>(attachOperands.size()),
+           static_cast<int32_t>(privateOperands.size()),
+           static_cast<int32_t>(firstprivateOperands.size())}));
 
   // Additional attributes
   if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
@@ -303,7 +321,8 @@ static void print(OpAsmPrinter &printer, ParallelOp &op) {
 
   // async()?
   if (Value async = op.async())
-    printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ")";
+    printer << " " << ParallelOp::getAsyncKeyword() << "(" << async << ": "
+            << async.getType() << ")";
 
   // wait()?
   printOperandList(op.waitOperands(), ParallelOp::getWaitKeyword(), printer);
@@ -311,17 +330,17 @@ static void print(OpAsmPrinter &printer, ParallelOp &op) {
   // num_gangs()?
   if (Value numGangs = op.numGangs())
     printer << " " << ParallelOp::getNumGangsKeyword() << "(" << numGangs
-            << ")";
+            << ": " << numGangs.getType() << ")";
 
   // num_workers()?
   if (Value numWorkers = op.numWorkers())
     printer << " " << ParallelOp::getNumWorkersKeyword() << "(" << numWorkers
-            << ")";
+            << ": " << numWorkers.getType() << ")";
 
   // vector_length()?
   if (Value vectorLength = op.vectorLength())
     printer << " " << ParallelOp::getVectorLengthKeyword() << "("
-            << vectorLength << ")";
+            << vectorLength << ": " << vectorLength.getType() << ")";
 
   // if()?
   if (Value ifCond = op.ifCond())
index 3398f95..1969498 100644 (file)
@@ -8,8 +8,9 @@ func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
   %c0 = constant 0 : index
   %c10 = constant 10 : index
   %c1 = constant 1 : index
+  %async = constant 1 : i64
 
-  acc.parallel async(%c1) {
+  acc.parallel async(%async: i64) {
     acc.loop gang vector {
       scf.for %arg3 = %c0 to %c10 step %c1 {
         scf.for %arg4 = %c0 to %c10 step %c1 {
@@ -35,7 +36,8 @@ func @compute1(%A: memref<10x10xf32>, %B: memref<10x10xf32>, %C: memref<10x10xf3
 //  CHECK-NEXT:   %{{.*}} = constant 0 : index
 //  CHECK-NEXT:   %{{.*}} = constant 10 : index
 //  CHECK-NEXT:   %{{.*}} = constant 1 : index
-//  CHECK-NEXT:   acc.parallel async(%{{.*}}) {
+//  CHECK-NEXT:   [[ASYNC:%.*]] = constant 1 : i64
+//  CHECK-NEXT:   acc.parallel async([[ASYNC]]: i64) {
 //  CHECK-NEXT:     acc.loop gang vector {
 //  CHECK-NEXT:       scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
 //  CHECK-NEXT:         scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
@@ -113,9 +115,11 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
   %lb = constant 0 : index
   %st = constant 1 : index
   %c10 = constant 10 : index
+  %numGangs = constant 10 : i64
+  %numWorkers = constant 10 : i64
 
   acc.data present(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>, %d: memref<10xf32>) {
-    acc.parallel num_gangs(%c10) num_workers(%c10) private(%c : memref<10xf32>) {
+    acc.parallel num_gangs(%numGangs: i64) num_workers(%numWorkers: i64) private(%c : memref<10xf32>) {
       acc.loop gang {
         scf.for %x = %lb to %c10 step %st {
           acc.loop worker {
@@ -154,8 +158,10 @@ func @compute3(%a: memref<10x10xf32>, %b: memref<10x10xf32>, %c: memref<10xf32>,
 // CHECK-NEXT:   [[C0:%.*]] = constant 0 : index
 // CHECK-NEXT:   [[C1:%.*]] = constant 1 : index
 // CHECK-NEXT:   [[C10:%.*]] = constant 10 : index
+// CHECK-NEXT:   [[NUMGANG:%.*]] = constant 10 : i64
+// CHECK-NEXT:   [[NUMWORKERS:%.*]] = constant 10 : i64
 // CHECK-NEXT:   acc.data present(%{{.*}}: memref<10x10xf32>, %{{.*}}: memref<10x10xf32>, %{{.*}}: memref<10xf32>, %{{.*}}: memref<10xf32>) {
-// CHECK-NEXT:     acc.parallel num_gangs([[C10]]) num_workers([[C10]]) private([[ARG2]]: memref<10xf32>) {
+// CHECK-NEXT:     acc.parallel num_gangs([[NUMGANG]]: i64) num_workers([[NUMWORKERS]]: i64) private([[ARG2]]: memref<10xf32>) {
 // CHECK-NEXT:       acc.loop gang {
 // CHECK-NEXT:         scf.for %{{.*}} = [[C0]] to [[C10]] step [[C1]] {
 // CHECK-NEXT:           acc.loop worker {
@@ -265,9 +271,42 @@ func @testop(%a: memref<10xf32>) -> () {
 // CHECK-NEXT:   acc.yield
 // CHECK-NEXT: }
 
+
 func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf32>) -> () {
-  %vectorLength = constant 128 : index
-  acc.parallel vector_length(%vectorLength) {
+  %i64value = constant 1 : i64
+  %i32value = constant 1 : i32
+  %idxValue = constant 1 : index
+  acc.parallel async(%i64value: i64) {
+  }
+  acc.parallel async(%i32value: i32) {
+  }
+  acc.parallel async(%idxValue: index) {
+  }
+  acc.parallel wait(%i64value: i64) {
+  }
+  acc.parallel wait(%i32value: i32) {
+  }
+  acc.parallel wait(%idxValue: index) {
+  }
+  acc.parallel wait(%i64value: i64, %i32value: i32, %idxValue: index) {
+  }
+  acc.parallel num_gangs(%i64value: i64) {
+  }
+  acc.parallel num_gangs(%i32value: i32) {
+  }
+  acc.parallel num_gangs(%idxValue: index) {
+  }
+  acc.parallel num_workers(%i64value: i64) {
+  }
+  acc.parallel num_workers(%i32value: i32) {
+  }
+  acc.parallel num_workers(%idxValue: index) {
+  }
+  acc.parallel vector_length(%i64value: i64) {
+  }
+  acc.parallel vector_length(%i32value: i32) {
+  }
+  acc.parallel vector_length(%idxValue: index) {
   }
   acc.parallel copyin(%a: memref<10xf32>, %b: memref<10xf32>) {
   }
@@ -293,26 +332,58 @@ func @testparallelop(%a: memref<10xf32>, %b: memref<10xf32>, %c: memref<10x10xf3
 }
 
 // CHECK:      func @testparallelop([[ARGA:%.*]]: memref<10xf32>, [[ARGB:%.*]]: memref<10xf32>, [[ARGC:%.*]]: memref<10x10xf32>) {
-// CHECK:        [[VECTORLENGTH:%.*]] = constant 128 : index
-// CHECK:        acc.parallel vector_length([[VECTORLENGTH]]) {
-// CHECK-NEXT:   }
-// CHECK:        acc.parallel copyin([[ARGA]]: memref<10xf32>, [[ARGB]]: memref<10xf32>) {
-// CHECK-NEXT:   }
-// CHECK:        acc.parallel copyin_readonly([[ARGA]]: memref<10xf32>, [[ARGB]]: memref<10xf32>) {
-// CHECK-NEXT:   }
-// CHECK:        acc.parallel copyin([[ARGA]]: memref<10xf32>) copyout_zero([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) {
-// CHECK-NEXT:   }
-// CHECK:        acc.parallel copyout([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) create([[ARGA]]: memref<10xf32>) {
-// CHECK-NEXT:   }
-// CHECK:        acc.parallel copyout_zero([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) create_zero([[ARGA]]: memref<10xf32>) {
-// CHECK-NEXT:   }
-// CHECK:        acc.parallel no_create([[ARGA]]: memref<10xf32>) present([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) {
-// CHECK-NEXT:   }
-// CHECK:        acc.parallel deviceptr([[ARGA]]: memref<10xf32>) attach([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) {
-// CHECK-NEXT:   }
-// CHECK:        acc.parallel private([[ARGA]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) firstprivate([[ARGB]]: memref<10xf32>) {
-// CHECK-NEXT:   }
-// CHECK:        acc.parallel {
-// CHECK-NEXT:   } attributes {defaultAttr = "none"}
-// CHECK:        acc.parallel {
-// CHECK-NEXT:   } attributes {defaultAttr = "present"}
+// CHECK:      [[I64VALUE:%.*]] = constant 1 : i64
+// CHECK:      [[I32VALUE:%.*]] = constant 1 : i32
+// CHECK:      [[IDXVALUE:%.*]] = constant 1 : index
+// CHECK:      acc.parallel async([[I64VALUE]]: i64) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel async([[I32VALUE]]: i32) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel async([[IDXVALUE]]: index) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel wait([[I64VALUE]]: i64) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel wait([[I32VALUE]]: i32) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel wait([[IDXVALUE]]: index) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel wait([[I64VALUE]]: i64, [[I32VALUE]]: i32, [[IDXVALUE]]: index) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel num_gangs([[I64VALUE]]: i64) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel num_gangs([[I32VALUE]]: i32) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel num_gangs([[IDXVALUE]]: index) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel num_workers([[I64VALUE]]: i64) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel num_workers([[I32VALUE]]: i32) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel num_workers([[IDXVALUE]]: index) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel vector_length([[I64VALUE]]: i64) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel vector_length([[I32VALUE]]: i32) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel vector_length([[IDXVALUE]]: index) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel copyin([[ARGA]]: memref<10xf32>, [[ARGB]]: memref<10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel copyin_readonly([[ARGA]]: memref<10xf32>, [[ARGB]]: memref<10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel copyin([[ARGA]]: memref<10xf32>) copyout_zero([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel copyout([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) create([[ARGA]]: memref<10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel copyout_zero([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) create_zero([[ARGA]]: memref<10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel no_create([[ARGA]]: memref<10xf32>) present([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel deviceptr([[ARGA]]: memref<10xf32>) attach([[ARGB]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel private([[ARGA]]: memref<10xf32>, [[ARGC]]: memref<10x10xf32>) firstprivate([[ARGB]]: memref<10xf32>) {
+// CHECK-NEXT: }
+// CHECK:      acc.parallel {
+// CHECK-NEXT: } attributes {defaultAttr = "none"}
+// CHECK:      acc.parallel {
+// CHECK-NEXT: } attributes {defaultAttr = "present"}