From afc3756e6c6da68f066703f384aca6c2ffb54286 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Fri, 13 Jan 2023 20:36:40 +0000 Subject: [PATCH] [mlir][vector] Masking support for reductions in Linalg vectorizer This patch enables vectorization of reductions in Linalg vectorizer using the vector.mask operation. It also introduces the logic to slice and propagate the vector mask of a masked multi-reduction to their respective lowering operations. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D141571 --- mlir/include/mlir/Dialect/Vector/IR/VectorOps.h | 14 +++ mlir/include/mlir/Dialect/Vector/IR/VectorOps.td | 18 ++-- .../Dialect/Linalg/Transforms/Vectorization.cpp | 43 ++------ mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 53 ++++++++-- .../VectorMultiDimReductionTransforms.cpp | 111 ++++++++++++++++++--- mlir/test/Dialect/Linalg/vectorization.mlir | 77 +++++++++++++- .../Vector/vector-multi-reduction-lowering.mlir | 101 ++++++++++++++++++- 7 files changed, 351 insertions(+), 66 deletions(-) diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h index 0028abe..deb86df 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.h @@ -203,6 +203,20 @@ inline bool isReductionIterator(Attribute attr) { return attr.cast().getValue() == IteratorType::reduction; } +//===----------------------------------------------------------------------===// +// Vector Masking Utilities +//===----------------------------------------------------------------------===// + +/// Create the vector.yield-ended region of a vector.mask op with `maskableOp` +/// as masked operation. +void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp); + +/// Creates a vector.mask operation around a maskable operation. Returns the +/// vector.mask operation if the mask provided is valid. Otherwise, returns the +/// maskable operation itself. +Operation *maskOperation(RewriterBase &rewriter, Operation *maskableOp, + Value mask); + } // namespace vector } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index 5a14f0d..8c5d44a 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -340,6 +340,7 @@ def Vector_MultiDimReductionOp : PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]>, Arguments<(ins Vector_CombiningKindAttr:$kind, @@ -2338,16 +2339,13 @@ def Vector_MaskOp : Vector_Op<"mask", [ let skipDefaultBuilders = 1; let builders = [ - OpBuilder<(ins "Value":$mask, - CArg<"function_ref", - "buildTerminatedBody">:$maskRegion)>, - OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask, - CArg<"function_ref", - "buildTerminatedBody">:$maskRegion)>, - OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask, - "Value":$passthru, - CArg<"function_ref", - "buildTerminatedBody">:$maskRegion)> + OpBuilder<(ins "Value":$mask, "Operation *":$maskableOp, + CArg<"function_ref">:$maskRegion)>, + OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask, "Operation *":$maskableOp, + CArg<"function_ref">:$maskRegion)>, + OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask, "Value":$passthru, + "Operation *":$maskableOp, + CArg<"function_ref">:$maskRegion)> ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 1e83350..5f367d1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -292,25 +292,8 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask, // Wrap the operation with a new `vector.mask` and update D-U chain. assert(opToMask && "Expected a valid operation to mask"); - auto opResults = opToMask->getResultTypes(); - auto createRegionMask = [opToMask](OpBuilder &builder, Location loc) { - Block *insBlock = builder.getInsertionBlock(); - // Create a block, put an op in that block. Look for a utility. - // Maybe in conversion pattern rewriter. Way to avoid splice. - // Set insertion point. - insBlock->getOperations().splice( - insBlock->begin(), opToMask->getBlock()->getOperations(), opToMask); - builder.create(loc, opToMask->getResults()); - }; - // TODO: Allow multiple results in vector.mask. - auto maskOp = - opResults.empty() - ? rewriter.create(opToMask->getLoc(), mask, - createRegionMask) - : rewriter.create(opToMask->getLoc(), - opToMask->getResultTypes().front(), - mask, createRegionMask); - + auto maskOp = cast( + mlir::vector::maskOperation(rewriter, opToMask, mask)); Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back(); for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults())) @@ -440,17 +423,16 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value, /// initial value.buildMultiDimReduce // Note: this is a true builder that notifies the OpBuilder listener. // TODO: Consider moving as a static helper on the ReduceOp. -static Operation *buildMultiDimReduce(OpBuilder &b, - Operation *reduceOp, Value valueToReduce, - Value acc, - const SmallVector &reductionMask) { +static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp, + Value valueToReduce, Value acc, + ArrayRef dimsToMask) { auto maybeKind = getCombinerOpKind(reduceOp); assert(maybeKind && "Failed precondition: could not get reduction kind"); return b.create( - reduceOp->getLoc(), valueToReduce, acc, reductionMask, *maybeKind); + reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind); } -static SmallVector getReductionMask(LinalgOp linalgOp) { +static SmallVector getDimsToReduce(LinalgOp linalgOp) { return llvm::to_vector( llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator)); } @@ -701,8 +683,8 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op, if (!reduceType || (outputType && reduceType.getShape() == outputType.getShape())) return nullptr; - SmallVector reductionMask = getReductionMask(linalgOp); - return buildMultiDimReduce(b, op, reduceVec, outputVec, reductionMask); + SmallVector dimsToMask = getDimsToReduce(linalgOp); + return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask); } /// Generic vectorization for a single operation `op`, given already vectorized @@ -972,11 +954,8 @@ static LogicalResult reductionPreconditions(LinalgOp op) { } static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) { - // TODO: Masking only supports dynamic generic ops without reductions for now. - if (!isElementwise(op) && - llvm::any_of(op.getIteratorTypesArray(), [](utils::IteratorType itType) { - return itType != utils::IteratorType::parallel; - })) + // TODO: Masking only supports dynamic generic ops for now. + if (!isa(op)) return failure(); // TODO: 0-d vectors are not supported yet. diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index f00d849..9339452 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -342,6 +342,13 @@ LogicalResult MultiDimReductionOp::verify() { return success(); } +/// Returns the mask type expected by this operation. +Type MultiDimReductionOp::getExpectedMaskType() { + auto vecType = getSourceVectorType(); + return VectorType::get(vecType.getShape(), + IntegerType::get(vecType.getContext(), /*width=*/1)); +} + namespace { // Only unit dimensions that are being reduced are folded. If the dimension is // unit, but not reduced, it is not folded, thereby keeping the output type the @@ -5276,7 +5283,8 @@ void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results, void MaskOp::build( OpBuilder &builder, OperationState &result, Value mask, - function_ref maskRegionBuilder) { + Operation *maskableOp, + function_ref maskRegionBuilder) { assert(maskRegionBuilder && "builder callback for 'maskRegion' must be present"); @@ -5284,21 +5292,22 @@ void MaskOp::build( OpBuilder::InsertionGuard guard(builder); Region *maskRegion = result.addRegion(); builder.createBlock(maskRegion); - maskRegionBuilder(builder, result.location); + maskRegionBuilder(builder, maskableOp); } void MaskOp::build( OpBuilder &builder, OperationState &result, TypeRange resultTypes, - Value mask, function_ref maskRegionBuilder) { - build(builder, result, resultTypes, mask, /*passthru=*/Value(), + Value mask, Operation *maskableOp, + function_ref maskRegionBuilder) { + build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp, maskRegionBuilder); } void MaskOp::build( - OpBuilder &builder, OperationState &result, TypeRange resultTypes, - Value mask, Value passthru, - function_ref maskRegionBuilder) { - build(builder, result, mask, maskRegionBuilder); + OpBuilder &builder, OperationState &result, TypeRange resultTypes, Value mask, + Value passthru, Operation *maskableOp, + function_ref maskRegionBuilder) { + build(builder, result, mask, maskableOp, maskRegionBuilder); if (passthru) result.addOperands(passthru); result.addTypes(resultTypes); @@ -5739,6 +5748,34 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc, } //===----------------------------------------------------------------------===// +// Vector Masking Utilities +//===----------------------------------------------------------------------===// + +/// Create the vector.yield-ended region of a vector.mask op with `maskableOp` +/// as masked operation. +void mlir::vector::createMaskOpRegion(OpBuilder &builder, + Operation *maskableOp) { + assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block"); + Block *insBlock = builder.getInsertionBlock(); + // Create a block and move the op to that block. + insBlock->getOperations().splice( + insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp); + builder.create(maskableOp->getLoc(), maskableOp->getResults()); +} + +/// Creates a vector.mask operation around a maskable operation. Returns the +/// vector.mask operation if the mask provided is valid. Otherwise, returns +/// the maskable operation itself. +Operation *mlir::vector::maskOperation(RewriterBase &rewriter, + Operation *maskableOp, Value mask) { + if (!mask) + return maskableOp; + return rewriter.create(maskableOp->getLoc(), + maskableOp->getResultTypes(), mask, maskableOp, + createMaskOpRegion); +} + +//===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp index 31a2452..e89059c 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp @@ -12,9 +12,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" -#include "mlir/Dialect/Vector/Utils/VectorUtils.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/IR/TypeUtilities.h" #define DEBUG_TYPE "vector-multi-reduction" @@ -40,6 +38,18 @@ public: LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { + // Vector mask setup. + OpBuilder::InsertionGuard guard(rewriter); + auto maskableOp = + cast(multiReductionOp.getOperation()); + Operation *rootOp; + if (maskableOp.isMasked()) { + rewriter.setInsertionPoint(maskableOp.getMaskingOp()); + rootOp = maskableOp.getMaskingOp(); + } else { + rootOp = multiReductionOp; + } + auto src = multiReductionOp.getSource(); auto loc = multiReductionOp.getLoc(); auto srcRank = multiReductionOp.getSourceVectorType().getRank(); @@ -79,6 +89,15 @@ public: indices.append(reductionDims.begin(), reductionDims.end()); indices.append(parallelDims.begin(), parallelDims.end()); } + + // If masked, transpose the original mask. + Value transposedMask; + if (maskableOp.isMasked()) { + transposedMask = rewriter.create( + loc, maskableOp.getMaskingOp().getMask(), indices); + } + + // Transpose reduction source. auto transposeOp = rewriter.create(loc, src, indices); SmallVector reductionMask(srcRank, false); for (int i = 0; i < reductionSize; ++i) { @@ -87,9 +106,14 @@ public: else reductionMask[i] = true; } - rewriter.replaceOpWithNewOp( - multiReductionOp, transposeOp.getResult(), multiReductionOp.getAcc(), - reductionMask, multiReductionOp.getKind()); + + Operation *newMultiRedOp = rewriter.create( + multiReductionOp.getLoc(), transposeOp.getResult(), + multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind()); + newMultiRedOp = + mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask); + + rewriter.replaceOp(rootOp, newMultiRedOp->getResult(0)); return success(); } @@ -113,6 +137,18 @@ public: LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { + // Vector mask setup. + OpBuilder::InsertionGuard guard(rewriter); + auto maskableOp = + cast(multiReductionOp.getOperation()); + Operation *rootOp; + if (maskableOp.isMasked()) { + rewriter.setInsertionPoint(maskableOp.getMaskingOp()); + rootOp = maskableOp.getMaskingOp(); + } else { + rootOp = multiReductionOp; + } + auto srcRank = multiReductionOp.getSourceVectorType().getRank(); auto srcShape = multiReductionOp.getSourceVectorType().getShape(); auto loc = multiReductionOp.getLoc(); @@ -186,10 +222,22 @@ public: std::swap(mask.front(), mask.back()); std::swap(vectorShape.front(), vectorShape.back()); } + + Value newVectorMask; + if (maskableOp.isMasked()) { + Value vectorMask = maskableOp.getMaskingOp().getMask(); + auto maskCastedType = VectorType::get( + vectorShape, + vectorMask.getType().cast().getElementType()); + newVectorMask = + rewriter.create(loc, maskCastedType, vectorMask); + } + auto castedType = VectorType::get( vectorShape, multiReductionOp.getSourceVectorType().getElementType()); Value cast = rewriter.create( loc, castedType, multiReductionOp.getSource()); + Value acc = multiReductionOp.getAcc(); if (flattenedParallelDim) { auto accType = VectorType::get( @@ -197,24 +245,26 @@ public: multiReductionOp.getSourceVectorType().getElementType()); acc = rewriter.create(loc, accType, acc); } - // 5. Creates the flattened form of vector.multi_reduction with inner/outer + // 6. Creates the flattened form of vector.multi_reduction with inner/outer // most dim as reduction. - auto newOp = rewriter.create( + Operation *newMultiDimRedOp = rewriter.create( loc, cast, acc, mask, multiReductionOp.getKind()); + newMultiDimRedOp = + mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask); - // 6. If there are no parallel shapes, the result is a scalar. + // 7. If there are no parallel shapes, the result is a scalar. // TODO: support 0-d vectors when available. if (parallelShapes.empty()) { - rewriter.replaceOp(multiReductionOp, newOp.getDest()); + rewriter.replaceOp(rootOp, newMultiDimRedOp->getResult(0)); return success(); } - // 7. Creates shape cast for the output n-D -> 2-D + // 8. Creates shape cast for the output n-D -> 2-D. VectorType outputCastedType = VectorType::get( parallelShapes, multiReductionOp.getSourceVectorType().getElementType()); rewriter.replaceOpWithNewOp( - multiReductionOp, outputCastedType, newOp.getDest()); + rootOp, outputCastedType, newMultiDimRedOp->getResult(0)); return success(); } @@ -230,6 +280,12 @@ struct TwoDimMultiReductionToElementWise LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { + auto maskableOp = + cast(multiReductionOp.getOperation()); + if (maskableOp.isMasked()) + // TODO: Support masking. + return failure(); + auto srcRank = multiReductionOp.getSourceVectorType().getRank(); // Rank-2 ["parallel", "reduce"] or bail. if (srcRank != 2) @@ -274,6 +330,18 @@ struct TwoDimMultiReductionToReduction if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1)) return failure(); + // Vector mask setup. + OpBuilder::InsertionGuard guard(rewriter); + auto maskableOp = + cast(multiReductionOp.getOperation()); + Operation *rootOp; + if (maskableOp.isMasked()) { + rewriter.setInsertionPoint(maskableOp.getMaskingOp()); + rootOp = maskableOp.getMaskingOp(); + } else { + rootOp = multiReductionOp; + } + auto loc = multiReductionOp.getLoc(); Value result = rewriter.create( loc, multiReductionOp.getDestType(), @@ -285,13 +353,22 @@ struct TwoDimMultiReductionToReduction loc, multiReductionOp.getSource(), ArrayRef{i}); auto acc = rewriter.create( loc, multiReductionOp.getAcc(), ArrayRef{i}); - auto reducedValue = rewriter.create( + Operation *reductionOp = rewriter.create( loc, multiReductionOp.getKind(), v, acc); + + // If masked, slice the mask and mask the new reduction operation. + if (maskableOp.isMasked()) { + Value mask = rewriter.create( + loc, maskableOp.getMaskingOp().getMask(), ArrayRef{i}); + reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask); + } + result = rewriter.create( - loc, reducedValue, result, + loc, reductionOp->getResult(0), result, rewriter.create(loc, i)); } - rewriter.replaceOp(multiReductionOp, result); + + rewriter.replaceOp(rootOp, result); return success(); } }; @@ -307,6 +384,12 @@ struct OneDimMultiReductionToTwoDim LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp, PatternRewriter &rewriter) const override { + auto maskableOp = + cast(multiReductionOp.getOperation()); + if (maskableOp.isMasked()) + // TODO: Support masking. + return failure(); + auto srcRank = multiReductionOp.getSourceVectorType().getRank(); // Rank-1 or bail. if (srcRank != 1) diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index 0ccd6c4..d25ffe7 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1824,6 +1824,82 @@ transform.sequence failures(propagate) { // ----- +func.func @vectorize_dynamic_reduction(%arg0: tensor, + %arg1: tensor) -> tensor { + %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, + affine_map<(d0, d1) -> (d0)>], + iterator_types = ["parallel", "reduction"] } + ins(%arg0 : tensor) + outs(%arg1 : tensor) { + ^bb(%in: f32, %out: f32) : + %0 = arith.addf %in, %out : f32 + linalg.yield %0 : f32 + } -> tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + transform.structured.masked_vectorize %0 vector_sizes [4, 8] +} + +// CHECK-LABEL: @vectorize_dynamic_reduction( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_8:.*]] = vector.create_mask %[[VAL_3]], %[[VAL_5]] : vector<4x8xi1> +// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_0]]{{.*}} {in_bounds = [true, true]} : tensor, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32> +// CHECK: %[[VAL_11:.*]] = vector.create_mask %[[VAL_3]] : vector<4xi1> +// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_11]] { vector.transfer_read %[[VAL_1]]{{.*}} {in_bounds = [true]} : tensor, vector<4xf32> } : vector<4xi1> -> vector<4xf32> +// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_8]] { vector.multi_reduction , %[[VAL_9]], %[[VAL_12]] [1] : vector<4x8xf32> to vector<4xf32> } : vector<4x8xi1> -> vector<4xf32> +// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_11]] { vector.transfer_write %[[VAL_13]], %[[VAL_1]]{{.*}} {in_bounds = [true]} : vector<4xf32>, tensor } : vector<4xi1> -> tensor +// CHECK: return %[[VAL_15]] : tensor +// CHECK: } + +// ----- + +func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor, + %arg1: tensor) -> tensor { + %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, + affine_map<(d0, d1, d2) -> (d2, d1)>], + iterator_types = ["reduction", "parallel", "parallel"] } + ins(%arg0 : tensor) + outs(%arg1 : tensor) { + ^bb(%in: f32, %out: f32) : + %0 = arith.addf %in, %out : f32 + linalg.yield %0 : f32 + } -> tensor + return %0 : tensor +} + +transform.sequence failures(propagate) { +^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + transform.structured.masked_vectorize %0 vector_sizes [4, 8, 16] +} + +// CHECK-LABEL: @vectorize_dynamic_transpose_reduction( +// CHECK-SAME: %[[VAL_0:.*]]: tensor, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor { +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index +// CHECK: %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor +// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor +// CHECK: %[[VAL_6:.*]] = arith.constant 2 : index +// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_6]] : tensor +// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_3]], %[[VAL_5]], %[[VAL_7]] : vector<4x8x16xi1> +// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_0]]{{.*}} {in_bounds = [true, true, true]} : tensor, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32> +// CHECK: %[[VAL_13:.*]] = vector.create_mask %[[VAL_7]], %[[VAL_5]] : vector<16x8xi1> +// CHECK: %[[VAL_14:.*]] = vector.mask %[[VAL_13]] { vector.transfer_read %[[VAL_1]]{{.*}} {in_bounds = [true, true], permutation_map = #{{.*}}} : tensor, vector<8x16xf32> } : vector<16x8xi1> -> vector<8x16xf32> +// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_10]] { vector.multi_reduction , %[[VAL_11]], %[[VAL_14]] [0] : vector<4x8x16xf32> to vector<8x16xf32> } : vector<4x8x16xi1> -> vector<8x16xf32> +// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_13]] { vector.transfer_write %[[VAL_15]], %{{.*}} {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<8x16xf32>, tensor } : vector<16x8xi1> -> tensor + +// ----- + // This is a regression test. This IR cannot be vectorized, but // structured.vectorize should nevertheless succeed. @@ -1892,4 +1968,3 @@ transform.sequence failures(propagate) { // CHECK-LABEL: @wrong_reduction_detection // CHECK: vector.broadcast // CHECK: vector.transfer_write - diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir index 6b372c3..ee4ab7a 100644 --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns | FileCheck %s +// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns -split-input-file | FileCheck %s func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> { %0 = vector.multi_reduction , %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32> @@ -19,6 +19,8 @@ func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) - // CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32> // CHECK: return %[[RESULT_VEC]] +// ----- + func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 { %0 = vector.multi_reduction , %arg0, %acc [0, 1] : vector<2x4xf32> to f32 return %0 : f32 @@ -31,6 +33,8 @@ func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) - // CHECK: %[[RES:.*]] = vector.extract %[[INSERTED]][0] : vector<1xf32> // CHECK: return %[[RES]] +// ----- + func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> { %0 = vector.multi_reduction , %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32> return %0 : vector<2x3xi32> @@ -72,6 +76,7 @@ func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi // CHECK: %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32> // CHECK: return %[[RESULT]] +// ----- func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>, %acc: vector<2x5xf32>) -> vector<2x5xf32> { %0 = vector.multi_reduction , %arg0, %acc [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32> @@ -85,6 +90,8 @@ func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>, %acc: v // CHECK: %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32> // CHECK: return %[[RESULT]] +// ----- + func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>, %acc: vector<2x4xf32>) -> vector<2x4xf32> { %0 = vector.multi_reduction , %arg0, %acc [0] : vector<3x2x4xf32> to vector<2x4xf32> return %0 : vector<2x4xf32> @@ -135,3 +142,95 @@ func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>, %acc: vecto // CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV7:.+]], %[[RESULT_VEC_7]][%[[C7]] : index] : vector<8xf32> // CHECK: %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32> // CHECK: return %[[RESHAPED_VEC]] + +// ----- + +func.func @vectorize_dynamic_reduction(%arg0: tensor, %arg1: tensor) -> tensor { + %c0 = arith.constant 0 : index + %dim = tensor.dim %arg0, %c0 : tensor + %c1 = arith.constant 1 : index + %dim_0 = tensor.dim %arg0, %c1 : tensor + %c0_1 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.create_mask %dim, %dim_0 : vector<4x8xi1> + %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_1, %c0_1], %cst {in_bounds = [true, true]} : tensor, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32> + %cst_2 = arith.constant 0.000000e+00 : f32 + %2 = vector.create_mask %dim : vector<4xi1> + %3 = vector.mask %2 { vector.transfer_read %arg1[%c0_1], %cst_2 {in_bounds = [true]} : tensor, vector<4xf32> } : vector<4xi1> -> vector<4xf32> + %4 = vector.mask %0 { vector.multi_reduction , %1, %3 [1] : vector<4x8xf32> to vector<4xf32> } : vector<4x8xi1> -> vector<4xf32> + %c0_3 = arith.constant 0 : index + %5 = vector.mask %2 { vector.transfer_write %4, %arg1[%c0_3] {in_bounds = [true]} : vector<4xf32>, tensor } : vector<4xi1> -> tensor + return %5 : tensor +} + +// Verify that the original 2-D mask is sliced and propagated properly to the +// vector.reduction instances. + +// CHECK-LABEL: func.func @vectorize_dynamic_reduction +// CHECK: %[[VAL_8:.*]] = tensor.dim +// CHECK: %[[VAL_9:.*]] = tensor.dim +// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_9]] : vector<4x8xi1> + +// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_10]][0] : vector<4x8xi1> +// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.reduction , %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32 +// CHECK: %[[VAL_18:.*]] = vector.insertelement + +// CHECK: %[[VAL_21:.*]] = vector.extract %[[VAL_10]][1] : vector<4x8xi1> +// CHECK: %[[VAL_22:.*]] = vector.mask %[[VAL_21]] { vector.reduction , %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32 +// CHECK: %[[VAL_23:.*]] = vector.insertelement + +// CHECK: %[[VAL_26:.*]] = vector.extract %[[VAL_10]][2] : vector<4x8xi1> +// CHECK: %[[VAL_27:.*]] = vector.mask %[[VAL_26]] { vector.reduction , %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32 +// CHECK: %[[VAL_28:.*]] = vector.insertelement + +// CHECK: %[[VAL_31:.*]] = vector.extract %[[VAL_10]][3] : vector<4x8xi1> +// CHECK: %[[VAL_32:.*]] = vector.mask %[[VAL_31]] { vector.reduction , %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32 +// CHECK: %[[VAL_33:.*]] = vector.insertelement + +// ----- + +func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor, %arg1: tensor) -> tensor { + %c0 = arith.constant 0 : index + %dim = tensor.dim %arg0, %c0 : tensor + %c1 = arith.constant 1 : index + %dim_0 = tensor.dim %arg0, %c1 : tensor + %c2 = arith.constant 2 : index + %dim_1 = tensor.dim %arg0, %c2 : tensor + %c0_2 = arith.constant 0 : index + %cst = arith.constant 0.000000e+00 : f32 + %0 = vector.create_mask %dim, %dim_0, %dim_1 : vector<4x8x16xi1> + %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_2, %c0_2, %c0_2], %cst {in_bounds = [true, true, true]} : tensor, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32> + %cst_3 = arith.constant 0.000000e+00 : f32 + %2 = vector.create_mask %dim_1, %dim_0 : vector<16x8xi1> + %3 = vector.mask %2 { vector.transfer_read %arg1[%c0_2, %c0_2], %cst_3 {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : tensor, vector<8x16xf32> } : vector<16x8xi1> -> vector<8x16xf32> + %4 = vector.mask %0 { vector.multi_reduction , %1, %3 [0] : vector<4x8x16xf32> to vector<8x16xf32> } : vector<4x8x16xi1> -> vector<8x16xf32> + %c0_4 = arith.constant 0 : index + %5 = vector.mask %2 { vector.transfer_write %4, %arg1[%c0_4, %c0_4] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<8x16xf32>, tensor } : vector<16x8xi1> -> tensor + return %5 : tensor +} + +// CHECK-LABEL: func.func @vectorize_dynamic_transpose_reduction +// CHECK: %[[VAL_6:.*]] = tensor.dim +// CHECK: %[[VAL_7:.*]] = tensor.dim +// CHECK: %[[VAL_8:.*]] = tensor.dim +// CHECK: %[[VAL_135:.*]] = vector.create_mask %{{.*}}, %{{.*}}, %{{.*}} : vector<4x8x16xi1> +// CHECK: %[[VAL_139:.*]] = vector.transpose %[[VAL_135]], [1, 2, 0] : vector<4x8x16xi1> to vector<8x16x4xi1> + +// Just checking a few instances to make sure the vector mask is properly propagated: + +// CHECK: %[[VAL_143:.*]] = vector.extract %[[VAL_139]][0, 0] : vector<8x16x4xi1> +// CHECK: %[[VAL_144:.*]] = vector.mask %[[VAL_143]] { vector.reduction +// CHECK: %[[VAL_145:.*]] = vector.insertelement %[[VAL_144]] + +// CHECK: %[[VAL_148:.*]] = vector.extract %[[VAL_139]][0, 1] : vector<8x16x4xi1> +// CHECK: %[[VAL_149:.*]] = vector.mask %[[VAL_148]] { vector.reduction +// CHECK: %[[VAL_150:.*]] = vector.insertelement %[[VAL_149]] + +// CHECK: %[[VAL_153:.*]] = vector.extract %[[VAL_139]][0, 2] : vector<8x16x4xi1> +// CHECK: %[[VAL_154:.*]] = vector.mask %[[VAL_153]] { vector.reduction +// CHECK: %[[VAL_155:.*]] = vector.insertelement %[[VAL_154]] + +// CHECK: %[[VAL_158:.*]] = vector.extract %[[VAL_139]][0, 3] : vector<8x16x4xi1> +// CHECK: %[[VAL_159:.*]] = vector.mask %[[VAL_158]] { vector.reduction +// CHECK: %[[VAL_160:.*]] = vector.insertelement %[[VAL_159]] + -- 2.7.4