From 200889fbd91b661a8cf4e5a914d05010300ffee1 Mon Sep 17 00:00:00 2001 From: rkayaith Date: Wed, 14 Sep 2022 15:34:27 -0400 Subject: [PATCH] [mlir-cpu-runner] Support parsing operations other than 'builtin.module' as top-level This adds a `--no-implicit-module` option, which disables the insertion of a top-level `builtin.module` during parsing. The top-level op is required to have the `SymbolTable` trait. The majority of the change here is removing `ModuleOp` from interfaces. Reviewed By: rriddle Differential Revision: https://reviews.llvm.org/D134238 --- .../include/mlir/ExecutionEngine/ExecutionEngine.h | 20 +++++------ mlir/include/mlir/ExecutionEngine/JitRunner.h | 13 +++---- mlir/lib/ExecutionEngine/ExecutionEngine.cpp | 2 +- mlir/lib/ExecutionEngine/JitRunner.cpp | 41 ++++++++++++++++------ mlir/test/mlir-cpu-runner/invalid.mlir | 4 +++ .../mlir-spirv-cpu-runner.cpp | 10 ++++-- 6 files changed, 59 insertions(+), 31 deletions(-) create mode 100644 mlir/test/mlir-cpu-runner/invalid.mlir diff --git a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h index 7d4a708..17776d0 100644 --- a/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h +++ b/mlir/include/mlir/ExecutionEngine/ExecutionEngine.h @@ -34,7 +34,7 @@ class MemoryBuffer; namespace mlir { -class ModuleOp; +class Operation; /// A simple object cache following Lang's LLJITWithObjectCache example. class SimpleObjectCache : public llvm::ObjectCache { @@ -51,10 +51,10 @@ private: }; struct ExecutionEngineOptions { - /// If `llvmModuleBuilder` is provided, it will be used to create LLVM module - /// from the given MLIR module. Otherwise, a default `translateModuleToLLVMIR` - /// function will be used to translate MLIR module to LLVM IR. - llvm::function_ref(ModuleOp, + /// If `llvmModuleBuilder` is provided, it will be used to create an LLVM + /// module from the given MLIR IR. Otherwise, a default + /// `translateModuleToLLVMIR` function will be used to translate to LLVM IR. + llvm::function_ref(Operation *, llvm::LLVMContext &)> llvmModuleBuilder = nullptr; @@ -89,9 +89,9 @@ struct ExecutionEngineOptions { bool enablePerfNotificationListener = true; }; -/// JIT-backed execution engine for MLIR modules. Assumes the module can be -/// converted to LLVM IR. For each function, creates a wrapper function with -/// the fixed interface +/// JIT-backed execution engine for MLIR. Assumes the IR can be converted to +/// LLVM IR. For each function, creates a wrapper function with the fixed +/// interface /// /// void _mlir_funcName(void **) /// @@ -104,9 +104,9 @@ public: ExecutionEngine(bool enableObjectCache, bool enableGDBNotificationListener, bool enablePerfNotificationListener); - /// Creates an execution engine for the given module. + /// Creates an execution engine for the given MLIR IR. static llvm::Expected> - create(ModuleOp m, const ExecutionEngineOptions &options = {}); + create(Operation *op, const ExecutionEngineOptions &options = {}); /// Looks up a packed-argument function wrapping the function with the given /// name and returns a pointer to it. Propagates errors in case of failure. diff --git a/mlir/include/mlir/ExecutionEngine/JitRunner.h b/mlir/include/mlir/ExecutionEngine/JitRunner.h index 3f04387..b25c604d 100644 --- a/mlir/include/mlir/ExecutionEngine/JitRunner.h +++ b/mlir/include/mlir/ExecutionEngine/JitRunner.h @@ -33,17 +33,18 @@ class MangleAndInterner; namespace mlir { class DialectRegistry; -class ModuleOp; +class Operation; struct LogicalResult; struct JitRunnerConfig { /// MLIR transformer applied after parsing the input into MLIR IR and before - /// passing the MLIR module to the ExecutionEngine. - llvm::function_ref mlirTransformer = nullptr; + /// passing the MLIR IR to the ExecutionEngine. + llvm::function_ref mlirTransformer = + nullptr; - /// A custom function that is passed to ExecutionEngine. It processes MLIR - /// module and creates LLVM IR module. - llvm::function_ref(ModuleOp, + /// A custom function that is passed to ExecutionEngine. It processes MLIR and + /// creates an LLVM IR module. + llvm::function_ref(Operation *, llvm::LLVMContext &)> llvmModuleBuilder = nullptr; diff --git a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp index 72e6c05..0c38d85 100644 --- a/mlir/lib/ExecutionEngine/ExecutionEngine.cpp +++ b/mlir/lib/ExecutionEngine/ExecutionEngine.cpp @@ -232,7 +232,7 @@ ExecutionEngine::ExecutionEngine(bool enableObjectCache, } Expected> -ExecutionEngine::create(ModuleOp m, const ExecutionEngineOptions &options) { +ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options) { auto engine = std::make_unique( options.enableObjectCache, options.enableGDBNotificationListener, options.enablePerfNotificationListener); diff --git a/mlir/lib/ExecutionEngine/JitRunner.cpp b/mlir/lib/ExecutionEngine/JitRunner.cpp index 6005c9b..cfa22fe 100644 --- a/mlir/lib/ExecutionEngine/JitRunner.cpp +++ b/mlir/lib/ExecutionEngine/JitRunner.cpp @@ -23,6 +23,7 @@ #include "mlir/IR/MLIRContext.h" #include "mlir/Parser/Parser.h" #include "mlir/Support/FileUtilities.h" +#include "mlir/Tools/ParseUtilties.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h" @@ -91,6 +92,12 @@ struct Options { llvm::cl::opt hostSupportsJit{"host-supports-jit", llvm::cl::desc("Report host JIT support"), llvm::cl::Hidden}; + + llvm::cl::opt noImplicitModule{ + "no-implicit-module", + llvm::cl::desc( + "Disable implicit addition of a top-level module op during parsing"), + llvm::cl::init(false)}; }; struct CompileAndExecuteConfig { @@ -99,7 +106,7 @@ struct CompileAndExecuteConfig { /// A custom function that is passed to ExecutionEngine. It processes MLIR /// module and creates LLVM IR module. - llvm::function_ref(ModuleOp, + llvm::function_ref(Operation *, llvm::LLVMContext &)> llvmModuleBuilder; @@ -111,8 +118,9 @@ struct CompileAndExecuteConfig { } // namespace -static OwningOpRef parseMLIRInput(StringRef inputFilename, - MLIRContext *context) { +static OwningOpRef parseMLIRInput(StringRef inputFilename, + bool insertImplicitModule, + MLIRContext *context) { // Set up the input file. std::string errorMessage; auto file = openInputFile(inputFilename, &errorMessage); @@ -123,7 +131,15 @@ static OwningOpRef parseMLIRInput(StringRef inputFilename, llvm::SourceMgr sourceMgr; sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc()); - return parseSourceFile(sourceMgr, context); + OwningOpRef module = + parseSourceFileForTool(sourceMgr, context, insertImplicitModule); + if (!module) + return nullptr; + if (!module.get()->hasTrait()) { + llvm::errs() << "Error: top-level op must be a symbol table.\n"; + return nullptr; + } + return module; } static inline Error makeStringError(const Twine &message) { @@ -148,7 +164,7 @@ static Optional getCommandLineOptLevel(Options &options) { } // JIT-compile the given module and run "entryPoint" with "args" as arguments. -static Error compileAndExecute(Options &options, ModuleOp module, +static Error compileAndExecute(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config, void **args) { Optional jitCodeGenOptLevel; @@ -240,10 +256,11 @@ static Error compileAndExecute(Options &options, ModuleOp module, return Error::success(); } -static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module, +static Error compileAndExecuteVoidFunction(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config) { - auto mainFunction = module.lookupSymbol(entryPoint); + auto mainFunction = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(module, entryPoint)); if (!mainFunction || mainFunction.empty()) return makeStringError("entry point not found"); void *empty = nullptr; @@ -283,10 +300,11 @@ Error checkCompatibleReturnType(LLVM::LLVMFuncOp mainFunction) { return Error::success(); } template -Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module, +Error compileAndExecuteSingleReturnFunction(Options &options, Operation *module, StringRef entryPoint, CompileAndExecuteConfig config) { - auto mainFunction = module.lookupSymbol(entryPoint); + auto mainFunction = dyn_cast_or_null( + SymbolTable::lookupSymbolIn(module, entryPoint)); if (!mainFunction || mainFunction.isExternal()) return makeStringError("entry point not found"); @@ -339,7 +357,8 @@ int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry ®istry, MLIRContext context(registry); - auto m = parseMLIRInput(options.inputFilename, &context); + auto m = parseMLIRInput(options.inputFilename, !options.noImplicitModule, + &context); if (!m) { llvm::errs() << "could not parse the input IR\n"; return 1; @@ -370,7 +389,7 @@ int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry ®istry, // Get the function used to compile and execute the module. using CompileAndExecuteFnT = - Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig); + Error (*)(Options &, Operation *, StringRef, CompileAndExecuteConfig); auto compileAndExecuteFn = StringSwitch(options.mainFuncType.getValue()) .Case("i32", compileAndExecuteSingleReturnFunction) diff --git a/mlir/test/mlir-cpu-runner/invalid.mlir b/mlir/test/mlir-cpu-runner/invalid.mlir new file mode 100644 index 0000000..74bfbb7 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/invalid.mlir @@ -0,0 +1,4 @@ +// RUN: not mlir-cpu-runner --no-implicit-module %s |& FileCheck %s + +// CHECK: Error: top-level op must be a symbol table. +llvm.func @main() diff --git a/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp b/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp index 87305a7..b87221a 100644 --- a/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp +++ b/mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp @@ -51,7 +51,10 @@ using namespace mlir; /// Each of these two modules is translated to LLVM IR module, then they are /// linked together and returned. static std::unique_ptr -convertMLIRModule(ModuleOp module, llvm::LLVMContext &context) { +convertMLIRModule(Operation *op, llvm::LLVMContext &context) { + auto module = dyn_cast(op); + if (!module) + return op->emitError("op must be a 'builtin.module"), nullptr; // Verify that there is only one nested module. auto modules = module.getOps(); if (!llvm::hasSingleElement(modules)) { @@ -71,8 +74,9 @@ convertMLIRModule(ModuleOp module, llvm::LLVMContext &context) { return mainModule; } -static LogicalResult runMLIRPasses(ModuleOp module) { - PassManager passManager(module.getContext()); +static LogicalResult runMLIRPasses(Operation *module) { + PassManager passManager(module->getContext(), + module->getName().getStringRef()); applyPassManagerCLOptions(passManager); passManager.addPass(createGpuKernelOutliningPass()); passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true)); -- 2.7.4