From cc7f52432bca6938d748c6730943586f879f841f Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 6 Jun 2023 11:19:59 +0200 Subject: [PATCH] [mlir][transform] Use separate ops instead of PatternRegistry * Remove `transform::PatternRegistry`. * Add a new op for each currently registered pattern set. * Change names of vector dialect pattern selector ops, so that they are consistent with the remaining code base. * Remove redundant `transform.vector.extract_address_computations` op. Differential Revision: https://reviews.llvm.org/D152249 --- .../Linalg/TransformOps/LinalgTransformOps.td | 47 ++++++++ .../MemRef/TransformOps/MemRefTransformOps.td | 120 +++++++++++++-------- .../Dialect/SCF/TransformOps/SCFTransformOps.td | 12 +++ .../Tensor/TransformOps/TensorTransformOps.h | 2 +- .../Tensor/TransformOps/TensorTransformOps.td | 69 ++++++++++++ .../mlir/Dialect/Tensor/Transforms/Transforms.h | 6 +- .../mlir/Dialect/Transform/IR/TransformOps.h | 43 -------- .../mlir/Dialect/Transform/IR/TransformOps.td | 13 +-- .../Vector/TransformOps/VectorTransformOps.td | 46 +++++--- .../Linalg/TransformOps/DialectExtension.cpp | 16 --- .../Linalg/TransformOps/LinalgTransformOps.cpp | 25 +++++ .../MemRef/TransformOps/MemRefTransformOps.cpp | 72 +++++-------- .../Dialect/SCF/TransformOps/SCFTransformOps.cpp | 16 +-- .../Tensor/TransformOps/TensorTransformOps.cpp | 54 ++++++---- mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 49 --------- .../Vector/TransformOps/VectorTransformOps.cpp | 118 +++++--------------- mlir/test/Dialect/LLVM/transform-e2e.mlir | 16 +-- .../transform-op-matmul-to-outerproduct.mlir | 2 +- .../MemRef/extract-address-computations.mlir | 52 ++++++--- mlir/test/Dialect/Transform/ops-invalid.mlir | 17 +-- .../Transform/test-pattern-application.mlir | 46 +++----- mlir/test/Dialect/Vector/transform-vector.mlir | 16 +-- .../vector-broadcast-lowering-transforms.mlir | 2 +- .../Vector/vector-contract-matvec-transforms.mlir | 2 +- .../Vector/vector-contract-to-dot-transforms.mlir | 2 +- ...r-contract-to-matrix-intrinsics-transforms.mlir | 4 +- ...vector-contract-to-outerproduct-transforms.mlir | 2 +- ...ctor-contract-to-parallel-arith-transforms.mlir | 2 +- .../Vector/vector-mask-lowering-transforms.mlir | 4 +- .../Vector/vector-multi-reduction-lowering.mlir | 2 +- .../vector-multi-reduction-outer-lowering.mlir | 2 +- .../vector-outerproduct-lowering-transforms.mlir | 4 +- .../vector-shape-cast-lowering-transforms.mlir | 2 +- .../vector-transfer-drop-unit-dims-patterns.mlir | 12 +-- ...transfer-full-partial-split-copy-transform.mlir | 6 +- .../Vector/vector-transfer-full-partial-split.mlir | 8 +- .../vector-transfer-to-vector-load-store.mlir | 4 +- .../Dialect/Vector/vector-transpose-lowering.mlir | 12 +-- .../Dialect/Vector/CPU/test-shuffle16x16.mlir | 2 +- .../Transform/TestTransformDialectExtension.cpp | 9 -- 40 files changed, 479 insertions(+), 459 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 65a2510..a7356c0 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -27,6 +27,53 @@ def TransformParamTypeOrAnyHandle : Type< "transform 'param' type or any handle type">; //===----------------------------------------------------------------------===// +// Apply...PatternsOp +//===----------------------------------------------------------------------===// + +def ApplyEraseUnnecessaryInputsPatternsOp : Op]> { + let description = [{ + Collects patterns that promote inputs to outputs and remove unused inputs of + `linalg.generic` ops. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyFoldUnitExtentDimsViaReshapesPatternsOp : Op]> { + let description = [{ + Collects patterns to fold unit-extent dimensions in operands/results of + linalg ops on tensors via reassociative reshape ops. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyFoldUnitExtentDimsViaSlicesPatternsOp : Op]> { + let description = [{ + Collects patterns to fold unit-extent dimensions in operands/results of + linalg ops on tensors via rank-reducing slices. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyTilingCanonicalizationPatternsOp : Op]> { + let description = [{ + Collects canonicalization patterns relevant to apply after tiling patterns. + }]; + + let assemblyFormat = "attr-dict"; +} + +//===----------------------------------------------------------------------===// // BufferizeToAllocationOp //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td index 04dfe1f..35a1d84 100644 --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -15,6 +15,81 @@ include "mlir/Dialect/Transform/IR/TransformTypes.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" +def ApplyExpandOpsPatternsOp : Op]> { + let description = [{ + Collects patterns to rewrite ops within the memref dialect. + + - Converts `atomic_rmw` that cannot be lowered to a simple atomic op with + AtomicRMWOpLowering pattern, e.g. with "minf" or "maxf" attributes, to + `memref.generic_atomic_rmw` with the expanded code. + - Converts `memref.reshape` that has a target shape of a statically-known + size to `memref.reinterpret_cast`. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyExpandStridedMetadataPatternsOp : Op]> { + let description = [{ + Collects patterns for expanding memref operations that modify the metadata + (sizes, offset, strides) of a memref into easier to analyze constructs. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyExtractAddressComputationsPatternsOp : Op]> { + let description = [{ + Collects patterns for extracting address computations from operations + with memory accesses such that these memory accesses use only a base + pointer. + + For instance, + ```mlir + memref.load %base[%off0, ...] + ``` + + Will be rewritten in: + ```mlir + %new_base = memref.subview %base[%off0,...][1,...][1,...] + memref.load %new_base[%c0,...] + ``` + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyFoldMemrefAliasOpsPatternsOp : Op]> { + let description = [{ + Collects patterns for folding memref aliasing ops (memref.subview) into + consumer load/store ops (affine.load, memref.load, nvgpu.ldmatrix, + vector.load, vector.transfer_read, affine.store, memref.store, etc.) and + other ops (e.g., memref.subview). + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyResolveRankedShapedTypeResultDimsPatternsOp : Op]> { + let description = [{ + Collects patterns that resolve `memref.dim` operations with values that are + defined by operations that implement the `ReifyRankedShapedTypeOpInterface`, + in terms of shapes of its input operands. + }]; + + let assemblyFormat = "attr-dict"; +} + def Transform_MemRefAllocOp : Transform_ConcreteOpType<"memref.alloc">; def MemRefMultiBufferOp : Op { - let summary = "Extract address computations from memory accesses"; - let description = [{ - Transformation that extracts address computations from instructions - with memory accesses such that these memory accesses use only a base - pointer. - - For instance, - ```mlir - memref.load %base[%off0, ...] - ``` - - Will be rewritten in: - ```mlir - %new_base = memref.subview %base[%off0,...][1,...][1,...] - memref.load %new_base[%c0,...] - ``` - - Note: The current implementation requires that the input operation - is "isolated from above". - - #### Return modes - - This operation produces `definiteFailure` if the extraction fails for any - reason. - The operation always returns the handle to the target op that is expected - to be isolated from above. - }]; - - let arguments = (ins TransformHandleTypeInterface:$target); - let results = (outs TransformHandleTypeInterface:$transformed); - - let assemblyFormat = "$target attr-dict `:` functional-type(operands, results)"; - - let extraClassDeclaration = [{ - ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::Operation *target, - ::mlir::transform::ApplyToEachResultList &transformResults, - ::mlir::transform::TransformState &state); - }]; -} - def MemRefMakeLoopIndependentOp : Op]> { + let description = [{ + Collects patterns for canonicalizing operations inside SCF loop bodies. + At the moment, only affine.min/max computations with iteration variables, + loop bounds and loop steps are canonicalized. + }]; + + let assemblyFormat = "attr-dict"; +} + def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">; def GetParentForOp : Op]> { + let description = [{ + Indicates that redundant tensor.insert_slice rank reductions should be + dropped. E.g., cases where a tensor.extract_slice rank reduction immediately + follows an inverse tensor.insert_slice rank expansion. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyFoldTensorEmptyPatternsOp : Op]> { + let description = [{ + Indicates that reassociative reshapes (tensor.collapse_shape / + tensor.expand_shape) should be folded with inverse rank expansions / rank + reductions (via tensor.insert_slice / tensor.extract_slice). + }]; + + let assemblyFormat = "attr-dict"; +} +def ApplyFoldIntoPackAndUnpackPatternsOp : Op]> { + let description = [{ + Indicates that operations like tensor.pad and tensor.extract_slice should + be folded into tensor.pack and tensor.unpack operations, respectively. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyFoldTensorSubsetOpsPatternsOp : Op]> { + let description = [{ + Indicates that tensor.empty should be folded with tensor.extract_slice, + tensor.expand_shape and tensor.collapse_shape. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyMergeConsecutiveInsertExtractSlicePatternsOp : Op]> { + let description = [{ + Indicates that consecutive tensor.extract_slice/tensor.insert_slice ops + should be merged into a single op. These patterns are not canonicalizations + because the bufferization is sensitive to IR structure. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyReassociativeReshapeFoldingPatternsOp : Op]> { + let description = [{ + Indicates that reassociative reshapes (tensor.collapse_shape / + tensor.expand_shape) should be folded with inverse rank expansions / rank + reductions (via tensor.insert_slice / tensor.extract_slice). + }]; + + let assemblyFormat = "attr-dict"; +} + def Transform_TensorPadOp : Transform_ConcreteOpType<"tensor.pad">; def MakeLoopIndependentOp diff --git a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h index a87ee1b..cadf3ad 100644 --- a/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Tensor/Transforms/Transforms.h @@ -39,8 +39,8 @@ FailureOr replaceExtractSliceWithTiledProducer( void populateFoldTensorSubsetOpPatterns(RewritePatternSet &patterns); /// Collects patterns to merge consecutive tensor.insert_slice/extract_slice -/// into one. These patterns are in in this separate entry point because the -/// bufferization is sensitive over IR structure, particularly those +/// into one. These patterns are in this separate entry point because the +/// bufferization is sensitive to IR structure, particularly those /// tensor.extract_slice and tensor.insert_slice ops for creating the slices. void populateMergeConsecutiveInsertExtractSlicePatterns( RewritePatternSet &patterns); @@ -55,7 +55,7 @@ void populateDropRedundantInsertSliceRankExpansionPatterns( void populateReassociativeReshapeFoldingPatterns(RewritePatternSet &patterns); /// Populates `patterns` with patterns that fold tensor.empty with -/// tensor.[extract_slice|cast|expand_shape|collapse_shape]. +/// tensor.[extract_slice|expand_shape|collapse_shape]. void populateFoldTensorEmptyPatterns(RewritePatternSet &patterns); /// Populates `patterns` with patterns that fold operations like `tensor.pad` diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h index e738baf..5625d19 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -26,7 +26,6 @@ namespace mlir { namespace transform { -class ApplyPatternsOp; enum class FailurePropagationMode : uint32_t; class FailurePropagationModeAttr; @@ -152,51 +151,9 @@ private: int64_t errorCounter = 0; }; -/// The PatternRegistry stores callbacks to functions that populate a -/// `RewritePatternSet`. Registered patterns can be applied with the -/// "transform.apply_patterns" op. -class PatternRegistry : public TransformDialectData { -public: - PatternRegistry(MLIRContext *ctx) : TransformDialectData(ctx), builder(ctx) {} - - /// A function that populates a `RewritePatternSet`. - using PopulatePatternsFn = std::function; - /// A function that populates a `RewritePatternSet` with a specified benefit. - using PopulatePatternsWithBenefitFn = - std::function; - - /// Registers patterns with the specified identifier. The identifier should - /// be prefixed with the dialect to which the patterns belong. - void registerPatterns(StringRef identifier, PopulatePatternsFn &&fn); - - /// Registers patterns with the specified identifier. The identifier should - /// be prefixed with the dialect to which the patterns belong. The pattern - /// benefit is currently ignored. - void registerPatterns(StringRef identifier, - PopulatePatternsWithBenefitFn &&fn); - -protected: - friend class ApplyPatternsOp; - - /// Returns "true" if patterns are registered with the specified identifier. - bool hasPatterns(StringAttr identifier) const; - - /// Populates the given pattern set with the specified patterns. - void populatePatterns(StringAttr identifier, - RewritePatternSet &patternSet) const; - -private: - /// A builder for creating StringAttrs. - Builder builder; - - DenseMap patterns; -}; - } // namespace transform } // namespace mlir -MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::transform::PatternRegistry) - #define GET_OP_CLASSES #include "mlir/Dialect/Transform/IR/TransformOps.h.inc" diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index 5b7e6ca..6c75e4b 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -138,12 +138,9 @@ def ApplyPatternsOp : TransformDialectOp<"apply_patterns", targeted op itself. The patterns that should be applied are specified in the graph region of - this op. They must implement the `PatternDescriptorOpInterface`. - - (Deprecated) In addition, patterns that were registered in the transform - dialect's `PatternRegistry` are available. "canonicalization" is a special - set of patterns that refers to all canonicalization patterns of all loaded - dialects. + this op. They must implement the `PatternDescriptorOpInterface`. The order + in which patterns are applied is unspecified; i.e., the ordering of ops in + the region of this op is irrelevant. This transform only reads the target handle and modifies the payload. If a pattern erases or replaces a tracked op, the mapping is updated accordingly. @@ -161,12 +158,12 @@ def ApplyPatternsOp : TransformDialectOp<"apply_patterns", }]; let arguments = (ins - TransformHandleTypeInterface:$target, ArrayAttr:$patterns, + TransformHandleTypeInterface:$target, DefaultValuedAttr:$fail_on_payload_replacement_not_found); let results = (outs); let regions = (region MaxSizedRegion<1>:$region); - let assemblyFormat = "$patterns `to` $target $region attr-dict `:` type($target)"; + let assemblyFormat = "`to` $target $region attr-dict `:` type($target)"; let hasVerifier = 1; let extraClassDeclaration = [{ diff --git a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td index 87fcbb0..c1891ff 100644 --- a/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td +++ b/mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td @@ -52,7 +52,7 @@ def ApplyTransferPermutationPatternsOp : Op]> { let description = [{ @@ -67,7 +67,7 @@ def LowerBroadcastOp : Op]> { let description = [{ @@ -86,7 +86,7 @@ def LowerContractionOp : Op]> { let description = [{ @@ -100,7 +100,7 @@ def LowerMasksOp : Op]> { let description = [{ @@ -114,7 +114,7 @@ def LowerMaskedTransfersOp : Op]> { let description = [{ @@ -129,7 +129,7 @@ def MaterializeMasksOp : Op]> { let description = [{ @@ -149,7 +149,7 @@ def LowerMultiReductionOp : Op]> { let description = [{ @@ -163,7 +163,29 @@ def LowerOuterProductOp : Op]> { + let description = [{ + Indicates that vector.gather operations should be lowered to + finer-grained vector primitives. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyLowerScanPatternsOp : Op]> { + let description = [{ + Indicates that vector.scan operations should be lowered to + finer-grained vector primitives. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyLowerShapeCastPatternsOp : Op]> { let description = [{ @@ -177,7 +199,7 @@ def LowerShapeCastOp : Op]> { let description = [{ @@ -196,7 +218,7 @@ def LowerTransferOp : Op]> { let description = [{ @@ -223,7 +245,7 @@ def LowerTransposeOp : Op]> { let description = [{ @@ -244,7 +266,7 @@ def SplitTransferFullPartialOp : Op]> { let description = [{ diff --git a/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp b/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp index ab28ad3..2bf2a56 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/DialectExtension.cpp @@ -49,22 +49,6 @@ public: #define GET_OP_LIST #include "mlir/Dialect/Linalg/TransformOps/LinalgMatchOps.cpp.inc" >(); - - addDialectDataInitializer( - [&](transform::PatternRegistry ®istry) { - registry.registerPatterns( - "linalg.erase_unnecessary_inputs", - linalg::populateEraseUnnecessaryInputsPatterns); - registry.registerPatterns( - "linalg.fold_unit_extent_dims_via_slices", - linalg::populateFoldUnitExtentDimsViaSlicesPatterns); - registry.registerPatterns( - "linalg.fold_unit_extent_dims_via_reshapes", - linalg::populateFoldUnitExtentDimsViaReshapesPatterns); - registry.registerPatterns( - "linalg.tiling_canonicalization", - linalg::populateLinalgTilingCanonicalizationPatterns); - }); } }; } // namespace diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 133ce91..4934a60 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -142,8 +142,33 @@ static DiagnosedSilenceableFailure unpackSingleIndexResultPayloadOperations( } //===----------------------------------------------------------------------===// +// Apply...PatternsOp +//===----------------------------------------------------------------------===// + +void transform::ApplyEraseUnnecessaryInputsPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + linalg::populateEraseUnnecessaryInputsPatterns(patterns); +} + +void transform::ApplyFoldUnitExtentDimsViaReshapesPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + linalg::populateFoldUnitExtentDimsViaReshapesPatterns(patterns); +} + +void transform::ApplyFoldUnitExtentDimsViaSlicesPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + linalg::populateFoldUnitExtentDimsViaSlicesPatterns(patterns); +} + +void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + linalg::populateLinalgTilingCanonicalizationPatterns(patterns); +} + +//===----------------------------------------------------------------------===// // BufferizeToAllocationOp //===----------------------------------------------------------------------===// + DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply(transform::TransformResults &results, transform::TransformState &state) { diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp index 7b63613..0f84fe4 100644 --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -16,7 +16,6 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" -#include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Interfaces/LoopLikeInterface.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -28,6 +27,35 @@ using namespace mlir; #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") //===----------------------------------------------------------------------===// +// Apply...PatternsOp +//===----------------------------------------------------------------------===// + +void transform::ApplyExpandOpsPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + memref::populateExpandOpsPatterns(patterns); +} + +void transform::ApplyExpandStridedMetadataPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + memref::populateExpandStridedMetadataPatterns(patterns); +} + +void transform::ApplyExtractAddressComputationsPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + memref::populateExtractAddressComputationsPatterns(patterns); +} + +void transform::ApplyFoldMemrefAliasOpsPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + memref::populateFoldMemRefAliasOpPatterns(patterns); +} + +void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp:: + populatePatterns(RewritePatternSet &patterns) { + memref::populateResolveRankedShapedTypeResultDimsPatterns(patterns); +} + +//===----------------------------------------------------------------------===// // MemRefMultiBufferOp //===----------------------------------------------------------------------===// @@ -73,31 +101,6 @@ DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply( } //===----------------------------------------------------------------------===// -// MemRefExtractAddressComputationsOp -//===----------------------------------------------------------------------===// - -DiagnosedSilenceableFailure -transform::MemRefExtractAddressComputationsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, - transform::TransformState &state) { - if (!target->hasTrait()) { - auto diag = this->emitOpError("requires isolated-from-above targets"); - diag.attachNote(target->getLoc()) << "non-isolated target"; - return DiagnosedSilenceableFailure::definiteFailure(); - } - - MLIRContext *ctx = getContext(); - RewritePatternSet patterns(ctx); - memref::populateExtractAddressComputationsPatterns(patterns); - - if (failed(applyPatternsAndFoldGreedily(target, std::move(patterns)))) - return emitDefaultDefiniteFailure(target); - - results.push_back(target); - return DiagnosedSilenceableFailure::success(); -} - -//===----------------------------------------------------------------------===// // MemRefMakeLoopIndependentOp //===----------------------------------------------------------------------===// @@ -162,23 +165,6 @@ public: #define GET_OP_LIST #include "mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp.inc" >(); - - addDialectDataInitializer( - [&](transform::PatternRegistry ®istry) { - registry.registerPatterns("memref.expand_ops", - memref::populateExpandOpsPatterns); - registry.registerPatterns("memref.fold_memref_alias_ops", - memref::populateFoldMemRefAliasOpPatterns); - registry.registerPatterns( - "memref.resolve_ranked_shaped_type_result_dims", - memref::populateResolveRankedShapedTypeResultDimsPatterns); - registry.registerPatterns( - "memref.expand_strided_metadata", - memref::populateExpandStridedMetadataPatterns); - registry.registerPatterns( - "memref.extract_address_computations", - memref::populateExtractAddressComputationsPatterns); - }); } }; } // namespace diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index 4c64c54..edd1d32 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -23,6 +23,15 @@ using namespace mlir; using namespace mlir::affine; //===----------------------------------------------------------------------===// +// Apply...PatternsOp +//===----------------------------------------------------------------------===// + +void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + scf::populateSCFForLoopCanonicalizationPatterns(patterns); +} + +//===----------------------------------------------------------------------===// // GetParentForOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure @@ -309,13 +318,6 @@ public: #define GET_OP_LIST #include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.cpp.inc" >(); - - addDialectDataInitializer( - [&](transform::PatternRegistry ®istry) { - registry.registerPatterns( - "scf.for_loop_canonicalization", - scf::populateSCFForLoopCanonicalizationPatterns); - }); } }; } // namespace diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp index 1ab7883..18517cf 100644 --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -80,6 +80,40 @@ void tensor::registerFindPayloadReplacementOpInterfaceExternalModels( } //===----------------------------------------------------------------------===// +// Apply...PatternsOp +//===----------------------------------------------------------------------===// + +void transform::ApplyDropRedundantInsertSliceRankExpansionPatternsOp:: + populatePatterns(RewritePatternSet &patterns) { + tensor::populateDropRedundantInsertSliceRankExpansionPatterns(patterns); +} + +void transform::ApplyFoldTensorEmptyPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + tensor::populateFoldTensorEmptyPatterns(patterns); +} + +void transform::ApplyFoldIntoPackAndUnpackPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + tensor::populateFoldIntoPackAndUnpackPatterns(patterns); +} + +void transform::ApplyFoldTensorSubsetOpsPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + tensor::populateFoldTensorSubsetOpPatterns(patterns); +} + +void transform::ApplyMergeConsecutiveInsertExtractSlicePatternsOp:: + populatePatterns(RewritePatternSet &patterns) { + tensor::populateMergeConsecutiveInsertExtractSlicePatterns(patterns); +} + +void transform::ApplyReassociativeReshapeFoldingPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + tensor::populateReassociativeReshapeFoldingPatterns(patterns); +} + +//===----------------------------------------------------------------------===// // MakeLoopIndependentOp //===----------------------------------------------------------------------===// @@ -144,26 +178,6 @@ public: #define GET_OP_LIST #include "mlir/Dialect/Tensor/TransformOps/TensorTransformOps.cpp.inc" >(); - - addDialectDataInitializer( - [&](transform::PatternRegistry ®istry) { - registry.registerPatterns("tensor.fold_tensor_subset_ops", - tensor::populateFoldTensorSubsetOpPatterns); - registry.registerPatterns( - "tensor.merge_consecutive_insert_extract_slice", - tensor::populateMergeConsecutiveInsertExtractSlicePatterns); - registry.registerPatterns( - "tensor.drop_redundant_insert_slice_rank_expansion", - tensor::populateDropRedundantInsertSliceRankExpansionPatterns); - registry.registerPatterns( - "tensor.reassociative_reshape_folding", - tensor::populateReassociativeReshapeFoldingPatterns); - registry.registerPatterns("tensor.fold_tensor_empty", - tensor::populateFoldTensorEmptyPatterns); - registry.registerPatterns( - "tensor.fold_into_pack_and_unpack", - tensor::populateFoldIntoPackAndUnpackPatterns); - }); } }; } // namespace diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 212deb1..e202956 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -32,8 +32,6 @@ using namespace mlir; -MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::transform::PatternRegistry) - static ParseResult parseSequenceOpOperands( OpAsmParser &parser, std::optional &root, Type &rootType, @@ -218,37 +216,6 @@ void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound( } //===----------------------------------------------------------------------===// -// PatternRegistry -//===----------------------------------------------------------------------===// - -void transform::PatternRegistry::registerPatterns(StringRef identifier, - PopulatePatternsFn &&fn) { - StringAttr attr = builder.getStringAttr(identifier); - assert(!patterns.contains(attr) && "patterns identifier is already in use"); - patterns.try_emplace(attr, std::move(fn)); -} - -void transform::PatternRegistry::registerPatterns( - StringRef identifier, PopulatePatternsWithBenefitFn &&fn) { - StringAttr attr = builder.getStringAttr(identifier); - assert(!patterns.contains(attr) && "patterns identifier is already in use"); - patterns.try_emplace(attr, [f = std::move(fn)](RewritePatternSet &patternSet) { - f(patternSet, /*benefit=*/1); - }); -} - -void transform::PatternRegistry::populatePatterns( - StringAttr identifier, RewritePatternSet &patternSet) const { - auto it = patterns.find(identifier); - assert(it != patterns.end() && "patterns not registered in registry"); - it->second(patternSet); -} - -bool transform::PatternRegistry::hasPatterns(StringAttr identifier) const { - return patterns.contains(identifier); -} - -//===----------------------------------------------------------------------===// // AlternativesOp //===----------------------------------------------------------------------===// @@ -440,11 +407,6 @@ transform::ApplyPatternsOp::applyToOne(Operation *target, // Gather all specified patterns. MLIRContext *ctx = target->getContext(); RewritePatternSet patterns(ctx); - const auto ®istry = getContext() - ->getLoadedDialect() - ->getExtraData(); - for (Attribute attr : getPatterns()) - registry.populatePatterns(attr.cast(), patterns); if (!getRegion().empty()) { for (Operation &op : getRegion().front()) { cast(&op).populatePatterns( @@ -495,17 +457,6 @@ transform::ApplyPatternsOp::applyToOne(Operation *target, } LogicalResult transform::ApplyPatternsOp::verify() { - const auto ®istry = getContext() - ->getLoadedDialect() - ->getExtraData(); - for (Attribute attr : getPatterns()) { - auto strAttr = attr.dyn_cast(); - if (!strAttr) - return emitOpError() << "expected " << getPatternsAttrName() - << " to be an array of strings"; - if (!registry.hasPatterns(strAttr)) - return emitOpError() << "patterns not registered: " << strAttr.strref(); - } if (!getRegion().empty()) { for (Operation &op : getRegion().front()) { if (!isa(&op)) { diff --git a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp index 44caaec..7e09794 100644 --- a/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp +++ b/mlir/lib/Dialect/Vector/TransformOps/VectorTransformOps.cpp @@ -24,7 +24,7 @@ using namespace mlir::vector; using namespace mlir::transform; //===----------------------------------------------------------------------===// -// ApplyRankReducingSubviewPatternsOp +// Apply...PatternsOp //===----------------------------------------------------------------------===// void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns( @@ -32,29 +32,17 @@ void transform::ApplyRankReducingSubviewPatternsOp::populatePatterns( vector::populateVectorTransferDropUnitDimsPatterns(patterns); } -//===----------------------------------------------------------------------===// -// ApplyTransferPermutationPatternsOp -//===----------------------------------------------------------------------===// - void transform::ApplyTransferPermutationPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::populateVectorTransferPermutationMapLoweringPatterns(patterns); } -//===----------------------------------------------------------------------===// -// LowerBroadcastOp -//===----------------------------------------------------------------------===// - -void transform::LowerBroadcastOp::populatePatterns( +void transform::ApplyLowerBroadcastPatternsOp::populatePatterns( RewritePatternSet &patterns) { populateVectorBroadcastLoweringPatterns(patterns); } -//===----------------------------------------------------------------------===// -// LowerContractionOp -//===----------------------------------------------------------------------===// - -void transform::LowerContractionOp::populatePatterns( +void transform::ApplyLowerContractionPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::VectorTransformsOptions vectorTransformOptions; vectorTransformOptions.setVectorTransformsOptions(getLoweringStrategy()); @@ -63,37 +51,23 @@ void transform::LowerContractionOp::populatePatterns( /*disableOuterProductLowering=*/true); } -//===----------------------------------------------------------------------===// -// LowerMasksOp -//===----------------------------------------------------------------------===// - -void transform::LowerMasksOp::populatePatterns(RewritePatternSet &patterns) { +void transform::ApplyLowerMasksPatternsOp::populatePatterns( + RewritePatternSet &patterns) { populateVectorMaskOpLoweringPatterns(patterns); } -//===----------------------------------------------------------------------===// -// LowerMaskedTransfersOp -//===----------------------------------------------------------------------===// - -void transform::LowerMaskedTransfersOp::populatePatterns( +void transform::ApplyLowerMaskedTransfersPatternsOp::populatePatterns( RewritePatternSet &patterns) { populateVectorMaskLoweringPatternsForSideEffectingOps(patterns); } -//===----------------------------------------------------------------------===// -// MaterializeMasksOp -//===----------------------------------------------------------------------===// - -void transform::MaterializeMasksOp::populatePatterns(RewritePatternSet &patterns) { +void transform::ApplyMaterializeMasksPatternsOp::populatePatterns( + RewritePatternSet &patterns) { populateVectorMaskMaterializationPatterns(patterns, /*force32BitVectorIndices=*/false); } -//===----------------------------------------------------------------------===// -// LowerMultiReductionOp -//===----------------------------------------------------------------------===// - -void transform::LowerMultiReductionOp::populatePatterns( +void transform::ApplyLowerMultiReductionPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::VectorTransformsOptions vectorTransformOptions; vectorTransformOptions.setVectorMultiReductionLowering(getLoweringStrategy()); @@ -101,38 +75,33 @@ void transform::LowerMultiReductionOp::populatePatterns( patterns, vectorTransformOptions.vectorMultiReductionLowering); } -//===----------------------------------------------------------------------===// -// LowerOuterProductOp -//===----------------------------------------------------------------------===// - -void transform::LowerOuterProductOp::populatePatterns( +void transform::ApplyLowerOuterProductPatternsOp::populatePatterns( RewritePatternSet &patterns) { populateVectorOuterProductLoweringPatterns(patterns); } -//===----------------------------------------------------------------------===// -// LowerShapeCastOp -//===----------------------------------------------------------------------===// +void transform::ApplyLowerGatherPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateVectorGatherLoweringPatterns(patterns); +} -void transform::LowerShapeCastOp::populatePatterns( +void transform::ApplyLowerScanPatternsOp::populatePatterns( RewritePatternSet &patterns) { - vector::populateVectorShapeCastLoweringPatterns(patterns); + vector::populateVectorScanLoweringPatterns(patterns); } -//===----------------------------------------------------------------------===// -// LowerTransferOp -//===----------------------------------------------------------------------===// +void transform::ApplyLowerShapeCastPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + vector::populateVectorShapeCastLoweringPatterns(patterns); +} -void transform::LowerTransferOp::populatePatterns(RewritePatternSet &patterns) { +void transform::ApplyLowerTransferPatternsOp::populatePatterns( + RewritePatternSet &patterns) { vector::populateVectorTransferLoweringPatterns(patterns, getMaxTransferRank()); } -//===----------------------------------------------------------------------===// -// LowerTransposeOp -//===----------------------------------------------------------------------===// - -void transform::LowerTransposeOp::populatePatterns( +void transform::ApplyLowerTransposePatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::populateVectorTransposeLoweringPatterns( patterns, vector::VectorTransformsOptions().setVectorTransposeLowering( @@ -148,22 +117,15 @@ void transform::LowerTransposeOp::populatePatterns( } } -//===----------------------------------------------------------------------===// -// SplitTransferFullPartialOp -//===----------------------------------------------------------------------===// - -void transform::SplitTransferFullPartialOp::populatePatterns( +void transform::ApplySplitTransferFullPartialPatternsOp::populatePatterns( RewritePatternSet &patterns) { vector::VectorTransformsOptions vectorTransformOptions; vectorTransformOptions.setVectorTransferSplit(getSplitTransferStrategy()); populateVectorTransferFullPartialPatterns(patterns, vectorTransformOptions); } -//===----------------------------------------------------------------------===// -// TransferToScfOp -//===----------------------------------------------------------------------===// - -void transform::TransferToScfOp::populatePatterns(RewritePatternSet &patterns) { +void transform::ApplyTransferToScfPatternsOp::populatePatterns( + RewritePatternSet &patterns) { VectorTransferToSCFOptions vectorTransferToSCFOptions = VectorTransferToSCFOptions() .enableFullUnroll(getFullUnroll()) @@ -189,34 +151,6 @@ public: #define GET_OP_LIST #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.cpp.inc" >(); - - addDialectDataInitializer( - [&](transform::PatternRegistry ®istry) { - registry.registerPatterns("vector.outer_product_lowering", - populateVectorOuterProductLoweringPatterns); - registry.registerPatterns("vector.broadcast_lowering", - populateVectorBroadcastLoweringPatterns); - registry.registerPatterns("vector.mask_op_lowering", - populateVectorMaskOpLoweringPatterns); - registry.registerPatterns("vector.shape_cast_lowering", - populateVectorShapeCastLoweringPatterns); - registry.registerPatterns( - "vector.transfer_lowering", - [&](RewritePatternSet &set, PatternBenefit benefit) { - return populateVectorTransferLoweringPatterns( - set, /*maxTransferRank=*/std::nullopt, benefit); - }); - registry.registerPatterns( - "vector.transfer_permutation_map_lowering", - populateVectorTransferPermutationMapLoweringPatterns); - registry.registerPatterns("vector.scan_lowering", - populateVectorScanLoweringPatterns); - registry.registerPatterns("vector.vector_gather_lowering", - populateVectorGatherLoweringPatterns); - registry.registerPatterns( - "vector.mask_lowering_for_side_effecting_ops", - populateVectorMaskLoweringPatternsForSideEffectingOps); - }); } }; } // namespace diff --git a/mlir/test/Dialect/LLVM/transform-e2e.mlir b/mlir/test/Dialect/LLVM/transform-e2e.mlir index 3fda274..049e893b 100644 --- a/mlir/test/Dialect/LLVM/transform-e2e.mlir +++ b/mlir/test/Dialect/LLVM/transform-e2e.mlir @@ -27,35 +27,35 @@ transform.sequence failures(propagate) { // TODO: group these lower-level controls into various properly named vector // lowering TD macros. - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.apply_transfer_permutation_patterns } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel" } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1 } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_shape_cast } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d" } : !transform.any_op } diff --git a/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir b/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir index f3140a9..6ebcdd4 100644 --- a/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-matmul-to-outerproduct.mlir @@ -32,7 +32,7 @@ transform.sequence failures(propagate) { %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = get_closest_isolated_parent %0 : (!transform.any_op) -> !transform.any_op %2 = transform.structured.vectorize %1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %2 { + transform.apply_patterns to %2 { transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" } : !transform.any_op } diff --git a/mlir/test/Dialect/MemRef/extract-address-computations.mlir b/mlir/test/Dialect/MemRef/extract-address-computations.mlir index 5064f60..40393ff 100644 --- a/mlir/test/Dialect/MemRef/extract-address-computations.mlir +++ b/mlir/test/Dialect/MemRef/extract-address-computations.mlir @@ -24,9 +24,11 @@ func.func @test_load(%base : memref<2x16x16xf32>, %offset : index) -> f32 { transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op // Verify that the returned handle is usable. - transform.test_print_remark_at_operand %1, "transformed" : !transform.any_op + transform.test_print_remark_at_operand %0, "transformed" : !transform.any_op } // ----- @@ -50,7 +52,9 @@ func.func @test_load_nontemporal(%base : memref<2x16x16xf32>, %offset : index) - transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op } // ----- @@ -79,7 +83,9 @@ func.func @test_store(%base : memref<2x16x16xf32>, %offset : index) -> () { transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op } // ----- @@ -105,7 +111,9 @@ func.func @test_store_nontemporal(%base : memref<2x16x16xf32>, %offset : index) transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op } // ----- @@ -159,7 +167,9 @@ func.func @testWithLoop(%base : memref>) transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op } // ----- @@ -197,7 +207,9 @@ func.func @test_ldmatrix(%base : memref<4x32x32xf16, 3>, transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op } // ----- @@ -231,7 +243,9 @@ func.func @test_ldmatrix(%base : memref, transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op } // ----- @@ -266,7 +280,9 @@ func.func @test_transfer_read_op(%base : memref, transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op } // ----- @@ -294,7 +310,9 @@ func.func @test_transfer_read_op_with_tensor(%base : tensor, transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op } // ----- @@ -328,7 +346,9 @@ func.func @test_transfer_write_op(%base : memref, transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op } // ----- @@ -363,8 +383,11 @@ func.func @test_transfer_write_op_with_strides(%base : memref !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op } + // ----- // Same as test_transfer_write_op but with tensors. @@ -389,5 +412,8 @@ func.func @test_transfer_write_op_with_tensor(%base : tensor, transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.memref.extract_address_computations %0 : (!transform.any_op) -> !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.memref.extract_address_computations + } : !transform.any_op } + diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir index 427b1b1..8aa7614 100644 --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -677,23 +677,8 @@ module attributes { transform.with_named_sequence } { transform.sequence failures(propagate) { ^bb0(%arg0: !transform.any_op): - // expected-error @below {{patterns not registered: transform.invalid_pattern_identifier}} - transform.apply_patterns ["transform.invalid_pattern_identifier"] to %arg0 {} : !transform.any_op -} - -// ----- - -transform.sequence failures(propagate) { -^bb0(%arg0: !transform.any_op): - // expected-error @below {{expected "patterns" to be an array of strings}} - transform.apply_patterns [3, 9] to %arg0 {} : !transform.any_op -} - -// ----- -transform.sequence failures(propagate) { -^bb0(%arg0: !transform.any_op): // expected-error @below {{expected children ops to implement PatternDescriptorOpInterface}} - transform.apply_patterns [] to %arg0 { + transform.apply_patterns to %arg0 { // expected-note @below {{op without interface}} transform.named_sequence @foo() } : !transform.any_op diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir index 55bb083..ca277ab 100644 --- a/mlir/test/Dialect/Transform/test-pattern-application.mlir +++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir @@ -15,29 +15,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op - // Add an attribute to %1, which is now mapped to a new op. - transform.annotate %1 "annotated" : !transform.any_op -} - -// ----- - -// CHECK-LABEL: func @update_tracked_op_mapping_region() -// CHECK: "test.container"() ({ -// CHECK: %0 = "test.foo"() {annotated} : () -> i32 -// CHECK: }) : () -> () -func.func @update_tracked_op_mapping_region() { - "test.container"() ({ - %0 = "test.foo"() {replace_with_new_op = "test.foo"} : () -> (i32) - }) : () -> () - return -} - -transform.sequence failures(propagate) { -^bb1(%arg1: !transform.any_op): - %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op - %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %0 { + transform.apply_patterns to %0 { transform.apply_patterns.transform.test_patterns } : !transform.any_op // Add an attribute to %1, which is now mapped to a new op. @@ -60,7 +38,9 @@ transform.sequence failures(propagate) { %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op // expected-error @below {{tracking listener failed to find replacement op}} - transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.transform.test_patterns + } : !transform.any_op // %1 must be used in some way. If no replacement payload op could be found, // an error is thrown only if the handle is not dead. transform.annotate %1 "annotated" : !transform.any_op @@ -84,7 +64,9 @@ transform.sequence failures(propagate) { %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op // No error because %1 is dead. - transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.transform.test_patterns + } : !transform.any_op } // ----- @@ -104,7 +86,9 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns ["transform.test"] to %0 {} {fail_on_payload_replacement_not_found = false}: !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.transform.test_patterns + } {fail_on_payload_replacement_not_found = false} : !transform.any_op transform.annotate %1 "annotated" : !transform.any_op } @@ -120,7 +104,9 @@ func.func @patterns_apply_only_to_target_body() { transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.transform.test_patterns + } : !transform.any_op } // ----- @@ -142,7 +128,9 @@ transform.sequence failures(propagate) { %0 = transform.structured.match ops{["test.container"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["test.erase_op"]} in %arg1 : (!transform.any_op) -> !transform.any_op transform.test_print_remark_at_operand %1, "matched op" : !transform.any_op - transform.apply_patterns ["transform.test"] to %0 {} : !transform.any_op + transform.apply_patterns to %0 { + transform.apply_patterns.transform.test_patterns + } : !transform.any_op transform.test_print_remark_at_operand %1, "op was deleted" : !transform.any_op } @@ -162,7 +150,7 @@ transform.sequence failures(propagate) { ^bb1(%arg1: !transform.any_op): %0 = transform.structured.match ops{["tensor.dim"]} in %arg1 : (!transform.any_op) -> !transform.any_op %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %1 { + transform.apply_patterns to %1 { transform.apply_patterns.canonicalization } : !transform.any_op transform.test_print_remark_at_operand %0, "op was replaced" : !transform.any_op diff --git a/mlir/test/Dialect/Vector/transform-vector.mlir b/mlir/test/Dialect/Vector/transform-vector.mlir index 7bdbaf9..15f423f 100644 --- a/mlir/test/Dialect/Vector/transform-vector.mlir +++ b/mlir/test/Dialect/Vector/transform-vector.mlir @@ -30,35 +30,35 @@ transform.sequence failures(propagate) { // TODO: group these lower-level controls into various properly named vector // lowering TD macros. - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.apply_transfer_permutation_patterns } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel" } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.transfer_to_scf max_transfer_rank = 1 full_unroll = true } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_transfer max_transfer_rank = 1 } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_shape_cast } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir index c71b6d6..2d3c88d 100644 --- a/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-broadcast-lowering-transforms.mlir @@ -167,7 +167,7 @@ transform.sequence failures(propagate) { %f = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_broadcast } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir index f897ed0..d8365f5 100644 --- a/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-matvec-transforms.mlir @@ -210,7 +210,7 @@ func.func @redpar_vecmattrans2x2(%arg0: memref>, %arg1: memref !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_contraction lowering_strategy = "dot" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir index 2de9578..e5582e3 100644 --- a/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-to-matrix-intrinsics-transforms.mlir @@ -48,11 +48,11 @@ transform.sequence failures(propagate) { %f = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_contraction lowering_strategy = "matmulintrinsics" } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_shape_cast } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir index a53064a..aee4cf6 100644 --- a/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-to-outerproduct-transforms.mlir @@ -347,7 +347,7 @@ transform.sequence failures(propagate) { %f = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_contraction lowering_strategy = "outerproduct" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir index 5a7502f..147f3ae 100644 --- a/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-contract-to-parallel-arith-transforms.mlir @@ -56,7 +56,7 @@ transform.sequence failures(propagate) { %f = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_contraction lowering_strategy = "parallelarith" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir index d0d53a8..d425e9c 100644 --- a/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-mask-lowering-transforms.mlir @@ -96,7 +96,7 @@ transform.sequence failures(propagate) { %f = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_masks } : !transform.any_op } @@ -127,7 +127,7 @@ transform.sequence failures(propagate) { %f = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_masked_transfers } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir index 884cf28..5a67c26 100644 --- a/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-lowering.mlir @@ -267,7 +267,7 @@ func.func @vector_multi_reduction_parallel_middle(%arg0: vector<3x4x5xf32>, %acc transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerreduction" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir index eee55e8..9f0b3a1 100644 --- a/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-multi-reduction-outer-lowering.mlir @@ -190,7 +190,7 @@ func.func @vector_multi_reduction_to_scalar(%arg0: vector<2x3xf32>, %acc: f32) - transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_multi_reduction lowering_strategy = "innerparallel" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir index 907bdb4..f6dd803 100644 --- a/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-outerproduct-lowering-transforms.mlir @@ -140,11 +140,11 @@ transform.sequence failures(propagate) { %f = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_outerproduct } : !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_broadcast } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir index 92b3e80..9ad0bbc 100644 --- a/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir +++ b/mlir/test/Dialect/Vector/vector-shape-cast-lowering-transforms.mlir @@ -154,7 +154,7 @@ transform.sequence failures(propagate) { %f = transform.structured.match ops{["func.func"]} in %module_op : (!transform.any_op) -> !transform.any_op - transform.apply_patterns [] to %f { + transform.apply_patterns to %f { transform.apply_patterns.vector.lower_shape_cast } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir index 749afa2..d64f888 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir @@ -17,7 +17,7 @@ func.func @transfer_read_rank_reducing( transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.apply_rank_reducing_subview_patterns } : !transform.any_op } @@ -39,7 +39,7 @@ func.func @transfer_write_rank_reducing(%arg : memref<1x1x3x2xi8, strided<[6, 6, transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.apply_rank_reducing_subview_patterns } : !transform.any_op } @@ -63,7 +63,7 @@ func.func @transfer_read_and_vector_rank_reducing( transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.apply_rank_reducing_subview_patterns } : !transform.any_op } @@ -87,7 +87,7 @@ func.func @transfer_write_and_vector_rank_reducing( transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.apply_rank_reducing_subview_patterns } : !transform.any_op } @@ -111,7 +111,7 @@ func.func @transfer_read_and_vector_rank_reducing_to_0d( transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.apply_rank_reducing_subview_patterns } : !transform.any_op } @@ -135,7 +135,7 @@ func.func @transfer_write_and_vector_rank_reducing_to_0d( transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.apply_rank_reducing_subview_patterns } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir index 424c53c..ecb2926 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-full-partial-split-copy-transform.mlir @@ -108,7 +108,7 @@ func.func @split_vector_transfer_read_strided_2d( transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "linalg-copy" } : !transform.any_op } @@ -169,7 +169,7 @@ func.func @split_vector_transfer_write_2d(%V: vector<4x8xf32>, %A: memref, %A: memref, %lb : index, %ub : transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.split_transfer_full_partial split_transfer_strategy = "vector-transfer" } : !transform.any_op } diff --git a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir index e136219..ad9bec9 100644 --- a/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir +++ b/mlir/test/Dialect/Vector/vector-transfer-to-vector-load-store.mlir @@ -240,7 +240,7 @@ func.func @transfer_broadcasting_complex(%mem : memref<10x20x30x8x8xf32>, %i : i transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_transfer max_transfer_rank = 99 transform.apply_patterns.vector.apply_transfer_permutation_patterns } : !transform.any_op @@ -362,7 +362,7 @@ func.func @transfer_write_broadcast_unit_dim( transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_transfer max_transfer_rank = 99 transform.apply_patterns.vector.apply_transfer_permutation_patterns } : !transform.any_op diff --git a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir index 0865f97a..f8b3605 100644 --- a/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir +++ b/mlir/test/Dialect/Vector/vector-transpose-lowering.mlir @@ -76,7 +76,7 @@ func.func @transpose1023_1x1x8x8xf32(%arg0: vector<1x1x8x8xf32>) -> vector<1x1x8 transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "eltwise" } : !transform.any_op } @@ -99,7 +99,7 @@ func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> { transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_1d" } : !transform.any_op } @@ -118,7 +118,7 @@ func.func @transpose(%arg0: vector<2x4xf32>) -> vector<4x2xf32> { transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "flat_transpose" } : !transform.any_op } @@ -605,7 +605,7 @@ func.func @transpose210_1x8x8xf32(%arg0: vector<1x8x8xf32>) -> vector<8x8x1xf32> transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_transpose avx2_lowering_strategy = true } : !transform.any_op } @@ -683,7 +683,7 @@ func.func @transpose_shuffle16x16xf32(%arg0: vector<16x16xf32>) -> vector<16x16x transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16" } : !transform.any_op } @@ -762,7 +762,7 @@ func.func @transpose021_shuffle16x16xf32(%arg0: vector<1x16x16xf32>) -> vector<1 transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16" } : !transform.any_op } diff --git a/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir b/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir index 798a224..74d47de 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/test-shuffle16x16.mlir @@ -31,7 +31,7 @@ func.func @entry() { transform.sequence failures(propagate) { ^bb1(%module_op: !transform.any_op): - transform.apply_patterns [] to %module_op { + transform.apply_patterns to %module_op { transform.apply_patterns.vector.lower_transpose lowering_strategy = "shuffle_16x16" } : !transform.any_op } diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index 9243723..fc6f323 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -784,10 +784,6 @@ void mlir::test::ApplyTestPatternsOp::populatePatterns( } namespace { -void populateTestPatterns(RewritePatternSet &patterns) { - patterns.insert(patterns.getContext()); -} - /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL /// types for operands and results. @@ -825,11 +821,6 @@ public: constraints.try_emplace("verbose_constraint", verboseConstraint); hooks.mergeInPDLMatchHooks(std::move(constraints)); }); - - addDialectDataInitializer( - [&](transform::PatternRegistry ®istry) { - registry.registerPatterns("transform.test", populateTestPatterns); - }); } }; } // namespace -- 2.7.4