From c339f9e1c3276bcd8db806bc87045a5ef2079fec Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Wed, 22 Feb 2023 01:20:10 +0000 Subject: [PATCH] [mlir][Vector] Support masking for more contraction flavors This patch adds masking support for more contraction flavors including those with any combiner operation (add, mul, min, max, and, or, etc.) and regular matmul contractions. Combiner operations that are performing vertical reductions (and, therefore, they are not represented with a horizontal reduction operation) can be executed unmasked. However, the previous value of the accumulator must be propagated for lanes that shouldn't accumulate. We achieve this goal by introducing a select operation after the accumulator to choose between the combined and the previous accumulator value. This design decision is made to avoid introducing masking support to all the arithmetic and logical operations in the Arith dialect. VP intrinsics do not support pass-thru values either so we would have to generate the same sequence when lowering to LLVM. The op + select pattern is peepholed by some backend with native masking support for those operations. Consequently, this patch removes masking support from the vector.fma operation to follow the same approach for all the combiner operations. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D144239 --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.h | 15 ++- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 1 - .../VectorToLLVM/ConvertVectorToLLVM.cpp | 52 +------- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 121 +++++++++++------- .../Dialect/Vector/Transforms/VectorTransforms.cpp | 36 +++--- .../Conversion/VectorToLLVM/vector-to-llvm.mlir | 140 ++++++++++++++++++--- .../Dialect/Vector/vector-contract-transforms.mlir | 72 +++++++---- 7 files changed, 281 insertions(+), 156 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h index deb86df..56f8b4b 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -191,7 +191,7 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA, /// Return the result value of reducing two scalar/vector values with the /// corresponding arith operation. Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind, - Value v1, Value v2); + Value v1, Value acc, Value mask = Value()); /// Returns true if `attr` has "parallel" iterator type semantics. inline bool isParallelIterator(Attribute attr) { @@ -214,8 +214,17 @@ void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp); /// Creates a vector.mask operation around a maskable operation. Returns the /// vector.mask operation if the mask provided is valid. Otherwise, returns the /// maskable operation itself. -Operation *maskOperation(RewriterBase &rewriter, Operation *maskableOp, - Value mask); +Operation *maskOperation(OpBuilder &builder, Operation *maskableOp, + Value mask, Value passthru = Value()); + +/// Creates a vector select operation that picks values from `newValue` or +/// `passthru` for each result vector lane based on `mask`. This utility is used +/// to propagate the pass-thru value for masked-out or expeculatively executed +/// lanes. VP intrinsics do not support pass-thru values and every mask-out lane +/// is set to poison. LLVM backends are usually able to match op + select +/// patterns and fold them into a native target instructions. +Value selectPassthru(OpBuilder &builder, Value mask, Value newValue, + Value passthru); } // namespace vector } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index c5ebe9f..04fb36a 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -633,7 +633,6 @@ def Vector_ExtractOp : def Vector_FMAOp : Op, - DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ] # ElementwiseMappable.traits>, Arguments<(ins VectorOfAnyRankOf<[AnyFloat]>:$lhs, diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 159bae8..b73c01a 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -704,11 +704,6 @@ public: Value acc = adaptor.getAcc(); Location loc = reductionOp.getLoc(); - // Masked reductions are lowered separately. - auto maskableOp = cast(reductionOp.getOperation()); - if (maskableOp.isMasked()) - return failure(); - if (eltType.isIntOrIndex()) { // Integer reductions: add/mul/min/max/and/or/xor. Value result; @@ -1108,47 +1103,12 @@ public: if (vType.getRank() > 1) return failure(); - // Masked fmas are lowered separately. - auto maskableOp = cast(fmaOp.getOperation()); - if (maskableOp.isMasked()) - return failure(); - rewriter.replaceOpWithNewOp( fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); return success(); } }; -/// Conversion pattern that turns a masked vector.fma on a 1-D vector into their -/// LLVM counterpart representation. Non side effecting VP intrinsics are not -/// fully supported by some backends, including x86, and they don't support -/// pass-through values either. For these reasons, we generate an unmasked -/// fma followed by a select instrution to emulate the masking behavior. -/// This pattern is peepholed by some backends with support for masked fma -/// instructions. This pattern does not match vectors of n >= 2 rank. -class MaskedFMAOp1DConversion - : public VectorMaskOpConversionBase { -public: - using VectorMaskOpConversionBase::VectorMaskOpConversionBase; - - MaskedFMAOp1DConversion(LLVMTypeConverter &converter, bool fullVPIntr) - : VectorMaskOpConversionBase(converter) {} - - virtual LogicalResult matchAndRewriteMaskableOp( - vector::MaskOp maskOp, MaskableOpInterface maskableOp, - ConversionPatternRewriter &rewriter) const override { - auto fmaOp = cast(maskableOp.getOperation()); - Type llvmType = typeConverter->convertType(fmaOp.getVectorType()); - - Value fmulAddOp = rewriter.create( - fmaOp.getLoc(), llvmType, fmaOp.getLhs(), fmaOp.getRhs(), - fmaOp.getAcc()); - rewriter.replaceOpWithNewOp( - maskOp, llvmType, maskOp.getMask(), fmulAddOp, fmaOp.getAcc()); - return success(); - } -}; - class VectorInsertElementOpConversion : public ConvertOpToLLVMPattern { public: @@ -1315,11 +1275,6 @@ public: if (vType.getRank() < 2) return failure(); - // Masked fmas are lowered separately. - auto maskableOp = cast(op.getOperation()); - if (maskableOp.isMasked()) - return failure(); - auto loc = op.getLoc(); auto elemType = vType.getElementType(); Value zero = rewriter.create( @@ -1748,10 +1703,9 @@ void mlir::populateVectorToLLVMConversionPatterns( patterns .add, VectorLoadStoreConversion, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 8c6609d..eb58f90 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1790,16 +1790,6 @@ std::optional> FMAOp::getShapeForUnroll() { return llvm::to_vector<4>(getVectorType().getShape()); } -// MaskableOpInterface methods. - -/// Returns the mask type expected by this operation. Mostly used for -/// verification purposes. It requires the operation to be vectorized." -Type FMAOp::getExpectedMaskType() { - auto vecType = this->getVectorType(); - return VectorType::get(vecType.getShape(), - IntegerType::get(vecType.getContext(), /*width=*/1)); -} - //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// @@ -5807,53 +5797,71 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) { } Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, - CombiningKind kind, Value v1, Value v2) { + CombiningKind kind, Value v1, Value acc, + Value mask) { Type t1 = getElementTypeOrSelf(v1.getType()); - Type t2 = getElementTypeOrSelf(v2.getType()); + Type tAcc = getElementTypeOrSelf(acc.getType()); + Value result; + switch (kind) { case CombiningKind::ADD: - if (t1.isIntOrIndex() && t2.isIntOrIndex()) - return b.createOrFold(loc, v1, v2); - else if (t1.isa() && t2.isa()) - return b.createOrFold(loc, v1, v2); - llvm_unreachable("invalid value types for ADD reduction"); + if (t1.isIntOrIndex() && tAcc.isIntOrIndex()) + result = b.createOrFold(loc, v1, acc); + else if (t1.isa() && tAcc.isa()) + result = b.createOrFold(loc, v1, acc); + else + llvm_unreachable("invalid value types for ADD reduction"); + break; case CombiningKind::AND: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); + assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); + result = b.createOrFold(loc, v1, acc); + break; case CombiningKind::MAXF: - assert(t1.isa() && t2.isa() && + assert(t1.isa() && tAcc.isa() && "expected float values"); - return b.createOrFold(loc, v1, v2); + result = b.createOrFold(loc, v1, acc); + break; case CombiningKind::MINF: - assert(t1.isa() && t2.isa() && + assert(t1.isa() && tAcc.isa() && "expected float values"); - return b.createOrFold(loc, v1, v2); + result = b.createOrFold(loc, v1, acc); + break; case CombiningKind::MAXSI: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); + assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); + result = b.createOrFold(loc, v1, acc); + break; case CombiningKind::MINSI: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); + assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); + result = b.createOrFold(loc, v1, acc); + break; case CombiningKind::MAXUI: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); + assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); + result = b.createOrFold(loc, v1, acc); + break; case CombiningKind::MINUI: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); + assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); + result = b.createOrFold(loc, v1, acc); + break; case CombiningKind::MUL: - if (t1.isIntOrIndex() && t2.isIntOrIndex()) - return b.createOrFold(loc, v1, v2); - else if (t1.isa() && t2.isa()) - return b.createOrFold(loc, v1, v2); - llvm_unreachable("invalid value types for MUL reduction"); + if (t1.isIntOrIndex() && tAcc.isIntOrIndex()) + result = b.createOrFold(loc, v1, acc); + else if (t1.isa() && tAcc.isa()) + result = b.createOrFold(loc, v1, acc); + else + llvm_unreachable("invalid value types for MUL reduction"); + break; case CombiningKind::OR: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); + assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); + result = b.createOrFold(loc, v1, acc); + break; case CombiningKind::XOR: - assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values"); - return b.createOrFold(loc, v1, v2); + assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values"); + result = b.createOrFold(loc, v1, acc); + break; }; - llvm_unreachable("unknown CombiningKind"); + + assert(result && "unknown CombiningKind"); + return selectPassthru(b, mask, result, acc); } //===----------------------------------------------------------------------===// @@ -5875,13 +5883,34 @@ void mlir::vector::createMaskOpRegion(OpBuilder &builder, /// Creates a vector.mask operation around a maskable operation. Returns the /// vector.mask operation if the mask provided is valid. Otherwise, returns /// the maskable operation itself. -Operation *mlir::vector::maskOperation(RewriterBase &rewriter, - Operation *maskableOp, Value mask) { +Operation *mlir::vector::maskOperation(OpBuilder &builder, + Operation *maskableOp, Value mask, + Value passthru) { if (!mask) return maskableOp; - return rewriter.create(maskableOp->getLoc(), - maskableOp->getResultTypes(), mask, maskableOp, - createMaskOpRegion); + if (passthru) + return builder.create(maskableOp->getLoc(), + maskableOp->getResultTypes(), mask, passthru, + maskableOp, createMaskOpRegion); + return builder.create(maskableOp->getLoc(), + maskableOp->getResultTypes(), mask, maskableOp, + createMaskOpRegion); +} + +/// Creates a vector select operation that picks values from `newValue` or +/// `passthru` for each result vector lane based on `mask`. This utility is used +/// to propagate the pass-thru value of vector.mask or for cases where only the +/// pass-thru value propagation is needed. VP intrinsics do not support +/// pass-thru values and every mask-out lane is set to poison. LLVM backends are +/// usually able to match op + select patterns and fold them into a native +/// target instructions. +Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask, + Value newValue, Value passthru) { + if (!mask) + return newValue; + + return builder.create(newValue.getLoc(), newValue.getType(), + mask, newValue, passthru); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index e7b8cd5..eecf970 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -151,8 +151,7 @@ static SmallVector extractVector(ArrayAttr arrayAttr) { static std::optional createContractArithOp(Location loc, Value x, Value y, Value acc, vector::CombiningKind kind, PatternRewriter &rewriter, - bool isInt, - std::optional maybeMask = std::nullopt) { + bool isInt, Value mask = Value()) { using vector::CombiningKind; Value mul; @@ -171,20 +170,20 @@ createContractArithOp(Location loc, Value x, Value y, Value acc, return std::nullopt; // Special case for fused multiply-add. if (acc && acc.getType().isa() && kind == CombiningKind::ADD) { - Operation *fmaOp = rewriter.create(loc, x, y, acc); - if (maybeMask.has_value() && maybeMask.value()) - fmaOp = maskOperation(rewriter, fmaOp, maybeMask.value()); - return fmaOp->getResult(0); + Value fma = rewriter.create(loc, x, y, acc); + if (mask) + // The fma op doesn't need explicit masking. However, fma ops used in + // reductions must preserve previous 'acc' values for masked-out lanes. + fma = selectPassthru(rewriter, mask, fma, acc); + return fma; } mul = rewriter.create(loc, x, y); } - assert((!maybeMask.has_value() || !maybeMask.value()) && - "Unsupported masked case"); - if (!acc) return std::optional(mul); - return makeArithReduction(rewriter, loc, kind, mul, acc); + + return makeArithReduction(rewriter, loc, kind, mul, acc, mask); } /// Return the positions of the reductions in the given map. @@ -587,13 +586,17 @@ public: for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) { auto pos = rewriter.getI64ArrayAttr(d); Value x = - rewriter.create(loc, eltType, op.getLhs(), pos); + rewriter.create(loc, op.getLhs(), pos); Value a = rewriter.create(loc, rhsType, x); Value r = nullptr; if (acc) - r = rewriter.create(loc, rhsType, acc, pos); + r = rewriter.create(loc, acc, pos); + Value extrMask; + if (mask) + extrMask = rewriter.create(loc, mask, pos); + std::optional m = createContractArithOp( - loc, a, op.getRhs(), r, kind, rewriter, isInt, mask); + loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask); if (!m.has_value()) return failure(); result = rewriter.create(loc, resType, *m, result, pos); @@ -638,6 +641,7 @@ struct ContractOpToElementwise if (vectorTransformOptions.vectorContractLowering != vector::VectorContractLowering::ParallelArith) return failure(); + ArrayRef lhsShape = contractOp.getLhsType().getShape(); ArrayRef rhsShape = contractOp.getRhsType().getShape(); AffineMap lhsMap = contractOp.getIndexingMapsArray()[0]; @@ -1564,8 +1568,7 @@ struct UnrolledOuterProductGenerator mask = maskableOp.getMaskingOp().getMask(); } - Value t(Value v) { - static constexpr std::array perm = {1, 0}; + Value t(Value v, ArrayRef perm = {1, 0}) { if (!v) return v; return rewriter.create(loc, v, perm); @@ -1620,7 +1623,8 @@ struct UnrolledOuterProductGenerator bindDims(rewriter.getContext(), m, n, k); // Classical row-major matmul: Just permute the lhs. if (layout({{m, k}, {k, n}, {m, n}})) - return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); + return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), + t(mask, {2, 0, 1})); // TODO: may be better to fail and use some vector -> scalar reduction. if (layout({{m, k}, {n, k}, {m, n}})) { Value tlhs = t(lhs); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index a72ab15..0b9b86a 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -418,16 +418,132 @@ func.func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: v // ----- -func.func @masked_vector_contract(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> { +func.func @masked_float_add_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> { %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> return %0 : vector<2xf32> } -// We can't check for the intermediate 'vector.mask { vector.fma }' state so we -// just make sure the vector.fma is lowered. +// CHECK-LABEL: func.func @masked_float_add_outerprod( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> { +// CHECK: %[[VAL_8:.*]] = llvm.intr.fmuladd(%[[VAL_0]], %{{.*}}, %[[VAL_2]]) : (vector<2xf32>, vector<2xf32>, vector<2xf32>) -> vector<2xf32> +// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_3]], %[[VAL_8]], %[[VAL_2]] : vector<2xi1>, vector<2xf32> -// CHECK: llvm.intr.fmuladd -// CHECK: llvm.select +// ----- + +func.func @masked_float_mul_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> { + %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: func.func @masked_float_mul_outerprod( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> { +// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32> +// CHECK: %[[VAL_9:.*]] = arith.mulf %[[VAL_8]], %[[VAL_2]] : vector<2xf32> +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32> + +// ----- + +func.func @masked_float_max_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> { + %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: func.func @masked_float_max_outerprod( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> { +// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32> +// CHECK: %[[VAL_9:.*]] = arith.maxf %[[VAL_8]], %[[VAL_2]] : vector<2xf32> +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32> + +// ----- + +func.func @masked_float_min_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> { + %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> + return %0 : vector<2xf32> +} + +// CHECK-LABEL: func.func @masked_float_min_outerprod( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> { +// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32> +// CHECK: %[[VAL_9:.*]] = arith.minf %[[VAL_8]], %[[VAL_2]] : vector<2xf32> +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32> + +// ----- + +func.func @masked_int_add_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> { + %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32> + return %0 : vector<2xi32> +} + +// CHECK-LABEL: func.func @masked_int_add_outerprod( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> { +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32> +// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_2]] : vector<2xi32> +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32> + +// ----- + +func.func @masked_int_mul_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> { + %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32> + return %0 : vector<2xi32> +} + +// CHECK-LABEL: func.func @masked_int_mul_outerprod( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> { +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32> +// CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_8]], %[[VAL_2]] : vector<2xi32> +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32> + +// ----- + +func.func @masked_int_max_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> { + %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32> + return %0 : vector<2xi32> +} + +// CHECK-LABEL: func.func @masked_int_max_outerprod( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> { +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32> +// CHECK: %[[VAL_9:.*]] = arith.maxsi %[[VAL_8]], %[[VAL_2]] : vector<2xi32> +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32> + +// ----- + +func.func @masked_int_min_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> { + %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32> + return %0 : vector<2xi32> +} + +// CHECK-LABEL: func.func @masked_int_min_outerprod( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> { +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32> +// CHECK: %[[VAL_9:.*]] = arith.minui %[[VAL_8]], %[[VAL_2]] : vector<2xi32> +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32> + +// ----- + +func.func @masked_int_and_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> { + %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32> + return %0 : vector<2xi32> +} + +// CHECK-LABEL: func.func @masked_int_and_outerprod( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> { +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32> +// CHECK: %[[VAL_9:.*]] = arith.andi %[[VAL_8]], %[[VAL_2]] : vector<2xi32> +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32> + +// ----- + +func.func @masked_int_or_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> { + %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32> + return %0 : vector<2xi32> +} + +// CHECK-LABEL: func.func @masked_int_or_outerprod( +// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> { +// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32> +// CHECK: %[[VAL_9:.*]] = arith.ori %[[VAL_8]], %[[VAL_2]] : vector<2xi32> +// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32> // ----- @@ -2157,17 +2273,3 @@ func.func @vector_scalable_extract(%vec: vector<[4]xf32>) -> vector<8xf32> { %0 = vector.scalable.extract %vec[0] : vector<8xf32> from vector<[4]xf32> return %0 : vector<8xf32> } - -// ----- - -// CHECK-LABEL: func.func @masked_vector_fma( -// CHECK-SAME: %[[INPUT:.*]]: vector<8xf32>, -// CHECK-SAME: %[[MASK:.*]]: vector<8xi1>) -> vector<8xf32> -// CHECK: %[[FMA:.*]] = llvm.intr.fmuladd(%[[INPUT]], %[[INPUT]], %[[INPUT]]) : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32> -// CHECK: llvm.select %[[MASK]], %[[FMA]], %[[INPUT]] : vector<8xi1>, vector<8xf32> - -func.func @masked_vector_fma(%a: vector<8xf32>, %m: vector<8xi1>) -> vector<8xf32> { - %0 = vector.mask %m { vector.fma %a, %a, %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32> - return %0 : vector<8xf32> -} - diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir index 1617cb1..6ad8a09 100644 --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -76,6 +76,30 @@ func.func @extract_contract2(%arg0: vector<2x3xf32>, return %0 : vector<2xf32> } +// OUTERPRODUCT-LABEL: func.func @masked_extract_contract2( +// OUTERPRODUCT-SAME: %[[VAL_0:.*]]: vector<2x3xf32>, +// OUTERPRODUCT-SAME: %[[VAL_1:.*]]: vector<3xf32>, +// OUTERPRODUCT-SAME: %[[VAL_2:.*]]: vector<2xf32>, +// OUTERPRODUCT-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32> +// OUTERPRODUCT: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1> +// OUTERPRODUCT: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x2xi1> +// OUTERPRODUCT: vector.mask %[[MASK0]] { vector.outerproduct + +// OUTERPRODUCT: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x2xi1> +// OUTERPRODUCT: vector.mask %[[MASK1]] { vector.outerproduct + +// OUTERPRODUCT: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x2xi1> +// OUTERPRODUCT: vector.mask %[[MASK2]] { vector.outerproduct + +func.func @masked_extract_contract2(%arg0: vector<2x3xf32>, + %arg1: vector<3xf32>, + %arg2: vector<2xf32>, + %m: vector<2x3xi1>) -> vector<2xf32> { + %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2 + : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32> + return %0 : vector<2xf32> +} + // CHECK-LABEL: func @extract_contract2_int // CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>, // CHECK-SAME: %[[B:.*1]]: vector<3xi32>, @@ -182,6 +206,32 @@ func.func @extract_contract4(%arg0: vector<2x2xf32>, return %0 : vector<2x2xf32> } +// OUTERPRODUCT-LABEL: func.func @masked_extract_contract4( +// OUTERPRODUCT-SAME: %[[VAL_0:.*]]: vector<3x5xf32>, +// OUTERPRODUCT-SAME: %[[VAL_1:.*]]: vector<5x7xf32>, +// OUTERPRODUCT-SAME: %[[VAL_2:.*]]: vector<3x7xf32>, +// OUTERPRODUCT-SAME: %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> { +// OUTERPRODUCT: %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1> +// OUTERPRODUCT: %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<5x3x7xi1> +// OUTERPRODUCT: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// OUTERPRODUCT: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<5x3x7xi1> +// OUTERPRODUCT: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// OUTERPRODUCT: %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<5x3x7xi1> +// OUTERPRODUCT: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// OUTERPRODUCT: %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<5x3x7xi1> +// OUTERPRODUCT: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> +// OUTERPRODUCT: %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<5x3x7xi1> +// OUTERPRODUCT: %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32> + +func.func @masked_extract_contract4(%arg0: vector<3x5xf32>, + %arg1: vector<5x7xf32>, + %arg2: vector<3x7xf32>, + %m : vector<3x7x5xi1>) -> vector<3x7xf32> { + %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2 + : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32> + return %0 : vector<3x7xf32> +} + #contraction2d_accesses = [ affine_map<(i, j) -> (i, j)>, affine_map<(i, j) -> (i, j)>, @@ -1197,26 +1247,4 @@ func.func @parrallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vec return %0 : f32 } -func.func @masked_vector_contract(%arg0: vector<2x3xf32>, - %arg1: vector<3xf32>, - %arg2: vector<2xf32>, - %m: vector<2x3xi1>) -> vector<2xf32> { - %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2 - : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32> - return %0 : vector<2xf32> -} -// OUTERPRODUCT-LABEL: func.func @masked_vector_contract( -// OUTERPRODUCT-SAME: %[[VAL_0:.*]]: vector<2x3xf32>, -// OUTERPRODUCT-SAME: %[[VAL_1:.*]]: vector<3xf32>, -// OUTERPRODUCT-SAME: %[[VAL_2:.*]]: vector<2xf32>, -// OUTERPRODUCT-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32> -// OUTERPRODUCT: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1> -// OUTERPRODUCT: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x2xi1> -// OUTERPRODUCT: vector.mask %[[MASK0]] { vector.outerproduct - -// OUTERPRODUCT: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x2xi1> -// OUTERPRODUCT: vector.mask %[[MASK1]] { vector.outerproduct - -// OUTERPRODUCT: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x2xi1> -// OUTERPRODUCT: vector.mask %[[MASK2]] { vector.outerproduct -- 2.7.4