Change the muli-return syntax for operations. The name of the operation result now...
authorRiver Riddle <riverriddle@google.com>
Thu, 28 Mar 2019 21:58:52 +0000 (14:58 -0700)
committerjpienaar <jpienaar@google.com>
Sat, 30 Mar 2019 00:51:32 +0000 (17:51 -0700)
Example:
    %call:2 = call @multi_return() : () -> (f32, i32)
    use(%calltensorflow/mlir#0, %calltensorflow/mlir#1)

This cl also adds parser support for uniquely named result values. This means that a test writer can now write something like:
    %foo, %bar = call @multi_return() : () -> (f32, i32)
    use(%foo, %bar)

Note: The printer will still print the collapsed form.
PiperOrigin-RevId: 240860058

mlir/g3doc/ConversionToLLVMDialect.md
mlir/g3doc/LangRef.md
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/EDSC/builder-api-test.cpp
mlir/test/IR/invalid.mlir
mlir/test/IR/parser.mlir
mlir/test/LLVMIR/convert-to-llvmir.mlir
mlir/test/Transforms/Vectorize/vectorize_1d.mlir
mlir/test/Transforms/unroll.mlir

index 418dfbf..27f9d3a 100644 (file)
@@ -198,7 +198,7 @@ func @foo(%arg0: i32, %arg1: i64) -> (i32, i64) {
 func @bar() {
   %0 = constant 42 : i32
   %1 = constant 17 : i64
-  %2 = call @foo(%0, %1) : (i32, i64) -> (i32, i64)
+  %2:2 = call @foo(%0, %1) : (i32, i64) -> (i32, i64)
   "use_i32"(%2#0) : (i32) -> ()
   "use_i64"(%2#1) : (i64) -> ()
 }
index c8a4373..a6b684d 100644 (file)
@@ -1250,8 +1250,9 @@ case: they become arguments to the entry block
 Syntax:
 
 ``` {.ebnf}
-instruction ::= (ssa-id `=`)? string-literal `(` ssa-use-list? `)`
+instruction ::= inst-result? string-literal `(` ssa-use-list? `)`
               (`[` successor-list `]`)? attribute-dict? `:` function-type
+inst-result ::= ssa-id ((`:` integer-literal) | (`,` ssa-id)*) `=`
 successor-list ::= successor (`,` successor)*
 ```
 
@@ -1272,9 +1273,16 @@ used to indicate the types of the results and operands.
 Example:
 
 ```mlir {.mlir}
+// Call a function that returns two results.
+// The results of %call can be accessed via the <name> `#` <opNo> syntax.
+%call:2 = call @multi_return() : () -> (f32, i32)
+
+// Pretty form that defines a unique name for each result.
+%foo, %bar = call @multi_return() : () -> (f32, i32)
+
 // Invoke a TensorFlow function called tf.scramble with two inputs
 // and an attribute "fruit".
-%2 = "tf.scramble"(%42, %12){fruit: "banana"} : (f32, i32) -> f32
+%2 = "tf.scramble"(%call#0, %bar){fruit: "banana"} : (f32, i32) -> f32
 
 // Invoke the TPU specific add instruction that takes two vector register
 // as input and produces a vector register.
@@ -1386,8 +1394,7 @@ single function to return.
 Syntax:
 
 ``` {.ebnf}
-operation ::=
-    ssa-id `=` `call` function-id `(` ssa-use-list? `)` `:` function-type
+operation ::= `call` function-id `(` ssa-use-list? `)` `:` function-type
 ```
 
 The `call` operation represents a direct call to a function. The operands and
@@ -1406,8 +1413,7 @@ Example:
 Syntax:
 
 ``` {.ebnf}
-operation ::= ssa-id `=` `call_indirect` ssa-use
-                `(` ssa-use-list? `)` `:` function-type
+operation ::= `call_indirect` ssa-use `(` ssa-use-list? `)` `:` function-type
 ```
 
 The `call_indirect` operation represents an indirect call to a value of function
@@ -2260,7 +2266,7 @@ Example:
 
 ```mlir {.mlir}
 // LLVM: %x = call {i16, i1} @llvm.sadd.with.overflow.i16(i16 %a, i16 %b)
-%x = "llvm.sadd.with.overflow.i16"(%a, %b) : (i16, i16) -> (i16, i1)
+%x:2 = "llvm.sadd.with.overflow.i16"(%a, %b) : (i16, i16) -> (i16, i1)
 ```
 
 These operations only work when targeting LLVM as a backend (e.g. for CPUs and
index 82d5813..b9f7921 100644 (file)
@@ -1412,11 +1412,6 @@ void FunctionPrinter::printValueID(Value *value, bool printResultNo) const {
       resultNo = result->getResultNumber();
       lookupValue = result->getOwner()->getResult(0);
     }
-  } else if (auto *result = dyn_cast<OpResult>(value)) {
-    if (result->getOwner()->getNumResults() != 1) {
-      resultNo = result->getResultNumber();
-      lookupValue = result->getOwner()->getResult(0);
-    }
   }
 
   auto it = valueIDs.find(lookupValue);
@@ -1439,8 +1434,10 @@ void FunctionPrinter::printValueID(Value *value, bool printResultNo) const {
 }
 
 void FunctionPrinter::printOperation(Operation *op) {
-  if (op->getNumResults()) {
+  if (size_t numResults = op->getNumResults()) {
     printValueID(op->getResult(0), /*printResultNo=*/false);
+    if (numResults > 1)
+      os << ':' << numResults;
     os << " = ";
   }
 
index e2ba5a0..b8d4294 100644 (file)
@@ -2789,16 +2789,50 @@ ParseResult FunctionParser::parseOptionalBlockArgList(
 /// Parse an operation.
 ///
 ///  operation ::=
-///    (ssa-id `=`)? string '(' ssa-use-list? ')' attribute-dict?
+///    operation-result? string '(' ssa-use-list? ')' attribute-dict?
 ///    `:` function-type trailing-location?
+///  operation-result ::= ssa-id ((`:` integer-literal) | (`,` ssa-id)*) `=`
 ///
 ParseResult FunctionParser::parseOperation() {
   auto loc = getToken().getLoc();
-
-  StringRef resultID;
+  SmallVector<std::pair<StringRef, SMLoc>, 1> resultIDs;
+  size_t numExpectedResults;
   if (getToken().is(Token::percent_identifier)) {
-    resultID = getTokenSpelling();
+    // Parse the first result id.
+    resultIDs.emplace_back(getTokenSpelling(), loc);
     consumeToken(Token::percent_identifier);
+
+    // If the next token is a ':', we parse the expected result count.
+    if (consumeIf(Token::colon)) {
+      // Check that the next token is an integer.
+      if (!getToken().is(Token::integer))
+        return emitError("expected integer number of results");
+
+      // Check that number of results is > 0.
+      auto val = getToken().getUInt64IntegerValue();
+      if (!val.hasValue() || val.getValue() < 1)
+        return emitError("expected named operation to have atleast 1 result");
+      consumeToken(Token::integer);
+      numExpectedResults = *val;
+    } else {
+      // Otherwise, this is a comma separated list of result ids.
+      if (consumeIf(Token::comma)) {
+        auto parseNextResult = [&] {
+          // Parse the next result id.
+          if (!getToken().is(Token::percent_identifier))
+            return emitError("expected valid ssa identifier");
+
+          resultIDs.emplace_back(getTokenSpelling(), getToken().getLoc());
+          consumeToken(Token::percent_identifier);
+          return ParseSuccess;
+        };
+
+        if (parseCommaSeparatedList(parseNextResult))
+          return ParseFailure;
+      }
+      numExpectedResults = resultIDs.size();
+    }
+
     if (parseToken(Token::equal, "expected '=' after SSA name"))
       return ParseFailure;
   }
@@ -2816,13 +2850,28 @@ ParseResult FunctionParser::parseOperation() {
     return ParseFailure;
 
   // If the operation had a name, register it.
-  if (!resultID.empty()) {
+  if (!resultIDs.empty()) {
     if (op->getNumResults() == 0)
       return emitError(loc, "cannot name an operation with no results");
-
-    for (unsigned i = 0, e = op->getNumResults(); i != e; ++i)
-      if (addDefinition({resultID, i, loc}, op->getResult(i)))
-        return ParseFailure;
+    if (numExpectedResults != op->getNumResults())
+      return emitError(loc, "operation defines more results than expected : " +
+                                Twine(op->getNumResults()) + " vs " +
+                                Twine(numExpectedResults));
+
+    // If the number of result names matches the number of operation results, we
+    // can use the names directly.
+    if (resultIDs.size() == op->getNumResults()) {
+      for (unsigned i = 0, e = op->getNumResults(); i != e; ++i)
+        if (addDefinition({resultIDs[i].first, 0, resultIDs[i].second},
+                          op->getResult(i)))
+          return ParseFailure;
+    } else {
+      // Otherwise, we use the first name for each result.
+      StringRef name = resultIDs.front().first;
+      for (unsigned i = 0, e = op->getNumResults(); i != e; ++i)
+        if (addDefinition({name, i, loc}, op->getResult(i)))
+          return ParseFailure;
+    }
   }
 
   // Try to parse the optional trailing location.
index 25a2e13..fc173b4 100644 (file)
@@ -398,7 +398,7 @@ TEST_FUNC(custom_ops) {
   // CHECK:   affine.for %i1 {{.*}}
   // CHECK:     {{.*}} = "my_custom_op"{{.*}} : (index, index) -> index
   // CHECK:     "my_custom_op_0"{{.*}} : (index, index) -> ()
-  // CHECK:     [[TWO:%[a-z0-9]+]] = "my_custom_op_2"{{.*}} : (index, index) -> (index, index)
+  // CHECK:     [[TWO:%[a-z0-9]+]]:2 = "my_custom_op_2"{{.*}} : (index, index) -> (index, index)
   // CHECK:     {{.*}} = "my_custom_op"([[TWO]]#0, [[TWO]]#1) : (index, index) -> index
   // clang-format on
   f->print(llvm::outs());
index 9fc0bb9..f5ce967 100644 (file)
@@ -373,7 +373,7 @@ func @bbargMismatch(i32, f32) {
 
 func @br_mismatch() {
 ^bb0:
-  %0 = "foo"() : () -> (i1, i17)
+  %0:2 = "foo"() : () -> (i1, i17)
   // expected-error @+1 {{branch has 2 operands, but target block has 1}}
   br ^bb1(%0#1, %0#0 : i17, i1)
 
@@ -992,3 +992,70 @@ func @invalid_region_dominance() {
   }
   return
 }
+
+// -----
+
+func @multi_result_missing_count() {
+  // expected-error@+1 {{expected integer number of results}}
+  %0: = "foo" () : () -> (i32, i32)
+  return
+}
+
+// -----
+
+func @multi_result_zero_count() {
+  // expected-error@+1 {{expected named operation to have atleast 1 result}}
+  %0:0 = "foo" () : () -> (i32, i32)
+  return
+}
+
+// -----
+
+func @multi_result_invalid_identifier() {
+  // expected-error@+1 {{expected valid ssa identifier}}
+  %0, = "foo" () : () -> (i32, i32)
+  return
+}
+
+// -----
+
+func @multi_result_mismatch_count() {
+  // expected-error@+1 {{operation defines more results than expected : 2 vs 1}}
+  %0:1 = "foo" () : () -> (i32, i32)
+  return
+}
+
+// -----
+
+func @multi_result_mismatch_count() {
+  // expected-error@+1 {{operation defines more results than expected : 2 vs 3}}
+  %0, %1, %3 = "foo" () : () -> (i32, i32)
+  return
+}
+
+// -----
+
+func @no_result_with_name() {
+  // expected-error@+1 {{cannot name an operation with no results}}
+  %0 = "foo" () : () -> ()
+  return
+}
+
+// -----
+
+func @conflicting_names() {
+  // expected-error@+1 {{previously defined here}}
+  %foo, %bar  = "foo" () : () -> (i32, i32)
+
+  // expected-error@+1 {{redefinition of SSA value '%bar'}}
+  %bar, %baz  = "foo" () : () -> (i32, i32)
+  return
+}
+
+// -----
+
+func @ssa_name_missing_eq() {
+  // expected-error@+1 {{expected '=' after SSA name}}
+  %0:2 "foo" () : () -> (i32, i32)
+  return
+}
index 797244b..63d3700 100644 (file)
@@ -136,7 +136,7 @@ func @simpleCFG(%arg0: i32, %f: f32) -> i1 {
   // CHECK: %0 = "foo"() : () -> i64
   %1 = "foo"() : ()->i64
   // CHECK: "bar"(%0) : (i64) -> (i1, i1, i1)
-  %2 = "bar"(%1) : (i64) -> (i1,i1,i1)
+  %2:3 = "bar"(%1) : (i64) -> (i1,i1,i1)
   // CHECK: return %1#1
   return %2#1 : i1
 // CHECK: }
@@ -146,7 +146,7 @@ func @simpleCFG(%arg0: i32, %f: f32) -> i1 {
 func @simpleCFGUsingBBArgs(i32, i64) {
 ^bb42 (%arg0: i32, %f: i64):
   // CHECK: "bar"(%arg1) : (i64) -> (i1, i1, i1)
-  %2 = "bar"(%f) : (i64) -> (i1,i1,i1)
+  %2:3 = "bar"(%f) : (i64) -> (i1,i1,i1)
   // CHECK: return{{$}}
   return
 // CHECK: }
@@ -177,8 +177,8 @@ func @func_with_one_arg(%c : i1) -> i2 {
 
 // CHECK-LABEL: func @func_with_two_args(%arg0: f16, %arg1: i8) -> (i1, i32) {
 func @func_with_two_args(%a : f16, %b : i8) -> (i1, i32) {
-  // CHECK: %0 = "foo"(%arg0, %arg1) : (f16, i8) -> (i1, i32)
-  %c = "foo"(%a, %b) : (f16, i8)->(i1, i32)
+  // CHECK: %0:2 = "foo"(%arg0, %arg1) : (f16, i8) -> (i1, i32)
+  %c:2 = "foo"(%a, %b) : (f16, i8)->(i1, i32)
   return %c#0, %c#1 : i1, i32  // CHECK: return %0#0, %0#1 : i1, i32
 } // CHECK: }
 
@@ -381,38 +381,38 @@ func @attributes() {
 
 // CHECK-LABEL: func @ssa_values() -> (i16, i8) {
 func @ssa_values() -> (i16, i8) {
-  // CHECK: %0 = "foo"() : () -> (i1, i17)
-  %0 = "foo"() : () -> (i1, i17)
+  // CHECK: %0:2 = "foo"() : () -> (i1, i17)
+  %0:2 = "foo"() : () -> (i1, i17)
   br ^bb2
 
 ^bb1:       // CHECK: ^bb1: // pred: ^bb2
-  // CHECK: %1 = "baz"(%2#1, %2#0, %0#1) : (f32, i11, i17) -> (i16, i8)
-  %1 = "baz"(%2#1, %2#0, %0#1) : (f32, i11, i17) -> (i16, i8)
+  // CHECK: %1:2 = "baz"(%2#1, %2#0, %0#1) : (f32, i11, i17) -> (i16, i8)
+  %1:2 = "baz"(%2#1, %2#0, %0#1) : (f32, i11, i17) -> (i16, i8)
 
   // CHECK: return %1#0, %1#1 : i16, i8
   return %1#0, %1#1 : i16, i8
 
 ^bb2:       // CHECK: ^bb2:  // pred: ^bb0
-  // CHECK: %2 = "bar"(%0#0, %0#1) : (i1, i17) -> (i11, f32)
-  %2 = "bar"(%0#0, %0#1) : (i1, i17) -> (i11, f32)
+  // CHECK: %2:2 = "bar"(%0#0, %0#1) : (i1, i17) -> (i11, f32)
+  %2:2 = "bar"(%0#0, %0#1) : (i1, i17) -> (i11, f32)
   br ^bb1
 }
 
 // CHECK-LABEL: func @bbargs() -> (i16, i8) {
 func @bbargs() -> (i16, i8) {
-  // CHECK: %0 = "foo"() : () -> (i1, i17)
-  %0 = "foo"() : () -> (i1, i17)
+  // CHECK: %0:2 = "foo"() : () -> (i1, i17)
+  %0:2 = "foo"() : () -> (i1, i17)
   br ^bb1(%0#1, %0#0 : i17, i1)
 
 ^bb1(%x: i17, %y: i1):       // CHECK: ^bb1(%1: i17, %2: i1):
-  // CHECK: %3 = "baz"(%1, %2, %0#1) : (i17, i1, i17) -> (i16, i8)
-  %1 = "baz"(%x, %y, %0#1) : (i17, i1, i17) -> (i16, i8)
+  // CHECK: %3:2 = "baz"(%1, %2, %0#1) : (i17, i1, i17) -> (i16, i8)
+  %1:2 = "baz"(%x, %y, %0#1) : (i17, i1, i17) -> (i16, i8)
   return %1#0, %1#1 : i16, i8
 }
 
 // CHECK-LABEL: func @verbose_terminators() -> (i1, i17)
 func @verbose_terminators() -> (i1, i17) {
-  %0 = "foo"() : () -> (i1, i17)
+  %0:2 = "foo"() : () -> (i1, i17)
 // CHECK:  br ^bb1(%0#0, %0#1 : i1, i17)
   "std.br"()[^bb1(%0#0, %0#1 : i1, i17)] : () -> ()
 
@@ -825,3 +825,10 @@ func @tuple_multi_element(tuple<i32, i16, f32>)
 
 // CHECK-LABEL: func @tuple_nested(tuple<tuple<tuple<i32>>>)
 func @tuple_nested(tuple<tuple<tuple<i32>>>)
+
+// CHECK-LABEL: func @pretty_form_multi_result
+func @pretty_form_multi_result() -> (i16, i16) {
+  // CHECK: %0:2 = "foo_div"() : () -> (i16, i16)
+  %quot, %rem = "foo_div"() : () -> (i16, i16)
+  return %quot, %rem : i16, i16
+}
index b96189c..d2b988a 100644 (file)
@@ -345,7 +345,7 @@ func @multireturn_caller() {
 // CHECK-NEXT:   {{.*}} = "llvm.extractvalue"({{.*}}) {position: [0]} : (!llvm<"{ i64, float, { float*, i64, i64 } }">) -> !llvm<"i64">
 // CHECK-NEXT:   {{.*}} = "llvm.extractvalue"({{.*}}) {position: [1]} : (!llvm<"{ i64, float, { float*, i64, i64 } }">) -> !llvm<"float">
 // CHECK-NEXT:   {{.*}} = "llvm.extractvalue"({{.*}}) {position: [2]} : (!llvm<"{ i64, float, { float*, i64, i64 } }">) -> !llvm<"{ float*, i64, i64 }">
-  %0 = call @multireturn() : () -> (i64, f32, memref<42x?x10x?xf32>)
+  %0:3 = call @multireturn() : () -> (i64, f32, memref<42x?x10x?xf32>)
   %1 = constant 42 : i64
 // CHECK:   {{.*}} = "llvm.add"({{.*}}, {{.*}}) : (!llvm<"i64">, !llvm<"i64">) -> !llvm<"i64">
   %2 = addi %0#0, %1 : i64
index 15b6ba7..df829fe 100644 (file)
@@ -46,7 +46,7 @@ func @vec1d_2(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK-NEXT:   {{.*}} = vector_transfer_read %arg0, [[C0]], [[APP3]] {permutation_map: #[[map_proj_d0d1_d1]]} : {{.*}} -> vector<128xf32>
    affine.for %i3 = 0 to %M { // vectorized
      %r3 = affine.apply (d0) -> (d0) (%i3)
-     %a3 = load %A[%cst0, %r3#0] : memref<?x?xf32>
+     %a3 = load %A[%cst0, %r3] : memref<?x?xf32>
    }
    return
 }
@@ -161,7 +161,7 @@ func @vec_rejected_2(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // CHECK:   affine.for %i{{[0-9]*}} = 0 to [[ARG_M]] {
    affine.for %i2 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1
      %r2 = affine.apply (d0) -> (d0) (%i2)
-     %a2 = load %A[%r2#0, %cst0] : memref<?x?xf32>
+     %a2 = load %A[%r2, %cst0] : memref<?x?xf32>
    }
    return
 }
index 5bbf3b8..9fd0088 100644 (file)
@@ -126,17 +126,17 @@ func @loop_nest_multiple_results() {
     // UNROLL-FULL: %0 = affine.apply [[MAP4]](%i0, %c0)
     // UNROLL-FULL-NEXT: %1 = "addi32"(%0, %0) : (index, index) -> index
     // UNROLL-FULL-NEXT: %2 = affine.apply #map{{.*}}(%i0, %c0)
-    // UNROLL-FULL-NEXT: %3 = "fma"(%2, %0, %0) : (index, index, index) -> (index, index)
+    // UNROLL-FULL-NEXT: %3:2 = "fma"(%2, %0, %0) : (index, index, index) -> (index, index)
     // UNROLL-FULL-NEXT: %4 = affine.apply #map{{.*}}(%c0)
     // UNROLL-FULL-NEXT: %5 = affine.apply #map{{.*}}(%i0, %4)
     // UNROLL-FULL-NEXT: %6 = "addi32"(%5, %5) : (index, index) -> index
     // UNROLL-FULL-NEXT: %7 = affine.apply #map{{.*}}(%i0, %4)
-    // UNROLL-FULL-NEXT: %8 = "fma"(%7, %5, %5) : (index, index, index) -> (index, index)
+    // UNROLL-FULL-NEXT: %8:2 = "fma"(%7, %5, %5) : (index, index, index) -> (index, index)
     affine.for %j = 0 to 2 step 1 {
       %x = affine.apply (d0, d1) -> (d0 + 1) (%i, %j)
       %y = "addi32"(%x, %x) : (index, index) -> index
       %z = affine.apply (d0, d1) -> (d0 + 3) (%i, %j)
-      %w = "fma"(%z, %x, %x) : (index, index, index) -> (index, index)
+      %w:2 = "fma"(%z, %x, %x) : (index, index, index) -> (index, index)
     }
   }       // UNROLL-FULL:  }
   return  // UNROLL-FULL:  return