From f94131a2a502118e0507164dcef160d1cfecb316 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 13 Jan 2023 16:52:48 +0100 Subject: [PATCH] [mlir][vector] Support multiple result types in vector.mask The verifier already had support for multiple result types, but the op definition assumed a single, optional result. Differential Revision: https://reviews.llvm.org/D141683 --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 21 ++++++++++++++------- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 12 ++++++------ 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 04af8d3..5a14f0d 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -2287,10 +2287,13 @@ def Vector_MaskOp : Vector_Op<"mask", [ The `vector.mask` is a `MaskingOpInterface` operation that predicates the execution of another operation. It takes an `i1` vector mask and an optional passthru vector as arguments. - A `vector.yield`-terminated region encloses the operation to be masked. - Values used within the region are captured from above. Only one *maskable* - operation can be masked with a `vector.mask` operation at a time. An - operation is *maskable* if it implements the `MaskableOpInterface`. + + A implicitly `vector.yield`-terminated region encloses the operation to be + masked. Values used within the region are captured from above. Only one + *maskable* operation can be masked with a `vector.mask` operation at a time. + An operation is *maskable* if it implements the `MaskableOpInterface`. The + terminator yields all results of the maskable operation to the result of + this operation. The vector mask argument holds a bit for each vector lane and determines which vector lanes should execute the maskable operation and which ones @@ -2321,12 +2324,16 @@ def Vector_MaskOp : Vector_Op<"mask", [ ``` vector.mask %mask { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, memref } : vector<16xi1> ``` + + ``` + vector.mask %mask { vector.transfer_write %val, %t0[%idx] : vector<16xf32>, tensor } : vector<16xi1> -> tensor + ``` }]; // TODO: Support multiple results and passthru values. let arguments = (ins VectorOf<[I1]>:$mask, Optional:$passthru); - let results = (outs Optional:$results); + let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$maskRegion); let skipDefaultBuilders = 1; @@ -2334,10 +2341,10 @@ def Vector_MaskOp : Vector_Op<"mask", [ OpBuilder<(ins "Value":$mask, CArg<"function_ref", "buildTerminatedBody">:$maskRegion)>, - OpBuilder<(ins "Type":$resultType, "Value":$mask, + OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask, CArg<"function_ref", "buildTerminatedBody">:$maskRegion)>, - OpBuilder<(ins "Type":$resultType, "Value":$mask, + OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask, "Value":$passthru, CArg<"function_ref", "buildTerminatedBody">:$maskRegion)> diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index e2a3e61..f00d849 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -5288,20 +5288,20 @@ void MaskOp::build( } void MaskOp::build( - OpBuilder &builder, OperationState &result, Type resultType, Value mask, - function_ref maskRegionBuilder) { - build(builder, result, resultType, mask, /*passthru=*/Value(), + OpBuilder &builder, OperationState &result, TypeRange resultTypes, + Value mask, function_ref maskRegionBuilder) { + build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskRegionBuilder); } void MaskOp::build( - OpBuilder &builder, OperationState &result, Type resultType, Value mask, - Value passthru, + OpBuilder &builder, OperationState &result, TypeRange resultTypes, + Value mask, Value passthru, function_ref maskRegionBuilder) { build(builder, result, mask, maskRegionBuilder); if (passthru) result.addOperands(passthru); - result.addTypes(resultType); + result.addTypes(resultTypes); } ParseResult MaskOp::parse(OpAsmParser &parser, OperationState &result) { -- 2.7.4