#ifndef MLIR_SUPPORT_JITRUNNER_H_
#define MLIR_SUPPORT_JITRUNNER_H_
-#include "mlir/IR/Module.h"
-
#include "llvm/ADT/STLExtras.h"
-#include "llvm/IR/Module.h"
+#include "llvm/ExecutionEngine/Orc/Core.h"
-namespace mlir {
+namespace llvm {
+class Module;
+class LLVMContext;
-using TranslationCallback = llvm::function_ref<std::unique_ptr<llvm::Module>(
- ModuleOp, llvm::LLVMContext &)>;
+namespace orc {
+class MangleAndInterner;
+} // namespace orc
+} // namespace llvm
+
+namespace mlir {
class ModuleOp;
struct LogicalResult;
+struct JitRunnerConfig {
+ /// MLIR transformer applied after parsing the input into MLIR IR and before
+ /// passing the MLIR module to the ExecutionEngine.
+ llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer = nullptr;
+
+ /// A custom function that is passed to ExecutionEngine. It processes MLIR
+ /// module and creates LLVM IR module.
+ llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
+ llvm::LLVMContext &)>
+ llvmModuleBuilder = nullptr;
+
+ /// A callback to register symbols with ExecutionEngine at runtime.
+ llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
+ runtimesymbolMap = nullptr;
+};
+
// Entry point for all CPU runners. Expects the common argc/argv arguments for
-// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`.
-/// `mlirTransformer` is applied after parsing the input into MLIR IR and before
-/// passing the MLIR module to the ExecutionEngine.
-/// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine.
-/// It processes MLIR module and creates LLVM IR module.
-int JitRunnerMain(
- int argc, char **argv,
- llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer,
- TranslationCallback llvmModuleBuilder = nullptr);
+// standard C++ main functions.
+int JitRunnerMain(int argc, char **argv, JitRunnerConfig config = {});
} // namespace mlir
"object-filename",
llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
};
+
+struct CompileAndExecuteConfig {
+ /// LLVM module transformer that is passed to ExecutionEngine.
+ llvm::function_ref<llvm::Error(llvm::Module *)> transformer;
+
+ /// A custom function that is passed to ExecutionEngine. It processes MLIR
+ /// module and creates LLVM IR module.
+ llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
+ llvm::LLVMContext &)>
+ llvmModuleBuilder;
+
+ /// A custom function that is passed to ExecutinEngine to register symbols at
+ /// runtime.
+ llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
+ runtimeSymbolMap;
+};
+
} // end anonymous namespace
static OwningModuleRef parseMLIRInput(StringRef inputFilename,
}
// JIT-compile the given module and run "entryPoint" with "args" as arguments.
-static Error
-compileAndExecute(Options &options, ModuleOp module,
- TranslationCallback llvmModuleBuilder, StringRef entryPoint,
- std::function<llvm::Error(llvm::Module *)> transformer,
- void **args) {
+static Error compileAndExecute(Options &options, ModuleOp module,
+ StringRef entryPoint,
+ CompileAndExecuteConfig config, void **args) {
Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel;
if (auto clOptLevel = getCommandLineOptLevel(options))
jitCodeGenOptLevel =
SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(),
options.clSharedLibs.end());
auto expectedEngine = mlir::ExecutionEngine::create(
- module, llvmModuleBuilder, transformer, jitCodeGenOptLevel, libs);
+ module, config.llvmModuleBuilder, config.transformer, jitCodeGenOptLevel,
+ libs);
if (!expectedEngine)
return expectedEngine.takeError();
auto engine = std::move(*expectedEngine);
+ if (config.runtimeSymbolMap)
+ engine->registerSymbols(config.runtimeSymbolMap);
+
auto expectedFPtr = engine->lookup(entryPoint);
if (!expectedFPtr)
return expectedFPtr.takeError();
return Error::success();
}
-static Error compileAndExecuteVoidFunction(
- Options &options, ModuleOp module, TranslationCallback llvmModuleBuilder,
- StringRef entryPoint,
- std::function<llvm::Error(llvm::Module *)> transformer) {
+static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module,
+ StringRef entryPoint,
+ CompileAndExecuteConfig config) {
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
if (!mainFunction || mainFunction.empty())
return make_string_error("entry point not found");
void *empty = nullptr;
- return compileAndExecute(options, module, llvmModuleBuilder, entryPoint,
- transformer, &empty);
+ return compileAndExecute(options, module, entryPoint, config, &empty);
}
template <typename Type>
return Error::success();
}
template <typename Type>
-Error compileAndExecuteSingleReturnFunction(
- Options &options, ModuleOp module, TranslationCallback llvmModuleBuilder,
- StringRef entryPoint,
- std::function<llvm::Error(llvm::Module *)> transformer) {
+Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module,
+ StringRef entryPoint,
+ CompileAndExecuteConfig config) {
auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
if (!mainFunction || mainFunction.isExternal())
return make_string_error("entry point not found");
void *data;
} data;
data.data = &res;
- if (auto error = compileAndExecute(options, module, llvmModuleBuilder,
- entryPoint, transformer, (void **)&data))
+ if (auto error = compileAndExecute(options, module, entryPoint, config,
+ (void **)&data))
return error;
// Intentional printing of the output so we can test.
}
/// Entry point for all CPU runners. Expects the common argc/argv arguments for
-/// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`.
-/// `mlirTransformer` is applied after parsing the input into MLIR IR and before
-/// passing the MLIR module to the ExecutionEngine.
-/// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine.
-/// It processes MLIR module and creates LLVM IR module.
-int mlir::JitRunnerMain(
- int argc, char **argv,
- function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer,
- TranslationCallback llvmModuleBuilder) {
+/// standard C++ main functions.
+int mlir::JitRunnerMain(int argc, char **argv, JitRunnerConfig config) {
// Create the options struct containing the command line options for the
// runner. This must come before the command line options are parsed.
Options options;
return 1;
}
- if (mlirTransformer)
- if (failed(mlirTransformer(m.get())))
+ if (config.mlirTransformer)
+ if (failed(config.mlirTransformer(m.get())))
return EXIT_FAILURE;
auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
auto transformer = mlir::makeLLVMPassesTransformer(
passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition);
+ CompileAndExecuteConfig compileAndExecuteConfig;
+ compileAndExecuteConfig.transformer = transformer;
+ compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
+ compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;
+
// Get the function used to compile and execute the module.
using CompileAndExecuteFnT =
- Error (*)(Options &, ModuleOp, TranslationCallback, StringRef,
- std::function<llvm::Error(llvm::Module *)>);
+ Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig);
auto compileAndExecuteFn =
StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
.Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
.Case("void", compileAndExecuteVoidFunction)
.Default(nullptr);
- Error error =
- compileAndExecuteFn
- ? compileAndExecuteFn(options, m.get(), llvmModuleBuilder,
- options.mainFuncName.getValue(), transformer)
- : make_string_error("unsupported function type");
+ Error error = compileAndExecuteFn
+ ? compileAndExecuteFn(options, m.get(),
+ options.mainFuncName.getValue(),
+ compileAndExecuteConfig)
+ : make_string_error("unsupported function type");
int exitCode = EXIT_SUCCESS;
llvm::handleAllErrors(std::move(error),