[mlir][linalg] Retire Linalg's Vectorization Pattern
authorGuray Ozen <guray.ozen@gmail.com>
Thu, 15 Sep 2022 08:39:13 +0000 (10:39 +0200)
committerGuray Ozen <guray.ozen@gmail.com>
Thu, 15 Sep 2022 09:23:46 +0000 (11:23 +0200)
This revision retires the LinalgCodegenStrategy vectorization pattern. Please see the context: https://discourse.llvm.org/t/psa-retire-linalg-filter-based-patterns/63785.
This revision improves the transform dialect's VectorizeOp in different ways below:
- Adds LinalgDialect as a dependent dialect. When `transform.structured.vectorize` vectorizes `tensor.pad`, it generates `linalg.init_tensor`. In this case, linalg dialect must be registered.
- Inserts CopyVectorizationPattern in order to vectorize `memref.copy`.
- Creates two attributes: `disable_multi_reduction_to_contract_patterns` and `disable_transfer_permutation_map_lowering_patterns`. They are limiting the power of vectorization and are currently intended for testing purposes.

It also removes some of the "CHECK: vector.transfer_write" in the vectorization.mlir test. They are redundant writes, at the end of the code there is a rewrite to the same place. Transform dialect no longer generates them.

Depends on D133684 that retires the LinalgCodegenStrategy vectorization pass.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D133699

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
mlir/test/Dialect/Linalg/vectorization.mlir
mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp

index 46d70a6..79c0e62 100644 (file)
@@ -767,6 +767,10 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
     Note that this transformation is invalidating the handles to any payload IR
     operation that is contained inside the vectorization target.
 
+    `disable_multi_reduction_to_contract_patterns` and 
+    `disable_transfer_permutation_map_lowering_patterns` limits the power of 
+    vectorization. They are currently intended for testing purposes.
+
     #### Return modes:
     
     This operation produces `definiteFailure` if vectorization fails for any
@@ -776,7 +780,9 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
   }];
 
   let arguments = (ins PDL_Operation:$target,
-                   DefaultValuedAttr<BoolAttr, "false">:$vectorize_padding);
+                   DefaultValuedAttr<BoolAttr, "false">:$vectorize_padding,                   
+                   DefaultValuedAttr<BoolAttr, "false">:$disable_multi_reduction_to_contract_patterns,                   
+                   DefaultValuedAttr<BoolAttr, "false">:$disable_transfer_permutation_map_lowering_patterns);
   let results = (outs PDL_Operation:$transformed);
 
   let assemblyFormat = "$target attr-dict";
index 9586815..43185a2 100644 (file)
@@ -928,31 +928,6 @@ struct LinalgVectorizationOptions {};
 
 /// `filter` controls LinalgTransformMarker matching and update when specified.
 /// See `vectorizeLinalgOp` for more details.
-struct LinalgVectorizationPattern : public OpInterfaceRewritePattern<LinalgOp> {
-  /// Construct a generic pattern applied to all LinalgOp that verify `filter`.
-  LinalgVectorizationPattern(
-      MLIRContext *context,
-      LinalgTransformationFilter f = LinalgTransformationFilter(),
-      LinalgVectorizationOptions options = LinalgVectorizationOptions(),
-      PatternBenefit benefit = 1);
-
-  /// Construct a pattern specifically applied to `opName`.
-  LinalgVectorizationPattern(
-      StringRef opName, MLIRContext *context,
-      LinalgVectorizationOptions options = LinalgVectorizationOptions(),
-      LinalgTransformationFilter f = LinalgTransformationFilter(),
-      PatternBenefit benefit = 1);
-
-  LogicalResult matchAndRewrite(LinalgOp linalgOp,
-                                PatternRewriter &rewriter) const override;
-
-private:
-  /// LinalgTransformMarker handles special attribute manipulations.
-  LinalgTransformationFilter filter;
-};
-
-/// `filter` controls LinalgTransformMarker matching and update when specified.
-/// See `vectorizeLinalgOp` for more details.
 struct CopyVectorizationPattern : public OpRewritePattern<memref::CopyOp> {
   using OpRewritePattern<memref::CopyOp>::OpRewritePattern;
 
@@ -1335,18 +1310,6 @@ public:
                      const LinalgTransformationFilter &f) {}
 };
 
-template <typename OpTy, typename... OpTypes>
-class VectorizationPatterns<OpTy, OpTypes...> {
-public:
-  static void insert(RewritePatternSet &patterns,
-                     const LinalgVectorizationOptions &options,
-                     const LinalgTransformationFilter &f) {
-    patterns.add<LinalgVectorizationPattern>(OpTy::getOperationName(),
-                                             patterns.getContext(), options, f);
-    VectorizationPatterns<OpTypes...>::insert(patterns, options, f);
-  }
-};
-
 template <typename... OpTypes>
 class TilingPatterns;
 
index 29b13e2..93b1274 100644 (file)
@@ -1166,6 +1166,22 @@ LogicalResult TileToForeachThreadOp::verify() {
 // VectorizeOp
 //===----------------------------------------------------------------------===//
 
+namespace {
+/// This is an helper only to call vectorize via a pattern inside of
+/// VectorizeOp::applyToOne.
+struct VectorizationPattern : public RewritePattern {
+  explicit VectorizationPattern(MLIRContext *context)
+      : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const override {
+    LinalgOp linalgOp = dyn_cast<LinalgOp>(op);
+    if (!linalgOp)
+      return failure();
+    return vectorize(rewriter, linalgOp);
+  }
+};
+} // namespace
+
 DiagnosedSilenceableFailure
 transform::VectorizeOp::applyToOne(Operation *target,
                                    SmallVectorImpl<Operation *> &results,
@@ -1178,15 +1194,22 @@ transform::VectorizeOp::applyToOne(Operation *target,
 
   MLIRContext *ctx = getContext();
   RewritePatternSet patterns(ctx);
-  patterns.add<LinalgVectorizationPattern>(ctx);
+  patterns.add<VectorizationPattern>(ctx);
+
+  if (!getDisableTransferPermutationMapLoweringPatterns())
+    vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
+
+  if (!getDisableMultiReductionToContractPatterns())
+    vector::populateVectorReductionToContractPatterns(patterns);
 
-  vector::populateVectorTransferPermutationMapLoweringPatterns(patterns);
-  vector::populateVectorReductionToContractPatterns(patterns);
   patterns.add<linalg::LinalgCopyVTRForwardingPattern,
                linalg::LinalgCopyVTWForwardingPattern>(ctx,
                                                        /*benefit=*/2);
   vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx);
   vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx);
+
+  patterns.add<CopyVectorizationPattern>(ctx);
+
   if (getVectorizePadding())
     linalg::populatePadOpVectorizationPatterns(patterns);
 
@@ -1212,7 +1235,7 @@ public:
 
   void init() {
     declareDependentDialect<pdl::PDLDialect>();
-
+    declareDependentDialect<LinalgDialect>();
     declareGeneratedDialect<AffineDialect>();
     declareGeneratedDialect<arith::ArithmeticDialect>();
     declareGeneratedDialect<scf::SCFDialect>();
index 31136dd..b00d923 100644 (file)
@@ -590,25 +590,6 @@ LogicalResult mlir::linalg::LinalgPeelingPattern::matchAndRewrite(
   return success();
 }
 
-mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
-    MLIRContext *context, LinalgTransformationFilter f,
-    LinalgVectorizationOptions options, PatternBenefit benefit)
-    : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
-      filter(std::move(f)) {}
-
-mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern(
-    StringRef opName, MLIRContext *context, LinalgVectorizationOptions options,
-    LinalgTransformationFilter f, PatternBenefit benefit)
-    : OpInterfaceRewritePattern<LinalgOp>(context, benefit),
-      filter(f.addOpNameFilter(opName)) {}
-
-LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite(
-    LinalgOp linalgOp, PatternRewriter &rewriter) const {
-  if (failed(filter.checkAndNotify(rewriter, linalgOp)))
-    return failure();
-  return vectorize(rewriter, linalgOp);
-}
-
 LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite(
     memref::CopyOp copyOp, PatternRewriter &rewriter) const {
   return vectorizeCopy(rewriter, copyOp);
index 2295305..6ac2fdb 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s
 
 // -----
 
@@ -12,6 +12,16 @@ func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memre
   return
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.dot"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1  { disable_multi_reduction_to_contract_patterns = true }
+  }
+}
+
 // -----
 
 // CHECK-LABEL: contraction_matvec
@@ -24,6 +34,16 @@ func.func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, %
   return
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.matvec"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1  { disable_multi_reduction_to_contract_patterns = true }
+  }
+}
+
 // -----
 
 // CHECK-LABEL: contraction_matmul
@@ -35,6 +55,16 @@ func.func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf3
   return
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1  { disable_multi_reduction_to_contract_patterns = true }
+  }
+}
+
 // -----
 
 // CHECK-LABEL: contraction_batch_matmul
@@ -47,6 +77,16 @@ func.func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1
   return
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1  { disable_multi_reduction_to_contract_patterns = true }
+  }
+}
+
 // -----
 
 #matmul_trait = {
@@ -80,6 +120,16 @@ func.func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
   return
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1  { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
+  }
+}
+
 // -----
 
 #matmul_transpose_out_trait = {
@@ -113,6 +163,16 @@ func.func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
   return
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1  { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
+  }
+}
+
 // -----
 
 #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
@@ -133,6 +193,16 @@ func.func @generic_interchanged_transpose(%arg0: tensor<12x128x32xf32>) -> tenso
   return %1 : tensor<128x12x32xf32>
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1  { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
+  }
+}
+
 // -----
 
 #matmul_trait = {
@@ -166,6 +236,16 @@ func.func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32
   return
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1  { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
+  }
+}
+
 // -----
 
 // CHECK-LABEL: func @vectorization_test_2
@@ -179,6 +259,16 @@ func.func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
   return
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1  { disable_multi_reduction_to_contract_patterns = true }
+  }
+}
+
 // -----
 
 // CHECK-LABEL: func @test_vectorize_scalar_input
@@ -196,6 +286,16 @@ func.func @test_vectorize_scalar_input(%A : memref<8x16xf32>, %arg0 : f32) {
   return
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1 
+  }
+}
+
 // -----
 
 // CHECK-LABEL: func @test_do_not_vectorize_unsupported_element_types
@@ -213,6 +313,16 @@ func.func @test_do_not_vectorize_unsupported_element_types(%A : memref<8x16xcomp
   return
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1 
+  }
+}
+
 // -----
 
 // CHECK-LABEL: func @test_vectorize_fill
@@ -223,6 +333,16 @@ func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) {
   return
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.fill"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1 
+  }
+}
+
 // -----
 
 // CHECK-LABEL: func @test_vectorize_fill
@@ -234,6 +354,16 @@ func.func @test_vectorize_fill_scalar(%A : memref<f32>, %arg0 : f32) {
   return
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.fill"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1 
+  }
+}
+
 // -----
 
 // CHECK-LABEL: func @test_vectorize_copy
@@ -244,6 +374,16 @@ func.func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) {
   return
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["memref.copy"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1 
+  }
+}
+
 // -----
 
 // CHECK-LABEL: func @test_vectorize_copy_scalar
@@ -257,6 +397,15 @@ func.func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
   return
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["memref.copy"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1 
+  }
+}
 // -----
 
 // CHECK-LABEL: func @test_vectorize_trailing_index
@@ -278,6 +427,16 @@ func.func @test_vectorize_trailing_index(%arg0: memref<1x2x4x8xindex>) {
   return
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1 
+  }
+}
+
 // -----
 
 // CHECK-LABEL: func @test_vectorize_inner_index
@@ -300,6 +459,16 @@ func.func @test_vectorize_inner_index(%arg0: memref<1x2x4x8xindex>) {
   return
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1 
+  }
+}
+
 // -----
 
 // CHECK-LABEL: func @generic_vectorize
@@ -378,6 +547,16 @@ func.func @generic_vectorize(%arg0: memref<4x256xf32>,
   return
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1  {disable_transfer_permutation_map_lowering_patterns = true }
+  }
+}
+
 // -----
 
 // CHECK-LABEL: func @generic_vectorize_tensor
@@ -462,6 +641,16 @@ func.func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>,
     tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1  { disable_transfer_permutation_map_lowering_patterns = true }
+  }
+}
+
 // -----
 
 // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, 0, 0, d1)>
@@ -499,6 +688,16 @@ func.func @generic_vectorize_broadcast_transpose(
   return
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1  {disable_transfer_permutation_map_lowering_patterns = true }
+  }
+}
+
 // -----
 
 // Test different input maps.
@@ -535,6 +734,16 @@ func.func @vectorization_transpose(%A: memref<14x7xf32>, %B: memref<16x14xf32>,
   return
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1  {disable_transfer_permutation_map_lowering_patterns = true }
+  }
+}
+
 // -----
 
 // CHECK-LABEL: func @matmul_tensors
@@ -560,6 +769,16 @@ func.func @matmul_tensors(
   return %0 : tensor<8x12xf32>
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1  { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
+  }
+}
+
 // -----
 
 // CHECK-LABEL: func @pad_static(
@@ -581,6 +800,17 @@ func.func @pad_static(%arg0: tensor<2x?x2xf32>, %pad_value: f32) -> tensor<2x3x4
   return %0 : tensor<2x3x4xf32>
 }
 
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1 { vectorize_padding = true }
+  }
+}
+
 // -----
 
 // CHECK-LABEL: func @pad_static_source(
@@ -602,6 +832,18 @@ func.func @pad_static_source(%arg0: tensor<2x5x2xf32>, %pad_value: f32) -> tenso
   return %0 : tensor<2x6x4xf32>
 }
 
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1  { vectorize_padding = true }
+  }
+}
+
+
 // -----
 
 // CHECK-LABEL: func @pad_static_dynamic(
@@ -630,6 +872,18 @@ func.func @pad_static_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: in
   return %0 : tensor<6x?x?x?xf32>
 }
 
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1  { vectorize_padding = true }
+  }
+}
+
+
 // -----
 
 // CHECK-LABEL: func @pad_and_transfer_read
@@ -652,6 +906,17 @@ func.func @pad_and_transfer_read(%arg0: tensor<5x6xf32>) -> vector<7x9xf32> {
   return %1 : vector<7x9xf32>
 }
 
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["tensor.pad"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1 { vectorize_padding = true }
+  }
+}
+
 // -----
 
 func.func private @make_vector() -> vector<7x9xf32>
@@ -678,6 +943,17 @@ func.func @pad_and_transfer_write_static(
   return %3 : tensor<5x6xf32>
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %3 = transform.structured.match ops{["tensor.pad"]} in %arg1
+    %4 = get_closest_isolated_parent %3
+    %5 = transform.structured.vectorize %4  { vectorize_padding = true }
+  }
+}
+
+
 // -----
 
 func.func private @make_vector() -> vector<7x9xf32>
@@ -707,6 +983,17 @@ func.func @pad_and_transfer_write_dynamic_static(
   return %3 : tensor<?x6xf32>
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %3 = transform.structured.match ops{["tensor.pad"]} in %arg1
+    %4 = get_closest_isolated_parent %3
+    %5 = transform.structured.vectorize %4 { vectorize_padding = true }
+  }
+}
+
+
 // -----
 
 func.func private @make_vector() -> tensor<12x13xf32>
@@ -733,6 +1020,17 @@ func.func @pad_and_insert_slice_source(
   return %r : tensor<12x13xf32>
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %3 = transform.structured.match ops{["tensor.pad"]} in %arg1
+    %4 = get_closest_isolated_parent %3
+    %5 = transform.structured.vectorize %4  { vectorize_padding = true }
+  }
+}
+
+
 // -----
 
 func.func private @make_vector() -> tensor<12x13xf32>
@@ -753,6 +1051,16 @@ func.func @pad_and_insert_slice_dest(
   return %r : tensor<1x12x13xf32>
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %3 = transform.structured.match ops{["tensor.pad"]} in %arg1
+    %4 = get_closest_isolated_parent %3
+    %5 = transform.structured.vectorize %4 
+  }
+}
+
 // -----
 
 // CHECK-LABEL: func @pad_tensor_non_const_pad_value
@@ -782,6 +1090,17 @@ func.func @pad_tensor_non_const_pad_value(%arg0: tensor<5x6xf32>) -> tensor<12x1
   return %0 : tensor<12x13xf32>
 }
 
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %3 = transform.structured.match ops{["tensor.pad"]} in %arg1
+    %4 = get_closest_isolated_parent %3
+    %5 = transform.structured.vectorize %4  { vectorize_padding = true }
+  }
+}
+
 // -----
 
 // CHECK-LABEL: func @sum_exp
@@ -809,6 +1128,17 @@ func.func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>)
   return %0 : tensor<4x16xf32>
 }
 
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %3 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %4 = get_closest_isolated_parent %3
+    %5 = transform.structured.vectorize %4 
+  }
+}
+
 // -----
 
 // CHECK-DAG: #[[$M1:.*]] =  affine_map<(d0, d1) -> (d1, d0, 0, 0)>
@@ -846,13 +1176,23 @@ func.func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output
   return %0 : tensor<5x2xf32>
 }
 
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %3 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %4 = get_closest_isolated_parent %3
+    %5 = transform.structured.vectorize %4  { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
+  }
+}
+
 // -----
 
 // CHECK-LABEL:   func @red_max_2d(
 func.func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
   // CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32>
   // CHECK: linalg.init_tensor [4] : tensor<4xf32>
-  // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
   // CHECK: vector.multi_reduction <maxf>, {{.*}}, %[[CMINF]] [1] : vector<4x4xf32> to vector<4xf32>
   // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
   %ident = arith.constant -3.40282e+38 : f32
@@ -869,13 +1209,23 @@ func.func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
   return %red : tensor<4xf32>
 }
 
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %3 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %4 = get_closest_isolated_parent %3
+    %5 = transform.structured.vectorize %4 { vectorize_padding = true }
+  }
+}
+
 // -----
 
 // CHECK-LABEL:   func @red_min_2d(
 func.func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
   // CHECK: %[[CMAXF:.+]] = arith.constant dense<3.402820e+38> : vector<4xf32>
   // CHECK: linalg.init_tensor [4] : tensor<4xf32>
-  // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
   // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
   // CHECK: vector.multi_reduction <minf>, {{.*}}, %[[CMAXF]] [1] : vector<4x4xf32> to vector<4xf32>
   // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
@@ -893,12 +1243,22 @@ func.func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
   return %red : tensor<4xf32>
 }
 
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %3 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %4 = get_closest_isolated_parent %3
+    %5 = transform.structured.vectorize %4 
+  }
+}
+
 // -----
 
 // CHECK-LABEL:   func @red_mul_2d(
 func.func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
   // CHECK: linalg.init_tensor [4] : tensor<4xf32>
-  // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
   // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32>
   // CHECK: vector.multi_reduction <mul>, {{.*}}, {{.*}} [1] : vector<4x4xf32> to vector<4xf32>
   // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32>
@@ -916,12 +1276,22 @@ func.func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> {
   return %red : tensor<4xf32>
 }
 
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %3 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %4 = get_closest_isolated_parent %3
+    %5 = transform.structured.vectorize %4 
+  }
+}
+
 // -----
 
 // CHECK-LABEL:   func @red_or_2d(
 func.func @red_or_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
   // CHECK: linalg.init_tensor [4] : tensor<4xi1>
-  // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
   // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
   // CHECK: vector.multi_reduction <or>, {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
   // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
@@ -939,12 +1309,22 @@ func.func @red_or_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
   return %red : tensor<4xi1>
 }
 
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %3 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %4 = get_closest_isolated_parent %3
+    %5 = transform.structured.vectorize %4 
+  }
+}
+
 // -----
 
 // CHECK-LABEL:   func @red_and_2d(
 func.func @red_and_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
   // CHECK: linalg.init_tensor [4] : tensor<4xi1>
-  // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
   // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
   // CHECK: vector.multi_reduction <and>, {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
   // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
@@ -962,12 +1342,22 @@ func.func @red_and_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
   return %red : tensor<4xi1>
 }
 
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %3 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %4 = get_closest_isolated_parent %3
+    %5 = transform.structured.vectorize %4 
+  }
+}
+
 // -----
 
 // CHECK-LABEL:   func @red_xor_2d(
 func.func @red_xor_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
   // CHECK: linalg.init_tensor [4] : tensor<4xi1>
-  // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
   // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1>
   // CHECK: vector.multi_reduction <xor>, {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1>
   // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1>
@@ -985,6 +1375,17 @@ func.func @red_xor_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> {
   return %red : tensor<4xi1>
 }
 
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %3 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %4 = get_closest_isolated_parent %3
+    %5 = transform.structured.vectorize %4 
+  }
+}
+
 // -----
 
 // CHECK-DAG: #[[$M5:.*]] = affine_map<(d0, d1) -> (d0, 0)>
@@ -1011,6 +1412,17 @@ func.func @explicit_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) ->
   return %red : tensor<4x4xf32>
 }
 
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %3 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %4 = get_closest_isolated_parent %3
+    %5 = transform.structured.vectorize %4 
+  }
+}
+
 // -----
 
 // CHECK-DAG: #[[$M6:.*]] = affine_map<(d0, d1) -> (d0, 0)>
@@ -1041,6 +1453,21 @@ func.func @fused_broadcast_red_2d(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>
   return %red : tensor<4xf32>
 }
 
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.fill"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1 
+    
+    %3 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %4 = get_closest_isolated_parent %3
+    %5 = transform.structured.vectorize %4 
+  }
+}
+
 // -----
 
 //  CHECK-LABEL: func @reduce_1d(
@@ -1054,8 +1481,6 @@ func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
   //      CHECK: %[[init:.*]] = linalg.init_tensor [] : tensor<f32>
   %0 = linalg.init_tensor [] : tensor<f32>
 
-  //      CHECK: %[[f:.*]] = vector.transfer_write %[[vF0]], %[[init]][]
-  // CHECK-SAME:   : vector<f32>, tensor<f32>
   %1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor<f32>) -> tensor<f32>
   //      CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]]
   // CHECK-SAME:   : tensor<32xf32>, vector<32xf32>
@@ -1063,7 +1488,7 @@ func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
   //      CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[f0]] [0]
   // CHECK-SAME:   : vector<32xf32> to f32
   //      CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<f32>
-  //      CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][]
+  //      CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[init]][]
   // CHECK-SAME:   : vector<f32>, tensor<f32>
   %2 = linalg.generic {
          indexing_maps = [affine_map<(d0) -> (d0)>,
@@ -1079,6 +1504,16 @@ func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> {
   return %2 : tensor<f32>
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1 
+  }
+}
+
 
 // -----
 
@@ -1103,6 +1538,16 @@ func.func @not_projected_permutation(%arg0: tensor<8x8xf32>) -> tensor<6x6x3x3xf
   return %result : tensor<6x6x3x3xf32>
 }
 
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1 
+  }
+}
+
 // -----
 
 // Check vectorization can handle cases where outputs are a mix of reduced and non-reduced values.
@@ -1134,3 +1579,13 @@ func.func @mixed_parallel_reduced_results(%arg0 : tensor<2x4x8xf32>,
 //   CHECK-DAG:   %[[ADD:.+]] = vector.multi_reduction <add>, %[[MUL]], %[[V2]]
 //   CHECK-DAG:   vector.transfer_write %[[MUL]], %[[ARG2]]
 //   CHECK-DAG:   vector.transfer_write %[[ADD]], %[[ARG3]]
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+  transform.sequence %arg0 failures(propagate) {
+  ^bb1(%arg1: !pdl.operation):
+    %0 = transform.structured.match ops{["linalg.generic"]} in %arg1
+    %1 = get_closest_isolated_parent %0
+    %2 = transform.structured.vectorize %1  { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true }
+  }
+}
\ No newline at end of file
index 576082a..3949544 100644 (file)
@@ -225,9 +225,6 @@ static void applyPatterns(func::FuncOp funcOp) {
   //===--------------------------------------------------------------------===//
   // Linalg to vector contraction patterns.
   //===--------------------------------------------------------------------===//
-  patterns.add<LinalgVectorizationPattern>(
-      ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE"))
-               .addOpFilter<MatmulOp, FillOp, GenericOp>());
   patterns.add<CopyVectorizationPattern>(ctx);
 
   (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));
@@ -441,9 +438,6 @@ static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) {
 static void applyLinalgToVectorPatterns(func::FuncOp funcOp) {
   RewritePatternSet patterns(funcOp.getContext());
   auto *ctx = funcOp.getContext();
-  patterns.add<LinalgVectorizationPattern>(
-      ctx, LinalgTransformationFilter()
-               .addOpFilter<ContractionOpInterface, FillOp, GenericOp>());
   patterns.add<CopyVectorizationPattern>(ctx);
   populatePadOpVectorizationPatterns(patterns);
   populateConvolutionVectorizationPatterns(patterns);