From e9b82a5c4fb6fa1c0af1b8e2536252b0730f41ef Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Mon, 13 Feb 2023 19:31:23 +0000 Subject: [PATCH] [mlir][Vector] Add LLVM lowering for masked reductions This patch adds the conversion patterns to lower masked reduction operations to the corresponding vp intrinsics in LLVM. Reviewed By: ftynse Differential Revision: https://reviews.llvm.org/D142177 --- .../VectorToLLVM/ConvertVectorToLLVM.cpp | 406 +++++++++++++++++++-- .../VectorToLLVM/vector-reduction-to-llvm.mlir | 196 +++++++++- 2 files changed, 564 insertions(+), 38 deletions(-) diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp index 6eabe49..68865d3 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/LLVMIR/FunctionCallUtils.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h" #include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/TypeUtilities.h" @@ -408,15 +409,154 @@ public: } }; +/// Reduction neutral classes for overloading. +class ReductionNeutralZero {}; +class ReductionNeutralIntOne {}; +class ReductionNeutralFPOne {}; +class ReductionNeutralAllOnes {}; +class ReductionNeutralSIntMin {}; +class ReductionNeutralUIntMin {}; +class ReductionNeutralSIntMax {}; +class ReductionNeutralUIntMax {}; +class ReductionNeutralFPMin {}; +class ReductionNeutralFPMax {}; + +/// Create the reduction neutral zero value. +static Value createReductionNeutralValue(ReductionNeutralZero neutral, + ConversionPatternRewriter &rewriter, + Location loc, Type llvmType) { + return rewriter.create(loc, llvmType, + rewriter.getZeroAttr(llvmType)); +} + +/// Create the reduction neutral integer one value. +static Value createReductionNeutralValue(ReductionNeutralIntOne neutral, + ConversionPatternRewriter &rewriter, + Location loc, Type llvmType) { + return rewriter.create( + loc, llvmType, rewriter.getIntegerAttr(llvmType, 1)); +} + +/// Create the reduction neutral fp one value. +static Value createReductionNeutralValue(ReductionNeutralFPOne neutral, + ConversionPatternRewriter &rewriter, + Location loc, Type llvmType) { + return rewriter.create( + loc, llvmType, rewriter.getFloatAttr(llvmType, 1.0)); +} + +/// Create the reduction neutral all-ones value. +static Value createReductionNeutralValue(ReductionNeutralAllOnes neutral, + ConversionPatternRewriter &rewriter, + Location loc, Type llvmType) { + return rewriter.create( + loc, llvmType, + rewriter.getIntegerAttr( + llvmType, llvm::APInt::getAllOnes(llvmType.getIntOrFloatBitWidth()))); +} + +/// Create the reduction neutral signed int minimum value. +static Value createReductionNeutralValue(ReductionNeutralSIntMin neutral, + ConversionPatternRewriter &rewriter, + Location loc, Type llvmType) { + return rewriter.create( + loc, llvmType, + rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMinValue( + llvmType.getIntOrFloatBitWidth()))); +} + +/// Create the reduction neutral unsigned int minimum value. +static Value createReductionNeutralValue(ReductionNeutralUIntMin neutral, + ConversionPatternRewriter &rewriter, + Location loc, Type llvmType) { + return rewriter.create( + loc, llvmType, + rewriter.getIntegerAttr(llvmType, llvm::APInt::getMinValue( + llvmType.getIntOrFloatBitWidth()))); +} + +/// Create the reduction neutral signed int maximum value. +static Value createReductionNeutralValue(ReductionNeutralSIntMax neutral, + ConversionPatternRewriter &rewriter, + Location loc, Type llvmType) { + return rewriter.create( + loc, llvmType, + rewriter.getIntegerAttr(llvmType, llvm::APInt::getSignedMaxValue( + llvmType.getIntOrFloatBitWidth()))); +} + +/// Create the reduction neutral unsigned int maximum value. +static Value createReductionNeutralValue(ReductionNeutralUIntMax neutral, + ConversionPatternRewriter &rewriter, + Location loc, Type llvmType) { + return rewriter.create( + loc, llvmType, + rewriter.getIntegerAttr(llvmType, llvm::APInt::getMaxValue( + llvmType.getIntOrFloatBitWidth()))); +} + +/// Create the reduction neutral fp minimum value. +static Value createReductionNeutralValue(ReductionNeutralFPMin neutral, + ConversionPatternRewriter &rewriter, + Location loc, Type llvmType) { + auto floatType = llvmType.cast(); + return rewriter.create( + loc, llvmType, + rewriter.getFloatAttr( + llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), + /*Negative=*/false))); +} + +/// Create the reduction neutral fp maximum value. +static Value createReductionNeutralValue(ReductionNeutralFPMax neutral, + ConversionPatternRewriter &rewriter, + Location loc, Type llvmType) { + auto floatType = llvmType.cast(); + return rewriter.create( + loc, llvmType, + rewriter.getFloatAttr( + llvmType, llvm::APFloat::getQNaN(floatType.getFloatSemantics(), + /*Negative=*/true))); +} + +/// Returns `accumulator` if it has a valid value. Otherwise, creates and +/// returns a new accumulator value using `ReductionNeutral`. +template +static Value getOrCreateAccumulator(ConversionPatternRewriter &rewriter, + Location loc, Type llvmType, + Value accumulator) { + if (accumulator) + return accumulator; + + return createReductionNeutralValue(ReductionNeutral(), rewriter, loc, + llvmType); +} + +/// Creates a constant value with the 1-D vector shape provided in `llvmType`. +/// This is used as effective vector length by some intrinsics supporting +/// dynamic vector lengths at runtime. +static Value createVectorLengthValue(ConversionPatternRewriter &rewriter, + Location loc, Type llvmType) { + VectorType vType = cast(llvmType); + auto vShape = vType.getShape(); + assert(vShape.size() == 1 && "Unexpected multi-dim vector type"); + + return rewriter.create( + loc, rewriter.getI32Type(), + rewriter.getIntegerAttr(rewriter.getI32Type(), vShape[0])); +} + /// Helper method to lower a `vector.reduction` op that performs an arithmetic /// operation like add,mul, etc.. `VectorOp` is the LLVM vector intrinsic to use /// and `ScalarOp` is the scalar operation used to add the accumulation value if /// non-null. -template +template static Value createIntegerReductionArithmeticOpLowering( ConversionPatternRewriter &rewriter, Location loc, Type llvmType, Value vectorOperand, Value accumulator) { - Value result = rewriter.create(loc, llvmType, vectorOperand); + + Value result = rewriter.create(loc, llvmType, vectorOperand); + if (accumulator) result = rewriter.create(loc, accumulator, result); return result; @@ -426,11 +566,11 @@ static Value createIntegerReductionArithmeticOpLowering( /// a comparison operation like `min`/`max`. `VectorOp` is the LLVM vector /// intrinsic to use and `predicate` is the predicate to use to compare+combine /// the accumulator value if non-null. -template +template static Value createIntegerReductionComparisonOpLowering( ConversionPatternRewriter &rewriter, Location loc, Type llvmType, Value vectorOperand, Value accumulator, LLVM::ICmpPredicate predicate) { - Value result = rewriter.create(loc, llvmType, vectorOperand); + Value result = rewriter.create(loc, llvmType, vectorOperand); if (accumulator) { Value cmp = rewriter.create(loc, predicate, accumulator, result); @@ -460,6 +600,91 @@ static Value createMinMaxF(OpBuilder &builder, Location loc, Value lhs, return builder.create(loc, isNan, nan, sel); } +template +static Value createFPReductionComparisonOpLowering( + ConversionPatternRewriter &rewriter, Location loc, Type llvmType, + Value vectorOperand, Value accumulator, bool isMin) { + Value result = rewriter.create(loc, llvmType, vectorOperand); + + if (accumulator) + result = createMinMaxF(rewriter, loc, result, accumulator, /*isMin=*/isMin); + + return result; +} + +/// Overloaded methods to lower a reduction to an llvm instrinsic that requires +/// a start value. This start value format spans across fp reductions without +/// mask and all the masked reduction intrinsics. +template +static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, + Location loc, Type llvmType, + Value vectorOperand, + Value accumulator) { + accumulator = getOrCreateAccumulator(rewriter, loc, + llvmType, accumulator); + return rewriter.create(loc, llvmType, + /*startValue=*/accumulator, + vectorOperand); +} + +template +static Value +lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, Location loc, + Type llvmType, Value vectorOperand, + Value accumulator, bool reassociateFPReds) { + accumulator = getOrCreateAccumulator(rewriter, loc, + llvmType, accumulator); + return rewriter.create(loc, llvmType, + /*startValue=*/accumulator, + vectorOperand, reassociateFPReds); +} + +template +static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, + Location loc, Type llvmType, + Value vectorOperand, + Value accumulator, Value mask) { + accumulator = getOrCreateAccumulator(rewriter, loc, + llvmType, accumulator); + Value vectorLength = + createVectorLengthValue(rewriter, loc, vectorOperand.getType()); + return rewriter.create(loc, llvmType, + /*startValue=*/accumulator, + vectorOperand, mask, vectorLength); +} + +template +static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, + Location loc, Type llvmType, + Value vectorOperand, + Value accumulator, Value mask, + bool reassociateFPReds) { + accumulator = getOrCreateAccumulator(rewriter, loc, + llvmType, accumulator); + Value vectorLength = + createVectorLengthValue(rewriter, loc, vectorOperand.getType()); + return rewriter.create(loc, llvmType, + /*startValue=*/accumulator, + vectorOperand, mask, vectorLength, + reassociateFPReds); +} + +template +static Value lowerReductionWithStartValue(ConversionPatternRewriter &rewriter, + Location loc, Type llvmType, + Value vectorOperand, + Value accumulator, Value mask) { + if (llvmType.isIntOrIndex()) + return lowerReductionWithStartValue( + rewriter, loc, llvmType, vectorOperand, accumulator, mask); + + // FP dispatch. + return lowerReductionWithStartValue( + rewriter, loc, llvmType, vectorOperand, accumulator, mask); +} + /// Conversion pattern for all vector reductions. class VectorReductionOpConversion : public ConvertOpToLLVMPattern { @@ -478,6 +703,12 @@ public: Value operand = adaptor.getVector(); Value acc = adaptor.getAcc(); Location loc = reductionOp.getLoc(); + + // Masked reductions are lowered separately. + auto maskableOp = cast(reductionOp.getOperation()); + if (maskableOp.isMasked()) + return failure(); + if (eltType.isIntOrIndex()) { // Integer reductions: add/mul/min/max/and/or/xor. Value result; @@ -544,45 +775,31 @@ public: return failure(); // Floating-point reductions: add/mul/min/max + Value result; if (kind == vector::CombiningKind::ADD) { - // Optional accumulator (or zero). - Value acc = adaptor.getOperands().size() > 1 - ? adaptor.getOperands()[1] - : rewriter.create( - reductionOp->getLoc(), llvmType, - rewriter.getZeroAttr(eltType)); - rewriter.replaceOpWithNewOp( - reductionOp, llvmType, acc, operand, - rewriter.getBoolAttr(reassociateFPReductions)); + result = lowerReductionWithStartValue( + rewriter, loc, llvmType, operand, acc, reassociateFPReductions); } else if (kind == vector::CombiningKind::MUL) { - // Optional accumulator (or one). - Value acc = adaptor.getOperands().size() > 1 - ? adaptor.getOperands()[1] - : rewriter.create( - reductionOp->getLoc(), llvmType, - rewriter.getFloatAttr(eltType, 1.0)); - rewriter.replaceOpWithNewOp( - reductionOp, llvmType, acc, operand, - rewriter.getBoolAttr(reassociateFPReductions)); + result = lowerReductionWithStartValue( + rewriter, loc, llvmType, operand, acc, reassociateFPReductions); } else if (kind == vector::CombiningKind::MINF) { // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle // NaNs/-0.0/+0.0 in the same way. - Value result = - rewriter.create(loc, llvmType, operand); - if (acc) - result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/true); - rewriter.replaceOp(reductionOp, result); + result = createFPReductionComparisonOpLowering( + rewriter, loc, llvmType, operand, acc, + /*isMin=*/true); } else if (kind == vector::CombiningKind::MAXF) { // FIXME: MLIR's 'maxf' and LLVM's 'vector_reduce_fmax' do not handle // NaNs/-0.0/+0.0 in the same way. - Value result = - rewriter.create(loc, llvmType, operand); - if (acc) - result = createMinMaxF(rewriter, loc, result, acc, /*isMin=*/false); - rewriter.replaceOp(reductionOp, result); + result = createFPReductionComparisonOpLowering( + rewriter, loc, llvmType, operand, acc, + /*isMin=*/false); } else return failure(); + rewriter.replaceOp(reductionOp, result); return success(); } @@ -590,6 +807,127 @@ private: const bool reassociateFPReductions; }; +/// Base class to convert a `vector.mask` operation while matching traits +/// of the maskable operation nested inside. A `VectorMaskOpConversionBase` +/// instance matches against a `vector.mask` operation. The `matchAndRewrite` +/// method performs a second match against the maskable operation `MaskedOp`. +/// Finally, it invokes the virtual method `matchAndRewriteMaskableOp` to be +/// implemented by the concrete conversion classes. This method can match +/// against specific traits of the `vector.mask` and the maskable operation. It +/// must replace the `vector.mask` operation. +template +class VectorMaskOpConversionBase + : public ConvertOpToLLVMPattern { +public: + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override final { + // Match against the maskable operation kind. + Operation *maskableOp = maskOp.getMaskableOp(); + if (!isa(maskableOp)) + return failure(); + return matchAndRewriteMaskableOp( + maskOp, cast(maskOp.getMaskableOp()), rewriter); + } + +protected: + virtual LogicalResult + matchAndRewriteMaskableOp(vector::MaskOp maskOp, + vector::MaskableOpInterface maskableOp, + ConversionPatternRewriter &rewriter) const = 0; +}; + +class MaskedReductionOpConversion + : public VectorMaskOpConversionBase { + +public: + using VectorMaskOpConversionBase< + vector::ReductionOp>::VectorMaskOpConversionBase; + + virtual LogicalResult matchAndRewriteMaskableOp( + vector::MaskOp maskOp, MaskableOpInterface maskableOp, + ConversionPatternRewriter &rewriter) const override { + auto reductionOp = cast(maskableOp.getOperation()); + auto kind = reductionOp.getKind(); + Type eltType = reductionOp.getDest().getType(); + Type llvmType = typeConverter->convertType(eltType); + Value operand = reductionOp.getVector(); + Value acc = reductionOp.getAcc(); + Location loc = reductionOp.getLoc(); + + Value result; + switch (kind) { + case vector::CombiningKind::ADD: + result = lowerReductionWithStartValue< + LLVM::VPReduceAddOp, ReductionNeutralZero, LLVM::VPReduceFAddOp, + ReductionNeutralZero>(rewriter, loc, llvmType, operand, acc, + maskOp.getMask()); + break; + case vector::CombiningKind::MUL: + result = lowerReductionWithStartValue< + LLVM::VPReduceMulOp, ReductionNeutralIntOne, LLVM::VPReduceFMulOp, + ReductionNeutralFPOne>(rewriter, loc, llvmType, operand, acc, + maskOp.getMask()); + break; + case vector::CombiningKind::MINUI: + result = lowerReductionWithStartValue( + rewriter, loc, llvmType, operand, acc, maskOp.getMask()); + break; + case vector::CombiningKind::MINSI: + result = lowerReductionWithStartValue( + rewriter, loc, llvmType, operand, acc, maskOp.getMask()); + break; + case vector::CombiningKind::MAXUI: + result = lowerReductionWithStartValue( + rewriter, loc, llvmType, operand, acc, maskOp.getMask()); + break; + case vector::CombiningKind::MAXSI: + result = lowerReductionWithStartValue( + rewriter, loc, llvmType, operand, acc, maskOp.getMask()); + break; + case vector::CombiningKind::AND: + result = lowerReductionWithStartValue( + rewriter, loc, llvmType, operand, acc, maskOp.getMask()); + break; + case vector::CombiningKind::OR: + result = lowerReductionWithStartValue( + rewriter, loc, llvmType, operand, acc, maskOp.getMask()); + break; + case vector::CombiningKind::XOR: + result = lowerReductionWithStartValue( + rewriter, loc, llvmType, operand, acc, maskOp.getMask()); + break; + case vector::CombiningKind::MINF: + // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle + // NaNs/-0.0/+0.0 in the same way. + result = lowerReductionWithStartValue( + rewriter, loc, llvmType, operand, acc, maskOp.getMask()); + break; + case vector::CombiningKind::MAXF: + // FIXME: MLIR's 'minf' and LLVM's 'vector_reduce_fmin' do not handle + // NaNs/-0.0/+0.0 in the same way. + result = lowerReductionWithStartValue( + rewriter, loc, llvmType, operand, acc, maskOp.getMask()); + break; + } + + // Replace `vector.mask` operation altogether. + rewriter.replaceOp(maskOp, result); + return success(); + } +}; + class VectorShuffleOpConversion : public ConvertOpToLLVMPattern { public: @@ -1381,8 +1719,8 @@ void mlir::populateVectorToLLVMConversionPatterns( VectorGatherOpConversion, VectorScatterOpConversion, VectorExpandLoadOpConversion, VectorCompressStoreOpConversion, VectorSplatOpLowering, VectorSplatNdOpLowering, - VectorScalableInsertOpLowering, VectorScalableExtractOpLowering>( - converter); + VectorScalableInsertOpLowering, VectorScalableExtractOpLowering, + MaskedReductionOpConversion>(converter); // Transfer ops with rank > 1 are handled by VectorToSCF. populateVectorTransferLoweringPatterns(patterns, /*maxTransferRank=*/1); } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir index 8da65b8..0b61200 100644 --- a/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-reduction-to-llvm.mlir @@ -1,7 +1,6 @@ -// RUN: mlir-opt %s -convert-vector-to-llvm='use-opaque-pointers=1' | FileCheck %s -// RUN: mlir-opt %s -convert-vector-to-llvm='reassociate-fp-reductions use-opaque-pointers=1' | FileCheck %s --check-prefix=REASSOC +// RUN: mlir-opt %s -convert-vector-to-llvm='use-opaque-pointers=1' -split-input-file | FileCheck %s +// RUN: mlir-opt %s -convert-vector-to-llvm='reassociate-fp-reductions use-opaque-pointers=1' -split-input-file | FileCheck %s --check-prefix=REASSOC -// // CHECK-LABEL: @reduce_add_f32( // CHECK-SAME: %[[A:.*]]: vector<16xf32>) // CHECK: %[[C:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 @@ -21,7 +20,8 @@ func.func @reduce_add_f32(%arg0: vector<16xf32>) -> f32 { return %0 : f32 } -// +// ----- + // CHECK-LABEL: @reduce_mul_f32( // CHECK-SAME: %[[A:.*]]: vector<16xf32>) // CHECK: %[[C:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 @@ -40,3 +40,191 @@ func.func @reduce_mul_f32(%arg0: vector<16xf32>) -> f32 { %0 = vector.reduction , %arg0 : vector<16xf32> into f32 return %0 : f32 } + +// ----- + +func.func @masked_reduce_add_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32 + return %0 : f32 +} + +// CHECK-LABEL: func.func @masked_reduce_add_f32( +// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>, +// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 { +// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0.000000e+00 : f32) : f32 +// CHECK: %[[VL:.*]] = llvm.mlir.constant(16 : i32) : i32 +// CHECK: "llvm.intr.vp.reduce.fadd"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32 + + +// ----- + +func.func @masked_reduce_mul_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32 + return %0 : f32 +} + +// CHECK-LABEL: func.func @masked_reduce_mul_f32( +// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>, +// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 { +// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(1.000000e+00 : f32) : f32 +// CHECK: %[[VL:.*]] = llvm.mlir.constant(16 : i32) : i32 +// CHECK: "llvm.intr.vp.reduce.fmul"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32 + + +// ----- + +func.func @masked_reduce_minf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32 + return %0 : f32 +} + +// CHECK-LABEL: func.func @masked_reduce_minf_f32( +// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>, +// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 { +// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0xFFC00000 : f32) : f32 +// CHECK: %[[VL:.*]] = llvm.mlir.constant(16 : i32) : i32 +// CHECK: "llvm.intr.vp.reduce.fmin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32 + +// ----- + +func.func @masked_reduce_maxf_f32(%arg0: vector<16xf32>, %mask : vector<16xi1>) -> f32 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<16xf32> into f32 } : vector<16xi1> -> f32 + return %0 : f32 +} + +// CHECK-LABEL: func.func @masked_reduce_maxf_f32( +// CHECK-SAME: %[[INPUT:.*]]: vector<16xf32>, +// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>) -> f32 { +// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0x7FC00000 : f32) : f32 +// CHECK: %[[VL:.*]] = llvm.mlir.constant(16 : i32) : i32 +// CHECK: "llvm.intr.vp.reduce.fmax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (f32, vector<16xf32>, vector<16xi1>, i32) -> f32 + +// ----- + +func.func @masked_reduce_add_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8 + return %0 : i8 +} + +// CHECK-LABEL: func.func @masked_reduce_add_i8( +// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>, +// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 { +// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8 +// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32 +// CHECK: "llvm.intr.vp.reduce.add"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8 + + +// ----- + +func.func @masked_reduce_mul_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8 + return %0 : i8 +} + +// CHECK-LABEL: func.func @masked_reduce_mul_i8( +// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>, +// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 { +// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(1 : i8) : i8 +// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32 +// CHECK: %[[VAL_4:.*]] = "llvm.intr.vp.reduce.mul"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8 + +// ----- + +func.func @masked_reduce_minui_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8 + return %0 : i8 +} + +// CHECK-LABEL: func.func @masked_reduce_minui_i8( +// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>, +// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 { +// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-1 : i8) : i8 +// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32 +// CHECK: "llvm.intr.vp.reduce.umin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8 + +// ----- + +func.func @masked_reduce_maxui_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8 + return %0 : i8 +} + +// CHECK-LABEL: func.func @masked_reduce_maxui_i8( +// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>, +// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 { +// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8 +// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32 +// CHECK: "llvm.intr.vp.reduce.umax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8 + +// ----- + +func.func @masked_reduce_minsi_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8 + return %0 : i8 +} + +// CHECK-LABEL: func.func @masked_reduce_minsi_i8( +// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>, +// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 { +// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(127 : i8) : i8 +// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32 +// CHECK: "llvm.intr.vp.reduce.smin"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8 + +// ----- + +func.func @masked_reduce_maxsi_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8 + return %0 : i8 +} + +// CHECK-LABEL: func.func @masked_reduce_maxsi_i8( +// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>, +// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 { +// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-128 : i8) : i8 +// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32 +// CHECK: "llvm.intr.vp.reduce.smax"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8 + +// ----- + +func.func @masked_reduce_or_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8 + return %0 : i8 +} + +// CHECK-LABEL: func.func @masked_reduce_or_i8( +// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>, +// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 { +// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8 +// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32 +// CHECK: "llvm.intr.vp.reduce.or"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8 + + +// ----- + +func.func @masked_reduce_and_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8 + return %0 : i8 +} + +// CHECK-LABEL: func.func @masked_reduce_and_i8( +// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>, +// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 { +// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(-1 : i8) : i8 +// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32 +// CHECK: "llvm.intr.vp.reduce.and"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8 + +// ----- + +func.func @masked_reduce_xor_i8(%arg0: vector<32xi8>, %mask : vector<32xi1>) -> i8 { + %0 = vector.mask %mask { vector.reduction , %arg0 : vector<32xi8> into i8 } : vector<32xi1> -> i8 + return %0 : i8 +} + +// CHECK-LABEL: func.func @masked_reduce_xor_i8( +// CHECK-SAME: %[[INPUT:.*]]: vector<32xi8>, +// CHECK-SAME: %[[MASK:.*]]: vector<32xi1>) -> i8 { +// CHECK: %[[NEUTRAL:.*]] = llvm.mlir.constant(0 : i8) : i8 +// CHECK: %[[VL:.*]] = llvm.mlir.constant(32 : i32) : i32 +// CHECK: "llvm.intr.vp.reduce.xor"(%[[NEUTRAL]], %[[INPUT]], %[[MASK]], %[[VL]]) : (i8, vector<32xi8>, vector<32xi1>, i32) -> i8 + + -- 2.7.4