[mlir][Vector] Add support for high-order masked contractions
authorDiego Caballero <diegocaballero@google.com>
Wed, 22 Feb 2023 06:53:37 +0000 (06:53 +0000)
committerDiego Caballero <diegocaballero@google.com>
Wed, 22 Feb 2023 06:54:02 +0000 (06:54 +0000)
This patch adds support for masked vector.contract ops that needs to be
decomposed using the ContractionOpLowering pattern. It just slices the
mask according to the rest of the lowering.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D144427

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Dialect/Vector/vector-contract-transforms.mlir

index fb2d07e..775bfbf 100644 (file)
@@ -510,12 +510,12 @@ private:
   vector::VectorTransformsOptions vectorTransformOptions;
   FilterConstraintType filter;
   // Lower one parallel dimension.
-  FailureOr<Value> lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
-                                 int64_t rhsIndex,
-                                 PatternRewriter &rewriter) const;
+  FailureOr<Value> lowerParallel(PatternRewriter &rewriter,
+                                 vector::ContractionOp op, int64_t lhsIndex,
+                                 int64_t rhsIndex, Value mask) const;
   // Lower one reduction dimension.
-  FailureOr<Value> lowerReduction(vector::ContractionOp op,
-                                  PatternRewriter &rewriter) const;
+  FailureOr<Value> lowerReduction(PatternRewriter &rewriter,
+                                  vector::ContractionOp op, Value mask) const;
 };
 
 } // namespace vector
index eecf970..0844fda 100644 (file)
@@ -1904,11 +1904,6 @@ ContractionOpToDotLowering::matchAndRewrite(vector::ContractionOp op,
 LogicalResult
 ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
                                        PatternRewriter &rewriter) const {
-  // TODO: Support vector.mask.
-  auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
-  if (maskableOp.isMasked())
-    return failure();
-
   // TODO: Remove native masks from contraction op?
   if (!op.getMasks().empty())
     return failure();
@@ -1944,15 +1939,25 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
   if (succeeded(pat4.matchAndRewrite(op, rewriter)))
     return success();
 
+  // Vector mask setup.
+  OpBuilder::InsertionGuard guard(rewriter);
+  Operation *rootOp = op;
+  Value mask;
+  if (op.isMasked()) {
+    rewriter.setInsertionPoint(op.getMaskingOp());
+    rootOp = op.getMaskingOp();
+    mask = op.getMaskingOp().getMask();
+  }
+
   // Find first batch dimension in LHS/RHS, and lower when found.
   std::vector<std::pair<int64_t, int64_t>> batchDimMap = op.getBatchDimMap();
   if (!batchDimMap.empty()) {
     int64_t lhsIndex = batchDimMap[0].first;
     int64_t rhsIndex = batchDimMap[0].second;
-    auto newOp = lowerParallel(op, lhsIndex, rhsIndex, rewriter);
+    auto newOp = lowerParallel(rewriter, op, lhsIndex, rhsIndex, mask);
     if (failed(newOp))
       return failure();
-    rewriter.replaceOp(op, *newOp);
+    rewriter.replaceOp(rootOp, *newOp);
     return success();
   }
 
@@ -1970,10 +1975,10 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
   VectorType lhsType = op.getLhsType();
   for (int64_t lhsIndex = 0, e = lhsType.getRank(); lhsIndex < e; ++lhsIndex) {
     if (lhsContractingDimSet.count(lhsIndex) == 0) {
-      auto newOp = lowerParallel(op, lhsIndex, /*rhsIndex=*/-1, rewriter);
+      auto newOp = lowerParallel(rewriter, op, lhsIndex, /*rhsIndex=*/-1, mask);
       if (failed(newOp))
         return failure();
-      rewriter.replaceOp(op, *newOp);
+      rewriter.replaceOp(rootOp, *newOp);
       return success();
     }
   }
@@ -1982,20 +1987,20 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
   VectorType rhsType = op.getRhsType();
   for (int64_t rhsIndex = 0, e = rhsType.getRank(); rhsIndex < e; ++rhsIndex) {
     if (rhsContractingDimSet.count(rhsIndex) == 0) {
-      auto newOp = lowerParallel(op, /*lhsIndex=*/-1, rhsIndex, rewriter);
+      auto newOp = lowerParallel(rewriter, op, /*lhsIndex=*/-1, rhsIndex, mask);
       if (failed(newOp))
         return failure();
-      rewriter.replaceOp(op, *newOp);
+      rewriter.replaceOp(rootOp, *newOp);
       return success();
     }
   }
 
   // Lower the first remaining reduction dimension.
   if (!contractingDimMap.empty()) {
-    auto newOp = lowerReduction(op, rewriter);
+    auto newOp = lowerReduction(rewriter, op, mask);
     if (failed(newOp))
       return failure();
-    rewriter.replaceOp(op, *newOp);
+    rewriter.replaceOp(rootOp, *newOp);
     return success();
   }
 
@@ -2005,10 +2010,11 @@ ContractionOpLowering::matchAndRewrite(vector::ContractionOp op,
 // Lower one parallel dimension.
 // Incidentally also tolerates unit-size (hence trivial) reduction dimensions.
 // TODO: consider reusing existing contract unrolling
-FailureOr<Value>
-ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
-                                     int64_t rhsIndex,
-                                     PatternRewriter &rewriter) const {
+FailureOr<Value> ContractionOpLowering::lowerParallel(PatternRewriter &rewriter,
+                                                      vector::ContractionOp op,
+                                                      int64_t lhsIndex,
+                                                      int64_t rhsIndex,
+                                                      Value mask) const {
   VectorType lhsType = op.getLhsType();
   VectorType rhsType = op.getRhsType();
   VectorType resType = op.getResultType().cast<VectorType>();
@@ -2046,6 +2052,7 @@ ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
       diag << "expected the dimension for iterIndex=" << iterIndex
            << " to either appear in the result map, or to be a unit dimension";
     });
+
   // Construct new iterator types and affine map array attribute.
   std::array<AffineMap, 3> lowIndexingMaps = {
       adjustMap(iMap[0], iterIndex, rewriter),
@@ -2058,22 +2065,29 @@ ContractionOpLowering::lowerParallel(vector::ContractionOp op, int64_t lhsIndex,
   Location loc = op.getLoc();
   Value result = rewriter.create<arith::ConstantOp>(
       loc, resType, rewriter.getZeroAttr(resType));
+
   for (int64_t d = 0; d < dimSize; ++d) {
     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
     auto acc = reshapeLoad(loc, op.getAcc(), resType, resIndex, d, rewriter);
-    Value lowContract = rewriter.create<vector::ContractionOp>(
+
+    Value lowMask;
+    if (mask)
+      lowMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
+                            iterIndex, d, rewriter);
+
+    Operation *lowContract = rewriter.create<vector::ContractionOp>(
         loc, lhs, rhs, acc, lowAffine, lowIter);
-    result =
-        reshapeStore(loc, lowContract, result, resType, resIndex, d, rewriter);
+    lowContract = maskOperation(rewriter, lowContract, lowMask);
+    result = reshapeStore(loc, lowContract->getResult(0), result, resType,
+                          resIndex, d, rewriter);
   }
   return result;
 }
 
 // Lower one reduction dimension.
-FailureOr<Value>
-ContractionOpLowering::lowerReduction(vector::ContractionOp op,
-                                      PatternRewriter &rewriter) const {
+FailureOr<Value> ContractionOpLowering::lowerReduction(
+    PatternRewriter &rewriter, vector::ContractionOp op, Value mask) const {
   auto loc = op.getLoc();
   VectorType lhsType = op.getLhsType();
   VectorType rhsType = op.getRhsType();
@@ -2110,10 +2124,12 @@ ContractionOpLowering::lowerReduction(vector::ContractionOp op,
           op, "When LHS has rank 1, expected also RHS to have rank 1");
     Value m = createMul(loc, op.getLhs(), op.getRhs(), isInt, rewriter);
     auto kind = vector::CombiningKind::ADD;
-    if (auto acc = op.getAcc())
-      return rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
-          .getResult();
-    return rewriter.create<vector::ReductionOp>(loc, kind, m).getResult();
+
+    Value acc = op.getAcc();
+    Operation *reductionOp =
+        acc ? rewriter.create<vector::ReductionOp>(loc, kind, m, acc)
+            : rewriter.create<vector::ReductionOp>(loc, kind, m);
+    return maskOperation(rewriter, reductionOp, mask)->getResult(0);
   }
   // Construct new iterator types and affine map array attribute.
   std::array<AffineMap, 3> lowIndexingMaps = {
@@ -2131,8 +2147,14 @@ ContractionOpLowering::lowerReduction(vector::ContractionOp op,
   for (int64_t d = 0; d < dimSize; ++d) {
     auto lhs = reshapeLoad(loc, op.getLhs(), lhsType, lhsIndex, d, rewriter);
     auto rhs = reshapeLoad(loc, op.getRhs(), rhsType, rhsIndex, d, rewriter);
-    result = rewriter.create<vector::ContractionOp>(loc, lhs, rhs, result,
-                                                    lowAffine, lowIter);
+    Value newMask;
+    if (mask)
+      newMask = reshapeLoad(loc, mask, cast<VectorType>(mask.getType()),
+                            iterIndex, d, rewriter);
+
+    Operation *newContract = rewriter.create<vector::ContractionOp>(
+        loc, lhs, rhs, result, lowAffine, lowIter);
+    result = maskOperation(rewriter, newContract, newMask)->getResult(0);
   }
   return result;
 }
index 6ad8a09..2cbd604 100644 (file)
@@ -28,6 +28,18 @@ func.func @extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2:
   return %0 : f32
 }
 
+// CHECK-LABEL: func @masked_extract_contract1
+//  CHECK-SAME:   %[[A:.*0]]: vector<4xf32>, %[[B:.*1]]: vector<4xf32>, %[[C:.*2]]: f32
+//  CHECK-SAME:   %[[M:.*]]: vector<4xi1>
+//       CHECK:   %[[F:.*]] = arith.mulf %[[A]], %[[B]] : vector<4xf32>
+//       CHECK:   %[[R:.*]] = vector.mask %[[M]] { vector.reduction <add>, %0, %arg2 : vector<4xf32> into f32 } : vector<4xi1> -> f32
+//       CHECK:   return %[[R]] : f32
+
+func.func @masked_extract_contract1(%arg0: vector<4xf32>, %arg1: vector<4xf32>, %arg2: f32, %mask: vector<4xi1>) -> f32 {
+  %0 = vector.mask %mask { vector.contract #dotp_trait %arg0, %arg1, %arg2 : vector<4xf32>, vector<4xf32> into f32 } : vector<4xi1> -> f32
+  return %0 : f32
+}
+
 // CHECK-LABEL: func @extract_contract1_int
 // CHECK-SAME: %[[A:.*0]]: vector<4xi32>,
 // CHECK-SAME: %[[B:.*1]]: vector<4xi32>,