From 9452356ddcf44720387eda9b6316b7cadff1f60d Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Wed, 15 Feb 2023 05:46:15 +0000 Subject: [PATCH] [mlir][Vector] Add support for masked vector.contract This patch adds support for masking vector.contract ops with the vector.mask approach. This also includes the lowering of vector.contract through the vector.outerproduct path to LLVM. For now, this only adds support for one of the many potential flavors of vector.contract/vector.outerproduct but unsupported cases will fail gratefully. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D143965 --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 5 +- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 48 ++++++++- mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 48 +++++++++ .../Dialect/Vector/Transforms/VectorTransforms.cpp | 120 ++++++++++++++++----- .../Conversion/VectorToLLVM/vector-to-llvm.mlir | 26 +++++ .../Dialect/Vector/vector-contract-transforms.mlir | 24 +++++ 6 files changed, 241 insertions(+), 30 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index b366cf8..6f6d80c 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -91,6 +91,7 @@ def Vector_ContractionOp : PredOpTrait<"lhs and rhs have same element type", TCopVTEtIsSameAs<0, 1>>, PredOpTrait<"third operand acc and result have same element type", TCresVTEtIsSameAsOpBase<0, 2>>, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ]>, Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc, @@ -632,6 +633,7 @@ def Vector_ExtractOp : def Vector_FMAOp : Op, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods ] # ElementwiseMappable.traits>, Arguments<(ins VectorOfAnyRankOf<[AnyFloat]>:$lhs, @@ -923,7 +925,8 @@ def Vector_OuterProductOp : PredOpTrait<"lhs operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, PredOpTrait<"rhs operand and result have same element type", - TCresVTEtIsSameAsOpBase<0, 1>>]>, + TCresVTEtIsSameAsOpBase<0, 1>>, + DeclareOpInterfaceMethods]>, Arguments<(ins AnyVector:$lhs, AnyType:$rhs, Variadic:$acc, DefaultValuedAttr:$kind)>, diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 68865d3..a1c1c3d 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -1107,12 +1107,48 @@ public: VectorType vType = fmaOp.getVectorType(); if (vType.getRank() > 1) return failure(); + + // Masked fmas are lowered separately. + auto maskableOp = cast(fmaOp.getOperation()); + if (maskableOp.isMasked()) + return failure(); + rewriter.replaceOpWithNewOp( fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc()); return success(); } }; +/// Conversion pattern that turns a masked vector.fma on a 1-D vector into their +/// LLVM counterpart representation. Non side effecting VP intrinsics are not +/// fully supported by some backends, including x86, and they don't support +/// pass-through values either. For these reasons, we generate an unmasked +/// fma followed by a select instrution to emulate the masking behavior. +/// This pattern is peepholed by some backends with support for masked fma +/// instructions. This pattern does not match vectors of n >= 2 rank. +class MaskedFMAOp1DConversion + : public VectorMaskOpConversionBase { +public: + using VectorMaskOpConversionBase::VectorMaskOpConversionBase; + + MaskedFMAOp1DConversion(LLVMTypeConverter &converter, bool fullVPIntr) + : VectorMaskOpConversionBase(converter) {} + + virtual LogicalResult matchAndRewriteMaskableOp( + vector::MaskOp maskOp, MaskableOpInterface maskableOp, + ConversionPatternRewriter &rewriter) const override { + auto fmaOp = cast(maskableOp.getOperation()); + Type llvmType = typeConverter->convertType(fmaOp.getVectorType()); + + Value fmulAddOp = rewriter.create( + fmaOp.getLoc(), llvmType, fmaOp.getLhs(), fmaOp.getRhs(), + fmaOp.getAcc()); + rewriter.replaceOpWithNewOp( + maskOp, llvmType, maskOp.getMask(), fmulAddOp, fmaOp.getAcc()); + return success(); + } +}; + class VectorInsertElementOpConversion : public ConvertOpToLLVMPattern { public: @@ -1279,6 +1315,11 @@ public: if (vType.getRank() < 2) return failure(); + // Masked fmas are lowered separately. + auto maskableOp = cast(op.getOperation()); + if (maskableOp.isMasked()) + return failure(); + auto loc = op.getLoc(); auto elemType = vType.getElementType(); Value zero = rewriter.create( @@ -1707,9 +1748,10 @@ void mlir::populateVectorToLLVMConversionPatterns( patterns .add, VectorLoadStoreConversion, diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 02cee8c..3e145f1 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -889,6 +889,34 @@ LogicalResult ContractionOp::verify() { return success(); } +// MaskableOpInterface methods. + +/// Returns the mask type expected by this operation. Mostly used for +/// verification purposes. It requires the operation to be vectorized." +Type ContractionOp::getExpectedMaskType() { + auto indexingMaps = this->getIndexingMapsArray(); + AffineMap lhsIdxMap = indexingMaps[0]; + AffineMap rhsIdxMap = indexingMaps[1]; + VectorType lhsType = this->getLhsType(); + VectorType rhsType = this->getRhsType(); + + unsigned numVecDims = lhsIdxMap.getNumDims(); + SmallVector maskShape(numVecDims, ShapedType::kDynamic); + + // Using the information in the indexing maps, extract the size of each + // dimension in the vector.contract operation from the two input operands. + for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape())) + maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize; + for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape())) + maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize; + + assert(!ShapedType::isDynamicShape(maskShape) && + "Mask shape couldn't be computed"); + + return VectorType::get(maskShape, + IntegerType::get(lhsType.getContext(), /*width=*/1)); +} + SmallVector ContractionOp::getTraitAttrNames() { return SmallVector{getIndexingMapsAttrName(), getIteratorTypesAttrName(), getKindAttrName()}; @@ -1760,6 +1788,16 @@ std::optional> FMAOp::getShapeForUnroll() { return llvm::to_vector<4>(getVectorType().getShape()); } +// MaskableOpInterface methods. + +/// Returns the mask type expected by this operation. Mostly used for +/// verification purposes. It requires the operation to be vectorized." +Type FMAOp::getExpectedMaskType() { + auto vecType = this->getVectorType(); + return VectorType::get(vecType.getShape(), + IntegerType::get(vecType.getContext(), /*width=*/1)); +} + //===----------------------------------------------------------------------===// // BroadcastOp //===----------------------------------------------------------------------===// @@ -2762,6 +2800,16 @@ LogicalResult OuterProductOp::verify() { return success(); } +// MaskableOpInterface methods. + +/// Returns the mask type expected by this operation. Mostly used for +/// verification purposes. It requires the operation to be vectorized." +Type OuterProductOp::getExpectedMaskType() { + auto vecType = this->getVectorType(); + return VectorType::get(vecType.getShape(), + IntegerType::get(vecType.getContext(), /*width=*/1)); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp index c7aee6d..9976cf7 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp @@ -147,13 +147,13 @@ static SmallVector extractVector(ArrayAttr arrayAttr) { } /// Helper to create arithmetic operation associated with a kind of contraction. -static std::optional createContractArithOp(Location loc, Value x, - Value y, Value acc, - vector::CombiningKind kind, - PatternRewriter &rewriter, - bool isInt) { +static std::optional +createContractArithOp(Location loc, Value x, Value y, Value acc, + vector::CombiningKind kind, PatternRewriter &rewriter, + bool isInt, Optional maybeMask = std::nullopt) { using vector::CombiningKind; Value mul; + if (isInt) { if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF) // Only valid for floating point types. @@ -169,11 +169,17 @@ static std::optional createContractArithOp(Location loc, Value x, return std::nullopt; // Special case for fused multiply-add. if (acc && acc.getType().isa() && kind == CombiningKind::ADD) { - return std::optional( - rewriter.create(loc, x, y, acc)); + Operation *fmaOp = rewriter.create(loc, x, y, acc); + if (maybeMask.has_value() && maybeMask.value()) + fmaOp = maskOperation(rewriter, fmaOp, maybeMask.value()); + return fmaOp->getResult(0); } mul = rewriter.create(loc, x, y); } + + assert((!maybeMask.has_value() || !maybeMask.value()) && + "Unsupported masked case"); + if (!acc) return std::optional(mul); return makeArithReduction(rewriter, loc, kind, mul, acc); @@ -550,14 +556,27 @@ public: Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0]; vector::CombiningKind kind = op.getKind(); + // Vector mask setup. + OpBuilder::InsertionGuard guard(rewriter); + auto maskableOp = cast(op.getOperation()); + Operation *rootOp; + Value mask; + if (maskableOp.isMasked()) { + rewriter.setInsertionPoint(maskableOp.getMaskingOp()); + rootOp = maskableOp.getMaskingOp(); + mask = maskableOp.getMaskingOp().getMask(); + } else { + rootOp = op; + } + if (!rhsType) { // Special case: AXPY operation. Value b = rewriter.create(loc, lhsType, op.getRhs()); std::optional mult = createContractArithOp( - loc, op.getLhs(), b, acc, kind, rewriter, isInt); + loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask); if (!mult.has_value()) return failure(); - rewriter.replaceOp(op, *mult); + rewriter.replaceOp(rootOp, *mult); return success(); } @@ -571,13 +590,14 @@ public: Value r = nullptr; if (acc) r = rewriter.create(loc, rhsType, acc, pos); - std::optional m = - createContractArithOp(loc, a, op.getRhs(), r, kind, rewriter, isInt); + std::optional m = createContractArithOp( + loc, a, op.getRhs(), r, kind, rewriter, isInt, mask); if (!m.has_value()) return failure(); result = rewriter.create(loc, resType, *m, result, pos); } - rewriter.replaceOp(op, result); + + rewriter.replaceOp(rootOp, result); return success(); } }; @@ -601,7 +621,12 @@ struct ContractOpToElementwise LogicalResult matchAndRewrite(vector::ContractionOp contractOp, PatternRewriter &rewriter) const override { - // TODO: implement masks + // TODO: Support vector.mask. + auto maskableOp = cast(contractOp.getOperation()); + if (maskableOp.isMasked()) + return failure(); + + // TODO: Remove native masks from contraction op? if (!contractOp.getMasks().empty()) return failure(); @@ -1429,7 +1454,12 @@ namespace mlir { LogicalResult ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op, PatternRewriter &rew) const { - // TODO: implement masks + // TODO: Support vector.mask. + auto maskableOp = cast(op.getOperation()); + if (maskableOp.isMasked()) + return failure(); + + // TODO: Remove native masks from contraction op? if (!op.getMasks().empty()) return failure(); if (vectorTransformOptions.vectorContractLowering != @@ -1525,10 +1555,16 @@ struct UnrolledOuterProductGenerator UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op) : StructuredGenerator(b, op), kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()), - res(op.getAcc()), lhsType(op.getLhsType()) {} + res(op.getAcc()), lhsType(op.getLhsType()) { + auto maskableOp = cast(op.getOperation()); + if (maskableOp.isMasked()) + mask = maskableOp.getMaskingOp().getMask(); + } Value t(Value v) { static constexpr std::array perm = {1, 0}; + if (!v) + return v; return rewriter.create(loc, v, perm); } @@ -1547,16 +1583,27 @@ struct UnrolledOuterProductGenerator return rewriter.create(loc, promotedType, v); } - Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) { + FailureOr outerProd(Value lhs, Value rhs, Value res, int reductionSize, + Optional maybeMask = std::nullopt) { assert(reductionSize > 0); + // Incremental support for masking. + if (mask && !maybeMask.has_value()) + return failure(); + Type resElementType = res.getType().cast().getElementType(); for (int64_t k = 0; k < reductionSize; ++k) { Value extractA = rewriter.create(loc, lhs, k); Value extractB = rewriter.create(loc, rhs, k); extractA = promote(extractA, resElementType); extractB = promote(extractB, resElementType); - res = rewriter.create(loc, res.getType(), extractA, - extractB, res, kind); + Value extractMask; + if (maybeMask.has_value() && maybeMask.value()) + extractMask = + rewriter.create(loc, maybeMask.value(), k); + + Operation *outerProdOp = rewriter.create( + loc, res.getType(), extractA, extractB, res, kind); + res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0); } return res; } @@ -1607,7 +1654,7 @@ struct UnrolledOuterProductGenerator // Case mat-vec: transpose. if (layout({{m, k}, {k}, {m}})) - return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1)); + return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask)); // Case mat-trans-vec: ready to go. if (layout({{k, m}, {k}, {m}})) return outerProd(lhs, rhs, res, lhsType.getDimSize(0)); @@ -1646,7 +1693,7 @@ struct UnrolledOuterProductGenerator private: vector::CombiningKind kind; - Value lhs, rhs, res; + Value lhs, rhs, res, mask; VectorType lhsType; }; } // namespace @@ -1668,7 +1715,7 @@ private: /// otherwise supports any layout permutation of the matrix-multiply. LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( vector::ContractionOp op, PatternRewriter &rewriter) const { - // TODO: implement masks + // TODO: Remove native masks from contraction op? if (!op.getMasks().empty()) return failure(); @@ -1679,20 +1726,31 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( if (failed(filter(op))) return failure(); + // Vector mask setup. + OpBuilder::InsertionGuard guard(rewriter); + auto maskableOp = cast(op.getOperation()); + Operation *rootOp; + if (maskableOp.isMasked()) { + rewriter.setInsertionPoint(maskableOp.getMaskingOp()); + rootOp = maskableOp.getMaskingOp(); + } else { + rootOp = op; + } + UnrolledOuterProductGenerator e(rewriter, op); FailureOr matmatRes = e.matmat(); if (succeeded(matmatRes)) { - rewriter.replaceOp(op, *matmatRes); + rewriter.replaceOp(rootOp, *matmatRes); return success(); } FailureOr matvecRes = e.matvec(); if (succeeded(matvecRes)) { - rewriter.replaceOp(op, *matvecRes); + rewriter.replaceOp(rootOp, *matvecRes); return success(); } FailureOr tmatvecRes = e.tmatvec(); if (succeeded(tmatvecRes)) { - rewriter.replaceOp(op, *tmatvecRes); + rewriter.replaceOp(rootOp, *tmatvecRes); return success(); } @@ -1702,7 +1760,12 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite( LogicalResult ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const { - // TODO: implement masks + // TODO: Support vector.mask. + auto maskableOp = cast(op.getOperation()); + if (maskableOp.isMasked()) + return failure(); + + // TODO: Remove native masks from contraction op? if (!op.getMasks().empty()) return failure(); @@ -1834,7 +1897,12 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op, LogicalResult ContractionOpLowering::matchAndRewrite(vector::ContractionOp op, PatternRewriter &rewriter) const { - // TODO: implement masks. + // TODO: Support vector.mask. + auto maskableOp = cast(op.getOperation()); + if (maskableOp.isMasked()) + return failure(); + + // TODO: Remove native masks from contraction op? if (!op.getMasks().empty()) return failure(); diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir index f38f799..a72ab15 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -416,6 +416,18 @@ func.func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: v // CHECK: %[[T19:.*]] = builtin.unrealized_conversion_cast %[[T18]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32> // CHECK: return %[[T19]] : vector<2x3xf32> +// ----- + +func.func @masked_vector_contract(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> { + %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32> + return %0 : vector<2xf32> +} + +// We can't check for the intermediate 'vector.mask { vector.fma }' state so we +// just make sure the vector.fma is lowered. + +// CHECK: llvm.intr.fmuladd +// CHECK: llvm.select // ----- @@ -2145,3 +2157,17 @@ func.func @vector_scalable_extract(%vec: vector<[4]xf32>) -> vector<8xf32> { %0 = vector.scalable.extract %vec[0] : vector<8xf32> from vector<[4]xf32> return %0 : vector<8xf32> } + +// ----- + +// CHECK-LABEL: func.func @masked_vector_fma( +// CHECK-SAME: %[[INPUT:.*]]: vector<8xf32>, +// CHECK-SAME: %[[MASK:.*]]: vector<8xi1>) -> vector<8xf32> +// CHECK: %[[FMA:.*]] = llvm.intr.fmuladd(%[[INPUT]], %[[INPUT]], %[[INPUT]]) : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32> +// CHECK: llvm.select %[[MASK]], %[[FMA]], %[[INPUT]] : vector<8xi1>, vector<8xf32> + +func.func @masked_vector_fma(%a: vector<8xf32>, %m: vector<8xi1>) -> vector<8xf32> { + %0 = vector.mask %m { vector.fma %a, %a, %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32> + return %0 : vector<8xf32> +} + diff --git a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir index d4dc35d..1617cb1 100644 --- a/mlir/test/Dialect/Vector/vector-contract-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-transforms.mlir @@ -1196,3 +1196,27 @@ func.func @parrallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vec %arg0, %arg1, %arg2 : vector<1x1xf32>, vector<1x1xf32> into f32 return %0 : f32 } + +func.func @masked_vector_contract(%arg0: vector<2x3xf32>, + %arg1: vector<3xf32>, + %arg2: vector<2xf32>, + %m: vector<2x3xi1>) -> vector<2xf32> { + %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2 + : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32> + return %0 : vector<2xf32> +} + +// OUTERPRODUCT-LABEL: func.func @masked_vector_contract( +// OUTERPRODUCT-SAME: %[[VAL_0:.*]]: vector<2x3xf32>, +// OUTERPRODUCT-SAME: %[[VAL_1:.*]]: vector<3xf32>, +// OUTERPRODUCT-SAME: %[[VAL_2:.*]]: vector<2xf32>, +// OUTERPRODUCT-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32> +// OUTERPRODUCT: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1> +// OUTERPRODUCT: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x2xi1> +// OUTERPRODUCT: vector.mask %[[MASK0]] { vector.outerproduct + +// OUTERPRODUCT: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x2xi1> +// OUTERPRODUCT: vector.mask %[[MASK1]] { vector.outerproduct + +// OUTERPRODUCT: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x2xi1> +// OUTERPRODUCT: vector.mask %[[MASK2]] { vector.outerproduct -- 2.7.4