[mlir] JitRunner: add a config option to register symbols with ExecutionEngine at...
authorEugene Zhulenev <ezhulenev@google.com>
Tue, 27 Oct 2020 21:12:47 +0000 (14:12 -0700)
committerEugene Zhulenev <ezhulenev@google.com>
Tue, 27 Oct 2020 22:57:34 +0000 (15:57 -0700)
Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D90264

mlir/include/mlir/ExecutionEngine/JitRunner.h
mlir/lib/ExecutionEngine/JitRunner.cpp
mlir/tools/mlir-cpu-runner/mlir-cpu-runner.cpp
mlir/tools/mlir-cuda-runner/mlir-cuda-runner.cpp
mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp
mlir/tools/mlir-vulkan-runner/mlir-vulkan-runner.cpp

index 2b7518c..43c9f96 100644 (file)
 #ifndef MLIR_SUPPORT_JITRUNNER_H_
 #define MLIR_SUPPORT_JITRUNNER_H_
 
-#include "mlir/IR/Module.h"
-
 #include "llvm/ADT/STLExtras.h"
-#include "llvm/IR/Module.h"
+#include "llvm/ExecutionEngine/Orc/Core.h"
 
-namespace mlir {
+namespace llvm {
+class Module;
+class LLVMContext;
 
-using TranslationCallback = llvm::function_ref<std::unique_ptr<llvm::Module>(
-    ModuleOp, llvm::LLVMContext &)>;
+namespace orc {
+class MangleAndInterner;
+} // namespace orc
+} // namespace llvm
+
+namespace mlir {
 
 class ModuleOp;
 struct LogicalResult;
 
+struct JitRunnerConfig {
+  /// MLIR transformer applied after parsing the input into MLIR IR and before
+  /// passing the MLIR module to the ExecutionEngine.
+  llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer = nullptr;
+
+  /// A custom function that is passed to ExecutionEngine. It processes MLIR
+  /// module and creates LLVM IR module.
+  llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
+                                                   llvm::LLVMContext &)>
+      llvmModuleBuilder = nullptr;
+
+  /// A callback to register symbols with ExecutionEngine at runtime.
+  llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
+      runtimesymbolMap = nullptr;
+};
+
 // Entry point for all CPU runners. Expects the common argc/argv arguments for
-// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`.
-/// `mlirTransformer` is applied after parsing the input into MLIR IR and before
-/// passing the MLIR module to the ExecutionEngine.
-/// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine.
-/// It processes MLIR module and creates LLVM IR module.
-int JitRunnerMain(
-    int argc, char **argv,
-    llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer,
-    TranslationCallback llvmModuleBuilder = nullptr);
+// standard C++ main functions.
+int JitRunnerMain(int argc, char **argv, JitRunnerConfig config = {});
 
 } // namespace mlir
 
index c1bb1e6..d9e0ff6 100644 (file)
@@ -92,6 +92,23 @@ struct Options {
       "object-filename",
       llvm::cl::desc("Dump JITted-compiled object to file <input file>.o")};
 };
+
+struct CompileAndExecuteConfig {
+  /// LLVM module transformer that is passed to ExecutionEngine.
+  llvm::function_ref<llvm::Error(llvm::Module *)> transformer;
+
+  /// A custom function that is passed to ExecutionEngine. It processes MLIR
+  /// module and creates LLVM IR module.
+  llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
+                                                   llvm::LLVMContext &)>
+      llvmModuleBuilder;
+
+  /// A custom function that is passed to ExecutinEngine to register symbols at
+  /// runtime.
+  llvm::function_ref<llvm::orc::SymbolMap(llvm::orc::MangleAndInterner)>
+      runtimeSymbolMap;
+};
+
 } // end anonymous namespace
 
 static OwningModuleRef parseMLIRInput(StringRef inputFilename,
@@ -131,11 +148,9 @@ static Optional<unsigned> getCommandLineOptLevel(Options &options) {
 }
 
 // JIT-compile the given module and run "entryPoint" with "args" as arguments.
-static Error
-compileAndExecute(Options &options, ModuleOp module,
-                  TranslationCallback llvmModuleBuilder, StringRef entryPoint,
-                  std::function<llvm::Error(llvm::Module *)> transformer,
-                  void **args) {
+static Error compileAndExecute(Options &options, ModuleOp module,
+                               StringRef entryPoint,
+                               CompileAndExecuteConfig config, void **args) {
   Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel;
   if (auto clOptLevel = getCommandLineOptLevel(options))
     jitCodeGenOptLevel =
@@ -143,11 +158,15 @@ compileAndExecute(Options &options, ModuleOp module,
   SmallVector<StringRef, 4> libs(options.clSharedLibs.begin(),
                                  options.clSharedLibs.end());
   auto expectedEngine = mlir::ExecutionEngine::create(
-      module, llvmModuleBuilder, transformer, jitCodeGenOptLevel, libs);
+      module, config.llvmModuleBuilder, config.transformer, jitCodeGenOptLevel,
+      libs);
   if (!expectedEngine)
     return expectedEngine.takeError();
 
   auto engine = std::move(*expectedEngine);
+  if (config.runtimeSymbolMap)
+    engine->registerSymbols(config.runtimeSymbolMap);
+
   auto expectedFPtr = engine->lookup(entryPoint);
   if (!expectedFPtr)
     return expectedFPtr.takeError();
@@ -163,16 +182,14 @@ compileAndExecute(Options &options, ModuleOp module,
   return Error::success();
 }
 
-static Error compileAndExecuteVoidFunction(
-    Options &options, ModuleOp module, TranslationCallback llvmModuleBuilder,
-    StringRef entryPoint,
-    std::function<llvm::Error(llvm::Module *)> transformer) {
+static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module,
+                                           StringRef entryPoint,
+                                           CompileAndExecuteConfig config) {
   auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
   if (!mainFunction || mainFunction.empty())
     return make_string_error("entry point not found");
   void *empty = nullptr;
-  return compileAndExecute(options, module, llvmModuleBuilder, entryPoint,
-                           transformer, &empty);
+  return compileAndExecute(options, module, entryPoint, config, &empty);
 }
 
 template <typename Type>
@@ -196,10 +213,9 @@ Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
   return Error::success();
 }
 template <typename Type>
-Error compileAndExecuteSingleReturnFunction(
-    Options &options, ModuleOp module, TranslationCallback llvmModuleBuilder,
-    StringRef entryPoint,
-    std::function<llvm::Error(llvm::Module *)> transformer) {
+Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module,
+                                            StringRef entryPoint,
+                                            CompileAndExecuteConfig config) {
   auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
   if (!mainFunction || mainFunction.isExternal())
     return make_string_error("entry point not found");
@@ -215,8 +231,8 @@ Error compileAndExecuteSingleReturnFunction(
     void *data;
   } data;
   data.data = &res;
-  if (auto error = compileAndExecute(options, module, llvmModuleBuilder,
-                                     entryPoint, transformer, (void **)&data))
+  if (auto error = compileAndExecute(options, module, entryPoint, config,
+                                     (void **)&data))
     return error;
 
   // Intentional printing of the output so we can test.
@@ -226,15 +242,8 @@ Error compileAndExecuteSingleReturnFunction(
 }
 
 /// Entry point for all CPU runners. Expects the common argc/argv arguments for
-/// standard C++ main functions, `mlirTransformer` and `llvmModuleBuilder`.
-/// `mlirTransformer` is applied after parsing the input into MLIR IR and before
-/// passing the MLIR module to the ExecutionEngine.
-/// `llvmModuleBuilder` is a custom function that is passed to ExecutionEngine.
-/// It processes MLIR module and creates LLVM IR module.
-int mlir::JitRunnerMain(
-    int argc, char **argv,
-    function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer,
-    TranslationCallback llvmModuleBuilder) {
+/// standard C++ main functions.
+int mlir::JitRunnerMain(int argc, char **argv, JitRunnerConfig config) {
   // Create the options struct containing the command line options for the
   // runner. This must come before the command line options are parsed.
   Options options;
@@ -274,8 +283,8 @@ int mlir::JitRunnerMain(
     return 1;
   }
 
-  if (mlirTransformer)
-    if (failed(mlirTransformer(m.get())))
+  if (config.mlirTransformer)
+    if (failed(config.mlirTransformer(m.get())))
       return EXIT_FAILURE;
 
   auto tmBuilderOrError = llvm::orc::JITTargetMachineBuilder::detectHost();
@@ -292,10 +301,14 @@ int mlir::JitRunnerMain(
   auto transformer = mlir::makeLLVMPassesTransformer(
       passes, optLevel, /*targetMachine=*/tmOrError->get(), optPosition);
 
+  CompileAndExecuteConfig compileAndExecuteConfig;
+  compileAndExecuteConfig.transformer = transformer;
+  compileAndExecuteConfig.llvmModuleBuilder = config.llvmModuleBuilder;
+  compileAndExecuteConfig.runtimeSymbolMap = config.runtimesymbolMap;
+
   // Get the function used to compile and execute the module.
   using CompileAndExecuteFnT =
-      Error (*)(Options &, ModuleOp, TranslationCallback, StringRef,
-                std::function<llvm::Error(llvm::Module *)>);
+      Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig);
   auto compileAndExecuteFn =
       StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
           .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
@@ -304,11 +317,11 @@ int mlir::JitRunnerMain(
           .Case("void", compileAndExecuteVoidFunction)
           .Default(nullptr);
 
-  Error error =
-      compileAndExecuteFn
-          ? compileAndExecuteFn(options, m.get(), llvmModuleBuilder,
-                                options.mainFuncName.getValue(), transformer)
-          : make_string_error("unsupported function type");
+  Error error = compileAndExecuteFn
+                    ? compileAndExecuteFn(options, m.get(),
+                                          options.mainFuncName.getValue(),
+                                          compileAndExecuteConfig)
+                    : make_string_error("unsupported function type");
 
   int exitCode = EXIT_SUCCESS;
   llvm::handleAllErrors(std::move(error),
index 7667908..a2661c1 100644 (file)
@@ -24,5 +24,5 @@ int main(int argc, char **argv) {
   llvm::InitializeNativeTargetAsmPrinter();
   mlir::initializeLLVMPasses();
 
-  return mlir::JitRunnerMain(argc, argv, nullptr);
+  return mlir::JitRunnerMain(argc, argv);
 }
index be00646..cfffaaa 100644 (file)
@@ -136,5 +136,9 @@ int main(int argc, char **argv) {
   LLVMInitializeNVPTXAsmPrinter();
 
   mlir::initializeLLVMPasses();
-  return mlir::JitRunnerMain(argc, argv, &runMLIRPasses);
+
+  mlir::JitRunnerConfig jitRunnerConfig;
+  jitRunnerConfig.mlirTransformer = &runMLIRPasses;
+
+  return mlir::JitRunnerMain(argc, argv, jitRunnerConfig);
 }
index 9979801..cc0f503 100644 (file)
@@ -86,5 +86,9 @@ int main(int argc, char **argv) {
   llvm::InitializeNativeTargetAsmPrinter();
   mlir::initializeLLVMPasses();
 
-  return mlir::JitRunnerMain(argc, argv, &runMLIRPasses, &convertMLIRModule);
+  mlir::JitRunnerConfig jitRunnerConfig;
+  jitRunnerConfig.mlirTransformer = &runMLIRPasses;
+  jitRunnerConfig.llvmModuleBuilder = &convertMLIRModule;
+
+  return mlir::JitRunnerMain(argc, argv, jitRunnerConfig);
 }
index 905d2e4..322f949 100644 (file)
@@ -58,5 +58,8 @@ int main(int argc, char **argv) {
   llvm::InitializeNativeTargetAsmPrinter();
   mlir::initializeLLVMPasses();
 
-  return mlir::JitRunnerMain(argc, argv, &runMLIRPasses);
+  mlir::JitRunnerConfig jitRunnerConfig;
+  jitRunnerConfig.mlirTransformer = &runMLIRPasses;
+
+  return mlir::JitRunnerMain(argc, argv, jitRunnerConfig);
 }