[mlir] support external named transform libraries
authorAlex Zinenko <zinenko@google.com>
Mon, 27 Mar 2023 14:03:03 +0000 (14:03 +0000)
committerAlex Zinenko <zinenko@google.com>
Tue, 28 Mar 2023 09:47:19 +0000 (09:47 +0000)
Introduce support for external definitions of named sequences in the
transform dialect by letting the TransformInterpreterPassBase read a
"library" MLIR file. This file is expected to contain definitions for
named sequences that are only declared in the main transformation
script. This allows for sharing non-trivial transform combinations
without duplication.

This patch provides only the minimal plumbing for a single textual IR
file. Further changes are possible to support multiple libraries and
bytecode files.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D146961

mlir/include/mlir/Dialect/Transform/IR/TransformOps.td
mlir/include/mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h
mlir/lib/Dialect/Transform/IR/TransformOps.cpp
mlir/lib/Dialect/Transform/Transforms/TransformInterpreterPassBase.cpp
mlir/test/Dialect/Transform/ops-invalid.mlir
mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir [new file with mode: 0644]
mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir [new file with mode: 0644]
mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir [new file with mode: 0644]
mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir [new file with mode: 0644]
mlir/test/lib/Dialect/Transform/TestTransformDialectInterpreter.cpp
utils/bazel/llvm-project-overlay/mlir/test/Dialect/BUILD.bazel

index 46dea74..acded77 100644 (file)
@@ -373,9 +373,10 @@ def NamedSequenceOp : TransformDialectOp<"named_sequence",
     SymbolNameAttr:$sym_name,
     TypeAttrBase<"::mlir::FunctionType",
                  "function type attribute">:$function_type,
+    OptionalAttr<StrAttr>:$sym_visibility,
     OptionalAttr<DictArrayAttr>:$arg_attrs,
     OptionalAttr<DictArrayAttr>:$res_attrs);
-  let regions = (region SizedRegion<1>:$body);
+  let regions = (region MaxSizedRegion<1>:$body);
 
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
index 0a60b4c..d7c61ef 100644 (file)
@@ -31,18 +31,22 @@ class Region;
 namespace transform {
 namespace detail {
 /// Template-free implementation of TransformInterpreterPassBase::initialize.
-LogicalResult
-interpreterBaseInitializeImpl(MLIRContext *context, StringRef transformFileName,
-                              std::shared_ptr<OwningOpRef<ModuleOp>> &module);
+LogicalResult interpreterBaseInitializeImpl(
+    MLIRContext *context, StringRef transformFileName,
+    StringRef transformLibraryFileName,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &module,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule);
 
 /// Template-free implementation of
 /// TransformInterpreterPassBase::runOnOperation.
 LogicalResult interpreterBaseRunOnOperationImpl(
     Operation *target, StringRef passName,
     const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
+    const std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
     const RaggedArray<MappedValue> &extraMappings,
     const TransformOptions &options,
     const Pass::Option<std::string> &transformFileName,
+    const Pass::Option<std::string> &transformLibraryFileName,
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
     StringRef binaryName);
@@ -56,6 +60,9 @@ LogicalResult interpreterBaseRunOnOperationImpl(
 ///     transform script. If empty, `debugTransformRootTag` is considered or the
 ///     pass root operation must contain a single top-level transform op that
 ///     will be interpreted.
+///   - transformLibraryFileName: if non-empty, the name of the file containing
+///     definitions of external symbols referenced in the transform script.
+///     These definitions will be used to replace declarations.
 ///   - debugPayloadRootTag: if non-empty, the value of the attribute named
 ///     `kTransformDialectTagAttrName` indicating the single op that is
 ///     considered the payload root of the transform interpreter; otherwise, the
@@ -106,13 +113,17 @@ public:
     REQUIRE_PASS_OPTION(transformFileName);
     REQUIRE_PASS_OPTION(debugPayloadRootTag);
     REQUIRE_PASS_OPTION(debugTransformRootTag);
+    REQUIRE_PASS_OPTION(transformLibraryFileName);
 
 #undef REQUIRE_PASS_OPTION
 
     StringRef transformFileName =
         static_cast<Concrete *>(this)->transformFileName;
-    return detail::interpreterBaseInitializeImpl(context, transformFileName,
-                                                 sharedTransformModule);
+    StringRef transformLibraryFileName =
+        static_cast<Concrete *>(this)->transformLibraryFileName;
+    return detail::interpreterBaseInitializeImpl(
+        context, transformFileName, transformLibraryFileName,
+        sharedTransformModule, transformLibraryModule);
   }
 
   /// Hook for passes to run additional logic in the pass before the
@@ -132,9 +143,10 @@ public:
     if (failed(pass->runBeforeInterpreter(op)) ||
         failed(detail::interpreterBaseRunOnOperationImpl(
             op, pass->getArgument(), sharedTransformModule,
+            transformLibraryModule,
             /*extraMappings=*/{}, options, pass->transformFileName,
-            pass->debugPayloadRootTag, pass->debugTransformRootTag,
-            binaryName)) ||
+            pass->transformLibraryFileName, pass->debugPayloadRootTag,
+            pass->debugTransformRootTag, binaryName)) ||
         failed(pass->runAfterInterpreter(op))) {
       return pass->signalPassFailure();
     }
@@ -150,12 +162,24 @@ protected:
     return sharedTransformModule;
   }
 
+  /// Returns a read-only reference to the transform library module.
+  const std::shared_ptr<OwningOpRef<ModuleOp>> &
+  getTransformLibraryModule() const {
+    return transformLibraryModule;
+  }
+
 private:
   /// The separate transform module to be used for transformations, shared
   /// across multiple instances of the pass if it is applied in parallel to
   /// avoid potentially expensive cloning. MUST NOT be modified after the pass
   /// has been initialized.
   std::shared_ptr<OwningOpRef<ModuleOp>> sharedTransformModule = nullptr;
+
+  /// The transform module containing symbol definitions that become available
+  /// in the transform scripts. Similar to dynamic linking for binaries. This is
+  /// shared across multiple instances of the pass and therefore MUST NOT be
+  /// modified after the pass has been initialized.
+  std::shared_ptr<OwningOpRef<ModuleOp>> transformLibraryModule = nullptr;
 };
 
 } // namespace transform
index 6051007..9ec4c8e 100644 (file)
@@ -573,6 +573,9 @@ transform::IncludeOp::apply(transform::TransformResults &results,
       getOperation(), getTarget());
   assert(callee && "unverified reference to unknown symbol");
 
+  if (callee.isExternal())
+    return emitDefiniteFailure() << "unresolved external named sequence";
+
   // Map operands to block arguments.
   SmallVector<SmallVector<MappedValue>> mappings;
   detail::prepareValueMappings(mappings, getOperands(), state);
@@ -648,7 +651,10 @@ void transform::IncludeOp::getEffects(
   }
 
   // Carry over effects from the callee.
-  remapArgumentEffects(callee.getBody().front(), getOperands(), effects);
+  // TODO: external callees must provides attributes annotating the
+  // readonly/consume effects on operands.
+  if (!callee.isExternal())
+    remapArgumentEffects(callee.getBody().front(), getOperands(), effects);
 
   // Proper effects.
   onlyReadsHandle(getOperands(), effects);
@@ -784,9 +790,6 @@ void transform::NamedSequenceOp::print(OpAsmPrinter &printer) {
 /// verifier runs, e.g., during trait verification.
 static DiagnosedSilenceableFailure
 verifyNamedSequenceOp(transform::NamedSequenceOp op) {
-  if (op.isExternal())
-    return emitSilenceableFailure(op) << "cannot be empty";
-
   if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
     if (!parent->getAttr(
             transform::TransformDialect::kWithNamedSequenceAttrName)) {
@@ -808,6 +811,9 @@ verifyNamedSequenceOp(transform::NamedSequenceOp op) {
     return diag;
   }
 
+  if (op.isExternal() || op.getBody().empty())
+    return DiagnosedSilenceableFailure::success();
+
   if (op.getBody().front().empty())
     return emitSilenceableFailure(op) << "expected a non-empty body block";
 
index 40624e6..b3fe45e 100644 (file)
 #include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
 #include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
 #include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/FunctionInterfaces.h"
+#include "mlir/IR/Verifier.h"
 #include "mlir/IR/Visitors.h"
 #include "mlir/Parser/Parser.h"
 #include "mlir/Pass/Pass.h"
 #include "mlir/Support/FileUtilities.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/FileSystem.h"
@@ -157,9 +160,17 @@ static llvm::raw_ostream &
 printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName,
                const Pass::Option<std::string> &debugPayloadRootTag,
                const Pass::Option<std::string> &debugTransformRootTag,
+               const Pass::Option<std::string> &transformLibraryFileName,
                StringRef binaryName) {
+  std::string transformLibraryOption = "";
+  if (!transformLibraryFileName.empty()) {
+    transformLibraryOption =
+        llvm::formatv(" {0}={1}", transformLibraryFileName.getArgStr(),
+                      transformLibraryFileName.getValue())
+            .str();
+  }
   os << llvm::formatv(
-      "{6} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}})\"", rootOpName,
+      "{7} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}{6}})\"", rootOpName,
       passName, debugPayloadRootTag.getArgStr(),
       debugPayloadRootTag.empty()
           ? StringRef(kTransformDialectTagPayloadRootValue)
@@ -168,7 +179,7 @@ printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName,
       debugTransformRootTag.empty()
           ? StringRef(kTransformDialectTagTransformContainerValue)
           : debugTransformRootTag,
-      binaryName);
+      transformLibraryOption, binaryName);
   return os;
 }
 
@@ -184,11 +195,12 @@ llvm::raw_ostream &printModuleForRepro(llvm::raw_ostream &os, Operation *root,
 
 /// Saves the payload and the transform IR into a temporary file and reports
 /// the file name to `os`.
-void saveReproToTempFile(llvm::raw_ostream &os, Operation *target,
-                         Operation *transform, StringRef passName,
-                         const Pass::Option<std::string> &debugPayloadRootTag,
-                         const Pass::Option<std::string> &debugTransformRootTag,
-                         StringRef binaryName) {
+void saveReproToTempFile(
+    llvm::raw_ostream &os, Operation *target, Operation *transform,
+    StringRef passName, const Pass::Option<std::string> &debugPayloadRootTag,
+    const Pass::Option<std::string> &debugTransformRootTag,
+    const Pass::Option<std::string> &transformLibraryFileName,
+    StringRef binaryName) {
   using llvm::sys::fs::TempFile;
   Operation *root = getRootOperation(target);
 
@@ -213,7 +225,8 @@ void saveReproToTempFile(llvm::raw_ostream &os, Operation *target,
 
   os << "=== Transform Interpreter Repro ===\n";
   printReproCall(os, root->getName().getStringRef(), passName,
-                 debugPayloadRootTag, debugTransformRootTag, binaryName)
+                 debugPayloadRootTag, debugTransformRootTag,
+                 transformLibraryFileName, binaryName)
       << " " << filename << "\n";
   os << "===================================\n";
 }
@@ -224,6 +237,7 @@ static void performOptionalDebugActions(
     Operation *target, Operation *transform, StringRef passName,
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
+    const Pass::Option<std::string> &transformLibraryFileName,
     StringRef binaryName) {
   MLIRContext *context = target->getContext();
 
@@ -266,7 +280,8 @@ static void performOptionalDebugActions(
     llvm::dbgs() << "=== Transform Interpreter Repro ===\n";
     printReproCall(llvm::dbgs() << "cat <<EOF | ",
                    root->getName().getStringRef(), passName,
-                   debugPayloadRootTag, debugTransformRootTag, binaryName)
+                   debugPayloadRootTag, debugTransformRootTag,
+                   transformLibraryFileName, binaryName)
         << "\n";
     printModuleForRepro(llvm::dbgs(), root, transform);
     llvm::dbgs() << "\nEOF\n";
@@ -275,16 +290,63 @@ static void performOptionalDebugActions(
   (void)root;
   DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, {
     saveReproToTempFile(llvm::dbgs(), target, transform, passName,
-                        debugPayloadRootTag, debugTransformRootTag, binaryName);
+                        debugPayloadRootTag, debugTransformRootTag,
+                        transformLibraryFileName, binaryName);
   });
 }
 
+/// Replaces external symbols in `block` with their (non-external) definitions
+/// from the given module.
+static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
+  for (Operation &op : llvm::make_early_inc_range(block)) {
+    LLVM_DEBUG(DBGS() << op << "\n");
+    auto symbol = dyn_cast<SymbolOpInterface>(op);
+    if (!symbol)
+      continue;
+    if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
+      continue;
+
+    LLVM_DEBUG(DBGS() << "looking for definition of symbol "
+                      << symbol.getNameAttr() << ":");
+    SymbolTable symbolTable(definitions);
+    Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
+    if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
+        externalSymbol->getRegion(0).empty()) {
+      LLVM_DEBUG(llvm::dbgs() << "not found\n");
+      continue;
+    }
+
+    auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
+    auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
+    if (!symbolFunc || !externalSymbolFunc) {
+      LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
+      continue;
+    }
+
+    LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
+    if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
+      return symbolFunc.emitError()
+             << "external definition has a mismatching signature ("
+             << externalSymbolFunc.getFunctionType() << ")";
+    }
+
+    OpBuilder builder(&op);
+    builder.setInsertionPoint(&op);
+    builder.clone(*externalSymbol);
+    symbol->erase();
+  }
+
+  return success();
+}
+
 LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
     Operation *target, StringRef passName,
     const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
+    const std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
     const RaggedArray<MappedValue> &extraMappings,
     const TransformOptions &options,
     const Pass::Option<std::string> &transformFileName,
+    const Pass::Option<std::string> &transformLibraryFileName,
     const Pass::Option<std::string> &debugPayloadRootTag,
     const Pass::Option<std::string> &debugTransformRootTag,
     StringRef binaryName) {
@@ -328,13 +390,31 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
 
   // Step 3
   // ------
+  // Copy external defintions for symbols if provided. Be aware of potential
+  // concurrent execution (normally, the error shouldn't be triggered unless the
+  // transform IR modifies itself in a pass, which is also forbidden elsewhere).
+  if (!sharedTransform && libraryModule && *libraryModule) {
+    if (!target->isProperAncestor(transformRoot)) {
+      InFlightDiagnostic diag =
+          transformRoot->emitError()
+          << "cannot inject transform definitions next to pass anchor op";
+      diag.attachNote(target->getLoc()) << "pass anchor op";
+      return diag;
+    }
+    if (failed(defineDeclaredSymbols(*transformRoot->getBlock(),
+                                     libraryModule->get())))
+      return failure();
+  }
+
+  // Step 4
+  // ------
   // Optionally perform debug actions requested by the user to dump IR and a
   // repro to stderr and/or a file.
   performOptionalDebugActions(target, transformRoot, passName,
                               debugPayloadRootTag, debugTransformRootTag,
-                              binaryName);
+                              transformLibraryFileName, binaryName);
 
-  // Step 4
+  // Step 5
   // ------
   // Apply the transform to the IR
   return applyTransforms(payloadRoot, cast<TransformOpInterface>(transformRoot),
@@ -343,11 +423,33 @@ LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
 
 LogicalResult transform::detail::interpreterBaseInitializeImpl(
     MLIRContext *context, StringRef transformFileName,
-    std::shared_ptr<OwningOpRef<ModuleOp>> &module) {
+    StringRef transformLibraryFileName,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &module,
+    std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule) {
   OwningOpRef<ModuleOp> parsed;
   if (failed(parseTransformModuleFromFile(context, transformFileName, parsed)))
     return failure();
+  if (parsed && failed(mlir::verify(*parsed)))
+    return failure();
+
+  OwningOpRef<ModuleOp> parsedLibrary;
+  if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
+                                          parsedLibrary)))
+    return failure();
+  if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
+    return failure();
 
   module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
+  if (!parsedLibrary || !*parsedLibrary)
+    return success();
+
+  if (module && *module) {
+    if (failed(defineDeclaredSymbols(*module->get().getBody(),
+                                     parsedLibrary.get())))
+      return failure();
+  } else {
+    libraryModule =
+        std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibrary));
+  }
   return success();
 }
index ee03d9e..a6bfa64 100644 (file)
@@ -293,8 +293,10 @@ transform.sequence failures(suppress) {
 // -----
 
 module attributes { transform.with_named_sequence } {
-  // expected-error @below {{failed to verify constraint: region with 1 blocks}}
-  "transform.named_sequence"() ({}) { sym_name = "external_named_sequence", function_type = () -> () } : () -> ()
+  // expected-error @below {{expected a non-empty body block}}
+  "transform.named_sequence"() ({
+  ^bb0:
+  }) { sym_name = "external_named_sequence", function_type = () -> () } : () -> ()
 
   transform.sequence failures(propagate) {
   ^bb0(%arg0: !transform.any_op):
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-and-schedule.mlir
new file mode 100644 (file)
index 0000000..da50c6b
--- /dev/null
@@ -0,0 +1,13 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN:             --verify-diagnostics
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN:             --verify-diagnostics
+
+// The external transform script has a declaration to the named sequence @foo,
+// the definition of which is provided in another file. Repeated application
+// of the same pass should not be a problem. Note that the same diagnostic
+// produced twice at the same location only needs to be matched once.
+
+// expected-remark @below {{message}}
+module {}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl-invalid.mlir
new file mode 100644 (file)
index 0000000..bb8acf8
--- /dev/null
@@ -0,0 +1,27 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
+// RUN:             --verify-diagnostics --split-input-file
+
+// The definition of the @foo named sequence is provided in another file. It
+// will be included because of the pass option.
+
+module attributes {transform.with_named_sequence} {
+  // expected-error @below {{external definition has a mismatching signature}}
+  transform.named_sequence private @foo(!transform.op<"builtin.module">)
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.op<"builtin.module">):
+    include @foo failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
+  }
+}
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence private @undefined_sequence()
+
+  transform.sequence failures(suppress) {
+  ^bb0(%arg0: !transform.any_op):
+    // expected-error @below {{unresolved external named sequence}}
+    include @undefined_sequence failures(suppress) () : () -> ()
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-decl.mlir
new file mode 100644 (file)
index 0000000..6e23641
--- /dev/null
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN:             --verify-diagnostics | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
+// RUN:             --verify-diagnostics | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN:             --verify-diagnostics | FileCheck %s
+
+// The definition of the @foo named sequence is provided in another file. It
+// will be included because of the pass option. Repeated application of the
+// same pass, with or without the library option, should not be a problem.
+// Note that the same diagnostic produced twice at the same location only
+// needs to be matched once.
+
+// expected-remark @below {{message}}
+module attributes {transform.with_named_sequence} {
+  // CHECK: transform.named_sequence @foo
+  // CHECK: test_print_remark_at_operand %{{.*}}, "message"
+  transform.named_sequence private @foo(!transform.any_op)
+
+  transform.sequence failures(propagate) {
+  ^bb0(%arg0: !transform.any_op):
+    include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+  }
+}
diff --git a/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir b/mlir/test/Dialect/Transform/test-interpreter-external-symbol-def.mlir
new file mode 100644 (file)
index 0000000..509612b
--- /dev/null
@@ -0,0 +1,8 @@
+// RUN: mlir-opt %s
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @foo(%arg0: !transform.any_op) {
+    transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
+    transform.yield
+  }
+}
index 7beae91..b7e9d08 100644 (file)
@@ -138,7 +138,8 @@ public:
     options = options.enableExpensiveChecks(enableExpensiveChecks);
     if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
             getOperation(), getArgument(), getSharedTransformModule(),
-            extraMapping, options, transformFileName, debugPayloadRootTag,
+            getTransformLibraryModule(), extraMapping, options,
+            transformFileName, transformLibraryFileName, debugPayloadRootTag,
             debugTransformRootTag, getBinaryName())))
       return signalPassFailure();
   }
@@ -193,6 +194,11 @@ public:
           "the given value as container IR for top-level transform ops. This "
           "allows user control on what transformation to apply. If empty, "
           "select the container of the top-level transform op.")};
+  Option<std::string> transformLibraryFileName{
+      *this, "transform-library-file-name", llvm::cl::init(""),
+      llvm::cl::desc(
+          "Optional name of the file containing transform dialect symbol "
+          "definitions to be injected into the transform module.")};
 };
 
 struct TestTransformDialectEraseSchedulePass
index adfb3d0..137db36 100644 (file)
@@ -15,10 +15,15 @@ package(default_visibility = ["//visibility:public"])
             "//mlir/test:lit_data",
         ] + glob([
             "Transform/*-source.mlir",
+            "Transform/*-symbol-def.mlir",
         ])
     )
     for src in glob(
         include=["**/*.mlir"],
-        exclude=["Transform/*-source.mlir"]
+        exclude=[
+            "Transform/*-source.mlir",
+            "Transform/*-symbol-def.mlir",
+            "Transform/*-symbol-decl-and-schedule.mlir",
+        ]
     )
 ]