[mlir][Vector] Initial masking support in Linalg vectorizer
authorDiego Caballero <diegocaballero@google.com>
Thu, 24 Nov 2022 02:16:46 +0000 (02:16 +0000)
committerDiego Caballero <diegocaballero@google.com>
Tue, 13 Dec 2022 01:33:06 +0000 (01:33 +0000)
This patch introduces the initial bits to support vector masking
using the `vector.mask` operation. Vectorization changes should be
NFC for non-masked cases. We can't test masked cases directly until
we extend the Transform dialect to support masking.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D137690

13 files changed:
mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
mlir/include/mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td
mlir/include/mlir/Dialect/Vector/Transforms/Passes.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/lib/Dialect/Vector/IR/CMakeLists.txt
mlir/lib/Dialect/Vector/IR/VectorOps.cpp
mlir/lib/Dialect/Vector/Transforms/LowerVectorMask.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

index e5006ae..a04c48f 100644 (file)
@@ -591,6 +591,36 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         return result;
       }]
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Given a dimension of the iteration space of a Linalg operation, finds an
+        operand in the operation that is defined on such dimension. Returns
+        whether such operand was found or not. If found, also returns the
+        operand value and the dimension position within the operand.
+      }],
+      /*retTy=*/"LogicalResult",
+      /*methodName=*/"mapIterationSpaceDimToOperandDim",
+      /*args=*/(ins "unsigned":$dimPos,
+                    "::mlir::Value &":$operand,
+                    "unsigned &":$operandDimPos),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        // Retrieve the operand and its dimension position from the first
+        // operand with a permutation map that is defined on such dimension.
+        for (auto [i, idxMap] : llvm::enumerate($_op.getIndexingMapsArray())) {
+          if (idxMap.isProjectedPermutation()) {
+            if (auto mayOperandDim = idxMap.getResultPosition(
+                getAffineDimExpr(dimPos, idxMap.getContext()))) {
+              operand = $_op->getOperand(i);
+              operandDimPos = *mayOperandDim;
+              return success();
+            }
+          }
+        }
+
+        return failure();
+      }]
+    >,
     //===------------------------------------------------------------------===//
     // Linalg generalization hooks.
     //===------------------------------------------------------------------===//
index 9fe6536..9c80f56 100644 (file)
@@ -1115,4 +1115,45 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
   }];
 }
 
+def MaskedVectorizeOp : Op<Transform_Dialect, "structured.masked_vectorize",
+    [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
+     TransformOpInterface]> {
+  let description = [{
+    Vectorize the target ops, which must be Linalg ops, with masked vectors
+    of the specified size.
+
+    The vector sizes can be either static or dynamic (SSA values). In case of
+    SSA values, the handle must be mapped to exactly one payload op with
+    exactly one index-typed result.
+
+    #### Return modes:
+
+    This operation produces a definite failure if the dynamic vector sizes (SSA
+    values) do not satify the constraints mentioned above. It produces a
+    silenceable failure if at least one target op is not a Linalg op or fails to
+    vectorize.
+  }];
+
+  let arguments = (ins PDL_Operation:$target,
+                       Variadic<PDL_Operation>:$vector_sizes,
+                       DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:
+                          $static_vector_sizes);
+  let results = (outs);
+  let assemblyFormat = [{
+      $target
+      `vector_sizes` custom<DynamicIndexList>($vector_sizes,
+                                              $static_vector_sizes)
+      attr-dict
+  }];
+
+  let extraClassDeclaration = [{
+    // TODO: applyToOne.
+    ::mlir::DiagnosedSilenceableFailure apply(
+        ::mlir::transform::TransformResults &transformResults,
+        ::mlir::transform::TransformState &state);
+
+    ::llvm::SmallVector<::mlir::OpFoldResult> getMixedVectorSizes();
+  }];
+}
+
 #endif // LINALG_TRANSFORM_OPS
index 7d6a584..7b3e072 100644 (file)
@@ -344,8 +344,14 @@ promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView,
 FailureOr<LinalgOp> promoteSubViews(OpBuilder &b, LinalgOp op,
                                     const LinalgPromotionOptions &options);
 
-/// Emit a suitable vector form for a Linalg op with fully static shape.
-LogicalResult vectorize(RewriterBase &builder, LinalgOp linalgOp,
+/// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes`
+/// are used to vectorize this operation. `inputVectorSizes` must match the rank
+/// of the iteration space of the operation and the sizes must be smaller or
+/// equal than their counterpart interation space sizes, if static.
+/// `inputVectorShapes` also allows the vectorization of operations with dynamic
+/// shapes.
+LogicalResult vectorize(RewriterBase &rewriter, LinalgOp linalgOp,
+                        ArrayRef<int64_t> inputVectorSizes = {},
                         bool vectorizeNDExtract = false);
 
 /// Emit a suitable vector form for a Copy op with fully static shape.
@@ -372,8 +378,10 @@ LogicalResult promoteSubviewsPrecondition(Operation *op,
                                           LinalgPromotionOptions options);
 
 /// Return success if the operation can be vectorized.
-LogicalResult vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
-                                            bool vectorizeNDExtract = false);
+LogicalResult
+vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
+                              ArrayRef<int64_t> inputVectorSizes = {},
+                              bool vectorizeNDExtract = false);
 
 //===----------------------------------------------------------------------===//
 // Transformations exposed as rewrite patterns.
index f035ab3..176d709 100644 (file)
@@ -450,13 +450,13 @@ def Vector_BroadcastOp :
     /// source tensor and thus correspond to "dim-1" broadcasting.
     llvm::SetVector<int64_t> computeBroadcastedUnitDims();
 
-    /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the 
+    /// Broadcast `value` to a vector of `dstShape`, knowing that exactly the
     /// `broadcastedDims` dimensions in the dstShape are broadcasted.
-    /// This requires (and asserts) that the broadcast is free of dim-1 
+    /// This requires (and asserts) that the broadcast is free of dim-1
     /// broadcasting.
     /// Since vector.broadcast only allows expanding leading dimensions, an extra
     /// vector.transpose may be inserted to make the broadcast possible.
-    /// `value`, `dstShape` and `broadcastedDims` must be properly specified or 
+    /// `value`, `dstShape` and `broadcastedDims` must be properly specified or
     /// the helper will assert. This means:
     ///   1. `dstShape` must not be empty.
     ///   2. `broadcastedDims` must be confined to [0 .. rank(value.getVectorType)]
@@ -1179,6 +1179,8 @@ def Vector_ExtractStridedSliceOp :
   let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
 }
 
+// TODO: Tighten semantics so that masks and inbounds can't be used
+// simultaneously within the same transfer op.
 def Vector_TransferReadOp :
   Vector_Op<"transfer_read", [
       DeclareOpInterfaceMethods<VectorTransferOpInterface>,
@@ -1394,6 +1396,8 @@ def Vector_TransferReadOp :
   let hasVerifier = 1;
 }
 
+// TODO: Tighten semantics so that masks and inbounds can't be used
+// simultaneously within the same transfer op.
 def Vector_TransferWriteOp :
   Vector_Op<"transfer_write", [
       DeclareOpInterfaceMethods<VectorTransferOpInterface>,
index bbde7bc..184ca68 100644 (file)
@@ -31,7 +31,9 @@ def MaskableOpInterface : OpInterface<"MaskableOpInterface"> {
       /*args=*/(ins),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return mlir::isa<mlir::vector::MaskingOpInterface>($_op->getParentOp());
+        mlir::Operation *parentOp = $_op->getParentOp();
+        return parentOp &&
+               mlir::isa<mlir::vector::MaskingOpInterface>(parentOp);
     }]>,
     InterfaceMethod<
       /*desc=*/"Returns the MaskingOpInterface masking this operation.",
@@ -54,18 +56,14 @@ def MaskableOpInterface : OpInterface<"MaskableOpInterface"> {
         return false;
     }]>,
     InterfaceMethod<
-      /*desc=*/"Returns the mask type expected by this operation. It requires "
-               "the operation to be vectorized.",
-      /*retTy=*/"mlir::VectorType",
+      /*desc=*/"Returns the mask type expected by this operation. Mostly used"
+               " for verification purposes. It requires the operation to be "
+               "vectorized.",
+      /*retTy=*/"mlir::Type",
       /*methodName=*/"getExpectedMaskType",
       /*args=*/(ins),
       /*methodBody=*/"",
-      /*defaultImplementation=*/[{
-      // Default implementation is only aimed for operations that implement the
-      // `getVectorType()` method.
-        return $_op.getVectorType().cloneWith(/*shape=*/std::nullopt,
-          IntegerType::get($_op.getContext(), /*width=*/1));
-    }]>,
+      /*defaultImplementation=*/"">,
   ];
 }
 
index bf89b01..d0c06f6 100644 (file)
@@ -22,6 +22,12 @@ std::unique_ptr<Pass> createVectorBufferizePass();
 /// Creates an instance of the `vector.mask` lowering pass.
 std::unique_ptr<Pass> createLowerVectorMaskPass();
 
+/// Populates instances of `MaskOpRewritePattern` to lower masked operations
+/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
+/// not its nested `MaskableOpInterface`.
+void populateVectorMaskLoweringPatternsForSideEffectingOps(
+    RewritePatternSet &patterns);
+
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
index dc6badd..79738d1 100644 (file)
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/Dialect/Transform/IR/TransformUtils.h"
 #include "mlir/IR/BuiltinTypes.h"
+#include "mlir/IR/Matchers.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -1825,7 +1826,8 @@ struct VectorizationPattern : public RewritePattern {
     LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
     if (!linalgOp)
       return rewriter.notifyMatchFailure(op, "expected Linalg Op");
-    return vectorize(rewriter, linalgOp, vectorizeNDExtract);
+    return vectorize(rewriter, linalgOp, /*inputVectorSizes=*/{},
+                     vectorizeNDExtract);
   }
 
 private:
@@ -1874,6 +1876,85 @@ transform::VectorizeOp::applyToOne(Operation *target,
 }
 
 //===----------------------------------------------------------------------===//
+// MaskedVectorizeOp
+//===----------------------------------------------------------------------===//
+
+DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply(
+    mlir::transform::TransformResults &transformResults,
+    mlir::transform::TransformState &state) {
+  IRRewriter rewriter(getContext());
+  ArrayRef<Operation *> targets = state.getPayloadOps(getTarget());
+  if (targets.empty())
+    return DiagnosedSilenceableFailure::success();
+
+  SmallVector<int64_t> vectorSizes;
+  for (OpFoldResult sz : getMixedVectorSizes()) {
+    if (sz.is<Attribute>()) {
+      auto attr = sz.get<Attribute>();
+      vectorSizes.push_back(attr.cast<IntegerAttr>().getInt());
+      continue;
+    }
+
+    ArrayRef<Operation *> szPayloads = state.getPayloadOps(sz.get<Value>());
+    if (szPayloads.size() != 1) {
+      auto diag = this->emitOpError(
+          "requires vector size handle that is mapped to 1 payload op");
+      diag.attachNote(sz.get<Value>().getLoc())
+          << "mapped to " << szPayloads.size() << " payload ops";
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+
+    Operation *szPayloadOp = szPayloads[0];
+    if (szPayloadOp->getNumResults() != 1 ||
+        !szPayloadOp->getResult(0).getType().isIndex()) {
+      auto diag = this->emitOpError(
+          "requires vector size payload op with 1 index result");
+      diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op";
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+
+    IntegerAttr attr;
+    if (!matchPattern(szPayloadOp->getResult(0), m_Constant(&attr))) {
+      auto diag = this->emitOpError("requires constant vector size");
+      diag.attachNote(szPayloadOp->getLoc()) << "vector size payload op";
+      return DiagnosedSilenceableFailure::definiteFailure();
+    }
+
+    vectorSizes.push_back(attr.getInt());
+  }
+
+  // TODO: Check that the correct number of vectorSizes was provided.
+
+  for (Operation *target : targets) {
+    auto linalgOp = dyn_cast<LinalgOp>(target);
+    if (!linalgOp) {
+      Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error);
+      diag << "cannot vectorize non-Linalg op";
+      return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+    }
+
+    if (failed(linalg::vectorize(rewriter, linalgOp, vectorSizes))) {
+      Diagnostic diag(target->getLoc(), DiagnosticSeverity::Error);
+      diag << "failed to vectorize op";
+      return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
+    }
+  }
+
+  return DiagnosedSilenceableFailure::success();
+}
+
+void transform::MaskedVectorizeOp::getEffects(
+    SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
+  consumesHandle(getTarget(), effects);
+  onlyReadsHandle(getVectorSizes(), effects);
+}
+
+SmallVector<OpFoldResult> MaskedVectorizeOp::getMixedVectorSizes() {
+  OpBuilder b(getContext());
+  return getMixedValues(getStaticVectorSizes(), getVectorSizes(), b);
+}
+
+//===----------------------------------------------------------------------===//
 // Transform op registration
 //===----------------------------------------------------------------------===//
 
index a7c3c00..89140d4 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Analysis/SliceAnalysis.h"
-#include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Func/IR/FuncOps.h"
-#include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StructuredOpsUtils.h"
 #include "mlir/Dialect/Vector/IR/VectorOps.h"
-#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
+#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
 #include "mlir/IR/AffineExpr.h"
-#include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
-#include "mlir/Pass/Pass.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/RegionUtils.h"
-#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/Sequence.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/TypeSwitch.h"
@@ -65,6 +60,266 @@ static OpType getSingleOpOfType(Block &block) {
   return res;
 }
 
+/// Contains the vectorization state and related methods used across the
+/// vectorization process of a given operation.
+struct VectorizationState {
+  VectorizationState(RewriterBase &rewriter) : rewriterGuard(rewriter) {}
+
+  /// Initializes the vectorization state, including the computation of the
+  /// canonical vector shape for vectorization.
+  LogicalResult initState(RewriterBase &rewriter, LinalgOp linalgOp,
+                          ArrayRef<int64_t> inputVectorSizes);
+
+  /// Returns the canonical vector shape used to vectorize the iteration space.
+  ArrayRef<int64_t> getCanonicalVecShape() const { return canonicalVecShape; }
+
+  /// Masks an operation with the canonical vector mask if the operation needs
+  /// masking. Returns the masked operation or the original operation if masking
+  /// is not needed. If provided, the canonical mask for this operation is
+  /// permuted using `maybeMaskingMap`.
+  Operation *maskOperation(RewriterBase &rewriter, Operation *opToMask,
+                           LinalgOp linalgOp,
+                           Optional<AffineMap> maybeMaskingMap = std::nullopt);
+
+private:
+  /// Initializes the iteration space static sizes using the Linalg op
+  /// information. This may become more complicated in the future.
+  void initIterSpaceStaticSizes(LinalgOp linalgOp) {
+    iterSpaceStaticSizes.append(linalgOp.getStaticLoopRanges());
+  }
+
+  /// Generates 'tensor.dim' operations for all the dynamic dimensions of the
+  /// iteration space to be vectorized and store them in
+  /// `iterSpaceDynamicSizes`.
+  LogicalResult precomputeIterSpaceDynamicSizes(RewriterBase &rewriter,
+                                                LinalgOp linalgOp);
+
+  /// Create or retrieve an existing mask value to mask `opToMask` in the
+  /// canonical vector iteration space. If `maybeMaskingMap` the mask is
+  /// permuted using that permutation map. If a new mask is created, it will be
+  /// cached for future users.
+  Value getOrCreateMaskFor(RewriterBase &rewriter, Operation *opToMask,
+                           LinalgOp linalgOp,
+                           Optional<AffineMap> maybeMaskingMap);
+
+  // Holds the compile-time static sizes of the iteration space to vectorize.
+  // Dynamic dimensions are represented using ShapedType::kDynamicSize.
+  SmallVector<int64_t> iterSpaceStaticSizes;
+
+  /// Holds the runtime sizes of the iteration spaces to vectorize. Static
+  /// dimensions are represented with a empty value.
+  SmallVector<Value> iterSpaceDynamicSizes;
+
+  /// Holds the canonical vector shape used to vectorize the iteration space.
+  SmallVector<int64_t> canonicalVecShape;
+
+  /// Holds the active masks for permutations of the canonical vector iteration
+  /// space.
+  DenseMap<AffineMap, Value> activeMaskCache;
+
+  /// Global vectorization guard for the incoming rewriter. It's initialized
+  /// when the vectorization state is initialized.
+  OpBuilder::InsertionGuard rewriterGuard;
+};
+
+/// Generates 'tensor.dim' operations for all the dynamic dimensions of the
+/// iteration space to be vectorized and store them in
+/// `iterSpaceDynamicSizes`.
+LogicalResult
+VectorizationState::precomputeIterSpaceDynamicSizes(RewriterBase &rewriter,
+                                                    LinalgOp linalgOp) {
+  // TODO: Support 0-d vectors.
+  for (int vecDim = 0, end = canonicalVecShape.size(); vecDim < end; ++vecDim) {
+    if (!ShapedType::isDynamic(iterSpaceStaticSizes[vecDim])) {
+      // Add a empty value for static dimensions.
+      iterSpaceDynamicSizes.push_back(Value());
+      continue;
+    }
+
+    // Find an operand defined on this dimension of the iteration space to
+    // extract the runtime dimension size.
+    Value operand;
+    unsigned operandDimPos;
+    if (failed(linalgOp.mapIterationSpaceDimToOperandDim(vecDim, operand,
+                                                         operandDimPos)))
+      return failure();
+
+    Value dynamicDim = linalgOp.hasTensorSemantics()
+                           ? (Value)rewriter.create<tensor::DimOp>(
+                                 linalgOp.getLoc(), operand, operandDimPos)
+                           : (Value)rewriter.create<memref::DimOp>(
+                                 linalgOp.getLoc(), operand, operandDimPos);
+    iterSpaceDynamicSizes.push_back(dynamicDim);
+  }
+
+  return success();
+}
+
+/// Initializes the vectorization state, including the computation of the
+/// canonical vector shape for vectorization.
+// TODO: Move this to the constructor when we can remove the failure cases.
+LogicalResult
+VectorizationState::initState(RewriterBase &rewriter, LinalgOp linalgOp,
+                              ArrayRef<int64_t> inputVectorSizes) {
+  // Initialize the insertion point.
+  rewriter.setInsertionPoint(linalgOp);
+
+  if (!inputVectorSizes.empty()) {
+    // Get the canonical vector shape from the input vector sizes provided. This
+    // path should be taken to vectorize code with dynamic shapes and when using
+    // vector sizes greater than the iteration space sizes.
+    canonicalVecShape.append(inputVectorSizes.begin(), inputVectorSizes.end());
+  } else {
+    // Compute the canonical vector shape from the operation shape. If there are
+    // dynamic shapes, the operation won't be vectorized.
+    canonicalVecShape = linalgOp.getStaticLoopRanges();
+  }
+
+  LDBG("Canonical vector shape: ");
+  LLVM_DEBUG(llvm::interleaveComma(canonicalVecShape, llvm::dbgs()));
+  LLVM_DEBUG(llvm::dbgs() << "\n");
+
+  // Initialize iteration space static sizes.
+  initIterSpaceStaticSizes(linalgOp);
+
+  // Extract and register the runtime value of any potential dynamic shape
+  // needed to compute a mask during vectorization.
+  if (failed(precomputeIterSpaceDynamicSizes(rewriter, linalgOp)))
+    return failure();
+
+  if (ShapedType::isDynamicShape(canonicalVecShape))
+    return failure();
+  return success();
+}
+
+/// Create or retrieve an existing mask value to mask `opToMask` in the
+/// canonical vector iteration space. If `maybeMaskingMap` the mask is permuted
+/// using that permutation map. If a new mask is created, it will be cached for
+/// future users.
+Value VectorizationState::getOrCreateMaskFor(
+    RewriterBase &rewriter, Operation *opToMask, LinalgOp linalgOp,
+    Optional<AffineMap> maybeMaskingMap) {
+  // No mask is needed if the operation is not maskable.
+  auto maskableOp = dyn_cast<vector::MaskableOpInterface>(opToMask);
+  if (!maskableOp)
+    return Value();
+
+  assert(!maskableOp.isMasked() &&
+         "Masking an operation that is already masked");
+
+  // If no masking map was provided, use an identity map with the loop dims.
+  assert((!maybeMaskingMap || *maybeMaskingMap) &&
+         "Unexpected null mask permutation map");
+  AffineMap maskingMap =
+      maybeMaskingMap ? *maybeMaskingMap
+                      : AffineMap::getMultiDimIdentityMap(
+                            linalgOp.getNumLoops(), rewriter.getContext());
+
+  LDBG("Masking map: " << maskingMap << "\n");
+
+  // Return the active mask for the masking map of this operation if it was
+  // already created.
+  auto activeMaskIt = activeMaskCache.find(maskingMap);
+  if (activeMaskIt != activeMaskCache.end()) {
+    Value mask = activeMaskIt->second;
+    LDBG("Reusing mask: " << mask << "\n");
+    return mask;
+  }
+
+  // Compute permuted projection of the iteration space to be masked and the
+  // corresponding mask shape. If the resulting iteration space dimensions are
+  // static and identical to the mask shape, masking is not needed for this
+  // operation.
+  // TODO: Improve this check. Only projected permutation indexing maps are
+  // supported.
+  SmallVector<int64_t> permutedStaticSizes =
+      applyPermutationMap(maskingMap, ArrayRef<int64_t>(iterSpaceStaticSizes));
+  SmallVector<int64_t> maskShape =
+      applyPermutationMap(maskingMap, ArrayRef<int64_t>(canonicalVecShape));
+  LDBG("Mask shape: ");
+  LLVM_DEBUG(llvm::interleaveComma(maskShape, llvm::dbgs()));
+  LLVM_DEBUG(llvm::dbgs() << "\n");
+
+  if (permutedStaticSizes == maskShape) {
+    LDBG("Masking is not needed for masking map: " << maskingMap << "\n");
+    activeMaskCache[maskingMap] = Value();
+    return Value();
+  }
+
+  // Compute the mask upper bound values by combining the permuted iteration
+  // space static sizes and the dynamic values.
+  SmallVector<Value> permutedDynamicSizes =
+      applyPermutationMap(maskingMap, ArrayRef<Value>(iterSpaceDynamicSizes));
+  SmallVector<Value> upperBounds;
+  for (auto [staticBound, dynBound] :
+       llvm::zip(permutedStaticSizes, permutedDynamicSizes))
+    upperBounds.push_back(ShapedType::isDynamic(staticBound)
+                              ? dynBound
+                              : rewriter.create<arith::ConstantIndexOp>(
+                                    linalgOp.getLoc(), staticBound));
+
+  assert(!maskShape.empty() && !upperBounds.empty() &&
+         "Masked 0-d vectors are not supported yet");
+
+  // Create the mask based on the dimension size values.
+  auto maskType = VectorType::get(maskShape, rewriter.getI1Type());
+  Value mask = rewriter.create<vector::CreateMaskOp>(linalgOp.getLoc(),
+                                                     maskType, upperBounds);
+  LDBG("Creating new mask: " << mask << "\n");
+  activeMaskCache[maskingMap] = mask;
+  return mask;
+}
+
+/// Masks an operation with the canonical vector mask if the operation needs
+/// masking. Returns the masked operation or the original operation if masking
+/// is not needed. If provided, the canonical mask for this operation is
+/// permuted using `maybeMaskingMap`.
+Operation *
+VectorizationState::maskOperation(RewriterBase &rewriter, Operation *opToMask,
+                                  LinalgOp linalgOp,
+                                  Optional<AffineMap> maybeMaskingMap) {
+  LDBG("Trying to mask: " << *opToMask << "\n");
+
+  // Create or retrieve mask for this operation.
+  Value mask =
+      getOrCreateMaskFor(rewriter, opToMask, linalgOp, maybeMaskingMap);
+
+  if (!mask) {
+    LDBG("No mask required\n");
+    return opToMask;
+  }
+
+  // Wrap the operation with a new `vector.mask` and update D-U chain.
+  assert(opToMask && "Expected a valid operation to mask");
+  auto opResults = opToMask->getResultTypes();
+  auto createRegionMask = [opToMask](OpBuilder &builder, Location loc) {
+    Block *insBlock = builder.getInsertionBlock();
+    // Create a block, put an op in that block. Look for a utility.
+    // Maybe in conversion pattern rewriter. Way to avoid splice.
+    // Set insertion point.
+    insBlock->getOperations().splice(
+        insBlock->begin(), opToMask->getBlock()->getOperations(), opToMask);
+    builder.create<vector::YieldOp>(loc, opToMask->getResults());
+  };
+  // TODO: Allow multiple results in vector.mask.
+  auto maskOp =
+      opResults.empty()
+          ? rewriter.create<vector::MaskOp>(opToMask->getLoc(), mask,
+                                            createRegionMask)
+          : rewriter.create<vector::MaskOp>(opToMask->getLoc(),
+                                            opToMask->getResultTypes().front(),
+                                            mask, createRegionMask);
+
+  Operation *maskOpTerminator = &maskOp.getMaskRegion().front().back();
+
+  for (auto [resIdx, resVal] : llvm::enumerate(opToMask->getResults()))
+    rewriter.replaceAllUsesExcept(resVal, maskOp.getResult(resIdx),
+                                  maskOpTerminator);
+
+  LDBG("Masked operation: " << *maskOp << "\n");
+  return maskOp;
+}
+
 /// Given an indexing `map` coming from a LinalgOp indexing, restricted to a
 /// projectedPermutation, compress the unused dimensions to serve as a
 /// permutation_map for a vector transfer operation.
@@ -204,35 +459,44 @@ static SmallVector<bool> getReductionMask(LinalgOp linalgOp) {
 /// Return the produced value or null if no value is produced.
 // Note: this is a true builder that notifies the OpBuilder listener.
 // TODO: Consider moving as a static helper on the ReduceOp.
-static Value buildVectorWrite(OpBuilder &b, Value value,
-                              OpOperand *outputOperand) {
-  Operation *write;
+static Value buildVectorWrite(RewriterBase &rewriter, Value value,
+                              OpOperand *outputOperand,
+                              VectorizationState &state) {
   Location loc = value.getLoc();
   auto linalgOp = cast<LinalgOp>(outputOperand->getOwner());
-  ArrayRef<int64_t> shape = linalgOp.getShape(outputOperand);
-  auto vectorType = VectorType::get(
-      shape, getElementTypeOrSelf(outputOperand->get().getType()));
+  AffineMap opOperandMap = linalgOp.getMatchingIndexingMap(outputOperand);
+  auto vectorType =
+      VectorType::get(opOperandMap.compose(state.getCanonicalVecShape()),
+                      getElementTypeOrSelf(outputOperand->get().getType()));
+
+  Operation *write;
   if (vectorType.getRank() > 0) {
-    // 0-d case is still special: do not invert the reindexing map.
-    AffineMap map =
-        reindexIndexingMap(linalgOp.getMatchingIndexingMap(outputOperand));
-    SmallVector<int64_t> transposeShape =
-        applyPermutationMap(inversePermutation(map), vectorType.getShape());
-    assert(!transposeShape.empty() && "unexpected empty transpose shape");
-    vectorType = VectorType::get(transposeShape, vectorType.getElementType());
+    AffineMap writeMap = reindexIndexingMap(opOperandMap);
     SmallVector<Value> indices(linalgOp.getRank(outputOperand),
-                               b.create<arith::ConstantIndexOp>(loc, 0));
-    value = broadcastIfNeeded(b, value, vectorType.getShape());
-    write = b.create<vector::TransferWriteOp>(
-        loc, value, outputOperand->get(), indices, map);
+                               rewriter.create<arith::ConstantIndexOp>(loc, 0));
+    value = broadcastIfNeeded(rewriter, value, vectorType.getShape());
+    write = rewriter.create<vector::TransferWriteOp>(
+        loc, value, outputOperand->get(), indices, writeMap);
   } else {
+    // 0-d case is still special: do not invert the reindexing writeMap.
     if (!value.getType().isa<VectorType>())
-      value = b.create<vector::BroadcastOp>(loc, vectorType, value);
+      value = rewriter.create<vector::BroadcastOp>(loc, vectorType, value);
     assert(value.getType() == vectorType && "incorrect type");
-    write = b.create<vector::TransferWriteOp>(
+    write = rewriter.create<vector::TransferWriteOp>(
         loc, value, outputOperand->get(), ValueRange{});
   }
-  LDBG("vectorized op: " << *write);
+
+  write = state.maskOperation(rewriter, write, linalgOp, opOperandMap);
+
+  // If masked, set in-bounds to true. Masking guarantees that the access will
+  // be in-bounds.
+  if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(write)) {
+    auto maskedWriteOp = cast<vector::TransferWriteOp>(maskOp.getMaskableOp());
+    SmallVector<bool> inBounds(maskedWriteOp.getVectorType().getRank(), true);
+    maskedWriteOp.setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
+  }
+
+  LDBG("vectorized op: " << *write << "\n");
   if (!write->getResults().empty())
     return write->getResult(0);
   return Value();
@@ -259,20 +523,22 @@ using CustomVectorizationHook = std::function<VectorizationResult(
 /// CustomVectorizationHook.
 static VectorizationResult
 vectorizeLinalgYield(RewriterBase &rewriter, Operation *op,
-                     const BlockAndValueMapping &bvm, LinalgOp linalgOp,
-                     SmallVectorImpl<Value> &newResults) {
+                     const BlockAndValueMapping &bvm, VectorizationState &state,
+                     LinalgOp linalgOp, SmallVectorImpl<Value> &newResults) {
   auto yieldOp = dyn_cast<linalg::YieldOp>(op);
   if (!yieldOp)
     return VectorizationResult{VectorizationStatus::Failure, nullptr};
-  for (const auto &outputs : llvm::enumerate(yieldOp.getValues())) {
+  for (const auto &output : llvm::enumerate(yieldOp.getValues())) {
     // TODO: Scan for an opportunity for reuse.
     // TODO: use a map.
-    Value vectorValue = bvm.lookup(outputs.value());
-    Value newResult = buildVectorWrite(
-        rewriter, vectorValue, linalgOp.getDpsInitOperand(outputs.index()));
+    Value vectorValue = bvm.lookup(output.value());
+    Value newResult =
+        buildVectorWrite(rewriter, vectorValue,
+                         linalgOp.getDpsInitOperand(output.index()), state);
     if (newResult)
       newResults.push_back(newResult);
   }
+
   return VectorizationResult{VectorizationStatus::NoReplace, nullptr};
 }
 
@@ -464,7 +730,7 @@ static VectorizationResult
 vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op,
                const BlockAndValueMapping &bvm,
                ArrayRef<CustomVectorizationHook> customVectorizationHooks) {
-  LDBG("vectorize op " << *op);
+  LDBG("vectorize op " << *op << "\n");
 
   // 1. Try to apply any CustomVectorizationHook.
   if (!customVectorizationHooks.empty()) {
@@ -561,8 +827,10 @@ vectorizeOneOp(RewriterBase &rewriter, LinalgOp linalgOp, Operation *op,
 /// This is not deemed a problem as we expect canonicalizations and foldings to
 /// aggressively clean up the useless work.
 static LogicalResult
-vectorizeAsLinalgGeneric(RewriterBase &rewriter, LinalgOp linalgOp,
+vectorizeAsLinalgGeneric(RewriterBase &rewriter, VectorizationState &state,
+                         LinalgOp linalgOp,
                          SmallVectorImpl<Value> &newResults) {
+  LDBG("Vectorizing operation as linalg generic\n");
   Block *block = linalgOp.getBlock();
 
   // 2. Values defined above the region can only be broadcast for now. Make them
@@ -575,11 +843,6 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, LinalgOp linalgOp,
   if (linalgOp.getNumDpsInits() == 0)
     return failure();
 
-  // TODO: the common vector shape is equal to the static loop sizes only when
-  // all indexing maps are projected permutations. For convs and stencils the
-  // logic will need to evolve.
-  SmallVector<int64_t> commonVectorShape = linalgOp.computeStaticLoopSizes();
-
   // 3. Turn all BBArgs into vector.transfer_read / load.
   Location loc = linalgOp.getLoc();
   Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
@@ -589,35 +852,60 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, LinalgOp linalgOp,
       bvm.map(bbarg, opOperand->get());
       continue;
     }
-    VectorType readType;
-    AffineMap map;
-    // TODO: can we keep this simplification?
-    // if (linalgOp.getShape(&opOperand).empty()) {
-    //   readType = VectorType::get({}, bbarg.getType());
-    // } else {
-    if (opOperand->getOperandNumber() < linalgOp.getNumDpsInputs()) {
-      map = inverseAndBroadcastProjectedPermutation(
-          linalgOp.getMatchingIndexingMap(opOperand));
-      readType = VectorType::get(commonVectorShape,
-                                 getElementTypeOrSelf(opOperand->get()));
+
+    // 3.a. Convert the indexing map for this input/output to a transfer read
+    // permutation map and masking map.
+    AffineMap indexingMap = linalgOp.getMatchingIndexingMap(opOperand);
+
+    // Remove zeros from indexing map to use it as masking map.
+    SmallVector<int64_t> zeroPos;
+    auto results = indexingMap.getResults();
+    for (auto result : llvm::enumerate(results)) {
+      if (result.value().isa<AffineConstantExpr>()) {
+        zeroPos.push_back(result.index());
+      }
+    }
+    AffineMap maskingMap = indexingMap.dropResults(zeroPos);
+
+    AffineMap readMap;
+    SmallVector<int64_t> readVecShape;
+    if (linalgOp.isDpsInput(opOperand)) {
+      // 3.a.i. For input reads we use the canonical vector shape.
+      readMap = inverseAndBroadcastProjectedPermutation(indexingMap);
+      readVecShape = llvm::to_vector(state.getCanonicalVecShape());
     } else {
-      map = inversePermutation(
-          reindexIndexingMap(linalgOp.getMatchingIndexingMap(opOperand)));
-      readType = VectorType::get(map.compose(linalgOp.getShape(opOperand)),
-                                 getElementTypeOrSelf(opOperand->get()));
+      // 3.a.ii. For output reads (iteration-carried dependence, e.g.,
+      // reductions), the vector shape is computed by mapping the canonical
+      // vector shape to the output domain and back to the canonical domain.
+      readMap = inversePermutation(reindexIndexingMap(indexingMap));
+      readVecShape =
+          readMap.compose(indexingMap.compose(state.getCanonicalVecShape()));
     }
-    // }
 
-    auto shape = linalgOp.getShape(opOperand);
-    SmallVector<Value> indices(shape.size(), zero);
-    Value readValue = rewriter.create<vector::TransferReadOp>(
-        loc, readType, opOperand->get(), indices, map);
-    // Not all ops support 0-d vectors, extract the scalar for now.
+    auto readType =
+        VectorType::get(readVecShape, getElementTypeOrSelf(opOperand->get()));
+    SmallVector<Value> indices(linalgOp.getShape(opOperand).size(), zero);
+
+    Operation *read = rewriter.create<vector::TransferReadOp>(
+        loc, readType, opOperand->get(), indices, readMap);
+    read = state.maskOperation(rewriter, read, linalgOp, maskingMap);
+    Value readValue = read->getResult(0);
+
+    // 3.b. If masked, set in-bounds to true. Masking guarantees that the access
+    // will be in-bounds.
+    if (auto maskOp = dyn_cast<vector::MaskingOpInterface>(read)) {
+      SmallVector<bool> inBounds(readType.getRank(), true);
+      cast<vector::TransferReadOp>(maskOp.getMaskableOp())
+          .setInBoundsAttr(rewriter.getBoolArrayAttr(inBounds));
+    }
+
+    // 3.c. Not all ops support 0-d vectors, extract the scalar for now.
     // TODO: remove this.
     if (readValue.getType().cast<VectorType>().getRank() == 0)
       readValue = rewriter.create<vector::ExtractElementOp>(loc, readValue);
 
-    LDBG("new vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue);
+    LDBG("New vectorized bbarg(" << bbarg.getArgNumber() << "): " << readValue
+                                 << "\n");
     bvm.map(bbarg, readValue);
     bvm.map(opOperand->get(), readValue);
   }
@@ -627,7 +915,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, LinalgOp linalgOp,
   CustomVectorizationHook vectorizeYield =
       [&](Operation *op,
           const BlockAndValueMapping &bvm) -> VectorizationResult {
-    return vectorizeLinalgYield(rewriter, op, bvm, linalgOp, newResults);
+    return vectorizeLinalgYield(rewriter, op, bvm, state, linalgOp, newResults);
   };
   hooks.push_back(vectorizeYield);
 
@@ -652,12 +940,14 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, LinalgOp linalgOp,
     VectorizationResult result =
         vectorizeOneOp(rewriter, linalgOp, &op, bvm, hooks);
     if (result.status == VectorizationStatus::Failure) {
-      LDBG("failed to vectorize: " << op);
+      LDBG("failed to vectorize: " << op << "\n");
       return failure();
     }
     if (result.status == VectorizationStatus::NewOp) {
-      LDBG("new vector op: " << *result.newOp;);
-      bvm.map(op.getResults(), result.newOp->getResults());
+      Operation *maybeMaskedOp =
+          state.maskOperation(rewriter, result.newOp, linalgOp);
+      LDBG("New vector op: " << *maybeMaskedOp << "\n");
+      bvm.map(op.getResults(), maybeMaskedOp->getResults());
     }
   }
 
@@ -668,7 +958,7 @@ vectorizeAsLinalgGeneric(RewriterBase &rewriter, LinalgOp linalgOp,
 // ops that may not commute (e.g. linear reduction + non-linear instructions).
 static LogicalResult reductionPreconditions(LinalgOp op) {
   if (llvm::none_of(op.getIteratorTypesArray(), isReductionIterator)) {
-    LDBG("reduction precondition failed: no reduction iterator");
+    LDBG("reduction precondition failed: no reduction iterator\n");
     return failure();
   }
   for (OpOperand *opOperand : op.getDpsInitOperands()) {
@@ -678,20 +968,69 @@ static LogicalResult reductionPreconditions(LinalgOp op) {
 
     Operation *reduceOp = matchLinalgReduction(opOperand);
     if (!reduceOp || !getCombinerOpKind(reduceOp)) {
-      LDBG("reduction precondition failed: reduction detection failed");
+      LDBG("reduction precondition failed: reduction detection failed\n");
       return failure();
     }
   }
   return success();
 }
 
-static LogicalResult vectorizeStaticLinalgOpPrecondition(
-    linalg::LinalgOp op,
-    ArrayRef<CustomVectorizationPrecondition> customPreconditions,
-    bool vectorizeNDExtract) {
+static LogicalResult vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op) {
+  // TODO: Masking only supports dynamic generic ops without reductions for now.
+  if (!isElementwise(op) &&
+      llvm::any_of(op.getIteratorTypesArray(), [](utils::IteratorType itType) {
+        return itType != utils::IteratorType::parallel;
+      }))
+    return failure();
+
+  // TODO: 0-d vectors are not supported yet.
+  if (llvm::any_of(op.getIndexingMapsArray(), [](AffineMap map) {
+        return map.isEmpty() || map.getResults().empty();
+      }))
+    return failure();
+
+  LDBG("Dynamically-shaped op meets vectorization pre-conditions\n");
+  return success();
+}
+
+LogicalResult
+mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
+                                            ArrayRef<int64_t> inputVectorSizes,
+                                            bool vectorizeNDExtract) {
+  // Check API contract for input vector sizes.
+  if (!inputVectorSizes.empty()) {
+    assert(inputVectorSizes.size() == linalgOp.getNumLoops() &&
+           "Input vector sizes don't match the number of loops");
+    assert(!ShapedType::isDynamicShape(inputVectorSizes) &&
+           "Input vector sizes can't have dynamic dimensions");
+    assert(llvm::all_of(
+               llvm::zip(linalgOp.getStaticLoopRanges(), inputVectorSizes),
+               [](std::tuple<int64_t, int64_t> sizePair) {
+                 int64_t staticSize = std::get<0>(sizePair);
+                 int64_t inputSize = std::get<1>(sizePair);
+                 return ShapedType::isDynamic(staticSize) ||
+                        staticSize <= inputSize;
+               }) &&
+           "Input vector sizes must be smaller or equal than iteration space "
+           "static sizes");
+  }
+
+  // TODO: Masking is only supported for dynamic shapes so input vector sizes
+  // must be empty if the op is not dynamic.
+  if (!linalgOp.hasDynamicShape() && !inputVectorSizes.empty())
+    return failure();
+
+  if (linalgOp.hasDynamicShape() &&
+      failed(vectorizeDynamicLinalgOpPrecondition(linalgOp)))
+    return failure();
+
+  SmallVector<CustomVectorizationPrecondition> customPreconditions;
+
+  // Register CustomVectorizationPrecondition for extractOp.
+  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
 
   // All types in the body should be a supported element type for VectorType.
-  for (Operation &innerOp : op->getRegion(0).front()) {
+  for (Operation &innerOp : linalgOp->getRegion(0).front()) {
     // Check if any custom hook can vectorize the inner op.
     if (llvm::any_of(
             customPreconditions,
@@ -712,50 +1051,52 @@ static LogicalResult vectorizeStaticLinalgOpPrecondition(
       return failure();
     }
   }
-  if (isElementwise(op))
+  if (isElementwise(linalgOp))
     return success();
   // TODO: isaConvolutionOpInterface that can also infer from generic features.
   // But we will still need stride/dilation attributes that will be annoying to
   // reverse-engineer...
-  if (isa<ConvolutionOpInterface>(op.getOperation()))
+  if (isa<ConvolutionOpInterface>(linalgOp.getOperation()))
     return success();
   // TODO: the common vector shape is equal to the static loop sizes only when
   // all indexing maps are projected permutations. For convs and stencils the
   // logic will need to evolve.
-  if (!allIndexingsAreProjectedPermutation(op)) {
-    LDBG("precondition failed: not projected permutations");
+  if (!allIndexingsAreProjectedPermutation(linalgOp)) {
+    LDBG("precondition failed: not projected permutations\n");
     return failure();
   }
-  if (failed(reductionPreconditions(op))) {
-    LDBG("precondition failed: reduction preconditions");
+  if (failed(reductionPreconditions(linalgOp))) {
+    LDBG("precondition failed: reduction preconditions\n");
     return failure();
   }
   return success();
 }
 
-LogicalResult
-mlir::linalg::vectorizeLinalgOpPrecondition(LinalgOp linalgOp,
-                                            bool vectorizeNDExtract) {
-  // All types must be static shape to go to vector.
-  if (linalgOp.hasDynamicShape()) {
-    LDBG("precondition failed: dynamic shape");
-    return failure();
-  }
-
-  SmallVector<CustomVectorizationPrecondition> customPreconditions;
-
-  // Register CustomVectorizationPrecondition for extractOp.
-  customPreconditions.push_back(tensorExtractVectorizationPrecondition);
-
-  return vectorizeStaticLinalgOpPrecondition(linalgOp, customPreconditions,
-                                             vectorizeNDExtract);
-}
-
+/// Emit a suitable vector form for a Linalg op. If provided, `inputVectorSizes`
+/// are used to vectorize this operation. `inputVectorSizes` must match the rank
+/// of the iteration space of the operation and the sizes must be smaller or
+/// equal than their counterpart interation space sizes, if static.
+/// `inputVectorShapes` also allows the vectorization of operations with dynamic
+/// shapes.
 LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp,
+                                      ArrayRef<int64_t> inputVectorSizes,
                                       bool vectorizeNDExtract) {
-  if (failed(vectorizeLinalgOpPrecondition(linalgOp, vectorizeNDExtract)))
+  LDBG("Attempting to vectorize:\n" << linalgOp << "\n");
+  LDBG("Input vector sizes: ");
+  LLVM_DEBUG(llvm::interleaveComma(inputVectorSizes, llvm::dbgs()));
+  LLVM_DEBUG(llvm::dbgs() << "\n");
+
+  if (failed(vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
+                                           vectorizeNDExtract)))
     return failure();
 
+  // Initialize vectorization state.
+  VectorizationState state(rewriter);
+  if (failed(state.initState(rewriter, linalgOp, inputVectorSizes))) {
+    LDBG("Vectorization state couldn't be initialized\n");
+    return failure();
+  }
+
   SmallVector<Value> results;
   // TODO: isaConvolutionOpInterface that can also infer from generic
   // features. Will require stride/dilation attributes inference.
@@ -763,10 +1104,16 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, LinalgOp linalgOp,
   if (succeeded(convOr)) {
     llvm::append_range(results, (*convOr)->getResults());
   } else {
-    if (failed(vectorizeLinalgOpPrecondition(linalgOp, vectorizeNDExtract)))
+    if (failed(vectorizeLinalgOpPrecondition(linalgOp, inputVectorSizes,
+                                             vectorizeNDExtract)))
       return failure();
-    LDBG("Vectorize generic by broadcasting to a common shape: " << linalgOp);
-    if (failed(vectorizeAsLinalgGeneric(rewriter, linalgOp, results)))
+    LDBG("Vectorize generic by broadcasting to the canonical vector shape\n");
+    // TODO: 'vectorize' takes in a 'RewriterBase' which is up-casted to
+    // 'OpBuilder' when it is passed over to some methods like
+    // 'vectorizeAsLinalgGeneric'. This is highly problematic: if we erase an op
+    // within these methods, the actual rewriter won't be notified and we will
+    // end up with read-after-free issues!
+    if (failed(vectorizeAsLinalgGeneric(rewriter, state, linalgOp, results)))
       return failure();
   }
 
@@ -1262,7 +1609,7 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
   if (firstOp->getBlock() != secondOp->getBlock() ||
       !firstOp->isBeforeInBlock(secondOp)) {
     LDBG("interleavedUses precondition failed, firstOp: "
-         << *firstOp << ", second op: " << *secondOp);
+         << *firstOp << ", second op: " << *secondOp << "\n");
     return true;
   }
   for (auto v : values) {
@@ -1275,7 +1622,7 @@ static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
           (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
         continue;
       LDBG(" found interleaved op " << *owner << ", firstOp: " << *firstOp
-                                    << ", second op: " << *secondOp);
+                                    << ", second op: " << *secondOp << "\n");
       return true;
     }
   }
index 8dfa96f..596f642 100644 (file)
@@ -5,6 +5,8 @@ add_mlir_dialect_library(MLIRVectorDialect
   ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Vector/IR
 
   DEPENDS
+  MLIRMaskableOpInterfaceIncGen
+  MLIRMaskingOpInterfaceIncGen
   MLIRVectorOpsIncGen
   MLIRVectorOpsEnumsIncGen
 
index 18dae28..4c772c2 100644 (file)
@@ -447,6 +447,15 @@ void ReductionOp::print(OpAsmPrinter &p) {
   p << " : " << getVector().getType() << " into " << getDest().getType();
 }
 
+// MaskableOpInterface methods.
+
+/// Returns the mask type expected by this operation.
+Type ReductionOp::getExpectedMaskType() {
+  auto vecType = getVectorType();
+  return vecType.cloneWith(std::nullopt,
+                           IntegerType::get(vecType.getContext(), /*width=*/1));
+}
+
 Value mlir::vector::getVectorReductionOp(arith::AtomicRMWKind op,
                                          OpBuilder &builder, Location loc,
                                          Value vector) {
@@ -3461,6 +3470,14 @@ LogicalResult TransferReadOp::verify() {
                               [&](Twine t) { return emitOpError(t); });
 }
 
+// MaskableOpInterface methods.
+
+/// Returns the mask type expected by this operation. Mostly used for
+/// verification purposes. It requires the operation to be vectorized."
+Type TransferReadOp::getExpectedMaskType() {
+  return inferTransferReadMaskType(getVectorType(), getPermutationMap());
+}
+
 template <typename TransferOp>
 static bool isInBounds(TransferOp op, int64_t resultIdx, int64_t indicesIdx) {
   // TODO: support more aggressive createOrFold on:
@@ -3903,6 +3920,14 @@ LogicalResult TransferWriteOp::verify() {
                               [&](Twine t) { return emitOpError(t); });
 }
 
+// MaskableOpInterface methods.
+
+/// Returns the mask type expected by this operation. Mostly used for
+/// verification purposes.
+Type TransferWriteOp::getExpectedMaskType() {
+  return inferTransferWriteMaskType(getVectorType(), getPermutationMap());
+}
+
 /// Fold:
 /// ```
 ///    %t1 = ...
@@ -5377,9 +5402,10 @@ LogicalResult MaskOp::verify() {
         "expects result type to match maskable operation result type");
 
   // Mask checks.
-  if (getMask().getType() != maskableOp.getExpectedMaskType())
-    return emitOpError("expects a ") << maskableOp.getExpectedMaskType()
-                                     << " mask for the maskable operation";
+  Type expectedMaskType = maskableOp.getExpectedMaskType();
+  if (getMask().getType() != expectedMaskType)
+    return emitOpError("expects a ")
+           << expectedMaskType << " mask for the maskable operation";
 
   // Passthru checks.
   Value passthru = getPassthru();
index aa79e54..b225662 100644 (file)
@@ -109,15 +109,6 @@ public:
   }
 };
 
-/// Populates instances of `MaskOpRewritePattern` to lower masked operations
-/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
-/// not its nested `MaskableOpInterface`.
-void populateVectorMaskLoweringPatternsForSideEffectingOps(
-    RewritePatternSet &patterns) {
-  patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern>(
-      patterns.getContext());
-}
-
 struct LowerVectorMaskPass
     : public vector::impl::LowerVectorMaskPassBase<LowerVectorMaskPass> {
   using Base::Base;
@@ -141,6 +132,15 @@ struct LowerVectorMaskPass
 
 } // namespace
 
+/// Populates instances of `MaskOpRewritePattern` to lower masked operations
+/// with `vector.mask`. Patterns should rewrite the `vector.mask` operation and
+/// not its nested `MaskableOpInterface`.
+void vector::populateVectorMaskLoweringPatternsForSideEffectingOps(
+    RewritePatternSet &patterns) {
+  patterns.add<MaskedTransferReadOpPattern, MaskedTransferWriteOpPattern>(
+      patterns.getContext());
+}
+
 std::unique_ptr<Pass> mlir::vector::createLowerVectorMaskPass() {
   return std::make_unique<LowerVectorMaskPass>();
 }
index a9a536a..96c81d1 100644 (file)
@@ -1,7 +1,5 @@
 // RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s
 
-// -----
-
 // CHECK-LABEL: contraction_dot
 func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memref<f32>) {
 
@@ -130,7 +128,7 @@ transform.sequence failures(propagate) {
 
 // CHECK-LABEL: func @generic_output_transpose
 func.func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
-                         %C: memref<32x8xf32>) {
+                                    %C: memref<32x8xf32>) {
   //       CHECK: vector.transfer_read %{{.*}} : memref<8x16xf32>, vector<8x32x16xf32>
   //       CHECK: vector.transfer_read %{{.*}} : memref<16x32xf32>, vector<8x32x16xf32>
   //       CHECK: %[[ACC:.*]] = vector.transfer_read %{{.*}} : memref<32x8xf32>, vector<8x32xf32>
@@ -1608,3 +1606,147 @@ transform.sequence failures(propagate) {
   %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
   %2 = transform.structured.vectorize %1
 }
+
+// -----
+
+func.func @vectorize_dynamic_identity(%arg0: tensor<?xf32>,
+                                      %arg1: tensor<?xf32>,
+                                      %arg2: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>,
+                                         affine_map<(d0) -> (d0)>,
+                                         affine_map<(d0) -> (d0)>],
+                   iterator_types = ["parallel"] }
+    ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
+    outs(%arg2 : tensor<?xf32>) {
+    ^bb(%in0: f32, %in1: f32, %out: f32) :
+      %0 = arith.addf %in0, %in1 : f32
+      linalg.yield %0 : f32
+    } -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL:   @vectorize_dynamic_identity
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = tensor.dim %{{.*}}, %[[VAL_3]] : tensor<?xf32>
+// CHECK:           %[[VAL_7:.*]] = vector.create_mask %[[VAL_4]] : vector<4xi1>
+// CHECK:           %[[VAL_8:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK:           %[[VAL_10:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK:           %[[VAL_12:.*]] = vector.mask %[[VAL_7]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK:           %[[VAL_13:.*]] = arith.addf %[[VAL_8]], %[[VAL_10]] : vector<4xf32>
+// CHECK:           %[[VAL_14:.*]] = vector.mask %[[VAL_7]] { vector.transfer_write %{{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<?xf32> } : vector<4xi1> -> tensor<?xf32>
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  transform.structured.masked_vectorize %0 vector_sizes [4]
+}
+
+// -----
+
+func.func @vectorize_dynamic_1d_broadcast(%arg0: tensor<?xf32>,
+                                          %arg1: tensor<?xf32>,
+                                          %arg2: tensor<?xf32>) -> tensor<?xf32> {
+  %0 = linalg.generic { indexing_maps = [affine_map<(d0) -> (0)>,
+                                         affine_map<(d0) -> (d0)>,
+                                         affine_map<(d0) -> (d0)>],
+                        iterator_types = ["parallel"] }
+    ins(%arg0, %arg1 : tensor<?xf32>, tensor<?xf32>)
+    outs(%arg2 : tensor<?xf32>) {
+    ^bb(%in0: f32, %in1: f32, %out: f32) :
+      %0 = arith.addf %in0, %in1 : f32
+      linalg.yield %0 : f32
+    } -> tensor<?xf32>
+  return %0 : tensor<?xf32>
+}
+
+// CHECK-LABEL:   @vectorize_dynamic_1d_broadcast
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = tensor.dim %{{.*}}, %[[VAL_3]] : tensor<?xf32>
+// CHECK:           %[[VAL_7:.*]] = vector.transfer_read %{{.*}} {permutation_map = #{{.*}}} : tensor<?xf32>, vector<4xf32>
+// CHECK:           %[[VAL_9:.*]] = vector.create_mask %[[VAL_4]] : vector<4xi1>
+// CHECK:           %[[VAL_10:.*]] = vector.mask %[[VAL_9]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK:           %[[VAL_12:.*]] = vector.mask %[[VAL_9]] { vector.transfer_read %{{.*}} {in_bounds = [true]} : tensor<?xf32>, vector<4xf32> } : vector<4xi1> -> vector<4xf32>
+// CHECK:           %[[VAL_13:.*]] = arith.addf %[[VAL_7]], %[[VAL_10]] : vector<4xf32>
+// CHECK:           %[[VAL_14:.*]] = vector.mask %{{.*}} { vector.transfer_write %[[VAL_13]], {{.*}} {in_bounds = [true]} : vector<4xf32>, tensor<?xf32> } : vector<4xi1> -> tensor<?xf32>
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  transform.structured.masked_vectorize %0 vector_sizes [4]
+}
+
+// -----
+
+func.func @vectorize_dynamic_2d_transpose(%arg0: tensor<?x?xf32>,
+                                          %arg1: tensor<?x?xf32>,
+                                          %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d1, d0)>,
+                                         affine_map<(d0, d1) -> (d0, d1)>,
+                                         affine_map<(d0, d1) -> (d0, d1)>],
+                        iterator_types = ["parallel", "parallel"] }
+    ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%arg2 : tensor<?x?xf32>) {
+    ^bb(%in0: f32, %in1: f32, %out: f32) :
+      %0 = arith.addf %in0, %in1 : f32
+      linalg.yield %0 : f32
+    } -> tensor<?x?xf32>
+    return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL:   @vectorize_dynamic_2d_transpose
+// CHECK:           %[[VAL_3:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_4:.*]] = tensor.dim %{{.*}}, %[[VAL_3]] : tensor<?x?xf32>
+// CHECK:           %[[VAL_5:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_6:.*]] = tensor.dim %{{.*}}, %[[VAL_5]] : tensor<?x?xf32>
+// CHECK:           %[[VAL_9:.*]] = vector.create_mask %[[VAL_6]], %[[VAL_4]] : vector<8x4xi1>
+// CHECK:           %[[VAL_10:.*]] = vector.mask %[[VAL_9]] { vector.transfer_read %{{.*}} {in_bounds = [true, true], permutation_map = #{{.*}}} : tensor<?x?xf32>, vector<4x8xf32> } : vector<8x4xi1> -> vector<4x8xf32>
+// CHECK:           %[[VAL_12:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<4x8xi1>
+// CHECK:           %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.transfer_read %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32>
+// CHECK:           %[[VAL_14:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_15:.*]] = vector.mask %[[VAL_12]] { vector.transfer_read %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32>
+// CHECK:           %[[VAL_16:.*]] = arith.addf %[[VAL_10]], %[[VAL_13]] : vector<4x8xf32>
+// CHECK:           %[[VAL_17:.*]] = vector.mask %[[VAL_12]] { vector.transfer_write %[[VAL_16]], %{{.*}} {in_bounds = [true, true]} : vector<4x8xf32>, tensor<?x?xf32> } : vector<4x8xi1> -> tensor<?x?xf32>
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  transform.structured.masked_vectorize %0 vector_sizes [4, 8]
+}
+
+// -----
+
+func.func @vectorize_dynamic_generic_2d_broadcast(%arg0: tensor<?x?xf32>,
+                                                  %arg1: tensor<?x?xf32>,
+                                                  %arg2: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (0, d1)>,
+                                         affine_map<(d0, d1) -> (d0, d1)>,
+                                         affine_map<(d0, d1) -> (d0, d1)>],
+                        iterator_types = ["parallel", "parallel"] }
+    ins(%arg0, %arg1 : tensor<?x?xf32>, tensor<?x?xf32>)
+    outs(%arg2 : tensor<?x?xf32>) {
+    ^bb(%in0: f32, %in1: f32, %out: f32) :
+      %0 = arith.addf %in0, %in1 : f32
+      linalg.yield %0 : f32
+    } -> tensor<?x?xf32>
+  return %0 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL:   @vectorize_dynamic_generic_2d_broadcast
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_4:.*]] = tensor.dim %{{.*}}, %[[VAL_3]] : tensor<?x?xf32>
+// CHECK:           %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_6:.*]] = tensor.dim %{{.*}}, %[[VAL_5]] : tensor<?x?xf32>
+// CHECK:           %[[VAL_9:.*]] = vector.create_mask %[[VAL_6]] : vector<8xi1>
+// CHECK:           %[[VAL_10:.*]] = vector.mask %[[VAL_9]] { vector.transfer_read %{{.*}} {in_bounds = [true, true], permutation_map = #{{.*}}} : tensor<?x?xf32>, vector<4x8xf32> } : vector<8xi1> -> vector<4x8xf32>
+// CHECK:           %[[VAL_12:.*]] = vector.create_mask %[[VAL_4]], %[[VAL_6]] : vector<4x8xi1>
+// CHECK:           %[[VAL_13:.*]] = vector.mask %[[VAL_12]] { vector.transfer_read %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32>
+// CHECK:           %[[VAL_15:.*]] = vector.mask %[[VAL_12]] { vector.transfer_read %{{.*}} {in_bounds = [true, true]} : tensor<?x?xf32>, vector<4x8xf32> } : vector<4x8xi1> -> vector<4x8xf32>
+// CHECK:           %[[VAL_16:.*]] = arith.addf %[[VAL_10]], %[[VAL_13]] : vector<4x8xf32>
+// CHECK:           %[[VAL_18:.*]] = vector.mask %[[VAL_12]] { vector.transfer_write %{{.*}} {in_bounds = [true, true]} : vector<4x8xf32>, tensor<?x?xf32> } : vector<4x8xi1> -> tensor<?x?xf32>
+
+transform.sequence failures(propagate) {
+^bb1(%arg1: !pdl.operation):
+  %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+  transform.structured.masked_vectorize %0 vector_sizes [4, 8]
+}
+
index 0023e45..b18a15f 100644 (file)
@@ -8308,6 +8308,7 @@ cc_library(
         ":LinalgPassIncGen",
         ":LinalgStructuredOpsIncGen",
         ":LinalgUtils",
+        ":MaskableOpInterface",
         ":MathDialect",
         ":MemRefDialect",
         ":MemRefTransforms",