[mlir][vector] Masking support for reductions in Linalg vectorizer
authorDiego Caballero <diegocaballero@google.com>
Fri, 13 Jan 2023 20:36:40 +0000 (20:36 +0000)
committerDiego Caballero <diegocaballero@google.com>
Fri, 13 Jan 2023 20:45:04 +0000 (20:45 +0000)
This patch enables vectorization of reductions in Linalg vectorizer
using the vector.mask operation. It also introduces the logic to slice
and propagate the vector mask of a masked multi-reduction to their
respective lowering operations.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D141571

mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/VectorMultiDimReductionTransforms.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir

index 0028abe..deb86df 100644 (file)
@@ -203,6 +203,20 @@ inline bool isReductionIterator(Attribute attr) {
   return attr.cast<IteratorTypeAttr>().getValue() == IteratorType::reduction;
 }
 
+//===----------------------------------------------------------------------===//
+// Vector Masking Utilities
+//===----------------------------------------------------------------------===//
+
+/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
+/// as masked operation.
+void createMaskOpRegion(OpBuilder &builder, Operation *maskableOp);
+
+/// Creates a vector.mask operation around a maskable operation. Returns the
+/// vector.mask operation if the mask provided is valid. Otherwise, returns the
+/// maskable operation itself.
+Operation *maskOperation(RewriterBase &rewriter, Operation *maskableOp,
+                         Value mask);
+
 } // namespace vector
 } // namespace mlir
 
index 5a14f0d..8c5d44a 100644 (file)
@@ -340,6 +340,7 @@ def Vector_MultiDimReductionOp :
      PredOpTrait<"source operand and result have same element type",
                  TCresVTEtIsSameAsOpBase<0, 0>>,
      DeclareOpInterfaceMethods<InferTypeOpInterface>,
+     DeclareOpInterfaceMethods<MaskableOpInterface>,
      DeclareOpInterfaceMethods<VectorUnrollOpInterface,
                                ["getShapeForUnroll"]>]>,
     Arguments<(ins Vector_CombiningKindAttr:$kind,
@@ -2338,16 +2339,13 @@ def Vector_MaskOp : Vector_Op<"mask", [
 
   let skipDefaultBuilders = 1;
   let builders = [
-    OpBuilder<(ins "Value":$mask,
-                   CArg<"function_ref<void(OpBuilder &, Location)>",
-                        "buildTerminatedBody">:$maskRegion)>,
-    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask,
-                   CArg<"function_ref<void(OpBuilder &, Location)>",
-                        "buildTerminatedBody">:$maskRegion)>,
-    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask,
-                   "Value":$passthru,
-                   CArg<"function_ref<void(OpBuilder &, Location)>",
-                        "buildTerminatedBody">:$maskRegion)>
+    OpBuilder<(ins "Value":$mask, "Operation *":$maskableOp,
+                   CArg<"function_ref<void(OpBuilder &, Operation *)>">:$maskRegion)>,
+    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask, "Operation *":$maskableOp,
+                   CArg<"function_ref<void(OpBuilder &, Operation *)>">:$maskRegion)>,
+    OpBuilder<(ins "TypeRange":$resultTypes, "Value":$mask, "Value":$passthru,
+                   "Operation *":$maskableOp,
+                   CArg<"function_ref<void(OpBuilder &, Operation *)>">:$maskRegion)>
   ];
 
   let extraClassDeclaration = [{
index 1e83350..5f367d1 100644 (file)
@@ -292,25 +292,8 @@ VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
 
   // Wrap the operation with a new `vector.mask` and update D-U chain.
   assert(opToMask && "Expected a valid operation to mask");
-  auto opResults = opToMask->getResultTypes();
-  auto createRegionMask = [opToMask](OpBuilder &builder, Location loc) {
-    Block *insBlock = builder.getInsertionBlock();
-    // Create a block, put an op in that block. Look for a utility.
-    // Maybe in conversion pattern rewriter. Way to avoid splice.
-    // Set insertion point.
-    insBlock->getOperations().splice(
-        insBlock->begin(), opToMask->getBlock()->getOperations(), opToMask);
-    builder.create<vector::YieldOp>(loc, opToMask->getResults());
-  };
-  // TODO: Allow multiple results in vector.mask.
-  auto maskOp =
-      opResults.empty()
-          ? rewriter.create<vector::MaskOp>(opToMask->getLoc(), mask,
-                                            createRegionMask)
-          : rewriter.create<vector::MaskOp>(opToMask->getLoc(),
-                                            opToMask->getResultTypes().front(),
-                                            mask, createRegionMask);
-
+  auto maskOp = cast<vector::MaskOp>(
+      mlir::vector::maskOperation(rewriter, opToMask, mask));
   Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
 
   for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
@@ -440,17 +423,16 @@ static Value broadcastIfNeeded(OpBuilder &b, Value value,
 /// initial value.buildMultiDimReduce
 // Note: this is a true builder that notifies the OpBuilder listener.
 // TODO: Consider moving as a static helper on the ReduceOp.
-static Operation *buildMultiDimReduce(OpBuilder &b,
-                                      Operation *reduceOp, Value valueToReduce,
-                                      Value acc,
-                                      const SmallVector<bool> &reductionMask) {
+static Operation *buildMultiDimReduce(OpBuilder &b, Operation *reduceOp,
+                                      Value valueToReduce, Value acc,
+                                      ArrayRef<bool> dimsToMask) {
   auto maybeKind = getCombinerOpKind(reduceOp);
   assert(maybeKind && "Failed precondition: could not get reduction kind");
   return b.create<vector::MultiDimReductionOp>(
-      reduceOp->getLoc(), valueToReduce, acc, reductionMask, *maybeKind);
+      reduceOp->getLoc(), valueToReduce, acc, dimsToMask, *maybeKind);
 }
 
-static SmallVector<bool> getReductionMask(LinalgOp linalgOp) {
+static SmallVector<bool> getDimsToReduce(LinalgOp linalgOp) {
   return llvm::to_vector(
       llvm::map_range(linalgOp.getIteratorTypesArray(), isReductionIterator));
 }
@@ -701,8 +683,8 @@ static Operation *reduceIfNeeded(OpBuilder &b, LinalgOp linalgOp, Operation *op,
   if (!reduceType ||
       (outputType && reduceType.getShape() == outputType.getShape()))
     return nullptr;
-  SmallVector<bool> reductionMask = getReductionMask(linalgOp);
-  return buildMultiDimReduce(b, op, reduceVec, outputVec, reductionMask);
+  SmallVector<bool> dimsToMask = getDimsToReduce(linalgOp);
+  return buildMultiDimReduce(b, op, reduceVec, outputVec, dimsToMask);
 }
 
 /// Generic vectorization for a single operation `op`, given already vectorized
@@ -972,11 +954,8 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
 }
 
 static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
-  // TODO: Masking only supports dynamic generic ops without reductions for now.
-  if (!isElementwise(op) &&
-      llvm::any_of(op.getIteratorTypesArray(), [](utils::IteratorType itType) {
-        return itType != utils::IteratorType::parallel;
-      }))
+  // TODO: Masking only supports dynamic generic ops for now.
+  if (!isa<linalg::GenericOp>(op))
     return failure();
 
   // TODO: 0-d vectors are not supported yet.
index f00d849..9339452 100644 (file)
@@ -342,6 +342,13 @@ LogicalResult MultiDimReductionOp::verify() {
   return success();
 }
 
+/// Returns the mask type expected by this operation.
+Type MultiDimReductionOp::getExpectedMaskType() {
+  auto vecType = getSourceVectorType();
+  return VectorType::get(vecType.getShape(),
+                         IntegerType::get(vecType.getContext(), /*width=*/1));
+}
+
 namespace {
 // Only unit dimensions that are being reduced are folded. If the dimension is
 // unit, but not reduced, it is not folded, thereby keeping the output type the
@@ -5276,7 +5283,8 @@ void CreateMaskOp::getCanonicalizationPatterns(RewritePatternSet &results,
 
 void MaskOp::build(
     OpBuilder &builder, OperationState &result, Value mask,
-    function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
+    Operation *maskableOp,
+    function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
   assert(maskRegionBuilder &&
          "builder callback for 'maskRegion' must be present");
 
@@ -5284,21 +5292,22 @@ void MaskOp::build(
   OpBuilder::InsertionGuard guard(builder);
   Region *maskRegion = result.addRegion();
   builder.createBlock(maskRegion);
-  maskRegionBuilder(builder, result.location);
+  maskRegionBuilder(builder, maskableOp);
 }
 
 void MaskOp::build(
     OpBuilder &builder, OperationState &result, TypeRange resultTypes,
-    Value mask, function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
-  build(builder, result, resultTypes, mask, /*passthru=*/Value(),
+    Value mask, Operation *maskableOp,
+    function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
+  build(builder, result, resultTypes, mask, /*passthru=*/Value(), maskableOp,
         maskRegionBuilder);
 }
 
 void MaskOp::build(
-    OpBuilder &builder, OperationState &result, TypeRange resultTypes,
-    Value mask, Value passthru,
-    function_ref<void(OpBuilder &, Location)> maskRegionBuilder) {
-  build(builder, result, mask, maskRegionBuilder);
+    OpBuilder &builder, OperationState &result, TypeRange resultTypes, Value mask,
+    Value passthru, Operation *maskableOp,
+    function_ref<void(OpBuilder &, Operation *)> maskRegionBuilder) {
+  build(builder, result, mask, maskableOp, maskRegionBuilder);
   if (passthru)
     result.addOperands(passthru);
   result.addTypes(resultTypes);
@@ -5739,6 +5748,34 @@ Value mlir::vector::makeArithReduction(OpBuilder &b, Location loc,
 }
 
 //===----------------------------------------------------------------------===//
+// Vector Masking Utilities
+//===----------------------------------------------------------------------===//
+
+/// Create the vector.yield-ended region of a vector.mask op with `maskableOp`
+/// as masked operation.
+void mlir::vector::createMaskOpRegion(OpBuilder &builder,
+                                      Operation *maskableOp) {
+  assert(maskableOp->getBlock() && "MaskableOp must be inserted into a block");
+  Block *insBlock = builder.getInsertionBlock();
+  // Create a block and move the op to that block.
+  insBlock->getOperations().splice(
+      insBlock->begin(), maskableOp->getBlock()->getOperations(), maskableOp);
+  builder.create<YieldOp>(maskableOp->getLoc(), maskableOp->getResults());
+}
+
+/// Creates a vector.mask operation around a maskable operation. Returns the
+/// vector.mask operation if the mask provided is valid. Otherwise, returns
+/// the maskable operation itself.
+Operation *mlir::vector::maskOperation(RewriterBase &rewriter,
+                                       Operation *maskableOp, Value mask) {
+  if (!mask)
+    return maskableOp;
+  return rewriter.create<MaskOp>(maskableOp->getLoc(),
+                                 maskableOp->getResultTypes(), mask, maskableOp,
+                                 createMaskOpRegion);
+}
+
+//===----------------------------------------------------------------------===//
 // TableGen'd op method definitions
 //===----------------------------------------------------------------------===//
 
index 31a2452..e89059c 100644 (file)
@@ -12,9 +12,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
-#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
 #include "mlir/IR/Builders.h"
-#include "mlir/IR/ImplicitLocOpBuilder.h"
 #include "mlir/IR/TypeUtilities.h"
 
 #define DEBUG_TYPE "vector-multi-reduction"
@@ -40,6 +38,18 @@ public:
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
                                 PatternRewriter &rewriter) const override {
+    // Vector mask setup.
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+    Operation *rootOp;
+    if (maskableOp.isMasked()) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+    } else {
+      rootOp = multiReductionOp;
+    }
+
     auto src = multiReductionOp.getSource();
     auto loc = multiReductionOp.getLoc();
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
@@ -79,6 +89,15 @@ public:
       indices.append(reductionDims.begin(), reductionDims.end());
       indices.append(parallelDims.begin(), parallelDims.end());
     }
+
+    // If masked, transpose the original mask.
+    Value transposedMask;
+    if (maskableOp.isMasked()) {
+      transposedMask = rewriter.create<vector::TransposeOp>(
+          loc, maskableOp.getMaskingOp().getMask(), indices);
+    }
+
+    // Transpose reduction source.
     auto transposeOp = rewriter.create<vector::TransposeOp>(loc, src, indices);
     SmallVector<bool> reductionMask(srcRank, false);
     for (int i = 0; i < reductionSize; ++i) {
@@ -87,9 +106,14 @@ public:
       else
         reductionMask[i] = true;
     }
-    rewriter.replaceOpWithNewOp<vector::MultiDimReductionOp>(
-        multiReductionOp, transposeOp.getResult(), multiReductionOp.getAcc(),
-        reductionMask, multiReductionOp.getKind());
+
+    Operation *newMultiRedOp = rewriter.create<vector::MultiDimReductionOp>(
+        multiReductionOp.getLoc(), transposeOp.getResult(),
+        multiReductionOp.getAcc(), reductionMask, multiReductionOp.getKind());
+    newMultiRedOp =
+        mlir::vector::maskOperation(rewriter, newMultiRedOp, transposedMask);
+
+    rewriter.replaceOp(rootOp, newMultiRedOp->getResult(0));
     return success();
   }
 
@@ -113,6 +137,18 @@ public:
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
                                 PatternRewriter &rewriter) const override {
+    // Vector mask setup.
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+    Operation *rootOp;
+    if (maskableOp.isMasked()) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+    } else {
+      rootOp = multiReductionOp;
+    }
+
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
     auto srcShape = multiReductionOp.getSourceVectorType().getShape();
     auto loc = multiReductionOp.getLoc();
@@ -186,10 +222,22 @@ public:
       std::swap(mask.front(), mask.back());
       std::swap(vectorShape.front(), vectorShape.back());
     }
+
+    Value newVectorMask;
+    if (maskableOp.isMasked()) {
+      Value vectorMask = maskableOp.getMaskingOp().getMask();
+      auto maskCastedType = VectorType::get(
+          vectorShape,
+          vectorMask.getType().cast<VectorType>().getElementType());
+      newVectorMask =
+          rewriter.create<vector::ShapeCastOp>(loc, maskCastedType, vectorMask);
+    }
+
     auto castedType = VectorType::get(
         vectorShape, multiReductionOp.getSourceVectorType().getElementType());
     Value cast = rewriter.create<vector::ShapeCastOp>(
         loc, castedType, multiReductionOp.getSource());
+
     Value acc = multiReductionOp.getAcc();
     if (flattenedParallelDim) {
       auto accType = VectorType::get(
@@ -197,24 +245,26 @@ public:
           multiReductionOp.getSourceVectorType().getElementType());
       acc = rewriter.create<vector::ShapeCastOp>(loc, accType, acc);
     }
-    // 5. Creates the flattened form of vector.multi_reduction with inner/outer
+    // 6. Creates the flattened form of vector.multi_reduction with inner/outer
     // most dim as reduction.
-    auto newOp = rewriter.create<vector::MultiDimReductionOp>(
+    Operation *newMultiDimRedOp = rewriter.create<vector::MultiDimReductionOp>(
         loc, cast, acc, mask, multiReductionOp.getKind());
+    newMultiDimRedOp =
+        mlir::vector::maskOperation(rewriter, newMultiDimRedOp, newVectorMask);
 
-    // 6. If there are no parallel shapes, the result is a scalar.
+    // 7. If there are no parallel shapes, the result is a scalar.
     // TODO: support 0-d vectors when available.
     if (parallelShapes.empty()) {
-      rewriter.replaceOp(multiReductionOp, newOp.getDest());
+      rewriter.replaceOp(rootOp, newMultiDimRedOp->getResult(0));
       return success();
     }
 
-    // 7. Creates shape cast for the output n-D -> 2-D
+    // 8. Creates shape cast for the output n-D -> 2-D.
     VectorType outputCastedType = VectorType::get(
         parallelShapes,
         multiReductionOp.getSourceVectorType().getElementType());
     rewriter.replaceOpWithNewOp<vector::ShapeCastOp>(
-        multiReductionOp, outputCastedType, newOp.getDest());
+        rootOp, outputCastedType, newMultiDimRedOp->getResult(0));
     return success();
   }
 
@@ -230,6 +280,12 @@ struct TwoDimMultiReductionToElementWise
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
                                 PatternRewriter &rewriter) const override {
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+    if (maskableOp.isMasked())
+      // TODO: Support masking.
+      return failure();
+
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
     // Rank-2 ["parallel", "reduce"] or bail.
     if (srcRank != 2)
@@ -274,6 +330,18 @@ struct TwoDimMultiReductionToReduction
     if (multiReductionOp.isReducedDim(0) || !multiReductionOp.isReducedDim(1))
       return failure();
 
+    // Vector mask setup.
+    OpBuilder::InsertionGuard guard(rewriter);
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+    Operation *rootOp;
+    if (maskableOp.isMasked()) {
+      rewriter.setInsertionPoint(maskableOp.getMaskingOp());
+      rootOp = maskableOp.getMaskingOp();
+    } else {
+      rootOp = multiReductionOp;
+    }
+
     auto loc = multiReductionOp.getLoc();
     Value result = rewriter.create<arith::ConstantOp>(
         loc, multiReductionOp.getDestType(),
@@ -285,13 +353,22 @@ struct TwoDimMultiReductionToReduction
           loc, multiReductionOp.getSource(), ArrayRef<int64_t>{i});
       auto acc = rewriter.create<vector::ExtractOp>(
           loc, multiReductionOp.getAcc(), ArrayRef<int64_t>{i});
-      auto reducedValue = rewriter.create<vector::ReductionOp>(
+      Operation *reductionOp = rewriter.create<vector::ReductionOp>(
           loc, multiReductionOp.getKind(), v, acc);
+
+      // If masked, slice the mask and mask the new reduction operation.
+      if (maskableOp.isMasked()) {
+        Value mask = rewriter.create<vector::ExtractOp>(
+            loc, maskableOp.getMaskingOp().getMask(), ArrayRef<int64_t>{i});
+        reductionOp = mlir::vector::maskOperation(rewriter, reductionOp, mask);
+      }
+
       result = rewriter.create<vector::InsertElementOp>(
-          loc, reducedValue, result,
+          loc, reductionOp->getResult(0), result,
           rewriter.create<arith::ConstantIndexOp>(loc, i));
     }
-    rewriter.replaceOp(multiReductionOp, result);
+
+    rewriter.replaceOp(rootOp, result);
     return success();
   }
 };
@@ -307,6 +384,12 @@ struct OneDimMultiReductionToTwoDim
 
   LogicalResult matchAndRewrite(vector::MultiDimReductionOp multiReductionOp,
                                 PatternRewriter &rewriter) const override {
+    auto maskableOp =
+        cast<vector::MaskableOpInterface>(multiReductionOp.getOperation());
+    if (maskableOp.isMasked())
+      // TODO: Support masking.
+      return failure();
+
     auto srcRank = multiReductionOp.getSourceVectorType().getRank();
     // Rank-1 or bail.
     if (srcRank != 1)
index 0ccd6c4..d25ffe7 100644 (file)
@@ -1824,6 +1824,82 @@ transform.sequence failures(propagate) {
 
 // -----
 
+func.func @vectorize_dynamic_reduction(%arg0: tensor<?x?xf32>,
+                                       %arg1: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
+                                         affine_map<(d0, d1) -> (d0)>],
+                        iterator_types = ["parallel", "reduction"] }
+    ins(%arg0 : tensor<?x?xf32>)
+    outs(%arg1 : tensor<?xf32>) {
+    ^bb(%in: f32, %out: f32) :
+      %0 = arith.addf %in, %out : f32
+      linalg.yield %0 : f32
+    } -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  transform.structured.masked_vectorize %0 vector_sizes [4, 8]
+}
+
+// CHECK-LABEL:   @vectorize_dynamic_reduction(
+// CHECK-SAME:                                 %[[VAL_0:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:                                 %[[VAL_1:.*]]: tensor<?xf32>) -> tensor<?xf32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?xf32>
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?xf32>
+// CHECK:           %[[VAL_8:.*]] = vector.create_mask %[[VAL_3]], %[[VAL_5]] : vector<4x8xi1>
+// CHECK:           %[[VAL_9:.*]] = vector.mask %[[VAL_8]] { vector.transfer_read %[[VAL_0]]{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32>
+// CHECK:           %[[VAL_11:.*]] = vector.create_mask %[[VAL_3]] : vector<4xi1>
+// CHECK:           %[[VAL_12:.*]] = vector.mask %[[VAL_11]] { vector.transfer_read %[[VAL_1]]{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK:           %[[VAL_13:.*]] = vector.mask %[[VAL_8]] { vector.multi_reduction <add>, %[[VAL_9]], %[[VAL_12]] [1] : vector<4x8xf32> to vector<4xf32> } : vector<4x8xi1> -> vector<4xf32>
+// CHECK:           %[[VAL_15:.*]] = vector.mask %[[VAL_11]] { vector.transfer_write %[[VAL_13]], %[[VAL_1]]{{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<?xf32> } : vector<4xi1> -> tensor<?xf32>
+// CHECK:           return %[[VAL_15]] : tensor<?xf32>
+// CHECK:         }
+
+// -----
+
+func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor<?x?x?xf32>,
+                                                 %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>,
+                                         affine_map<(d0, d1, d2) -> (d2, d1)>],
+                        iterator_types = ["reduction", "parallel", "parallel"] }
+    ins(%arg0 : tensor<?x?x?xf32>)
+    outs(%arg1 : tensor<?x?xf32>) {
+    ^bb(%in: f32, %out: f32) :
+      %0 = arith.addf %in, %out : f32
+      linalg.yield %0 : f32
+    } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  transform.structured.masked_vectorize %0 vector_sizes [4, 8, 16]
+}
+
+// CHECK-LABEL:   @vectorize_dynamic_transpose_reduction(
+// CHECK-SAME:                                           %[[VAL_0:.*]]: tensor<?x?x?xf32>,
+// CHECK-SAME:                                           %[[VAL_1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = tensor.dim %[[VAL_0]], %[[VAL_2]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_4:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_5:.*]] = tensor.dim %[[VAL_0]], %[[VAL_4]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_6:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_7:.*]] = tensor.dim %[[VAL_0]], %[[VAL_6]] : tensor<?x?x?xf32>
+// CHECK:           %[[VAL_10:.*]] = vector.create_mask %[[VAL_3]], %[[VAL_5]], %[[VAL_7]] : vector<4x8x16xi1>
+// CHECK:           %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_0]]{{.*}} {in_bounds = [true, true, true]} : tensor<?x?x?xf32>, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32>
+// CHECK:           %[[VAL_13:.*]] = vector.create_mask %[[VAL_7]], %[[VAL_5]] : vector<16x8xi1>
+// CHECK:           %[[VAL_14:.*]] = vector.mask %[[VAL_13]] { vector.transfer_read %[[VAL_1]]{{.*}} {in_bounds = [true, true], permutation_map = #{{.*}}} : tensor<?x?xf32>, vector<8x16xf32> } : vector<16x8xi1> -> vector<8x16xf32>
+// CHECK:           %[[VAL_15:.*]] = vector.mask %[[VAL_10]] { vector.multi_reduction <add>, %[[VAL_11]], %[[VAL_14]] [0] : vector<4x8x16xf32> to vector<8x16xf32> } : vector<4x8x16xi1> -> vector<8x16xf32>
+// CHECK:           %[[VAL_17:.*]] = vector.mask %[[VAL_13]] { vector.transfer_write %[[VAL_15]], %{{.*}} {in_bounds = [true, true], permutation_map = #{{.*}}} : vector<8x16xf32>, tensor<?x?xf32> } : vector<16x8xi1> -> tensor<?x?xf32>
+
+// -----
+
 // This is a regression test. This IR cannot be vectorized, but
 // structured.vectorize should nevertheless succeed.
 
@@ -1892,4 +1968,3 @@ transform.sequence failures(propagate) {
 // CHECK-LABEL: @wrong_reduction_detection
 // CHECK:         vector.broadcast
 // CHECK:         vector.transfer_write
-
index 6b372c3..ee4ab7a 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns | FileCheck %s
+// RUN: mlir-opt %s -test-vector-multi-reduction-lowering-patterns -split-input-file | FileCheck %s
 
 func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -> vector<2xf32> {
     %0 = vector.multi_reduction <mul>, %arg0, %acc [1] : vector<2x4xf32> to vector<2xf32>
@@ -19,6 +19,8 @@ func.func @vector_multi_reduction(%arg0: vector<2x4xf32>, %acc: vector<2xf32>) -
 //       CHECK:       %[[RESULT_VEC:.+]] = vector.insertelement %[[RV1:.+]], %[[RESULT_VEC_1]][%[[C1]] : index] : vector<2xf32>
 //       CHECK:       return %[[RESULT_VEC]]
 
+// -----
+
 func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -> f32 {
     %0 = vector.multi_reduction <mul>, %arg0, %acc [0, 1] : vector<2x4xf32> to f32
     return %0 : f32
@@ -31,6 +33,8 @@ func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x4xf32>, %acc: f32) -
 //       CHECK:   %[[RES:.*]] = vector.extract %[[INSERTED]][0] : vector<1xf32>
 //       CHECK:   return %[[RES]]
 
+// -----
+
 func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi32>) -> vector<2x3xi32> {
     %0 = vector.multi_reduction <add>, %arg0, %acc [2, 3] : vector<2x3x4x5xi32> to vector<2x3xi32>
     return %0 : vector<2x3xi32>
@@ -72,6 +76,7 @@ func.func @vector_reduction_inner(%arg0: vector<2x3x4x5xi32>, %acc: vector<2x3xi
 //       CHECK:       %[[RESULT:.+]] = vector.shape_cast %[[FLAT_RESULT_VEC]] : vector<6xi32> to vector<2x3xi32>
 //       CHECK:       return %[[RESULT]]
 
+// -----
 
 func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>, %acc: vector<2x5xf32>) -> vector<2x5xf32> {
     %0 = vector.multi_reduction <add>, %arg0, %acc [1, 2] : vector<2x3x4x5xf32> to vector<2x5xf32>
@@ -85,6 +90,8 @@ func.func @vector_multi_reduction_transposed(%arg0: vector<2x3x4x5xf32>, %acc: v
 //       CHECK:     %[[RESULT:.+]] = vector.shape_cast %{{.*}} : vector<10xf32> to vector<2x5xf32>
 //       CHECK:       return %[[RESULT]]
 
+// -----
+
 func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>, %acc: vector<2x4xf32>) -> vector<2x4xf32> {
     %0 = vector.multi_reduction <mul>, %arg0, %acc [0] : vector<3x2x4xf32> to vector<2x4xf32>
     return %0 : vector<2x4xf32>
@@ -135,3 +142,95 @@ func.func @vector_multi_reduction_ordering(%arg0: vector<3x2x4xf32>, %acc: vecto
 //       CHECK:       %[[RESULT_VEC:.+]] = vector.insertelement %[[RV7:.+]], %[[RESULT_VEC_7]][%[[C7]] : index] : vector<8xf32>
 //       CHECK:       %[[RESHAPED_VEC:.+]] = vector.shape_cast %[[RESULT_VEC]] : vector<8xf32> to vector<2x4xf32>
 //       CHECK:       return %[[RESHAPED_VEC]]
+
+// -----
+
+func.func @vectorize_dynamic_reduction(%arg0: tensor<?x?xf32>, %arg1: tensor<?xf32>) -> tensor<?xf32> {
+  %c0 = arith.constant 0 : index
+  %dim = tensor.dim %arg0, %c0 : tensor<?x?xf32>
+  %c1 = arith.constant 1 : index
+  %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?xf32>
+  %c0_1 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = vector.create_mask %dim, %dim_0 : vector<4x8xi1>
+  %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_1, %c0_1], %cst {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32>
+  %cst_2 = arith.constant 0.000000e+00 : f32
+  %2 = vector.create_mask %dim : vector<4xi1>
+  %3 = vector.mask %2 { vector.transfer_read %arg1[%c0_1], %cst_2 {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+  %4 = vector.mask %0 { vector.multi_reduction <add>, %1, %3 [1] : vector<4x8xf32> to vector<4xf32> } : vector<4x8xi1> -> vector<4xf32>
+  %c0_3 = arith.constant 0 : index
+  %5 = vector.mask %2 { vector.transfer_write %4, %arg1[%c0_3] {in_bounds = [true]} : vector<4xf32>, tensor<?xf32> } : vector<4xi1> -> tensor<?xf32>
+  return %5 : tensor<?xf32>
+}
+
+// Verify that the original 2-D mask is sliced and propagated properly to the
+// vector.reduction instances.
+
+// CHECK-LABEL:   func.func @vectorize_dynamic_reduction
+// CHECK:           %[[VAL_8:.*]] = tensor.dim
+// CHECK:           %[[VAL_9:.*]] = tensor.dim
+// CHECK:           %[[VAL_10:.*]] = vector.create_mask %[[VAL_8]], %[[VAL_9]] : vector<4x8xi1>
+
+// CHECK:           %[[VAL_16:.*]] = vector.extract %[[VAL_10]][0] : vector<4x8xi1>
+// CHECK:           %[[VAL_17:.*]] = vector.mask %[[VAL_16]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
+// CHECK:           %[[VAL_18:.*]] = vector.insertelement
+
+// CHECK:           %[[VAL_21:.*]] = vector.extract %[[VAL_10]][1] : vector<4x8xi1>
+// CHECK:           %[[VAL_22:.*]] = vector.mask %[[VAL_21]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
+// CHECK:           %[[VAL_23:.*]] = vector.insertelement
+
+// CHECK:           %[[VAL_26:.*]] = vector.extract %[[VAL_10]][2] : vector<4x8xi1>
+// CHECK:           %[[VAL_27:.*]] = vector.mask %[[VAL_26]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
+// CHECK:           %[[VAL_28:.*]] = vector.insertelement
+
+// CHECK:           %[[VAL_31:.*]] = vector.extract %[[VAL_10]][3] : vector<4x8xi1>
+// CHECK:           %[[VAL_32:.*]] = vector.mask %[[VAL_31]] { vector.reduction <add>, %{{.*}} : vector<8xf32> into f32 } : vector<8xi1> -> f32
+// CHECK:           %[[VAL_33:.*]] = vector.insertelement
+
+// -----
+
+func.func @vectorize_dynamic_transpose_reduction(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 0 : index
+  %dim = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
+  %c1 = arith.constant 1 : index
+  %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
+  %c2 = arith.constant 2 : index
+  %dim_1 = tensor.dim %arg0, %c2 : tensor<?x?x?xf32>
+  %c0_2 = arith.constant 0 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = vector.create_mask %dim, %dim_0, %dim_1 : vector<4x8x16xi1>
+  %1 = vector.mask %0 { vector.transfer_read %arg0[%c0_2, %c0_2, %c0_2], %cst {in_bounds = [true, true, true]} : tensor<?x?x?xf32>, vector<4x8x16xf32> } : vector<4x8x16xi1> -> vector<4x8x16xf32>
+  %cst_3 = arith.constant 0.000000e+00 : f32
+  %2 = vector.create_mask %dim_1, %dim_0 : vector<16x8xi1>
+  %3 = vector.mask %2 { vector.transfer_read %arg1[%c0_2, %c0_2], %cst_3 {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : tensor<?x?xf32>, vector<8x16xf32> } : vector<16x8xi1> -> vector<8x16xf32>
+  %4 = vector.mask %0 { vector.multi_reduction <add>, %1, %3 [0] : vector<4x8x16xf32> to vector<8x16xf32> } : vector<4x8x16xi1> -> vector<8x16xf32>
+  %c0_4 = arith.constant 0 : index
+  %5 = vector.mask %2 { vector.transfer_write %4, %arg1[%c0_4, %c0_4] {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, d0)>} : vector<8x16xf32>, tensor<?x?xf32> } : vector<16x8xi1> -> tensor<?x?xf32>
+  return %5 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL:   func.func @vectorize_dynamic_transpose_reduction
+// CHECK:           %[[VAL_6:.*]] = tensor.dim
+// CHECK:           %[[VAL_7:.*]] = tensor.dim
+// CHECK:           %[[VAL_8:.*]] = tensor.dim
+// CHECK:           %[[VAL_135:.*]] = vector.create_mask %{{.*}}, %{{.*}}, %{{.*}} : vector<4x8x16xi1>
+// CHECK:           %[[VAL_139:.*]] = vector.transpose %[[VAL_135]], [1, 2, 0] : vector<4x8x16xi1> to vector<8x16x4xi1>
+
+// Just checking a few instances to make sure the vector mask is properly propagated:
+
+// CHECK:           %[[VAL_143:.*]] = vector.extract %[[VAL_139]][0, 0] : vector<8x16x4xi1>
+// CHECK:           %[[VAL_144:.*]] = vector.mask %[[VAL_143]] { vector.reduction <add>
+// CHECK:           %[[VAL_145:.*]] = vector.insertelement %[[VAL_144]]
+
+// CHECK:           %[[VAL_148:.*]] = vector.extract %[[VAL_139]][0, 1] : vector<8x16x4xi1>
+// CHECK:           %[[VAL_149:.*]] = vector.mask %[[VAL_148]] { vector.reduction <add>
+// CHECK:           %[[VAL_150:.*]] = vector.insertelement %[[VAL_149]]
+
+// CHECK:           %[[VAL_153:.*]] = vector.extract %[[VAL_139]][0, 2] : vector<8x16x4xi1>
+// CHECK:           %[[VAL_154:.*]] = vector.mask %[[VAL_153]] { vector.reduction <add>
+// CHECK:           %[[VAL_155:.*]] = vector.insertelement %[[VAL_154]]
+
+// CHECK:           %[[VAL_158:.*]] = vector.extract %[[VAL_139]][0, 3] : vector<8x16x4xi1>
+// CHECK:           %[[VAL_159:.*]] = vector.mask %[[VAL_158]] { vector.reduction <add>
+// CHECK:           %[[VAL_160:.*]] = vector.insertelement %[[VAL_159]]
+