From 30f22429d38944e126db75296a1ffc6c12c7b87a Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Wed, 20 Apr 2022 12:57:23 +0200 Subject: [PATCH] [mlir] Connect Transform dialect to PDL This introduces a pair of ops to the Transform dialect that connect it to PDL patterns. Transform dialect relies on PDL for matching the Payload IR ops that are about to be transformed. For this purpose, it provides a container op for patterns, a "pdl_match" op and transform interface implementations that call into the pattern matching infrastructure. To enable the caching of compiled patterns, this also provides the extension mechanism for TransformState. Extensions allow one to store additional information in the TransformState and thus communicate it between different Transform dialect operations when they are applied. They can be added and removed when applying transform ops. An extension containing a symbol table in which the pattern names are resolved and a pattern compilation cache is introduced as the first client. Depends On D123664 Reviewed By: Mogball Differential Revision: https://reviews.llvm.org/D124007 --- .../Dialect/Transform/IR/TransformDialect.h | 26 +++ .../Dialect/Transform/IR/TransformDialect.td | 36 ++- .../Transform/IR/TransformInterfaces.h | 135 +++++++++++ .../mlir/Dialect/Transform/IR/TransformOps.h | 1 + .../mlir/Dialect/Transform/IR/TransformOps.td | 79 ++++++- mlir/lib/Dialect/Transform/IR/CMakeLists.txt | 1 + .../Dialect/Transform/IR/TransformDialect.cpp | 12 + .../Transform/IR/TransformInterfaces.cpp | 58 +++++ .../lib/Dialect/Transform/IR/TransformOps.cpp | 210 ++++++++++++++++-- mlir/test/Dialect/Transform/ops-invalid.mlir | 67 +++++- mlir/test/Dialect/Transform/ops.mlir | 20 ++ .../Dialect/Transform/test-interpreter.mlir | 28 +++ .../TestTransformDialectExtension.cpp | 40 +++- .../TestTransformDialectExtension.td | 8 + .../llvm-project-overlay/mlir/BUILD.bazel | 2 + 15 files changed, 692 insertions(+), 31 deletions(-) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h index 628a46535f33..d1607b57622f 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h @@ -9,9 +9,13 @@ #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/Dialect.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" +#include "llvm/ADT/StringMap.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h.inc" @@ -57,6 +61,7 @@ public: loader(context); for (const Initializer &init : opInitializers) init(transformDialect); + transformDialect->mergeInPDLMatchHooks(std::move(pdlMatchConstraintFns)); } protected: @@ -88,9 +93,30 @@ protected: [](MLIRContext *context) { context->loadDialect(); }); } + /// Injects the named constraint to make it available for use with the + /// PDLMatchOp in the transform dialect. + void registerPDLMatchConstraintFn(StringRef name, + PDLConstraintFunction &&fn) { + pdlMatchConstraintFns.try_emplace(name, + std::forward(fn)); + } + template + void registerPDLMatchConstraintFn(StringRef name, ConstraintFnTy &&fn) { + pdlMatchConstraintFns.try_emplace( + name, ::mlir::detail::pdl_function_builder::buildConstraintFn( + std::forward(fn))); + } + private: SmallVector opInitializers; SmallVector dialectLoaders; + + /// A list of constraints that should be made availble to PDL patterns + /// processed by PDLMatchOp in the Transform dialect. + /// + /// Declared as mutable so its contents can be moved in the `apply` const + /// method, which is only called once. + mutable llvm::StringMap pdlMatchConstraintFns; }; } // namespace transform diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td index aca6497bcb9c..d695b850474a 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -84,6 +84,13 @@ def Transform_Dialect : Dialect { `LoopTransformDialectExtension` in the cases above. Unprefixed operation names are reserved for ops defined directly in the Transform dialect. + Overall, Transform IR ops are expected to be contained in a single top-level + op. Such top-level ops specifie how to apply the transformations described + by operations they contain, e.g., `transform.sequence` executes + transformations one by one and fails if any of them fails. Such ops are + expected to have the `PossibleTopLevelTransformOpTrait` and may be used + without arguments. + ## Intended Use and Integrations The transformation control infrastructure provided by this dialect is @@ -163,13 +170,32 @@ def Transform_Dialect : Dialect { let cppNamespace = "::mlir::transform"; let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed; + let dependentDialects = [ + "::mlir::pdl::PDLDialect", + "::mlir::pdl_interp::PDLInterpDialect", + ]; + let extraClassDeclaration = [{ - // Make addOperations available to the TransformDialectExtension class. + /// Returns the named PDL constraint functions available in the dialect + /// as a map from their name to the function. + const ::llvm::StringMap<::mlir::PDLConstraintFunction> & + getPDLConstraintHooks() const; + private: + // Make addOperations available to the TransformDialectExtension class. using ::mlir::Dialect::addOperations; template friend class TransformDialectExtension; + + /// Takes ownership of the named PDL constraint function from the given + /// map and makes them available for use by the operations in the dialect. + void mergeInPDLMatchHooks( + ::llvm::StringMap<::mlir::PDLConstraintFunction> &&constraintFns); + + /// A container for PDL constraint function that can be used by + /// operations in this dialect. + PDLPatternModule pdlMatchHooks; }]; } @@ -178,4 +204,12 @@ def Transform_Dialect : Dialect { class TransformDialectOp traits = []> : Op; +// Trait for operations that may be top-level operations in Transform IR. +// Operations must have one single-block region and must be usable without +// operands. See the C++ definition of the trait for more information. +def PossibleTopLevelTransformOpTrait + : NativeOpTrait<"PossibleTopLevelTransformOpTrait"> { + let cppNamespace = "::mlir::transform"; +} + #endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index f109ad599b84..49d3bcd8be45 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -140,6 +140,89 @@ public: }; friend class RegionScope; + /// Base class for TransformState extensions that allow TransformState to + /// contain user-specified information in the state object. Clients are + /// expected to derive this class, add the desired fields, and make the + /// derived class compatible with the MLIR TypeID mechanism: + /// + /// ```mlir + /// class MyExtension final : public TransformState::Extension { + /// public: + /// MyExtension(TranfsormState &state, int myData) + /// : Extension(state) {...} + /// private: + /// int mySupplementaryData; + /// }; + /// ``` + /// + /// Instances of this and derived classes are not expected to be created by + /// the user, instead they are directly constructed within a TransformState. A + /// TransformState can only contain one extension with the given TypeID. + /// Extensions can be obtained from a TransformState instance, and can be + /// removed when they are no longer required. + /// + /// ```mlir + /// transformState.addExtension(/*myData=*/42); + /// MyExtension *ext = transformState.getExtension(); + /// ext->doSomething(); + /// ``` + class Extension { + // Allow TransformState to allocate Extensions. + friend class TransformState; + + public: + /// Base virtual destructor. + // Out-of-line definition ensures symbols are emitted in a single object + // file. + virtual ~Extension(); + + protected: + /// Constructs an extension of the given TransformState object. + Extension(TransformState &state) : state(state) {} + + private: + /// Back-reference to the state that is being extended. + TransformState &state; + }; + + /// Adds a new Extension of the type specified as template parameter, + /// constructing it with the arguments provided. The extension is owned by the + /// TransformState. It is expected that the state does not already have an + /// extension of the same type. Extension constructors are expected to take + /// a reference to TransformState as first argument, automatically supplied + /// by this call. + template + Ty &addExtension(Args &&...args) { + static_assert( + std::is_base_of::value, + "only an class derived from TransformState::Extension is allowed here"); + auto ptr = std::make_unique(*this, std::forward(args)...); + auto result = extensions.try_emplace(TypeID::get(), std::move(ptr)); + assert(result.second && "extension already added"); + return *static_cast(result.first->second.get()); + } + + /// Returns the extension of the specified type. + template + Ty *getExtension() { + static_assert( + std::is_base_of::value, + "only an class derived from TransformState::Extension is allowed here"); + auto iter = extensions.find(TypeID::get()); + if (iter == extensions.end()) + return nullptr; + return static_cast(iter->second.get()); + } + + /// Removes the extension of the specified type. + template + void removeExtension() { + static_assert( + std::is_base_of::value, + "only an class derived from TransformState::Extension is allowed here"); + extensions.erase(TypeID::get()); + } + private: /// Identifier for storing top-level value in the `operations` mapping. static constexpr Value kTopLevelValue = Value(); @@ -196,6 +279,10 @@ private: /// the region in which the transform IR values are defined. llvm::SmallDenseMap mappings; + /// Extensions attached to the TransformState, identified by the TypeID of + /// their type. Only one extension of any given type is allowed. + DenseMap> extensions; + /// The top-level operation that contains all payload IR, typically a module. Operation *topLevel; @@ -241,6 +328,54 @@ TransformState::RegionScope TransformState::make_region_scope(Region ®ion) { return RegionScope(*this, region); } +namespace detail { +/// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait +/// to either the list of operations associated with its operand or the root of +/// the payload IR, depending on what is available in the context. +LogicalResult +mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, + Operation *op); + +/// Verification hook for PossibleTopLevelTransformOpTrait. +LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op); +} // namespace detail + +/// This trait is supposed to be attached to Transform dialect operations that +/// can be standalone top-level transforms. Such operations typically contain +/// other Transform dialect operations that can be executed following some +/// control flow logic specific to the current operation. The operations with +/// this trait are expected to have exactly one single-block region with one +/// argument of PDL Operation type. The operations are also expected to be valid +/// without operands, in which case they are considered top-level, and with one +/// or more arguments, in which case they are considered nested. Top-level +/// operations have the block argument of the entry block in the Transform IR +/// correspond to the root operation of Payload IR. Nested operations have the +/// block argument of the entry block in the Transform IR correspond to a list +/// of Payload IR operations mapped to the first operand of the Transform IR +/// operation. The operation must implement TransformOpInterface. +template +class PossibleTopLevelTransformOpTrait + : public OpTrait::TraitBase { +public: + /// Verifies that `op` satisfies the invariants of this trait. Not expected to + /// be called directly. + static LogicalResult verifyTrait(Operation *op) { + return detail::verifyPossibleTopLevelTransformOpTrait(op); + } + + /// Returns the single block of the op's only region. + Block *getBodyBlock() { return &this->getOperation()->getRegion(0).front(); } + + /// Sets up the mapping between the entry block of the only region of this op + /// and the relevant list of Payload IR operations in the given state. The + /// state is expected to be already scoped at the region of this operation. + /// Returns failure if the mapping failed, e.g., the value is already mapped. + LogicalResult mapBlockArguments(TransformState &state) { + return detail::mapPossibleTopLevelTransformOpBlockArguments( + state, this->getOperation()); + } +}; + } // namespace transform } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h index a12b5abd8ffc..9714b77ad968 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -13,6 +13,7 @@ #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/OpDefinition.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" #define GET_OP_CLASSES #include "mlir/Dialect/Transform/IR/TransformOps.h.inc" diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index 246de281568b..489197fe46f6 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -10,12 +10,40 @@ #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/SymbolInterfaces.td" include "mlir/Dialect/PDL/IR/PDLTypes.td" include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +def PDLMatchOp : TransformDialectOp<"pdl_match", + [DeclareOpInterfaceMethods]> { + let summary = "Finds ops that match the named PDL pattern"; + let description = [{ + Find Payload IR ops nested within the Payload IR op associated with the + operand that match the PDL pattern identified by its name. The pattern is + expected to be defined in the closest surrounding `WithPDLPatternsOp`. + + Produces a Transform IR value associated with the list of Payload IR ops + that matched the pattern. The order of results in the list is that of the + Operation::walk, clients are advised not to rely on a specific order though. + If the operand is assocaited with multiple Payload IR ops, finds matching + ops nested within each of those and produces a single list containing all + of the matched ops. + + The tranfsormation is considered successful regardless of whether some + Payload IR ops actually matched the pattern and only fails if the pattern + could not be looked up or compiled. + }]; + + let arguments = (ins PDL_Operation:$root, SymbolRefAttr:$pattern_name); + let results = (outs PDL_Operation:$matched); + + let assemblyFormat = "$pattern_name `in` $root attr-dict"; +} + def SequenceOp : TransformDialectOp<"sequence", [DeclareOpInterfaceMethods, OpAsmOpInterface, + PossibleTopLevelTransformOpTrait, SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> { let summary = "Contains a sequence of other transform ops to apply"; let description = [{ @@ -48,13 +76,60 @@ def SequenceOp : TransformDialectOp<"sequence", let extraClassDeclaration = [{ /// Allow the dialect prefix to be omitted. static StringRef getDefaultDialect() { return "transform"; } + }]; + + let hasVerifier = 1; +} - Block *getBodyBlock() { - return &getBody().front(); +def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns", + [DeclareOpInterfaceMethods, NoTerminator, + OpAsmOpInterface, PossibleTopLevelTransformOpTrait, SymbolTable]> { + let summary = "Contains PDL patterns available for use in transforms"; + let description = [{ + This op contains a set of named PDL patterns that are available for the + Transform dialect operations to be used for pattern matching. For example, + PDLMatchOp can be used to produce a Transform IR value associated with all + Payload IR operations that match the pattern as follows: + + ```mlir + transform.with_pdl_patterns { + ^bb0(%arg0: !pdl.operation): + pdl.pattern @my_pattern : benefit(1) { + %0 = pdl.operation //... + // Regular PDL goes here. + pdl.rewrite %0 with "transform.dialect" + } + + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %1 = pdl_match @my_pattern in %arg1 + // Use %1 as handle + } } + ``` + + Note that the pattern is expected to finish with a `pdl.rewrite` terminator + that points to the custom rewriter named "transform.dialect". The rewriter + actually does nothing, but the transform application will keep track of the + operations that matched the pattern. + + This op is expected to contain `pdl.pattern` operations and exactly one + another Transform dialect operation that gets executed with all patterns + available. This op is a possible top-level Transform IR op, the argument of + its entry block corresponds to either the root op of the payload IR or the + ops associated with its operand when provided. }]; + let arguments = (ins Optional:$root); + let regions = (region SizedRegion<1>:$body); + let assemblyFormat = "($root^)? attr-dict-with-keyword regions"; + let hasVerifier = 1; + + let extraClassDeclaration = [{ + /// Allow the dialect prefix to be omitted. + static StringRef getDefaultDialect() { return "transform"; } + }]; } def YieldOp : TransformDialectOp<"yield", [Terminator]> { diff --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt index 760ce9364b0a..a5ac053c9119 100644 --- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt @@ -11,4 +11,5 @@ add_mlir_dialect_library(MLIRTransformDialect MLIRIR MLIRPDL MLIRPDLInterp + MLIRRewrite ) diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp index a566cb91ee75..513f8736237a 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -19,3 +19,15 @@ void transform::TransformDialect::initialize() { #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" >(); } + +void transform::TransformDialect::mergeInPDLMatchHooks( + llvm::StringMap &&constraintFns) { + // Steal the constraint functions form the given map. + for (auto &it : constraintFns) + pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second)); +} + +const llvm::StringMap & +transform::TransformDialect::getPDLConstraintHooks() const { + return pdlMatchHooks.getConstraintFunctions(); +} diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index 7df299a94cfb..2c9a2870dd61 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Operation.h" #include "llvm/ADT/ScopeExit.h" @@ -117,6 +118,8 @@ transform::TransformState::applyTransform(TransformOpInterface transform) { return success(); } +transform::TransformState::Extension::~Extension() = default; + //===----------------------------------------------------------------------===// // TransformResults //===----------------------------------------------------------------------===// @@ -145,6 +148,61 @@ transform::TransformResults::get(unsigned resultNumber) const { return segments[resultNumber]; } +//===----------------------------------------------------------------------===// +// Utilities for PossibleTopLevelTransformOpTrait. +//===----------------------------------------------------------------------===// + +LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( + TransformState &state, Operation *op) { + SmallVector targets; + if (op->getNumOperands() != 0) + llvm::append_range(targets, state.getPayloadOps(op->getOperand(0))); + else + targets.push_back(state.getTopLevel()); + + return state.mapBlockArguments(op->getRegion(0).front().getArgument(0), + targets); +} + +LogicalResult +transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) { + // Attaching this trait without the interface is a misuse of the API, but it + // cannot be caught via a static_assert because interface registration is + // dynamic. + assert(isa(op) && + "should implement TransformOpInterface to have " + "PossibleTopLevelTransformOpTrait"); + + if (op->getNumRegions() != 1) + return op->emitOpError() << "expects one region"; + + Region *bodyRegion = &op->getRegion(0); + if (!llvm::hasNItems(*bodyRegion, 1)) + return op->emitOpError() << "expects a single-block region"; + + Block *body = &bodyRegion->front(); + if (body->getNumArguments() != 1 || + !body->getArgumentTypes()[0].isa()) { + return op->emitOpError() + << "expects the entry block to have one argument of type " + << pdl::OperationType::get(op->getContext()); + } + + if (auto *parent = + op->getParentWithTrait()) { + if (op->getNumOperands() == 0) { + InFlightDiagnostic diag = + op->emitOpError() + << "expects the root operation to be provided for a nested op"; + diag.attachNote(parent->getLoc()) + << "nested in another possible top-level op"; + return diag; + } + } + + return success(); +} + //===----------------------------------------------------------------------===// // Generated interface implementation. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index 3018e3b5b68b..c68ba11e3a06 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -7,26 +7,143 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Dialect/PDL/IR/PDLOps.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/IR/Builders.h" - #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Rewrite/PatternApplicator.h" +#include "llvm/ADT/ScopeExit.h" using namespace mlir; #define GET_OP_CLASSES #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" -LogicalResult transform::SequenceOp::apply(transform::TransformResults &results, +//===----------------------------------------------------------------------===// +// PatternApplicatorExtension +//===----------------------------------------------------------------------===// + +namespace { +/// A simple pattern rewriter that can be constructed from a context. This is +/// necessary to apply patterns to a specific op locally. +class TrivialPatternRewriter : public PatternRewriter { +public: + explicit TrivialPatternRewriter(MLIRContext *context) + : PatternRewriter(context) {} +}; + +/// A TransformState extension that keeps track of compiled PDL pattern sets. +/// This is intended to be used along the WithPDLPatterns op. The extension +/// can be constructed given an operation that has a SymbolTable trait and +/// contains pdl::PatternOp instances. The patterns are compiled lazily and one +/// by one when requested; this behavior is subject to change. +class PatternApplicatorExtension : public transform::TransformState::Extension { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension) + + /// Creates the extension for patterns contained in `patternContainer`. + explicit PatternApplicatorExtension(transform::TransformState &state, + Operation *patternContainer) + : Extension(state), patterns(patternContainer) {} + + /// Appends to `results` the operations contained in `root` that matched the + /// PDL pattern with the given name. Note that `root` may or may not be the + /// operation that contains PDL patterns. Reports an error if the pattern + /// cannot be found. Note that when no operations are matched, this still + /// succeeds as long as the pattern exists. + LogicalResult findAllMatches(StringRef patternName, Operation *root, + SmallVectorImpl &results); + +private: + /// Map from the pattern name to a singleton set of rewrite patterns that only + /// contains the pattern with this name. Populated when the pattern is first + /// requested. + // TODO: reconsider the efficiency of this storage when more usage data is + // available. Storing individual patterns in a set and triggering compilation + // for each of them has overhead. So does compiling a large set of patterns + // only to apply a handlful of them. + llvm::StringMap compiledPatterns; + + /// A symbol table operation containing the relevant PDL patterns. + SymbolTable patterns; +}; + +LogicalResult PatternApplicatorExtension::findAllMatches( + StringRef patternName, Operation *root, + SmallVectorImpl &results) { + auto it = compiledPatterns.find(patternName); + if (it == compiledPatterns.end()) { + auto patternOp = patterns.lookup(patternName); + if (!patternOp) + return failure(); + + OwningOpRef pdlModuleOp = ModuleOp::create(patternOp.getLoc()); + patternOp->moveBefore(pdlModuleOp->getBody(), + pdlModuleOp->getBody()->end()); + PDLPatternModule patternModule(std::move(pdlModuleOp)); + + // Merge in the hooks owned by the dialect. Make a copy as they may be + // also used by the following operations. + auto *dialect = + root->getContext()->getLoadedDialect(); + for (const auto &pair : dialect->getPDLConstraintHooks()) + patternModule.registerConstraintFunction(pair.first(), pair.second); + + // Register a noop rewriter because PDL requires patterns to end with some + // rewrite call. + patternModule.registerRewriteFunction( + "transform.dialect", [](PatternRewriter &, Operation *) {}); + + it = compiledPatterns + .try_emplace(patternOp.getName(), std::move(patternModule)) + .first; + } + + PatternApplicator applicator(it->second); + TrivialPatternRewriter rewriter(root->getContext()); + applicator.applyDefaultCostModel(); + root->walk([&](Operation *op) { + if (succeeded(applicator.matchAndRewrite(op, rewriter))) + results.push_back(op); + }); + + return success(); +} +} // namespace + +//===----------------------------------------------------------------------===// +// PDLMatchOp +//===----------------------------------------------------------------------===// + +LogicalResult transform::PDLMatchOp::apply(transform::TransformResults &results, transform::TransformState &state) { + auto *extension = state.getExtension(); + assert(extension && + "expected PatternApplicatorExtension to be attached by the parent op"); SmallVector targets; - if (getRoot()) - llvm::append_range(targets, state.getPayloadOps(getRoot())); - else - targets.push_back(state.getTopLevel()); + for (Operation *root : state.getPayloadOps(getRoot())) { + if (failed(extension->findAllMatches( + getPatternName().getLeafReference().getValue(), root, targets))) { + return emitOpError() << "could not find pattern '" << getPatternName() + << "'"; + } + } + results.set(getResult().cast(), targets); + return success(); +} +//===----------------------------------------------------------------------===// +// SequenceOp +//===----------------------------------------------------------------------===// + +LogicalResult transform::SequenceOp::apply(transform::TransformResults &results, + transform::TransformState &state) { // Map the entry block argument to the list of operations. auto scope = state.make_region_scope(*getBodyBlock()->getParent()); - if (failed(state.mapBlockArguments(getBodyBlock()->getArgument(0), targets))) + if (failed(mapBlockArguments(state))) return failure(); // Apply the sequenced ops one by one. @@ -48,23 +165,6 @@ LogicalResult transform::SequenceOp::apply(transform::TransformResults &results, } LogicalResult transform::SequenceOp::verify() { - if (getBodyBlock()->getNumArguments() != 1 || - !getBodyBlock()->getArgumentTypes()[0].isa()) { - return emitOpError() - << "expected the entry block to have one argument of type " - << pdl::OperationType::get(getContext()); - } - - if (auto parent = getOperation()->getParentOfType()) { - if (!getRoot()) { - InFlightDiagnostic diag = - emitOpError() - << "expected the root operation to be provided for a nested sequence"; - diag.attachNote(parent.getLoc()) << "nested in another sequence"; - return diag; - } - } - for (Operation &child : *getBodyBlock()) { if (!isa(child) && &child != &getBodyBlock()->back()) { @@ -99,3 +199,65 @@ LogicalResult transform::SequenceOp::verify() { } return success(); } + +//===----------------------------------------------------------------------===// +// WithPDLPatternsOp +//===----------------------------------------------------------------------===// + +LogicalResult +transform::WithPDLPatternsOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + OwningOpRef pdlModuleOp = + ModuleOp::create(getOperation()->getLoc()); + TransformOpInterface transformOp = nullptr; + for (Operation &nested : getBody().front()) { + if (!isa(nested)) { + transformOp = cast(nested); + break; + } + } + + state.addExtension(getOperation()); + auto guard = llvm::make_scope_exit( + [&]() { state.removeExtension(); }); + + auto scope = state.make_region_scope(getBody()); + if (failed(mapBlockArguments(state))) + return failure(); + return state.applyTransform(transformOp); +} + +LogicalResult transform::WithPDLPatternsOp::verify() { + Block *body = getBodyBlock(); + Operation *topLevelOp = nullptr; + for (Operation &op : body->getOperations()) { + if (isa(op)) + continue; + + if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) { + if (topLevelOp) { + InFlightDiagnostic diag = + emitOpError() << "expects only one non-pattern op in its body"; + diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op"; + diag.attachNote(op.getLoc()) << "second non-pattern op"; + return diag; + } + topLevelOp = &op; + continue; + } + + InFlightDiagnostic diag = + emitOpError() + << "expects only pattern and top-level transform ops in its body"; + diag.attachNote(op.getLoc()) << "offending op"; + return diag; + } + + if (auto parent = getOperation()->getParentOfType()) { + InFlightDiagnostic diag = emitOpError() << "cannot be nested"; + diag.attachNote(parent.getLoc()) << "parent operation"; + return diag; + } + + return success(); +} diff --git a/mlir/test/Dialect/Transform/ops-invalid.mlir b/mlir/test/Dialect/Transform/ops-invalid.mlir index 614628107834..61ed760d700f 100644 --- a/mlir/test/Dialect/Transform/ops-invalid.mlir +++ b/mlir/test/Dialect/Transform/ops-invalid.mlir @@ -1,15 +1,15 @@ // RUN: mlir-opt %s -split-input-file -verify-diagnostics -// expected-error @below {{expected the entry block to have one argument of type '!pdl.operation'}} +// expected-error @below {{expects the entry block to have one argument of type '!pdl.operation'}} transform.sequence { } // ----- -// expected-note @below {{nested in another sequence}} +// expected-note @below {{nested in another possible top-level op}} transform.sequence { ^bb0(%arg0: !pdl.operation): - // expected-error @below {{expected the root operation to be provided for a nested sequence}} + // expected-error @below {{expects the root operation to be provided for a nested op}} transform.sequence { ^bb1(%arg1: !pdl.operation): } @@ -50,3 +50,64 @@ transform.sequence { // expected-note @below {{terminator}} transform.yield } : !pdl.operation + +// ----- + +// expected-note @below {{nested in another possible top-level op}} +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + // expected-error @below {{expects the root operation to be provided for a nested op}} + transform.sequence { + ^bb1(%arg1: !pdl.operation): + } +} + +// ----- + +// expected-error @below {{expects only one non-pattern op in its body}} +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + // expected-note @below {{first non-pattern op}} + transform.sequence { + ^bb1(%arg1: !pdl.operation): + } + // expected-note @below {{second non-pattern op}} + transform.sequence { + ^bb1(%arg1: !pdl.operation): + } +} + +// ----- + +// expected-error @below {{expects only pattern and top-level transform ops in its body}} +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + // expected-note @below {{offending op}} + "test.something"() : () -> () +} + +// ----- + +// expected-note @below {{parent operation}} +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + // expected-error @below {{op cannot be nested}} + transform.with_pdl_patterns %arg0 { + ^bb1(%arg1: !pdl.operation): + } +} + +// ----- + +// expected-error @below {{expects one region}} +"transform.test_transform_unrestricted_op_no_interface"() : () -> () + +// ----- + +// expected-error @below {{expects a single-block region}} +"transform.test_transform_unrestricted_op_no_interface"() ({ +^bb0(%arg0: !pdl.operation): + "test.potential_terminator"() : () -> () +^bb1: + "test.potential_terminator"() : () -> () +}) : () -> () diff --git a/mlir/test/Dialect/Transform/ops.mlir b/mlir/test/Dialect/Transform/ops.mlir index c3aab426aad2..34ee62e0bbc7 100644 --- a/mlir/test/Dialect/Transform/ops.mlir +++ b/mlir/test/Dialect/Transform/ops.mlir @@ -10,3 +10,23 @@ transform.sequence { ^bb1(%arg1: !pdl.operation): } } + +// CHECK: transform.with_pdl_patterns +// CHECK: ^{{.+}}(%[[ARG:.+]]: !pdl.operation): +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + // CHECK: sequence %[[ARG]] + sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + } +} + +// CHECK: transform.sequence +// CHECK: ^{{.+}}(%[[ARG:.+]]: !pdl.operation): +transform.sequence { +^bb0(%arg0: !pdl.operation): + // CHECK: with_pdl_patterns %[[ARG]] + with_pdl_patterns %arg0 { + ^bb1(%arg1: !pdl.operation): + } +} diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir index a6ceeea82a5c..2b2416480af1 100644 --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -69,3 +69,31 @@ transform.sequence { // expected-remark @below {{succeeded}} test_consume_operand_if_matches_param_or_fail %0[42] } + +// ----- + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + sequence %arg0 { + ^bb0(%arg1: !pdl.operation): + %0 = pdl_match @some in %arg1 + test_print_remark_at_operand %0, "matched" + } + + pdl.pattern @some : benefit(1) { + %0 = pdl.operation "test.some_op" + pdl.rewrite %0 with "transform.dialect" + } + + pdl.pattern @other : benefit(1) { + %0 = pdl.operation "test.other_op" + pdl.rewrite %0 with "transform.dialect" + } +} + +// expected-remark @below {{matched}} +"test.some_op"() : () -> () +"test.other_op"() : () -> () +// expected-remark @below {{matched}} +"test.some_op"() : () -> () + diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index 4aed0aae1e77..c3bbb5a66c61 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -22,7 +22,8 @@ using namespace mlir; namespace { /// Simple transform op defined outside of the dialect. Just emits a remark when -/// applied. +/// applied. This op is defined in C++ to test that C++ definitions also work +/// for op injection into the Transform dialect. class TestTransformOp : public Op { public: @@ -63,6 +64,33 @@ public: printer << " " << getMessage(); } }; + +/// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait +/// in cases where it is attached to ops that do not comply with the trait +/// requirements. This op cannot be defined in ODS because ODS generates strict +/// verifiers that overalp with those in the trait and run earlier. +class TestTransformUnrestrictedOpNoInterface + : public Op { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestTransformUnrestrictedOpNoInterface) + + using Op::Op; + + static ArrayRef getAttributeNames() { return {}; } + + static constexpr llvm::StringLiteral getOperationName() { + return llvm::StringLiteral( + "transform.test_transform_unrestricted_op_no_interface"); + } + + LogicalResult apply(transform::TransformResults &results, + transform::TransformState &state) { + return success(); + } +}; } // namespace LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply( @@ -97,6 +125,15 @@ LogicalResult mlir::test::TestConsumeOperandIfMatchesParamOrFail::apply( return success(); } +LogicalResult mlir::test::TestPrintRemarkAtOperandOp::apply( + transform::TransformResults &results, transform::TransformState &state) { + ArrayRef payload = state.getPayloadOps(getOperand()); + for (Operation *op : payload) + op->emitRemark() << getMessage(); + + return success(); +} + namespace { /// Test extension of the Transform dialect. Registers additional ops and /// declares PDL as dependent dialect since the additional ops are using PDL @@ -108,6 +145,7 @@ public: TestTransformDialectExtension() { declareDependentDialect(); registerTransformOps(); diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td index c263409c618d..4596780ac131 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td @@ -38,4 +38,12 @@ def TestConsumeOperandIfMatchesParamOrFail let cppNamespace = "::mlir::test"; } +def TestPrintRemarkAtOperandOp + : Op]> { + let arguments = (ins PDL_Operation:$operand, StrAttr:$message); + let assemblyFormat = "$operand `,` $message attr-dict"; + let cppNamespace = "::mlir::test"; +} + #endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 346e7f7d16a7..2b70cd7afb03 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7772,6 +7772,8 @@ cc_library( deps = [ ":IR", ":PDLDialect", + ":PDLInterpDialect", + ":Rewrite", ":Support", ":TransformDialectIncGen", ":TransformDialectInterfacesIncGen", -- 2.34.1