From 9d30c6a721edf75d0726e07fb82cc5538fb95c16 Mon Sep 17 00:00:00 2001 From: Alex Zinenko Date: Mon, 5 Jun 2023 18:42:15 +0000 Subject: [PATCH] [mlir][transform] generate transform module on-the-fly Add a TransformInterpreterPassBase capability to generate the (shared) module containing the transform script during the pass initialization. This is helpful to programmatically generate the script as opposed to parsing it from the textual module. Reviewed By: springerm Differential Revision: https://reviews.llvm.org/D152185 --- .../Transforms/TransformInterpreterPassBase.h | 18 +++++++++++++-- .../Transforms/TransformInterpreterPassBase.cpp | 22 ++++++++++++++++-- .../test-interpreter-module-generation.mlir | 4 ++++ .../Transform/TestTransformDialectInterpreter.cpp | 27 ++++++++++++++++++++++ 4 files changed, 67 insertions(+), 4 deletions(-) create mode 100644 mlir/test/Dialect/Transform/test-interpreter-module-generation.mlir diff --git a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h index d7c61ef..91903e2 100644 --- a/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h +++ b/mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h @@ -35,7 +35,9 @@ LogicalResult interpreterBaseInitializeImpl( MLIRContext *context, StringRef transformFileName, StringRef transformLibraryFileName, std::shared_ptr> &module, - std::shared_ptr> &libraryModule); + std::shared_ptr> &libraryModule, + function_ref(OpBuilder &, Location)> + moduleBuilder = nullptr); /// Template-free implementation of /// TransformInterpreterPassBase::runOnOperation. @@ -123,7 +125,11 @@ public: static_cast(this)->transformLibraryFileName; return detail::interpreterBaseInitializeImpl( context, transformFileName, transformLibraryFileName, - sharedTransformModule, transformLibraryModule); + sharedTransformModule, transformLibraryModule, + [this](OpBuilder &builder, Location loc) { + return static_cast(this)->constructTransformModule( + builder, loc); + }); } /// Hook for passes to run additional logic in the pass before the @@ -136,6 +142,14 @@ public: /// fails. LogicalResult runAfterInterpreter(Operation *) { return success(); } + /// Hook for passes to run custom logic to construct the transform module. + /// This will run during initialization. If the external script is provided, + /// it overrides the construction, which will not be called. + std::optional constructTransformModule(OpBuilder &builder, + Location loc) { + return std::nullopt; + } + void runOnOperation() override { auto *pass = static_cast(this); Operation *op = pass->getOperation(); diff --git a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp index 9dc9161..b9380f5 100644 --- a/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp +++ b/mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp @@ -462,7 +462,9 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl( MLIRContext *context, StringRef transformFileName, StringRef transformLibraryFileName, std::shared_ptr> &module, - std::shared_ptr> &libraryModule) { + std::shared_ptr> &libraryModule, + function_ref(OpBuilder &, Location)> + moduleBuilder) { OwningOpRef parsed; if (failed(parseTransformModuleFromFile(context, transformFileName, parsed))) return failure(); @@ -476,7 +478,23 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl( if (parsedLibrary && failed(mlir::verify(*parsedLibrary))) return failure(); - module = std::make_shared>(std::move(parsed)); + if (parsed) { + module = std::make_shared>(std::move(parsed)); + } else if (moduleBuilder) { + // TODO: better location story. + auto location = UnknownLoc::get(context); + auto localModule = std::make_shared>( + ModuleOp::create(location, "__transform")); + + OpBuilder b(context); + b.setInsertionPointToEnd(localModule->get().getBody()); + if (std::optional result = moduleBuilder(b, location)) { + if (failed(*result)) + return failure(); + module = std::move(localModule); + } + } + if (!parsedLibrary || !*parsedLibrary) return success(); diff --git a/mlir/test/Dialect/Transform/test-interpreter-module-generation.mlir b/mlir/test/Dialect/Transform/test-interpreter-module-generation.mlir new file mode 100644 index 0000000..159aed72 --- /dev/null +++ b/mlir/test/Dialect/Transform/test-interpreter-module-generation.mlir @@ -0,0 +1,4 @@ +// RUN: mlir-opt %s --test-transform-dialect-interpreter=test-module-generation=1 --verify-diagnostics + +// expected-remark @below {{remark from generated}} +module {} diff --git a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp index b7e9d08..f73deef 100644 --- a/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp +++ b/mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp @@ -11,7 +11,9 @@ // //===----------------------------------------------------------------------===// +#include "TestTransformDialectExtension.h" #include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Transform/IR/TransformOps.h" #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" @@ -46,6 +48,10 @@ public: return "apply transform dialect operations one by one"; } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void findOperationsByName(Operation *root, StringRef name, SmallVectorImpl &operations) { root->walk([&](Operation *op) { @@ -86,6 +92,22 @@ public: return numSetValues; } + std::optional constructTransformModule(OpBuilder &builder, + Location loc) { + if (!testModuleGeneration) + return std::nullopt; + + builder.create( + loc, TypeRange(), transform::FailurePropagationMode::Propagate, + builder.getType(), + [](OpBuilder &b, Location nested, Value rootH) { + b.create( + nested, rootH, "remark from generated"); + b.create(nested, ValueRange()); + }); + return success(); + } + void runOnOperation() override { unsigned firstSetOptions = numberOfSetOptions(bindFirstExtraToOps, bindFirstExtraToParams, @@ -199,6 +221,11 @@ public: llvm::cl::desc( "Optional name of the file containing transform dialect symbol " "definitions to be injected into the transform module.")}; + + Option testModuleGeneration{ + *this, "test-module-generation", llvm::cl::init(false), + llvm::cl::desc("test the generation of the transform module during pass " + "initialization, overridden by parsing")}; }; struct TestTransformDialectEraseSchedulePass -- 2.7.4