SymbolNameAttr:$sym_name,
TypeAttrBase<"::mlir::FunctionType",
"function type attribute">:$function_type,
+ OptionalAttr<StrAttr>:$sym_visibility,
OptionalAttr<DictArrayAttr>:$arg_attrs,
OptionalAttr<DictArrayAttr>:$res_attrs);
- let regions = (region SizedRegion<1>:$body);
+ let regions = (region MaxSizedRegion<1>:$body);
let hasCustomAssemblyFormat = 1;
let hasVerifier = 1;
namespace transform {
namespace detail {
/// Template-free implementation of TransformInterpreterPassBase::initialize.
-LogicalResult
-interpreterBaseInitializeImpl(MLIRContext *context, StringRef transformFileName,
- std::shared_ptr<OwningOpRef<ModuleOp>> &module);
+LogicalResult interpreterBaseInitializeImpl(
+ MLIRContext *context, StringRef transformFileName,
+ StringRef transformLibraryFileName,
+ std::shared_ptr<OwningOpRef<ModuleOp>> &module,
+ std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule);
/// Template-free implementation of
/// TransformInterpreterPassBase::runOnOperation.
LogicalResult interpreterBaseRunOnOperationImpl(
Operation *target, StringRef passName,
const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
+ const std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
const RaggedArray<MappedValue> &extraMappings,
const TransformOptions &options,
const Pass::Option<std::string> &transformFileName,
+ const Pass::Option<std::string> &transformLibraryFileName,
const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
StringRef binaryName);
/// transform script. If empty, `debugTransformRootTag` is considered or the
/// pass root operation must contain a single top-level transform op that
/// will be interpreted.
+/// - transformLibraryFileName: if non-empty, the name of the file containing
+/// definitions of external symbols referenced in the transform script.
+/// These definitions will be used to replace declarations.
/// - debugPayloadRootTag: if non-empty, the value of the attribute named
/// `kTransformDialectTagAttrName` indicating the single op that is
/// considered the payload root of the transform interpreter; otherwise, the
REQUIRE_PASS_OPTION(transformFileName);
REQUIRE_PASS_OPTION(debugPayloadRootTag);
REQUIRE_PASS_OPTION(debugTransformRootTag);
+ REQUIRE_PASS_OPTION(transformLibraryFileName);
#undef REQUIRE_PASS_OPTION
StringRef transformFileName =
static_cast<Concrete *>(this)->transformFileName;
- return detail::interpreterBaseInitializeImpl(context, transformFileName,
- sharedTransformModule);
+ StringRef transformLibraryFileName =
+ static_cast<Concrete *>(this)->transformLibraryFileName;
+ return detail::interpreterBaseInitializeImpl(
+ context, transformFileName, transformLibraryFileName,
+ sharedTransformModule, transformLibraryModule);
}
/// Hook for passes to run additional logic in the pass before the
if (failed(pass->runBeforeInterpreter(op)) ||
failed(detail::interpreterBaseRunOnOperationImpl(
op, pass->getArgument(), sharedTransformModule,
+ transformLibraryModule,
/*extraMappings=*/{}, options, pass->transformFileName,
- pass->debugPayloadRootTag, pass->debugTransformRootTag,
- binaryName)) ||
+ pass->transformLibraryFileName, pass->debugPayloadRootTag,
+ pass->debugTransformRootTag, binaryName)) ||
failed(pass->runAfterInterpreter(op))) {
return pass->signalPassFailure();
}
return sharedTransformModule;
}
+ /// Returns a read-only reference to the transform library module.
+ const std::shared_ptr<OwningOpRef<ModuleOp>> &
+ getTransformLibraryModule() const {
+ return transformLibraryModule;
+ }
+
private:
/// The separate transform module to be used for transformations, shared
/// across multiple instances of the pass if it is applied in parallel to
/// avoid potentially expensive cloning. MUST NOT be modified after the pass
/// has been initialized.
std::shared_ptr<OwningOpRef<ModuleOp>> sharedTransformModule = nullptr;
+
+ /// The transform module containing symbol definitions that become available
+ /// in the transform scripts. Similar to dynamic linking for binaries. This is
+ /// shared across multiple instances of the pass and therefore MUST NOT be
+ /// modified after the pass has been initialized.
+ std::shared_ptr<OwningOpRef<ModuleOp>> transformLibraryModule = nullptr;
};
} // namespace transform
getOperation(), getTarget());
assert(callee && "unverified reference to unknown symbol");
+ if (callee.isExternal())
+ return emitDefiniteFailure() << "unresolved external named sequence";
+
// Map operands to block arguments.
SmallVector<SmallVector<MappedValue>> mappings;
detail::prepareValueMappings(mappings, getOperands(), state);
}
// Carry over effects from the callee.
- remapArgumentEffects(callee.getBody().front(), getOperands(), effects);
+ // TODO: external callees must provides attributes annotating the
+ // readonly/consume effects on operands.
+ if (!callee.isExternal())
+ remapArgumentEffects(callee.getBody().front(), getOperands(), effects);
// Proper effects.
onlyReadsHandle(getOperands(), effects);
/// verifier runs, e.g., during trait verification.
static DiagnosedSilenceableFailure
verifyNamedSequenceOp(transform::NamedSequenceOp op) {
- if (op.isExternal())
- return emitSilenceableFailure(op) << "cannot be empty";
-
if (Operation *parent = op->getParentWithTrait<OpTrait::SymbolTable>()) {
if (!parent->getAttr(
transform::TransformDialect::kWithNamedSequenceAttrName)) {
return diag;
}
+ if (op.isExternal() || op.getBody().empty())
+ return DiagnosedSilenceableFailure::success();
+
if (op.getBody().front().empty())
return emitSilenceableFailure(op) << "expected a non-empty body block";
#include "mlir/Dialect/Transform/Transforms/TransformInterpreterPassBase.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/FunctionInterfaces.h"
+#include "mlir/IR/Verifier.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Support/FileUtilities.h"
+#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FileSystem.h"
printReproCall(llvm::raw_ostream &os, StringRef rootOpName, StringRef passName,
const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
+ const Pass::Option<std::string> &transformLibraryFileName,
StringRef binaryName) {
+ std::string transformLibraryOption = "";
+ if (!transformLibraryFileName.empty()) {
+ transformLibraryOption =
+ llvm::formatv(" {0}={1}", transformLibraryFileName.getArgStr(),
+ transformLibraryFileName.getValue())
+ .str();
+ }
os << llvm::formatv(
- "{6} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}})\"", rootOpName,
+ "{7} --pass-pipeline=\"{0}({1}{{{2}={3} {4}={5}{6}})\"", rootOpName,
passName, debugPayloadRootTag.getArgStr(),
debugPayloadRootTag.empty()
? StringRef(kTransformDialectTagPayloadRootValue)
debugTransformRootTag.empty()
? StringRef(kTransformDialectTagTransformContainerValue)
: debugTransformRootTag,
- binaryName);
+ transformLibraryOption, binaryName);
return os;
}
/// Saves the payload and the transform IR into a temporary file and reports
/// the file name to `os`.
-void saveReproToTempFile(llvm::raw_ostream &os, Operation *target,
- Operation *transform, StringRef passName,
- const Pass::Option<std::string> &debugPayloadRootTag,
- const Pass::Option<std::string> &debugTransformRootTag,
- StringRef binaryName) {
+void saveReproToTempFile(
+ llvm::raw_ostream &os, Operation *target, Operation *transform,
+ StringRef passName, const Pass::Option<std::string> &debugPayloadRootTag,
+ const Pass::Option<std::string> &debugTransformRootTag,
+ const Pass::Option<std::string> &transformLibraryFileName,
+ StringRef binaryName) {
using llvm::sys::fs::TempFile;
Operation *root = getRootOperation(target);
os << "=== Transform Interpreter Repro ===\n";
printReproCall(os, root->getName().getStringRef(), passName,
- debugPayloadRootTag, debugTransformRootTag, binaryName)
+ debugPayloadRootTag, debugTransformRootTag,
+ transformLibraryFileName, binaryName)
<< " " << filename << "\n";
os << "===================================\n";
}
Operation *target, Operation *transform, StringRef passName,
const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
+ const Pass::Option<std::string> &transformLibraryFileName,
StringRef binaryName) {
MLIRContext *context = target->getContext();
llvm::dbgs() << "=== Transform Interpreter Repro ===\n";
printReproCall(llvm::dbgs() << "cat <<EOF | ",
root->getName().getStringRef(), passName,
- debugPayloadRootTag, debugTransformRootTag, binaryName)
+ debugPayloadRootTag, debugTransformRootTag,
+ transformLibraryFileName, binaryName)
<< "\n";
printModuleForRepro(llvm::dbgs(), root, transform);
llvm::dbgs() << "\nEOF\n";
(void)root;
DEBUG_WITH_TYPE(DEBUG_TYPE_DUMP_FILE, {
saveReproToTempFile(llvm::dbgs(), target, transform, passName,
- debugPayloadRootTag, debugTransformRootTag, binaryName);
+ debugPayloadRootTag, debugTransformRootTag,
+ transformLibraryFileName, binaryName);
});
}
+/// Replaces external symbols in `block` with their (non-external) definitions
+/// from the given module.
+static LogicalResult defineDeclaredSymbols(Block &block, ModuleOp definitions) {
+ for (Operation &op : llvm::make_early_inc_range(block)) {
+ LLVM_DEBUG(DBGS() << op << "\n");
+ auto symbol = dyn_cast<SymbolOpInterface>(op);
+ if (!symbol)
+ continue;
+ if (symbol->getNumRegions() == 1 && !symbol->getRegion(0).empty())
+ continue;
+
+ LLVM_DEBUG(DBGS() << "looking for definition of symbol "
+ << symbol.getNameAttr() << ":");
+ SymbolTable symbolTable(definitions);
+ Operation *externalSymbol = symbolTable.lookup(symbol.getNameAttr());
+ if (!externalSymbol || externalSymbol->getNumRegions() != 1 ||
+ externalSymbol->getRegion(0).empty()) {
+ LLVM_DEBUG(llvm::dbgs() << "not found\n");
+ continue;
+ }
+
+ auto symbolFunc = dyn_cast<FunctionOpInterface>(op);
+ auto externalSymbolFunc = dyn_cast<FunctionOpInterface>(externalSymbol);
+ if (!symbolFunc || !externalSymbolFunc) {
+ LLVM_DEBUG(llvm::dbgs() << "cannot compare types\n");
+ continue;
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "found @" << externalSymbol << "\n");
+ if (symbolFunc.getFunctionType() != externalSymbolFunc.getFunctionType()) {
+ return symbolFunc.emitError()
+ << "external definition has a mismatching signature ("
+ << externalSymbolFunc.getFunctionType() << ")";
+ }
+
+ OpBuilder builder(&op);
+ builder.setInsertionPoint(&op);
+ builder.clone(*externalSymbol);
+ symbol->erase();
+ }
+
+ return success();
+}
+
LogicalResult transform::detail::interpreterBaseRunOnOperationImpl(
Operation *target, StringRef passName,
const std::shared_ptr<OwningOpRef<ModuleOp>> &sharedTransformModule,
+ const std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule,
const RaggedArray<MappedValue> &extraMappings,
const TransformOptions &options,
const Pass::Option<std::string> &transformFileName,
+ const Pass::Option<std::string> &transformLibraryFileName,
const Pass::Option<std::string> &debugPayloadRootTag,
const Pass::Option<std::string> &debugTransformRootTag,
StringRef binaryName) {
// Step 3
// ------
+ // Copy external defintions for symbols if provided. Be aware of potential
+ // concurrent execution (normally, the error shouldn't be triggered unless the
+ // transform IR modifies itself in a pass, which is also forbidden elsewhere).
+ if (!sharedTransform && libraryModule && *libraryModule) {
+ if (!target->isProperAncestor(transformRoot)) {
+ InFlightDiagnostic diag =
+ transformRoot->emitError()
+ << "cannot inject transform definitions next to pass anchor op";
+ diag.attachNote(target->getLoc()) << "pass anchor op";
+ return diag;
+ }
+ if (failed(defineDeclaredSymbols(*transformRoot->getBlock(),
+ libraryModule->get())))
+ return failure();
+ }
+
+ // Step 4
+ // ------
// Optionally perform debug actions requested by the user to dump IR and a
// repro to stderr and/or a file.
performOptionalDebugActions(target, transformRoot, passName,
debugPayloadRootTag, debugTransformRootTag,
- binaryName);
+ transformLibraryFileName, binaryName);
- // Step 4
+ // Step 5
// ------
// Apply the transform to the IR
return applyTransforms(payloadRoot, cast<TransformOpInterface>(transformRoot),
LogicalResult transform::detail::interpreterBaseInitializeImpl(
MLIRContext *context, StringRef transformFileName,
- std::shared_ptr<OwningOpRef<ModuleOp>> &module) {
+ StringRef transformLibraryFileName,
+ std::shared_ptr<OwningOpRef<ModuleOp>> &module,
+ std::shared_ptr<OwningOpRef<ModuleOp>> &libraryModule) {
OwningOpRef<ModuleOp> parsed;
if (failed(parseTransformModuleFromFile(context, transformFileName, parsed)))
return failure();
+ if (parsed && failed(mlir::verify(*parsed)))
+ return failure();
+
+ OwningOpRef<ModuleOp> parsedLibrary;
+ if (failed(parseTransformModuleFromFile(context, transformLibraryFileName,
+ parsedLibrary)))
+ return failure();
+ if (parsedLibrary && failed(mlir::verify(*parsedLibrary)))
+ return failure();
module = std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsed));
+ if (!parsedLibrary || !*parsedLibrary)
+ return success();
+
+ if (module && *module) {
+ if (failed(defineDeclaredSymbols(*module->get().getBody(),
+ parsedLibrary.get())))
+ return failure();
+ } else {
+ libraryModule =
+ std::make_shared<OwningOpRef<ModuleOp>>(std::move(parsedLibrary));
+ }
return success();
}
// -----
module attributes { transform.with_named_sequence } {
- // expected-error @below {{failed to verify constraint: region with 1 blocks}}
- "transform.named_sequence"() ({}) { sym_name = "external_named_sequence", function_type = () -> () } : () -> ()
+ // expected-error @below {{expected a non-empty body block}}
+ "transform.named_sequence"() ({
+ ^bb0:
+ }) { sym_name = "external_named_sequence", function_type = () -> () } : () -> ()
transform.sequence failures(propagate) {
^bb0(%arg0: !transform.any_op):
--- /dev/null
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: --verify-diagnostics
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-file-name=%p/test-interpreter-external-symbol-decl.mlir transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: --verify-diagnostics
+
+// The external transform script has a declaration to the named sequence @foo,
+// the definition of which is provided in another file. Repeated application
+// of the same pass should not be a problem. Note that the same diagnostic
+// produced twice at the same location only needs to be matched once.
+
+// expected-remark @below {{message}}
+module {}
--- /dev/null
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
+// RUN: --verify-diagnostics --split-input-file
+
+// The definition of the @foo named sequence is provided in another file. It
+// will be included because of the pass option.
+
+module attributes {transform.with_named_sequence} {
+ // expected-error @below {{external definition has a mismatching signature}}
+ transform.named_sequence private @foo(!transform.op<"builtin.module">)
+
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.op<"builtin.module">):
+ include @foo failures(propagate) (%arg0) : (!transform.op<"builtin.module">) -> ()
+ }
+}
+
+// -----
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence private @undefined_sequence()
+
+ transform.sequence failures(suppress) {
+ ^bb0(%arg0: !transform.any_op):
+ // expected-error @below {{unresolved external named sequence}}
+ include @undefined_sequence failures(suppress) () : () -> ()
+ }
+}
--- /dev/null
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: --verify-diagnostics | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter)" \
+// RUN: --verify-diagnostics | FileCheck %s
+
+// RUN: mlir-opt %s --pass-pipeline="builtin.module(test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir}, test-transform-dialect-interpreter{transform-library-file-name=%p/test-interpreter-external-symbol-def.mlir})" \
+// RUN: --verify-diagnostics | FileCheck %s
+
+// The definition of the @foo named sequence is provided in another file. It
+// will be included because of the pass option. Repeated application of the
+// same pass, with or without the library option, should not be a problem.
+// Note that the same diagnostic produced twice at the same location only
+// needs to be matched once.
+
+// expected-remark @below {{message}}
+module attributes {transform.with_named_sequence} {
+ // CHECK: transform.named_sequence @foo
+ // CHECK: test_print_remark_at_operand %{{.*}}, "message"
+ transform.named_sequence private @foo(!transform.any_op)
+
+ transform.sequence failures(propagate) {
+ ^bb0(%arg0: !transform.any_op):
+ include @foo failures(propagate) (%arg0) : (!transform.any_op) -> ()
+ }
+}
--- /dev/null
+// RUN: mlir-opt %s
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @foo(%arg0: !transform.any_op) {
+ transform.test_print_remark_at_operand %arg0, "message" : !transform.any_op
+ transform.yield
+ }
+}
options = options.enableExpensiveChecks(enableExpensiveChecks);
if (failed(transform::detail::interpreterBaseRunOnOperationImpl(
getOperation(), getArgument(), getSharedTransformModule(),
- extraMapping, options, transformFileName, debugPayloadRootTag,
+ getTransformLibraryModule(), extraMapping, options,
+ transformFileName, transformLibraryFileName, debugPayloadRootTag,
debugTransformRootTag, getBinaryName())))
return signalPassFailure();
}
"the given value as container IR for top-level transform ops. This "
"allows user control on what transformation to apply. If empty, "
"select the container of the top-level transform op.")};
+ Option<std::string> transformLibraryFileName{
+ *this, "transform-library-file-name", llvm::cl::init(""),
+ llvm::cl::desc(
+ "Optional name of the file containing transform dialect symbol "
+ "definitions to be injected into the transform module.")};
};
struct TestTransformDialectEraseSchedulePass
"//mlir/test:lit_data",
] + glob([
"Transform/*-source.mlir",
+ "Transform/*-symbol-def.mlir",
])
)
for src in glob(
include=["**/*.mlir"],
- exclude=["Transform/*-source.mlir"]
+ exclude=[
+ "Transform/*-source.mlir",
+ "Transform/*-symbol-def.mlir",
+ "Transform/*-symbol-decl-and-schedule.mlir",
+ ]
)
]