return attr.cast<IteratorTypeAttr>().getValue() == IteratorType::reduction;
}
+//===----------------------------------------------------------------------===//
+// Vector Masking Utilities
+//===----------------------------------------------------------------------===//
+
+/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
+/// as masked operation.
+void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);
+
+/// Creates a vector.mask operation around a maskable operation. Returns the
+/// vector.mask operation if the mask provided is valid. Otherwise, returns the
+/// maskable operation itself.
+Operation *maskOperation(RewriterBase &rewriter, Operation *maskableOp,
+ Value mask);
+
} // namespace vector
} // namespace mlir
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
DeclareOpInterfaceMethods<InferTypeOpInterface>,
+ DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface,
["getShapeForUnroll"]>]>,
Arguments<(ins Vector_CombiningKindAttr:$kind,
let skipDefaultBuilders = 1;
let builders = [
- OpBuilder<(ins "Value":$mask,
- CArg<"function_ref<void(OpBuilder &, Location)>",
- "buildTerminatedBody">:$maskRegion)>,
- OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask,
- CArg<"function_ref<void(OpBuilder &, Location)>",
- "buildTerminatedBody">:$maskRegion)>,
- OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask,
- "Value":$passthru,
- CArg<"function_ref<void(OpBuilder &, Location)>",
- "buildTerminatedBody">:$maskRegion)>
+ OpBuilder<(ins "Value":$mask, "Operation *":$maskableOp,
+ CArg<"function_ref<void(OpBuilder &, Operation *)>">:$maskRegion)>,
+ OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask, "Operation *":$maskableOp,
+ CArg<"function_ref<void(OpBuilder &, Operation *)>">:$maskRegion)>,
+ OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask, "Value":$passthru,
+ "Operation *":$maskableOp,
+ CArg<"function_ref<void(OpBuilder &, Operation *)>">:$maskRegion)>
];
let extraClassDeclaration = [{
// Wrap the operation with a new `vector.mask` and update D-U chain.
assert(opToMask && "Expected a valid operation to mask");
- auto opResults = opToMask->getResultTypes();
- auto createRegionMask = [opToMask](OpBuilder &builder, Location loc) {
- Block *insBlock = builder.getInsertionBlock();
- // Create a block, put an op in that block. Look for a utility.
- // Maybe in conversion pattern rewriter. Way to avoid splice.
- // Set insertion point.
- insBlock->getOperations().splice(
- insBlock->begin(), opToMask->getBlock()->getOperations(), opToMask);
- builder.create<vector::YieldOp>(loc, opToMask->getResults());
- };
- // TODO: Allow multiple results in vector.mask.
- auto maskOp =
- opResults.empty()
- ? rewriter.create<vector::MaskOp>(opToMask->getLoc(), mask,
- createRegionMask)
- : rewriter.create<vector::MaskOp>(opToMask->getLoc(),
- opToMask->getResultTypes().front(),
- mask, createRegionMask);
-
+ auto maskOp = cast<vector::MaskOp>(
+ mlir::vector::maskOperation(rewriter, opToMask, mask));
Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
/// initial value.buildMultiDimReduce
// Note: this is a true builder that notifies the OpBuilder listener.
// TODO: Consider moving as a static helper on the ReduceOp.
-static Operation *buildMultiDimReduce(OpBuilder &b,
- Operation *reduceOp, Value valueToReduce,
- Value acc,
- const SmallVector<bool> &reductionMask) {
+static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
+ Value valueToReduce, Value acc,
+ ArrayRef<bool> dimsToMask) {
auto maybeKind = getCombinerOpKind(reduceOp);
assert(maybeKind && "Failed precondition: could not get reduction kind");
return b.create<vector::MultiDimReductionOp>(
- reduceOp->getLoc(), valueToReduce, acc, reductionMask, *maybeKind);
+ reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
}
-static SmallVector<bool> getReductionMask(LinalgOp linalgOp) {
+static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
return llvm::to_vector(
llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
}
if (!reduceType ||
(outputType && reduceType.getShape() == outputType.getShape()))
return nullptr;
- SmallVector<bool> reductionMask = getReductionMask(linalgOp);
- return buildMultiDimReduce(b, op, reduceVec, outputVec, reductionMask);
+ SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
+ return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
}
/// Generic vectorization for a single operation `op`, given already vectorized
}
static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
- // TODO: Masking only supports dynamic generic ops without reductions for now.
- if (!isElementwise(op) &&
- llvm::any_of(op.getIteratorTypesArray(), [](utils::IteratorType itType) {
- return itType != utils::IteratorType::parallel;
- }))
+ // TODO: Masking only supports dynamic generic ops for now.
+ if (!isa<linalg::GenericOp>(op))
return failure();
// TODO: 0-d vectors are not supported yet.
return success();
}
+/// Returns the mask type expected by this operation.
+Type MultiDimReductionOp::getExpectedMaskType() {
+ auto vecType = getSourceVectorType();
+ return VectorType::get(vecType.getShape(),
+ IntegerType::get(vecType.getContext(), /*width=*/1));
+}
+
namespace {
// Only unit dimensions that are being reduced are folded. If the dimension is
// unit, but not reduced, it is not folded, thereby keeping the output type the
void MaskOp::build(
OpBuilder &builder, OperationState &result, Value mask,
- function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
+ Operation *maskableOp,
+ function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
assert(maskRegionBuilder &&
"builder callback for 'maskRegion' must be present");
OpBuilder::InsertionGuard guard(builder);
Region *maskRegion = result.addRegion();
builder.createBlock(maskRegion);
- maskRegionBuilder(builder, result.location);
+ maskRegionBuilder(builder, maskableOp);
}
void MaskOp::build(
OpBuilder &builder, OperationState &result, TypeRange resultTypes,
- Value mask, function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
- build(builder, result, resultTypes, mask, /*passthru=*/Value(),
+ Value mask, Operation *maskableOp,
+ function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
+ build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
maskRegionBuilder);
}
void MaskOp::build(
- OpBuilder &builder, OperationState &result, TypeRange resultTypes,
- Value mask, Value passthru,
- function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
- build(builder, result, mask, maskRegionBuilder);
+ OpBuilder &builder, OperationState &result, TypeRange resultTypes, Value mask,
+ Value passthru, Operation *maskableOp,
+ function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
+ build(builder, result, mask, maskableOp, maskRegionBuilder);
if (passthru)
result.addOperands(passthru);
result.addTypes(resultTypes);
}
//===----------------------------------------------------------------------===//
+// Vector Masking Utilities
+//===----------------------------------------------------------------------===//
+
+/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
+/// as masked operation.
+void mlir::vector::createMaskOpRegion(OpBuilder &builder,
+ Operation *maskableOp) {
+ assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
+ Block *insBlock = builder.getInsertionBlock();
+ // Create a block and move the op to that block.
+ insBlock->getOperations().splice(
+ insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
+ builder.create<YieldOp>(maskableOp->getLoc(), maskableOp->getResults());
+}
+
+/// Creates a vector.mask operation around a maskable operation. Returns the
+/// vector.mask operation if the mask provided is valid. Otherwise, returns
+/// the maskable operation itself.
+Operation *mlir::vector::maskOperation(RewriterBase &rewriter,
+ Operation *maskableOp, Value mask) {
+ if (!mask)
+ return maskableOp;
+ return rewriter.create<MaskOp>(maskableOp->getLoc(),
+ maskableOp->getResultTypes(), mask, maskableOp,
+ createMaskOpRegion);
+}
+
+//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
-#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/Builders.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
#define DEBUG_TYPE "vector-multi-reduction"
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
+ // Vector mask setup.
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+ Operation *rootOp;
+ if (maskableOp.isMasked()) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ } else {
+ rootOp = multiReductionOp;
+ }
+
auto src = multiReductionOp.getSource();
auto loc = multiReductionOp.getLoc();
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
indices.append(reductionDims.begin(), reductionDims.end());
indices.append(parallelDims.begin(), parallelDims.end());
}
+
+ // If masked, transpose the original mask.
+ Value transposedMask;
+ if (maskableOp.isMasked()) {
+ transposedMask = rewriter.create<vector::TransposeOp>(
+ loc, maskableOp.getMaskingOp().getMask(), indices);
+ }
+
+ // Transpose reduction source.
auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
SmallVector<bool> reductionMask(srcRank, false);
for (int i = 0; i < reductionSize; ++i) {
else
reductionMask[i] = true;
}
- rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
- multiReductionOp, transposeOp.getResult(), multiReductionOp.getAcc(),
- reductionMask, multiReductionOp.getKind());
+
+ Operation *newMultiRedOp = rewriter.create<vector::MultiDimReductionOp>(
+ multiReductionOp.getLoc(), transposeOp.getResult(),
+ multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
+ newMultiRedOp =
+ mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask);
+
+ rewriter.replaceOp(rootOp, newMultiRedOp->getResult(0));
return success();
}
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
+ // Vector mask setup.
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+ Operation *rootOp;
+ if (maskableOp.isMasked()) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ } else {
+ rootOp = multiReductionOp;
+ }
+
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
auto srcShape = multiReductionOp.getSourceVectorType().getShape();
auto loc = multiReductionOp.getLoc();
std::swap(mask.front(), mask.back());
std::swap(vectorShape.front(), vectorShape.back());
}
+
+ Value newVectorMask;
+ if (maskableOp.isMasked()) {
+ Value vectorMask = maskableOp.getMaskingOp().getMask();
+ auto maskCastedType = VectorType::get(
+ vectorShape,
+ vectorMask.getType().cast<VectorType>().getElementType());
+ newVectorMask =
+ rewriter.create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask);
+ }
+
auto castedType = VectorType::get(
vectorShape, multiReductionOp.getSourceVectorType().getElementType());
Value cast = rewriter.create<vector::ShapeCastOp>(
loc, castedType, multiReductionOp.getSource());
+
Value acc = multiReductionOp.getAcc();
if (flattenedParallelDim) {
auto accType = VectorType::get(
multiReductionOp.getSourceVectorType().getElementType());
acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
}
- // 5. Creates the flattened form of vector.multi_reduction with inner/outer
+ // 6. Creates the flattened form of vector.multi_reduction with inner/outer
// most dim as reduction.
- auto newOp = rewriter.create<vector::MultiDimReductionOp>(
+ Operation *newMultiDimRedOp = rewriter.create<vector::MultiDimReductionOp>(
loc, cast, acc, mask, multiReductionOp.getKind());
+ newMultiDimRedOp =
+ mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask);
- // 6. If there are no parallel shapes, the result is a scalar.
+ // 7. If there are no parallel shapes, the result is a scalar.
// TODO: support 0-d vectors when available.
if (parallelShapes.empty()) {
- rewriter.replaceOp(multiReductionOp, newOp.getDest());
+ rewriter.replaceOp(rootOp, newMultiDimRedOp->getResult(0));
return success();
}
- // 7. Creates shape cast for the output n-D -> 2-D
+ // 8. Creates shape cast for the output n-D -> 2-D.
VectorType outputCastedType = VectorType::get(
parallelShapes,
multiReductionOp.getSourceVectorType().getElementType());
rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
- multiReductionOp, outputCastedType, newOp.getDest());
+ rootOp, outputCastedType, newMultiDimRedOp->getResult(0));
return success();
}
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+ if (maskableOp.isMasked())
+ // TODO: Support masking.
+ return failure();
+
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
// Rank-2 ["parallel", "reduce"] or bail.
if (srcRank != 2)
if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
return failure();
+ // Vector mask setup.
+ OpBuilder::InsertionGuard guard(rewriter);
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+ Operation *rootOp;
+ if (maskableOp.isMasked()) {
+ rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+ rootOp = maskableOp.getMaskingOp();
+ } else {
+ rootOp = multiReductionOp;
+ }
+
auto loc = multiReductionOp.getLoc();
Value result = rewriter.create<arith::ConstantOp>(
loc, multiReductionOp.getDestType(),
loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
auto acc = rewriter.create<vector::ExtractOp>(
loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
- auto reducedValue = rewriter.create<vector::ReductionOp>(
+ Operation *reductionOp = rewriter.create<vector::ReductionOp>(
loc, multiReductionOp.getKind(), v, acc);
+
+ // If masked, slice the mask and mask the new reduction operation.
+ if (maskableOp.isMasked()) {
+ Value mask = rewriter.create<vector::ExtractOp>(
+ loc, maskableOp.getMaskingOp().getMask(), ArrayRef<int64_t>{i});
+ reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
+ }
+
result = rewriter.create<vector::InsertElementOp>(
- loc, reducedValue, result,
+ loc, reductionOp->getResult(0), result,
rewriter.create<arith::ConstantIndexOp>(loc, i));
}
- rewriter.replaceOp(multiReductionOp, result);
+
+ rewriter.replaceOp(rootOp, result);
return success();
}
};
LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
PatternRewriter &rewriter) const override {
+ auto maskableOp =
+ cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+ if (maskableOp.isMasked())
+ // TODO: Support masking.
+ return failure();
+
auto srcRank = multiReductionOp.getSourceVectorType().getRank();
// Rank-1 or bail.
if (srcRank != 1)
// -----
+func.func @vectorize_dynamic_reduction(%arg0: tensor<?x?xf32>,
+ %arg1: tensor<?xf32>) -> tensor<?xf32> {
+ %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+ affine_map<(d0, d1) -> (d0)>],
+ iterator_types = ["parallel", "reduction"] }
+ ins(%arg0 : tensor<?x?xf32>)
+ outs(%arg1 : tensor<?xf32>) {
+ ^bb(%in: f32, %out: f32) :
+ %0 = arith.addf %in, %out : f32
+ linalg.yield %0 : f32
+ } -> tensor<?xf32>
+ return %0 : tensor<?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ transform.structured.masked_vectorize %0 vector_sizes [4, 8]
+}
+
+// CHECK-LABEL: @vectorize_dynamic_reduction(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf32>
+// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf32>
+// CHECK: %[[VAL_8:.*]] = vector.create_mask %[[VAL_3]], %[[VAL_5]] : vector<4x8xi1>
+// CHECK: %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_0]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32>
+// CHECK: %[[VAL_11:.*]] = vector.create_mask %[[VAL_3]] : vector<4xi1>
+// CHECK: %[[VAL_12:.*]] = vector.mask %[[VAL_11]] { vector.transfer_read %[[VAL_1]]{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK: %[[VAL_13:.*]] = vector.mask %[[VAL_8]] { vector.multi_reduction <add>, %[[VAL_9]], %[[VAL_12]] [1] : vector<4x8xf32> to vector<4xf32> } : vector<4x8xi1> -> vector<4xf32>
+// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_11]] { vector.transfer_write %[[VAL_13]], %[[VAL_1]]{{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<?xf32> } : vector<4xi1> -> tensor<?xf32>
+// CHECK: return %[[VAL_15]] : tensor<?xf32>
+// CHECK: }
+
+// -----
+
+func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor<?x?x?xf32>,
+ %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+ affine_map<(d0, d1, d2) -> (d2, d1)>],
+ iterator_types = ["reduction", "parallel", "parallel"] }
+ ins(%arg0 : tensor<?x?x?xf32>)
+ outs(%arg1 : tensor<?x?xf32>) {
+ ^bb(%in: f32, %out: f32) :
+ %0 = arith.addf %in, %out : f32
+ linalg.yield %0 : f32
+ } -> tensor<?x?xf32>
+ return %0 : tensor<?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+ %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+ transform.structured.masked_vectorize %0 vector_sizes [4, 8, 16]
+}
+
+// CHECK-LABEL: @vectorize_dynamic_transpose_reduction(
+// CHECK-SAME: %[[VAL_0:.*]]: tensor<?x?x?xf32>,
+// CHECK-SAME: %[[VAL_1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK: %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK: %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32>
+// CHECK: %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32>
+// CHECK: %[[VAL_6:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_6]] : tensor<?x?x?xf32>
+// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_3]], %[[VAL_5]], %[[VAL_7]] : vector<4x8x16xi1>
+// CHECK: %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_0]]{{.*}} {in_bounds = [true, true, true]} : tensor<?x?x?xf32>, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32>
+// CHECK: %[[VAL_13:.*]] = vector.create_mask %[[VAL_7]], %[[VAL_5]] : vector<16x8xi1>
+// CHECK: %[[VAL_14:.*]] = vector.mask %[[VAL_13]] { vector.transfer_read %[[VAL_1]]{{.*}} {in_bounds = [true, true], permutation_map = #{{.*}}} : tensor<?x?xf32>, vector<8x16xf32> } : vector<16x8xi1> -> vector<8x16xf32>
+// CHECK: %[[VAL_15:.*]] = vector.mask %[[VAL_10]] { vector.multi_reduction <add>, %[[VAL_11]], %[[VAL_14]] [0] : vector<4x8x16xf32> to vector<8x16xf32> } : vector<4x8x16xi1> -> vector<8x16xf32>
+// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_13]] { vector.transfer_write %[[VAL_15]], %{{.*}} {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<8x16xf32>, tensor<?x?xf32> } : vector<16x8xi1> -> tensor<?x?xf32>
+
+// -----
+
// This is a regression test. This IR cannot be vectorized, but
// structured.vectorize should nevertheless succeed.
// CHECK-LABEL: @wrong_reduction_detection
// CHECK: vector.broadcast
// CHECK: vector.transfer_write
-
-// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns | FileCheck %s
+// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns -split-input-file | FileCheck %s
func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
%0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
// CHECK: return %[[RESULT_VEC]]
+// -----
+
func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
%0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
return %0 : f32
// CHECK: %[[RES:.*]] = vector.extract %[[INSERTED]][0] : vector<1xf32>
// CHECK: return %[[RES]]
+// -----
+
func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> {
%0 = vector.multi_reduction <add>, %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
return %0 : vector<2x3xi32>
// CHECK: %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32>
// CHECK: return %[[RESULT]]
+// -----
func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>, %acc: vector<2x5xf32>) -> vector<2x5xf32> {
%0 = vector.multi_reduction <add>, %arg0, %acc [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
// CHECK: %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32>
// CHECK: return %[[RESULT]]
+// -----
+
func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>, %acc: vector<2x4xf32>) -> vector<2x4xf32> {
%0 = vector.multi_reduction <mul>, %arg0, %acc [0] : vector<3x2x4xf32> to vector<2x4xf32>
return %0 : vector<2x4xf32>
// CHECK: %[[RESULT_VEC:.+]] = vector.insertelement %[[RV7:.+]], %[[RESULT_VEC_7]][%[[C7]] : index] : vector<8xf32>
// CHECK: %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32>
// CHECK: return %[[RESHAPED_VEC]]
+
+// -----
+
+func.func @vectorize_dynamic_reduction(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+ %c0 = arith.constant 0 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+ %c1 = arith.constant 1 : index
+ %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+ %c0_1 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = vector.create_mask %dim, %dim_0 : vector<4x8xi1>
+ %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_1, %c0_1], %cst {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32>
+ %cst_2 = arith.constant 0.000000e+00 : f32
+ %2 = vector.create_mask %dim : vector<4xi1>
+ %3 = vector.mask %2 { vector.transfer_read %arg1[%c0_1], %cst_2 {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+ %4 = vector.mask %0 { vector.multi_reduction <add>, %1, %3 [1] : vector<4x8xf32> to vector<4xf32> } : vector<4x8xi1> -> vector<4xf32>
+ %c0_3 = arith.constant 0 : index
+ %5 = vector.mask %2 { vector.transfer_write %4, %arg1[%c0_3] {in_bounds = [true]} : vector<4xf32>, tensor<?xf32> } : vector<4xi1> -> tensor<?xf32>
+ return %5 : tensor<?xf32>
+}
+
+// Verify that the original 2-D mask is sliced and propagated properly to the
+// vector.reduction instances.
+
+// CHECK-LABEL: func.func @vectorize_dynamic_reduction
+// CHECK: %[[VAL_8:.*]] = tensor.dim
+// CHECK: %[[VAL_9:.*]] = tensor.dim
+// CHECK: %[[VAL_10:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_9]] : vector<4x8xi1>
+
+// CHECK: %[[VAL_16:.*]] = vector.extract %[[VAL_10]][0] : vector<4x8xi1>
+// CHECK: %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
+// CHECK: %[[VAL_18:.*]] = vector.insertelement
+
+// CHECK: %[[VAL_21:.*]] = vector.extract %[[VAL_10]][1] : vector<4x8xi1>
+// CHECK: %[[VAL_22:.*]] = vector.mask %[[VAL_21]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
+// CHECK: %[[VAL_23:.*]] = vector.insertelement
+
+// CHECK: %[[VAL_26:.*]] = vector.extract %[[VAL_10]][2] : vector<4x8xi1>
+// CHECK: %[[VAL_27:.*]] = vector.mask %[[VAL_26]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
+// CHECK: %[[VAL_28:.*]] = vector.insertelement
+
+// CHECK: %[[VAL_31:.*]] = vector.extract %[[VAL_10]][3] : vector<4x8xi1>
+// CHECK: %[[VAL_32:.*]] = vector.mask %[[VAL_31]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
+// CHECK: %[[VAL_33:.*]] = vector.insertelement
+
+// -----
+
+func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+ %c0 = arith.constant 0 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+ %c1 = arith.constant 1 : index
+ %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+ %c2 = arith.constant 2 : index
+ %dim_1 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+ %c0_2 = arith.constant 0 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = vector.create_mask %dim, %dim_0, %dim_1 : vector<4x8x16xi1>
+ %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_2, %c0_2, %c0_2], %cst {in_bounds = [true, true, true]} : tensor<?x?x?xf32>, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32>
+ %cst_3 = arith.constant 0.000000e+00 : f32
+ %2 = vector.create_mask %dim_1, %dim_0 : vector<16x8xi1>
+ %3 = vector.mask %2 { vector.transfer_read %arg1[%c0_2, %c0_2], %cst_3 {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : tensor<?x?xf32>, vector<8x16xf32> } : vector<16x8xi1> -> vector<8x16xf32>
+ %4 = vector.mask %0 { vector.multi_reduction <add>, %1, %3 [0] : vector<4x8x16xf32> to vector<8x16xf32> } : vector<4x8x16xi1> -> vector<8x16xf32>
+ %c0_4 = arith.constant 0 : index
+ %5 = vector.mask %2 { vector.transfer_write %4, %arg1[%c0_4, %c0_4] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<8x16xf32>, tensor<?x?xf32> } : vector<16x8xi1> -> tensor<?x?xf32>
+ return %5 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL: func.func @vectorize_dynamic_transpose_reduction
+// CHECK: %[[VAL_6:.*]] = tensor.dim
+// CHECK: %[[VAL_7:.*]] = tensor.dim
+// CHECK: %[[VAL_8:.*]] = tensor.dim
+// CHECK: %[[VAL_135:.*]] = vector.create_mask %{{.*}}, %{{.*}}, %{{.*}} : vector<4x8x16xi1>
+// CHECK: %[[VAL_139:.*]] = vector.transpose %[[VAL_135]], [1, 2, 0] : vector<4x8x16xi1> to vector<8x16x4xi1>
+
+// Just checking a few instances to make sure the vector mask is properly propagated:
+
+// CHECK: %[[VAL_143:.*]] = vector.extract %[[VAL_139]][0, 0] : vector<8x16x4xi1>
+// CHECK: %[[VAL_144:.*]] = vector.mask %[[VAL_143]] { vector.reduction <add>
+// CHECK: %[[VAL_145:.*]] = vector.insertelement %[[VAL_144]]
+
+// CHECK: %[[VAL_148:.*]] = vector.extract %[[VAL_139]][0, 1] : vector<8x16x4xi1>
+// CHECK: %[[VAL_149:.*]] = vector.mask %[[VAL_148]] { vector.reduction <add>
+// CHECK: %[[VAL_150:.*]] = vector.insertelement %[[VAL_149]]
+
+// CHECK: %[[VAL_153:.*]] = vector.extract %[[VAL_139]][0, 2] : vector<8x16x4xi1>
+// CHECK: %[[VAL_154:.*]] = vector.mask %[[VAL_153]] { vector.reduction <add>
+// CHECK: %[[VAL_155:.*]] = vector.insertelement %[[VAL_154]]
+
+// CHECK: %[[VAL_158:.*]] = vector.extract %[[VAL_139]][0, 3] : vector<8x16x4xi1>
+// CHECK: %[[VAL_159:.*]] = vector.mask %[[VAL_158]] { vector.reduction <add>
+// CHECK: %[[VAL_160:.*]] = vector.insertelement %[[VAL_159]]
+