[mlir][vector] Support multiple result types in vector.mask
authorMatthias Springer <springerm@google.com>
Fri, 13 Jan 2023 15:52:48 +0000 (16:52 +0100)
committerMatthias Springer <springerm@google.com>
Fri, 13 Jan 2023 15:59:36 +0000 (16:59 +0100)
The verifier already had support for multiple result types, but the op definition assumed a single, optional result.

Differential Revision: https://reviews.llvm.org/D141683

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Vector/IR/VectorOps.cpp

index 04af8d3..5a14f0d 100644 (file)
@@ -2287,10 +2287,13 @@ def Vector_MaskOp : Vector_Op<"mask", [
     The `vector.mask` is a `MaskingOpInterface` operation that predicates the
     execution of another operation. It takes an `i1` vector mask and an
     optional passthru vector as arguments.
-    A `vector.yield`-terminated region encloses the operation to be masked.
-    Values used within the region are captured from above. Only one *maskable*
-    operation can be masked with a `vector.mask` operation at a time. An
-    operation is *maskable* if it implements the `MaskableOpInterface`.
+
+    A implicitly `vector.yield`-terminated region encloses the operation to be
+    masked. Values used within the region are captured from above. Only one
+    *maskable* operation can be masked with a `vector.mask` operation at a time.
+    An operation is *maskable* if it implements the `MaskableOpInterface`. The
+    terminator yields all results of the maskable operation to the result of
+    this operation.
 
     The vector mask argument holds a bit for each vector lane and determines
     which vector lanes should execute the maskable operation and which ones
@@ -2321,12 +2324,16 @@ def Vector_MaskOp : Vector_Op<"mask", [
     ```
       vector.mask %mask { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, memref<?xf32> } : vector<16xi1>
     ```
+
+    ```
+      vector.mask %mask { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor<?xf32> } : vector<16xi1> -> tensor<?xf32>
+    ```
   }];
 
   // TODO: Support multiple results and passthru values.
   let arguments = (ins VectorOf<[I1]>:$mask,
                    Optional<AnyType>:$passthru);
-  let results = (outs Optional<AnyType>:$results);
+  let results = (outs Variadic<AnyType>:$results);
   let regions = (region SizedRegion<1>:$maskRegion);
 
   let skipDefaultBuilders = 1;
@@ -2334,10 +2341,10 @@ def Vector_MaskOp : Vector_Op<"mask", [
     OpBuilder<(ins "Value":$mask,
                    CArg<"function_ref<void(OpBuilder &, Location)>",
                         "buildTerminatedBody">:$maskRegion)>,
-    OpBuilder<(ins "Type":$resultType, "Value":$mask,
+    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask,
                    CArg<"function_ref<void(OpBuilder &, Location)>",
                         "buildTerminatedBody">:$maskRegion)>,
-    OpBuilder<(ins "Type":$resultType, "Value":$mask,
+    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask,
                    "Value":$passthru,
                    CArg<"function_ref<void(OpBuilder &, Location)>",
                         "buildTerminatedBody">:$maskRegion)>
index e2a3e61..f00d849 100644 (file)
@@ -5288,20 +5288,20 @@ void MaskOp::build(
 }
 
 void MaskOp::build(
-    OpBuilder &builder, OperationState &result, Type resultType, Value mask,
-    function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
-  build(builder, result, resultType, mask, /*passthru=*/Value(),
+    OpBuilder &builder, OperationState &result, TypeRange resultTypes,
+    Value mask, function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
+  build(builder, result, resultTypes, mask, /*passthru=*/Value(),
         maskRegionBuilder);
 }
 
 void MaskOp::build(
-    OpBuilder &builder, OperationState &result, Type resultType, Value mask,
-    Value passthru,
+    OpBuilder &builder, OperationState &result, TypeRange resultTypes,
+    Value mask, Value passthru,
     function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
   build(builder, result, mask, maskRegionBuilder);
   if (passthru)
     result.addOperands(passthru);
-  result.addTypes(resultType);
+  result.addTypes(resultTypes);
 }
 
 ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) {