From 94d608d410267db693aa85070263e2b4ef0be913 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 22 May 2023 14:36:58 +0000 Subject: [PATCH] [mlir] move PDL-related transform ops into an extension The initial bring-up of the Transform dialect relied on PDL to provide the default handle type (`!pdl.operation`) and the matching capability. Both are now provided natively by the Transform dialect removing the reason to have a hard dependency on the PDL dialect and its interpreter. Move PDL-related transform operations into a separate extension. This requires us to introduce a dialect state extension mechanism into the Transform dialect so it no longer needs to know about PDL constraint functions that may be injected by extensions similarly to operations and types. This mechanism will be reused to connect pattern application drivers and the Transform dialect. This completes the restructuring of the Transform dialect to remove overrilance on PDL. Note to downstreams: flow that are using `!pdl.operation` with Transform dialect operations will now require `transform::PDLExtension` to be applied to the transform dialect in order to provide the transform handle type interface for `!pdl.operation`. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D151104 --- mlir/include/mlir/Dialect/Transform/CMakeLists.txt | 1 + .../mlir/Dialect/Transform/IR/TransformDialect.h | 118 +++++++-- .../mlir/Dialect/Transform/IR/TransformDialect.td | 39 ++- .../Dialect/Transform/IR/TransformInterfaces.h | 35 ++- .../mlir/Dialect/Transform/IR/TransformOps.h | 1 - .../mlir/Dialect/Transform/IR/TransformOps.td | 86 ------- .../Dialect/Transform/PDLExtension/CMakeLists.txt | 6 + .../Dialect/Transform/PDLExtension/PDLExtension.h | 16 ++ .../Transform/PDLExtension/PDLExtensionOps.h | 49 ++++ .../Transform/PDLExtension/PDLExtensionOps.td | 104 ++++++++ mlir/include/mlir/InitAllDialects.h | 2 + mlir/lib/Dialect/Transform/CMakeLists.txt | 1 + mlir/lib/Dialect/Transform/IR/CMakeLists.txt | 2 - mlir/lib/Dialect/Transform/IR/TransformDialect.cpp | 29 --- .../Dialect/Transform/IR/TransformInterfaces.cpp | 55 +++++ mlir/lib/Dialect/Transform/IR/TransformOps.cpp | 266 +-------------------- .../Dialect/Transform/PDLExtension/CMakeLists.txt | 13 + .../Transform/PDLExtension/PDLExtension.cpp | 69 ++++++ .../Transform/PDLExtension/PDLExtensionOps.cpp | 234 ++++++++++++++++++ mlir/python/CMakeLists.txt | 10 + .../mlir/dialects/TransformPDLExtensionOps.td | 20 ++ mlir/python/mlir/dialects/_transform_ops_ext.py | 42 ---- .../dialects/_transform_pdl_extension_ops_ext.py | 55 +++++ mlir/python/mlir/dialects/transform/pdl.py | 5 + mlir/test/Dialect/Transform/test-interpreter.mlir | 27 --- .../test/Dialect/Transform/test-pdl-extension.mlir | 47 ++++ mlir/test/lib/Dialect/Transform/CMakeLists.txt | 1 + .../Transform/TestTransformDialectExtension.cpp | 19 ++ mlir/test/python/dialects/transform.py | 12 +- .../python/dialects/transform_structured_ext.py | 7 +- utils/bazel/llvm-project-overlay/mlir/BUILD.bazel | 52 +++- .../llvm-project-overlay/mlir/python/BUILD.bazel | 21 ++ .../llvm-project-overlay/mlir/test/BUILD.bazel | 2 + 33 files changed, 929 insertions(+), 517 deletions(-) create mode 100644 mlir/include/mlir/Dialect/Transform/PDLExtension/CMakeLists.txt create mode 100644 mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h create mode 100644 mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h create mode 100644 mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td create mode 100644 mlir/lib/Dialect/Transform/PDLExtension/CMakeLists.txt create mode 100644 mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp create mode 100644 mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp create mode 100644 mlir/python/mlir/dialects/TransformPDLExtensionOps.td create mode 100644 mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py create mode 100644 mlir/python/mlir/dialects/transform/pdl.py create mode 100644 mlir/test/Dialect/Transform/test-pdl-extension.mlir diff --git a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt index 9f57627..d9fbaee 100644 --- a/mlir/include/mlir/Dialect/Transform/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Transform/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(IR) +add_subdirectory(PDLExtension) add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h index 36712add..e156602 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.h @@ -12,12 +12,52 @@ #include "mlir/IR/Dialect.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Support/LLVM.h" +#include "mlir/Support/TypeID.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/StringMap.h" #include namespace mlir { namespace transform { + +namespace detail { +/// Concrete base class for CRTP TransformDialectDataBase. Must not be used +/// directly. +class TransformDialectDataBase { +public: + virtual ~TransformDialectDataBase() = default; + + /// Returns the dynamic type ID of the subclass. + TypeID getTypeID() const { return typeID; } + +protected: + /// Must be called by the subclass with the appropriate type ID. + explicit TransformDialectDataBase(TypeID typeID) : typeID(typeID) {} + +private: + /// The type ID of the subclass. + const TypeID typeID; +}; +} // namespace detail + +/// Base class for additional data owned by the Transform dialect. Extensions +/// may communicate with each other using this data. The data object is +/// identified by the TypeID of the specific data subclass, querying the data of +/// the same subclass returns a reference to the same object. When a Transform +/// dialect extension is initialized, it can populate the data in the specific +/// subclass. When a Transform op is applied, it can read (but not mutate) the +/// data in the specific subclass, including the data provided by other +/// extensions. +/// +/// This follows CRTP: derived classes must list themselves as template +/// argument. +template +class TransformDialectData : public detail::TransformDialectDataBase { +protected: + /// Forward the TypeID of the derived class to the base. + TransformDialectData() : TransformDialectDataBase(TypeID::get()) {} +}; + #ifndef NDEBUG namespace detail { /// Asserts that the operations provided as template arguments implement the @@ -85,9 +125,8 @@ public: for (const DialectLoader &loader : generatedDialectLoaders) loader(context); - for (const Initializer &init : opInitializers) + for (const Initializer &init : initializers) init(transformDialect); - transformDialect->mergeInPDLMatchHooks(std::move(pdlMatchConstraintFns)); } protected: @@ -100,6 +139,41 @@ protected: static_cast(this)->init(); } + /// Registers a custom initialization step to be performed when the extension + /// is applied to the dialect while loading. This is discouraged in favor of + /// more specific calls `declareGeneratedDialect`, `addDialectDataInitializer` + /// etc. `Func` must be convertible to the `void (MLIRContext *)` form. It + /// will be called during the extension initialization and given the current + /// MLIR context. This may be used to attach additional interfaces that cannot + /// be attached elsewhere. + template + void addCustomInitializationStep(Func &&func) { + std::function initializer = func; + dialectLoaders.push_back( + [init = std::move(initializer)](MLIRContext *ctx) { init(ctx); }); + } + + /// Registers the given function as one of the initializers for the + /// dialect-owned data of the kind specified as template argument. The + /// function must be convertible to the `void (DataTy &)` form. It will be + /// called during the extension initialization and will be given a mutable + /// reference to `DataTy`. The callback is expected to append data to the + /// given storage, and is not allowed to remove or destructively mutate the + /// existing data. The order in which callbacks from different extensions are + /// executed is unspecified so the callbacks may not rely on data being + /// already present. `DataTy` must be a class deriving `TransformDialectData`. + template + void addDialectDataInitializer(Func &&func) { + static_assert(std::is_base_of_v, + "only classes deriving TransformDialectData are accepted"); + + std::function initializer = func; + initializers.push_back( + [init = std::move(initializer)](TransformDialect *transformDialect) { + init(transformDialect->getOrCreateExtraData()); + }); + } + /// Hook for derived classes to inject constructor behavior. void init() {} @@ -108,7 +182,7 @@ protected: /// implementations must be already available when the operation is injected. template void registerTransformOps() { - opInitializers.push_back([](TransformDialect *transformDialect) { + initializers.push_back([](TransformDialect *transformDialect) { transformDialect->addOperationsChecked(); }); } @@ -120,7 +194,7 @@ protected: /// `StringRef` that is unique across all injected types. template void registerTypes() { - opInitializers.push_back([](TransformDialect *transformDialect) { + initializers.push_back([](TransformDialect *transformDialect) { transformDialect->addTypesChecked(); }); } @@ -151,22 +225,10 @@ protected: [](MLIRContext *context) { context->loadDialect(); }); } - /// Injects the named constraint to make it available for use with the - /// PDLMatchOp in the transform dialect. - void registerPDLMatchConstraintFn(StringRef name, - PDLConstraintFunction &&fn) { - pdlMatchConstraintFns.try_emplace(name, - std::forward(fn)); - } - template - void registerPDLMatchConstraintFn(StringRef name, ConstraintFnTy &&fn) { - pdlMatchConstraintFns.try_emplace( - name, ::mlir::detail::pdl_function_builder::buildConstraintFn( - std::forward(fn))); - } - private: - SmallVector opInitializers; + /// Callbacks performing extension initialization, e.g., registering ops, + /// types and defining the additional data. + SmallVector initializers; /// Callbacks loading the dependent dialects, i.e. the dialect needed for the /// extension ops. @@ -176,13 +238,6 @@ private: /// applying the transformations. SmallVector generatedDialectLoaders; - /// A list of constraints that should be made available to PDL patterns - /// processed by PDLMatchOp in the Transform dialect. - /// - /// Declared as mutable so its contents can be moved in the `apply` const - /// method, which is only called once. - mutable llvm::StringMap pdlMatchConstraintFns; - /// Indicates that the extension is in build-only mode. bool buildOnly; }; @@ -232,6 +287,17 @@ void TransformDialect::addTypeIfNotRegistered() { #endif // NDEBUG } +template +DataTy &TransformDialect::getOrCreateExtraData() { + TypeID typeID = TypeID::get(); + auto it = extraData.find(typeID); + if (it != extraData.end()) + return static_cast(*it->getSecond()); + + auto emplaced = extraData.try_emplace(typeID, std::make_unique()); + return static_cast(*emplaced.first->getSecond()); +} + /// A wrapper for transform dialect extensions that forces them to be /// constructed in the build-only mode. template diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td index 160f1ff..0539187 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformDialect.td @@ -18,36 +18,31 @@ def Transform_Dialect : Dialect { let name = "transform"; let cppNamespace = "::mlir::transform"; - let dependentDialects = [ - "::mlir::pdl::PDLDialect", - "::mlir::pdl_interp::PDLInterpDialect", - ]; - let hasOperationAttrVerify = 1; let usePropertiesForAttributes = 1; let extraClassDeclaration = [{ /// Name of the attribute attachable to the symbol table operation /// containing named sequences. This is used to trigger verification. - constexpr const static llvm::StringLiteral + constexpr const static ::llvm::StringLiteral kWithNamedSequenceAttrName = "transform.with_named_sequence"; /// Names of the attribute attachable to an operation so it can be /// identified as root by the default interpreter pass. - constexpr const static llvm::StringLiteral + constexpr const static ::llvm::StringLiteral kTargetTagAttrName = "transform.target_tag"; /// Names of the attributes indicating whether an argument of an external /// transform dialect symbol is consumed or only read. - constexpr const static llvm::StringLiteral + constexpr const static ::llvm::StringLiteral kArgConsumedAttrName = "transform.consumed"; - constexpr const static llvm::StringLiteral + constexpr const static ::llvm::StringLiteral kArgReadOnlyAttrName = "transform.readonly"; - /// Returns the named PDL constraint functions available in the dialect - /// as a map from their name to the function. - const ::llvm::StringMap<::mlir::PDLConstraintFunction> & - getPDLConstraintHooks() const; + template + const DataTy &getExtraData() const { + return *static_cast(extraData.at(::mlir::TypeID::get()).get()); + } /// Parses a type registered by this dialect or one of its extensions. ::mlir::Type parseType(::mlir::DialectAsmParser &parser) const override; @@ -92,23 +87,27 @@ def Transform_Dialect : Dialect { /// mnemonic. [[noreturn]] void reportDuplicateTypeRegistration(StringRef mnemonic); + /// Registers dialect types with the context. void initializeTypes(); + // Give extensions access to injection functions. template friend class TransformDialectExtension; - /// Takes ownership of the named PDL constraint function from the given - /// map and makes them available for use by the operations in the dialect. - void mergeInPDLMatchHooks( - ::llvm::StringMap<::mlir::PDLConstraintFunction> &&constraintFns); + /// Gets a mutable reference to extra data of the kind specified as + /// template argument. Allocates the data on the first call. + template + DataTy &getOrCreateExtraData(); //===----------------------------------------------------------------===// // Data fields //===----------------------------------------------------------------===// - /// A container for PDL constraint function that can be used by - /// operations in this dialect. - ::mlir::PDLPatternModule pdlMatchHooks; + /// Additional data associated with and owned by the dialect. Accessible + /// to extensions. + ::llvm::DenseMap<::mlir::TypeID, std::unique_ptr< + ::mlir::transform::detail::TransformDialectDataBase>> + extraData; /// A map from type mnemonic to its parsing function for the remainder of /// the syntax. The parser has access to the mnemonic, so it is used for diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h index 6730552..77d0c7d 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h @@ -38,6 +38,14 @@ mapPossibleTopLevelTransformOpBlockArguments(TransformState &state, /// Verification hook for PossibleTopLevelTransformOpTrait. LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op); +/// Populates `effects` with side effects implied by +/// PossibleTopLevelTransformOpTrait for the given operation. The operation may +/// have an optional `root` operand, indicating it is not in fact top-level. It +/// is also expected to have a single-block body. +void getPotentialTopLevelEffects( + Operation *operation, Value root, Block &body, + SmallVectorImpl &effects); + /// Verification hook for TransformOpInterface. LogicalResult verifyTransformOpInterface(Operation *op); @@ -753,15 +761,16 @@ TransformState::make_isolated_region_scope(Region ®ion) { /// can be standalone top-level transforms. Such operations typically contain /// other Transform dialect operations that can be executed following some /// control flow logic specific to the current operation. The operations with -/// this trait are expected to have at least one single-block region with one -/// argument of PDL Operation type. The operations are also expected to be valid -/// without operands, in which case they are considered top-level, and with one -/// or more arguments, in which case they are considered nested. Top-level -/// operations have the block argument of the entry block in the Transform IR -/// correspond to the root operation of Payload IR. Nested operations have the -/// block argument of the entry block in the Transform IR correspond to a list -/// of Payload IR operations mapped to the first operand of the Transform IR -/// operation. The operation must implement TransformOpInterface. +/// this trait are expected to have at least one single-block region with at +/// least one argument of type implementing TransformHandleTypeInterface. The +/// operations are also expected to be valid without operands, in which case +/// they are considered top-level, and with one or more arguments, in which case +/// they are considered nested. Top-level operations have the block argument of +/// the entry block in the Transform IR correspond to the root operation of +/// Payload IR. Nested operations have the block argument of the entry block in +/// the Transform IR correspond to a list of Payload IR operations mapped to the +/// first operand of the Transform IR operation. The operation must implement +/// TransformOpInterface. template class PossibleTopLevelTransformOpTrait : public OpTrait::TraitBase { @@ -777,6 +786,14 @@ public: return &this->getOperation()->getRegion(region).front(); } + /// Populates `effects` with side effects implied by this trait. + void getPotentialTopLevelEffects( + SmallVectorImpl &effects) { + detail::getPotentialTopLevelEffects( + this->getOperation(), cast(this->getOperation()).getRoot(), + *getBodyBlock(), effects); + } + /// Sets up the mapping between the entry block of the given region of this op /// and the relevant list of Payload IR operations in the given state. The /// state is expected to be already scoped at the region of this operation. diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h index 7a0f802..543eba9 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.h @@ -9,7 +9,6 @@ #ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H #define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS_H -#include "mlir/Dialect/PDL/IR/PDLTypes.h" #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" diff --git a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td index aa88e49..a313d28 100644 --- a/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td +++ b/mlir/include/mlir/Dialect/Transform/IR/TransformOps.td @@ -575,37 +575,6 @@ def ParamConstantOp : Op, - DeclareOpInterfaceMethods]> { - let summary = "Finds ops that match the named PDL pattern"; - let description = [{ - Find Payload IR ops nested within the Payload IR op associated with the - operand that match the PDL pattern identified by its name. The pattern is - expected to be defined in the closest surrounding `WithPDLPatternsOp`. - - Produces a Transform IR value associated with the list of Payload IR ops - that matched the pattern. The order of results in the list is that of the - Operation::walk, clients are advised not to rely on a specific order though. - If the operand is associated with multiple Payload IR ops, finds matching - ops nested within each of those and produces a single list containing all - of the matched ops. - - The transformation is considered successful regardless of whether some - Payload IR ops actually matched the pattern and only fails if the pattern - could not be looked up or compiled. - }]; - - let arguments = (ins - Arg:$root, - SymbolRefAttr:$pattern_name); - let results = (outs - Res:$matched); - - let assemblyFormat = "$pattern_name `in` $root attr-dict `:` " - "functional-type(operands, results)"; -} - def PrintOp : TransformDialectOp<"print", [DeclareOpInterfaceMethods, DeclareOpInterfaceMethods]> { @@ -753,61 +722,6 @@ def SequenceOp : TransformDialectOp<"sequence", let hasVerifier = 1; } -def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns", - [DeclareOpInterfaceMethods, NoTerminator, - OpAsmOpInterface, PossibleTopLevelTransformOpTrait, - DeclareOpInterfaceMethods, - SymbolTable]> { - let summary = "Contains PDL patterns available for use in transforms"; - let description = [{ - This op contains a set of named PDL patterns that are available for the - Transform dialect operations to be used for pattern matching. For example, - PDLMatchOp can be used to produce a Transform IR value associated with all - Payload IR operations that match the pattern as follows: - - ```mlir - transform.with_pdl_patterns { - ^bb0(%arg0: !transform.any_op): - pdl.pattern @my_pattern : benefit(1) { - %0 = pdl.operation //... - // Regular PDL goes here. - pdl.rewrite %0 with "transform.dialect" - } - - sequence %arg0 failures(propagate) { - ^bb0(%arg1: !transform.any_op): - %1 = pdl_match @my_pattern in %arg1 - // Use %1 as handle - } - } - ``` - - Note that the pattern is expected to finish with a `pdl.rewrite` terminator - that points to the custom rewriter named "transform.dialect". The rewriter - actually does nothing, but the transform application will keep track of the - operations that matched the pattern. - - This op is expected to contain `pdl.pattern` operations and exactly one - another Transform dialect operation that gets executed with all patterns - available. This op is a possible top-level Transform IR op, the argument of - its entry block corresponds to either the root op of the payload IR or the - ops associated with its operand when provided. - }]; - - let arguments = (ins - Arg, "Root operation of the Payload IR" - >:$root); - let regions = (region SizedRegion<1>:$body); - let assemblyFormat = "($root^ `:` type($root))? attr-dict-with-keyword regions"; - - let hasVerifier = 1; - - let extraClassDeclaration = [{ - /// Allow the dialect prefix to be omitted. - static StringRef getDefaultDialect() { return "transform"; } - }]; -} - def YieldOp : TransformDialectOp<"yield", [Terminator, DeclareOpInterfaceMethods]> { let summary = "Yields operation handles from a transform IR region"; diff --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/CMakeLists.txt b/mlir/include/mlir/Dialect/Transform/PDLExtension/CMakeLists.txt new file mode 100644 index 0000000..6af6b83 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/CMakeLists.txt @@ -0,0 +1,6 @@ +set(LLVM_TARGET_DEFINITIONS PDLExtensionOps.td) +mlir_tablegen(PDLExtensionOps.h.inc -gen-op-decls) +mlir_tablegen(PDLExtensionOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRTransformDialectPDLExtensionOpsIncGen) + +add_mlir_doc(PDLExtensionOps PDLExtensionOps Dialects/ -gen-op-doc) diff --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h new file mode 100644 index 0000000..0891521 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtension.h @@ -0,0 +1,16 @@ +//===- PDLExtension.h - PDL extension for Transform dialect -----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +namespace mlir { +class DialectRegistry; + +namespace transform { +/// Registers the PDL extension of the Transform dialect in the given registry. +void registerPDLExtension(DialectRegistry &dialectRegistry); +} // namespace transform +} // namespace mlir diff --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h new file mode 100644 index 0000000..a159c30 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h @@ -0,0 +1,49 @@ +//===- PDLExtensionOps.h - PDL extension for Transform dialect --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS_H +#define MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS_H + +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" + +#define GET_OP_CLASSES +#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h.inc" + +namespace mlir { +namespace transform { +/// PDL constraint callbacks that can be used by the PDL extension of the +/// Transform dialect. These are owned by the Transform dialect and can be +/// populated by extensions. +class PDLMatchHooks : public TransformDialectData { +public: + /// Takes ownership of the named PDL constraint function from the given + /// map and makes them available for use by the operations in the dialect. + void + mergeInPDLMatchHooks(llvm::StringMap &&constraintFns); + + /// Returns the named PDL constraint functions available in the dialect + /// as a map from their name to the function. + const llvm::StringMap<::mlir::PDLConstraintFunction> & + getPDLConstraintHooks() const; + +private: + /// A container for PDL constraint function that can be used by + /// operations in this dialect. + PDLPatternModule pdlMatchHooks; +}; +} // namespace transform +} // namespace mlir + +MLIR_DECLARE_EXPLICIT_TYPE_ID(mlir::transform::PDLMatchHooks) + +#endif // MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS_H diff --git a/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td new file mode 100644 index 0000000..16107b3 --- /dev/null +++ b/mlir/include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td @@ -0,0 +1,104 @@ +//===- TransformOps.td - Transform dialect operations ------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS +#define MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS + +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpAsmInterface.td" +include "mlir/IR/SymbolInterfaces.td" + +def PDLMatchOp : TransformDialectOp<"pdl_match", + [DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods]> { + let summary = "Finds ops that match the named PDL pattern"; + let description = [{ + Find Payload IR ops nested within the Payload IR op associated with the + operand that match the PDL pattern identified by its name. The pattern is + expected to be defined in the closest surrounding `WithPDLPatternsOp`. + + Produces a Transform IR value associated with the list of Payload IR ops + that matched the pattern. The order of results in the list is that of the + Operation::walk, clients are advised not to rely on a specific order though. + If the operand is associated with multiple Payload IR ops, finds matching + ops nested within each of those and produces a single list containing all + of the matched ops. + + The transformation is considered successful regardless of whether some + Payload IR ops actually matched the pattern and only fails if the pattern + could not be looked up or compiled. + }]; + + let arguments = (ins + Arg:$root, + SymbolRefAttr:$pattern_name); + let results = (outs + Res:$matched); + + let assemblyFormat = "$pattern_name `in` $root attr-dict `:` " + "functional-type(operands, results)"; +} + +def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns", + [DeclareOpInterfaceMethods, NoTerminator, + OpAsmOpInterface, PossibleTopLevelTransformOpTrait, + DeclareOpInterfaceMethods, + SymbolTable]> { + let summary = "Contains PDL patterns available for use in transforms"; + let description = [{ + This op contains a set of named PDL patterns that are available for the + Transform dialect operations to be used for pattern matching. For example, + PDLMatchOp can be used to produce a Transform IR value associated with all + Payload IR operations that match the pattern as follows: + + ```mlir + transform.with_pdl_patterns { + ^bb0(%arg0: !transform.any_op): + pdl.pattern @my_pattern : benefit(1) { + %0 = pdl.operation //... + // Regular PDL goes here. + pdl.rewrite %0 with "transform.dialect" + } + + sequence %arg0 failures(propagate) { + ^bb0(%arg1: !transform.any_op): + %1 = pdl_match @my_pattern in %arg1 + // Use %1 as handle + } + } + ``` + + Note that the pattern is expected to finish with a `pdl.rewrite` terminator + that points to the custom rewriter named "transform.dialect". The rewriter + actually does nothing, but the transform application will keep track of the + operations that matched the pattern. + + This op is expected to contain `pdl.pattern` operations and exactly one + another Transform dialect operation that gets executed with all patterns + available. This op is a possible top-level Transform IR op, the argument of + its entry block corresponds to either the root op of the payload IR or the + ops associated with its operand when provided. + }]; + + let arguments = (ins + Arg, "Root operation of the Payload IR" + >:$root); + let regions = (region SizedRegion<1>:$body); + let assemblyFormat = "($root^ `:` type($root))? attr-dict-with-keyword regions"; + + let hasVerifier = 1; + + let extraClassDeclaration = [{ + /// Allow the dialect prefix to be omitted. + static StringRef getDefaultDialect() { return "transform"; } + }]; +} + +#endif // MLIR_DIALECT_TRANSFORM_PDLEXTENSION_PDLEXTENSIONOPS diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index b00de3f..e307b23 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -76,6 +76,7 @@ #include "mlir/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h" #include "mlir/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.h" @@ -135,6 +136,7 @@ inline void registerAllDialects(DialectRegistry ®istry) { memref::registerTransformDialectExtension(registry); scf::registerTransformDialectExtension(registry); tensor::registerTransformDialectExtension(registry); + transform::registerPDLExtension(registry); vector::registerTransformDialectExtension(registry); // Register all external models. diff --git a/mlir/lib/Dialect/Transform/CMakeLists.txt b/mlir/lib/Dialect/Transform/CMakeLists.txt index 31167e6..9e144eb 100644 --- a/mlir/lib/Dialect/Transform/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/CMakeLists.txt @@ -1,3 +1,4 @@ add_subdirectory(IR) +add_subdirectory(PDLExtension) add_subdirectory(Transforms) add_subdirectory(Utils) diff --git a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt index 2fed20f..4fb2751 100644 --- a/mlir/lib/Dialect/Transform/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/Transform/IR/CMakeLists.txt @@ -14,8 +14,6 @@ add_mlir_dialect_library(MLIRTransformDialect LINK_LIBS PUBLIC MLIRIR MLIRParser - MLIRPDLDialect - MLIRPDLInterpDialect MLIRRewrite MLIRSideEffectInterfaces MLIRTransforms diff --git a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp index 6780c3b..d075994 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformDialect.cpp @@ -8,8 +8,6 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Analysis/CallGraph.h" -#include "mlir/Dialect/PDL/IR/PDL.h" -#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/IR/TransformTypes.h" @@ -51,18 +49,6 @@ void transform::detail::checkImplementsTransformHandleTypeInterface( } #endif // NDEBUG -namespace { -struct PDLOperationTypeTransformHandleTypeInterfaceImpl - : public transform::TransformHandleTypeInterface::ExternalModel< - PDLOperationTypeTransformHandleTypeInterfaceImpl, - pdl::OperationType> { - DiagnosedSilenceableFailure - checkPayload(Type type, Location loc, ArrayRef payload) const { - return DiagnosedSilenceableFailure::success(); - } -}; -} // namespace - void transform::TransformDialect::initialize() { // Using the checked versions to enable the same assertions as for the ops // from extensions. @@ -71,21 +57,6 @@ void transform::TransformDialect::initialize() { #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" >(); initializeTypes(); - - pdl::OperationType::attachInterface< - PDLOperationTypeTransformHandleTypeInterfaceImpl>(*getContext()); -} - -void transform::TransformDialect::mergeInPDLMatchHooks( - llvm::StringMap &&constraintFns) { - // Steal the constraint functions from the given map. - for (auto &it : constraintFns) - pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second)); -} - -const llvm::StringMap & -transform::TransformDialect::getPDLConstraintHooks() const { - return pdlMatchHooks.getConstraintFunctions(); } Type transform::TransformDialect::parseType(DialectAsmParser &parser) const { diff --git a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp index 5685187e..37caa60 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp @@ -1242,6 +1242,61 @@ void transform::detail::forwardTerminatorOperands( // Utilities for PossibleTopLevelTransformOpTrait. //===----------------------------------------------------------------------===// +/// Appends to `effects` the memory effect instances on `target` with the same +/// resource and effect as the ones the operation `iface` having on `source`. +static void +remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target, + SmallVectorImpl &effects) { + SmallVector nestedEffects; + iface.getEffectsOnValue(source, nestedEffects); + for (const auto &effect : nestedEffects) + effects.emplace_back(effect.getEffect(), target, effect.getResource()); +} + +/// Appends to `effects` the same effects as the operations of `block` have on +/// block arguments but associated with `operands.` +static void +remapArgumentEffects(Block &block, ValueRange operands, + SmallVectorImpl &effects) { + for (Operation &op : block) { + auto iface = dyn_cast(&op); + if (!iface) + continue; + + for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) { + remapEffects(iface, source, target, effects); + } + + SmallVector nestedEffects; + iface.getEffectsOnResource(transform::PayloadIRResource::get(), + nestedEffects); + llvm::append_range(effects, nestedEffects); + } +} + +void transform::detail::getPotentialTopLevelEffects( + Operation *operation, Value root, Block &body, + SmallVectorImpl &effects) { + transform::onlyReadsHandle(operation->getOperands(), effects); + transform::producesHandle(operation->getResults(), effects); + + if (!root) { + for (Operation &op : body) { + auto iface = dyn_cast(&op); + if (!iface) + continue; + + SmallVector nestedEffects; + iface.getEffects(effects); + } + return; + } + + // Carry over all effects on arguments of the entry block as those on the + // operands, this is the same value just remapped. + remapArgumentEffects(body, operation->getOperands(), effects); +} + LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments( TransformState &state, Operation *op, Region ®ion) { SmallVector targets; diff --git a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp index ad00170..a3b55a4 100644 --- a/mlir/lib/Dialect/Transform/IR/TransformOps.cpp +++ b/mlir/lib/Dialect/Transform/IR/TransformOps.cpp @@ -7,7 +7,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Transform/IR/TransformOps.h" -#include "mlir/Dialect/PDL/IR/PDLOps.h" #include "mlir/Dialect/Transform/IR/MatchInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformAttrs.h" #include "mlir/Dialect/Transform/IR/TransformDialect.h" @@ -17,8 +16,6 @@ #include "mlir/IR/FunctionImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" -#include "mlir/Rewrite/FrozenRewritePatternSet.h" -#include "mlir/Rewrite/PatternApplicator.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopeExit.h" #include "llvm/ADT/SmallPtrSet.h" @@ -53,99 +50,6 @@ static ParseResult parseForeachMatchSymbols(OpAsmParser &parser, #include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc" //===----------------------------------------------------------------------===// -// PatternApplicatorExtension -//===----------------------------------------------------------------------===// - -namespace { -/// A TransformState extension that keeps track of compiled PDL pattern sets. -/// This is intended to be used along the WithPDLPatterns op. The extension -/// can be constructed given an operation that has a SymbolTable trait and -/// contains pdl::PatternOp instances. The patterns are compiled lazily and one -/// by one when requested; this behavior is subject to change. -class PatternApplicatorExtension : public transform::TransformState::Extension { -public: - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension) - - /// Creates the extension for patterns contained in `patternContainer`. - explicit PatternApplicatorExtension(transform::TransformState &state, - Operation *patternContainer) - : Extension(state), patterns(patternContainer) {} - - /// Appends to `results` the operations contained in `root` that matched the - /// PDL pattern with the given name. Note that `root` may or may not be the - /// operation that contains PDL patterns. Reports an error if the pattern - /// cannot be found. Note that when no operations are matched, this still - /// succeeds as long as the pattern exists. - LogicalResult findAllMatches(StringRef patternName, Operation *root, - SmallVectorImpl &results); - -private: - /// Map from the pattern name to a singleton set of rewrite patterns that only - /// contains the pattern with this name. Populated when the pattern is first - /// requested. - // TODO: reconsider the efficiency of this storage when more usage data is - // available. Storing individual patterns in a set and triggering compilation - // for each of them has overhead. So does compiling a large set of patterns - // only to apply a handlful of them. - llvm::StringMap compiledPatterns; - - /// A symbol table operation containing the relevant PDL patterns. - SymbolTable patterns; -}; - -LogicalResult PatternApplicatorExtension::findAllMatches( - StringRef patternName, Operation *root, - SmallVectorImpl &results) { - auto it = compiledPatterns.find(patternName); - if (it == compiledPatterns.end()) { - auto patternOp = patterns.lookup(patternName); - if (!patternOp) - return failure(); - - // Copy the pattern operation into a new module that is compiled and - // consumed by the PDL interpreter. - OwningOpRef pdlModuleOp = ModuleOp::create(patternOp.getLoc()); - auto builder = OpBuilder::atBlockEnd(pdlModuleOp->getBody()); - builder.clone(*patternOp); - PDLPatternModule patternModule(std::move(pdlModuleOp)); - - // Merge in the hooks owned by the dialect. Make a copy as they may be - // also used by the following operations. - auto *dialect = - root->getContext()->getLoadedDialect(); - for (const auto &[name, constraintFn] : dialect->getPDLConstraintHooks()) - patternModule.registerConstraintFunction(name, constraintFn); - - // Register a noop rewriter because PDL requires patterns to end with some - // rewrite call. - patternModule.registerRewriteFunction( - "transform.dialect", [](PatternRewriter &, Operation *) {}); - - it = compiledPatterns - .try_emplace(patternOp.getName(), std::move(patternModule)) - .first; - } - - PatternApplicator applicator(it->second); - // We want to discourage direct use of PatternRewriter in APIs but In this - // very specific case, an IRRewriter is not enough. - struct TrivialPatternRewriter : public PatternRewriter { - public: - explicit TrivialPatternRewriter(MLIRContext *context) - : PatternRewriter(context) {} - }; - TrivialPatternRewriter rewriter(root->getContext()); - applicator.applyDefaultCostModel(); - root->walk([&](Operation *op) { - if (succeeded(applicator.matchAndRewrite(op, rewriter))) - results.push_back(op); - }); - - return success(); -} -} // namespace - -//===----------------------------------------------------------------------===// // TrackingListener //===----------------------------------------------------------------------===// @@ -420,10 +324,7 @@ bool transform::CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { assert(outputs.size() == 1 && "expected one output"); return llvm::all_of( std::initializer_list{inputs.front(), outputs.front()}, - [](Type ty) { - return llvm::isa(ty); - }); + [](Type ty) { return isa(ty); }); } //===----------------------------------------------------------------------===// @@ -1031,38 +932,6 @@ transform::IncludeOp::apply(transform::TransformResults &results, return result; } -/// Appends to `effects` the memory effect instances on `target` with the same -/// resource and effect as the ones the operation `iface` having on `source`. -static void -remapEffects(MemoryEffectOpInterface iface, BlockArgument source, Value target, - SmallVectorImpl &effects) { - SmallVector nestedEffects; - iface.getEffectsOnValue(source, nestedEffects); - for (const auto &effect : nestedEffects) - effects.emplace_back(effect.getEffect(), target, effect.getResource()); -} - -/// Appends to `effects` the same effects as the operations of `block` have on -/// block arguments but associated with `operands.` -static void -remapArgumentEffects(Block &block, ValueRange operands, - SmallVectorImpl &effects) { - for (Operation &op : block) { - auto iface = dyn_cast(&op); - if (!iface) - continue; - - for (auto &&[source, target] : llvm::zip(block.getArguments(), operands)) { - remapEffects(iface, source, target, effects); - } - - SmallVector nestedEffects; - iface.getEffectsOnResource(transform::PayloadIRResource::get(), - nestedEffects); - llvm::append_range(effects, nestedEffects); - } -} - static DiagnosedSilenceableFailure verifyNamedSequenceOp(transform::NamedSequenceOp op); @@ -1474,8 +1343,7 @@ LogicalResult transform::NamedSequenceOp::verify() { void transform::SplitHandleOp::build(OpBuilder &builder, OperationState &result, Value target, int64_t numResultHandles) { result.addOperands(target); - auto pdlOpType = pdl::OperationType::get(builder.getContext()); - result.addTypes(SmallVector(numResultHandles, pdlOpType)); + result.addTypes(SmallVector(numResultHandles, target.getType())); } DiagnosedSilenceableFailure @@ -1536,35 +1404,6 @@ LogicalResult transform::SplitHandleOp::verify() { } //===----------------------------------------------------------------------===// -// PDLMatchOp -//===----------------------------------------------------------------------===// - -DiagnosedSilenceableFailure -transform::PDLMatchOp::apply(transform::TransformResults &results, - transform::TransformState &state) { - auto *extension = state.getExtension(); - assert(extension && - "expected PatternApplicatorExtension to be attached by the parent op"); - SmallVector targets; - for (Operation *root : state.getPayloadOps(getRoot())) { - if (failed(extension->findAllMatches( - getPatternName().getLeafReference().getValue(), root, targets))) { - emitDefiniteFailure() - << "could not find pattern '" << getPatternName() << "'"; - } - } - results.set(llvm::cast(getResult()), targets); - return DiagnosedSilenceableFailure::success(); -} - -void transform::PDLMatchOp::getEffects( - SmallVectorImpl &effects) { - onlyReadsHandle(getRoot(), effects); - producesHandle(getMatched(), effects); - onlyReadsPayload(effects); -} - -//===----------------------------------------------------------------------===// // ReplicateOp //===----------------------------------------------------------------------===// @@ -1776,37 +1615,9 @@ LogicalResult transform::SequenceOp::verify() { return success(); } -/// Populate `effects` with transform dialect memory effects for the potential -/// top-level operation. Such operations have recursive effects from nested -/// operations. When they have an operand, we can additionally remap effects on -/// the block argument to be effects on the operand. -template -static void getPotentialTopLevelEffects( - OpTy operation, SmallVectorImpl &effects) { - transform::onlyReadsHandle(operation->getOperands(), effects); - transform::producesHandle(operation->getResults(), effects); - - if (!operation.getRoot()) { - for (Operation &op : *operation.getBodyBlock()) { - auto iface = dyn_cast(&op); - if (!iface) - continue; - - SmallVector nestedEffects; - iface.getEffects(effects); - } - return; - } - - // Carry over all effects on arguments of the entry block as those on the - // operands, this is the same value just remapped. - remapArgumentEffects(*operation.getBodyBlock(), operation->getOperands(), - effects); -} - void transform::SequenceOp::getEffects( SmallVectorImpl &effects) { - getPotentialTopLevelEffects(*this, effects); + getPotentialTopLevelEffects(effects); } OperandRange transform::SequenceOp::getSuccessorEntryOperands( @@ -1909,77 +1720,6 @@ void transform::SequenceOp::build(OpBuilder &builder, OperationState &state, } //===----------------------------------------------------------------------===// -// WithPDLPatternsOp -//===----------------------------------------------------------------------===// - -DiagnosedSilenceableFailure -transform::WithPDLPatternsOp::apply(transform::TransformResults &results, - transform::TransformState &state) { - TransformOpInterface transformOp = nullptr; - for (Operation &nested : getBody().front()) { - if (!isa(nested)) { - transformOp = cast(nested); - break; - } - } - - state.addExtension(getOperation()); - auto guard = llvm::make_scope_exit( - [&]() { state.removeExtension(); }); - - auto scope = state.make_region_scope(getBody()); - if (failed(mapBlockArguments(state))) - return DiagnosedSilenceableFailure::definiteFailure(); - return state.applyTransform(transformOp); -} - -void transform::WithPDLPatternsOp::getEffects( - SmallVectorImpl &effects) { - getPotentialTopLevelEffects(*this, effects); -} - -LogicalResult transform::WithPDLPatternsOp::verify() { - Block *body = getBodyBlock(); - Operation *topLevelOp = nullptr; - for (Operation &op : body->getOperations()) { - if (isa(op)) - continue; - - if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) { - if (topLevelOp) { - InFlightDiagnostic diag = - emitOpError() << "expects only one non-pattern op in its body"; - diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op"; - diag.attachNote(op.getLoc()) << "second non-pattern op"; - return diag; - } - topLevelOp = &op; - continue; - } - - InFlightDiagnostic diag = - emitOpError() - << "expects only pattern and top-level transform ops in its body"; - diag.attachNote(op.getLoc()) << "offending op"; - return diag; - } - - if (auto parent = getOperation()->getParentOfType()) { - InFlightDiagnostic diag = emitOpError() << "cannot be nested"; - diag.attachNote(parent.getLoc()) << "parent operation"; - return diag; - } - - if (!topLevelOp) { - InFlightDiagnostic diag = emitOpError() - << "expects at least one non-pattern op"; - return diag; - } - - return success(); -} - -//===----------------------------------------------------------------------===// // PrintOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Transform/PDLExtension/CMakeLists.txt b/mlir/lib/Dialect/Transform/PDLExtension/CMakeLists.txt new file mode 100644 index 0000000..4a60ed4 --- /dev/null +++ b/mlir/lib/Dialect/Transform/PDLExtension/CMakeLists.txt @@ -0,0 +1,13 @@ +add_mlir_dialect_library(MLIRTransformPDLExtension + PDLExtension.cpp + PDLExtensionOps.cpp + + DEPENDS + MLIRTransformDialectPDLExtensionOpsIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRTransformDialect + MLIRPDLDialect + MLIRPDLInterpDialect +) diff --git a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp new file mode 100644 index 0000000..2c770ab --- /dev/null +++ b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtension.cpp @@ -0,0 +1,69 @@ +//===- PDLExtension.cpp - PDL extension for the Transform dialect ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h" +#include "mlir/Dialect/PDL/IR/PDL.h" +#include "mlir/Dialect/PDL/IR/PDLTypes.h" +#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h" +#include "mlir/IR/DialectRegistry.h" + +using namespace mlir; + +namespace { +/// Implementation of the TransformHandleTypeInterface for the PDL +/// OperationType. Accepts any payload operation. +struct PDLOperationTypeTransformHandleTypeInterfaceImpl + : public transform::TransformHandleTypeInterface::ExternalModel< + PDLOperationTypeTransformHandleTypeInterfaceImpl, + pdl::OperationType> { + + /// Accept any operation. + DiagnosedSilenceableFailure + checkPayload(Type type, Location loc, ArrayRef payload) const { + return DiagnosedSilenceableFailure::success(); + } +}; +} // namespace + +namespace { +/// PDL extension of the Transform dialect. This provides transform operations +/// that connect to PDL matching as well as interfaces for PDL types to be used +/// with Transform dialect operations. +class PDLExtension : public transform::TransformDialectExtension { +public: + void init() { + registerTransformOps< +#define GET_OP_LIST +#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc" + >(); + + addDialectDataInitializer( + [](transform::PDLMatchHooks &) {}); + + // Declare PDL as dependent so we can attach an interface to its type in the + // later step. + declareDependentDialect(); + + // PDLInterp is only relevant if we actually apply the transform IR so + // declare it as generated. + declareGeneratedDialect(); + + // Make PDL OperationType usable as a transform dialect type. + addCustomInitializationStep([](MLIRContext *context) { + pdl::OperationType::attachInterface< + PDLOperationTypeTransformHandleTypeInterfaceImpl>(*context); + }); + } +}; +} // namespace + +void mlir::transform::registerPDLExtension(DialectRegistry &dialectRegistry) { + dialectRegistry.addExtensions(); +} diff --git a/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp new file mode 100644 index 0000000..5126d79 --- /dev/null +++ b/mlir/lib/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp @@ -0,0 +1,234 @@ +//===- PDLExtensionOps.cpp - PDL extension for the Transform dialect ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h" +#include "mlir/Dialect/PDL/IR/PDLOps.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Rewrite/PatternApplicator.h" +#include "llvm/ADT/ScopeExit.h" + +using namespace mlir; + +MLIR_DEFINE_EXPLICIT_TYPE_ID(mlir::transform::PDLMatchHooks) + +#define GET_OP_CLASSES +#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc" + +//===----------------------------------------------------------------------===// +// PatternApplicatorExtension +//===----------------------------------------------------------------------===// + +namespace { +/// A TransformState extension that keeps track of compiled PDL pattern sets. +/// This is intended to be used along the WithPDLPatterns op. The extension +/// can be constructed given an operation that has a SymbolTable trait and +/// contains pdl::PatternOp instances. The patterns are compiled lazily and one +/// by one when requested; this behavior is subject to change. +class PatternApplicatorExtension : public transform::TransformState::Extension { +public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension) + + /// Creates the extension for patterns contained in `patternContainer`. + explicit PatternApplicatorExtension(transform::TransformState &state, + Operation *patternContainer) + : Extension(state), patterns(patternContainer) {} + + /// Appends to `results` the operations contained in `root` that matched the + /// PDL pattern with the given name. Note that `root` may or may not be the + /// operation that contains PDL patterns. Reports an error if the pattern + /// cannot be found. Note that when no operations are matched, this still + /// succeeds as long as the pattern exists. + LogicalResult findAllMatches(StringRef patternName, Operation *root, + SmallVectorImpl &results); + +private: + /// Map from the pattern name to a singleton set of rewrite patterns that only + /// contains the pattern with this name. Populated when the pattern is first + /// requested. + // TODO: reconsider the efficiency of this storage when more usage data is + // available. Storing individual patterns in a set and triggering compilation + // for each of them has overhead. So does compiling a large set of patterns + // only to apply a handful of them. + llvm::StringMap compiledPatterns; + + /// A symbol table operation containing the relevant PDL patterns. + SymbolTable patterns; +}; + +LogicalResult PatternApplicatorExtension::findAllMatches( + StringRef patternName, Operation *root, + SmallVectorImpl &results) { + auto it = compiledPatterns.find(patternName); + if (it == compiledPatterns.end()) { + auto patternOp = patterns.lookup(patternName); + if (!patternOp) + return failure(); + + // Copy the pattern operation into a new module that is compiled and + // consumed by the PDL interpreter. + OwningOpRef pdlModuleOp = ModuleOp::create(patternOp.getLoc()); + auto builder = OpBuilder::atBlockEnd(pdlModuleOp->getBody()); + builder.clone(*patternOp); + PDLPatternModule patternModule(std::move(pdlModuleOp)); + + // Merge in the hooks owned by the dialect. Make a copy as they may be + // also used by the following operations. + auto *dialect = + root->getContext()->getLoadedDialect(); + for (const auto &[name, constraintFn] : + dialect->getExtraData() + .getPDLConstraintHooks()) { + patternModule.registerConstraintFunction(name, constraintFn); + } + + // Register a noop rewriter because PDL requires patterns to end with some + // rewrite call. + patternModule.registerRewriteFunction( + "transform.dialect", [](PatternRewriter &, Operation *) {}); + + it = compiledPatterns + .try_emplace(patternOp.getName(), std::move(patternModule)) + .first; + } + + PatternApplicator applicator(it->second); + // We want to discourage direct use of PatternRewriter in APIs but In this + // very specific case, an IRRewriter is not enough. + struct TrivialPatternRewriter : public PatternRewriter { + public: + explicit TrivialPatternRewriter(MLIRContext *context) + : PatternRewriter(context) {} + }; + TrivialPatternRewriter rewriter(root->getContext()); + applicator.applyDefaultCostModel(); + root->walk([&](Operation *op) { + if (succeeded(applicator.matchAndRewrite(op, rewriter))) + results.push_back(op); + }); + + return success(); +} +} // namespace + +//===----------------------------------------------------------------------===// +// PDLMatchHooks +//===----------------------------------------------------------------------===// + +void transform::PDLMatchHooks::mergeInPDLMatchHooks( + llvm::StringMap &&constraintFns) { + // Steal the constraint functions from the given map. + for (auto &it : constraintFns) + pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second)); +} + +const llvm::StringMap & +transform::PDLMatchHooks::getPDLConstraintHooks() const { + return pdlMatchHooks.getConstraintFunctions(); +} + +//===----------------------------------------------------------------------===// +// PDLMatchOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::PDLMatchOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + auto *extension = state.getExtension(); + assert(extension && + "expected PatternApplicatorExtension to be attached by the parent op"); + SmallVector targets; + for (Operation *root : state.getPayloadOps(getRoot())) { + if (failed(extension->findAllMatches( + getPatternName().getLeafReference().getValue(), root, targets))) { + emitDefiniteFailure() + << "could not find pattern '" << getPatternName() << "'"; + } + } + results.set(llvm::cast(getResult()), targets); + return DiagnosedSilenceableFailure::success(); +} + +void transform::PDLMatchOp::getEffects( + SmallVectorImpl &effects) { + onlyReadsHandle(getRoot(), effects); + producesHandle(getMatched(), effects); + onlyReadsPayload(effects); +} + +//===----------------------------------------------------------------------===// +// WithPDLPatternsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure +transform::WithPDLPatternsOp::apply(transform::TransformResults &results, + transform::TransformState &state) { + TransformOpInterface transformOp = nullptr; + for (Operation &nested : getBody().front()) { + if (!isa(nested)) { + transformOp = cast(nested); + break; + } + } + + state.addExtension(getOperation()); + auto guard = llvm::make_scope_exit( + [&]() { state.removeExtension(); }); + + auto scope = state.make_region_scope(getBody()); + if (failed(mapBlockArguments(state))) + return DiagnosedSilenceableFailure::definiteFailure(); + return state.applyTransform(transformOp); +} + +void transform::WithPDLPatternsOp::getEffects( + SmallVectorImpl &effects) { + getPotentialTopLevelEffects(effects); +} + +LogicalResult transform::WithPDLPatternsOp::verify() { + Block *body = getBodyBlock(); + Operation *topLevelOp = nullptr; + for (Operation &op : body->getOperations()) { + if (isa(op)) + continue; + + if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) { + if (topLevelOp) { + InFlightDiagnostic diag = + emitOpError() << "expects only one non-pattern op in its body"; + diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op"; + diag.attachNote(op.getLoc()) << "second non-pattern op"; + return diag; + } + topLevelOp = &op; + continue; + } + + InFlightDiagnostic diag = + emitOpError() + << "expects only pattern and top-level transform ops in its body"; + diag.attachNote(op.getLoc()) << "offending op"; + return diag; + } + + if (auto parent = getOperation()->getParentOfType()) { + InFlightDiagnostic diag = emitOpError() << "cannot be nested"; + diag.attachNote(parent.getLoc()) << "parent operation"; + return diag; + } + + if (!topLevelOp) { + InFlightDiagnostic diag = emitOpError() + << "expects at least one non-pattern op"; + return diag; + } + + return success(); +} diff --git a/mlir/python/CMakeLists.txt b/mlir/python/CMakeLists.txt index b0b4ed9..39dd7b0 100644 --- a/mlir/python/CMakeLists.txt +++ b/mlir/python/CMakeLists.txt @@ -114,6 +114,16 @@ declare_mlir_dialect_python_bindings( DIALECT_NAME linalg DEPENDS LinalgOdsGen) +declare_mlir_dialect_extension_python_bindings( +ADD_TO_PARENT MLIRPythonSources.Dialects +ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" + TD_FILE dialects/TransformPDLExtensionOps.td + SOURCES + dialects/_transform_pdl_extension_ops_ext.py + dialects/transform/pdl.py + DIALECT_NAME transform + EXTENSION_NAME transform_pdl_extension) + declare_mlir_dialect_python_bindings( ADD_TO_PARENT MLIRPythonSources.Dialects ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir" diff --git a/mlir/python/mlir/dialects/TransformPDLExtensionOps.td b/mlir/python/mlir/dialects/TransformPDLExtensionOps.td new file mode 100644 index 0000000..e3e5daf --- /dev/null +++ b/mlir/python/mlir/dialects/TransformPDLExtensionOps.td @@ -0,0 +1,20 @@ +//===-- TransformPDLExtensionOps.td - Binding entry point --*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Entry point of the generated Python bindings for the PDL extension of the +// Transform dialect. +// +//===----------------------------------------------------------------------===// + +#ifndef PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS +#define PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS + +include "mlir/Bindings/Python/Attributes.td" +include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td" + +#endif // PYTHON_BINDINGS_TRANSFORM_PDL_EXTENSION_OPS diff --git a/mlir/python/mlir/dialects/_transform_ops_ext.py b/mlir/python/mlir/dialects/_transform_ops_ext.py index 8651c76..cc4428e 100644 --- a/mlir/python/mlir/dialects/_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_transform_ops_ext.py @@ -60,26 +60,6 @@ class MergeHandlesOp: ) -class PDLMatchOp: - - def __init__( - self, - result_type: Type, - target: Union[Operation, Value], - pattern_name: Union[Attribute, str], - *, - loc=None, - ip=None, - ): - super().__init__( - result_type, - _get_op_result_or_value(target), - pattern_name, - loc=loc, - ip=ip, - ) - - class ReplicateOp: def __init__( @@ -152,28 +132,6 @@ class SequenceOp: return self.body.arguments[1:] -class WithPDLPatternsOp: - - def __init__(self, - target: Union[Operation, Value, Type], - *, - loc=None, - ip=None): - root = _get_op_result_or_value(target) if not isinstance(target, - Type) else None - root_type = target if isinstance(target, Type) else root.type - super().__init__(root=root, loc=loc, ip=ip) - self.regions[0].blocks.append(root_type) - - @property - def body(self) -> Block: - return self.regions[0].blocks[0] - - @property - def bodyTarget(self) -> Value: - return self.body.arguments[0] - - class YieldOp: def __init__( diff --git a/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py b/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py new file mode 100644 index 0000000..c4e4b4b --- /dev/null +++ b/mlir/python/mlir/dialects/_transform_pdl_extension_ops_ext.py @@ -0,0 +1,55 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +try: + from ..ir import * + from ._ods_common import ( + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + ) +except ImportError as e: + raise RuntimeError("Error loading imports from extension module") from e + +from typing import Union + +class PDLMatchOp: + + def __init__( + self, + result_type: Type, + target: Union[Operation, Value], + pattern_name: Union[Attribute, str], + *, + loc=None, + ip=None, + ): + super().__init__( + result_type, + _get_op_result_or_value(target), + pattern_name, + loc=loc, + ip=ip, + ) + + +class WithPDLPatternsOp: + + def __init__(self, + target: Union[Operation, Value, Type], + *, + loc=None, + ip=None): + root = _get_op_result_or_value(target) if not isinstance(target, + Type) else None + root_type = target if isinstance(target, Type) else root.type + super().__init__(root=root, loc=loc, ip=ip) + self.regions[0].blocks.append(root_type) + + @property + def body(self) -> Block: + return self.regions[0].blocks[0] + + @property + def bodyTarget(self) -> Value: + return self.body.arguments[0] diff --git a/mlir/python/mlir/dialects/transform/pdl.py b/mlir/python/mlir/dialects/transform/pdl.py new file mode 100644 index 0000000..b151528 --- /dev/null +++ b/mlir/python/mlir/dialects/transform/pdl.py @@ -0,0 +1,5 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from .._transform_pdl_extension_ops_gen import * diff --git a/mlir/test/Dialect/Transform/test-interpreter.mlir b/mlir/test/Dialect/Transform/test-interpreter.mlir index 11a5568..a885c89 100644 --- a/mlir/test/Dialect/Transform/test-interpreter.mlir +++ b/mlir/test/Dialect/Transform/test-interpreter.mlir @@ -83,33 +83,6 @@ transform.sequence failures(propagate) { // ----- -transform.with_pdl_patterns { -^bb0(%arg0: !transform.any_op): - sequence %arg0 : !transform.any_op failures(propagate) { - ^bb0(%arg1: !transform.any_op): - %0 = pdl_match @some in %arg1 : (!transform.any_op) -> !transform.any_op - test_print_remark_at_operand %0, "matched" : !transform.any_op - } - - pdl.pattern @some : benefit(1) { - %0 = pdl.operation "test.some_op" - pdl.rewrite %0 with "transform.dialect" - } - - pdl.pattern @other : benefit(1) { - %0 = pdl.operation "test.other_op" - pdl.rewrite %0 with "transform.dialect" - } -} - -// expected-remark @below {{matched}} -"test.some_op"() : () -> () -"test.other_op"() : () -> () -// expected-remark @below {{matched}} -"test.some_op"() : () -> () - -// ----- - // expected-remark @below {{parent function}} func.func @foo() { %0 = arith.constant 0 : i32 diff --git a/mlir/test/Dialect/Transform/test-pdl-extension.mlir b/mlir/test/Dialect/Transform/test-pdl-extension.mlir new file mode 100644 index 0000000..b5f9fbf --- /dev/null +++ b/mlir/test/Dialect/Transform/test-pdl-extension.mlir @@ -0,0 +1,47 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter -allow-unregistered-dialect --split-input-file --verify-diagnostics + +transform.with_pdl_patterns { +^bb0(%arg0: !transform.any_op): + sequence %arg0 : !transform.any_op failures(propagate) { + ^bb0(%arg1: !transform.any_op): + %0 = pdl_match @some in %arg1 : (!transform.any_op) -> !transform.any_op + test_print_remark_at_operand %0, "matched" : !transform.any_op + } + + pdl.pattern @some : benefit(1) { + %0 = pdl.operation "test.some_op" + pdl.rewrite %0 with "transform.dialect" + } + + pdl.pattern @other : benefit(1) { + %0 = pdl.operation "test.other_op" + pdl.rewrite %0 with "transform.dialect" + } +} + +// expected-remark @below {{matched}} +"test.some_op"() : () -> () +"test.other_op"() : () -> () +// expected-remark @below {{matched}} +"test.some_op"() : () -> () + + +// ----- + +transform.with_pdl_patterns { +^bb0(%arg0: !transform.any_op): + sequence %arg0 : !transform.any_op failures(propagate) { + ^bb1(%arg1: !transform.any_op): + %0 = pdl_match @some in %arg1 : (!transform.any_op) -> !transform.any_op + } + + pdl.pattern @some : benefit(1) { + %0 = pdl.operation "test.some_op" + pdl.apply_native_constraint "verbose_constraint"(%0 : !pdl.operation) + pdl.rewrite %0 with "transform.dialect" + } +} + +// expected-warning @below {{from PDL constraint}} +"test.some_op"() : () -> () +"test.other_op"() : () -> () diff --git a/mlir/test/lib/Dialect/Transform/CMakeLists.txt b/mlir/test/lib/Dialect/Transform/CMakeLists.txt index b86b8f5..c7e83d3 100644 --- a/mlir/test/lib/Dialect/Transform/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Transform/CMakeLists.txt @@ -21,4 +21,5 @@ add_mlir_library(MLIRTestTransformDialect MLIRPDLDialect MLIRTransformDialect MLIRTransformDialectTransforms + MLIRTransformPDLExtension ) diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp index 50a4c92..2b23b88 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp @@ -17,7 +17,9 @@ #include "mlir/Dialect/Transform/IR/TransformDialect.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" #include "mlir/Dialect/Transform/IR/TransformOps.h" +#include "mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h" #include "mlir/IR/OpImplementation.h" +#include "mlir/IR/PatternMatch.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Compiler.h" @@ -754,6 +756,23 @@ public: #define GET_TYPEDEF_LIST #include "TestTransformDialectExtensionTypes.cpp.inc" >(); + + auto verboseConstraint = [](PatternRewriter &rewriter, + ArrayRef pdlValues) { + for (const PDLValue &pdlValue : pdlValues) { + if (Operation *op = pdlValue.dyn_cast()) { + op->emitWarning() << "from PDL constraint"; + } + } + return success(); + }; + + addDialectDataInitializer( + [&](transform::PDLMatchHooks &hooks) { + llvm::StringMap constraints; + constraints.try_emplace("verbose_constraint", verboseConstraint); + hooks.mergeInPDLMatchHooks(std::move(constraints)); + }); } }; } // namespace diff --git a/mlir/test/python/dialects/transform.py b/mlir/test/python/dialects/transform.py index 5b64582..6b36c02 100644 --- a/mlir/test/python/dialects/transform.py +++ b/mlir/test/python/dialects/transform.py @@ -2,7 +2,7 @@ from mlir.ir import * from mlir.dialects import transform -from mlir.dialects import pdl +from mlir.dialects.transform import pdl as transform_pdl def run(f): @@ -103,13 +103,13 @@ def testNestedSequenceOpWithExtras(): @run def testTransformPDLOps(): - withPdl = transform.WithPDLPatternsOp(transform.AnyOpType.get()) + withPdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) with InsertionPoint(withPdl.body): sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [transform.AnyOpType.get()], withPdl.bodyTarget) with InsertionPoint(sequence.body): - match = transform.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher") + match = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "pdl_matcher") transform.YieldOp(match) # CHECK-LABEL: TEST: testTransformPDLOps # CHECK: transform.with_pdl_patterns { @@ -148,13 +148,13 @@ def testMergeHandlesOp(): @run def testReplicateOp(): - with_pdl = transform.WithPDLPatternsOp(transform.AnyOpType.get()) + with_pdl = transform_pdl.WithPDLPatternsOp(transform.AnyOpType.get()) with InsertionPoint(with_pdl.body): sequence = transform.SequenceOp( transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget) with InsertionPoint(sequence.body): - m1 = transform.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "first") - m2 = transform.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "second") + m1 = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "first") + m2 = transform_pdl.PDLMatchOp(transform.AnyOpType.get(), sequence.bodyTarget, "second") transform.ReplicateOp(m1, [m2]) transform.YieldOp() # CHECK-LABEL: TEST: testReplicateOp diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py index 9684bfb..d2a82b8 100644 --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -4,6 +4,7 @@ from mlir.ir import * from mlir.dialects import transform from mlir.dialects import pdl from mlir.dialects.transform import structured +from mlir.dialects.transform import pdl as transform_pdl def run(f): @@ -151,13 +152,13 @@ def testTileZero(): @run def testTileDynamic(): - with_pdl = transform.WithPDLPatternsOp(pdl.OperationType.get()) + with_pdl = transform_pdl.WithPDLPatternsOp(pdl.OperationType.get()) with InsertionPoint(with_pdl.body): sequence = transform.SequenceOp(transform.FailurePropagationMode.PROPAGATE, [], with_pdl.bodyTarget) with InsertionPoint(sequence.body): - m1 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "first") - m2 = transform.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "second") + m1 = transform_pdl.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "first") + m2 = transform_pdl.PDLMatchOp(pdl.OperationType.get(), sequence.bodyTarget, "second") structured.TileOp(sequence.bodyTarget, sizes=[m1, 3, m2, 0]) transform.YieldOp() diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index c9ba651..b36bdf9 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -7495,6 +7495,7 @@ cc_library( ":TosaToLinalg", ":TransformDialect", ":TransformDialectTransforms", + ":TransformPDLExtension", ":Transforms", ":TransformsPassIncGen", ":VectorDialect", @@ -9732,7 +9733,6 @@ td_library( ":ControlFlowInterfacesTdFiles", ":InferTypeOpInterfaceTdFiles", ":OpBaseTdFiles", - ":PDLDialectTdFiles", ":SideEffectInterfacesTdFiles", ], ) @@ -9889,8 +9889,6 @@ cc_library( ":CallOpInterfaces", ":ControlFlowInterfaces", ":IR", - ":PDLDialect", - ":PDLInterpDialect", ":Rewrite", ":SideEffectInterfaces", ":Support", @@ -9907,6 +9905,54 @@ cc_library( ) td_library( + name = "TransformPDLExtensionTdFiles", + srcs = glob(["include/mlir/Dialect/Transform/PDLExtension/*.td"]), + deps = [ + ":PDLDialectTdFiles", + ":TransformDialectTdFiles", + ], +) + +gentbl_cc_library( + name = "TransformPDLExtensionOpsIncGen", + strip_include_prefix = "include", + tbl_outs = [ + ( + [ + "-gen-op-decls", + ], + "include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.h.inc", + ), + ( + [ + "-gen-op-defs", + ], + "include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.cpp.inc", + ), + ], + tblgen = ":mlir-tblgen", + td_file = "include/mlir/Dialect/Transform/PDLExtension/PDLExtensionOps.td", + deps = [":TransformPDLExtensionTdFiles"], +) + +cc_library( + name = "TransformPDLExtension", + srcs = glob(["lib/Dialect/Transform/PDLExtension/*.cpp"]), + hdrs = glob(["include/mlir/Dialect/Transform/PDLExtension/*.h"]), + deps = [ + ":IR", + ":PDLDialect", + ":PDLInterpDialect", + ":SideEffectInterfaces", + ":Support", + ":TransformDialect", + ":TransformPDLExtensionOpsIncGen", + ":Rewrite", + "//llvm:Support", + ], +) + +td_library( name = "TransformDialectTransformsTdFiles", srcs = glob(["include/mlir/Dialect/Transform/Transforms/*.td"]), deps = [ diff --git a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel index 06a97f4..f6c87ea 100644 --- a/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/python/BUILD.bazel @@ -927,6 +927,26 @@ gentbl_filegroup( ], ) +gentbl_filegroup( + name = "PDLTransformOpsPyGen", + tbl_outs = [ + ( + [ + "-gen-python-op-bindings", + "-bind-dialect=transform", + "-dialect-extension=transform_pdl_extension", + ], + "mlir/dialects/_transform_pdl_extension_ops_gen.py", + ), + ], + tblgen = "//mlir:mlir-tblgen", + td_file = "mlir/dialects/TransformPDLExtensionOps.td", + deps = [ + ":TransformOpsPyTdFiles", + "//mlir:TransformPDLExtensionTdFiles", + ], +) + filegroup( name = "TransformOpsPyFiles", srcs = [ @@ -934,6 +954,7 @@ filegroup( "mlir/dialects/_structured_transform_ops_ext.py", "mlir/dialects/_transform_ops_ext.py", ":LoopTransformOpsPyGen", + ":PDLTransformOpsPyGen", ":StructuredTransformOpsPyGen", ":TransformOpsPyGen", ], diff --git a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel index b4f8ca7..c95aea5 100644 --- a/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/test/BUILD.bazel @@ -317,6 +317,7 @@ gentbl_cc_library( ":TransformDialectTdFiles", "//mlir:PDLDialectTdFiles", "//mlir:TransformDialectTdFiles", + "//mlir:TransformPDLExtension", ], ) @@ -333,6 +334,7 @@ cc_library( "//mlir:Pass", "//mlir:TransformDialect", "//mlir:TransformDialectTransforms", + "//mlir:TransformPDLExtension", ], ) -- 2.7.4