#ifndef MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT_H
+#include "mlir/Dialect/PDL/IR/PDL.h"
+#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/Dialect.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/StringMap.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h.inc"
loader(context);
for (const Initializer &init : opInitializers)
init(transformDialect);
+ transformDialect->mergeInPDLMatchHooks(std::move(pdlMatchConstraintFns));
}
protected:
[](MLIRContext *context) { context->loadDialect<DialectTy>(); });
}
+ /// Injects the named constraint to make it available for use with the
+ /// PDLMatchOp in the transform dialect.
+ void registerPDLMatchConstraintFn(StringRef name,
+ PDLConstraintFunction &&fn) {
+ pdlMatchConstraintFns.try_emplace(name,
+ std::forward<PDLConstraintFunction>(fn));
+ }
+ template <typename ConstraintFnTy>
+ void registerPDLMatchConstraintFn(StringRef name, ConstraintFnTy &&fn) {
+ pdlMatchConstraintFns.try_emplace(
+ name, ::mlir::detail::pdl_function_builder::buildConstraintFn(
+ std::forward<ConstraintFnTy>(fn)));
+ }
+
private:
SmallVector<Initializer> opInitializers;
SmallVector<DialectLoader> dialectLoaders;
+
+ /// A list of constraints that should be made availble to PDL patterns
+ /// processed by PDLMatchOp in the Transform dialect.
+ ///
+ /// Declared as mutable so its contents can be moved in the `apply` const
+ /// method, which is only called once.
+ mutable llvm::StringMap<PDLConstraintFunction> pdlMatchConstraintFns;
};
} // namespace transform
`LoopTransformDialectExtension` in the cases above. Unprefixed operation
names are reserved for ops defined directly in the Transform dialect.
+ Overall, Transform IR ops are expected to be contained in a single top-level
+ op. Such top-level ops specifie how to apply the transformations described
+ by operations they contain, e.g., `transform.sequence` executes
+ transformations one by one and fails if any of them fails. Such ops are
+ expected to have the `PossibleTopLevelTransformOpTrait` and may be used
+ without arguments.
+
## Intended Use and Integrations
The transformation control infrastructure provided by this dialect is
let cppNamespace = "::mlir::transform";
let emitAccessorPrefix = kEmitAccessorPrefix_Prefixed;
+ let dependentDialects = [
+ "::mlir::pdl::PDLDialect",
+ "::mlir::pdl_interp::PDLInterpDialect",
+ ];
+
let extraClassDeclaration = [{
- // Make addOperations available to the TransformDialectExtension class.
+ /// Returns the named PDL constraint functions available in the dialect
+ /// as a map from their name to the function.
+ const ::llvm::StringMap<::mlir::PDLConstraintFunction> &
+ getPDLConstraintHooks() const;
+
private:
+ // Make addOperations available to the TransformDialectExtension class.
using ::mlir::Dialect::addOperations;
template <typename, typename...>
friend class TransformDialectExtension;
+
+ /// Takes ownership of the named PDL constraint function from the given
+ /// map and makes them available for use by the operations in the dialect.
+ void mergeInPDLMatchHooks(
+ ::llvm::StringMap<::mlir::PDLConstraintFunction> &&constraintFns);
+
+ /// A container for PDL constraint function that can be used by
+ /// operations in this dialect.
+ PDLPatternModule pdlMatchHooks;
}];
}
class TransformDialectOp<string mnemonic, list<Trait> traits = []>
: Op<Transform_Dialect, mnemonic, traits>;
+// Trait for operations that may be top-level operations in Transform IR.
+// Operations must have one single-block region and must be usable without
+// operands. See the C++ definition of the trait for more information.
+def PossibleTopLevelTransformOpTrait
+ : NativeOpTrait<"PossibleTopLevelTransformOpTrait"> {
+ let cppNamespace = "::mlir::transform";
+}
+
#endif // MLIR_DIALECT_TRANSFORM_IR_TRANSFORMDIALECT
};
friend class RegionScope;
+ /// Base class for TransformState extensions that allow TransformState to
+ /// contain user-specified information in the state object. Clients are
+ /// expected to derive this class, add the desired fields, and make the
+ /// derived class compatible with the MLIR TypeID mechanism:
+ ///
+ /// ```mlir
+ /// class MyExtension final : public TransformState::Extension {
+ /// public:
+ /// MyExtension(TranfsormState &state, int myData)
+ /// : Extension(state) {...}
+ /// private:
+ /// int mySupplementaryData;
+ /// };
+ /// ```
+ ///
+ /// Instances of this and derived classes are not expected to be created by
+ /// the user, instead they are directly constructed within a TransformState. A
+ /// TransformState can only contain one extension with the given TypeID.
+ /// Extensions can be obtained from a TransformState instance, and can be
+ /// removed when they are no longer required.
+ ///
+ /// ```mlir
+ /// transformState.addExtension<MyExtension>(/*myData=*/42);
+ /// MyExtension *ext = transformState.getExtension<MyExtension>();
+ /// ext->doSomething();
+ /// ```
+ class Extension {
+ // Allow TransformState to allocate Extensions.
+ friend class TransformState;
+
+ public:
+ /// Base virtual destructor.
+ // Out-of-line definition ensures symbols are emitted in a single object
+ // file.
+ virtual ~Extension();
+
+ protected:
+ /// Constructs an extension of the given TransformState object.
+ Extension(TransformState &state) : state(state) {}
+
+ private:
+ /// Back-reference to the state that is being extended.
+ TransformState &state;
+ };
+
+ /// Adds a new Extension of the type specified as template parameter,
+ /// constructing it with the arguments provided. The extension is owned by the
+ /// TransformState. It is expected that the state does not already have an
+ /// extension of the same type. Extension constructors are expected to take
+ /// a reference to TransformState as first argument, automatically supplied
+ /// by this call.
+ template <typename Ty, typename... Args>
+ Ty &addExtension(Args &&...args) {
+ static_assert(
+ std::is_base_of<Extension, Ty>::value,
+ "only an class derived from TransformState::Extension is allowed here");
+ auto ptr = std::make_unique<Ty>(*this, std::forward<Args>(args)...);
+ auto result = extensions.try_emplace(TypeID::get<Ty>(), std::move(ptr));
+ assert(result.second && "extension already added");
+ return *static_cast<Ty *>(result.first->second.get());
+ }
+
+ /// Returns the extension of the specified type.
+ template <typename Ty>
+ Ty *getExtension() {
+ static_assert(
+ std::is_base_of<Extension, Ty>::value,
+ "only an class derived from TransformState::Extension is allowed here");
+ auto iter = extensions.find(TypeID::get<Ty>());
+ if (iter == extensions.end())
+ return nullptr;
+ return static_cast<Ty *>(iter->second.get());
+ }
+
+ /// Removes the extension of the specified type.
+ template <typename Ty>
+ void removeExtension() {
+ static_assert(
+ std::is_base_of<Extension, Ty>::value,
+ "only an class derived from TransformState::Extension is allowed here");
+ extensions.erase(TypeID::get<Ty>());
+ }
+
private:
/// Identifier for storing top-level value in the `operations` mapping.
static constexpr Value kTopLevelValue = Value();
/// the region in which the transform IR values are defined.
llvm::SmallDenseMap<Region *, Mappings> mappings;
+ /// Extensions attached to the TransformState, identified by the TypeID of
+ /// their type. Only one extension of any given type is allowed.
+ DenseMap<TypeID, std::unique_ptr<Extension>> extensions;
+
/// The top-level operation that contains all payload IR, typically a module.
Operation *topLevel;
return RegionScope(*this, region);
}
+namespace detail {
+/// Maps the only block argument of the op with PossibleTopLevelTransformOpTrait
+/// to either the list of operations associated with its operand or the root of
+/// the payload IR, depending on what is available in the context.
+LogicalResult
+mapPossibleTopLevelTransformOpBlockArguments(TransformState &state,
+ Operation *op);
+
+/// Verification hook for PossibleTopLevelTransformOpTrait.
+LogicalResult verifyPossibleTopLevelTransformOpTrait(Operation *op);
+} // namespace detail
+
+/// This trait is supposed to be attached to Transform dialect operations that
+/// can be standalone top-level transforms. Such operations typically contain
+/// other Transform dialect operations that can be executed following some
+/// control flow logic specific to the current operation. The operations with
+/// this trait are expected to have exactly one single-block region with one
+/// argument of PDL Operation type. The operations are also expected to be valid
+/// without operands, in which case they are considered top-level, and with one
+/// or more arguments, in which case they are considered nested. Top-level
+/// operations have the block argument of the entry block in the Transform IR
+/// correspond to the root operation of Payload IR. Nested operations have the
+/// block argument of the entry block in the Transform IR correspond to a list
+/// of Payload IR operations mapped to the first operand of the Transform IR
+/// operation. The operation must implement TransformOpInterface.
+template <typename OpTy>
+class PossibleTopLevelTransformOpTrait
+ : public OpTrait::TraitBase<OpTy, PossibleTopLevelTransformOpTrait> {
+public:
+ /// Verifies that `op` satisfies the invariants of this trait. Not expected to
+ /// be called directly.
+ static LogicalResult verifyTrait(Operation *op) {
+ return detail::verifyPossibleTopLevelTransformOpTrait(op);
+ }
+
+ /// Returns the single block of the op's only region.
+ Block *getBodyBlock() { return &this->getOperation()->getRegion(0).front(); }
+
+ /// Sets up the mapping between the entry block of the only region of this op
+ /// and the relevant list of Payload IR operations in the given state. The
+ /// state is expected to be already scoped at the region of this operation.
+ /// Returns failure if the mapping failed, e.g., the value is already mapped.
+ LogicalResult mapBlockArguments(TransformState &state) {
+ return detail::mapPossibleTopLevelTransformOpBlockArguments(
+ state, this->getOperation());
+ }
+};
+
} // namespace transform
} // namespace mlir
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/SymbolTable.h"
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/IR/TransformOps.h.inc"
#define MLIR_DIALECT_TRANSFORM_IR_TRANSFORMOPS
include "mlir/IR/OpAsmInterface.td"
+include "mlir/IR/SymbolInterfaces.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
+def PDLMatchOp : TransformDialectOp<"pdl_match",
+ [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let summary = "Finds ops that match the named PDL pattern";
+ let description = [{
+ Find Payload IR ops nested within the Payload IR op associated with the
+ operand that match the PDL pattern identified by its name. The pattern is
+ expected to be defined in the closest surrounding `WithPDLPatternsOp`.
+
+ Produces a Transform IR value associated with the list of Payload IR ops
+ that matched the pattern. The order of results in the list is that of the
+ Operation::walk, clients are advised not to rely on a specific order though.
+ If the operand is assocaited with multiple Payload IR ops, finds matching
+ ops nested within each of those and produces a single list containing all
+ of the matched ops.
+
+ The tranfsormation is considered successful regardless of whether some
+ Payload IR ops actually matched the pattern and only fails if the pattern
+ could not be looked up or compiled.
+ }];
+
+ let arguments = (ins PDL_Operation:$root, SymbolRefAttr:$pattern_name);
+ let results = (outs PDL_Operation:$matched);
+
+ let assemblyFormat = "$pattern_name `in` $root attr-dict";
+}
+
def SequenceOp : TransformDialectOp<"sequence",
[DeclareOpInterfaceMethods<TransformOpInterface>, OpAsmOpInterface,
+ PossibleTopLevelTransformOpTrait,
SingleBlockImplicitTerminator<"::mlir::transform::YieldOp">]> {
let summary = "Contains a sequence of other transform ops to apply";
let description = [{
let extraClassDeclaration = [{
/// Allow the dialect prefix to be omitted.
static StringRef getDefaultDialect() { return "transform"; }
+ }];
+
+ let hasVerifier = 1;
+}
- Block *getBodyBlock() {
- return &getBody().front();
+def WithPDLPatternsOp : TransformDialectOp<"with_pdl_patterns",
+ [DeclareOpInterfaceMethods<TransformOpInterface>, NoTerminator,
+ OpAsmOpInterface, PossibleTopLevelTransformOpTrait, SymbolTable]> {
+ let summary = "Contains PDL patterns available for use in transforms";
+ let description = [{
+ This op contains a set of named PDL patterns that are available for the
+ Transform dialect operations to be used for pattern matching. For example,
+ PDLMatchOp can be used to produce a Transform IR value associated with all
+ Payload IR operations that match the pattern as follows:
+
+ ```mlir
+ transform.with_pdl_patterns {
+ ^bb0(%arg0: !pdl.operation):
+ pdl.pattern @my_pattern : benefit(1) {
+ %0 = pdl.operation //...
+ // Regular PDL goes here.
+ pdl.rewrite %0 with "transform.dialect"
+ }
+
+ sequence %arg0 {
+ ^bb0(%arg1: !pdl.operation):
+ %1 = pdl_match @my_pattern in %arg1
+ // Use %1 as handle
+ }
}
+ ```
+
+ Note that the pattern is expected to finish with a `pdl.rewrite` terminator
+ that points to the custom rewriter named "transform.dialect". The rewriter
+ actually does nothing, but the transform application will keep track of the
+ operations that matched the pattern.
+
+ This op is expected to contain `pdl.pattern` operations and exactly one
+ another Transform dialect operation that gets executed with all patterns
+ available. This op is a possible top-level Transform IR op, the argument of
+ its entry block corresponds to either the root op of the payload IR or the
+ ops associated with its operand when provided.
}];
+ let arguments = (ins Optional<PDL_Operation>:$root);
+ let regions = (region SizedRegion<1>:$body);
+ let assemblyFormat = "($root^)? attr-dict-with-keyword regions";
+
let hasVerifier = 1;
+
+ let extraClassDeclaration = [{
+ /// Allow the dialect prefix to be omitted.
+ static StringRef getDefaultDialect() { return "transform"; }
+ }];
}
def YieldOp : TransformDialectOp<"yield", [Terminator]> {
MLIRIR
MLIRPDL
MLIRPDLInterp
+ MLIRRewrite
)
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
>();
}
+
+void transform::TransformDialect::mergeInPDLMatchHooks(
+ llvm::StringMap<PDLConstraintFunction> &&constraintFns) {
+ // Steal the constraint functions form the given map.
+ for (auto &it : constraintFns)
+ pdlMatchHooks.registerConstraintFunction(it.getKey(), std::move(it.second));
+}
+
+const llvm::StringMap<PDLConstraintFunction> &
+transform::TransformDialect::getPDLConstraintHooks() const {
+ return pdlMatchHooks.getConstraintFunctions();
+}
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Operation.h"
#include "llvm/ADT/ScopeExit.h"
return success();
}
+transform::TransformState::Extension::~Extension() = default;
+
//===----------------------------------------------------------------------===//
// TransformResults
//===----------------------------------------------------------------------===//
return segments[resultNumber];
}
+//===----------------------------------------------------------------------===//
+// Utilities for PossibleTopLevelTransformOpTrait.
+//===----------------------------------------------------------------------===//
+
+LogicalResult transform::detail::mapPossibleTopLevelTransformOpBlockArguments(
+ TransformState &state, Operation *op) {
+ SmallVector<Operation *> targets;
+ if (op->getNumOperands() != 0)
+ llvm::append_range(targets, state.getPayloadOps(op->getOperand(0)));
+ else
+ targets.push_back(state.getTopLevel());
+
+ return state.mapBlockArguments(op->getRegion(0).front().getArgument(0),
+ targets);
+}
+
+LogicalResult
+transform::detail::verifyPossibleTopLevelTransformOpTrait(Operation *op) {
+ // Attaching this trait without the interface is a misuse of the API, but it
+ // cannot be caught via a static_assert because interface registration is
+ // dynamic.
+ assert(isa<TransformOpInterface>(op) &&
+ "should implement TransformOpInterface to have "
+ "PossibleTopLevelTransformOpTrait");
+
+ if (op->getNumRegions() != 1)
+ return op->emitOpError() << "expects one region";
+
+ Region *bodyRegion = &op->getRegion(0);
+ if (!llvm::hasNItems(*bodyRegion, 1))
+ return op->emitOpError() << "expects a single-block region";
+
+ Block *body = &bodyRegion->front();
+ if (body->getNumArguments() != 1 ||
+ !body->getArgumentTypes()[0].isa<pdl::OperationType>()) {
+ return op->emitOpError()
+ << "expects the entry block to have one argument of type "
+ << pdl::OperationType::get(op->getContext());
+ }
+
+ if (auto *parent =
+ op->getParentWithTrait<PossibleTopLevelTransformOpTrait>()) {
+ if (op->getNumOperands() == 0) {
+ InFlightDiagnostic diag =
+ op->emitOpError()
+ << "expects the root operation to be provided for a nested op";
+ diag.attachNote(parent->getLoc())
+ << "nested in another possible top-level op";
+ return diag;
+ }
+ }
+
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// Generated interface implementation.
//===----------------------------------------------------------------------===//
//===----------------------------------------------------------------------===//
#include "mlir/Dialect/Transform/IR/TransformOps.h"
+#include "mlir/Dialect/PDL/IR/PDLOps.h"
+#include "mlir/Dialect/Transform/IR/TransformDialect.h"
+#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/Builders.h"
-
#include "mlir/IR/OpImplementation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Rewrite/FrozenRewritePatternSet.h"
+#include "mlir/Rewrite/PatternApplicator.h"
+#include "llvm/ADT/ScopeExit.h"
using namespace mlir;
#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/IR/TransformOps.cpp.inc"
-LogicalResult transform::SequenceOp::apply(transform::TransformResults &results,
+//===----------------------------------------------------------------------===//
+// PatternApplicatorExtension
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A simple pattern rewriter that can be constructed from a context. This is
+/// necessary to apply patterns to a specific op locally.
+class TrivialPatternRewriter : public PatternRewriter {
+public:
+ explicit TrivialPatternRewriter(MLIRContext *context)
+ : PatternRewriter(context) {}
+};
+
+/// A TransformState extension that keeps track of compiled PDL pattern sets.
+/// This is intended to be used along the WithPDLPatterns op. The extension
+/// can be constructed given an operation that has a SymbolTable trait and
+/// contains pdl::PatternOp instances. The patterns are compiled lazily and one
+/// by one when requested; this behavior is subject to change.
+class PatternApplicatorExtension : public transform::TransformState::Extension {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PatternApplicatorExtension)
+
+ /// Creates the extension for patterns contained in `patternContainer`.
+ explicit PatternApplicatorExtension(transform::TransformState &state,
+ Operation *patternContainer)
+ : Extension(state), patterns(patternContainer) {}
+
+ /// Appends to `results` the operations contained in `root` that matched the
+ /// PDL pattern with the given name. Note that `root` may or may not be the
+ /// operation that contains PDL patterns. Reports an error if the pattern
+ /// cannot be found. Note that when no operations are matched, this still
+ /// succeeds as long as the pattern exists.
+ LogicalResult findAllMatches(StringRef patternName, Operation *root,
+ SmallVectorImpl<Operation *> &results);
+
+private:
+ /// Map from the pattern name to a singleton set of rewrite patterns that only
+ /// contains the pattern with this name. Populated when the pattern is first
+ /// requested.
+ // TODO: reconsider the efficiency of this storage when more usage data is
+ // available. Storing individual patterns in a set and triggering compilation
+ // for each of them has overhead. So does compiling a large set of patterns
+ // only to apply a handlful of them.
+ llvm::StringMap<FrozenRewritePatternSet> compiledPatterns;
+
+ /// A symbol table operation containing the relevant PDL patterns.
+ SymbolTable patterns;
+};
+
+LogicalResult PatternApplicatorExtension::findAllMatches(
+ StringRef patternName, Operation *root,
+ SmallVectorImpl<Operation *> &results) {
+ auto it = compiledPatterns.find(patternName);
+ if (it == compiledPatterns.end()) {
+ auto patternOp = patterns.lookup<pdl::PatternOp>(patternName);
+ if (!patternOp)
+ return failure();
+
+ OwningOpRef<ModuleOp> pdlModuleOp = ModuleOp::create(patternOp.getLoc());
+ patternOp->moveBefore(pdlModuleOp->getBody(),
+ pdlModuleOp->getBody()->end());
+ PDLPatternModule patternModule(std::move(pdlModuleOp));
+
+ // Merge in the hooks owned by the dialect. Make a copy as they may be
+ // also used by the following operations.
+ auto *dialect =
+ root->getContext()->getLoadedDialect<transform::TransformDialect>();
+ for (const auto &pair : dialect->getPDLConstraintHooks())
+ patternModule.registerConstraintFunction(pair.first(), pair.second);
+
+ // Register a noop rewriter because PDL requires patterns to end with some
+ // rewrite call.
+ patternModule.registerRewriteFunction(
+ "transform.dialect", [](PatternRewriter &, Operation *) {});
+
+ it = compiledPatterns
+ .try_emplace(patternOp.getName(), std::move(patternModule))
+ .first;
+ }
+
+ PatternApplicator applicator(it->second);
+ TrivialPatternRewriter rewriter(root->getContext());
+ applicator.applyDefaultCostModel();
+ root->walk([&](Operation *op) {
+ if (succeeded(applicator.matchAndRewrite(op, rewriter)))
+ results.push_back(op);
+ });
+
+ return success();
+}
+} // namespace
+
+//===----------------------------------------------------------------------===//
+// PDLMatchOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult transform::PDLMatchOp::apply(transform::TransformResults &results,
transform::TransformState &state) {
+ auto *extension = state.getExtension<PatternApplicatorExtension>();
+ assert(extension &&
+ "expected PatternApplicatorExtension to be attached by the parent op");
SmallVector<Operation *> targets;
- if (getRoot())
- llvm::append_range(targets, state.getPayloadOps(getRoot()));
- else
- targets.push_back(state.getTopLevel());
+ for (Operation *root : state.getPayloadOps(getRoot())) {
+ if (failed(extension->findAllMatches(
+ getPatternName().getLeafReference().getValue(), root, targets))) {
+ return emitOpError() << "could not find pattern '" << getPatternName()
+ << "'";
+ }
+ }
+ results.set(getResult().cast<OpResult>(), targets);
+ return success();
+}
+//===----------------------------------------------------------------------===//
+// SequenceOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult transform::SequenceOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
// Map the entry block argument to the list of operations.
auto scope = state.make_region_scope(*getBodyBlock()->getParent());
- if (failed(state.mapBlockArguments(getBodyBlock()->getArgument(0), targets)))
+ if (failed(mapBlockArguments(state)))
return failure();
// Apply the sequenced ops one by one.
}
LogicalResult transform::SequenceOp::verify() {
- if (getBodyBlock()->getNumArguments() != 1 ||
- !getBodyBlock()->getArgumentTypes()[0].isa<pdl::OperationType>()) {
- return emitOpError()
- << "expected the entry block to have one argument of type "
- << pdl::OperationType::get(getContext());
- }
-
- if (auto parent = getOperation()->getParentOfType<transform::SequenceOp>()) {
- if (!getRoot()) {
- InFlightDiagnostic diag =
- emitOpError()
- << "expected the root operation to be provided for a nested sequence";
- diag.attachNote(parent.getLoc()) << "nested in another sequence";
- return diag;
- }
- }
-
for (Operation &child : *getBodyBlock()) {
if (!isa<TransformOpInterface>(child) &&
&child != &getBodyBlock()->back()) {
}
return success();
}
+
+//===----------------------------------------------------------------------===//
+// WithPDLPatternsOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult
+transform::WithPDLPatternsOp::apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ OwningOpRef<ModuleOp> pdlModuleOp =
+ ModuleOp::create(getOperation()->getLoc());
+ TransformOpInterface transformOp = nullptr;
+ for (Operation &nested : getBody().front()) {
+ if (!isa<pdl::PatternOp>(nested)) {
+ transformOp = cast<TransformOpInterface>(nested);
+ break;
+ }
+ }
+
+ state.addExtension<PatternApplicatorExtension>(getOperation());
+ auto guard = llvm::make_scope_exit(
+ [&]() { state.removeExtension<PatternApplicatorExtension>(); });
+
+ auto scope = state.make_region_scope(getBody());
+ if (failed(mapBlockArguments(state)))
+ return failure();
+ return state.applyTransform(transformOp);
+}
+
+LogicalResult transform::WithPDLPatternsOp::verify() {
+ Block *body = getBodyBlock();
+ Operation *topLevelOp = nullptr;
+ for (Operation &op : body->getOperations()) {
+ if (isa<pdl::PatternOp>(op))
+ continue;
+
+ if (op.hasTrait<::mlir::transform::PossibleTopLevelTransformOpTrait>()) {
+ if (topLevelOp) {
+ InFlightDiagnostic diag =
+ emitOpError() << "expects only one non-pattern op in its body";
+ diag.attachNote(topLevelOp->getLoc()) << "first non-pattern op";
+ diag.attachNote(op.getLoc()) << "second non-pattern op";
+ return diag;
+ }
+ topLevelOp = &op;
+ continue;
+ }
+
+ InFlightDiagnostic diag =
+ emitOpError()
+ << "expects only pattern and top-level transform ops in its body";
+ diag.attachNote(op.getLoc()) << "offending op";
+ return diag;
+ }
+
+ if (auto parent = getOperation()->getParentOfType<WithPDLPatternsOp>()) {
+ InFlightDiagnostic diag = emitOpError() << "cannot be nested";
+ diag.attachNote(parent.getLoc()) << "parent operation";
+ return diag;
+ }
+
+ return success();
+}
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
-// expected-error @below {{expected the entry block to have one argument of type '!pdl.operation'}}
+// expected-error @below {{expects the entry block to have one argument of type '!pdl.operation'}}
transform.sequence {
}
// -----
-// expected-note @below {{nested in another sequence}}
+// expected-note @below {{nested in another possible top-level op}}
transform.sequence {
^bb0(%arg0: !pdl.operation):
- // expected-error @below {{expected the root operation to be provided for a nested sequence}}
+ // expected-error @below {{expects the root operation to be provided for a nested op}}
transform.sequence {
^bb1(%arg1: !pdl.operation):
}
// expected-note @below {{terminator}}
transform.yield
} : !pdl.operation
+
+// -----
+
+// expected-note @below {{nested in another possible top-level op}}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ // expected-error @below {{expects the root operation to be provided for a nested op}}
+ transform.sequence {
+ ^bb1(%arg1: !pdl.operation):
+ }
+}
+
+// -----
+
+// expected-error @below {{expects only one non-pattern op in its body}}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ // expected-note @below {{first non-pattern op}}
+ transform.sequence {
+ ^bb1(%arg1: !pdl.operation):
+ }
+ // expected-note @below {{second non-pattern op}}
+ transform.sequence {
+ ^bb1(%arg1: !pdl.operation):
+ }
+}
+
+// -----
+
+// expected-error @below {{expects only pattern and top-level transform ops in its body}}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ // expected-note @below {{offending op}}
+ "test.something"() : () -> ()
+}
+
+// -----
+
+// expected-note @below {{parent operation}}
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ // expected-error @below {{op cannot be nested}}
+ transform.with_pdl_patterns %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ }
+}
+
+// -----
+
+// expected-error @below {{expects one region}}
+"transform.test_transform_unrestricted_op_no_interface"() : () -> ()
+
+// -----
+
+// expected-error @below {{expects a single-block region}}
+"transform.test_transform_unrestricted_op_no_interface"() ({
+^bb0(%arg0: !pdl.operation):
+ "test.potential_terminator"() : () -> ()
+^bb1:
+ "test.potential_terminator"() : () -> ()
+}) : () -> ()
^bb1(%arg1: !pdl.operation):
}
}
+
+// CHECK: transform.with_pdl_patterns
+// CHECK: ^{{.+}}(%[[ARG:.+]]: !pdl.operation):
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ // CHECK: sequence %[[ARG]]
+ sequence %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ }
+}
+
+// CHECK: transform.sequence
+// CHECK: ^{{.+}}(%[[ARG:.+]]: !pdl.operation):
+transform.sequence {
+^bb0(%arg0: !pdl.operation):
+ // CHECK: with_pdl_patterns %[[ARG]]
+ with_pdl_patterns %arg0 {
+ ^bb1(%arg1: !pdl.operation):
+ }
+}
// expected-remark @below {{succeeded}}
test_consume_operand_if_matches_param_or_fail %0[42]
}
+
+// -----
+
+transform.with_pdl_patterns {
+^bb0(%arg0: !pdl.operation):
+ sequence %arg0 {
+ ^bb0(%arg1: !pdl.operation):
+ %0 = pdl_match @some in %arg1
+ test_print_remark_at_operand %0, "matched"
+ }
+
+ pdl.pattern @some : benefit(1) {
+ %0 = pdl.operation "test.some_op"
+ pdl.rewrite %0 with "transform.dialect"
+ }
+
+ pdl.pattern @other : benefit(1) {
+ %0 = pdl.operation "test.other_op"
+ pdl.rewrite %0 with "transform.dialect"
+ }
+}
+
+// expected-remark @below {{matched}}
+"test.some_op"() : () -> ()
+"test.other_op"() : () -> ()
+// expected-remark @below {{matched}}
+"test.some_op"() : () -> ()
+
namespace {
/// Simple transform op defined outside of the dialect. Just emits a remark when
-/// applied.
+/// applied. This op is defined in C++ to test that C++ definitions also work
+/// for op injection into the Transform dialect.
class TestTransformOp
: public Op<TestTransformOp, transform::TransformOpInterface::Trait> {
public:
printer << " " << getMessage();
}
};
+
+/// A test op to exercise the verifier of the PossibleTopLevelTransformOpTrait
+/// in cases where it is attached to ops that do not comply with the trait
+/// requirements. This op cannot be defined in ODS because ODS generates strict
+/// verifiers that overalp with those in the trait and run earlier.
+class TestTransformUnrestrictedOpNoInterface
+ : public Op<TestTransformUnrestrictedOpNoInterface,
+ transform::PossibleTopLevelTransformOpTrait,
+ transform::TransformOpInterface::Trait> {
+public:
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
+ TestTransformUnrestrictedOpNoInterface)
+
+ using Op::Op;
+
+ static ArrayRef<StringRef> getAttributeNames() { return {}; }
+
+ static constexpr llvm::StringLiteral getOperationName() {
+ return llvm::StringLiteral(
+ "transform.test_transform_unrestricted_op_no_interface");
+ }
+
+ LogicalResult apply(transform::TransformResults &results,
+ transform::TransformState &state) {
+ return success();
+ }
+};
} // namespace
LogicalResult mlir::test::TestProduceParamOrForwardOperandOp::apply(
return success();
}
+LogicalResult mlir::test::TestPrintRemarkAtOperandOp::apply(
+ transform::TransformResults &results, transform::TransformState &state) {
+ ArrayRef<Operation *> payload = state.getPayloadOps(getOperand());
+ for (Operation *op : payload)
+ op->emitRemark() << getMessage();
+
+ return success();
+}
+
namespace {
/// Test extension of the Transform dialect. Registers additional ops and
/// declares PDL as dependent dialect since the additional ops are using PDL
TestTransformDialectExtension() {
declareDependentDialect<pdl::PDLDialect>();
registerTransformOps<TestTransformOp,
+ TestTransformUnrestrictedOpNoInterface,
#define GET_OP_LIST
#include "TestTransformDialectExtension.cpp.inc"
>();
let cppNamespace = "::mlir::test";
}
+def TestPrintRemarkAtOperandOp
+ : Op<Transform_Dialect, "test_print_remark_at_operand",
+ [DeclareOpInterfaceMethods<TransformOpInterface>]> {
+ let arguments = (ins PDL_Operation:$operand, StrAttr:$message);
+ let assemblyFormat = "$operand `,` $message attr-dict";
+ let cppNamespace = "::mlir::test";
+}
+
#endif // MLIR_TESTTRANSFORMDIALECTEXTENSION_TD
deps = [
":IR",
":PDLDialect",
+ ":PDLInterpDialect",
+ ":Rewrite",
":Support",
":TransformDialectIncGen",
":TransformDialectInterfacesIncGen",