From c63d2b2c714195d150ebec0ad4c712df36814c06 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Tue, 20 Jun 2023 10:48:40 +0200 Subject: [PATCH] [mlir][transform] Add TransformRewriter All `apply` functions now have a `TransformRewriter &` parameter. This rewriter should be used to modify the IR. It has a `TrackingListener` attached and updates the internal handle-payload mappings based on rewrites. Implementations no longer need to create their own `TrackingListener` and `IRRewriter`. Error checking is integrated into `applyTransform`. Tracking listener errors are reported only for ops with the `ReportTrackingListenerFailuresOpTrait` trait attached, allowing for a gradual migration. Furthermore, errors can be silenced with an op attribute. Additional API will be added to `TransformRewriter` in subsequent revisions. This revision just adds an "empty" `TransformRewriter` class and updates all `apply` implementations. Differential Revision: https://reviews.llvm.org/D152427 --- mlir/docs/Tutorials/transform/Ch1.md | 13 ++ mlir/docs/Tutorials/transform/Ch2.md | 51 +++-- mlir/docs/Tutorials/transform/Ch3.md | 39 ++-- mlir/examples/transform/Ch2/lib/MyExtension.cpp | 18 +- mlir/examples/transform/Ch3/include/MyExtension.td | 2 + mlir/examples/transform/Ch3/lib/MyExtension.cpp | 16 +- .../TransformOps/BufferizationTransformOps.td | 1 + .../Dialect/GPU/TransformOps/GPUTransformOps.td | 2 + .../Linalg/TransformOps/LinalgTransformOps.td | 111 +++++++--- .../MemRef/TransformOps/MemRefTransformOps.td | 1 + .../Dialect/SCF/TransformOps/SCFTransformOps.td | 5 + .../Tensor/TransformOps/TensorTransformOps.td | 1 + .../mlir/Dialect/Transform/IR/MatchInterfaces.h | 6 +- .../mlir/Dialect/Transform/IR/TransformDialect.td | 7 +- .../Dialect/Transform/IR/TransformInterfaces.h | 60 ++++-- .../Dialect/Transform/IR/TransformInterfaces.td | 9 + .../mlir/Dialect/Transform/IR/TransformOps.td | 16 +- .../Affine/TransformOps/AffineTransformOps.cpp | 5 +- .../TransformOps/BufferizationTransformOps.cpp | 18 +- .../Dialect/GPU/TransformOps/GPUTransformOps.cpp | 14 +- .../Linalg/TransformOps/LinalgTransformOps.cpp | 223 +++++++++------------ .../MemRef/TransformOps/MemRefTransformOps.cpp | 6 +- .../Dialect/SCF/TransformOps/SCFTransformOps.cpp | 27 +-- .../Tensor/TransformOps/TensorTransformOps.cpp | 4 +- mlir/lib/Dialect/Transform/IR/TransformDialect.cpp | 7 + .../Dialect/Transform/IR/TransformInterfaces.cpp | 91 ++++++++- mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 90 +++++---- .../Transform/PDLExtension/PDLExtensionOps.cpp | 6 +- .../Transform/test-pattern-application.mlir | 2 +- .../Transform/TestTransformDialectExtension.cpp | 89 +++++--- .../Transform/TestTransformDialectExtension.td | 12 +- 31 files changed, 602 insertions(+), 350 deletions(-) diff --git a/mlir/docs/Tutorials/transform/Ch1.md b/mlir/docs/Tutorials/transform/Ch1.md index 988117a..3db3931 100644 --- a/mlir/docs/Tutorials/transform/Ch1.md +++ b/mlir/docs/Tutorials/transform/Ch1.md @@ -362,3 +362,16 @@ test/Examples/transform/Ch1/invalidation-2.mlir:24:13: note: nested payload op Note that the “add” elementwise operation is indicated as payload ancestor because it was used to produce the tile loop, and the loop therefore has its location. Finally, we would like to replace the call to the outlined function with a call to the microkernel. Unfortunately, the Transform dialect doesn’t have support for this transformation (and cannot have if the call is rewritten to a custom, out-of-tree operation). Therefore, we need to define new transform operations. The next chapters will describe how this can be done. + +## Tracking IR Modifications + +The transform dialect automatically tracks all IR changes that are made as part +of transform ops. (Implementations must use the provided rewriter to modify IR.) +If a payload op is erased, it is automatically removed from all handles that it +is currently associated with. If a payload op is replaced, the transform dialect +tries to find the replacement op and updates all handles accordingly. If a +multi-result op is replaced with values that are defined by multiple ops, or if +an op is replaced with an op of a different type, an error is produced. This is +because it is unclear whether the direct replacements actually represent the +computation of the original op. There are ways to customize this behavior. More +details can be found at the documentation of `transform::TrackingListener`. diff --git a/mlir/docs/Tutorials/transform/Ch2.md b/mlir/docs/Tutorials/transform/Ch2.md index e3d6cf7..04f79bf 100644 --- a/mlir/docs/Tutorials/transform/Ch2.md +++ b/mlir/docs/Tutorials/transform/Ch2.md @@ -189,8 +189,13 @@ def ChangeCallTargetOp : Op(payloadOp); @@ -231,7 +238,7 @@ To finalize the definition of the transform operation, we need to implement the diag.attachNote(payloadOp->getLoc()) << "offending payload"; return diag; } - + updateCallee(call, getNewTarget()); } diff --git a/mlir/docs/Tutorials/transform/Ch3.md b/mlir/docs/Tutorials/transform/Ch3.md index 64837a6..bf533c2 100644 --- a/mlir/docs/Tutorials/transform/Ch3.md +++ b/mlir/docs/Tutorials/transform/Ch3.md @@ -10,26 +10,26 @@ A transform operation that applies to each payload operation individually and re // Define the new operation. By convention, prefix its name with the name of the dialect extension, "my.". The full operation name will be further prefixed with "transform.". def ChangeCallTargetOp : Op]> { - // Provide a brief and a full description. It is recommended that the latter describes + // Provide a brief and a full description. It is recommended that the latter describes // the effects on the operands and how the operation processes various failure modes. let summary = "Changes the callee of a call operation to the specified one"; let description = [{ - For each `func.call` payload operation associated with the handle, changes its + For each `func.call` payload operation associated with the handle, changes its callee to be the symbol whose name is provided as an attribute to this operation. - Generates a silenceable failure if the operand is associated with payload operations + Generates a silenceable failure if the operand is associated with payload operations that are not `func.call`. Only reads the operand. }]; - // The arguments include the handle to the payload operations and the attribute that - // specifies the new callee. The handle must implement TransformHandleTypeInterface. - // We use a string attribute as the symbol may not exist in the transform IR so the - // verification may fail. + // The arguments include the handle to the payload operations and the attribute that + // specifies the new callee. The handle must implement TransformHandleTypeInterface. + // We use a string attribute as the symbol may not exist in the transform IR so the + // verification may fail. let arguments = (ins Transform_ConcreteOpType<"func.call">:$call, StrAttr:$new_target); @@ -43,6 +43,7 @@ def ChangeCallTargetOp : Op]> { @@ -166,7 +167,7 @@ def CallToOp : Opresult_type_begin(), call->result_type_end()); state.operands.assign(call->operand_begin(), call->operand_end()); - mlir::OpBuilder builder(call); - mlir::Operation *replacement = builder.create(state); - call->replaceAllUsesWith(replacement->getResults()); - call->erase(); + mlir::Operation *replacement = rewriter.create(state); + rewriter.replaceOp(call, replacement->getResults()); return replacement; } // See above for the signature description. mlir::DiagnosedSilenceableFailure mlir::transform::CallToOp::applyToOne( - mlir::CallOpInterface call, mlir::transform::ApplyToEachResultList &results, + mlir::transform::TransformRewriter &rewriter, mlir::CallOpInterface call, + mlir::transform::ApplyToEachResultList &results, mlir::transform::TransformState &state) { // Dispatch to the actual transformation. - Operation *replacement = replaceCallWithOp(call); + Operation *replacement = replaceCallWithOp(rewriter, call); // Associate the payload operation produced by the rewrite with the result // handle of this transform operation. diff --git a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td index 49b5ef6..807a63d 100644 --- a/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td +++ b/mlir/include/mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.td @@ -152,6 +152,7 @@ def EmptyTensorToAllocTensorOp let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::tensor::EmptyOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td index 4216231..4e47638 100644 --- a/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/GPUTransformOps.td @@ -126,6 +126,7 @@ def MapNestedForallToThreads : }]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -189,6 +190,7 @@ def MapForallToBlocks : }]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 3e2cb78..0157fc2 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -80,7 +80,8 @@ def ApplyTilingCanonicalizationPatternsOp : Op, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ This transform materializes an allocation for the targeted tensor value. It replaces all original uses of the target with the newly allocated buffer, @@ -133,7 +134,8 @@ def DecomposeOp : Op { + TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Decomposes named complex operations, such as higher-dimensional (depthwise) convolutions, into combinations of lower-dimensional equivalents @@ -155,6 +157,7 @@ def DecomposeOp : Op]> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Tiles the operations pointed to by the target handle and fuses their producers greedily using the options provided as attributes. @@ -192,7 +196,8 @@ def FuseIntoContainingOp : Op, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { let summary = "Fuse a producer into a containing operation."; let description = [{ @@ -247,7 +252,8 @@ def FuseIntoContainingOp : def GeneralizeOp : Op { + TransformOpInterface, TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Transforms a named structured operation into the generic form with the explicit attached region. @@ -270,6 +276,7 @@ def GeneralizeOp : Op { + TransformOpInterface, TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Interchanges the iterators of the operations pointed to by the target handle using the iterator interchange attribute. @@ -313,6 +321,7 @@ def InterchangeOp : Op { + TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Rewrite a tensor.pack into tensor.pad + tensor.expand_shape + linalg.transpose. @@ -350,6 +360,7 @@ def LowerPackOp : Op { + TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Lower a tensor.unpack into empty + linalg.transpose + tensor.collapse_shape + tensor.extract_slice. @@ -389,6 +401,7 @@ def LowerUnPackOp : Op, - TransformOpInterface, TransformEachOpTrait]> { + TransformOpInterface, TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Emits the IR computing the tile sizes `s1` and `s2` such that: @@ -531,6 +545,7 @@ def MultiTileSizesOp : Op, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Pack a LinalgOp by applying a data tiling transformation on the op and packing the operands according to the `packed_sizes` specification. @@ -632,7 +648,8 @@ def PackOp : Op, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Target a Linalg op and rewrite it into packed LinalgOp form by trying to infer whether a known suboperation is embedded @@ -740,7 +757,8 @@ def PackGreedilyOp : Op]> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Apply a transposition to a single `tensor.pack` (resp. `tensor.unpack`) and update the `linalg.generic` op that consumes (resp. produces) the operation. @@ -803,7 +821,8 @@ def PackTransposeOp : Op { + TransformOpInterface, TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Pads the operations pointed to by the target handle using the options provides as operation attributes. @@ -837,6 +856,7 @@ def PadOp : Op, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Helper transform used to hoist a tensor.pad target operation. This operation creates the packing loop nest required by the hoist_pad operation and makes @@ -931,6 +952,7 @@ def HoistPadOp : Op { + TransformOpInterface, TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Promotes the specified operands of the target into a separate memory buffer. @@ -978,6 +1001,7 @@ def PromoteOp : Op, - DeclareOpInterfaceMethods] # GraphRegionNoTerminator.traits> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait] # GraphRegionNoTerminator.traits> { let description = [{ Replace all `target` payload ops with the single op that is contained in this op's region. All targets must have zero arguments and must be isolated @@ -1018,7 +1043,8 @@ def ReplaceOp : Op { + TransformOpInterface, TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Indicates that ops of a specific kind in the given function should be scalarized (i.e. their dynamic dimensions tiled by 1). @@ -1051,6 +1077,7 @@ def ScalarizeOp : Op { + TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Rewrite a supported tensor operation that is not in destination-passing style into a form that is in destination-passing style. @@ -1098,6 +1126,7 @@ def RewriteInDestinationPassingStyleOp : Op< let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1110,7 +1139,8 @@ def RewriteInDestinationPassingStyleOp : Op< def SplitOp : Op, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Indicates that the given `target` op should be split into two complementary parts, which combined cover the entire iteration domain of the original op. @@ -1147,7 +1177,8 @@ def SplitOp : Op { + TransformEachOpTrait, TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Indicates that the given `target` op should be transformed with the `splitReduction` transformation and split factor provided as attribute. @@ -1309,6 +1340,7 @@ def SplitReductionOp : Op { + TransformEachOpTrait, TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Indicates that the given `target` op should be transformed with the `tileReduction` transformation with the tile size provided as attribute. @@ -1413,6 +1446,7 @@ def TileReductionUsingScfOp : Op { + TransformEachOpTrait, TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Tile a PartialReductionOpInterface op to a tiled `scf.forall` doing partial reduction. @@ -1522,6 +1557,7 @@ def TileReductionUsingForallOp : let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::linalg::LinalgOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -1535,7 +1571,8 @@ def TileReductionUsingForallOp : def TileOp : Op, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Indicates that the given `target` op should be tiled with the given sizes. This transform generates a loop nest with a smaller ("tiled") target @@ -1616,7 +1653,7 @@ def TileToForallOp : Op, - TransformOpInterface]> { + TransformOpInterface, ReportTrackingListenerFailuresOpTrait]> { let description = [{ Tile a TilingInterface op to a tiled `scf.forall`. @@ -1726,6 +1763,7 @@ def TileToForallOp : let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure apply( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::transform::TransformResults &transformResults, ::mlir::transform::TransformState &state); @@ -1740,7 +1778,8 @@ def TileToForallOp : def TileToScfForOp : Op, - DeclareOpInterfaceMethods]> { + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Indicates that the given `target` op should be tiled with the given sizes. This transform generates a loop nest with a smaller ("tiled") target @@ -1807,7 +1846,8 @@ def TileToScfForOp : Op { + TransformEachOpTrait, TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Indicates that the given `target` op all the ops it contains should be vectorized with the configuration specified by the attributes of this op. @@ -1862,6 +1902,7 @@ def VectorizeOp : Op, - TransformOpInterface]> { + TransformOpInterface, ReportTrackingListenerFailuresOpTrait]> { let description = [{ Vectorize the target ops, which must be Linalg ops, with masked vectors of the specified size. @@ -1914,6 +1955,7 @@ def MaskedVectorizeOp : Op { + TransformEachOpTrait, TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Hoist vector.transfer_read / vector.transfer_write pairs out of immediately enclosing scf::ForOp iteratively, if the following conditions are true: @@ -1958,7 +2001,8 @@ def HoistRedundantVectorTransfersOp : ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( - ::mlir::func::FuncOp target, + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::func::FuncOp target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); }]; @@ -1973,7 +2017,8 @@ def ConvertConv2DToImg2ColOp : Op { + TransformEachOpTrait, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Convert linalg.conv_2d_xxx into linalg.generic (for img2col packing) and linalg.matmul. @@ -2034,6 +2079,7 @@ def ConvertConv2DToImg2ColOp : Op, TransformEachOpTrait, - TransformOpInterface]> { + TransformOpInterface, + ReportTrackingListenerFailuresOpTrait]> { let description = [{ Hoists supported tensor subset extract/insert operation pairs out of immediately enclosing loop iteratively, if the following conditions @@ -2082,6 +2129,7 @@ def HoistRedundantTensorSubsetsOp : let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -2120,6 +2168,7 @@ def InsertSliceToCopyOp : ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td index 35a1d84..86a3586b 100644 --- a/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td +++ b/mlir/include/mlir/Dialect/MemRef/TransformOps/MemRefTransformOps.td @@ -158,6 +158,7 @@ def MemRefMakeLoopIndependentOp let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td index d128c0e..b9d2dd1 100644 --- a/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td +++ b/mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td @@ -126,6 +126,7 @@ def LoopPeelOp : Op(this->getOperation()).getOperandHandle(); auto payload = state.getPayloadOps(operandHandle); @@ -90,7 +91,8 @@ public: return success(); } - DiagnosedSilenceableFailure apply(TransformResults &results, + DiagnosedSilenceableFailure apply(TransformRewriter &rewriter, + TransformResults &results, TransformState &state) { Value operandHandle = cast(this->getOperation()).getOperandHandle(); ValueRange payload = state.getPayloadValues(operandHandle); diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td index 0539187..6277580 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -27,11 +27,16 @@ def Transform_Dialect : Dialect { constexpr const static ::llvm::StringLiteral kWithNamedSequenceAttrName = "transform.with_named_sequence"; - /// Names of the attribute attachable to an operation so it can be + /// Name of the attribute attachable to an operation so it can be /// identified as root by the default interpreter pass. constexpr const static ::llvm::StringLiteral kTargetTagAttrName = "transform.target_tag"; + /// Name of the attribute attachable to an operation, indicating that + /// TrackingListener failures should be silenced. + constexpr const static ::llvm::StringLiteral + kSilenceTrackingFailuresAttrName = "transform.silence_tracking_failures"; + /// Names of the attributes indicating whether an argument of an external /// transform dialect symbol is consumed or only read. constexpr const static ::llvm::StringLiteral diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index 1afac0a..5a22ae5 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -18,11 +18,11 @@ #include "mlir/Support/LogicalResult.h" namespace mlir { - namespace transform { class TransformOpInterface; class TransformResults; +class TransformRewriter; class TransformState; using Param = Attribute; @@ -856,8 +856,7 @@ class TrackingListener : public RewriterBase::Listener, public TransformState::Extension { public: /// Create a new TrackingListener for usage in the specified transform op. - explicit TrackingListener(TransformState &state, TransformOpInterface op) - : TransformState::Extension(state), transformOp(op) {} + TrackingListener(TransformState &state, TransformOpInterface op); protected: /// Return a replacement payload op for the given op, which is going to be @@ -937,6 +936,9 @@ private: /// The transform op in which this TrackingListener is used. TransformOpInterface transformOp; + + /// The handles that are consumed by the transform op. + DenseSet consumedHandles; }; /// A specialized listener that keeps track of cases in which no replacement @@ -968,6 +970,28 @@ private: int64_t errorCounter = 0; }; +/// This is a special rewriter to be used in transform op implementations, +/// providing additional helper functions to update the transform state, etc. +// TODO: Helper functions will be added in a subsequent change. +class TransformRewriter : public RewriterBase { +protected: + friend class TransformState; + + /// Create a new TransformRewriter. + explicit TransformRewriter(MLIRContext *ctx, + ErrorCheckingTrackingListener *listener); + +public: + /// Return "true" if the tracking listener had failures. + bool hasTrackingFailures() const; + + /// Silence all tracking failures that have been encountered so far. + void silenceTrackingFailure(); + +private: + ErrorCheckingTrackingListener *const listener; +}; + /// This trait is supposed to be attached to Transform dialect operations that /// can be standalone top-level transforms. Such operations typically contain /// other Transform dialect operations that can be executed following some @@ -1064,7 +1088,8 @@ public: /// 5. If any `applyToOne` return silenceableFailure, the transformation is /// considered silenceable. /// 6. Otherwise the transformation is considered successful. - DiagnosedSilenceableFailure apply(TransformResults &transformResults, + DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, + TransformResults &transformResults, TransformState &state); /// Checks that the op matches the expectations of this trait. @@ -1213,6 +1238,15 @@ public: } }; +/// `TrackingListener` failures are reported only for ops that have this trait. +/// The purpose of this trait is to give users more time to update their custom +/// transform ops to use the provided `TransformRewriter` for all IR +/// modifications. This trait will eventually be removed, and failures will be +/// reported for all transform ops. +template +class ReportTrackingListenerFailuresOpTrait + : public OpTrait::TraitBase {}; + /// A single result of applying a transform op with `ApplyEachOpTrait` to a /// single payload operation. using ApplyToEachResult = MappedValue; @@ -1307,14 +1341,14 @@ void setApplyToOneResults(Operation *transformOp, /// `targets` contains operations of the same class and a silenceable failure /// is reported if it does not. template -DiagnosedSilenceableFailure -applyTransformToEach(TransformOpTy transformOp, Range &&targets, - SmallVectorImpl &results, - TransformState &state) { +DiagnosedSilenceableFailure applyTransformToEach( + TransformOpTy transformOp, TransformRewriter &rewriter, Range &&targets, + SmallVectorImpl &results, TransformState &state) { using OpTy = typename llvm::function_traits< - decltype(&TransformOpTy::applyToOne)>::template arg_t<0>; + decltype(&TransformOpTy::applyToOne)>::template arg_t<1>; static_assert(std::is_convertible::value, "expected transform function to take an operation"); + OpBuilder::InsertionGuard g(rewriter); SmallVector silenceableStack; unsigned expectedNumResults = transformOp->getNumResults(); @@ -1331,8 +1365,9 @@ applyTransformToEach(TransformOpTy transformOp, Range &&targets, ApplyToEachResultList partialResults; partialResults.reserve(expectedNumResults); Location specificOpLoc = specificOp->getLoc(); + rewriter.setInsertionPoint(specificOp); DiagnosedSilenceableFailure res = - transformOp.applyToOne(specificOp, partialResults, state); + transformOp.applyToOne(rewriter, specificOp, partialResults, state); if (res.isDefiniteFailure()) return DiagnosedSilenceableFailure::definiteFailure(); @@ -1367,7 +1402,8 @@ LogicalResult checkNestedConsumption(Location loc, template mlir::DiagnosedSilenceableFailure mlir::transform::TransformEachOpTrait::apply( - TransformResults &transformResults, TransformState &state) { + TransformRewriter &rewriter, TransformResults &transformResults, + TransformState &state) { Value handle = this->getOperation()->getOperand(0); auto targets = state.getPayloadOps(handle); @@ -1403,7 +1439,7 @@ mlir::transform::TransformEachOpTrait::apply( // corresponding results entry. SmallVector results; DiagnosedSilenceableFailure result = detail::applyTransformToEach( - cast(this->getOperation()), targets, results, state); + cast(this->getOperation()), rewriter, targets, results, state); // Step 3. Propagate the definite failure if any and bail out. if (result.isDefiniteFailure()) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td index 07f8ebc..eaab057 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.td @@ -41,10 +41,15 @@ def TransformOpInterface : OpInterface<"TransformOpInterface"> { transformation represented by the current op is targeted. Returns a special status object indicating whether the transformation succeeded or failed, and, if it failed, whether the failure is recoverable or not. + + IR must be created, modified and deleted with the provided rewriter. + implementations are responsible for setting the insertion point of the + rewriter to the desired location. }], /*returnType=*/"::mlir::DiagnosedSilenceableFailure", /*name=*/"apply", /*arguments=*/(ins + "::mlir::transform::TransformRewriter &":$rewriter, "::mlir::transform::TransformResults &":$transformResults, "::mlir::transform::TransformState &":$state )>, @@ -194,6 +199,10 @@ def ParamProducerTransformOpTrait : NativeOpTrait<"ParamProducerTransformOpTrait let cppNamespace = "::mlir::transform"; } +def ReportTrackingListenerFailuresOpTrait : NativeOpTrait<"ReportTrackingListenerFailuresOpTrait"> { + let cppNamespace = "::mlir::transform"; +} + def FindPayloadReplacementOpInterface : OpInterface<"FindPayloadReplacementOpInterface"> { let description = [{ diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index 6f9aaa6..a5ed64e 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -129,7 +129,8 @@ def AnnotateOp : TransformDialectOp<"annotate", def ApplyPatternsOp : TransformDialectOp<"apply_patterns", [TransformOpInterface, TransformEachOpTrait, - DeclareOpInterfaceMethods] + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait] # GraphRegionNoTerminator.traits> { let summary = "Greedily applies patterns to the body of the targeted op"; let description = [{ @@ -149,8 +150,8 @@ def ApplyPatternsOp : TransformDialectOp<"apply_patterns", considered "payload op replacements". Furthermore, only if the replacement values are defined by the same op and that op has the same type as the original op, the mapping is updated. Otherwise, this transform fails - silently unless `fail_on_payload_replacement_not_found` is set to "false". - More details can be found at the documentation site of `TrackingListener`. + silently. More details can be found at the documentation site of + `TrackingListener`. This transform also fails silently if the pattern application did not converge within the default number of iterations/rewrites of the greedy @@ -158,8 +159,7 @@ def ApplyPatternsOp : TransformDialectOp<"apply_patterns", }]; let arguments = (ins - TransformHandleTypeInterface:$target, - DefaultValuedAttr:$fail_on_payload_replacement_not_found); + TransformHandleTypeInterface:$target); let results = (outs); let regions = (region MaxSizedRegion<1>:$region); @@ -171,12 +171,12 @@ def ApplyPatternsOp : TransformDialectOp<"apply_patterns", OpBuilder<(ins "Value":$target, CArg<"function_ref", "nullptr">: - $bodyBuilder, - CArg<"bool", "true">:$failOnPayloadReplacementNotFound)>, + $bodyBuilder)>, ]; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -222,6 +222,7 @@ def ApplyRegisteredPassOp : TransformDialectOp<"apply_registered_pass", let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -238,6 +239,7 @@ def CastOp : TransformDialectOp<"cast", let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); diff --git a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp index 5a7f092..0b4570c 100644 --- a/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp +++ b/mlir/lib/Dialect/Affine/TransformOps/AffineTransformOps.cpp @@ -63,7 +63,8 @@ struct SimplifyAffineMinMaxOp : public OpRewritePattern { } // namespace DiagnosedSilenceableFailure -SimplifyBoundedAffineOpsOp::apply(TransformResults &results, +SimplifyBoundedAffineOpsOp::apply(transform::TransformRewriter &rewriter, + TransformResults &results, TransformState &state) { // Get constraints for bounded values. SmallVector lbs; @@ -127,6 +128,8 @@ SimplifyBoundedAffineOpsOp::apply(TransformResults &results, SimplifyAffineMinMaxOp>(getContext(), cstr); FrozenRewritePatternSet frozenPatterns(std::move(patterns)); GreedyRewriteConfig config; + config.listener = + static_cast(rewriter.getListener()); config.strictMode = GreedyRewriteStrictness::ExistingAndNewOps; // Apply the simplification pattern to a fixpoint. if (failed(applyOpPatternsAndFold(targets, frozenPatterns, config))) { diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp index 9d45442..9c23ad6 100644 --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -26,7 +26,8 @@ using namespace mlir::transform; //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::OneShotBufferizeOp::apply(TransformResults &transformResults, +transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, + TransformResults &transformResults, TransformState &state) { OneShotBufferizationOptions options; options.allowReturnAllocs = getAllowReturnAllocs(); @@ -71,10 +72,9 @@ void transform::EliminateEmptyTensorsOp::getEffects( modifiesPayload(effects); } -DiagnosedSilenceableFailure -transform::EliminateEmptyTensorsOp::apply(TransformResults &transformResults, - TransformState &state) { - IRRewriter rewriter(getContext()); +DiagnosedSilenceableFailure transform::EliminateEmptyTensorsOp::apply( + transform::TransformRewriter &rewriter, TransformResults &transformResults, + TransformState &state) { OneShotBufferizationOptions options; options.allowReturnAllocs = true; @@ -95,11 +95,9 @@ transform::EliminateEmptyTensorsOp::apply(TransformResults &transformResults, // EmptyTensorToAllocTensorOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure -EmptyTensorToAllocTensorOp::applyToOne(tensor::EmptyOp target, - ApplyToEachResultList &results, - transform::TransformState &state) { - IRRewriter rewriter(target->getContext()); +DiagnosedSilenceableFailure EmptyTensorToAllocTensorOp::applyToOne( + transform::TransformRewriter &rewriter, tensor::EmptyOp target, + ApplyToEachResultList &results, transform::TransformState &state) { rewriter.setInsertionPoint(target); auto alloc = rewriter.replaceOpWithNewOp( target, target.getType(), target.getDynamicSizes()); diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp index e9027d1..28b645e 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -398,7 +398,7 @@ createGpuLaunch(RewriterBase &rewriter, Location loc, /// Alter kernel configuration of the given kernel. static DiagnosedSilenceableFailure -alterGpuLaunch(IRRewriter &rewriter, LaunchOp gpuLaunch, +alterGpuLaunch(RewriterBase &rewriter, LaunchOp gpuLaunch, TransformOpInterface transformOp, std::optional gridDimX = std::nullopt, std::optional gridDimY = std::nullopt, @@ -661,12 +661,10 @@ mlir::transform::gpu::findTopLevelForallOp(Operation *target, return DiagnosedSilenceableFailure::success(); } -DiagnosedSilenceableFailure -transform::MapForallToBlocks::applyToOne(Operation *target, - ApplyToEachResultList &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, transform::TransformState &state) { LaunchOp gpuLaunch = dyn_cast(target); - IRRewriter rewriter(getContext()); auto transformOp = cast(getOperation()); if (!getGenerateGpuLaunch() && !gpuLaunch) { @@ -856,7 +854,8 @@ DiagnosedSilenceableFailure mlir::transform::gpu::mapNestedForallToThreadsImpl( } DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne( - Operation *target, ApplyToEachResultList &results, TransformState &state) { + transform::TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, TransformState &state) { LaunchOp gpuLaunch = dyn_cast(target); auto transformOp = cast(getOperation()); @@ -877,7 +876,6 @@ DiagnosedSilenceableFailure transform::MapNestedForallToThreads::applyToOne( // Set the GPU launch configuration for the block dims early, this is not // subject to IR inspection. - IRRewriter rewriter(getContext()); diag = alterGpuLaunch(rewriter, gpuLaunch, transformOp, std::nullopt, std::nullopt, std::nullopt, blockDims[0], blockDims[1], blockDims[2]); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 4a9fc8e..0424763 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -169,13 +169,11 @@ void transform::ApplyTilingCanonicalizationPatternsOp::populatePatterns( // BufferizeToAllocationOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure -transform::BufferizeToAllocationOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure transform::BufferizeToAllocationOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { Attribute memorySpace = getMemorySpace().has_value() ? getMemorySpace().value() : Attribute(); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); auto transformed = llvm::to_vector( llvm::map_range(state.getPayloadValues(getTarget()), [&](Value v) { return linalg::bufferizeToAllocation(rewriter, v, memorySpace); @@ -196,7 +194,8 @@ void transform::BufferizeToAllocationOp::getEffects( //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::DecomposeOp::applyToOne(LinalgOp target, +transform::DecomposeOp::applyToOne(transform::TransformRewriter &rewriter, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { #define DOWNSCALE(trans) \ @@ -286,7 +285,8 @@ static LogicalResult applyTilingToAll( } DiagnosedSilenceableFailure -transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, +transform::FuseOp::apply(transform::TransformRewriter &rewriter, + mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { SmallVector tileSizes = extractFromI64ArrayAttr(getTileSizes()); SmallVector tileInterchange = @@ -297,8 +297,6 @@ transform::FuseOp::apply(mlir::transform::TransformResults &transformResults, tilingOptions = tilingOptions.setTileSizes(tileSizes); scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.tilingOptions = tilingOptions; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); LogicalResult result = applyTilingToAll( rewriter, getOperation(), state.getPayloadOps(getTarget()), tileSizes.size() - llvm::count(tileSizes, 0), transformResults, @@ -721,7 +719,8 @@ static void forciblyReplaceReferencedPayloadOperation(TransformState &state, } DiagnosedSilenceableFailure -transform::FuseIntoContainingOp::apply(transform::TransformResults &results, +transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector fusedOps; auto producerOps = state.getPayloadOps(getProducerOp()); @@ -764,8 +763,6 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results, return failure(); }; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); while (!remainingProducers.empty()) { auto nextProducer = getNextProducer(); if (failed(nextProducer)) { @@ -842,7 +839,8 @@ void transform::FuseIntoContainingOp::getEffects( //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GeneralizeOp::applyToOne(LinalgOp target, +transform::GeneralizeOp::applyToOne(transform::TransformRewriter &rewriter, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Exit early if no transformation is needed. @@ -850,8 +848,6 @@ transform::GeneralizeOp::applyToOne(LinalgOp target, results.push_back(target); return DiagnosedSilenceableFailure::success(); } - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr generic = generalizeNamedOp(rewriter, target); if (succeeded(generic)) { @@ -866,7 +862,8 @@ transform::GeneralizeOp::applyToOne(LinalgOp target, //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::InterchangeOp::applyToOne(GenericOp target, +transform::InterchangeOp::applyToOne(transform::TransformRewriter &rewriter, + GenericOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { ArrayRef interchangeVector = getIteratorInterchange(); @@ -875,8 +872,6 @@ transform::InterchangeOp::applyToOne(GenericOp target, results.push_back(target); return DiagnosedSilenceableFailure::success(); } - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr res = interchangeGenericOp(rewriter, target, SmallVector(interchangeVector.begin(), @@ -904,10 +899,9 @@ LogicalResult transform::InterchangeOp::verify() { //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne( - tensor::PackOp target, transform::ApplyToEachResultList &transformResults, + transform::TransformRewriter &rewriter, tensor::PackOp target, + transform::ApplyToEachResultList &transformResults, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr res = lowerPack(rewriter, target); if (failed(res)) { @@ -925,10 +919,9 @@ DiagnosedSilenceableFailure transform::LowerPackOp::applyToOne( //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::LowerUnPackOp::applyToOne( - tensor::UnPackOp target, transform::ApplyToEachResultList &transformResults, + transform::TransformRewriter &rewriter, tensor::UnPackOp target, + transform::ApplyToEachResultList &transformResults, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr res = lowerUnPack(rewriter, target); if (failed(res)) { @@ -964,7 +957,8 @@ void transform::MatchOp::build(OpBuilder &builder, OperationState &result, } DiagnosedSilenceableFailure -transform::MatchOp::apply(transform::TransformResults &results, +transform::MatchOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { llvm::StringSet<> strs; if (getOps().has_value()) @@ -1053,8 +1047,8 @@ static ParseResult parseMultitileSizesTypes(OpAsmParser &parser, } DiagnosedSilenceableFailure transform::MultiTileSizesOp::applyToOne( - LinalgOp target, transform::ApplyToEachResultList &results, - TransformState &state) { + transform::TransformRewriter &rewriter, LinalgOp target, + transform::ApplyToEachResultList &results, TransformState &state) { if (isa(getLowSize().getType())) { if (target.hasDynamicShape()) { auto diag = emitSilenceableError() @@ -1155,7 +1149,8 @@ SmallVector transform::PackOp::getMixedPackedSizes() { } DiagnosedSilenceableFailure -transform::PackOp::apply(transform::TransformResults &transformResults, +transform::PackOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &transformResults, transform::TransformState &state) { auto targetOps = state.getPayloadOps(getTarget()); // If nothing to pack, propagate success. @@ -1184,8 +1179,6 @@ transform::PackOp::apply(transform::TransformResults &transformResults, DiagnosedSilenceableFailure status = unpackSingleIndexResultPayloadOperations( state, *this, packedSizes, getMixedPackedSizes()); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(linalgOp); FailureOr maybeResult = pack(rewriter, linalgOp, packedSizes); if (failed(maybeResult)) @@ -1364,11 +1357,10 @@ packMatmulGreedily(RewriterBase &rewriter, LinalgOp linalgOp, } DiagnosedSilenceableFailure -PackGreedilyOp::apply(transform::TransformResults &transformResults, +PackGreedilyOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &transformResults, transform::TransformState &state) { SmallVector results; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); for (Operation *op : state.getPayloadOps(getTarget())) { auto linalgOp = dyn_cast(op); if (!linalgOp) @@ -1464,7 +1456,8 @@ bool isValidPackingPermutation( } DiagnosedSilenceableFailure -transform::PackTransposeOp::apply(transform::TransformResults &transformResults, +transform::PackTransposeOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &transformResults, transform::TransformState &state) { auto packOrUnpackOps = state.getPayloadOps(getTargetPackOrUnPackOp()); auto linalgOps = state.getPayloadOps(getTargetLinalgOp()); @@ -1542,8 +1535,6 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults, assert(packOp && linalgOp && "unexpected null op"); // Step 3. Actually transpose the ops. - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr res = packTranspose( rewriter, packOp, linalgOp, unPackOp, getOuterPerm(), getInnerPerm()); // Preconditions have been checked, it is an error to fail here. @@ -1568,7 +1559,8 @@ transform::PackTransposeOp::apply(transform::TransformResults &transformResults, //===---------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::PadOp::applyToOne(LinalgOp target, +transform::PadOp::applyToOne(transform::TransformRewriter &rewriter, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { // Convert the integer packing flags to booleans. @@ -1616,8 +1608,6 @@ transform::PadOp::applyToOne(LinalgOp target, transposePaddings.push_back( extractFromI64ArrayAttr(cast(transposeVector))); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); LinalgOp paddedOp; SmallVector paddingDimensions = extractFromI64ArrayAttr(getPaddingDimensions()); @@ -1684,6 +1674,7 @@ LogicalResult transform::PadOp::verify() { //===---------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &transformResults, transform::TransformState &state) { auto targetOps = state.getPayloadOps(getTarget()); @@ -1700,8 +1691,6 @@ DiagnosedSilenceableFailure transform::HoistPadBuildPackingLoopNestOp::apply( if (!padOp || !loopOp) return emitDefiniteFailure() << "requires exactly 2 non-null handles"; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr result = linalg::detail::buildPackingLoopNest(rewriter, padOp, loopOp, getTranspose()); @@ -1740,13 +1729,12 @@ void transform::HoistPadBuildPackingLoopNestOp::getEffects( } DiagnosedSilenceableFailure -transform::HoistPadOp::applyToOne(tensor::PadOp target, +transform::HoistPadOp::applyToOne(transform::TransformRewriter &rewriter, + tensor::PadOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { tensor::PadOp hoistedPadOp; SmallVector transposeOps; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr result = hoistPaddingOnTensors(rewriter, target, getNumLoops(), getTranspose(), hoistedPadOp, transposeOps); @@ -1779,7 +1767,8 @@ LogicalResult transform::HoistPadOp::verify() { //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::PromoteOp::applyToOne(LinalgOp target, +transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { LinalgPromotionOptions promotionOptions; @@ -1829,8 +1818,6 @@ transform::PromoteOp::applyToOne(LinalgOp target, if (failed(promoteSubviewsPrecondition(target, promotionOptions))) return emitDefaultDefiniteFailure(target); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr res = promoteSubViews(rewriter, target, promotionOptions); if (failed(res)) @@ -1844,7 +1831,8 @@ transform::PromoteOp::applyToOne(LinalgOp target, //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ReplaceOp::apply(TransformResults &transformResults, +transform::ReplaceOp::apply(transform::TransformRewriter &rewriter, + TransformResults &transformResults, TransformState &state) { auto payload = state.getPayloadOps(getTarget()); @@ -1859,8 +1847,6 @@ transform::ReplaceOp::apply(TransformResults &transformResults, } // Clone and replace. - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); Operation *pattern = &getBodyRegion().front().front(); SmallVector replacements; for (Operation *target : payload) { @@ -1904,7 +1890,8 @@ LogicalResult transform::ReplaceOp::verify() { //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ScalarizeOp::applyToOne(LinalgOp target, +transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter, + LinalgOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::SCFTilingOptions tilingOptions; @@ -1916,8 +1903,6 @@ transform::ScalarizeOp::applyToOne(LinalgOp target, AffineMap map = target.getShapesToLoopsMap(); if (!map) return tileSizes; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); SmallVector shapeSizes = affine::makeComposedFoldedMultiResultAffineApply(rewriter, loc, map, allShapeSizes); @@ -1931,8 +1916,6 @@ transform::ScalarizeOp::applyToOne(LinalgOp target, return tileSizes; }); SmallVector emptyTileSizes; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr maybeTilingResult = tileUsingSCFForOp( rewriter, cast(target.getOperation()), tilingOptions); @@ -1956,11 +1939,10 @@ transform::ScalarizeOp::applyToOne(LinalgOp target, DiagnosedSilenceableFailure transform::RewriteInDestinationPassingStyleOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { SmallVector res; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr maybeResult = TypeSwitch>(target) @@ -1978,13 +1960,12 @@ transform::RewriteInDestinationPassingStyleOp::applyToOne( // SplitOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure SplitOp::apply(TransformResults &results, - TransformState &state) { +DiagnosedSilenceableFailure +SplitOp::apply(transform::TransformRewriter &rewriter, + TransformResults &results, TransformState &state) { // Collect the dynamic split points if provided. SmallVector payload = llvm::to_vector(state.getPayloadOps(getTarget())); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); SmallVector splitPoints; splitPoints.reserve(payload.size()); if (getDynamicSplitPoint()) { @@ -2184,15 +2165,14 @@ void transform::SplitReductionOp::build( } DiagnosedSilenceableFailure transform::SplitReductionOp::applyToOne( - LinalgOp target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, LinalgOp target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { ControlSplitReductionFn splitFn = [&](LinalgOp) { return linalg::SplitReductionOptions{int64_t(getSplitFactor()), unsigned(getInsertSplitDimension()), bool(getInnerParallel())}; }; - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr splitResult = (getUseScalingAlgorithm()) @@ -2230,10 +2210,9 @@ void transform::TileReductionUsingScfOp::build( } DiagnosedSilenceableFailure transform::TileReductionUsingScfOp::applyToOne( - LinalgOp target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, LinalgOp target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); FailureOr result = scf::tileReductionUsingScf( rewriter, cast(target.getOperation()), @@ -2274,10 +2253,9 @@ void transform::TileReductionUsingForallOp::build( } DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne( - LinalgOp target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, LinalgOp target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); SmallVector numThreads = getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads())); @@ -2363,7 +2341,8 @@ void transform::TileOp::build(OpBuilder &builder, OperationState &result, } DiagnosedSilenceableFailure -transform::TileOp::apply(TransformResults &transformResults, +transform::TileOp::apply(transform::TransformRewriter &rewriter, + TransformResults &transformResults, TransformState &state) { ArrayRef tileSizes = getStaticSizes(); @@ -2478,8 +2457,6 @@ transform::TileOp::apply(TransformResults &transformResults, } tilingOptions.setInterchange(getInterchange()); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr maybeTilingResult = tileUsingSCFForOp(rewriter, tilingInterface, tilingOptions); if (failed(maybeTilingResult)) @@ -2720,45 +2697,44 @@ DiagnosedSilenceableFailure transform::tileToForallOpImpl( } DiagnosedSilenceableFailure -transform::TileToForallOp::apply(transform::TransformResults &transformResults, +transform::TileToForallOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &transformResults, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); - auto transformOp = cast(getOperation()); - - // Result payload ops. - SmallVector tileOps; - SmallVector tiledOps; - - // Unpack handles. - SmallVector mixedNumThreads; - DiagnosedSilenceableFailure status = - getPackedNumThreads() - ? unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedNumThreads, getPackedNumThreads()) - : unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedNumThreads, getMixedNumThreads()); - if (!status.succeeded()) - return status; - SmallVector mixedTileSizes; - status = getPackedTileSizes() - ? unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedTileSizes, getPackedTileSizes()) - : unpackSingleIndexResultPayloadOperations( - state, transformOp, mixedTileSizes, getMixedTileSizes()); - if (!status.succeeded()) - return status; - - for (Operation *target : state.getPayloadOps(getTarget())) { - linalg::ForallTilingResult tilingResult; - DiagnosedSilenceableFailure diag = tileToForallOpImpl( - rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes, - getMapping(), tilingResult); - if (!diag.succeeded()) + auto transformOp = cast(getOperation()); + + // Result payload ops. + SmallVector tileOps; + SmallVector tiledOps; + + // Unpack handles. + SmallVector mixedNumThreads; + DiagnosedSilenceableFailure status = + getPackedNumThreads() + ? unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedNumThreads, getPackedNumThreads()) + : unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedNumThreads, getMixedNumThreads()); + if (!status.succeeded()) + return status; + SmallVector mixedTileSizes; + status = getPackedTileSizes() + ? unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedTileSizes, getPackedTileSizes()) + : unpackSingleIndexResultPayloadOperations( + state, transformOp, mixedTileSizes, getMixedTileSizes()); + if (!status.succeeded()) + return status; + + for (Operation *target : state.getPayloadOps(getTarget())) { + linalg::ForallTilingResult tilingResult; + DiagnosedSilenceableFailure diag = tileToForallOpImpl( + rewriter, state, transformOp, target, mixedNumThreads, mixedTileSizes, + getMapping(), tilingResult); + if (!diag.succeeded()) return diag; tileOps.push_back(tilingResult.tileOp); tiledOps.push_back(tilingResult.tiledOp); - } + } transformResults.set(cast(getForallOp()), tileOps); transformResults.set(cast(getTiledOp()), tiledOps); @@ -2833,7 +2809,8 @@ void transform::TileToScfForOp::build(OpBuilder &builder, } DiagnosedSilenceableFailure -transform::TileToScfForOp::apply(TransformResults &transformResults, +transform::TileToScfForOp::apply(transform::TransformRewriter &rewriter, + TransformResults &transformResults, TransformState &state) { ArrayRef tileSizes = getStaticSizes(); @@ -2902,8 +2879,6 @@ transform::TileToScfForOp::apply(TransformResults &transformResults, } tilingOptions.setInterchange(getInterchange()); - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); FailureOr tilingResult = tileUsingSCFForOp(rewriter, tilingInterfaceOp, tilingOptions); if (failed(tilingResult)) @@ -3047,7 +3022,8 @@ private: } // namespace DiagnosedSilenceableFailure -transform::VectorizeOp::applyToOne(Operation *target, +transform::VectorizeOp::applyToOne(transform::TransformRewriter &rewriter, + Operation *target, transform::ApplyToEachResultList &results, transform::TransformState &state) { if (!target->hasTrait()) { @@ -3094,10 +3070,9 @@ transform::VectorizeOp::applyToOne(Operation *target, // MaskedVectorizeOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MaskedVectorizeOp::apply( + transform::TransformRewriter &rewriter, mlir::transform::TransformResults &transformResults, mlir::transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); auto targets = state.getPayloadOps(getTarget()); if (std::empty(targets)) return DiagnosedSilenceableFailure::success(); @@ -3175,7 +3150,8 @@ SmallVector MaskedVectorizeOp::getMixedVectorSizes() { DiagnosedSilenceableFailure transform::HoistRedundantVectorTransfersOp::applyToOne( - func::FuncOp target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, func::FuncOp target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { // WARNING: This hoisting does not model parallelism and is generally // incorrect when used on distributed loops with memref semantics! @@ -3190,10 +3166,9 @@ transform::HoistRedundantVectorTransfersOp::applyToOne( //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne( - linalg::LinalgOp target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, linalg::LinalgOp target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); rewriter.setInsertionPoint(target); auto maybeTransformed = TypeSwitch>>( @@ -3225,10 +3200,9 @@ DiagnosedSilenceableFailure transform::ConvertConv2DToImg2ColOp::applyToOne( DiagnosedSilenceableFailure transform::HoistRedundantTensorSubsetsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); auto forOp = dyn_cast(target); if (forOp) { linalg::hoistRedundantSubsetExtractInsert(rewriter, forOp); @@ -3291,11 +3265,10 @@ DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target, } DiagnosedSilenceableFailure transform::InsertSliceToCopyOp::applyToOne( - Operation *targetOp, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *targetOp, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(targetOp->getContext(), &listener); rewriter.setInsertionPoint(targetOp); if (auto target = dyn_cast(targetOp)) return doit(rewriter, target, results, state); diff --git a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp index 0f84fe4..a46f2a3 100644 --- a/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp +++ b/mlir/lib/Dialect/MemRef/TransformOps/MemRefTransformOps.cpp @@ -60,10 +60,10 @@ void transform::ApplyResolveRankedShapedTypeResultDimsPatternsOp:: //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &transformResults, transform::TransformState &state) { SmallVector results; - IRRewriter rewriter(getContext()); for (Operation *op : state.getPayloadOps(getTarget())) { bool canApplyMultiBuffer = true; auto target = cast(op); @@ -105,7 +105,8 @@ DiagnosedSilenceableFailure transform::MemRefMultiBufferOp::apply( //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { // Gather IVs. SmallVector ivs; @@ -123,7 +124,6 @@ DiagnosedSilenceableFailure transform::MemRefMakeLoopIndependentOp::applyToOne( } // Rewrite IR. - IRRewriter rewriter(target->getContext()); FailureOr replacement = failure(); if (auto allocaOp = dyn_cast(target)) { replacement = memref::replaceWithIndependentOp(rewriter, allocaOp, ivs); diff --git a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp index edd1d32..43c7d7e 100644 --- a/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp +++ b/mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp @@ -35,7 +35,8 @@ void transform::ApplyForLoopCanonicalizationPatternsOp::populatePatterns( // GetParentForOp //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GetParentForOp::apply(transform::TransformResults &results, +transform::GetParentForOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SetVector parents; for (Operation *target : state.getPayloadOps(getTarget())) { @@ -92,7 +93,8 @@ static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b, } DiagnosedSilenceableFailure -transform::LoopOutlineOp::apply(transform::TransformResults &results, +transform::LoopOutlineOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector functions; SmallVector calls; @@ -100,7 +102,6 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results, for (Operation *target : state.getPayloadOps(getTarget())) { Location location = target->getLoc(); Operation *symbolTableOp = SymbolTable::getNearestSymbolTable(target); - IRRewriter rewriter(getContext()); scf::ExecuteRegionOp exec = wrapInExecuteRegion(rewriter, target); if (!exec) { DiagnosedSilenceableFailure diag = emitSilenceableError() @@ -135,11 +136,11 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results, //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::LoopPeelOp::applyToOne(scf::ForOp target, +transform::LoopPeelOp::applyToOne(transform::TransformRewriter &rewriter, + scf::ForOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::ForOp result; - IRRewriter rewriter(target->getContext()); // This helper returns failure when peeling does not occur (i.e. when the IR // is not modified). This is not a failure for the op as the postcondition: // "the loop trip count is divisible by the step" @@ -192,7 +193,8 @@ loopScheduling(scf::ForOp forOp, } DiagnosedSilenceableFailure -transform::LoopPipelineOp::applyToOne(scf::ForOp target, +transform::LoopPipelineOp::applyToOne(transform::TransformRewriter &rewriter, + scf::ForOp target, transform::ApplyToEachResultList &results, transform::TransformState &state) { scf::PipeliningOption options; @@ -203,7 +205,6 @@ transform::LoopPipelineOp::applyToOne(scf::ForOp target, getReadLatency()); }; scf::ForLoopPipeliningPattern pattern(options, target->getContext()); - IRRewriter rewriter(getContext()); rewriter.setInsertionPoint(target); FailureOr patternResult = scf::pipelineForLoop(rewriter, target, options); @@ -219,7 +220,8 @@ transform::LoopPipelineOp::applyToOne(scf::ForOp target, //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::LoopUnrollOp::applyToOne(Operation *op, +transform::LoopUnrollOp::applyToOne(transform::TransformRewriter &rewriter, + Operation *op, transform::ApplyToEachResultList &results, transform::TransformState &state) { LogicalResult result(failure()); @@ -241,7 +243,8 @@ transform::LoopUnrollOp::applyToOne(Operation *op, //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::LoopCoalesceOp::applyToOne(Operation *op, +transform::LoopCoalesceOp::applyToOne(transform::TransformRewriter &rewriter, + Operation *op, transform::ApplyToEachResultList &results, transform::TransformState &state) { LogicalResult result(failure()); @@ -276,12 +279,10 @@ static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op, } DiagnosedSilenceableFailure transform::TakeAssumedBranchOp::applyToOne( - scf::IfOp ifOp, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, scf::IfOp ifOp, + transform::ApplyToEachResultList &results, transform::TransformState &state) { - TrackingListener listener(state, *this); - IRRewriter rewriter(ifOp->getContext(), &listener); rewriter.setInsertionPoint(ifOp); - Region ®ion = getTakeElseBranch() ? ifOp.getElseRegion() : ifOp.getThenRegion(); if (!llvm::hasSingleElement(region)) { diff --git a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp index b4b8229..6ee4bfa 100644 --- a/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp +++ b/mlir/lib/Dialect/Tensor/TransformOps/TensorTransformOps.cpp @@ -123,7 +123,8 @@ void transform::ApplyRewriteTensorOpsAsConstantPatternsOp::populatePatterns( //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::MakeLoopIndependentOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { // Gather IVs. SmallVector ivs; @@ -141,7 +142,6 @@ DiagnosedSilenceableFailure transform::MakeLoopIndependentOp::applyToOne( } // Rewrite IR. - IRRewriter rewriter(target->getContext()); FailureOr replacement = failure(); if (auto padOp = dyn_cast(target)) { replacement = tensor::buildIndependentOp(rewriter, padOp, ivs); diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp index d4c4327..d7205ec 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -157,6 +157,13 @@ LogicalResult transform::TransformDialect::verifyOperationAttribute( } return success(); } + if (attribute.getName().getValue() == kSilenceTrackingFailuresAttrName) { + if (!llvm::isa(attribute.getValue())) { + return op->emitError() + << attribute.getName() << " must be a unit attribute"; + } + return success(); + } return emitError(op->getLoc()) << "unknown attribute: " << attribute.getName(); } diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index 679a860..32637b3 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Operation.h" @@ -896,12 +897,41 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { return diag; } + // Prepare rewriter and listener. + transform::ErrorCheckingTrackingListener trackingListener(*this, transform); + transform::TransformRewriter rewriter(transform->getContext(), + &trackingListener); + // Compute the result but do not short-circuit the silenceable failure case as // we still want the handles to propagate properly so the "suppress" mode can // proceed on a best effort basis. transform::TransformResults results(transform->getNumResults()); - DiagnosedSilenceableFailure result(transform.apply(results, *this)); + DiagnosedSilenceableFailure result(transform.apply(rewriter, results, *this)); compactOpHandles(); + + // Error handling: fail if transform or listener failed. + DiagnosedSilenceableFailure trackingFailure = + trackingListener.checkAndResetError(); + if (!transform->hasTrait() || + transform->hasAttr( + transform::TransformDialect::kSilenceTrackingFailuresAttrName)) { + // Only report failures for ReportTrackingListenerFailuresOpTrait ops. Also + // do not report failures if the above mentioned attribute is set. + if (trackingFailure.isSilenceableFailure()) + (void)trackingFailure.silence(); + trackingFailure = DiagnosedSilenceableFailure::success(); + } + if (!trackingFailure.succeeded()) { + if (result.succeeded()) { + result = std::move(trackingFailure); + } else { + // Transform op errors have precedence, report those first. + if (result.isSilenceableFailure()) + result.attachNote() << "tracking listener also failed: " + << trackingFailure.getMessage(); + (void)trackingFailure.silence(); + } + } if (result.isDefiniteFailure()) return result; @@ -1161,6 +1191,16 @@ bool transform::TransformResults::isSet(unsigned resultNumber) const { // TrackingListener //===----------------------------------------------------------------------===// +transform::TrackingListener::TrackingListener(TransformState &state, + TransformOpInterface op) + : TransformState::Extension(state), transformOp(op) { + if (op) { + for (OpOperand *opOperand : transformOp.getConsumedHandleOpOperands()) { + consumedHandles.insert(opOperand->get()); + } + } +} + Operation *transform::TrackingListener::getCommonDefiningOp(ValueRange values) { Operation *defOp = nullptr; for (Value v : values) { @@ -1267,15 +1307,34 @@ void transform::TrackingListener::notifyOperationReplaced( // Op is not tracked. return; } + + // Helper function to check if the current transform op consumes any handle + // that is mapped to `op`. + // + // Note: If a handle was consumed, there shouldn't be any alive users, so it + // is not really necessary to check for consumed handles. However, in case + // there are indeed alive handles that were consumed (which is undefined + // behavior) and a replacement op could not be found, we want to fail with a + // nicer error message: "op uses a handle invalidated..." instead of "could + // not find replacement op". This nicer error is produced later. + auto handleWasConsumed = [&] { + return llvm::any_of(opHandles, + [&](Value h) { return consumedHandles.contains(h); }); + }; + + // Helper function to check if the handle is alive. auto hasAliveUser = [&]() { - for (Value v : opHandles) + for (Value v : opHandles) { for (Operation *user : v.getUsers()) - if (!happensBefore(user, transformOp)) + if (user != transformOp && !happensBefore(user, transformOp)) return true; + } return false; }; - if (!hasAliveUser()) { - // The op is tracked but the corresponding handles are dead. + + if (!hasAliveUser() || handleWasConsumed()) { + // The op is tracked but the corresponding handles are dead or were + // consumed. Drop the op form the mapping. (void)replacePayloadOp(op, nullptr); return; } @@ -1327,6 +1386,28 @@ void transform::ErrorCheckingTrackingListener::notifyPayloadReplacementNotFound( } //===----------------------------------------------------------------------===// +// TransformRewriter +//===----------------------------------------------------------------------===// + +transform::TransformRewriter::TransformRewriter( + MLIRContext *ctx, ErrorCheckingTrackingListener *listener) + : RewriterBase(ctx), listener(listener) { + setListener(listener); +} + +bool transform::TransformRewriter::hasTrackingFailures() const { + return listener->failed(); +} + +/// Silence all tracking failures that have been encountered so far. +void transform::TransformRewriter::silenceTrackingFailure() { + if (hasTrackingFailures()) { + DiagnosedSilenceableFailure status = listener->checkAndResetError(); + (void)status.silence(); + } +} + +//===----------------------------------------------------------------------===// // Utilities for TransformEachOpTrait. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 3049930..72cf663 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -1,4 +1,4 @@ -//===- TransformDialect.cpp - Transform dialect operations ----------------===// +//===- TransformOps.cpp - Transform dialect operations --------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformOps.h" + #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" @@ -115,7 +116,8 @@ static void forwardEmptyOperands(Block *block, transform::TransformState &state, } DiagnosedSilenceableFailure -transform::AlternativesOp::apply(transform::TransformResults &results, +transform::AlternativesOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector originals; if (Value scopeHandle = getScope()) @@ -222,7 +224,8 @@ LogicalResult transform::AlternativesOp::verify() { //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::AnnotateOp::apply(transform::TransformResults &results, +transform::AnnotateOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector targets = llvm::to_vector(state.getPayloadOps(getTarget())); @@ -258,10 +261,9 @@ void transform::AnnotateOp::getEffects( // ApplyPatternsOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure -transform::ApplyPatternsOp::applyToOne(Operation *target, - ApplyToEachResultList &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure transform::ApplyPatternsOp::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, transform::TransformState &state) { // Make sure that this transform is not applied to itself. Modifying the // transform IR while it is being interpreted is generally dangerous. Even // more so for the ApplyPatternsOp because the GreedyPatternRewriteDriver @@ -282,9 +284,9 @@ transform::ApplyPatternsOp::applyToOne(Operation *target, } // Configure the GreedyPatternRewriteDriver. - ErrorCheckingTrackingListener listener(state, *this); GreedyRewriteConfig config; - config.listener = &listener; + config.listener = + static_cast(rewriter.getListener()); LogicalResult result = failure(); if (target->hasTrait()) { @@ -312,14 +314,6 @@ transform::ApplyPatternsOp::applyToOne(Operation *target, << "greedy pattern application failed"; } - // Check listener state for tracking errors. - if (listener.failed()) { - DiagnosedSilenceableFailure status = listener.checkAndResetError(); - if (getFailOnPayloadReplacementNotFound()) - return status; - (void)status.silence(); - } - return DiagnosedSilenceableFailure::success(); } @@ -346,12 +340,8 @@ void transform::ApplyPatternsOp::getEffects( void transform::ApplyPatternsOp::build( OpBuilder &builder, OperationState &result, Value target, - function_ref bodyBuilder, - bool failOnPayloadReplacementNotFound) { + function_ref bodyBuilder) { result.addOperands(target); - result.getOrAddProperties() - .fail_on_payload_replacement_not_found = - builder.getBoolAttr(failOnPayloadReplacementNotFound); OpBuilder::InsertionGuard g(builder); Region *region = result.addRegion(); @@ -377,10 +367,9 @@ void transform::ApplyCanonicalizationPatternsOp::populatePatterns( // ApplyRegisteredPassOp //===----------------------------------------------------------------------===// -DiagnosedSilenceableFailure -transform::ApplyRegisteredPassOp::applyToOne(Operation *target, - ApplyToEachResultList &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure transform::ApplyRegisteredPassOp::applyToOne( + transform::TransformRewriter &rewriter, Operation *target, + ApplyToEachResultList &results, transform::TransformState &state) { // Make sure that this transform is not applied to itself. Modifying the // transform IR while it is being interpreted is generally dangerous. Even // more so when applying passes because they may perform a wide range of IR @@ -420,7 +409,8 @@ transform::ApplyRegisteredPassOp::applyToOne(Operation *target, //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::CastOp::applyToOne(Operation *target, ApplyToEachResultList &results, +transform::CastOp::applyToOne(transform::TransformRewriter &rewriter, + Operation *target, ApplyToEachResultList &results, transform::TransformState &state) { results.push_back(target); return DiagnosedSilenceableFailure::success(); @@ -482,7 +472,8 @@ matchBlock(Block &block, Operation *op, transform::TransformState &state, } DiagnosedSilenceableFailure -transform::ForeachMatchOp::apply(transform::TransformResults &results, +transform::ForeachMatchOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector> matchActionPairs; @@ -780,7 +771,8 @@ LogicalResult transform::ForeachMatchOp::verifySymbolUses( //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ForeachOp::apply(transform::TransformResults &results, +transform::ForeachOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector> resultOps(getNumResults(), {}); @@ -869,6 +861,7 @@ LogicalResult transform::ForeachOp::verify() { //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SetVector parents; for (Operation *target : state.getPayloadOps(getTarget())) { @@ -892,7 +885,8 @@ DiagnosedSilenceableFailure transform::GetClosestIsolatedParentOp::apply( //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GetConsumersOfResult::apply(transform::TransformResults &results, +transform::GetConsumersOfResult::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { int64_t resultNumber = getResultNumber(); auto payloadOps = state.getPayloadOps(getTarget()); @@ -917,7 +911,8 @@ transform::GetConsumersOfResult::apply(transform::TransformResults &results, //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GetDefiningOp::apply(transform::TransformResults &results, +transform::GetDefiningOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector definingOps; for (Value v : state.getPayloadValues(getTarget())) { @@ -938,7 +933,8 @@ transform::GetDefiningOp::apply(transform::TransformResults &results, //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GetProducerOfOperand::apply(transform::TransformResults &results, +transform::GetProducerOfOperand::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { int64_t operandNumber = getOperandNumber(); SmallVector producers; @@ -966,7 +962,8 @@ transform::GetProducerOfOperand::apply(transform::TransformResults &results, //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::GetResultOp::apply(transform::TransformResults &results, +transform::GetResultOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { int64_t resultNumber = getResultNumber(); SmallVector opResults; @@ -1017,7 +1014,8 @@ applySequenceBlock(Block &block, transform::FailurePropagationMode mode, } DiagnosedSilenceableFailure -transform::IncludeOp::apply(transform::TransformResults &results, +transform::IncludeOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { auto callee = SymbolTable::lookupNearestSymbolFrom( getOperation(), getTarget()); @@ -1155,7 +1153,8 @@ DiagnosedSilenceableFailure transform::MatchOperationNameOp::matchOperation( //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::MatchParamCmpIOp::apply(transform::TransformResults &results, +transform::MatchParamCmpIOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { auto signedAPIntAsString = [&](APInt value) { std::string str; @@ -1241,7 +1240,8 @@ void transform::MatchParamCmpIOp::getEffects( //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ParamConstantOp::apply(transform::TransformResults &results, +transform::ParamConstantOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { results.setParams(cast(getParam()), {getValue()}); return DiagnosedSilenceableFailure::success(); @@ -1252,7 +1252,8 @@ transform::ParamConstantOp::apply(transform::TransformResults &results, //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::MergeHandlesOp::apply(transform::TransformResults &results, +transform::MergeHandlesOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector operations; for (Value operand : getHandles()) @@ -1295,7 +1296,8 @@ OpFoldResult transform::MergeHandlesOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::NamedSequenceOp::apply(transform::TransformResults &results, +transform::NamedSequenceOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { // Nothing to do here. return DiagnosedSilenceableFailure::success(); @@ -1461,7 +1463,8 @@ void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result, } DiagnosedSilenceableFailure -transform::SplitHandleOp::apply(transform::TransformResults &results, +transform::SplitHandleOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { int64_t numPayloadOps = llvm::range_size(state.getPayloadOps(getHandle())); auto produceNumOpsError = [&]() { @@ -1522,7 +1525,8 @@ LogicalResult transform::SplitHandleOp::verify() { //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::ReplicateOp::apply(transform::TransformResults &results, +transform::ReplicateOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { unsigned numRepetitions = llvm::range_size(state.getPayloadOps(getPattern())); for (const auto &en : llvm::enumerate(getHandles())) { @@ -1562,7 +1566,8 @@ void transform::ReplicateOp::getEffects( //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::SequenceOp::apply(transform::TransformResults &results, +transform::SequenceOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { // Map the entry block argument to the list of operations. auto scope = state.make_region_scope(*getBodyBlock()->getParent()); @@ -1852,7 +1857,8 @@ void transform::PrintOp::build(OpBuilder &builder, OperationState &result, } DiagnosedSilenceableFailure -transform::PrintOp::apply(transform::TransformResults &results, +transform::PrintOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { llvm::outs() << "[[[ IR printer: "; if (getName().has_value()) diff --git a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp index 5126d79..5585afc 100644 --- a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp +++ b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp @@ -138,7 +138,8 @@ transform::PDLMatchHooks::getPDLConstraintHooks() const { //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::PDLMatchOp::apply(transform::TransformResults &results, +transform::PDLMatchOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); assert(extension && @@ -167,7 +168,8 @@ void transform::PDLMatchOp::getEffects( //===----------------------------------------------------------------------===// DiagnosedSilenceableFailure -transform::WithPDLPatternsOp::apply(transform::TransformResults &results, +transform::WithPDLPatternsOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { TransformOpInterface transformOp = nullptr; for (Operation &nested : getBody().front()) { diff --git a/mlir/test/Dialect/Transform/test-pattern-application.mlir b/mlir/test/Dialect/Transform/test-pattern-application.mlir index 5d417f0..062ec6e 100644 --- a/mlir/test/Dialect/Transform/test-pattern-application.mlir +++ b/mlir/test/Dialect/Transform/test-pattern-application.mlir @@ -88,7 +88,7 @@ transform.sequence failures(propagate) { %1 = transform.structured.match ops{["test.foo"]} in %arg1 : (!transform.any_op) -> !transform.any_op transform.apply_patterns to %0 { transform.apply_patterns.transform.test_patterns - } {fail_on_payload_replacement_not_found = false} : !transform.any_op + } {transform.silence_tracking_failures} : !transform.any_op transform.annotate %1 "annotated" : !transform.any_op } diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index 835cbb3a..8f53b9c 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -45,7 +45,8 @@ public: return llvm::StringLiteral("transform.test_transform_op"); } - DiagnosedSilenceableFailure apply(transform::TransformResults &results, + DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { InFlightDiagnostic remark = emitRemark() << "applying transformation"; if (Attribute message = getMessage()) @@ -98,7 +99,8 @@ public: "transform.test_transform_unrestricted_op_no_interface"); } - DiagnosedSilenceableFailure apply(transform::TransformResults &results, + DiagnosedSilenceableFailure apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } @@ -110,6 +112,7 @@ public: DiagnosedSilenceableFailure mlir::test::TestProduceSelfHandleOrForwardOperandOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { if (getOperation()->getNumOperands() != 0) { results.set(cast(getResult()), @@ -129,6 +132,7 @@ void mlir::test::TestProduceSelfHandleOrForwardOperandOp::getEffects( DiagnosedSilenceableFailure mlir::test::TestProduceValueHandleToSelfOperand::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { results.setValues(llvm::cast(getOut()), getIn()); return DiagnosedSilenceableFailure::success(); @@ -143,7 +147,8 @@ void mlir::test::TestProduceValueHandleToSelfOperand::getEffects( DiagnosedSilenceableFailure mlir::test::TestProduceValueHandleToResult::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { if (target->getNumResults() <= getNumber()) return emitSilenceableError() << "payload has no result #" << getNumber(); @@ -160,7 +165,8 @@ void mlir::test::TestProduceValueHandleToResult::getEffects( DiagnosedSilenceableFailure mlir::test::TestProduceValueHandleToArgumentOfParentBlock::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { if (!target->getBlock()) return emitSilenceableError() << "payload has no parent block"; @@ -183,7 +189,8 @@ bool mlir::test::TestConsumeOperand::allowsRepeatedHandleOperands() { } DiagnosedSilenceableFailure -mlir::test::TestConsumeOperand::apply(transform::TransformResults &results, +mlir::test::TestConsumeOperand::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } @@ -197,6 +204,7 @@ void mlir::test::TestConsumeOperand::getEffects( } DiagnosedSilenceableFailure mlir::test::TestConsumeOperandOfOpKindOrFail::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto payload = state.getPayloadOps(getOperand()); assert(llvm::hasSingleElement(payload) && "expected a single target op"); @@ -237,6 +245,7 @@ void mlir::test::TestSucceedIfOperandOfOpKind::getEffects( } DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto payload = state.getPayloadOps(getOperand()); for (Operation *op : payload) @@ -252,6 +261,7 @@ void mlir::test::TestPrintRemarkAtOperandOp::getEffects( } DiagnosedSilenceableFailure mlir::test::TestPrintRemarkAtOperandValue::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { ArrayRef values = state.getPayloadValues(getIn()); for (Value value : values) { @@ -277,15 +287,16 @@ void mlir::test::TestPrintRemarkAtOperandValue::getEffects( transform::onlyReadsPayload(effects); } -DiagnosedSilenceableFailure -mlir::test::TestAddTestExtensionOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure mlir::test::TestAddTestExtensionOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { state.addExtension(getMessageAttr()); return DiagnosedSilenceableFailure::success(); } DiagnosedSilenceableFailure mlir::test::TestCheckIfTestExtensionPresentOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); if (!extension) { @@ -316,6 +327,7 @@ void mlir::test::TestCheckIfTestExtensionPresentOp::getEffects( } DiagnosedSilenceableFailure mlir::test::TestRemapOperandPayloadToSelfOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { auto *extension = state.getExtension(); if (!extension) @@ -337,14 +349,15 @@ void mlir::test::TestRemapOperandPayloadToSelfOp::getEffects( } DiagnosedSilenceableFailure mlir::test::TestRemoveTestExtensionOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { state.removeExtension(); return DiagnosedSilenceableFailure::success(); } -DiagnosedSilenceableFailure -mlir::test::TestReversePayloadOpsOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure mlir::test::TestReversePayloadOpsOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { auto payloadOps = state.getPayloadOps(getTarget()); auto reversedOps = llvm::to_vector(llvm::reverse(payloadOps)); results.set(llvm::cast(getResult()), reversedOps); @@ -352,6 +365,7 @@ mlir::test::TestReversePayloadOpsOp::apply(transform::TransformResults &results, } DiagnosedSilenceableFailure mlir::test::TestTransformOpWithRegions::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } @@ -361,6 +375,7 @@ void mlir::test::TestTransformOpWithRegions::getEffects( DiagnosedSilenceableFailure mlir::test::TestBranchingTransformOpTerminator::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { return DiagnosedSilenceableFailure::success(); } @@ -369,6 +384,7 @@ void mlir::test::TestBranchingTransformOpTerminator::getEffects( SmallVectorImpl &effects) {} DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { emitRemark() << getRemark(); for (Operation *op : state.getPayloadOps(getTarget())) @@ -386,7 +402,8 @@ void mlir::test::TestEmitRemarkAndEraseOperandOp::getEffects( } DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(OpBuilder(target).create(opState)); @@ -395,7 +412,8 @@ DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfResultsOp::applyToOne( DiagnosedSilenceableFailure mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { static int count = 0; if (count++ == 0) { @@ -407,7 +425,8 @@ mlir::test::TestWrongNumberOfMultiResultsOp::applyToOne( DiagnosedSilenceableFailure mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(OpBuilder(target).create(opState)); @@ -417,7 +436,8 @@ mlir::test::TestCorrectNumberOfMultiResultsOp::applyToOne( DiagnosedSilenceableFailure mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { OperationState opState(target->getLoc(), "foo"); results.push_back(nullptr); @@ -427,7 +447,8 @@ mlir::test::TestMixedNullAndNonNullResultsOp::applyToOne( DiagnosedSilenceableFailure mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne( - Operation *target, transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + transform::ApplyToEachResultList &results, transform::TransformState &state) { if (target->hasAttr("target_me")) return DiagnosedSilenceableFailure::success(); @@ -436,6 +457,7 @@ mlir::test::TestMixedSuccessAndSilenceableOp::applyToOne( DiagnosedSilenceableFailure mlir::test::TestPrintNumberOfAssociatedPayloadIROps::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { if (!getHandle()) emitRemark() << 0; @@ -449,7 +471,8 @@ void mlir::test::TestPrintNumberOfAssociatedPayloadIROps::getEffects( } DiagnosedSilenceableFailure -mlir::test::TestCopyPayloadOp::apply(transform::TransformResults &results, +mlir::test::TestCopyPayloadOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { results.set(llvm::cast(getCopy()), state.getPayloadOps(getHandle())); @@ -498,6 +521,7 @@ void mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::getEffects( DiagnosedSilenceableFailure mlir::test::TestReportNumberOfTrackedHandlesNestedUnder::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { int64_t count = 0; for (Operation *op : state.getPayloadOps(getTarget())) { @@ -520,7 +544,8 @@ void mlir::test::TestPrintParamOp::getEffects( } DiagnosedSilenceableFailure -mlir::test::TestPrintParamOp::apply(transform::TransformResults &results, +mlir::test::TestPrintParamOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { std::string str; llvm::raw_string_ostream os(str); @@ -537,7 +562,8 @@ mlir::test::TestPrintParamOp::apply(transform::TransformResults &results, } DiagnosedSilenceableFailure -mlir::test::TestAddToParamOp::apply(transform::TransformResults &results, +mlir::test::TestAddToParamOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { SmallVector values(/*Size=*/1, /*Value=*/0); if (Value param = getParam()) { @@ -559,6 +585,7 @@ mlir::test::TestAddToParamOp::apply(transform::TransformResults &results, DiagnosedSilenceableFailure mlir::test::TestProduceParamWithNumberOfTestOps::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { Builder builder(getContext()); SmallVector result = llvm::to_vector( @@ -577,6 +604,7 @@ mlir::test::TestProduceParamWithNumberOfTestOps::apply( DiagnosedSilenceableFailure mlir::test::TestProduceIntegerParamWithTypeOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { Attribute zero = IntegerAttr::get(getType(), 0); results.setParams(llvm::cast(getResult()), zero); @@ -599,7 +627,8 @@ void mlir::test::TestProduceTransformParamOrForwardOperandOp::getEffects( DiagnosedSilenceableFailure mlir::test::TestProduceTransformParamOrForwardOperandOp::applyToOne( - Operation *target, ::transform::ApplyToEachResultList &results, + transform::TransformRewriter &rewriter, Operation *target, + ::transform::ApplyToEachResultList &results, ::transform::TransformState &state) { Builder builder(getContext()); if (getFirstResultIsParam()) { @@ -625,6 +654,7 @@ void mlir::test::TestProduceNullPayloadOp::getEffects( } DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { SmallVector null({nullptr}); results.set(llvm::cast(getOut()), null); @@ -632,6 +662,7 @@ DiagnosedSilenceableFailure mlir::test::TestProduceNullPayloadOp::apply( } DiagnosedSilenceableFailure mlir::test::TestProduceEmptyPayloadOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { results.set(cast(getOut()), {}); return DiagnosedSilenceableFailure::success(); @@ -642,9 +673,9 @@ void mlir::test::TestProduceNullParamOp::getEffects( transform::producesHandle(getOut(), effects); } -DiagnosedSilenceableFailure -mlir::test::TestProduceNullParamOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure mlir::test::TestProduceNullParamOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { results.setParams(llvm::cast(getOut()), Attribute()); return DiagnosedSilenceableFailure::success(); } @@ -654,9 +685,9 @@ void mlir::test::TestProduceNullValueOp::getEffects( transform::producesHandle(getOut(), effects); } -DiagnosedSilenceableFailure -mlir::test::TestProduceNullValueOp::apply(transform::TransformResults &results, - transform::TransformState &state) { +DiagnosedSilenceableFailure mlir::test::TestProduceNullValueOp::apply( + transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { results.setValues(llvm::cast(getOut()), Value()); return DiagnosedSilenceableFailure::success(); } @@ -676,6 +707,7 @@ void mlir::test::TestRequiredMemoryEffectsOp::getEffects( } DiagnosedSilenceableFailure mlir::test::TestRequiredMemoryEffectsOp::apply( + transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { results.set(llvm::cast(getOut()), state.getPayloadOps(getIn())); return DiagnosedSilenceableFailure::success(); @@ -694,10 +726,9 @@ void mlir::test::TestDummyPayloadOp::getEffects( } DiagnosedSilenceableFailure -mlir::test::TestTrackedRewriteOp::apply(transform::TransformResults &results, +mlir::test::TestTrackedRewriteOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, transform::TransformState &state) { - transform::ErrorCheckingTrackingListener listener(state, *this); - IRRewriter rewriter(getContext(), &listener); int64_t numIterations = 0; // `getPayloadOps` returns an iterator that skips ops that are erased in the diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td index 7611165..fda6190 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -72,6 +72,7 @@ def TestProduceValueHandleToResult let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -89,6 +90,7 @@ def TestProduceValueHandleToArgumentOfParentBlock let cppNamespace = "::mlir::test"; let extraClassDeclaration = [{ ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, ::mlir::Operation *target, ::mlir::transform::ApplyToEachResultList &results, ::mlir::transform::TransformState &state); @@ -116,6 +118,7 @@ def TestConsumeOperandEach : Op