Add support for parsing multiple result name groups.
authorRiver Riddle <riverriddle@google.com>
Fri, 25 Oct 2019 16:33:32 +0000 (09:33 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 25 Oct 2019 16:34:02 +0000 (09:34 -0700)
This allows for parsing things like:

%name_1, %name_2:5, %name_3:2 = "my.op" ...

This is useful for operations that have groups of variadic result values. The
total number of results is expected to match the number of results defined by
the operation.

PiperOrigin-RevId: 276703280

mlir/lib/Parser/Parser.cpp
mlir/test/IR/parser.mlir

index 8813cdc..c3f9081 100644 (file)
@@ -3115,43 +3115,40 @@ Value *OperationParser::createForwardRefPlaceholder(SMLoc loc, Type type) {
 ///
 ParseResult OperationParser::parseOperation() {
   auto loc = getToken().getLoc();
-  SmallVector<std::pair<StringRef, SMLoc>, 1> resultIDs;
-  size_t numExpectedResults;
+  SmallVector<std::tuple<StringRef, unsigned, SMLoc>, 1> resultIDs;
+  size_t numExpectedResults = 0;
   if (getToken().is(Token::percent_identifier)) {
-    // Parse the first result id.
-    resultIDs.emplace_back(getTokenSpelling(), loc);
-    consumeToken(Token::percent_identifier);
-
-    // If the next token is a ':', we parse the expected result count.
-    if (consumeIf(Token::colon)) {
-      // Check that the next token is an integer.
-      if (!getToken().is(Token::integer))
-        return emitError("expected integer number of results");
-
-      // Check that number of results is > 0.
-      auto val = getToken().getUInt64IntegerValue();
-      if (!val.hasValue() || val.getValue() < 1)
-        return emitError("expected named operation to have atleast 1 result");
-      consumeToken(Token::integer);
-      numExpectedResults = *val;
-    } else {
-      // Otherwise, this is a comma separated list of result ids.
-      if (consumeIf(Token::comma)) {
-        auto parseNextResult = [&]() -> ParseResult {
-          // Parse the next result id.
-          if (!getToken().is(Token::percent_identifier))
-            return emitError("expected valid ssa identifier");
-
-          resultIDs.emplace_back(getTokenSpelling(), getToken().getLoc());
-          consumeToken(Token::percent_identifier);
-          return success();
-        };
-
-        if (parseCommaSeparatedList(parseNextResult))
-          return failure();
+    // Parse the group of result ids.
+    auto parseNextResult = [&]() -> ParseResult {
+      // Parse the next result id.
+      if (!getToken().is(Token::percent_identifier))
+        return emitError("expected valid ssa identifier");
+
+      Token nameTok = getToken();
+      consumeToken(Token::percent_identifier);
+
+      // If the next token is a ':', we parse the expected result count.
+      size_t expectedSubResults = 1;
+      if (consumeIf(Token::colon)) {
+        // Check that the next token is an integer.
+        if (!getToken().is(Token::integer))
+          return emitError("expected integer number of results");
+
+        // Check that number of results is > 0.
+        auto val = getToken().getUInt64IntegerValue();
+        if (!val.hasValue() || val.getValue() < 1)
+          return emitError("expected named operation to have atleast 1 result");
+        consumeToken(Token::integer);
+        expectedSubResults = *val;
       }
-      numExpectedResults = resultIDs.size();
-    }
+
+      resultIDs.emplace_back(nameTok.getSpelling(), expectedSubResults,
+                             nameTok.getLoc());
+      numExpectedResults += expectedSubResults;
+      return success();
+    };
+    if (parseCommaSeparatedList(parseNextResult))
+      return failure();
 
     if (parseToken(Token::equal, "expected '=' after SSA name"))
       return failure();
@@ -3178,19 +3175,14 @@ ParseResult OperationParser::parseOperation() {
              << op->getNumResults() << " results but was provided "
              << numExpectedResults << " to bind";
 
-    // If the number of result names matches the number of operation results, we
-    // can directly use the provided names.
-    if (resultIDs.size() == op->getNumResults()) {
-      for (unsigned i = 0, e = op->getNumResults(); i != e; ++i)
-        if (addDefinition({resultIDs[i].first, 0, resultIDs[i].second},
-                          op->getResult(i)))
-          return failure();
-    } else {
-      // Otherwise, we use the same name for all results.
-      StringRef name = resultIDs.front().first;
-      for (unsigned i = 0, e = op->getNumResults(); i != e; ++i)
-        if (addDefinition({name, i, loc}, op->getResult(i)))
+    // Add definitions for each of the result groups.
+    unsigned opResI = 0;
+    for (std::tuple<StringRef, unsigned, SMLoc> &resIt : resultIDs) {
+      for (unsigned subRes : llvm::seq<unsigned>(0, std::get<1>(resIt))) {
+        if (addDefinition({std::get<0>(resIt), subRes, std::get<2>(resIt)},
+                          op->getResult(opResI++)))
           return failure();
+      }
     }
   }
 
index 31452e0..37f85e7 100644 (file)
@@ -874,9 +874,16 @@ func @pretty_form_multi_result() -> (i16, i16) {
   return %quot, %rem : i16, i16
 }
 
+// CHECK-LABEL: func @pretty_form_multi_result_groups
+func @pretty_form_multi_result_groups() -> (i16, i16, i16, i16, i16) {
+  // CHECK: %[[RES:.*]]:5 =
+  // CHECK: return %[[RES]]#0, %[[RES]]#1, %[[RES]]#2, %[[RES]]#3, %[[RES]]#4
+  %group_1:2, %group_2, %group_3:2 = "foo_test"() : () -> (i16, i16, i16, i16, i16)
+  return %group_1#0, %group_1#1, %group_2, %group_3#0, %group_3#1 : i16, i16, i16, i16, i16
+}
+
 // CHECK-LABEL: func @pretty_dialect_attribute()
 func @pretty_dialect_attribute() {
-
   // CHECK: "foo.unknown_op"() {foo = #foo.simple_attr} : () -> ()
   "foo.unknown_op"() {foo = #foo.simple_attr} : () -> ()