From: Guray Ozen Date: Thu, 15 Sep 2022 08:39:13 +0000 (+0200) Subject: [mlir][linalg] Retire Linalg's Vectorization Pattern X-Git-Tag: upstream/17.0.6~33471 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=5279e11f063db6a0cc87ccf9e0e1c7b1b31aa7cf;p=platform%2Fupstream%2Fllvm.git [mlir][linalg] Retire Linalg's Vectorization Pattern This revision retires the LinalgCodegenStrategy vectorization pattern. Please see the context: https://discourse.llvm.org/t/psa-retire-linalg-filter-based-patterns/63785. This revision improves the transform dialect's VectorizeOp in different ways below: - Adds LinalgDialect as a dependent dialect. When `transform.structured.vectorize` vectorizes `tensor.pad`, it generates `linalg.init_tensor`. In this case, linalg dialect must be registered. - Inserts CopyVectorizationPattern in order to vectorize `memref.copy`. - Creates two attributes: `disable_multi_reduction_to_contract_patterns` and `disable_transfer_permutation_map_lowering_patterns`. They are limiting the power of vectorization and are currently intended for testing purposes. It also removes some of the "CHECK: vector.transfer_write" in the vectorization.mlir test. They are redundant writes, at the end of the code there is a rewrite to the same place. Transform dialect no longer generates them. Depends on D133684 that retires the LinalgCodegenStrategy vectorization pass. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D133699 --- diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 46d70a6..79c0e62 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -767,6 +767,10 @@ def VectorizeOp : Op:$vectorize_padding); + DefaultValuedAttr:$vectorize_padding, + DefaultValuedAttr:$disable_multi_reduction_to_contract_patterns, + DefaultValuedAttr:$disable_transfer_permutation_map_lowering_patterns); let results = (outs PDL_Operation:$transformed); let assemblyFormat = "$target attr-dict"; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 9586815..43185a2 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -928,31 +928,6 @@ struct LinalgVectorizationOptions {}; /// `filter` controls LinalgTransformMarker matching and update when specified. /// See `vectorizeLinalgOp` for more details. -struct LinalgVectorizationPattern : public OpInterfaceRewritePattern { - /// Construct a generic pattern applied to all LinalgOp that verify `filter`. - LinalgVectorizationPattern( - MLIRContext *context, - LinalgTransformationFilter f = LinalgTransformationFilter(), - LinalgVectorizationOptions options = LinalgVectorizationOptions(), - PatternBenefit benefit = 1); - - /// Construct a pattern specifically applied to `opName`. - LinalgVectorizationPattern( - StringRef opName, MLIRContext *context, - LinalgVectorizationOptions options = LinalgVectorizationOptions(), - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1); - - LogicalResult matchAndRewrite(LinalgOp linalgOp, - PatternRewriter &rewriter) const override; - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; -}; - -/// `filter` controls LinalgTransformMarker matching and update when specified. -/// See `vectorizeLinalgOp` for more details. struct CopyVectorizationPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1335,18 +1310,6 @@ public: const LinalgTransformationFilter &f) {} }; -template -class VectorizationPatterns { -public: - static void insert(RewritePatternSet &patterns, - const LinalgVectorizationOptions &options, - const LinalgTransformationFilter &f) { - patterns.add(OpTy::getOperationName(), - patterns.getContext(), options, f); - VectorizationPatterns::insert(patterns, options, f); - } -}; - template class TilingPatterns; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 29b13e2..93b1274 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -1166,6 +1166,22 @@ LogicalResult TileToForeachThreadOp::verify() { // VectorizeOp //===----------------------------------------------------------------------===// +namespace { +/// This is an helper only to call vectorize via a pattern inside of +/// VectorizeOp::applyToOne. +struct VectorizationPattern : public RewritePattern { + explicit VectorizationPattern(MLIRContext *context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + LinalgOp linalgOp = dyn_cast(op); + if (!linalgOp) + return failure(); + return vectorize(rewriter, linalgOp); + } +}; +} // namespace + DiagnosedSilenceableFailure transform::VectorizeOp::applyToOne(Operation *target, SmallVectorImpl &results, @@ -1178,15 +1194,22 @@ transform::VectorizeOp::applyToOne(Operation *target, MLIRContext *ctx = getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx); + patterns.add(ctx); + + if (!getDisableTransferPermutationMapLoweringPatterns()) + vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); + + if (!getDisableMultiReductionToContractPatterns()) + vector::populateVectorReductionToContractPatterns(patterns); - vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); - vector::populateVectorReductionToContractPatterns(patterns); patterns.add(ctx, /*benefit=*/2); vector::TransferReadOp::getCanonicalizationPatterns(patterns, ctx); vector::TransferWriteOp::getCanonicalizationPatterns(patterns, ctx); + + patterns.add(ctx); + if (getVectorizePadding()) linalg::populatePadOpVectorizationPatterns(patterns); @@ -1212,7 +1235,7 @@ public: void init() { declareDependentDialect(); - + declareDependentDialect(); declareGeneratedDialect(); declareGeneratedDialect(); declareGeneratedDialect(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 31136dd..b00d923 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -590,25 +590,6 @@ LogicalResult mlir::linalg::LinalgPeelingPattern::matchAndRewrite( return success(); } -mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( - MLIRContext *context, LinalgTransformationFilter f, - LinalgVectorizationOptions options, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(std::move(f)) {} - -mlir::linalg::LinalgVectorizationPattern::LinalgVectorizationPattern( - StringRef opName, MLIRContext *context, LinalgVectorizationOptions options, - LinalgTransformationFilter f, PatternBenefit benefit) - : OpInterfaceRewritePattern(context, benefit), - filter(f.addOpNameFilter(opName)) {} - -LogicalResult mlir::linalg::LinalgVectorizationPattern::matchAndRewrite( - LinalgOp linalgOp, PatternRewriter &rewriter) const { - if (failed(filter.checkAndNotify(rewriter, linalgOp))) - return failure(); - return vectorize(rewriter, linalgOp); -} - LogicalResult mlir::linalg::CopyVectorizationPattern::matchAndRewrite( memref::CopyOp copyOp, PatternRewriter &rewriter) const { return vectorizeCopy(rewriter, copyOp); diff --git a/mlir/test/Dialect/Linalg/vectorization.mlir b/mlir/test/Dialect/Linalg/vectorization.mlir index 2295305..6ac2fdb 100644 --- a/mlir/test/Dialect/Linalg/vectorization.mlir +++ b/mlir/test/Dialect/Linalg/vectorization.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt %s -test-linalg-transform-patterns=test-linalg-to-vector-patterns -split-input-file | FileCheck %s +// RUN: mlir-opt %s -test-transform-dialect-interpreter -split-input-file | FileCheck %s // ----- @@ -12,6 +12,16 @@ func.func @contraction_dot(%A: memref<1584xf32>, %B: memref<1584xf32>, %C: memre return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.dot"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true } + } +} + // ----- // CHECK-LABEL: contraction_matvec @@ -24,6 +34,16 @@ func.func @contraction_matvec(%A: memref<1584x1584xf32>, %B: memref<1584xf32>, % return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matvec"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true } + } +} + // ----- // CHECK-LABEL: contraction_matmul @@ -35,6 +55,16 @@ func.func @contraction_matmul(%A: memref<1584x1584xf32>, %B: memref<1584x1584xf3 return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true } + } +} + // ----- // CHECK-LABEL: contraction_batch_matmul @@ -47,6 +77,16 @@ func.func @contraction_batch_matmul(%A: memref<1584x1584x1584xf32>, %B: memref<1 return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.batch_matmul"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true } + } +} + // ----- #matmul_trait = { @@ -80,6 +120,16 @@ func.func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>, return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- #matmul_transpose_out_trait = { @@ -113,6 +163,16 @@ func.func @generic_output_transpose(%A: memref<8x16xf32>, %B: memref<16x32xf32>, return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- #map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> @@ -133,6 +193,16 @@ func.func @generic_interchanged_transpose(%arg0: tensor<12x128x32xf32>) -> tenso return %1 : tensor<128x12x32xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- #matmul_trait = { @@ -166,6 +236,16 @@ func.func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32 return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // CHECK-LABEL: func @vectorization_test_2 @@ -179,6 +259,16 @@ func.func @vectorization_test_2(%A: memref<8x16xf32>, %B: memref<16x32xf32>, return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true } + } +} + // ----- // CHECK-LABEL: func @test_vectorize_scalar_input @@ -196,6 +286,16 @@ func.func @test_vectorize_scalar_input(%A : memref<8x16xf32>, %arg0 : f32) { return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_do_not_vectorize_unsupported_element_types @@ -213,6 +313,16 @@ func.func @test_do_not_vectorize_unsupported_element_types(%A : memref<8x16xcomp return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_vectorize_fill @@ -223,6 +333,16 @@ func.func @test_vectorize_fill(%A : memref<8x16xf32>, %arg0 : f32) { return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_vectorize_fill @@ -234,6 +354,16 @@ func.func @test_vectorize_fill_scalar(%A : memref, %arg0 : f32) { return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_vectorize_copy @@ -244,6 +374,16 @@ func.func @test_vectorize_copy(%A : memref<8x16xf32>, %B : memref<8x16xf32>) { return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["memref.copy"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_vectorize_copy_scalar @@ -257,6 +397,15 @@ func.func @test_vectorize_copy_scalar(%A : memref, %B : memref) { return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["memref.copy"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} // ----- // CHECK-LABEL: func @test_vectorize_trailing_index @@ -278,6 +427,16 @@ func.func @test_vectorize_trailing_index(%arg0: memref<1x2x4x8xindex>) { return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @test_vectorize_inner_index @@ -300,6 +459,16 @@ func.func @test_vectorize_inner_index(%arg0: memref<1x2x4x8xindex>) { return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // CHECK-LABEL: func @generic_vectorize @@ -378,6 +547,16 @@ func.func @generic_vectorize(%arg0: memref<4x256xf32>, return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // CHECK-LABEL: func @generic_vectorize_tensor @@ -462,6 +641,16 @@ func.func @generic_vectorize_tensor(%arg0: tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32>, tensor<4x256xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, 0, 0, d1)> @@ -499,6 +688,16 @@ func.func @generic_vectorize_broadcast_transpose( return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // Test different input maps. @@ -535,6 +734,16 @@ func.func @vectorization_transpose(%A: memref<14x7xf32>, %B: memref<16x14xf32>, return } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 {disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // CHECK-LABEL: func @matmul_tensors @@ -560,6 +769,16 @@ func.func @matmul_tensors( return %0 : tensor<8x12xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // CHECK-LABEL: func @pad_static( @@ -581,6 +800,17 @@ func.func @pad_static(%arg0: tensor<2x?x2xf32>, %pad_value: f32) -> tensor<2x3x4 return %0 : tensor<2x3x4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { vectorize_padding = true } + } +} + // ----- // CHECK-LABEL: func @pad_static_source( @@ -602,6 +832,18 @@ func.func @pad_static_source(%arg0: tensor<2x5x2xf32>, %pad_value: f32) -> tenso return %0 : tensor<2x6x4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { vectorize_padding = true } + } +} + + // ----- // CHECK-LABEL: func @pad_static_dynamic( @@ -630,6 +872,18 @@ func.func @pad_static_dynamic(%arg0: tensor<1x2x2x?xf32>, %low: index, %high: in return %0 : tensor<6x?x?x?xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { vectorize_padding = true } + } +} + + // ----- // CHECK-LABEL: func @pad_and_transfer_read @@ -652,6 +906,17 @@ func.func @pad_and_transfer_read(%arg0: tensor<5x6xf32>) -> vector<7x9xf32> { return %1 : vector<7x9xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { vectorize_padding = true } + } +} + // ----- func.func private @make_vector() -> vector<7x9xf32> @@ -678,6 +943,17 @@ func.func @pad_and_transfer_write_static( return %3 : tensor<5x6xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 { vectorize_padding = true } + } +} + + // ----- func.func private @make_vector() -> vector<7x9xf32> @@ -707,6 +983,17 @@ func.func @pad_and_transfer_write_dynamic_static( return %3 : tensor } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 { vectorize_padding = true } + } +} + + // ----- func.func private @make_vector() -> tensor<12x13xf32> @@ -733,6 +1020,17 @@ func.func @pad_and_insert_slice_source( return %r : tensor<12x13xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 { vectorize_padding = true } + } +} + + // ----- func.func private @make_vector() -> tensor<12x13xf32> @@ -753,6 +1051,16 @@ func.func @pad_and_insert_slice_dest( return %r : tensor<1x12x13xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @pad_tensor_non_const_pad_value @@ -782,6 +1090,17 @@ func.func @pad_tensor_non_const_pad_value(%arg0: tensor<5x6xf32>) -> tensor<12x1 return %0 : tensor<12x13xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["tensor.pad"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 { vectorize_padding = true } + } +} + // ----- // CHECK-LABEL: func @sum_exp @@ -809,6 +1128,17 @@ func.func @sum_exp(%input: tensor<4x16x8xf32>, %output: tensor<4x16xf32>) return %0 : tensor<4x16xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-DAG: #[[$M1:.*]] = affine_map<(d0, d1) -> (d1, d0, 0, 0)> @@ -846,13 +1176,23 @@ func.func @sum_exp_2(%input: tensor<3x2xf32>, %input_2: tensor<5x4xf32>, %output return %0 : tensor<5x2xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} + // ----- // CHECK-LABEL: func @red_max_2d( func.func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { // CHECK: %[[CMINF:.+]] = arith.constant dense<-3.402820e+38> : vector<4xf32> // CHECK: linalg.init_tensor [4] : tensor<4xf32> - // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.multi_reduction , {{.*}}, %[[CMINF]] [1] : vector<4x4xf32> to vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> %ident = arith.constant -3.40282e+38 : f32 @@ -869,13 +1209,23 @@ func.func @red_max_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { return %red : tensor<4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 { vectorize_padding = true } + } +} + // ----- // CHECK-LABEL: func @red_min_2d( func.func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { // CHECK: %[[CMAXF:.+]] = arith.constant dense<3.402820e+38> : vector<4xf32> // CHECK: linalg.init_tensor [4] : tensor<4xf32> - // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> // CHECK: vector.multi_reduction , {{.*}}, %[[CMAXF]] [1] : vector<4x4xf32> to vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> @@ -893,12 +1243,22 @@ func.func @red_min_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { return %red : tensor<4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @red_mul_2d( func.func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { // CHECK: linalg.init_tensor [4] : tensor<4xf32> - // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xf32>, vector<4x4xf32> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xf32> to vector<4xf32> // CHECK: vector.transfer_write {{.*}} : vector<4xf32>, tensor<4xf32> @@ -916,12 +1276,22 @@ func.func @red_mul_2d(%arg0: tensor<4x4xf32>) -> tensor<4xf32> { return %red : tensor<4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @red_or_2d( func.func @red_or_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { // CHECK: linalg.init_tensor [4] : tensor<4xi1> - // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> @@ -939,12 +1309,22 @@ func.func @red_or_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { return %red : tensor<4xi1> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @red_and_2d( func.func @red_and_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { // CHECK: linalg.init_tensor [4] : tensor<4xi1> - // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> @@ -962,12 +1342,22 @@ func.func @red_and_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { return %red : tensor<4xi1> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @red_xor_2d( func.func @red_xor_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { // CHECK: linalg.init_tensor [4] : tensor<4xi1> - // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> // CHECK: vector.transfer_read {{.*}} : tensor<4x4xi1>, vector<4x4xi1> // CHECK: vector.multi_reduction , {{.*}}, {{.*}} [1] : vector<4x4xi1> to vector<4xi1> // CHECK: vector.transfer_write {{.*}} : vector<4xi1>, tensor<4xi1> @@ -985,6 +1375,17 @@ func.func @red_xor_2d(%arg0: tensor<4x4xi1>) -> tensor<4xi1> { return %red : tensor<4xi1> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-DAG: #[[$M5:.*]] = affine_map<(d0, d1) -> (d0, 0)> @@ -1011,6 +1412,17 @@ func.func @explicit_broadcast(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32>) -> return %red : tensor<4x4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-DAG: #[[$M6:.*]] = affine_map<(d0, d1) -> (d0, 0)> @@ -1041,6 +1453,21 @@ func.func @fused_broadcast_red_2d(%arg0: tensor<4x4xf32>, %arg1: tensor<4x1xf32> return %red : tensor<4xf32> } + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.fill"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + + %3 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %4 = get_closest_isolated_parent %3 + %5 = transform.structured.vectorize %4 + } +} + // ----- // CHECK-LABEL: func @reduce_1d( @@ -1054,8 +1481,6 @@ func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor { // CHECK: %[[init:.*]] = linalg.init_tensor [] : tensor %0 = linalg.init_tensor [] : tensor - // CHECK: %[[f:.*]] = vector.transfer_write %[[vF0]], %[[init]][] - // CHECK-SAME: : vector, tensor %1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor) -> tensor // CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]] // CHECK-SAME: : tensor<32xf32>, vector<32xf32> @@ -1063,7 +1488,7 @@ func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor { // CHECK: %[[red:.*]] = vector.multi_reduction , %[[r]], %[[f0]] [0] // CHECK-SAME: : vector<32xf32> to f32 // CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector - // CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[f]][] + // CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[init]][] // CHECK-SAME: : vector, tensor %2 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, @@ -1079,6 +1504,16 @@ func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor { return %2 : tensor } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- @@ -1103,6 +1538,16 @@ func.func @not_projected_permutation(%arg0: tensor<8x8xf32>) -> tensor<6x6x3x3xf return %result : tensor<6x6x3x3xf32> } +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 + } +} + // ----- // Check vectorization can handle cases where outputs are a mix of reduced and non-reduced values. @@ -1134,3 +1579,13 @@ func.func @mixed_parallel_reduced_results(%arg0 : tensor<2x4x8xf32>, // CHECK-DAG: %[[ADD:.+]] = vector.multi_reduction , %[[MUL]], %[[V2]] // CHECK-DAG: vector.transfer_write %[[MUL]], %[[ARG2]] // CHECK-DAG: vector.transfer_write %[[ADD]], %[[ARG3]] + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + transform.sequence %arg0 failures(propagate) { + ^bb1(%arg1: !pdl.operation): + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 + %1 = get_closest_isolated_parent %0 + %2 = transform.structured.vectorize %1 { disable_multi_reduction_to_contract_patterns = true, disable_transfer_permutation_map_lowering_patterns = true } + } +} \ No newline at end of file diff --git a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp index 576082a..3949544 100644 --- a/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp +++ b/mlir/test/lib/Dialect/Linalg/TestLinalgTransforms.cpp @@ -225,9 +225,6 @@ static void applyPatterns(func::FuncOp funcOp) { //===--------------------------------------------------------------------===// // Linalg to vector contraction patterns. //===--------------------------------------------------------------------===// - patterns.add( - ctx, LinalgTransformationFilter(StringAttr::get(ctx, "VECTORIZE")) - .addOpFilter()); patterns.add(ctx); (void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns)); @@ -441,9 +438,6 @@ static void applyVectorTransferForwardingPatterns(func::FuncOp funcOp) { static void applyLinalgToVectorPatterns(func::FuncOp funcOp) { RewritePatternSet patterns(funcOp.getContext()); auto *ctx = funcOp.getContext(); - patterns.add( - ctx, LinalgTransformationFilter() - .addOpFilter()); patterns.add(ctx); populatePadOpVectorizationPatterns(patterns); populateConvolutionVectorizationPatterns(patterns);