[mlir][transform] generate transform module on-the-fly
authorAlex Zinenko <zinenko@google.com>
Mon, 5 Jun 2023 18:42:15 +0000 (18:42 +0000)
committerAlex Zinenko <zinenko@google.com>
Tue, 6 Jun 2023 09:34:54 +0000 (09:34 +0000)
Add a TransformInterpreterPassBase capability to generate the (shared)
module containing the transform script during the pass initialization.
This is helpful to programmatically generate the script as opposed to
parsing it from the textual module.

Reviewed By: springerm

Differential Revision: https://reviews.llvm.org/D152185

mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
mlir/test/Dialect/Transform/test-interpreter-module-generation.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp

index d7c61ef..91903e2 100644 (file)
@@ -35,7 +35,9 @@ LogicalResult interpreterBaseInitializeImpl(
     MLIRContext *context, StringRef transformFileName,
     StringRef transformLibraryFileName,
     std::shared_ptr<OwningOpRef<ModuleOp>> &module,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule);
+    std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
+    function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
+        moduleBuilder = nullptr);
 
 /// Template-free implementation of
 /// TransformInterpreterPassBase::runOnOperation.
@@ -123,7 +125,11 @@ public:
         static_cast<Concrete *>(this)->transformLibraryFileName;
     return detail::interpreterBaseInitializeImpl(
         context, transformFileName, transformLibraryFileName,
-        sharedTransformModule, transformLibraryModule);
+        sharedTransformModule, transformLibraryModule,
+        [this](OpBuilder &builder, Location loc) {
+          return static_cast<Concrete *>(this)->constructTransformModule(
+              builder, loc);
+        });
   }
 
   /// Hook for passes to run additional logic in the pass before the
@@ -136,6 +142,14 @@ public:
   /// fails.
   LogicalResult runAfterInterpreter(Operation *) { return success(); }
 
+  /// Hook for passes to run custom logic to construct the transform module.
+  /// This will run during initialization. If the external script is provided,
+  /// it overrides the construction, which will not be called.
+  std::optional<LogicalResult> constructTransformModule(OpBuilder &builder,
+                                                        Location loc) {
+    return std::nullopt;
+  }
+
   void runOnOperation() override {
     auto *pass = static_cast<Concrete *>(this);
     Operation *op = pass->getOperation();
index 9dc9161..b9380f5 100644 (file)
@@ -462,7 +462,9 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
     MLIRContext *context, StringRef transformFileName,
     StringRef transformLibraryFileName,
     std::shared_ptr<OwningOpRef<ModuleOp>> &module,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule) {
+    std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
+    function_ref<std::optional<LogicalResult>(OpBuilder &, Location)>
+        moduleBuilder) {
   OwningOpRef<ModuleOp> parsed;
   if (failed(parseTransformModuleFromFile(context, transformFileName, parsed)))
     return failure();
@@ -476,7 +478,23 @@ LogicalResult transform::detail::interpreterBaseInitializeImpl(
   if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
     return failure();
 
-  module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
+  if (parsed) {
+    module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
+  } else if (moduleBuilder) {
+    // TODO: better location story.
+    auto location = UnknownLoc::get(context);
+    auto localModule = std::make_shared<OwningOpRef<ModuleOp>>(
+        ModuleOp::create(location, "__transform"));
+
+    OpBuilder b(context);
+    b.setInsertionPointToEnd(localModule->get().getBody());
+    if (std::optional<LogicalResult> result = moduleBuilder(b, location)) {
+      if (failed(*result))
+        return failure();
+      module = std::move(localModule);
+    }
+  }
+
   if (!parsedLibrary || !*parsedLibrary)
     return success();
 
diff --git a/mlir/test/Dialect/Transform/test-interpreter-module-generation.mlir b/mlir/test/Dialect/Transform/test-interpreter-module-generation.mlir
new file mode 100644 (file)
index 0000000..159aed7
--- /dev/null
@@ -0,0 +1,4 @@
+// RUN: mlir-opt %s --test-transform-dialect-interpreter=test-module-generation=1 --verify-diagnostics
+
+// expected-remark @below {{remark from generated}}
+module {}
index b7e9d08..f73deef 100644 (file)
@@ -11,7 +11,9 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "TestTransformDialectExtension.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
+#include "mlir/Dialect/Transform/IR/TransformOps.h"
 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinOps.h"
@@ -46,6 +48,10 @@ public:
     return "apply transform dialect operations one by one";
   }
 
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<transform::TransformDialect>();
+  }
+
   void findOperationsByName(Operation *root, StringRef name,
                             SmallVectorImpl<Operation *> &operations) {
     root->walk([&](Operation *op) {
@@ -86,6 +92,22 @@ public:
     return numSetValues;
   }
 
+  std::optional<LogicalResult> constructTransformModule(OpBuilder &builder,
+                                                        Location loc) {
+    if (!testModuleGeneration)
+      return std::nullopt;
+
+    builder.create<transform::SequenceOp>(
+        loc, TypeRange(), transform::FailurePropagationMode::Propagate,
+        builder.getType<transform::AnyOpType>(),
+        [](OpBuilder &b, Location nested, Value rootH) {
+          b.create<mlir::test::TestPrintRemarkAtOperandOp>(
+              nested, rootH, "remark from generated");
+          b.create<transform::YieldOp>(nested, ValueRange());
+        });
+    return success();
+  }
+
   void runOnOperation() override {
     unsigned firstSetOptions =
         numberOfSetOptions(bindFirstExtraToOps, bindFirstExtraToParams,
@@ -199,6 +221,11 @@ public:
       llvm::cl::desc(
           "Optional name of the file containing transform dialect symbol "
           "definitions to be injected into the transform module.")};
+
+  Option<bool> testModuleGeneration{
+      *this, "test-module-generation", llvm::cl::init(false),
+      llvm::cl::desc("test the generation of the transform module during pass "
+                     "initialization, overridden by parsing")};
 };
 
 struct TestTransformDialectEraseSchedulePass