[mlir][Linalg][Vector] Add forwarding patterns between linalg.copy and vector.transfer
authorNicolas Vasilache <ntv@google.com>
Fri, 29 May 2020 12:01:15 +0000 (08:01 -0400)
committerNicolas Vasilache <ntv@google.com>
Fri, 29 May 2020 12:08:34 +0000 (08:08 -0400)
This revision adds custom rewrites for patterns that arise during linalg structured
ops vectorization. These patterns allow the composition of linalg promotion,
vectorization and removal of redundant copies.

The patterns are voluntarily limited and restrictive atm.
More robust behavior will be implemented once more powerful side effect modeling and analyses are available on view/subview.

On the transfer_read side, the following pattern is rewritten:
```
   %alloc = ...
   [optional] %view = std.view %alloc ...
   %subView = subview %allocOrView ...
   [optional] linalg.fill(%allocOrView, %cst) ...
   ...
   linalg.copy(%in, %subView) ...
   vector.transfer_read %allocOrView[...], %cst ...
```
into
```
   [unchanged] %alloc = ...
   [unchanged] [optional] %view = std.view %alloc ...
   [unchanged] [unchanged] %subView = subview %allocOrView ...
   ...
   vector.transfer_read %in[...], %cst ...
```

On the transfer_write side, the following pattern is rewriten:
```
   %alloc = ...
   [optional] %view = std.view %alloc ...
   %subView = subview %allocOrView...
   ...
   vector.transfer_write %..., %allocOrView[...]
   linalg.copy(%subView, %out)
```

Differential Revision: https://reviews.llvm.org/D80728

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Dialect/Linalg/forward-vector-transfers.mlir [new file with mode: 0644]
mlir/test/lib/Transforms/TestLinalgTransforms.cpp

index 2da6319..2e06737 100644 (file)
 #include "llvm/ADT/SmallBitVector.h"
 
 namespace mlir {
+namespace vector {
+
+class TransferReadOp;
+class TransferWriteOp;
+
+} // namespace vector
+
 namespace linalg {
 
 struct LinalgTilingOptions;
@@ -438,6 +445,67 @@ private:
 };
 
 //===----------------------------------------------------------------------===//
+// Op-specific patterns.
+//===----------------------------------------------------------------------===//
+/// Match and rewrite for the pattern:
+/// ```
+///    %alloc = ...
+///    [optional] %view = std.view %alloc ...
+///    %subView = subview %allocOrView ...
+///    [optional] linalg.fill(%allocOrView, %cst) ...
+///    ...
+///    linalg.copy(%in, %subView) ...
+///    vector.transfer_read %allocOrView[...], %cst ...
+/// ```
+/// into
+/// ```
+///    [unchanged] %alloc = ...
+///    [unchanged] [optional] %view = std.view %alloc ...
+///    [unchanged] [unchanged] %subView = subview %allocOrView ...
+///    ...
+///    vector.transfer_read %in[...], %cst ...
+/// ```
+/// Where there is no interleaved use between linalg.copy and transfer_read as
+/// well as no interleaved use between linalg.fill and linalg.copy (if
+/// linalg.fill is specified).
+/// This is a custom rewrite to forward partial reads (with optional fills) to
+/// vector.transfer_read.
+struct LinalgCopyVTRForwardingPattern
+    : public OpRewritePattern<vector::TransferReadOp> {
+  using OpRewritePattern<vector::TransferReadOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferReadOp xferOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+/// Match and rewrite for the pattern:
+/// ```
+///    %alloc = ...
+///    [optional] %view = std.view %alloc ...
+///    %subView = subview %allocOrView...
+///    ...
+///    vector.transfer_write %..., %allocOrView[...]
+///    linalg.copy(%subView, %out)
+/// ```
+/// into
+/// ```
+///    [unchanged] %alloc = ...
+///    [unchanged] [optional] %view = std.view %alloc ...
+///    [unchanged] %subView = subview %allocOrView...
+///    ...
+///    vector.transfer_write %..., %out[...]
+/// ```
+/// Where there is no interleaved use between transfer_write and linalg.copy.
+/// This is a custom rewrite to forward partial writes to vector.transfer_write.
+struct LinalgCopyVTWForwardingPattern
+    : public OpRewritePattern<vector::TransferWriteOp> {
+  using OpRewritePattern<vector::TransferWriteOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::TransferWriteOp xferOp,
+                                PatternRewriter &rewriter) const override;
+};
+
+//===----------------------------------------------------------------------===//
 // Support for staged pattern application.
 //===----------------------------------------------------------------------===//
 /// Helper function to allow applying rewrite patterns, interleaved with more
index f27baa3..8fa0aa3 100644 (file)
@@ -103,12 +103,13 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
       llvm_unreachable("Unexpected conv with padding");
   }
 
+  StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
+  (void)dbgPref;
   edsc::ScopedContext scope(builder, op->getLoc());
   if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
     // Vectorize fill as a vector.broadcast.
-    LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
-                         "]: Rewrite linalg.fill as vector.broadcast: "
-                      << *op << ":\n");
+    LLVM_DEBUG(dbgs() << dbgPref
+                      << "Rewrite linalg.fill as vector.broadcast: " << *op);
     Value memref = vector_type_cast(fillOp.getOutputBuffer(0));
     Value dst = std_load(memref);
     Value res = vector_broadcast(dst.getType(), fillOp.value());
@@ -117,9 +118,8 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
   }
 
   // Vectorize other ops as vector contraction (currently only matmul).
-  LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE
-                       "]: Rewrite linalg op as vector.contract: "
-                    << *op << ":\n");
+  LLVM_DEBUG(dbgs() << dbgPref
+                    << "Rewrite linalg op as vector.contract: " << *op);
   auto linalgOp = cast<linalg::LinalgOp>(op);
   Value a = std_load(vector_type_cast(linalgOp.getInput(0)));
   Value b = std_load(vector_type_cast(linalgOp.getInput(1)));
@@ -129,3 +129,168 @@ void mlir::linalg::vectorizeLinalgOp(OpBuilder &builder, Operation *op) {
                               linalgOp.iterator_types());
   std_store(res, memref);
 }
+
+/// Check whether there is any interleaved use of any `values` between `firstOp`
+/// and `secondOp`. Conservatively return `true` if any op or value is in a
+/// different block.
+static bool mayExistInterleavedUses(Operation *firstOp, Operation *secondOp,
+                                    ValueRange values) {
+  StringRef dbgPref = "\n[" DEBUG_TYPE "]: ";
+  (void)dbgPref;
+  if (firstOp->getBlock() != secondOp->getBlock() ||
+      !firstOp->isBeforeInBlock(secondOp)) {
+    LLVM_DEBUG(llvm::dbgs()
+               << dbgPref << "interleavedUses precondition failed, firstOp: "
+               << *firstOp << ", second op: " << *secondOp);
+    return true;
+  }
+  for (auto v : values) {
+    for (auto &u : v.getUses()) {
+      Operation *owner = u.getOwner();
+      if (owner == firstOp || owner == secondOp)
+        continue;
+      // TODO: this is too conservative, use dominance info in the future.
+      if (owner->getBlock() == firstOp->getBlock() &&
+          (owner->isBeforeInBlock(firstOp) || secondOp->isBeforeInBlock(owner)))
+        continue;
+      LLVM_DEBUG(llvm::dbgs()
+                 << dbgPref << " found interleaved op " << *owner
+                 << ", firstOp: " << *firstOp << ", second op: " << *secondOp);
+      return true;
+    }
+  }
+  return false;
+}
+
+/// Return the unique subview use of `v` if it is indeed unique, null otherwise.
+static SubViewOp getSubViewUseIfUnique(Value v) {
+  SubViewOp subViewOp;
+  for (auto &u : v.getUses()) {
+    if (auto newSubViewOp = dyn_cast<SubViewOp>(u.getOwner())) {
+      if (subViewOp)
+        return SubViewOp();
+      subViewOp = newSubViewOp;
+    }
+  }
+  return subViewOp;
+}
+
+/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
+/// when available.
+LogicalResult LinalgCopyVTRForwardingPattern::matchAndRewrite(
+    vector::TransferReadOp xferOp, PatternRewriter &rewriter) const {
+
+  // Transfer into `view`.
+  Value viewOrAlloc = xferOp.memref();
+  if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
+      !viewOrAlloc.getDefiningOp<AllocOp>())
+    return failure();
+
+  StringRef dbgPref = "\n[" DEBUG_TYPE "]: VTRForwarding: ";
+  (void)dbgPref;
+  LLVM_DEBUG(llvm::dbgs() << dbgPref << viewOrAlloc);
+
+  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
+  SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
+  if (!subViewOp)
+    return failure();
+  Value subView = subViewOp.getResult();
+  LLVM_DEBUG(llvm::dbgs() << dbgPref << "with subView " << subView);
+
+  // Find the copy into `subView` without interleaved uses.
+  CopyOp copyOp;
+  for (auto &u : subView.getUses()) {
+    if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
+      if (newCopyOp.getOutputBuffer(0) != subView)
+        continue;
+      LLVM_DEBUG(llvm::dbgs() << dbgPref << "copy candidate " << *newCopyOp);
+      if (mayExistInterleavedUses(newCopyOp, xferOp, {viewOrAlloc, subView}))
+        continue;
+      copyOp = newCopyOp;
+      break;
+    }
+  }
+  if (!copyOp)
+    return failure();
+  LLVM_DEBUG(llvm::dbgs() << dbgPref << "with copy " << *copyOp);
+
+  // Find the fill into `viewOrAlloc` without interleaved uses before the copy.
+  FillOp maybeFillOp;
+  for (auto &u : viewOrAlloc.getUses()) {
+    if (auto newFillOp = dyn_cast<FillOp>(u.getOwner())) {
+      if (newFillOp.getOutputBuffer(0) != viewOrAlloc)
+        continue;
+      LLVM_DEBUG(llvm::dbgs() << dbgPref << "fill candidate " << *newFillOp);
+      if (mayExistInterleavedUses(newFillOp, copyOp, {viewOrAlloc, subView}))
+        continue;
+      maybeFillOp = newFillOp;
+      break;
+    }
+  }
+  // Ensure padding matches.
+  if (maybeFillOp && xferOp.padding() != maybeFillOp.value())
+    return failure();
+  if (maybeFillOp)
+    LLVM_DEBUG(llvm::dbgs() << dbgPref << "with maybeFillOp " << *maybeFillOp);
+
+  // `in` is the subview that linalg.copy reads. Replace it.
+  Value in = copyOp.getInput(0);
+
+  Value res = rewriter.create<vector::TransferReadOp>(
+      xferOp.getLoc(), xferOp.getVectorType(), in, xferOp.indices(),
+      xferOp.permutation_map(), xferOp.padding(),
+      xferOp.masked() ? *xferOp.masked() : ArrayAttr());
+
+  if (maybeFillOp)
+    rewriter.eraseOp(maybeFillOp);
+  rewriter.eraseOp(copyOp);
+  rewriter.replaceOp(xferOp, res);
+
+  return success();
+}
+
+/// TODO: use interfaces, side-effects and aliasing analysis as appropriate,
+/// when available.
+LogicalResult LinalgCopyVTWForwardingPattern::matchAndRewrite(
+    vector::TransferWriteOp xferOp, PatternRewriter &rewriter) const {
+  // Transfer into `viewOrAlloc`.
+  Value viewOrAlloc = xferOp.memref();
+  if (!viewOrAlloc.getDefiningOp<ViewOp>() &&
+      !viewOrAlloc.getDefiningOp<AllocOp>())
+    return failure();
+
+  // Ensure there is exactly one subview of `viewOrAlloc` defining `subView`.
+  SubViewOp subViewOp = getSubViewUseIfUnique(viewOrAlloc);
+  if (!subViewOp)
+    return failure();
+  Value subView = subViewOp.getResult();
+
+  // Find the copy from `subView` without interleaved uses.
+  CopyOp copyOp;
+  for (auto &u : subViewOp.getResult().getUses()) {
+    if (auto newCopyOp = dyn_cast<CopyOp>(u.getOwner())) {
+      if (newCopyOp.getInput(0) != subView)
+        continue;
+      if (mayExistInterleavedUses(xferOp, newCopyOp, {viewOrAlloc, subView}))
+        continue;
+      copyOp = newCopyOp;
+      break;
+    }
+  }
+  if (!copyOp)
+    return failure();
+
+  // `out` is the subview copied into that we replace.
+  Value out = copyOp.getOutputBuffer(0);
+
+  // Forward vector.transfer into copy.
+  rewriter.create<vector::TransferWriteOp>(
+      xferOp.getLoc(), xferOp.vector(), out, xferOp.indices(),
+      xferOp.permutation_map(),
+      xferOp.masked() ? *xferOp.masked() : ArrayAttr());
+
+  rewriter.eraseOp(copyOp);
+  rewriter.eraseOp(xferOp);
+
+  return success();
+}
diff --git a/mlir/test/Dialect/Linalg/forward-vector-transfers.mlir b/mlir/test/Dialect/Linalg/forward-vector-transfers.mlir
new file mode 100644 (file)
index 0000000..7f56234
--- /dev/null
@@ -0,0 +1,153 @@
+// RUN: mlir-opt %s -allow-unregistered-dialect -test-linalg-transform-patterns=test-vector-transfer-forwarding-patterns | FileCheck %s
+
+// CHECK-LABEL: testAllocRead
+//  CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref
+//   CHECK-NOT: linalg.fill
+//   CHECK-NOT: linalg.copy
+//       CHECK: %[[ALLOC:.*]] = alloc
+//       CHECK: vector.transfer_read %[[ARG0]]
+func @testAllocRead(%in: memref<? x f32>) -> vector<32 x f32> {
+  %c0 = constant 0: index
+  %f0 = constant 0.0: f32
+  %alloc = alloc() : memref<32 x f32>
+  %subview = subview %alloc[0][16][1] : memref<32 x f32> to memref<16 x f32>
+  linalg.copy(%in, %subview): memref<? x f32>, memref<16 x f32>
+  %0 = vector.transfer_read %alloc[%c0], %f0: memref<32 x f32>, vector<32 x f32>
+  dealloc %alloc : memref<32 x f32>
+  return %0: vector<32 x f32>
+}
+
+// CHECK-LABEL: testAllocFillRead
+//  CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref
+//   CHECK-NOT: linalg.fill
+//   CHECK-NOT: linalg.copy
+//       CHECK: %[[ALLOC:.*]] = alloc
+//       CHECK: vector.transfer_read %[[ARG0]]
+func @testAllocFillRead(%in: memref<? x f32>) -> vector<32 x f32> {
+  %c0 = constant 0: index
+  %f0 = constant 0.0: f32
+  %alloc = alloc() : memref<32 x f32>
+  linalg.fill(%alloc, %f0): memref<32 x f32>, f32
+  %subview = subview %alloc[0][16][1] : memref<32 x f32> to memref<16 x f32>
+  linalg.copy(%in, %subview): memref<? x f32>, memref<16 x f32>
+  %0 = vector.transfer_read %alloc[%c0], %f0: memref<32 x f32>, vector<32 x f32>
+  dealloc %alloc : memref<32 x f32>
+  return %0: vector<32 x f32>
+}
+
+// CHECK-LABEL: testViewRead
+//  CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref
+//   CHECK-NOT: linalg.fill
+//   CHECK-NOT: linalg.copy
+//       CHECK: %[[ALLOC:.*]] = alloc
+//       CHECK: vector.transfer_read %[[ARG0]]
+func @testViewRead(%in: memref<? x f32>) -> vector<32 x f32> {
+  %c0 = constant 0: index
+  %f0 = constant 0.0: f32
+  %alloc = alloc() : memref<128 x i8>
+  %view = view %alloc[%c0][] : memref<128 x i8> to memref<32 x f32>
+  %subview = subview %view[0][16][1] : memref<32 x f32> to memref<16 x f32>
+  linalg.copy(%in, %subview): memref<? x f32>, memref<16 x f32>
+  %0 = vector.transfer_read %view[%c0], %f0: memref<32 x f32>, vector<32 x f32>
+  dealloc %alloc : memref<128 x i8>
+  return %0: vector<32 x f32>
+}
+
+// CHECK-LABEL: testViewFillRead
+//  CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref
+//   CHECK-NOT: linalg.fill
+//   CHECK-NOT: linalg.copy
+//       CHECK: %[[ALLOC:.*]] = alloc
+//       CHECK: vector.transfer_read %[[ARG0]]
+func @testViewFillRead(%in: memref<? x f32>) -> vector<32 x f32> {
+  %c0 = constant 0: index
+  %f0 = constant 0.0: f32
+  %alloc = alloc() : memref<128 x i8>
+  %view = view %alloc[%c0][] : memref<128 x i8> to memref<32 x f32>
+  %subview = subview %view[0][16][1] : memref<32 x f32> to memref<16 x f32>
+  linalg.fill(%view, %f0): memref<32 x f32>, f32
+  linalg.copy(%in, %subview): memref<? x f32>, memref<16 x f32>
+  %0 = vector.transfer_read %view[%c0], %f0: memref<32 x f32>, vector<32 x f32>
+  dealloc %alloc : memref<128 x i8>
+  return %0: vector<32 x f32>
+}
+
+// CHECK-LABEL: testAllocWrite
+//  CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: vector
+//  CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: memref
+//   CHECK-NOT: linalg.copy
+//       CHECK: %[[ALLOC:.*]] = alloc
+//       CHECK: vector.transfer_write %[[ARG0]], %[[ARG1]]
+func @testAllocWrite(%vec: vector<32 x f32>, %out: memref<? x f32>) {
+  %c0 = constant 0: index
+  %f0 = constant 0.0: f32
+  %alloc = alloc() : memref<32 x f32>
+  %subview = subview %alloc[0][16][1] : memref<32 x f32> to memref<16 x f32>
+  vector.transfer_write %vec, %alloc[%c0] : vector<32 x f32>, memref<32 x f32>
+  linalg.copy(%subview, %out): memref<16 x f32>, memref<? x f32>
+  dealloc %alloc : memref<32 x f32>
+  return
+}
+
+// CHECK-LABEL: testViewWrite
+//  CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: vector
+//  CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: memref
+//   CHECK-NOT: linalg.copy
+//       CHECK: %[[ALLOC:.*]] = alloc
+//       CHECK: vector.transfer_write %[[ARG0]], %[[ARG1]]
+func @testViewWrite(%vec: vector<32 x f32>, %out: memref<? x f32>) {
+  %c0 = constant 0: index
+  %f0 = constant 0.0: f32
+  %alloc = alloc() : memref<128 x i8>
+  %view = view %alloc[%c0][] : memref<128 x i8> to memref<32 x f32>
+  %subview = subview %view[0][16][1] : memref<32 x f32> to memref<16 x f32>
+  vector.transfer_write %vec, %view[%c0] : vector<32 x f32>, memref<32 x f32>
+  linalg.copy(%subview, %out): memref<16 x f32>, memref<? x f32>
+  dealloc %alloc : memref<128 x i8>
+  return
+}
+
+///===--------------------------------------------------------------------===///
+// Negative tests
+///===--------------------------------------------------------------------===///
+
+// This should fail the rewrite due to mismatching fill and transfer read value.
+// CHECK-LABEL: failAllocFillRead
+//  CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: memref
+//   CHECK-NOT: vector.transfer_read %[[ARG0]]
+//       CHECK: %[[ALLOC:.*]] = alloc
+//       CHECK: linalg.copy
+//       CHECK: vector.transfer_read %[[ALLOC]]
+func @failAllocFillRead(%in: memref<? x f32>) -> vector<32 x f32> {
+  %c0 = constant 0: index
+  %f0 = constant 0.0: f32
+  %f1 = constant 1.0: f32
+  %alloc = alloc() : memref<32 x f32>
+  linalg.fill(%alloc, %f0): memref<32 x f32>, f32
+  %subview = subview %alloc[0][16][1] : memref<32 x f32> to memref<16 x f32>
+  linalg.copy(%in, %subview): memref<? x f32>, memref<16 x f32>
+  "some_interleaved_use"(%subview) : (memref<16 x f32>) -> ()
+  %0 = vector.transfer_read %alloc[%c0], %f1: memref<32 x f32>, vector<32 x f32>
+  dealloc %alloc : memref<32 x f32>
+  return %0: vector<32 x f32>
+}
+
+// This should fail the rewrite due to some interleaved use.
+// CHECK-LABEL: failAllocWrite
+//  CHECK-SAME: %[[ARG0:[0-9a-zA-Z]*]]: vector
+//  CHECK-SAME: %[[ARG1:[0-9a-zA-Z]*]]: memref
+//   CHECK-NOT: vector.transfer_write %[[ARG0]], %[[ARG1]]
+//       CHECK: %[[ALLOC:.*]] = alloc
+//       CHECK: vector.transfer_write %[[ARG0]], %[[ALLOC]]
+//       CHECK: linalg.copy
+func @failAllocWrite(%vec: vector<32 x f32>, %out: memref<? x f32>) {
+  %c0 = constant 0: index
+  %f0 = constant 0.0: f32
+  %alloc = alloc() : memref<32 x f32>
+  %subview = subview %alloc[0][16][1] : memref<32 x f32> to memref<16 x f32>
+  vector.transfer_write %vec, %alloc[%c0] : vector<32 x f32>, memref<32 x f32>
+  "some_interleaved_use"(%subview) : (memref<16 x f32>) -> ()
+  linalg.copy(%subview, %out): memref<16 x f32>, memref<? x f32>
+  dealloc %alloc : memref<32 x f32>
+  return
+}
index c38494f..31189f4 100644 (file)
@@ -14,6 +14,7 @@
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 
@@ -48,6 +49,11 @@ struct TestLinalgTransforms
   Option<bool> testPromotionOptions{*this, "test-linalg-promotion-options",
                                     llvm::cl::desc("Test promotion options"),
                                     llvm::cl::init(false)};
+  Option<bool> testVectorTransferForwardingPatterns{
+      *this, "test-vector-transfer-forwarding-patterns",
+      llvm::cl::desc(
+          "Test a fused pass that forwards linalg.copy to vector.transfer"),
+      llvm::cl::init(false)};
 };
 } // end anonymous namespace
 
@@ -167,19 +173,6 @@ static void applyPatterns(FuncOp funcOp) {
   });
 }
 
-static OwningRewritePatternList
-getMatmulToVectorCanonicalizationPatterns(MLIRContext *context) {
-  OwningRewritePatternList patterns;
-  AffineApplyOp::getCanonicalizationPatterns(patterns, context);
-  AffineMinOp::getCanonicalizationPatterns(patterns, context);
-  AffineMaxOp::getCanonicalizationPatterns(patterns, context);
-  AllocOp::getCanonicalizationPatterns(patterns, context);
-  SubViewOp::getCanonicalizationPatterns(patterns, context);
-  ViewOp::getCanonicalizationPatterns(patterns, context);
-  MatmulOp::getCanonicalizationPatterns(patterns, context);
-  return patterns;
-}
-
 static void fillL1TilingAndMatmulToVectorPatterns(
     FuncOp funcOp, StringRef startMarker,
     SmallVectorImpl<OwningRewritePatternList> &patternsVector) {
@@ -261,40 +254,58 @@ void fillPromotionCallBackPatterns(MLIRContext *context,
       LinalgMarker({"PROMOTE"}));
 }
 
+static void
+applyMatmulToVectorPatterns(FuncOp funcOp,
+                            bool testMatmulToVectorPatterns1dTiling,
+                            bool testMatmulToVectorPatterns2dTiling) {
+  MLIRContext *ctx = funcOp.getContext();
+  SmallVector<OwningRewritePatternList, 4> stage1Patterns;
+  if (testMatmulToVectorPatterns1dTiling) {
+    fillL1TilingAndMatmulToVectorPatterns(funcOp, "START", stage1Patterns);
+  } else if (testMatmulToVectorPatterns2dTiling) {
+    stage1Patterns.emplace_back(
+        LinalgTilingPattern<MatmulOp>(ctx,
+                                      LinalgTilingOptions()
+                                          .setTileSizes({768, 264, 768})
+                                          .setInterchange({1, 2, 0}),
+                                      LinalgMarker({"START"}, "L2")));
+    fillL1TilingAndMatmulToVectorPatterns(funcOp, "L2", stage1Patterns);
+  }
+  OwningRewritePatternList stage2Patterns =
+      getLinalgTilingCanonicalizationPatterns(ctx);
+  applyStagedPatterns(funcOp, stage1Patterns, stage2Patterns);
+}
+
+static void applyVectorTransferForwardingPatterns(FuncOp funcOp) {
+  OwningRewritePatternList forwardPattern;
+  forwardPattern.insert<LinalgCopyVTRForwardingPattern>(funcOp.getContext());
+  forwardPattern.insert<LinalgCopyVTWForwardingPattern>(funcOp.getContext());
+  applyPatternsAndFoldGreedily(funcOp, forwardPattern);
+}
+
 /// Apply transformations specified as patterns.
 void TestLinalgTransforms::runOnFunction() {
-  if (testPatterns) {
-    applyPatterns(getFunction());
-    return;
-  }
+  auto lambda = [&](void *) {
+    getFunction().walk([](LinalgOp op) {
+      op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
+    });
+  };
+  std::unique_ptr<void, decltype(lambda)> cleanupGuard{(void *)1, lambda};
+
   if (testPromotionOptions) {
     OwningRewritePatternList patterns;
     fillPromotionCallBackPatterns(&getContext(), patterns);
     applyPatternsAndFoldGreedily(getFunction(), patterns);
-  } else {
-    SmallVector<OwningRewritePatternList, 4> stage1Patterns;
-    if (testMatmulToVectorPatterns1dTiling) {
-      fillL1TilingAndMatmulToVectorPatterns(getFunction(), "START",
-                                            stage1Patterns);
-    } else if (testMatmulToVectorPatterns2dTiling) {
-      stage1Patterns.emplace_back(
-          LinalgTilingPattern<MatmulOp>(&getContext(),
-                                        LinalgTilingOptions()
-                                            .setTileSizes({768, 264, 768})
-                                            .setInterchange({1, 2, 0}),
-                                        LinalgMarker({"START"}, "L2")));
-      fillL1TilingAndMatmulToVectorPatterns(getFunction(), "L2",
-                                            stage1Patterns);
-    }
-    OwningRewritePatternList stage2Patterns =
-        getMatmulToVectorCanonicalizationPatterns(&getContext());
-    applyStagedPatterns(getFunction(), stage1Patterns, stage2Patterns);
+    return;
   }
-
-  // Drop the marker.
-  getFunction().walk([](LinalgOp op) {
-    op.removeAttr(LinalgTransforms::kLinalgTransformMarker);
-  });
+  if (testPatterns)
+    return applyPatterns(getFunction());
+  if (testMatmulToVectorPatterns1dTiling || testMatmulToVectorPatterns2dTiling)
+    return applyMatmulToVectorPatterns(getFunction(),
+                                       testMatmulToVectorPatterns1dTiling,
+                                       testMatmulToVectorPatterns2dTiling);
+  if (testVectorTransferForwardingPatterns)
+    return applyVectorTransferForwardingPatterns(getFunction());
 }
 
 namespace mlir {