[mlir-cpu-runner] Support parsing operations other than 'builtin.module' as top-level
authorrkayaith <rkayaith@gmail.com>
Wed, 14 Sep 2022 19:34:27 +0000 (15:34 -0400)
committerrkayaith <rkayaith@gmail.com>
Mon, 3 Oct 2022 19:36:59 +0000 (15:36 -0400)
This adds a `--no-implicit-module` option, which disables the insertion
of a top-level `builtin.module` during parsing. The top-level op is
required to have the `SymbolTable` trait.

The majority of the change here is removing `ModuleOp` from interfaces.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D134238

mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
mlir/include/mlir/ExecutionEngine/JitRunner.h
mlir/lib/ExecutionEngine/ExecutionEngine.cpp
mlir/lib/ExecutionEngine/JitRunner.cpp
mlir/test/mlir-cpu-runner/invalid.mlir [new file with mode: 0644]
mlir/tools/mlir-spirv-cpu-runner/mlir-spirv-cpu-runner.cpp

index 7d4a708..17776d0 100644 (file)
@@ -34,7 +34,7 @@ class MemoryBuffer;
 
 namespace mlir {
 
-class ModuleOp;
+class Operation;
 
 /// A simple object cache following Lang's LLJITWithObjectCache example.
 class SimpleObjectCache : public llvm::ObjectCache {
@@ -51,10 +51,10 @@ private:
 };
 
 struct ExecutionEngineOptions {
-  /// If `llvmModuleBuilder` is provided, it will be used to create LLVM module
-  /// from the given MLIR module. Otherwise, a default `translateModuleToLLVMIR`
-  /// function will be used to translate MLIR module to LLVM IR.
-  llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
+  /// If `llvmModuleBuilder` is provided, it will be used to create an LLVM
+  /// module from the given MLIR IR. Otherwise, a default
+  /// `translateModuleToLLVMIR` function will be used to translate to LLVM IR.
+  llvm::function_ref<std::unique_ptr<llvm::Module>(Operation *,
                                                    llvm::LLVMContext &)>
       llvmModuleBuilder = nullptr;
 
@@ -89,9 +89,9 @@ struct ExecutionEngineOptions {
   bool enablePerfNotificationListener = true;
 };
 
-/// JIT-backed execution engine for MLIR modules.  Assumes the module can be
-/// converted to LLVM IR.  For each function, creates a wrapper function with
-/// the fixed interface
+/// JIT-backed execution engine for MLIR. Assumes the IR can be converted to
+/// LLVM IR. For each function, creates a wrapper function with the fixed
+/// interface
 ///
 ///     void _mlir_funcName(void **)
 ///
@@ -104,9 +104,9 @@ public:
   ExecutionEngine(bool enableObjectCache, bool enableGDBNotificationListener,
                   bool enablePerfNotificationListener);
 
-  /// Creates an execution engine for the given module.
+  /// Creates an execution engine for the given MLIR IR.
   static llvm::Expected<std::unique_ptr<ExecutionEngine>>
-  create(ModuleOp m, const ExecutionEngineOptions &options = {});
+  create(Operation *op, const ExecutionEngineOptions &options = {});
 
   /// Looks up a packed-argument function wrapping the function with the given
   /// name and returns a pointer to it. Propagates errors in case of failure.
index 3f04387..b25c604 100644 (file)
@@ -33,17 +33,18 @@ class MangleAndInterner;
 namespace mlir {
 
 class DialectRegistry;
-class ModuleOp;
+class Operation;
 struct LogicalResult;
 
 struct JitRunnerConfig {
   /// MLIR transformer applied after parsing the input into MLIR IR and before
-  /// passing the MLIR module to the ExecutionEngine.
-  llvm::function_ref<LogicalResult(mlir::ModuleOp)> mlirTransformer = nullptr;
+  /// passing the MLIR IR to the ExecutionEngine.
+  llvm::function_ref<LogicalResult(mlir::Operation *)> mlirTransformer =
+      nullptr;
 
-  /// A custom function that is passed to ExecutionEngine. It processes MLIR
-  /// module and creates LLVM IR module.
-  llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
+  /// A custom function that is passed to ExecutionEngine. It processes MLIR and
+  /// creates an LLVM IR module.
+  llvm::function_ref<std::unique_ptr<llvm::Module>(Operation *,
                                                    llvm::LLVMContext &)>
       llvmModuleBuilder = nullptr;
 
index 72e6c05..0c38d85 100644 (file)
@@ -232,7 +232,7 @@ ExecutionEngine::ExecutionEngine(bool enableObjectCache,
 }
 
 Expected<std::unique_ptr<ExecutionEngine>>
-ExecutionEngine::create(ModuleOp m, const ExecutionEngineOptions &options) {
+ExecutionEngine::create(Operation *m, const ExecutionEngineOptions &options) {
   auto engine = std::make_unique<ExecutionEngine>(
       options.enableObjectCache, options.enableGDBNotificationListener,
       options.enablePerfNotificationListener);
index 6005c9b..cfa22fe 100644 (file)
@@ -23,6 +23,7 @@
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/Parser/Parser.h"
 #include "mlir/Support/FileUtilities.h"
+#include "mlir/Tools/ParseUtilties.h"
 
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ExecutionEngine/Orc/JITTargetMachineBuilder.h"
@@ -91,6 +92,12 @@ struct Options {
   llvm::cl::opt<bool> hostSupportsJit{"host-supports-jit",
                                       llvm::cl::desc("Report host JIT support"),
                                       llvm::cl::Hidden};
+
+  llvm::cl::opt<bool> noImplicitModule{
+      "no-implicit-module",
+      llvm::cl::desc(
+          "Disable implicit addition of a top-level module op during parsing"),
+      llvm::cl::init(false)};
 };
 
 struct CompileAndExecuteConfig {
@@ -99,7 +106,7 @@ struct CompileAndExecuteConfig {
 
   /// A custom function that is passed to ExecutionEngine. It processes MLIR
   /// module and creates LLVM IR module.
-  llvm::function_ref<std::unique_ptr<llvm::Module>(ModuleOp,
+  llvm::function_ref<std::unique_ptr<llvm::Module>(Operation *,
                                                    llvm::LLVMContext &)>
       llvmModuleBuilder;
 
@@ -111,8 +118,9 @@ struct CompileAndExecuteConfig {
 
 } // namespace
 
-static OwningOpRef<ModuleOp> parseMLIRInput(StringRef inputFilename,
-                                            MLIRContext *context) {
+static OwningOpRef<Operation *> parseMLIRInput(StringRef inputFilename,
+                                               bool insertImplicitModule,
+                                               MLIRContext *context) {
   // Set up the input file.
   std::string errorMessage;
   auto file = openInputFile(inputFilename, &errorMessage);
@@ -123,7 +131,15 @@ static OwningOpRef<ModuleOp> parseMLIRInput(StringRef inputFilename,
 
   llvm::SourceMgr sourceMgr;
   sourceMgr.AddNewSourceBuffer(std::move(file), SMLoc());
-  return parseSourceFile<ModuleOp>(sourceMgr, context);
+  OwningOpRef<Operation *> module =
+      parseSourceFileForTool(sourceMgr, context, insertImplicitModule);
+  if (!module)
+    return nullptr;
+  if (!module.get()->hasTrait<OpTrait::SymbolTable>()) {
+    llvm::errs() << "Error: top-level op must be a symbol table.\n";
+    return nullptr;
+  }
+  return module;
 }
 
 static inline Error makeStringError(const Twine &message) {
@@ -148,7 +164,7 @@ static Optional<unsigned> getCommandLineOptLevel(Options &options) {
 }
 
 // JIT-compile the given module and run "entryPoint" with "args" as arguments.
-static Error compileAndExecute(Options &options, ModuleOp module,
+static Error compileAndExecute(Options &options, Operation *module,
                                StringRef entryPoint,
                                CompileAndExecuteConfig config, void **args) {
   Optional<llvm::CodeGenOpt::Level> jitCodeGenOptLevel;
@@ -240,10 +256,11 @@ static Error compileAndExecute(Options &options, ModuleOp module,
   return Error::success();
 }
 
-static Error compileAndExecuteVoidFunction(Options &options, ModuleOp module,
+static Error compileAndExecuteVoidFunction(Options &options, Operation *module,
                                            StringRef entryPoint,
                                            CompileAndExecuteConfig config) {
-  auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
+  auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(
+      SymbolTable::lookupSymbolIn(module, entryPoint));
   if (!mainFunction || mainFunction.empty())
     return makeStringError("entry point not found");
   void *empty = nullptr;
@@ -283,10 +300,11 @@ Error checkCompatibleReturnType<float>(LLVM::LLVMFuncOp mainFunction) {
   return Error::success();
 }
 template <typename Type>
-Error compileAndExecuteSingleReturnFunction(Options &options, ModuleOp module,
+Error compileAndExecuteSingleReturnFunction(Options &options, Operation *module,
                                             StringRef entryPoint,
                                             CompileAndExecuteConfig config) {
-  auto mainFunction = module.lookupSymbol<LLVM::LLVMFuncOp>(entryPoint);
+  auto mainFunction = dyn_cast_or_null<LLVM::LLVMFuncOp>(
+      SymbolTable::lookupSymbolIn(module, entryPoint));
   if (!mainFunction || mainFunction.isExternal())
     return makeStringError("entry point not found");
 
@@ -339,7 +357,8 @@ int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry &registry,
 
   MLIRContext context(registry);
 
-  auto m = parseMLIRInput(options.inputFilename, &context);
+  auto m = parseMLIRInput(options.inputFilename, !options.noImplicitModule,
+                          &context);
   if (!m) {
     llvm::errs() << "could not parse the input IR\n";
     return 1;
@@ -370,7 +389,7 @@ int mlir::JitRunnerMain(int argc, char **argv, const DialectRegistry &registry,
 
   // Get the function used to compile and execute the module.
   using CompileAndExecuteFnT =
-      Error (*)(Options &, ModuleOp, StringRef, CompileAndExecuteConfig);
+      Error (*)(Options &, Operation *, StringRef, CompileAndExecuteConfig);
   auto compileAndExecuteFn =
       StringSwitch<CompileAndExecuteFnT>(options.mainFuncType.getValue())
           .Case("i32", compileAndExecuteSingleReturnFunction<int32_t>)
diff --git a/mlir/test/mlir-cpu-runner/invalid.mlir b/mlir/test/mlir-cpu-runner/invalid.mlir
new file mode 100644 (file)
index 0000000..74bfbb7
--- /dev/null
@@ -0,0 +1,4 @@
+// RUN: not mlir-cpu-runner --no-implicit-module %s |& FileCheck %s
+
+// CHECK: Error: top-level op must be a symbol table.
+llvm.func @main()
index 87305a7..b87221a 100644 (file)
@@ -51,7 +51,10 @@ using namespace mlir;
 /// Each of these two modules is translated to LLVM IR module, then they are
 /// linked together and returned.
 static std::unique_ptr<llvm::Module>
-convertMLIRModule(ModuleOp module, llvm::LLVMContext &context) {
+convertMLIRModule(Operation *op, llvm::LLVMContext &context) {
+  auto module = dyn_cast<ModuleOp>(op);
+  if (!module)
+    return op->emitError("op must be a 'builtin.module"), nullptr;
   // Verify that there is only one nested module.
   auto modules = module.getOps<ModuleOp>();
   if (!llvm::hasSingleElement(modules)) {
@@ -71,8 +74,9 @@ convertMLIRModule(ModuleOp module, llvm::LLVMContext &context) {
   return mainModule;
 }
 
-static LogicalResult runMLIRPasses(ModuleOp module) {
-  PassManager passManager(module.getContext());
+static LogicalResult runMLIRPasses(Operation *module) {
+  PassManager passManager(module->getContext(),
+                          module->getName().getStringRef());
   applyPassManagerCLOptions(passManager);
   passManager.addPass(createGpuKernelOutliningPass());
   passManager.addPass(createConvertGPUToSPIRVPass(/*mapMemorySpace=*/true));