[mlir][linalg] Vectorize tensor.extract using contiguous loads
authorAndrzej Warzynski <andrzej.warzynski@gmail.com>
Wed, 1 Mar 2023 15:44:01 +0000 (15:44 +0000)
committerAndrzej Warzynski <andrzej.warzynski@gmail.com>
Thu, 2 Mar 2023 09:20:37 +0000 (09:20 +0000)
This patch implements vectorization of tensor.extract for n-D tensor (n
>= 2) using contiguous load operations, i.e. `vector.transfer_read`. This
is a follow-up of https://reviews.llvm.org/D137660 in which gather loads
were used, i.e. `vector.gather`.

It is always safe to use gather load operations when the underlying
memory pattern is contiguous, but not vice-verse. At the moment, the
following conditions have to be met for contiguous loads to be
generated:
  1. The _output tensor_ must be a 1-D vector with the trailing dim > 1,
     e.g. `tensor<1x1x4xi32`,
  2. The trailing dim in the _input tensor_ must be > 1, e.g.
     `tensor<1x1x4i32>` would be fine, but not `tensor<1x4x1xi32>`.
If these conditions are not satisfied, gather loads are generated
instead.

Condition 1 guarantees that the iteration space of the corresponding
`linalg.generic` Op is relatively simple. That makes analysing the
indices for `tensor.extract` rather straightforward.

Condition 2 is mostly there to avoid weird vectorisation patterns
resulting in vectors like: `vector<1x1x1xi32>`. In practice, tensors
like `tensor<1x4x1xi32>` should be collapsed to `tensor<1x4xi32>` before
vectorisation, but that's beyond the scope of this patch.

If needed, both conditions can be relaxed. I've not been able to find a
good motivating example for these, hence skipping. For reference,
`tosa.resize` (lowered to Linalg) was the driving example used here.

As a bonus, the test from "vectorization-unsupported.mlir" is moved to
"vectorization.mlir" with proper CHECK lines added.

NOTE: This relands 89b144ece330b363713bec369d2d89dc85f715f5 (added extra
test, refined comments and variable names).

Differential Revision: https://reviews.llvm.org/D141998

Co-authored-by: Diego Caballero <diegocaballero@google.com>
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/vectorization-unsupported.mlir [deleted file]
mlir/test/Dialect/Linalg/vectorization.mlir

index 11ba371..29cabd4 100644 (file)
@@ -611,11 +611,11 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
 
   const size_t numIndices = extractOp.getIndices().size();
   for (size_t i = 1; i < numIndices; i++) {
+    Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
+
     auto dimSize = broadcastIfNeeded(
         rewriter,
-        rewriter.create<arith::ConstantIndexOp>(
-            loc,
-            extractOp.getTensor().getType().cast<ShapedType>().getDimSize(i)),
+        rewriter.create<tensor::DimOp>(loc, extractOp.getTensor(), dimIdx),
         indexVecType.getShape());
 
     offset = rewriter.create<arith::MulIOp>(loc, offset, dimSize);
@@ -630,6 +630,159 @@ static Value calculateGatherOffset(RewriterBase &rewriter,
   return offset;
 }
 
+enum VectorMemoryAccessKind {
+  // TODO: ScalarBroadcast,
+  Contiguous,
+  Gather
+};
+
+/// Check whether /p val can be used for calculating an index for a contiguous
+/// load operation. This means that /p val should either:
+///   * be invariant with respect to /p linalgOp, or
+///   * increment by 1 with every loop iterator (when /p shouldBeConstant is
+///     false).
+/// Parameters /p trailingLoopDim and /p shouldBeConstant are used to analyze
+/// `linalg.index` ops.
+static bool isContiguousLoadIdx(LinalgOp &linalgOp, Value &val,
+                                size_t trailingLoopDim, bool shouldBeConstant) {
+  auto *block = linalgOp.getBlock();
+
+  // Bail out if this is a block argument for this linalg.generic Op.
+  // TODO: We could try analysing the corresponding affine map here.
+  if (val.dyn_cast<BlockArgument>())
+    return llvm::all_of(block->getArguments(),
+                        [&val](Value v) { return (v != val); });
+
+  Operation *defOp = val.getDefiningOp();
+  assert(defOp && "This is neither a block argument nor an operation result");
+
+  // We know that we are reading into a 1-D tensor like this:
+  // `tensor<1x1x4xi32`. Given this assumption, the following Op:
+  //    * `%idx = `linalg.index dim : index`,
+  // will either:
+  //    1. produce a constant when `dim` _is not_ the trailing loop dim, or
+  //    2. increment with stride one when `dim` _is_ the trailing loop dim.
+  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
+    return shouldBeConstant ? (indexOp.getDim() != trailingLoopDim)
+                            : (indexOp.getDim() == trailingLoopDim);
+
+  auto *ancestor = block->findAncestorOpInBlock(*defOp);
+
+  // Values define outside `linalgOp`.
+  if (!ancestor)
+    return true;
+
+  // Values defined inside `linalgOp`, which are constant.
+  if (dyn_cast<arith::ConstantOp>(ancestor))
+    return true;
+
+  // Conservatively reject Ops that could lead to non-contiguous accesses.
+  if (!isa<arith::AddIOp, arith::SubIOp, linalg::IndexOp>(ancestor))
+    return false;
+
+  bool result = true;
+  for (auto op : ancestor->getOperands())
+    result &=
+        isContiguousLoadIdx(linalgOp, op, trailingLoopDim, shouldBeConstant);
+
+  return result;
+}
+
+/// Check whether the calculation of \p val is based on linalg.index Op with
+/// the dim attribute matching \p dim.
+static bool isBasedOnIndexOp(LinalgOp &linalgOp, Value &val, size_t dim) {
+  auto *block = linalgOp.getBlock();
+  auto targetShape = linalgOp.getStaticLoopRanges();
+
+  if (val.isa<BlockArgument>())
+    return false;
+
+  Operation *defOp = val.getDefiningOp();
+  assert(defOp && "This is neither a block argument nor an operation result");
+
+  if (auto indexOp = dyn_cast<linalg::IndexOp>(defOp))
+    return (indexOp.getDim() == dim);
+
+  auto *ancestor = block->findAncestorOpInBlock(*defOp);
+
+  if (!ancestor)
+    return false;
+
+  bool result = false;
+  for (auto op : ancestor->getOperands())
+    result |= isBasedOnIndexOp(linalgOp, op, dim);
+
+  return result;
+}
+
+/// Check whether \p extractOp would be a gather or a contiguous load Op after
+/// vectorising \p linalgOp. Note that it is always safe to use gather load
+/// operations for contiguous loads (albeit slow), but not vice-versa. When in
+/// doubt, bail out and assume that \p extractOp is a gather load.
+static VectorMemoryAccessKind
+getTensorExtractMemoryAccessPattern(tensor::ExtractOp extractOp,
+                                    LinalgOp &linalgOp) {
+
+  auto targetShape = linalgOp.getStaticLoopRanges();
+
+  // Assume that it's a gather load when reading _into_:
+  //    * an n-D vector, like`tensor<1x2x4xi32` or`tensor<2x1x4xi32>`, or
+  //    * a 1-D vector with the trailing dim equal 1, e.g. `tensor<1x4x1xi32`.
+  // TODO: Relax these conditions.
+  if ((llvm::count_if(targetShape,
+                      [](int64_t dimSize) { return dimSize > 1; }) != 1) ||
+      targetShape.back() == 1)
+    return VectorMemoryAccessKind::Gather;
+
+  auto inputShape = extractOp.getTensor().getType().cast<ShapedType>();
+
+  // Assume that it's a gather load when reading _from_ a tensor for which the
+  // trailing dimension is 1, e.g. `tensor<1x4x1xi32>`.
+  // TODO: Relax this condition.
+  if (inputShape.getShape().back() == 1)
+    return VectorMemoryAccessKind::Gather;
+
+  // The trailing loop dim is needed when analyzing ops like:
+  //     * %idx = `linalg.index <dim> : index`.
+  auto trailingLoopDim = targetShape.size() - 1;
+
+  bool isContiguous = true;
+
+  // Iterate over all indices. Analyze the way each index is calculated and
+  // decide whether it is suitable for a contiguous load (e.g. loop invariant).
+  auto indices = extractOp.getIndices();
+  for (auto [i, indexVal] : llvm::enumerate(indices)) {
+    if (inputShape.getShape()[i] == 1) {
+      // This index will always be equal 0, so it is a loop-invariant constant.
+      continue;
+    }
+
+    // Should this index be loop invariant?
+    //  * _no_ if this is the trailing index,
+    //  * _yes_ otherwise.
+    auto extractOpBottomIdx = indices.size() - 1;
+    bool loopInvariantIndex = (i != extractOpBottomIdx);
+
+    isContiguous &= isContiguousLoadIdx(linalgOp, indexVal, trailingLoopDim,
+                                        loopInvariantIndex);
+  }
+
+  // The trailing index in the extract Op must increment with every iteration,
+  // which means that it must be based on a loop index. Given the assumption
+  // on the output tensor, only the trailing loop index is not constant, so
+  // that's what we need to check against.
+  auto extractOpTrailingIdx = indices.back();
+  isContiguous &=
+      isBasedOnIndexOp(linalgOp, extractOpTrailingIdx, trailingLoopDim);
+
+  if (isContiguous) {
+    LDBG("Found contigous load: " << extractOp);
+    return VectorMemoryAccessKind::Contiguous;
+  }
+
+  return VectorMemoryAccessKind::Gather;
+}
+
 /// Helper function to vectorize the tensor.extract operations. Returns
 /// VectorizationStatus::NewOp to signal the vectorization algorithm that it
 /// should map the produced operations. This function is meant to be used as a
@@ -660,15 +813,64 @@ vectorizeTensorExtract(RewriterBase &rewriter, VectorizationState &state,
       extractOp.getIndices().size(),
       rewriter.create<arith::ConstantIndexOp>(loc, 0));
 
-  Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);
+  VectorMemoryAccessKind memAccessKind =
+      getTensorExtractMemoryAccessPattern(extractOp, linalgOp);
+
+  // 1. Handle gather access
+  if (memAccessKind == VectorMemoryAccessKind::Gather) {
+    Value offset = calculateGatherOffset(rewriter, extractOp, bvm, targetShape);
+
+    // Generate the gather load
+    Operation *gatherOp = rewriter.create<vector::GatherOp>(
+        loc, resultType, extractOp.getTensor(), baseIndices, offset,
+        maskConstantOp, passThruConstantOp);
+    gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
+
+    LDBG("Vectorised as gather load: " << extractOp);
+    return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
+  }
+
+  // 2. Handle contiguous access.
+  SmallVector<Value> transferReadIdxs;
+  auto resTrailingDim = resultType.getShape().back();
+  auto zero = rewriter.create<arith::ConstantOp>(
+      loc, rewriter.getI32Type(), rewriter.getZeroAttr(rewriter.getI32Type()));
+
+  // Collect indices for `vector.transfer_read`. At this point, the indices will
+  // either be scalars or would have been broadcast to vectors matching the
+  // result type. For indices that are vectors, there are two options:
+  //    * for non-trailing indices, all elements are identical (contiguous
+  //      loads are identified by looking for non-trailing indices that are
+  //      invariant with respect to the corresponding linalg.generic), or
+  //    * for trailing indices, the index vector will contain values with stride
+  //      one, but for `vector.transfer_read` only the first (i.e. 0th) index is
+  //      needed.
+  // This means that
+  //   * for scalar indices - just re-use it,
+  //   * for vector indices (e.g. `vector<1x1x4xindex>`) - extract the bottom
+  //    (0th) element and use that.
+  for (size_t i = 0; i < extractOp.getIndices().size(); i++) {
+    auto idx = bvm.lookup(extractOp.getIndices()[i]);
+    if (idx.getType().isIndex()) {
+      transferReadIdxs.push_back(idx);
+      continue;
+    }
+
+    auto indexAs1dVector = rewriter.create<vector::ShapeCastOp>(
+        loc, VectorType::get({resTrailingDim}, rewriter.getIndexType()),
+        bvm.lookup(extractOp.getIndices()[i]));
+    transferReadIdxs.push_back(
+        rewriter.create<vector::ExtractElementOp>(loc, indexAs1dVector, zero));
+  }
+
+  // `tensor.extract_element` is always in-bounds, hence the following holds.
+  SmallVector<bool> inBounds(resultType.getRank(), true);
 
-  // Generate the gather load
-  Operation *gatherOp = rewriter.create<vector::GatherOp>(
-      loc, resultType, extractOp.getTensor(), baseIndices, offset,
-      maskConstantOp, passThruConstantOp);
-  gatherOp = state.maskOperation(rewriter, gatherOp, linalgOp);
+  auto transferReadOp = rewriter.create<vector::TransferReadOp>(
+      loc, resultType, extractOp.getTensor(), transferReadIdxs, inBounds);
 
-  return VectorizationResult{VectorizationStatus::NewOp, gatherOp};
+  LDBG("Vectorised as contiguous load: " << extractOp);
+  return VectorizationResult{VectorizationStatus::NewOp, transferReadOp};
 }
 
 /// Emit reduction operations if the shapes of the value to reduce is different
diff --git a/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir b/mlir/test/Dialect/Linalg/vectorization-unsupported.mlir
deleted file mode 100644 (file)
index aa1d6e3..0000000
+++ /dev/null
@@ -1,29 +0,0 @@
-// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file -verify-diagnostics
-
-// Masked vectorisation of `tensor.extract`:
-//   * requires the `{ vectorize_nd_extract }` attribute,
-//   * has not been implemented yet (hence the attribute is absent).
-// TOOD: Implement masked vectorization for `tensor.extract`
-
-#map1 = affine_map<(d0, d1) -> (d0, d1)>
-func.func @extract_masked_vectorize(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
-  %c0 = arith.constant 1 : index
-  %c1 = arith.constant 2 : index
-  // expected-error@+1 {{failed to vectorize op}}
-  %2 = linalg.generic {
-    indexing_maps = [#map1],
-    iterator_types = ["parallel", "parallel"]
-  } outs(%arg1 : tensor<?x?xf32>) {
-  ^bb0(%arg3: f32):
-    %7 = tensor.extract %arg0[%c0, %c1] : tensor<?x?xf32>
-    linalg.yield %7 : f32
-  } -> tensor<?x?xf32>
-  return %2 : tensor<?x?xf32>
-}
-
-
-transform.sequence failures(propagate) {
- ^bb1(%arg1: !pdl.operation):
-   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
-   transform.structured.masked_vectorize %0 vector_sizes [3, 3]
- }
index a43fd7a..41e5290 100644 (file)
@@ -1584,7 +1584,6 @@ func.func @vectorize_nd_tensor_extract_constant_idx(%arg0: tensor<3x3xf32>, %arg
     iterator_types = ["parallel", "parallel", "parallel"]
   } outs(%arg2 : tensor<1x1x3xf32>) {
   ^bb0(%arg4: f32):
-    %3 = linalg.index 2 : index
     %7 = tensor.extract %arg0[%c0, %c1] : tensor<3x3xf32>
     linalg.yield %7 : f32
   } -> tensor<1x1x3xf32>
@@ -1613,7 +1612,7 @@ transform.sequence failures(propagate) {
 // -----
 
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
-func.func @vectorize_nd_tensor_extract_idx_from_iteration_index(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
+func.func @vectorize_nd_tensor_extract_transfer_read_basic(%arg0: tensor<3x3x3xf32>, %arg2: tensor<1x1x3xf32>) -> tensor<1x1x3xf32> {
   %1 = linalg.generic {
     indexing_maps = [#map1],
     iterator_types = ["parallel", "parallel", "parallel"]
@@ -1628,16 +1627,19 @@ func.func @vectorize_nd_tensor_extract_idx_from_iteration_index(%arg0: tensor<3x
   return %1 : tensor<1x1x3xf32>
 }
 
-// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_idx_from_iteration_index
+// CHECK-LABEL: func.func @vectorize_nd_tensor_extract_transfer_read_basic
 // CHECK-SAME: %[[ARG0:.*]]: tensor<3x3x3xf32>
 // CHECK-SAME: %[[ARG1:.*]]: tensor<1x1x3xf32>
-// CHECK:   %[[INDICES:.*]] = arith.constant dense<[0, 1, 2]> : vector<3xindex>
-// CHECK:   %[[MASK:.*]] = arith.constant dense<true> : vector<1x1x3xi1>
-// CHECK:   %[[PASSTHRU:.*]] = arith.constant dense<0.000000e+00> : vector<1x1x3xf32>
+// CHECK:   %[[CST:.*]] = arith.constant dense<0> : vector<1x1x3xindex>
+// CHECK:   %[[C0_i32:.*]] = arith.constant 0 : i32
 // CHECK:   %[[C0:.*]] = arith.constant 0 : index
-// CHECK:   %[[B:.*]] = vector.broadcast %[[INDICES]] : vector<3xindex> to vector<1x1x3xindex>
-// CHECK:   %[[GATHER:.*]] = vector.gather %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] [%[[B]]], %[[MASK]], %[[PASSTHRU]] : tensor<3x3x3xf32>, vector<1x1x3xindex>, vector<1x1x3xi1>, vector<1x1x3xf32> into vector<1x1x3xf32>
-// CHECK:   vector.transfer_write %[[GATHER]]
+// CHECK:   %[[CST_0:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:   %[[IDX_VEC0:.*]] = vector.shape_cast %[[CST]] : vector<1x1x3xindex> to vector<3xindex>
+// CHECK:   %[[IDX1:.*]] = vector.extractelement %[[IDX_VEC0]][%[[C0_i32]] : i32] : vector<3xindex>
+// CHECK:   %[[IDX_VEC:.*]] = vector.shape_cast %[[CST]] : vector<1x1x3xindex> to vector<3xindex>
+// CHECK:   %[[IDX2:.*]] = vector.extractelement %[[IDX_VEC]][%[[C0_i32]] : i32] : vector<3xindex>
+// CHECK:   %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[IDX1]], %[[IDX2]], %[[C0:.*]]], %[[CST_0]] {in_bounds = [true, true, true]} : tensor<3x3x3xf32>, vector<1x1x3xf32>
+// CHECK:   vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<1x1x3xf32>, tensor<1x1x3xf32>
 
 transform.sequence failures(propagate) {
 ^bb1(%arg1: !pdl.operation):
@@ -1646,6 +1648,56 @@ transform.sequence failures(propagate) {
   %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
 }
 
+ // -----
+
+func.func @vectorize_nd_tensor_extract_transfer_read_complex(%6: tensor<45x80x16xf32>, %arg0: index, %arg2: index, %arg1: index, %arg4: index, %extracted_slice : tensor<1x4xf32>) -> tensor<1x4xf32> {
+  %c79 = arith.constant 79 : index
+  %25 = linalg.generic {
+    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>],
+    iterator_types = ["parallel", "parallel"]
+  } outs(%extracted_slice : tensor<1x4xf32>) {
+  ^bb0(%out: f32):
+    %26 = linalg.index 0 : index
+    %27 = arith.addi %arg0, %26 : index
+    %28 = arith.addi %27, %arg2 : index
+    %29 = linalg.index 1 : index
+    %30 = arith.addi %arg1, %29 : index
+    %31 = arith.addi %30, %arg4 : index
+    %extracted = tensor.extract %6[%28, %c79, %31] : tensor<45x80x16xf32>
+    linalg.yield %extracted : f32
+  } -> tensor<1x4xf32>
+  return %25 : tensor<1x4xf32>
+}
+
+// CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_transfer_read_complex
+// CHECK-SAME:      %[[VAL_0:.*]]: tensor<45x80x16xf32>,
+// CHECK-SAME:      {{.*}}: index,
+// CHECK-SAME:      %[[VAL_5:.*]]: tensor<1x4xf32>) -> tensor<1x4xf32> {
+// CHECK:           %[[VAL_6:.*]] = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex>
+// CHECK:           %[[VAL_7:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_8:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_9:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_10:.*]] = arith.constant 79 : index
+// CHECK:           %[[VAL_11:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex>
+// CHECK:           %[[VAL_12:.*]] = vector.broadcast %{{.*}} : index to vector<1x4xindex>
+// CHECK:           %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : vector<1x4xindex>
+// CHECK:           %[[VAL_14:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex>
+// CHECK:           %[[VAL_15:.*]] = arith.addi %[[VAL_14]], %[[VAL_6]] : vector<4xindex>
+// CHECK:           %[[VAL_16:.*]] = vector.broadcast %{{.*}} : index to vector<4xindex>
+// CHECK:           %[[VAL_17:.*]] = arith.addi %[[VAL_15]], %[[VAL_16]] : vector<4xindex>
+// CHECK:           %[[VAL_18:.*]] = vector.shape_cast %[[VAL_13]] : vector<1x4xindex> to vector<4xindex>
+// CHECK:           %[[VAL_19:.*]] = vector.extractelement %[[VAL_18]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
+// CHECK:           %[[VAL_20:.*]] = vector.extractelement %[[VAL_17]]{{\[}}%[[VAL_7]] : i32] : vector<4xindex>
+// CHECK:           %[[VAL_21:.*]] = vector.transfer_read %[[VAL_0]]{{\[}}%[[VAL_19]], %[[VAL_10]], %[[VAL_20]]], %[[VAL_8]] {in_bounds = [true, true]} : tensor<45x80x16xf32>, vector<1x4xf32>
+// CHECK:           %[[VAL_22:.*]] = vector.transfer_write %[[VAL_21]], %[[VAL_5]]{{\[}}%[[VAL_9]], %[[VAL_9]]] {in_bounds = [true, true]} : vector<1x4xf32>, tensor<1x4xf32>
+
+transform.sequence failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+   %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+   %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
+ }
+
 // -----
 
 #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d2)>
@@ -1693,6 +1745,51 @@ transform.sequence failures(propagate) {
   %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
   %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
 }
+// -----
+
+#map = affine_map<(d0) -> (d0)>
+func.func @vectorize_nd_tensor_extract_contiguous_and_gather(%arg0: tensor<6xf32>, %arg1: tensor<5xi32>) -> tensor<5xf32> {
+ %c5 = arith.constant 5 : index
+ %c0 = arith.constant 0 : index
+ %0 = tensor.empty() : tensor<5xf32>
+ %1 = linalg.generic {indexing_maps = [#map], iterator_types = ["parallel"]} outs(%0 : tensor<5xf32>) {
+ ^bb0(%out: f32):
+   %2 = linalg.index 0 : index
+   %extracted = tensor.extract %arg1[%2] : tensor<5xi32>
+   %3 = arith.index_cast %extracted : i32 to index
+   %4 = arith.maxsi %3, %c0 : index
+   %5 = arith.minsi %4, %c5 : index
+   %extracted_0 = tensor.extract %arg0[%5] : tensor<6xf32>
+   linalg.yield %extracted_0 : f32
+ } -> tensor<5xf32>
+ return %1 : tensor<5xf32>
+}
+
+// CHECK-LABEL:   func.func @vectorize_nd_tensor_extract_contiguous_and_gather(
+// CHECK-SAME:                    %[[VAL_0:.*]]: tensor<6xf32>
+// CHECK-SAME:                    %[[VAL_1:.*]]: tensor<5xi32>
+// CHECK:           %[[VAL_2:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 0 : i32
+// CHECK:           %[[VAL_4:.*]] = arith.constant dense<0> : vector<5xindex>
+// CHECK:           %[[VAL_5:.*]] = arith.constant dense<5> : vector<5xindex>
+// CHECK:           %[[VAL_6:.*]] = arith.constant dense<true> : vector<5xi1>
+// CHECK:           %[[VAL_7:.*]] = arith.constant dense<0.000000e+00> : vector<5xf32>
+// CHECK:           %[[VAL_8:.*]] = tensor.empty() : tensor<5xf32>
+// CHECK:           %[[VAL_9:.*]] = vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_2]]], %[[VAL_3]] {in_bounds = [true]} : tensor<5xi32>, vector<5xi32>
+// CHECK:           %[[VAL_10:.*]] = arith.index_cast %[[VAL_9]] : vector<5xi32> to vector<5xindex>
+// CHECK:           %[[VAL_11:.*]] = arith.maxsi %[[VAL_10]], %[[VAL_4]] : vector<5xindex>
+// CHECK:           %[[VAL_12:.*]] = arith.minsi %[[VAL_11]], %[[VAL_5]] : vector<5xindex>
+// CHECK:           %[[VAL_13:.*]] = vector.gather %[[VAL_0]]{{\[}}%[[VAL_2]]] {{\[}}%[[VAL_12]]], %[[VAL_6]], %[[VAL_7]] : tensor<6xf32>, vector<5xindex>, vector<5xi1>, vector<5xf32> into vector<5xf32>
+// CHECK:           %[[VAL_14:.*]] = vector.transfer_write %[[VAL_13]], %[[VAL_8]]{{\[}}%[[VAL_2]]] {in_bounds = [true]} : vector<5xf32>, tensor<5xf32>
+// CHECK:           return %[[VAL_14]] : tensor<5xf32>
+
+transform.sequence failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+   %1 = get_closest_isolated_parent %0 : (!pdl.operation) -> !pdl.operation
+   %2 = transform.structured.vectorize %1 { vectorize_nd_extract }
+ }
+
 
 // -----
 
@@ -2119,6 +2216,56 @@ transform.sequence failures(propagate) {
 
 // -----
 
+#map1 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @extract_masked_vectorize(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
+  %c0 = arith.constant 1 : index
+  %c1 = arith.constant 2 : index
+  %2 = linalg.generic {
+    indexing_maps = [#map1],
+    iterator_types = ["parallel", "parallel"]
+  } outs(%arg1 : tensor<?x?xf32>) {
+  ^bb0(%arg3: f32):
+    %7 = tensor.extract %arg0[%c0, %c1] : tensor<?x?xf32>
+    linalg.yield %7 : f32
+  } -> tensor<?x?xf32>
+  return %2 : tensor<?x?xf32>
+}
+
+// CHECK-LABEL:   func.func @extract_masked_vectorize(
+// CHECK-SAME:                                        %[[VAL_0:.*]]: tensor<?x?xf32>,
+// CHECK-SAME:                                        %[[VAL_1:.*]]: tensor<?x?xf32>) -> tensor<?x?xf32> {
+// CHECK:           %[[VAL_2:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_3:.*]] = arith.constant 2 : index
+// CHECK:           %[[VAL_4:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor<?x?xf32>
+// CHECK:           %[[VAL_6:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_7:.*]] = tensor.dim %[[VAL_1]], %[[VAL_6]] : tensor<?x?xf32>
+// CHECK:           %[[VAL_8:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_9:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK:           %[[VAL_10:.*]] = vector.create_mask %[[VAL_5]], %[[VAL_7]] : vector<3x3xi1>
+// CHECK:           %[[VAL_11:.*]] = vector.mask %[[VAL_10]] { vector.transfer_read %[[VAL_1]]{{\[}}%[[VAL_8]], %[[VAL_8]]], %[[VAL_9]] {in_bounds = [true, true]} : tensor<?x?xf32>, vector<3x3xf32> } : vector<3x3xi1> -> vector<3x3xf32>
+// CHECK:           %[[VAL_12:.*]] = arith.constant dense<true> : vector<3x3xi1>
+// CHECK:           %[[VAL_13:.*]] = arith.constant dense<0.000000e+00> : vector<3x3xf32>
+// CHECK:           %[[VAL_14:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_15:.*]] = arith.constant dense<1> : vector<3x3xindex>
+// CHECK:           %[[VAL_16:.*]] = arith.constant 1 : index
+// CHECK:           %[[VAL_17:.*]] = tensor.dim %[[VAL_0]], %[[VAL_16]] : tensor<?x?xf32>
+// CHECK:           %[[VAL_18:.*]] = vector.broadcast %[[VAL_17]] : index to vector<3x3xindex>
+// CHECK:           %[[VAL_19:.*]] = arith.muli %[[VAL_15]], %[[VAL_18]] : vector<3x3xindex>
+// CHECK:           %[[VAL_20:.*]] = arith.constant dense<2> : vector<3x3xindex>
+// CHECK:           %[[VAL_21:.*]] = arith.addi %[[VAL_20]], %[[VAL_19]] : vector<3x3xindex>
+// CHECK:           %[[VAL_22:.*]] = vector.mask %[[VAL_10]] { vector.gather %[[VAL_0]]{{\[}}%[[VAL_14]], %[[VAL_14]]] {{\[}}%[[VAL_21]]], %[[VAL_12]], %[[VAL_13]] : tensor<?x?xf32>, vector<3x3xindex>, vector<3x3xi1>, vector<3x3xf32> into vector<3x3xf32> } : vector<3x3xi1> -> vector<3x3xf32>
+// CHECK:           %[[VAL_23:.*]] = arith.constant 0 : index
+// CHECK:           %[[VAL_24:.*]] = vector.mask %[[VAL_10]] { vector.transfer_write %[[VAL_22]], %[[VAL_1]]{{\[}}%[[VAL_23]], %[[VAL_23]]] {in_bounds = [true, true]} : vector<3x3xf32>, tensor<?x?xf32> } : vector<3x3xi1> -> tensor<?x?xf32>
+
+transform.sequence failures(propagate) {
+ ^bb1(%arg1: !pdl.operation):
+   %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!pdl.operation) -> !pdl.operation
+   transform.structured.masked_vectorize %0 vector_sizes [3, 3] { vectorize_nd_extract }
+ }
+
+// -----
+
 func.func @do_not_generate_masks(%arg0: tensor<8x32xf32>,
                                  %arg1: tensor<8x32xf32>,
                                  %arg2: tensor<8x32xf32>) -> tensor<8x32xf32> {