From 3f2f83ef41419507bcdaf751b86713ef193b7de0 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 16 Jan 2023 15:01:02 +0000 Subject: [PATCH] [mlir] accept values with result numbers in gpu.launch_func The parser of gpu.launch_func was incorrectly rejecting SSA values with result numbers (`%0#0`) in the list of function arguments by using the `parseArgument` function intended for region argument declarations, not operands. Fix this by directly parsing comma-separated operands and types. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D141851 --- mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 16 +++++++--------- mlir/test/Dialect/GPU/ops.mlir | 7 +++++++ 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index f901827..7e1f536 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -793,15 +793,13 @@ static ParseResult parseLaunchFuncOperands( if (parser.parseOptionalKeyword("args")) return success(); - SmallVector args; - if (parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren, - /*allowType=*/true)) - return failure(); - for (auto &arg : args) { - argNames.push_back(arg.ssaName); - argTypes.push_back(arg.type); - } - return success(); + auto parseElement = [&]() -> ParseResult { + return failure(parser.parseOperand(argNames.emplace_back()) || + parser.parseColonType(argTypes.emplace_back())); + }; + + return parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, + parseElement, " in argument list"); } static void printLaunchFuncOperands(OpAsmPrinter &printer, Operation *, diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir index 301ab91..5bb5efb 100644 --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -121,6 +121,8 @@ module attributes {gpu.container_module} { } } + func.func private @two_value_generator() -> (f32, memref) + func.func @foo() { %0 = "op"() : () -> (f32) %1 = "op"() : () -> (memref) @@ -140,6 +142,11 @@ module attributes {gpu.container_module} { // CHECK: %{{.*}} = gpu.launch_func async [%{{.*}}] @kernels::@kernel_2 blocks in (%{{.*}}, %{{.*}}, %{{.*}}) threads in (%{{.*}}, %{{.*}}, %{{.*}}) %t1 = gpu.launch_func async [%t0] @kernels::@kernel_2 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) + // CHECK: %[[VALUES:.*]]:2 = call + %values:2 = func.call @two_value_generator() : () -> (f32, memref) + // CHECK: gpu.launch_func @kernels::@kernel_1 {{.*}} args(%[[VALUES]]#0 : f32, %[[VALUES]]#1 : memref) + gpu.launch_func @kernels::@kernel_1 blocks in (%cst, %cst, %cst) threads in (%cst, %cst, %cst) args(%values#0 : f32, %values#1 : memref) + return } -- 2.7.4