[mlir] accept values with result numbers in gpu.launch_func
authorAlex Zinenko <zinenko@google.com>
Mon, 16 Jan 2023 15:01:02 +0000 (15:01 +0000)
committerAlex Zinenko <zinenko@google.com>
Mon, 16 Jan 2023 19:26:42 +0000 (19:26 +0000)
The parser of gpu.launch_func was incorrectly rejecting SSA values with
result numbers (`%0#0`) in the list of function arguments by using the
`parseArgument` function intended for region argument declarations, not
operands. Fix this by directly parsing comma-separated operands and
types.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D141851

mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
mlir/test/Dialect/GPU/ops.mlir

index f901827..7e1f536 100644 (file)
@@ -793,15 +793,13 @@ static ParseResult parseLaunchFuncOperands(
   if (parser.parseOptionalKeyword("args"))
     return success();
 
-  SmallVector<OpAsmParser::Argument> args;
-  if (parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren,
-                               /*allowType=*/true))
-    return failure();
-  for (auto &arg : args) {
-    argNames.push_back(arg.ssaName);
-    argTypes.push_back(arg.type);
-  }
-  return success();
+  auto parseElement = [&]() -> ParseResult {
+    return failure(parser.parseOperand(argNames.emplace_back()) ||
+                   parser.parseColonType(argTypes.emplace_back()));
+  };
+
+  return parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
+                                        parseElement, " in argument list");
 }
 
 static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *,
index 301ab91..5bb5efb 100644 (file)
@@ -121,6 +121,8 @@ module attributes {gpu.container_module} {
     }
   }
 
+  func.func private @two_value_generator() -> (f32, memref<?xf32, 1>)
+
   func.func @foo() {
     %0 = "op"() : () -> (f32)
     %1 = "op"() : () -> (memref<?xf32, 1>)
@@ -140,6 +142,11 @@ module attributes {gpu.container_module} {
     // CHECK: %{{.*}} = gpu.launch_func async [%{{.*}}] @kernels::@kernel_2 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}})
     %t1 = gpu.launch_func async [%t0] @kernels::@kernel_2  blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst)
 
+    // CHECK: %[[VALUES:.*]]:2 = call
+    %values:2 = func.call @two_value_generator() : () -> (f32, memref<?xf32, 1>)
+    // CHECK: gpu.launch_func @kernels::@kernel_1 {{.*}} args(%[[VALUES]]#0 : f32, %[[VALUES]]#1 : memref<?xf32, 1>)
+    gpu.launch_func @kernels::@kernel_1 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) args(%values#0 : f32, %values#1 : memref<?xf32, 1>)
+
     return
   }