LogicalResult
ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
PatternRewriter &rewriter) const {
- // TODO: Support vector.mask.
- auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
- if (maskableOp.isMasked())
- return failure();
-
// TODO: Remove native masks from contraction op?
if (!op.getMasks().empty())
return failure();
if (succeeded(pat4.matchAndRewrite(op, rewriter)))
return success();
+ // Vector mask setup.
+ OpBuilder::InsertionGuard guard(rewriter);
+ Operation *rootOp = op;
+ Value mask;
+ if (op.isMasked()) {
+ rewriter.setInsertionPoint(op.getMaskingOp());
+ rootOp = op.getMaskingOp();
+ mask = op.getMaskingOp().getMask();
+ }
+
// Find first batch dimension in LHS/RHS, and lower when found.
std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
if (!batchDimMap.empty()) {
int64_t lhsIndex = batchDimMap[0].first;
int64_t rhsIndex = batchDimMap[0].second;
- auto newOp = lowerParallel(op, lhsIndex, rhsIndex, rewriter);
+ auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
if (failed(newOp))
return failure();
- rewriter.replaceOp(op, *newOp);
+ rewriter.replaceOp(rootOp, *newOp);
return success();
}
VectorType lhsType = op.getLhsType();
for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
if (lhsContractingDimSet.count(lhsIndex) == 0) {
- auto newOp = lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter);
+ auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
if (failed(newOp))
return failure();
- rewriter.replaceOp(op, *newOp);
+ rewriter.replaceOp(rootOp, *newOp);
return success();
}
}
VectorType rhsType = op.getRhsType();
for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
if (rhsContractingDimSet.count(rhsIndex) == 0) {
- auto newOp = lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter);
+ auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
if (failed(newOp))
return failure();
- rewriter.replaceOp(op, *newOp);
+ rewriter.replaceOp(rootOp, *newOp);
return success();
}
}
// Lower the first remaining reduction dimension.
if (!contractingDimMap.empty()) {
- auto newOp = lowerReduction(op, rewriter);
+ auto newOp = lowerReduction(rewriter, op, mask);
if (failed(newOp))
return failure();
- rewriter.replaceOp(op, *newOp);
+ rewriter.replaceOp(rootOp, *newOp);
return success();
}
// Lower one parallel dimension.
// Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
// TODO: consider reusing existing contract unrolling
-FailureOr<Value>
-ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
- int64_t rhsIndex,
- PatternRewriter &rewriter) const {
+FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
+ vector::ContractionOp op,
+ int64_t lhsIndex,
+ int64_t rhsIndex,
+ Value mask) const {
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
VectorType resType = op.getResultType().cast<VectorType>();
diag << "expected the dimension for iterIndex=" << iterIndex
<< " to either appear in the result map, or to be a unit dimension";
});
+
// Construct new iterator types and affine map array attribute.
std::array<AffineMap, 3> lowIndexingMaps = {
adjustMap(iMap[0], iterIndex, rewriter),
Location loc = op.getLoc();
Value result = rewriter.create<arith::ConstantOp>(
loc, resType, rewriter.getZeroAttr(resType));
+
for (int64_t d = 0; d < dimSize; ++d) {
auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
- Value lowContract = rewriter.create<vector::ContractionOp>(
+
+ Value lowMask;
+ if (mask)
+ lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
+ iterIndex, d, rewriter);
+
+ Operation *lowContract = rewriter.create<vector::ContractionOp>(
loc, lhs, rhs, acc, lowAffine, lowIter);
- result =
- reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
+ lowContract = maskOperation(rewriter, lowContract, lowMask);
+ result = reshapeStore(loc, lowContract->getResult(0), result, resType,
+ resIndex, d, rewriter);
}
return result;
}
// Lower one reduction dimension.
-FailureOr<Value>
-ContractionOpLowering::lowerReduction(vector::ContractionOp op,
- PatternRewriter &rewriter) const {
+FailureOr<Value> ContractionOpLowering::lowerReduction(
+ PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
auto loc = op.getLoc();
VectorType lhsType = op.getLhsType();
VectorType rhsType = op.getRhsType();
op, "When LHS has rank 1, expected also RHS to have rank 1");
Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
auto kind = vector::CombiningKind::ADD;
- if (auto acc = op.getAcc())
- return rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
- .getResult();
- return rewriter.create<vector::ReductionOp>(loc, kind, m).getResult();
+
+ Value acc = op.getAcc();
+ Operation *reductionOp =
+ acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
+ : rewriter.create<vector::ReductionOp>(loc, kind, m);
+ return maskOperation(rewriter, reductionOp, mask)->getResult(0);
}
// Construct new iterator types and affine map array attribute.
std::array<AffineMap, 3> lowIndexingMaps = {
for (int64_t d = 0; d < dimSize; ++d) {
auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
- result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
- lowAffine, lowIter);
+ Value newMask;
+ if (mask)
+ newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
+ iterIndex, d, rewriter);
+
+ Operation *newContract = rewriter.create<vector::ContractionOp>(
+ loc, lhs, rhs, result, lowAffine, lowIter);
+ result = maskOperation(rewriter, newContract, newMask)->getResult(0);
}
return result;
}