From b69e8ee049ff1ba1e1ecb90acaa197cf869d9aa3 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 25 Oct 2019 09:33:32 -0700 Subject: [PATCH] Add support for parsing multiple result name groups. This allows for parsing things like: %name_1, %name_2:5, %name_3:2 = "my.op" ... This is useful for operations that have groups of variadic result values. The total number of results is expected to match the number of results defined by the operation. PiperOrigin-RevId: 276703280 --- mlir/lib/Parser/Parser.cpp | 86 +++++++++++++++++++++------------------------- mlir/test/IR/parser.mlir | 9 ++++- 2 files changed, 47 insertions(+), 48 deletions(-) diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 8813cdc..c3f9081 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -3115,43 +3115,40 @@ Value *OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) { /// ParseResult OperationParser::parseOperation() { auto loc = getToken().getLoc(); - SmallVector, 1> resultIDs; - size_t numExpectedResults; + SmallVector, 1> resultIDs; + size_t numExpectedResults = 0; if (getToken().is(Token::percent_identifier)) { - // Parse the first result id. - resultIDs.emplace_back(getTokenSpelling(), loc); - consumeToken(Token::percent_identifier); - - // If the next token is a ':', we parse the expected result count. - if (consumeIf(Token::colon)) { - // Check that the next token is an integer. - if (!getToken().is(Token::integer)) - return emitError("expected integer number of results"); - - // Check that number of results is > 0. - auto val = getToken().getUInt64IntegerValue(); - if (!val.hasValue() || val.getValue() < 1) - return emitError("expected named operation to have atleast 1 result"); - consumeToken(Token::integer); - numExpectedResults = *val; - } else { - // Otherwise, this is a comma separated list of result ids. - if (consumeIf(Token::comma)) { - auto parseNextResult = [&]() -> ParseResult { - // Parse the next result id. - if (!getToken().is(Token::percent_identifier)) - return emitError("expected valid ssa identifier"); - - resultIDs.emplace_back(getTokenSpelling(), getToken().getLoc()); - consumeToken(Token::percent_identifier); - return success(); - }; - - if (parseCommaSeparatedList(parseNextResult)) - return failure(); + // Parse the group of result ids. + auto parseNextResult = [&]() -> ParseResult { + // Parse the next result id. + if (!getToken().is(Token::percent_identifier)) + return emitError("expected valid ssa identifier"); + + Token nameTok = getToken(); + consumeToken(Token::percent_identifier); + + // If the next token is a ':', we parse the expected result count. + size_t expectedSubResults = 1; + if (consumeIf(Token::colon)) { + // Check that the next token is an integer. + if (!getToken().is(Token::integer)) + return emitError("expected integer number of results"); + + // Check that number of results is > 0. + auto val = getToken().getUInt64IntegerValue(); + if (!val.hasValue() || val.getValue() < 1) + return emitError("expected named operation to have atleast 1 result"); + consumeToken(Token::integer); + expectedSubResults = *val; } - numExpectedResults = resultIDs.size(); - } + + resultIDs.emplace_back(nameTok.getSpelling(), expectedSubResults, + nameTok.getLoc()); + numExpectedResults += expectedSubResults; + return success(); + }; + if (parseCommaSeparatedList(parseNextResult)) + return failure(); if (parseToken(Token::equal, "expected '=' after SSA name")) return failure(); @@ -3178,19 +3175,14 @@ ParseResult OperationParser::parseOperation() { << op->getNumResults() << " results but was provided " << numExpectedResults << " to bind"; - // If the number of result names matches the number of operation results, we - // can directly use the provided names. - if (resultIDs.size() == op->getNumResults()) { - for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) - if (addDefinition({resultIDs[i].first, 0, resultIDs[i].second}, - op->getResult(i))) - return failure(); - } else { - // Otherwise, we use the same name for all results. - StringRef name = resultIDs.front().first; - for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) - if (addDefinition({name, i, loc}, op->getResult(i))) + // Add definitions for each of the result groups. + unsigned opResI = 0; + for (std::tuple &resIt : resultIDs) { + for (unsigned subRes : llvm::seq(0, std::get<1>(resIt))) { + if (addDefinition({std::get<0>(resIt), subRes, std::get<2>(resIt)}, + op->getResult(opResI++))) return failure(); + } } } diff --git a/mlir/test/IR/parser.mlir b/mlir/test/IR/parser.mlir index 31452e0..37f85e7 100644 --- a/mlir/test/IR/parser.mlir +++ b/mlir/test/IR/parser.mlir @@ -874,9 +874,16 @@ func @pretty_form_multi_result() -> (i16, i16) { return %quot, %rem : i16, i16 } +// CHECK-LABEL: func @pretty_form_multi_result_groups +func @pretty_form_multi_result_groups() -> (i16, i16, i16, i16, i16) { + // CHECK: %[[RES:.*]]:5 = + // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3, %[[RES]]#4 + %group_1:2, %group_2, %group_3:2 = "foo_test"() : () -> (i16, i16, i16, i16, i16) + return %group_1#0, %group_1#1, %group_2, %group_3#0, %group_3#1 : i16, i16, i16, i16, i16 +} + // CHECK-LABEL: func @pretty_dialect_attribute() func @pretty_dialect_attribute() { - // CHECK: "foo.unknown_op"() {foo = #foo.simple_attr} : () -> () "foo.unknown_op"() {foo = #foo.simple_attr} : () -> () -- 2.7.4