/// Return the result value of reducing two scalar/vector values with the
/// corresponding arith operation.
Value makeArithReduction(OpBuilder &b, Location loc, CombiningKind kind,
- Value v1, Value v2);
+ Value v1, Value acc, Value mask = Value());
/// Returns true if `attr` has "parallel" iterator type semantics.
inline bool isParallelIterator(Attribute attr) {
/// Creates a vector.mask operation around a maskable operation. Returns the
/// vector.mask operation if the mask provided is valid. Otherwise, returns the
/// maskable operation itself.
-Operation *maskOperation(RewriterBase &rewriter, Operation *maskableOp,
- Value mask);
+Operation *maskOperation(OpBuilder &builder, Operation *maskableOp,
+ Value mask, Value passthru = Value());
+
+/// Creates a vector select operation that picks values from `newValue` or
+/// `passthru` for each result vector lane based on `mask`. This utility is used
+/// to propagate the pass-thru value for masked-out or expeculatively executed
+/// lanes. VP intrinsics do not support pass-thru values and every mask-out lane
+/// is set to poison. LLVM backends are usually able to match op + select
+/// patterns and fold them into a native target instructions.
+Value selectPassthru(OpBuilder &builder, Value mask, Value newValue,
+ Value passthru);
} // namespace vector
} // namespace mlir
def Vector_FMAOp :
Op<Vector_Dialect, "fma", [
Pure, AllTypesMatch<["lhs", "rhs", "acc", "result"]>,
- DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
] # ElementwiseMappable.traits>,
Arguments<(ins VectorOfAnyRankOf<[AnyFloat]>:$lhs,
Value acc = adaptor.getAcc();
Location loc = reductionOp.getLoc();
- // Masked reductions are lowered separately.
- auto maskableOp = cast<MaskableOpInterface>(reductionOp.getOperation());
- if (maskableOp.isMasked())
- return failure();
-
if (eltType.isIntOrIndex()) {
// Integer reductions: add/mul/min/max/and/or/xor.
Value result;
if (vType.getRank() > 1)
return failure();
- // Masked fmas are lowered separately.
- auto maskableOp = cast<MaskableOpInterface>(fmaOp.getOperation());
- if (maskableOp.isMasked())
- return failure();
-
rewriter.replaceOpWithNewOp<LLVM::FMulAddOp>(
fmaOp, adaptor.getLhs(), adaptor.getRhs(), adaptor.getAcc());
return success();
}
};
-/// Conversion pattern that turns a masked vector.fma on a 1-D vector into their
-/// LLVM counterpart representation. Non side effecting VP intrinsics are not
-/// fully supported by some backends, including x86, and they don't support
-/// pass-through values either. For these reasons, we generate an unmasked
-/// fma followed by a select instrution to emulate the masking behavior.
-/// This pattern is peepholed by some backends with support for masked fma
-/// instructions. This pattern does not match vectors of n >= 2 rank.
-class MaskedFMAOp1DConversion
- : public VectorMaskOpConversionBase<vector::FMAOp> {
-public:
- using VectorMaskOpConversionBase<vector::FMAOp>::VectorMaskOpConversionBase;
-
- MaskedFMAOp1DConversion(LLVMTypeConverter &converter, bool fullVPIntr)
- : VectorMaskOpConversionBase<vector::FMAOp>(converter) {}
-
- virtual LogicalResult matchAndRewriteMaskableOp(
- vector::MaskOp maskOp, MaskableOpInterface maskableOp,
- ConversionPatternRewriter &rewriter) const override {
- auto fmaOp = cast<FMAOp>(maskableOp.getOperation());
- Type llvmType = typeConverter->convertType(fmaOp.getVectorType());
-
- Value fmulAddOp = rewriter.create<LLVM::FMulAddOp>(
- fmaOp.getLoc(), llvmType, fmaOp.getLhs(), fmaOp.getRhs(),
- fmaOp.getAcc());
- rewriter.replaceOpWithNewOp<LLVM::SelectOp>(
- maskOp, llvmType, maskOp.getMask(), fmulAddOp, fmaOp.getAcc());
- return success();
- }
-};
-
class VectorInsertElementOpConversion
: public ConvertOpToLLVMPattern<vector::InsertElementOp> {
public:
if (vType.getRank() < 2)
return failure();
- // Masked fmas are lowered separately.
- auto maskableOp = cast<MaskableOpInterface>(op.getOperation());
- if (maskableOp.isMasked())
- return failure();
-
auto loc = op.getLoc();
auto elemType = vType.getElementType();
Value zero = rewriter.create<arith::ConstantOp>(
patterns
.add<VectorBitCastOpConversion, VectorShuffleOpConversion,
VectorExtractElementOpConversion, VectorExtractOpConversion,
- VectorFMAOp1DConversion, MaskedFMAOp1DConversion,
- VectorInsertElementOpConversion, VectorInsertOpConversion,
- VectorPrintOpConversion, VectorTypeCastOpConversion,
- VectorScaleOpConversion,
+ VectorFMAOp1DConversion, VectorInsertElementOpConversion,
+ VectorInsertOpConversion, VectorPrintOpConversion,
+ VectorTypeCastOpConversion, VectorScaleOpConversion,
VectorLoadStoreConversion<vector::LoadOp, vector::LoadOpAdaptor>,
VectorLoadStoreConversion<vector::MaskedLoadOp,
vector::MaskedLoadOpAdaptor>,
return llvm::to_vector<4>(getVectorType().getShape());
}
-// MaskableOpInterface methods.
-
-/// Returns the mask type expected by this operation. Mostly used for
-/// verification purposes. It requires the operation to be vectorized."
-Type FMAOp::getExpectedMaskType() {
- auto vecType = this->getVectorType();
- return VectorType::get(vecType.getShape(),
- IntegerType::get(vecType.getContext(), /*width=*/1));
-}
-
//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//
}
Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
- CombiningKind kind, Value v1, Value v2) {
+ CombiningKind kind, Value v1, Value acc,
+ Value mask) {
Type t1 = getElementTypeOrSelf(v1.getType());
- Type t2 = getElementTypeOrSelf(v2.getType());
+ Type tAcc = getElementTypeOrSelf(acc.getType());
+ Value result;
+
switch (kind) {
case CombiningKind::ADD:
- if (t1.isIntOrIndex() && t2.isIntOrIndex())
- return b.createOrFold<arith::AddIOp>(loc, v1, v2);
- else if (t1.isa<FloatType>() && t2.isa<FloatType>())
- return b.createOrFold<arith::AddFOp>(loc, v1, v2);
- llvm_unreachable("invalid value types for ADD reduction");
+ if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
+ result = b.createOrFold<arith::AddIOp>(loc, v1, acc);
+ else if (t1.isa<FloatType>() && tAcc.isa<FloatType>())
+ result = b.createOrFold<arith::AddFOp>(loc, v1, acc);
+ else
+ llvm_unreachable("invalid value types for ADD reduction");
+ break;
case CombiningKind::AND:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::AndIOp>(loc, v1, v2);
+ assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+ result = b.createOrFold<arith::AndIOp>(loc, v1, acc);
+ break;
case CombiningKind::MAXF:
- assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
+ assert(t1.isa<FloatType>() && tAcc.isa<FloatType>() &&
"expected float values");
- return b.createOrFold<arith::MaxFOp>(loc, v1, v2);
+ result = b.createOrFold<arith::MaxFOp>(loc, v1, acc);
+ break;
case CombiningKind::MINF:
- assert(t1.isa<FloatType>() && t2.isa<FloatType>() &&
+ assert(t1.isa<FloatType>() && tAcc.isa<FloatType>() &&
"expected float values");
- return b.createOrFold<arith::MinFOp>(loc, v1, v2);
+ result = b.createOrFold<arith::MinFOp>(loc, v1, acc);
+ break;
case CombiningKind::MAXSI:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::MaxSIOp>(loc, v1, v2);
+ assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+ result = b.createOrFold<arith::MaxSIOp>(loc, v1, acc);
+ break;
case CombiningKind::MINSI:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::MinSIOp>(loc, v1, v2);
+ assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+ result = b.createOrFold<arith::MinSIOp>(loc, v1, acc);
+ break;
case CombiningKind::MAXUI:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::MaxUIOp>(loc, v1, v2);
+ assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+ result = b.createOrFold<arith::MaxUIOp>(loc, v1, acc);
+ break;
case CombiningKind::MINUI:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::MinUIOp>(loc, v1, v2);
+ assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+ result = b.createOrFold<arith::MinUIOp>(loc, v1, acc);
+ break;
case CombiningKind::MUL:
- if (t1.isIntOrIndex() && t2.isIntOrIndex())
- return b.createOrFold<arith::MulIOp>(loc, v1, v2);
- else if (t1.isa<FloatType>() && t2.isa<FloatType>())
- return b.createOrFold<arith::MulFOp>(loc, v1, v2);
- llvm_unreachable("invalid value types for MUL reduction");
+ if (t1.isIntOrIndex() && tAcc.isIntOrIndex())
+ result = b.createOrFold<arith::MulIOp>(loc, v1, acc);
+ else if (t1.isa<FloatType>() && tAcc.isa<FloatType>())
+ result = b.createOrFold<arith::MulFOp>(loc, v1, acc);
+ else
+ llvm_unreachable("invalid value types for MUL reduction");
+ break;
case CombiningKind::OR:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::OrIOp>(loc, v1, v2);
+ assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+ result = b.createOrFold<arith::OrIOp>(loc, v1, acc);
+ break;
case CombiningKind::XOR:
- assert(t1.isIntOrIndex() && t2.isIntOrIndex() && "expected int values");
- return b.createOrFold<arith::XOrIOp>(loc, v1, v2);
+ assert(t1.isIntOrIndex() && tAcc.isIntOrIndex() && "expected int values");
+ result = b.createOrFold<arith::XOrIOp>(loc, v1, acc);
+ break;
};
- llvm_unreachable("unknown CombiningKind");
+
+ assert(result && "unknown CombiningKind");
+ return selectPassthru(b, mask, result, acc);
}
//===----------------------------------------------------------------------===//
/// Creates a vector.mask operation around a maskable operation. Returns the
/// vector.mask operation if the mask provided is valid. Otherwise, returns
/// the maskable operation itself.
-Operation *mlir::vector::maskOperation(RewriterBase &rewriter,
- Operation *maskableOp, Value mask) {
+Operation *mlir::vector::maskOperation(OpBuilder &builder,
+ Operation *maskableOp, Value mask,
+ Value passthru) {
if (!mask)
return maskableOp;
- return rewriter.create<MaskOp>(maskableOp->getLoc(),
- maskableOp->getResultTypes(), mask, maskableOp,
- createMaskOpRegion);
+ if (passthru)
+ return builder.create<MaskOp>(maskableOp->getLoc(),
+ maskableOp->getResultTypes(), mask, passthru,
+ maskableOp, createMaskOpRegion);
+ return builder.create<MaskOp>(maskableOp->getLoc(),
+ maskableOp->getResultTypes(), mask, maskableOp,
+ createMaskOpRegion);
+}
+
+/// Creates a vector select operation that picks values from `newValue` or
+/// `passthru` for each result vector lane based on `mask`. This utility is used
+/// to propagate the pass-thru value of vector.mask or for cases where only the
+/// pass-thru value propagation is needed. VP intrinsics do not support
+/// pass-thru values and every mask-out lane is set to poison. LLVM backends are
+/// usually able to match op + select patterns and fold them into a native
+/// target instructions.
+Value mlir::vector::selectPassthru(OpBuilder &builder, Value mask,
+ Value newValue, Value passthru) {
+ if (!mask)
+ return newValue;
+
+ return builder.create<arith::SelectOp>(newValue.getLoc(), newValue.getType(),
+ mask, newValue, passthru);
}
//===----------------------------------------------------------------------===//
static std::optional<Value>
createContractArithOp(Location loc, Value x, Value y, Value acc,
vector::CombiningKind kind, PatternRewriter &rewriter,
- bool isInt,
- std::optional<Value> maybeMask = std::nullopt) {
+ bool isInt, Value mask = Value()) {
using vector::CombiningKind;
Value mul;
return std::nullopt;
// Special case for fused multiply-add.
if (acc && acc.getType().isa<VectorType>() && kind == CombiningKind::ADD) {
- Operation *fmaOp = rewriter.create<vector::FMAOp>(loc, x, y, acc);
- if (maybeMask.has_value() && maybeMask.value())
- fmaOp = maskOperation(rewriter, fmaOp, maybeMask.value());
- return fmaOp->getResult(0);
+ Value fma = rewriter.create<vector::FMAOp>(loc, x, y, acc);
+ if (mask)
+ // The fma op doesn't need explicit masking. However, fma ops used in
+ // reductions must preserve previous 'acc' values for masked-out lanes.
+ fma = selectPassthru(rewriter, mask, fma, acc);
+ return fma;
}
mul = rewriter.create<arith::MulFOp>(loc, x, y);
}
- assert((!maybeMask.has_value() || !maybeMask.value()) &&
- "Unsupported masked case");
-
if (!acc)
return std::optional<Value>(mul);
- return makeArithReduction(rewriter, loc, kind, mul, acc);
+
+ return makeArithReduction(rewriter, loc, kind, mul, acc, mask);
}
/// Return the positions of the reductions in the given map.
for (int64_t d = 0, e = resType.getDimSize(0); d < e; ++d) {
auto pos = rewriter.getI64ArrayAttr(d);
Value x =
- rewriter.create<vector::ExtractOp>(loc, eltType, op.getLhs(), pos);
+ rewriter.create<vector::ExtractOp>(loc, op.getLhs(), pos);
Value a = rewriter.create<vector::BroadcastOp>(loc, rhsType, x);
Value r = nullptr;
if (acc)
- r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
+ r = rewriter.create<vector::ExtractOp>(loc, acc, pos);
+ Value extrMask;
+ if (mask)
+ extrMask = rewriter.create<vector::ExtractOp>(loc, mask, pos);
+
std::optional<Value> m = createContractArithOp(
- loc, a, op.getRhs(), r, kind, rewriter, isInt, mask);
+ loc, a, op.getRhs(), r, kind, rewriter, isInt, extrMask);
if (!m.has_value())
return failure();
result = rewriter.create<vector::InsertOp>(loc, resType, *m, result, pos);
if (vectorTransformOptions.vectorContractLowering !=
vector::VectorContractLowering::ParallelArith)
return failure();
+
ArrayRef<int64_t> lhsShape = contractOp.getLhsType().getShape();
ArrayRef<int64_t> rhsShape = contractOp.getRhsType().getShape();
AffineMap lhsMap = contractOp.getIndexingMapsArray()[0];
mask = maskableOp.getMaskingOp().getMask();
}
- Value t(Value v) {
- static constexpr std::array<int64_t, 2> perm = {1, 0};
+ Value t(Value v, ArrayRef<int64_t> perm = {1, 0}) {
if (!v)
return v;
return rewriter.create<vector::TransposeOp>(loc, v, perm);
bindDims(rewriter.getContext(), m, n, k);
// Classical row-major matmul: Just permute the lhs.
if (layout({{m, k}, {k, n}, {m, n}}))
- return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1));
+ return outerProd(t(lhs), rhs, res, lhsType.getDimSize(1),
+ t(mask, {2, 0, 1}));
// TODO: may be better to fail and use some vector<k> -> scalar reduction.
if (layout({{m, k}, {n, k}, {m, n}})) {
Value tlhs = t(lhs);
// -----
-func.func @masked_vector_contract(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
+func.func @masked_float_add_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
%0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<add>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
return %0 : vector<2xf32>
}
-// We can't check for the intermediate 'vector.mask { vector.fma }' state so we
-// just make sure the vector.fma is lowered.
+// CHECK-LABEL: func.func @masked_float_add_outerprod(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
+// CHECK: %[[VAL_8:.*]] = llvm.intr.fmuladd(%[[VAL_0]], %{{.*}}, %[[VAL_2]]) : (vector<2xf32>, vector<2xf32>, vector<2xf32>) -> vector<2xf32>
+// CHECK: %[[VAL_9:.*]] = arith.select %[[VAL_3]], %[[VAL_8]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
-// CHECK: llvm.intr.fmuladd
-// CHECK: llvm.select
+// -----
+
+func.func @masked_float_mul_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
+ %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<mul>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func.func @masked_float_mul_outerprod(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
+// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
+// CHECK: %[[VAL_9:.*]] = arith.mulf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
+
+// -----
+
+func.func @masked_float_max_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
+ %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<maxf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func.func @masked_float_max_outerprod(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
+// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
+// CHECK: %[[VAL_9:.*]] = arith.maxf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
+
+// -----
+
+func.func @masked_float_min_outerprod(%arg0: vector<2xf32>, %arg1: f32, %arg2: vector<2xf32>, %m: vector<2xi1>) -> vector<2xf32> {
+ %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<minf>} : vector<2xf32>, f32 } : vector<2xi1> -> vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
+// CHECK-LABEL: func.func @masked_float_min_outerprod(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xf32>, %[[VAL_1:.*]]: f32, %[[VAL_2:.*]]: vector<2xf32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xf32> {
+// CHECK: %[[VAL_8:.*]] = arith.mulf %[[VAL_0]], %{{.*}} : vector<2xf32>
+// CHECK: %[[VAL_9:.*]] = arith.minf %[[VAL_8]], %[[VAL_2]] : vector<2xf32>
+// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xf32>
+
+// -----
+
+func.func @masked_int_add_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
+ %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<add>} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func.func @masked_int_add_outerprod(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> {
+// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32>
+// CHECK: %[[VAL_9:.*]] = arith.addi %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
+// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
+
+// -----
+
+func.func @masked_int_mul_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
+ %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<mul>} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func.func @masked_int_mul_outerprod(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> {
+// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32>
+// CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
+// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
+
+// -----
+
+func.func @masked_int_max_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
+ %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<maxsi>} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func.func @masked_int_max_outerprod(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> {
+// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32>
+// CHECK: %[[VAL_9:.*]] = arith.maxsi %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
+// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
+
+// -----
+
+func.func @masked_int_min_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
+ %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<minui>} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func.func @masked_int_min_outerprod(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> {
+// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32>
+// CHECK: %[[VAL_9:.*]] = arith.minui %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
+// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
+
+// -----
+
+func.func @masked_int_and_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
+ %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<and>} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func.func @masked_int_and_outerprod(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> {
+// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32>
+// CHECK: %[[VAL_9:.*]] = arith.andi %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
+// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
+
+// -----
+
+func.func @masked_int_or_outerprod(%arg0: vector<2xi32>, %arg1: i32, %arg2: vector<2xi32>, %m: vector<2xi1>) -> vector<2xi32> {
+ %0 = vector.mask %m { vector.outerproduct %arg0, %arg1, %arg2 {kind = #vector.kind<or>} : vector<2xi32>, i32 } : vector<2xi1> -> vector<2xi32>
+ return %0 : vector<2xi32>
+}
+
+// CHECK-LABEL: func.func @masked_int_or_outerprod(
+// CHECK-SAME: %[[VAL_0:.*]]: vector<2xi32>, %[[VAL_1:.*]]: i32, %[[VAL_2:.*]]: vector<2xi32>, %[[VAL_3:.*]]: vector<2xi1>) -> vector<2xi32> {
+// CHECK: %[[VAL_8:.*]] = arith.muli %[[VAL_0]], %{{.*}} : vector<2xi32>
+// CHECK: %[[VAL_9:.*]] = arith.ori %[[VAL_8]], %[[VAL_2]] : vector<2xi32>
+// CHECK: %[[VAL_10:.*]] = arith.select %[[VAL_3]], %[[VAL_9]], %[[VAL_2]] : vector<2xi1>, vector<2xi32>
// -----
%0 = vector.scalable.extract %vec[0] : vector<8xf32> from vector<[4]xf32>
return %0 : vector<8xf32>
}
-
-// -----
-
-// CHECK-LABEL: func.func @masked_vector_fma(
-// CHECK-SAME: %[[INPUT:.*]]: vector<8xf32>,
-// CHECK-SAME: %[[MASK:.*]]: vector<8xi1>) -> vector<8xf32>
-// CHECK: %[[FMA:.*]] = llvm.intr.fmuladd(%[[INPUT]], %[[INPUT]], %[[INPUT]]) : (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32>
-// CHECK: llvm.select %[[MASK]], %[[FMA]], %[[INPUT]] : vector<8xi1>, vector<8xf32>
-
-func.func @masked_vector_fma(%a: vector<8xf32>, %m: vector<8xi1>) -> vector<8xf32> {
- %0 = vector.mask %m { vector.fma %a, %a, %a : vector<8xf32> } : vector<8xi1> -> vector<8xf32>
- return %0 : vector<8xf32>
-}
-
return %0 : vector<2xf32>
}
+// OUTERPRODUCT-LABEL: func.func @masked_extract_contract2(
+// OUTERPRODUCT-SAME: %[[VAL_0:.*]]: vector<2x3xf32>,
+// OUTERPRODUCT-SAME: %[[VAL_1:.*]]: vector<3xf32>,
+// OUTERPRODUCT-SAME: %[[VAL_2:.*]]: vector<2xf32>,
+// OUTERPRODUCT-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
+// OUTERPRODUCT: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
+// OUTERPRODUCT: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x2xi1>
+// OUTERPRODUCT: vector.mask %[[MASK0]] { vector.outerproduct
+
+// OUTERPRODUCT: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x2xi1>
+// OUTERPRODUCT: vector.mask %[[MASK1]] { vector.outerproduct
+
+// OUTERPRODUCT: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x2xi1>
+// OUTERPRODUCT: vector.mask %[[MASK2]] { vector.outerproduct
+
+func.func @masked_extract_contract2(%arg0: vector<2x3xf32>,
+ %arg1: vector<3xf32>,
+ %arg2: vector<2xf32>,
+ %m: vector<2x3xi1>) -> vector<2xf32> {
+ %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
+ : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
+ return %0 : vector<2xf32>
+}
+
// CHECK-LABEL: func @extract_contract2_int
// CHECK-SAME: %[[A:.*0]]: vector<2x3xi32>,
// CHECK-SAME: %[[B:.*1]]: vector<3xi32>,
return %0 : vector<2x2xf32>
}
+// OUTERPRODUCT-LABEL: func.func @masked_extract_contract4(
+// OUTERPRODUCT-SAME: %[[VAL_0:.*]]: vector<3x5xf32>,
+// OUTERPRODUCT-SAME: %[[VAL_1:.*]]: vector<5x7xf32>,
+// OUTERPRODUCT-SAME: %[[VAL_2:.*]]: vector<3x7xf32>,
+// OUTERPRODUCT-SAME: %[[VAL_3:.*]]: vector<3x7x5xi1>) -> vector<3x7xf32> {
+// OUTERPRODUCT: %[[VAL_5:.*]] = vector.transpose %[[VAL_3]], [2, 0, 1] : vector<3x7x5xi1> to vector<5x3x7xi1>
+// OUTERPRODUCT: %[[VAL_8:.*]] = vector.extract %[[VAL_5]][0] : vector<5x3x7xi1>
+// OUTERPRODUCT: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// OUTERPRODUCT: %[[VAL_12:.*]] = vector.extract %[[VAL_5]][1] : vector<5x3x7xi1>
+// OUTERPRODUCT: %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// OUTERPRODUCT: %[[VAL_16:.*]] = vector.extract %[[VAL_5]][2] : vector<5x3x7xi1>
+// OUTERPRODUCT: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// OUTERPRODUCT: %[[VAL_20:.*]] = vector.extract %[[VAL_5]][3] : vector<5x3x7xi1>
+// OUTERPRODUCT: %[[VAL_21:.*]] = vector.mask %[[VAL_20]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+// OUTERPRODUCT: %[[VAL_24:.*]] = vector.extract %[[VAL_5]][4] : vector<5x3x7xi1>
+// OUTERPRODUCT: %[[VAL_25:.*]] = vector.mask %[[VAL_24]] { vector.outerproduct %{{.*}} {kind = #vector.kind<add>} : vector<3xf32>, vector<7xf32> } : vector<3x7xi1> -> vector<3x7xf32>
+
+func.func @masked_extract_contract4(%arg0: vector<3x5xf32>,
+ %arg1: vector<5x7xf32>,
+ %arg2: vector<3x7xf32>,
+ %m : vector<3x7x5xi1>) -> vector<3x7xf32> {
+ %0 = vector.mask %m { vector.contract #matmat_trait %arg0, %arg1, %arg2
+ : vector<3x5xf32>, vector<5x7xf32> into vector<3x7xf32> } : vector<3x7x5xi1> -> vector<3x7xf32>
+ return %0 : vector<3x7xf32>
+}
+
#contraction2d_accesses = [
affine_map<(i, j) -> (i, j)>,
affine_map<(i, j) -> (i, j)>,
return %0 : f32
}
-func.func @masked_vector_contract(%arg0: vector<2x3xf32>,
- %arg1: vector<3xf32>,
- %arg2: vector<2xf32>,
- %m: vector<2x3xi1>) -> vector<2xf32> {
- %0 = vector.mask %m { vector.contract #matvec_trait %arg0, %arg1, %arg2
- : vector<2x3xf32>, vector<3xf32> into vector<2xf32> } : vector<2x3xi1> -> vector<2xf32>
- return %0 : vector<2xf32>
-}
-// OUTERPRODUCT-LABEL: func.func @masked_vector_contract(
-// OUTERPRODUCT-SAME: %[[VAL_0:.*]]: vector<2x3xf32>,
-// OUTERPRODUCT-SAME: %[[VAL_1:.*]]: vector<3xf32>,
-// OUTERPRODUCT-SAME: %[[VAL_2:.*]]: vector<2xf32>,
-// OUTERPRODUCT-SAME: %[[IN_MASK:.*]]: vector<2x3xi1>) -> vector<2xf32>
-// OUTERPRODUCT: %[[T_MASK:.*]] = vector.transpose %[[IN_MASK]], [1, 0] : vector<2x3xi1> to vector<3x2xi1>
-// OUTERPRODUCT: %[[MASK0:.*]] = vector.extract %[[T_MASK]][0] : vector<3x2xi1>
-// OUTERPRODUCT: vector.mask %[[MASK0]] { vector.outerproduct
-
-// OUTERPRODUCT: %[[MASK1:.*]] = vector.extract %[[T_MASK]][1] : vector<3x2xi1>
-// OUTERPRODUCT: vector.mask %[[MASK1]] { vector.outerproduct
-
-// OUTERPRODUCT: %[[MASK2:.*]] = vector.extract %[[T_MASK]][2] : vector<3x2xi1>
-// OUTERPRODUCT: vector.mask %[[MASK2]] { vector.outerproduct