VectorType vType = fmaOp.getVectorType();
if (vType.getRank() > 1)
return failure();
+
+ // Masked fmas are lowered separately.
+ auto maskableOp = cast<MaskableOpInterface>(fmaOp.getOperation());
+ if (maskableOp.isMasked())
+ return failure();
+
rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
return success();
}
};
+/// Conversion pattern that turns a masked vector.fma on a 1-D vector into their
+/// LLVM counterpart representation. Non side effecting VP intrinsics are not
+/// fully supported by some backends, including x86, and they don't support
+/// pass-through values either. For these reasons, we generate an unmasked
+/// fma followed by a select instrution to emulate the masking behavior.
+/// This pattern is peepholed by some backends with support for masked fma
+/// instructions. This pattern does not match vectors of n >= 2 rank.
+class MaskedFMAOp1DConversion
+ : public VectorMaskOpConversionBase<vector::FMAOp> {
+public:
+ using VectorMaskOpConversionBase<vector::FMAOp>::VectorMaskOpConversionBase;
+
+ MaskedFMAOp1DConversion(LLVMTypeConverter &converter, bool fullVPIntr)
+ : VectorMaskOpConversionBase<vector::FMAOp>(converter) {}
+
+ virtual LogicalResult matchAndRewriteMaskableOp(
+ vector::MaskOp maskOp, MaskableOpInterface maskableOp,
+ ConversionPatternRewriter &rewriter) const override {
+ auto fmaOp = cast<FMAOp>(maskableOp.getOperation());
+ Type llvmType = typeConverter->convertType(fmaOp.getVectorType());
+
+ Value fmulAddOp = rewriter.create<LLVM::FMulAddOp>(
+ fmaOp.getLoc(), llvmType, fmaOp.getLhs(), fmaOp.getRhs(),
+ fmaOp.getAcc());
+ rewriter.replaceOpWithNewOp<LLVM::SelectOp>(
+ maskOp, llvmType, maskOp.getMask(), fmulAddOp, fmaOp.getAcc());
+ return success();
+ }
+};
+
class VectorInsertElementOpConversion
: public ConvertOpToLLVMPattern<vector::InsertElementOp> {
public:
if (vType.getRank() < 2)
return failure();
+ // Masked fmas are lowered separately.
+ auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+ if (maskableOp.isMasked())
+ return failure();
+
auto loc = op.getLoc();
auto elemType = vType.getElementType();
Value zero = rewriter.create<arith::ConstantOp>(
patterns
.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
- VectorFMAOp1DConversion, VectorInsertElementOpConversion,
- VectorInsertOpConversion, VectorPrintOpConversion,
- VectorTypeCastOpConversion, VectorScaleOpConversion,
+ VectorFMAOp1DConversion, MaskedFMAOp1DConversion,
+ VectorInsertElementOpConversion, VectorInsertOpConversion,
+ VectorPrintOpConversion, VectorTypeCastOpConversion,
+ VectorScaleOpConversion,
VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
VectorLoadStoreConversion<vector::MaskedLoadOp,
vector::MaskedLoadOpAdaptor>,
return success();
}
+// MaskableOpInterface methods.
+
+/// Returns the mask type expected by this operation. Mostly used for
+/// verification purposes. It requires the operation to be vectorized."
+Type ContractionOp::getExpectedMaskType() {
+ auto indexingMaps = this->getIndexingMapsArray();
+ AffineMap lhsIdxMap = indexingMaps[0];
+ AffineMap rhsIdxMap = indexingMaps[1];
+ VectorType lhsType = this->getLhsType();
+ VectorType rhsType = this->getRhsType();
+
+ unsigned numVecDims = lhsIdxMap.getNumDims();
+ SmallVector<int64_t> maskShape(numVecDims, ShapedType::kDynamic);
+
+ // Using the information in the indexing maps, extract the size of each
+ // dimension in the vector.contract operation from the two input operands.
+ for (auto [dimIdx, dimSize] : llvm::enumerate(lhsType.getShape()))
+ maskShape[lhsIdxMap.getDimPosition(dimIdx)] = dimSize;
+ for (auto [dimIdx, dimSize] : llvm::enumerate(rhsType.getShape()))
+ maskShape[rhsIdxMap.getDimPosition(dimIdx)] = dimSize;
+
+ assert(!ShapedType::isDynamicShape(maskShape) &&
+ "Mask shape couldn't be computed");
+
+ return VectorType::get(maskShape,
+ IntegerType::get(lhsType.getContext(), /*width=*/1));
+}
+
SmallVector<StringRef> ContractionOp::getTraitAttrNames() {
return SmallVector<StringRef>{getIndexingMapsAttrName(),
getIteratorTypesAttrName(), getKindAttrName()};
return llvm::to_vector<4>(getVectorType().getShape());
}
+// MaskableOpInterface methods.
+
+/// Returns the mask type expected by this operation. Mostly used for
+/// verification purposes. It requires the operation to be vectorized."
+Type FMAOp::getExpectedMaskType() {
+ auto vecType = this->getVectorType();
+ return VectorType::get(vecType.getShape(),
+ IntegerType::get(vecType.getContext(), /*width=*/1));
+}
+
//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
return success();
}
+// MaskableOpInterface methods.
+
+/// Returns the mask type expected by this operation. Mostly used for
+/// verification purposes. It requires the operation to be vectorized."
+Type OuterProductOp::getExpectedMaskType() {
+ auto vecType = this->getVectorType();
+ return VectorType::get(vecType.getShape(),
+ IntegerType::get(vecType.getContext(), /*width=*/1));
+}
+
//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
}
/// Helper to create arithmetic operation associated with a kind of contraction.
-static std::optional<Value> createContractArithOp(Location loc, Value x,
- Value y, Value acc,
- vector::CombiningKind kind,
- PatternRewriter &rewriter,
- bool isInt) {
+static std::optional<Value>
+createContractArithOp(Location loc, Value x, Value y, Value acc,
+ vector::CombiningKind kind, PatternRewriter &rewriter,
+ bool isInt, Optional<Value> maybeMask = std::nullopt) {
using vector::CombiningKind;
Value mul;
+
if (isInt) {
if (kind == CombiningKind::MINF || kind == CombiningKind::MAXF)
// Only valid for floating point types.
return std::nullopt;
// Special case for fused multiply-add.
if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) {
- return std::optional<Value>(
- rewriter.create<vector::FMAOp>(loc, x, y, acc));
+ Operation *fmaOp = rewriter.create<vector::FMAOp>(loc, x, y, acc);
+ if (maybeMask.has_value() && maybeMask.value())
+ fmaOp = maskOperation(rewriter, fmaOp, maybeMask.value());
+ return fmaOp->getResult(0);
}
mul = rewriter.create<arith::MulFOp>(loc, x, y);
}
+
+ assert((!maybeMask.has_value() || !maybeMask.value()) &&
+ "Unsupported masked case");
+
if (!acc)
return std::optional<Value>(mul);
return makeArithReduction(rewriter, loc, kind, mul, acc);
Value acc = (op.getAcc().empty()) ? nullptr : op.getAcc()[0];
vector::CombiningKind kind = op.getKind();
+ // Vector mask setup.
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
+ Operation *rootOp;
+ Value mask;
+ if (maskableOp.isMasked()) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ mask = maskableOp.getMaskingOp().getMask();
+ } else {
+ rootOp = op;
+ }
+
if (!rhsType) {
// Special case: AXPY operation.
Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.getRhs());
std::optional<Value> mult = createContractArithOp(
- loc, op.getLhs(), b, acc, kind, rewriter, isInt);
+ loc, op.getLhs(), b, acc, kind, rewriter, isInt, mask);
if (!mult.has_value())
return failure();
- rewriter.replaceOp(op, *mult);
+ rewriter.replaceOp(rootOp, *mult);
return success();
}
Value r = nullptr;
if (acc)
r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
- std::optional<Value> m =
- createContractArithOp(loc, a, op.getRhs(), r, kind, rewriter, isInt);
+ std::optional<Value> m = createContractArithOp(
+ loc, a, op.getRhs(), r, kind, rewriter, isInt, mask);
if (!m.has_value())
return failure();
result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, pos);
}
- rewriter.replaceOp(op, result);
+
+ rewriter.replaceOp(rootOp, result);
return success();
}
};
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
PatternRewriter &rewriter) const override {
- // TODO: implement masks
+ // TODO: Support vector.mask.
+ auto maskableOp = cast<MaskableOpInterface>(contractOp.getOperation());
+ if (maskableOp.isMasked())
+ return failure();
+
+ // TODO: Remove native masks from contraction op?
if (!contractOp.getMasks().empty())
return failure();
LogicalResult
ContractionOpToMatmulOpLowering::matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rew) const {
- // TODO: implement masks
+ // TODO: Support vector.mask.
+ auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+ if (maskableOp.isMasked())
+ return failure();
+
+ // TODO: Remove native masks from contraction op?
if (!op.getMasks().empty())
return failure();
if (vectorTransformOptions.vectorContractLowering !=
UnrolledOuterProductGenerator(RewriterBase &b, vector::ContractionOp op)
: StructuredGenerator<vector::ContractionOp, vector::IteratorType>(b, op),
kind(op.getKind()), lhs(op.getLhs()), rhs(op.getRhs()),
- res(op.getAcc()), lhsType(op.getLhsType()) {}
+ res(op.getAcc()), lhsType(op.getLhsType()) {
+ auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+ if (maskableOp.isMasked())
+ mask = maskableOp.getMaskingOp().getMask();
+ }
Value t(Value v) {
static constexpr std::array<int64_t, 2> perm = {1, 0};
+ if (!v)
+ return v;
return rewriter.create<vector::TransposeOp>(loc, v, perm);
}
return rewriter.create<arith::ExtSIOp>(loc, promotedType, v);
}
- Value outerProd(Value lhs, Value rhs, Value res, int reductionSize) {
+ FailureOr<Value> outerProd(Value lhs, Value rhs, Value res, int reductionSize,
+ Optional<Value> maybeMask = std::nullopt) {
assert(reductionSize > 0);
+ // Incremental support for masking.
+ if (mask && !maybeMask.has_value())
+ return failure();
+
Type resElementType = res.getType().cast<VectorType>().getElementType();
for (int64_t k = 0; k < reductionSize; ++k) {
Value extractA = rewriter.create<vector::ExtractOp>(loc, lhs, k);
Value extractB = rewriter.create<vector::ExtractOp>(loc, rhs, k);
extractA = promote(extractA, resElementType);
extractB = promote(extractB, resElementType);
- res = rewriter.create<vector::OuterProductOp>(loc, res.getType(), extractA,
- extractB, res, kind);
+ Value extractMask;
+ if (maybeMask.has_value() && maybeMask.value())
+ extractMask =
+ rewriter.create<vector::ExtractOp>(loc, maybeMask.value(), k);
+
+ Operation *outerProdOp = rewriter.create<vector::OuterProductOp>(
+ loc, res.getType(), extractA, extractB, res, kind);
+ res = maskOperation(rewriter, outerProdOp, extractMask)->getResult(0);
}
return res;
}
// Case mat-vec: transpose.
if (layout({{m, k}, {k}, {m}}))
- return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
+ return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1), t(mask));
// Case mat-trans-vec: ready to go.
if (layout({{k, m}, {k}, {m}}))
return outerProd(lhs, rhs, res, lhsType.getDimSize(0));
private:
vector::CombiningKind kind;
- Value lhs, rhs, res;
+ Value lhs, rhs, res, mask;
VectorType lhsType;
};
} // namespace
/// otherwise supports any layout permutation of the matrix-multiply.
LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
vector::ContractionOp op, PatternRewriter &rewriter) const {
- // TODO: implement masks
+ // TODO: Remove native masks from contraction op?
if (!op.getMasks().empty())
return failure();
if (failed(filter(op)))
return failure();
+ // Vector mask setup.
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp = cast<vector::MaskableOpInterface>(op.getOperation());
+ Operation *rootOp;
+ if (maskableOp.isMasked()) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ } else {
+ rootOp = op;
+ }
+
UnrolledOuterProductGenerator e(rewriter, op);
FailureOr<Value> matmatRes = e.matmat();
if (succeeded(matmatRes)) {
- rewriter.replaceOp(op, *matmatRes);
+ rewriter.replaceOp(rootOp, *matmatRes);
return success();
}
FailureOr<Value> matvecRes = e.matvec();
if (succeeded(matvecRes)) {
- rewriter.replaceOp(op, *matvecRes);
+ rewriter.replaceOp(rootOp, *matvecRes);
return success();
}
FailureOr<Value> tmatvecRes = e.tmatvec();
if (succeeded(tmatvecRes)) {
- rewriter.replaceOp(op, *tmatvecRes);
+ rewriter.replaceOp(rootOp, *tmatvecRes);
return success();
}
LogicalResult
ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const {
- // TODO: implement masks
+ // TODO: Support vector.mask.
+ auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+ if (maskableOp.isMasked())
+ return failure();
+
+ // TODO: Remove native masks from contraction op?
if (!op.getMasks().empty())
return failure();
LogicalResult
ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const {
- // TODO: implement masks.
+ // TODO: Support vector.mask.
+ auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
+ if (maskableOp.isMasked())
+ return failure();
+
+ // TODO: Remove native masks from contraction op?
if (!op.getMasks().empty())
return failure();