[mlir][Vector] Support masking for more contraction flavors
authorDiego Caballero <diegocaballero@google.com>
Wed, 22 Feb 2023 01:20:10 +0000 (01:20 +0000)
committerDiego Caballero <diegocaballero@google.com>
Wed, 22 Feb 2023 01:47:44 +0000 (01:47 +0000)
This patch adds masking support for more contraction flavors including those
with any combiner operation (add, mul, min, max, and, or, etc.) and
regular matmul contractions.

Combiner operations that are performing vertical reductions (and,
therefore, they are not represented with a horizontal reduction
operation) can be executed unmasked. However, the previous value of
the accumulator must be propagated for lanes that shouldn't accumulate.
We achieve this goal by introducing a select operation after the
accumulator to choose between the combined and the previous accumulator
value. This design decision is made to avoid introducing masking support
to all the arithmetic and logical operations in the Arith dialect. VP
intrinsics do not support pass-thru values either so we would have to
generate the same sequence when lowering to LLVM. The op + select
pattern is peepholed by some backend with native masking support for those
operations.

Consequently, this patch removes masking support from the vector.fma
operation to follow the same approach for all the combiner operations.

Reviewed By: ThomasRaoux

Differential Revision: https://reviews.llvm.org/D144239

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/Dialect/Vector/vector-contract-transforms.mlir

index deb86df..56f8b4b 100644 (file)
@@ -191,7 +191,7 @@ bool isDisjointTransferSet(VectorTransferOpInterface transferA,
 /// Return the result value of reducing two scalar/vector values with the
 /// corresponding arith operation.
 Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
-                         Value v1, Value v2);
+                         Value v1, Value acc, Value mask = Value());
 
 /// Returns true if `attr` has "parallel" iterator type semantics.
 inline bool isParallelIterator(Attribute attr) {
@@ -214,8 +214,17 @@ void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);
 /// Creates a vector.mask operation around a maskable operation. Returns the
 /// vector.mask operation if the mask provided is valid. Otherwise, returns the
 /// maskable operation itself.
-Operation *maskOperation(RewriterBase &rewriter, Operation *maskableOp,
-                         Value mask);
+Operation *maskOperation(OpBuilder &builder, Operation *maskableOp,
+                         Value mask, Value passthru = Value());
+
+/// Creates a vector select operation that picks values from `newValue` or
+/// `passthru` for each result vector lane based on `mask`. This utility is used
+/// to propagate the pass-thru value for masked-out or expeculatively executed
+/// lanes. VP intrinsics do not support pass-thru values and every mask-out lane
+/// is set to poison. LLVM backends are usually able to match op + select
+/// patterns and fold them into a native target instructions.
+Value selectPassthru(OpBuilder &builder, Value mask, Value newValue,
+                     Value passthru);
 
 } // namespace vector
 } // namespace mlir
index c5ebe9f..04fb36a 100644 (file)
@@ -633,7 +633,6 @@ def Vector_ExtractOp :
 def Vector_FMAOp :
   Op<Vector_Dialect, "fma", [
        Pure, AllTypesMatch<["lhs", "rhs", "acc", "result"]>,
-       DeclareOpInterfaceMethods<MaskableOpInterface>,
        DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
      ] # ElementwiseMappable.traits>,
     Arguments<(ins VectorOfAnyRankOf<[AnyFloat]>:$lhs,
index 159bae8..b73c01a 100644 (file)
@@ -704,11 +704,6 @@ public:
     Value acc = adaptor.getAcc();
     Location loc = reductionOp.getLoc();
 
-    // Masked reductions are lowered separately.
-    auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation());
-    if (maskableOp.isMasked())
-      return failure();
-
     if (eltType.isIntOrIndex()) {
       // Integer reductions: add/mul/min/max/and/or/xor.
       Value result;
@@ -1108,47 +1103,12 @@ public:
     if (vType.getRank() > 1)
       return failure();
 
-    // Masked fmas are lowered separately.
-    auto maskableOp = cast<MaskableOpInterface>(fmaOp.getOperation());
-    if (maskableOp.isMasked())
-      return failure();
-
     rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
         fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
     return success();
   }
 };
 
-/// Conversion pattern that turns a masked vector.fma on a 1-D vector into their
-/// LLVM counterpart representation. Non side effecting VP intrinsics are not
-/// fully supported by some backends, including x86, and they don't support
-/// pass-through values either. For these reasons, we generate an unmasked
-/// fma followed by a select instrution to emulate the masking behavior.
-/// This pattern is peepholed by some backends with support for masked fma
-/// instructions. This pattern does not match vectors of n >= 2 rank.
-class MaskedFMAOp1DConversion
-    : public VectorMaskOpConversionBase<vector::FMAOp> {
-public:
-  using VectorMaskOpConversionBase<vector::FMAOp>::VectorMaskOpConversionBase;
-
-  MaskedFMAOp1DConversion(LLVMTypeConverter &converter, bool fullVPIntr)
-      : VectorMaskOpConversionBase<vector::FMAOp>(converter) {}
-
-  virtual LogicalResult matchAndRewriteMaskableOp(
-      vector::MaskOp maskOp, MaskableOpInterface maskableOp,
-      ConversionPatternRewriter &rewriter) const override {
-    auto fmaOp = cast<FMAOp>(maskableOp.getOperation());
-    Type llvmType = typeConverter->convertType(fmaOp.getVectorType());
-
-    Value fmulAddOp = rewriter.create<LLVM::FMulAddOp>(
-        fmaOp.getLoc(), llvmType, fmaOp.getLhs(), fmaOp.getRhs(),
-        fmaOp.getAcc());
-    rewriter.replaceOpWithNewOp<LLVM::SelectOp>(
-        maskOp, llvmType, maskOp.getMask(), fmulAddOp, fmaOp.getAcc());
-    return success();
-  }
-};
-
 class VectorInsertElementOpConversion
     : public ConvertOpToLLVMPattern<vector::InsertElementOp> {
 public:
@@ -1315,11 +1275,6 @@ public:
     if (vType.getRank() < 2)
       return failure();
 
-    // Masked fmas are lowered separately.
-    auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
-    if (maskableOp.isMasked())
-      return failure();
-
     auto loc = op.getLoc();
     auto elemType = vType.getElementType();
     Value zero = rewriter.create<arith::ConstantOp>(
@@ -1748,10 +1703,9 @@ void mlir::populateVectorToLLVMConversionPatterns(
   patterns
       .add<VectorBitCastOpConversion, VectorShuffleOpConversion,
            VectorExtractElementOpConversion, VectorExtractOpConversion,
-           VectorFMAOp1DConversion, MaskedFMAOp1DConversion,
-           VectorInsertElementOpConversion, VectorInsertOpConversion,
-           VectorPrintOpConversion, VectorTypeCastOpConversion,
-           VectorScaleOpConversion,
+           VectorFMAOp1DConversion, VectorInsertElementOpConversion,
+           VectorInsertOpConversion, VectorPrintOpConversion,
+           VectorTypeCastOpConversion, VectorScaleOpConversion,
            VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
            VectorLoadStoreConversion<vector::MaskedLoadOp,
                                      vector::MaskedLoadOpAdaptor>,
index 8c6609d..eb58f90 100644 (file)
@@ -1790,16 +1790,6 @@ std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
   return llvm::to_vector<4>(getVectorType().getShape());
 }
 
-// MaskableOpInterface methods.
-
-/// Returns the mask type expected by this operation. Mostly used for
-/// verification purposes. It requires the operation to be vectorized."
-Type FMAOp::getExpectedMaskType() {
-  auto vecType = this->getVectorType();
-  return VectorType::get(vecType.getShape(),
-                         IntegerType::get(vecType.getContext(), /*width=*/1));
-}
-
 //===----------------------------------------------------------------------===//
 // BroadcastOp
 //===----------------------------------------------------------------------===//
@@ -5807,53 +5797,71 @@ bool WarpExecuteOnLane0Op::areTypesCompatible(Type lhs, Type rhs) {
 }
 
 Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
-                                       CombiningKind kind, Value v1, Value v2) {
+                                       CombiningKind kind, Value v1, Value acc,
+                                       Value mask) {
   Type t1 = getElementTypeOrSelf(v1.getType());
-  Type t2 = getElementTypeOrSelf(v2.getType());
+  Type tAcc = getElementTypeOrSelf(acc.getType());
+  Value result;
+
   switch (kind) {
   case CombiningKind::ADD:
-    if (t1.isIntOrIndex() && t2.isIntOrIndex())
-      return b.createOrFold<arith::AddIOp>(loc, v1, v2);
-    else if (t1.isa<FloatType>() && t2.isa<FloatType>())
-      return b.createOrFold<arith::AddFOp>(loc, v1, v2);
-    llvm_unreachable("invalid value types for ADD reduction");
+    if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
+      result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
+    else if (t1.isa<FloatType>() && tAcc.isa<FloatType>())
+      result = b.createOrFold<arith::AddFOp>(loc, v1, acc);
+    else
+      llvm_unreachable("invalid value types for ADD reduction");
+    break;
   case CombiningKind::AND:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::AndIOp>(loc, v1, v2);
+    assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+    result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
+    break;
   case CombiningKind::MAXF:
-    assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
+    assert(t1.isa<FloatType>() && tAcc.isa<FloatType>() &&
            "expected float values");
-    return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
+    result = b.createOrFold<arith::MaxFOp>(loc, v1, acc);
+    break;
   case CombiningKind::MINF:
-    assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
+    assert(t1.isa<FloatType>() && tAcc.isa<FloatType>() &&
            "expected float values");
-    return b.createOrFold<arith::MinFOp>(loc, v1, v2);
+    result = b.createOrFold<arith::MinFOp>(loc, v1, acc);
+    break;
   case CombiningKind::MAXSI:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
+    assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+    result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
+    break;
   case CombiningKind::MINSI:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
+    assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+    result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
+    break;
   case CombiningKind::MAXUI:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
+    assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+    result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
+    break;
   case CombiningKind::MINUI:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
+    assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+    result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
+    break;
   case CombiningKind::MUL:
-    if (t1.isIntOrIndex() && t2.isIntOrIndex())
-      return b.createOrFold<arith::MulIOp>(loc, v1, v2);
-    else if (t1.isa<FloatType>() && t2.isa<FloatType>())
-      return b.createOrFold<arith::MulFOp>(loc, v1, v2);
-    llvm_unreachable("invalid value types for MUL reduction");
+    if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
+      result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
+    else if (t1.isa<FloatType>() && tAcc.isa<FloatType>())
+      result = b.createOrFold<arith::MulFOp>(loc, v1, acc);
+    else
+      llvm_unreachable("invalid value types for MUL reduction");
+    break;
   case CombiningKind::OR:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::OrIOp>(loc, v1, v2);
+    assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+    result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
+    break;
   case CombiningKind::XOR:
-    assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
-    return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
+    assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+    result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
+    break;
   };
-  llvm_unreachable("unknown CombiningKind");
+
+  assert(result && "unknown CombiningKind");
+  return selectPassthru(b, mask, result, acc);
 }
 
 //===----------------------------------------------------------------------===//
@@ -5875,13 +5883,34 @@ void mlir::vector::createMaskOpRegion(OpBuilder &builder,
 /// Creates a vector.mask operation around a maskable operation. Returns the
 /// vector.mask operation if the mask provided is valid. Otherwise, returns
 /// the maskable operation itself.
-Operation *mlir::vector::maskOperation(RewriterBase &rewriter,
-                                       Operation *maskableOp, Value mask) {
+Operation *mlir::vector::maskOperation(OpBuilder &builder,
+                                       Operation *maskableOp, Value mask,
+                                       Value passthru) {
   if (!mask)
     return maskableOp;
-  return rewriter.create<MaskOp>(maskableOp->getLoc(),
-                                 maskableOp->getResultTypes(), mask, maskableOp,
-                                 createMaskOpRegion);
+  if (passthru)
+    return builder.create<MaskOp>(maskableOp->getLoc(),
+                                  maskableOp->getResultTypes(), mask, passthru,
+                                  maskableOp, createMaskOpRegion);
+  return builder.create<MaskOp>(maskableOp->getLoc(),
+                                maskableOp->getResultTypes(), mask, maskableOp,
+                                createMaskOpRegion);
+}
+
+/// Creates a vector select operation that picks values from `newValue` or
+/// `passthru` for each result vector lane based on `mask`. This utility is used
+/// to propagate the pass-thru value of vector.mask or for cases where only the
+/// pass-thru value propagation is needed. VP intrinsics do not support
+/// pass-thru values and every mask-out lane is set to poison. LLVM backends are
+/// usually able to match op + select patterns and fold them into a native
+/// target instructions.
+Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
+                                   Value newValue, Value passthru) {
+  if (!mask)
+    return newValue;
+
+  return builder.create<arith::SelectOp>(newValue.getLoc(), newValue.getType(),
+                                         mask, newValue, passthru);
 }
 
 //===----------------------------------------------------------------------===//
index e7b8cd5..eecf970 100644 (file)
@@ -151,8 +151,7 @@ static SmallVector<IntType> extractVector(ArrayAttr arrayAttr) {
 static std::optional<Value>
 createContractArithOp(Location loc, Value x, Value y, Value acc,
                       vector::CombiningKind kind, PatternRewriter &rewriter,
-                      bool isInt,
-                      std::optional<Value> maybeMask = std::nullopt) {
+                      bool isInt, Value mask = Value()) {
   using vector::CombiningKind;
   Value mul;
 
@@ -171,20 +170,20 @@ createContractArithOp(Location loc, Value x, Value y, Value acc,
       return std::nullopt;
     // Special case for fused multiply-add.
     if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) {
-      Operation *fmaOp = rewriter.create<vector::FMAOp>(loc, x, y, acc);
-      if (maybeMask.has_value() && maybeMask.value())
-        fmaOp = maskOperation(rewriter, fmaOp, maybeMask.value());
-      return fmaOp->getResult(0);
+      Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
+      if (mask)
+        // The fma op doesn't need explicit masking. However, fma ops used in
+        // reductions must preserve previous 'acc' values for masked-out lanes.
+        fma = selectPassthru(rewriter, mask, fma, acc);
+      return fma;
     }
     mul = rewriter.create<arith::MulFOp>(loc, x, y);
   }
 
-  assert((!maybeMask.has_value() || !maybeMask.value()) &&
-         "Unsupported masked case");
-
   if (!acc)
     return std::optional<Value>(mul);
-  return makeArithReduction(rewriter, loc, kind, mul, acc);
+
+  return makeArithReduction(rewriter, loc, kind, mul, acc, mask);
 }
 
 /// Return the positions of the reductions in the given map.
@@ -587,13 +586,17 @@ public:
     for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
       auto pos = rewriter.getI64ArrayAttr(d);
       Value x =
-          rewriter.create<vector::ExtractOp>(loc, eltType, op.getLhs(), pos);
+          rewriter.create<vector::ExtractOp>(loc, op.getLhs(), pos);
       Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
       Value r = nullptr;
       if (acc)
-        r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
+        r = rewriter.create<vector::ExtractOp>(loc, acc, pos);
+      Value extrMask;
+      if (mask)
+        extrMask = rewriter.create<vector::ExtractOp>(loc, mask, pos);
+
       std::optional<Value> m = createContractArithOp(
-          loc, a, op.getRhs(), r, kind, rewriter, isInt, mask);
+          loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
       if (!m.has_value())
         return failure();
       result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, pos);
@@ -638,6 +641,7 @@ struct ContractOpToElementwise
     if (vectorTransformOptions.vectorContractLowering !=
         vector::VectorContractLowering::ParallelArith)
       return failure();
+
     ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
     ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
     AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
@@ -1564,8 +1568,7 @@ struct UnrolledOuterProductGenerator
       mask = maskableOp.getMaskingOp().getMask();
   }
 
-  Value t(Value v) {
-    static constexpr std::array<int64_t, 2> perm = {1, 0};
+  Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
     if (!v)
       return v;
     return rewriter.create<vector::TransposeOp>(loc, v, perm);
@@ -1620,7 +1623,8 @@ struct UnrolledOuterProductGenerator
     bindDims(rewriter.getContext(), m, n, k);
     // Classical row-major matmul:  Just permute the lhs.
     if (layout({{m, k}, {k, n}, {m, n}}))
-      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
+      return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1),
+                       t(mask, {2, 0, 1}));
     // TODO: may be better to fail and use some vector<k> -> scalar reduction.
     if (layout({{m, k}, {n, k}, {m, n}})) {
       Value tlhs = t(lhs);
index a72ab15..0b9b86a 100644 (file)
@@ -418,16 +418,132 @@ func.func @outerproduct_add(%arg0: vector<2xf32>, %arg1: vector<3xf32>, %arg2: v
 
 // -----
 
-func.func @masked_vector_contract(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
+func.func @masked_float_add_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
   %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
   return %0 : vector<2xf32>
 }
 
-// We can't check for the intermediate 'vector.mask { vector.fma }' state so we
-// just make sure the vector.fma is lowered.
+// CHECK-LABEL:   func.func @masked_float_add_outerprod(
+// CHECK-SAME:                                          %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
+// CHECK:           %[[VAL_8:.*]] = llvm.intr.fmuladd(%[[VAL_0]], %{{.*}}, %[[VAL_2]])  : (vector<2xf32>, vector<2xf32>, vector<2xf32>) -> vector<2xf32>
+// CHECK:           %[[VAL_9:.*]] = arith.select %[[VAL_3]], %[[VAL_8]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
 
-// CHECK:           llvm.intr.fmuladd
-// CHECK:           llvm.select
+// -----
+
+func.func @masked_float_mul_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
+  %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<mul>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL:   func.func @masked_float_mul_outerprod(
+// CHECK-SAME:                                          %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
+// CHECK:           %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
+// CHECK:           %[[VAL_9:.*]] = arith.mulf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK:           %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
+
+// -----
+
+func.func @masked_float_max_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
+  %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<maxf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL:   func.func @masked_float_max_outerprod(
+// CHECK-SAME:                                          %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
+// CHECK:           %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
+// CHECK:           %[[VAL_9:.*]] = arith.maxf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK:           %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
+
+// -----
+
+func.func @masked_float_min_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
+  %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<minf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL:   func.func @masked_float_min_outerprod(
+// CHECK-SAME:                                          %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
+// CHECK:           %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
+// CHECK:           %[[VAL_9:.*]] = arith.minf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK:           %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
+
+// -----
+
+func.func @masked_int_add_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
+  %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<add>} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32>
+  return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL:   func.func @masked_int_add_outerprod(
+// CHECK-SAME:                                        %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> {
+// CHECK:           %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32>
+// CHECK:           %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
+// CHECK:           %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
+
+// -----
+
+func.func @masked_int_mul_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
+  %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<mul>} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32>
+  return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL:   func.func @masked_int_mul_outerprod(
+// CHECK-SAME:                                        %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> {
+// CHECK:           %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32>
+// CHECK:           %[[VAL_9:.*]] = arith.muli %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
+// CHECK:           %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
+
+// -----
+
+func.func @masked_int_max_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
+  %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<maxsi>} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32>
+  return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL:   func.func @masked_int_max_outerprod(
+// CHECK-SAME:                                        %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> {
+// CHECK:           %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32>
+// CHECK:           %[[VAL_9:.*]] = arith.maxsi %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
+// CHECK:           %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
+
+// -----
+
+func.func @masked_int_min_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
+  %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<minui>} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32>
+  return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL:   func.func @masked_int_min_outerprod(
+// CHECK-SAME:                                        %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> {
+// CHECK:           %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32>
+// CHECK:           %[[VAL_9:.*]] = arith.minui %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
+// CHECK:           %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
+
+// -----
+
+func.func @masked_int_and_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
+  %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<and>} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32>
+  return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL:   func.func @masked_int_and_outerprod(
+// CHECK-SAME:                                        %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> {
+// CHECK:           %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32>
+// CHECK:           %[[VAL_9:.*]] = arith.andi %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
+// CHECK:           %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
+
+// -----
+
+func.func @masked_int_or_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
+  %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<or>} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32>
+  return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL:   func.func @masked_int_or_outerprod(
+// CHECK-SAME:                                       %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> {
+// CHECK:           %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32>
+// CHECK:           %[[VAL_9:.*]] = arith.ori %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
+// CHECK:           %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
 
 // -----
 
@@ -2157,17 +2273,3 @@ func.func @vector_scalable_extract(%vec: vector<[4]xf32>) -> vector<8xf32> {
   %0 = vector.scalable.extract %vec[0] : vector<8xf32> from vector<[4]xf32>
   return %0 : vector<8xf32>
 }
-
-// -----
-
-// CHECK-LABEL:   func.func @masked_vector_fma(
-// CHECK-SAME:                                 %[[INPUT:.*]]: vector<8xf32>,
-// CHECK-SAME:                                 %[[MASK:.*]]: vector<8xi1>) -> vector<8xf32>
-// CHECK:           %[[FMA:.*]] = llvm.intr.fmuladd(%[[INPUT]], %[[INPUT]], %[[INPUT]])  : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-// CHECK:           llvm.select %[[MASK]], %[[FMA]], %[[INPUT]] : vector<8xi1>, vector<8xf32>
-
-func.func @masked_vector_fma(%a: vector<8xf32>, %m: vector<8xi1>) -> vector<8xf32> {
-  %0 = vector.mask %m { vector.fma %a, %a, %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
-  return %0 : vector<8xf32>
-}
-
index 1617cb1..6ad8a09 100644 (file)
@@ -76,6 +76,30 @@ func.func @extract_contract2(%arg0: vector<2x3xf32>,
   return %0 : vector<2xf32>
 }
 
+// OUTERPRODUCT-LABEL:   func.func @masked_extract_contract2(
+// OUTERPRODUCT-SAME:                                      %[[VAL_0:.*]]: vector<2x3xf32>,
+// OUTERPRODUCT-SAME:                                      %[[VAL_1:.*]]: vector<3xf32>,
+// OUTERPRODUCT-SAME:                                      %[[VAL_2:.*]]: vector<2xf32>,
+// OUTERPRODUCT-SAME:                                      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
+// OUTERPRODUCT:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
+// OUTERPRODUCT:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x2xi1>
+// OUTERPRODUCT:           vector.mask %[[MASK0]] { vector.outerproduct
+
+// OUTERPRODUCT:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x2xi1>
+// OUTERPRODUCT:           vector.mask %[[MASK1]] { vector.outerproduct
+
+// OUTERPRODUCT:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x2xi1>
+// OUTERPRODUCT:           vector.mask %[[MASK2]] { vector.outerproduct
+
+func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
+                                    %arg1: vector<3xf32>,
+                                    %arg2: vector<2xf32>,
+                                    %m: vector<2x3xi1>) -> vector<2xf32> {
+  %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+          : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
+  return %0 : vector<2xf32>
+}
+
 // CHECK-LABEL: func @extract_contract2_int
 // CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>,
 // CHECK-SAME: %[[B:.*1]]: vector<3xi32>,
@@ -182,6 +206,32 @@ func.func @extract_contract4(%arg0: vector<2x2xf32>,
   return %0 : vector<2x2xf32>
 }
 
+// OUTERPRODUCT-LABEL: func.func @masked_extract_contract4(
+// OUTERPRODUCT-SAME:                                      %[[VAL_0:.*]]: vector<3x5xf32>,
+// OUTERPRODUCT-SAME:                                      %[[VAL_1:.*]]: vector<5x7xf32>,
+// OUTERPRODUCT-SAME:                                      %[[VAL_2:.*]]: vector<3x7xf32>,
+// OUTERPRODUCT-SAME:                                      %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
+// OUTERPRODUCT:         %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
+// OUTERPRODUCT:         %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<5x3x7xi1>
+// OUTERPRODUCT:         %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// OUTERPRODUCT:         %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<5x3x7xi1>
+// OUTERPRODUCT:         %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// OUTERPRODUCT:         %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<5x3x7xi1>
+// OUTERPRODUCT:         %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// OUTERPRODUCT:         %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<5x3x7xi1>
+// OUTERPRODUCT:         %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// OUTERPRODUCT:         %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<5x3x7xi1>
+// OUTERPRODUCT:         %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+
+func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
+                                    %arg1: vector<5x7xf32>,
+                                    %arg2: vector<3x7xf32>,
+                                    %m : vector<3x7x5xi1>) -> vector<3x7xf32> {
+  %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
+  : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32>
+  return %0 : vector<3x7xf32>
+}
+
 #contraction2d_accesses = [
   affine_map<(i, j) -> (i, j)>,
   affine_map<(i, j) -> (i, j)>,
@@ -1197,26 +1247,4 @@ func.func @parrallel_contract_lowering_scalar(%arg0: vector<1x1xf32>, %arg1: vec
   return %0 : f32
 }
 
-func.func @masked_vector_contract(%arg0: vector<2x3xf32>,
-                                  %arg1: vector<3xf32>,
-                                  %arg2: vector<2xf32>,
-                                  %m: vector<2x3xi1>) -> vector<2xf32> {
-  %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
-          : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
-  return %0 : vector<2xf32>
-}
 
-// OUTERPRODUCT-LABEL:   func.func @masked_vector_contract(
-// OUTERPRODUCT-SAME:                                      %[[VAL_0:.*]]: vector<2x3xf32>,
-// OUTERPRODUCT-SAME:                                      %[[VAL_1:.*]]: vector<3xf32>,
-// OUTERPRODUCT-SAME:                                      %[[VAL_2:.*]]: vector<2xf32>,
-// OUTERPRODUCT-SAME:                                      %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
-// OUTERPRODUCT:           %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
-// OUTERPRODUCT:           %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x2xi1>
-// OUTERPRODUCT:           vector.mask %[[MASK0]] { vector.outerproduct
-
-// OUTERPRODUCT:           %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x2xi1>
-// OUTERPRODUCT:           vector.mask %[[MASK1]] { vector.outerproduct
-
-// OUTERPRODUCT:           %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x2xi1>
-// OUTERPRODUCT:           vector.mask %[[MASK2]] { vector.outerproduct