MLIRContext *context, StringRef transformFileName,
StringRef transformLibraryFileName,
std::shared_ptr<OwningOpRef<ModuleOp>> &module,
- std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule);
+ std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
+ function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
+ moduleBuilder = nullptr);
/// Template-free implementation of
/// TransformInterpreterPassBase::runOnOperation.
static_cast<Concrete *>(this)->transformLibraryFileName;
return detail::interpreterBaseInitializeImpl(
context, transformFileName, transformLibraryFileName,
- sharedTransformModule, transformLibraryModule);
+ sharedTransformModule, transformLibraryModule,
+ [this](OpBuilder &builder, Location loc) {
+ return static_cast<Concrete *>(this)->constructTransformModule(
+ builder, loc);
+ });
}
/// Hook for passes to run additional logic in the pass before the
/// fails.
LogicalResult runAfterInterpreter(Operation *) { return success(); }
+ /// Hook for passes to run custom logic to construct the transform module.
+ /// This will run during initialization. If the external script is provided,
+ /// it overrides the construction, which will not be called.
+ std::optional<LogicalResult> constructTransformModule(OpBuilder &builder,
+ Location loc) {
+ return std::nullopt;
+ }
+
void runOnOperation() override {
auto *pass = static_cast<Concrete *>(this);
Operation *op = pass->getOperation();
MLIRContext *context, StringRef transformFileName,
StringRef transformLibraryFileName,
std::shared_ptr<OwningOpRef<ModuleOp>> &module,
- std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule) {
+ std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
+ function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
+ moduleBuilder) {
OwningOpRef<ModuleOp> parsed;
if (failed(parseTransformModuleFromFile(context, transformFileName, parsed)))
return failure();
if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
return failure();
- module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
+ if (parsed) {
+ module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
+ } else if (moduleBuilder) {
+ // TODO: better location story.
+ auto location = UnknownLoc::get(context);
+ auto localModule = std::make_shared<OwningOpRef<ModuleOp>>(
+ ModuleOp::create(location, "__transform"));
+
+ OpBuilder b(context);
+ b.setInsertionPointToEnd(localModule->get().getBody());
+ if (std::optional<LogicalResult> result = moduleBuilder(b, location)) {
+ if (failed(*result))
+ return failure();
+ module = std::move(localModule);
+ }
+ }
+
if (!parsedLibrary || !*parsedLibrary)
return success();
--- /dev/null
+// RUN: mlir-opt %s --test-transform-dialect-interpreter=test-module-generation=1 --verify-diagnostics
+
+// expected-remark @below {{remark from generated}}
+module {}
//
//===----------------------------------------------------------------------===//
+#include "TestTransformDialectExtension.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
return "apply transform dialect operations one by one";
}
+ void getDependentDialects(DialectRegistry ®istry) const override {
+ registry.insert<transform::TransformDialect>();
+ }
+
void findOperationsByName(Operation *root, StringRef name,
SmallVectorImpl<Operation *> &operations) {
root->walk([&](Operation *op) {
return numSetValues;
}
+ std::optional<LogicalResult> constructTransformModule(OpBuilder &builder,
+ Location loc) {
+ if (!testModuleGeneration)
+ return std::nullopt;
+
+ builder.create<transform::SequenceOp>(
+ loc, TypeRange(), transform::FailurePropagationMode::Propagate,
+ builder.getType<transform::AnyOpType>(),
+ [](OpBuilder &b, Location nested, Value rootH) {
+ b.create<mlir::test::TestPrintRemarkAtOperandOp>(
+ nested, rootH, "remark from generated");
+ b.create<transform::YieldOp>(nested, ValueRange());
+ });
+ return success();
+ }
+
void runOnOperation() override {
unsigned firstSetOptions =
numberOfSetOptions(bindFirstExtraToOps, bindFirstExtraToParams,
llvm::cl::desc(
"Optional name of the file containing transform dialect symbol "
"definitions to be injected into the transform module.")};
+
+ Option<bool> testModuleGeneration{
+ *this, "test-module-generation", llvm::cl::init(false),
+ llvm::cl::desc("test the generation of the transform module during pass "
+ "initialization, overridden by parsing")};
};
struct TestTransformDialectEraseSchedulePass