[mlir-translate] Support parsing operations other than 'builtin.module' as top-level
authorrkayaith <rkayaith@gmail.com>
Wed, 28 Sep 2022 01:31:16 +0000 (21:31 -0400)
committerrkayaith <rkayaith@gmail.com>
Fri, 21 Oct 2022 19:54:06 +0000 (15:54 -0400)
This adds a '--no-implicit-module' option, which disables the insertion
of a top-level 'builtin.module' during parsing.

The translation APIs are also updated to take/return 'Operation*'
instead of 'ModuleOp', to allow other operation types to be used. To
simplify translations which are restricted to specific operation types,
'TranslateFromMLIRRegistration' has an overload which performs the
necessary cast and error checking.

Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D134237

mlir/examples/standalone/standalone-translate/standalone-translate.cpp
mlir/include/mlir/Tools/mlir-translate/Translation.h
mlir/lib/Target/Cpp/TranslateRegistration.cpp
mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/lib/Target/SPIRV/TranslateRegistration.cpp
mlir/lib/Tools/mlir-translate/Translation.cpp
mlir/test/Target/LLVMIR/invalid-module.mlir [new file with mode: 0644]
mlir/test/Target/SPIRV/invalid-module.mlir [new file with mode: 0644]

index 31ddef4..94277c1 100644 (file)
@@ -12,7 +12,6 @@
 //===----------------------------------------------------------------------===//
 
 #include "Standalone/StandaloneDialect.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/InitAllTranslations.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Tools/mlir-translate/MlirTranslateMain.h"
@@ -24,7 +23,7 @@ int main(int argc, char **argv) {
   // TODO: Register standalone translations here.
   mlir::TranslateFromMLIRRegistration withdescription(
       "option", "different from option",
-      [](mlir::ModuleOp op, llvm::raw_ostream &output) {
+      [](mlir::Operation *op, llvm::raw_ostream &output) {
         return mlir::LogicalResult::success();
       },
       [](mlir::DialectRegistry &a) {});
index c8b5b70..bc4636c 100644 (file)
 #ifndef MLIR_TOOLS_MLIRTRANSLATE_TRANSLATION_H
 #define MLIR_TOOLS_MLIRTRANSLATE_TRANSLATION_H
 
+#include "mlir/IR/Operation.h"
 #include "llvm/Support/CommandLine.h"
 
-namespace llvm {
-class MemoryBuffer;
-class SourceMgr;
-class StringRef;
-} // namespace llvm
-
 namespace mlir {
-class DialectRegistry;
-struct LogicalResult;
-class MLIRContext;
-class ModuleOp;
 template <typename OpTy>
 class OwningOpRef;
 
 /// Interface of the function that translates the sources managed by `sourceMgr`
 /// to MLIR. The source manager has at least one buffer. The implementation
-/// should create a new MLIR ModuleOp in the given context and return a pointer
-/// to it, or a nullptr in case of any error.
-using TranslateSourceMgrToMLIRFunction = std::function<OwningOpRef<ModuleOp>(
+/// should create a new MLIR Operation in the given context and return a
+/// pointer to it, or a nullptr in case of any error.
+using TranslateSourceMgrToMLIRFunction = std::function<OwningOpRef<Operation *>(
     llvm::SourceMgr &sourceMgr, MLIRContext *)>;
 
 /// Interface of the function that translates the given string to MLIR. The
-/// implementation should create a new MLIR ModuleOp in the given context. If
+/// implementation should create a new MLIR Operation in the given context. If
 /// source-related error reporting is required from within the function, use
 /// TranslateSourceMgrToMLIRFunction instead.
 using TranslateStringRefToMLIRFunction =
-    std::function<OwningOpRef<ModuleOp>(llvm::StringRef, MLIRContext *)>;
+    std::function<OwningOpRef<Operation *>(llvm::StringRef, MLIRContext *)>;
 
 /// Interface of the function that translates MLIR to a different format and
-/// outputs the result to a stream. It is allowed to modify the module.
+/// outputs the result to a stream. It is allowed to modify the operation.
 using TranslateFromMLIRFunction =
-    std::function<LogicalResult(ModuleOp, llvm::raw_ostream &output)>;
+    std::function<LogicalResult(Operation *, llvm::raw_ostream &output)>;
 
 /// Interface of the function that performs file-to-file translation involving
 /// MLIR. The input file is held in the given MemoryBuffer; the output file
@@ -83,6 +74,23 @@ struct TranslateFromMLIRRegistration {
       const TranslateFromMLIRFunction &function,
       const std::function<void(DialectRegistry &)> &dialectRegistration =
           [](DialectRegistry &) {});
+
+  template <typename FuncTy, typename OpTy = detail::first_argument<FuncTy>,
+            typename = std::enable_if_t<!std::is_same_v<OpTy, Operation *>>>
+  TranslateFromMLIRRegistration(
+      llvm::StringRef name, llvm::StringRef description, FuncTy function,
+      const std::function<void(DialectRegistry &)> &dialectRegistration =
+          [](DialectRegistry &) {})
+      : TranslateFromMLIRRegistration(
+            name, description,
+            [function](Operation *op, raw_ostream &os) -> LogicalResult {
+              if (auto casted = dyn_cast<OpTy>(op))
+                return function(casted, os);
+              return emitError(op->getLoc())
+                     << "expected a '" << OpTy::getOperationName()
+                     << "' op, got '" << op->getName().getStringRef() << "'";
+            },
+            dialectRegistration){};
 };
 struct TranslateRegistration {
   TranslateRegistration(llvm::StringRef name, llvm::StringRef description,
index 7eb393f..b486e54 100644 (file)
@@ -34,9 +34,9 @@ void registerToCppTranslation() {
 
   TranslateFromMLIRRegistration reg(
       "mlir-to-cpp", "translate from mlir to cpp",
-      [](ModuleOp module, raw_ostream &output) {
+      [](Operation *op, raw_ostream &output) {
         return emitc::translateToCpp(
-            module, output,
+            op, output,
             /*declareVariablesAtTop=*/declareVariablesAtTop);
       },
       [](DialectRegistry &registry) {
index 2995ad3..865add9 100644 (file)
@@ -1152,8 +1152,8 @@ mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
 
 // Deserializes the LLVM bitcode stored in `input` into an MLIR module in the
 // LLVM dialect.
-OwningOpRef<ModuleOp> translateLLVMIRToModule(llvm::SourceMgr &sourceMgr,
-                                              MLIRContext *context) {
+static OwningOpRef<Operation *>
+translateLLVMIRToModule(llvm::SourceMgr &sourceMgr, MLIRContext *context) {
   llvm::SMDiagnostic err;
   llvm::LLVMContext llvmContext;
   std::unique_ptr<llvm::Module> llvmModule = llvm::parseIR(
index 1ae7b83..4f44770 100644 (file)
@@ -25,9 +25,9 @@ namespace mlir {
 void registerToLLVMIRTranslation() {
   TranslateFromMLIRRegistration registration(
       "mlir-to-llvmir", "translate mlir to llvmir",
-      [](ModuleOp module, raw_ostream &output) {
+      [](Operation *op, raw_ostream &output) {
         llvm::LLVMContext llvmContext;
-        auto llvmModule = translateModuleToLLVMIR(module, llvmContext);
+        auto llvmModule = translateModuleToLLVMIR(op, llvmContext);
         if (!llvmModule)
           return failure();
 
index e409b66..5e27dab 100644 (file)
@@ -1189,8 +1189,10 @@ prepareLLVMModule(Operation *m, llvm::LLVMContext &llvmContext,
 std::unique_ptr<llvm::Module>
 mlir::translateModuleToLLVMIR(Operation *module, llvm::LLVMContext &llvmContext,
                               StringRef name) {
-  if (!satisfiesLLVMModule(module))
+  if (!satisfiesLLVMModule(module)) {
+    module->emitOpError("can not be translated to an LLVMIR module");
     return nullptr;
+  }
 
   std::unique_ptr<llvm::Module> llvmModule =
       prepareLLVMModule(module, llvmContext, name);
index d24e578..b293d87 100644 (file)
@@ -36,8 +36,8 @@ using namespace mlir;
 
 // Deserializes the SPIR-V binary module stored in the file named as
 // `inputFilename` and returns a module containing the SPIR-V module.
-static OwningOpRef<ModuleOp> deserializeModule(const llvm::MemoryBuffer *input,
-                                               MLIRContext *context) {
+static OwningOpRef<Operation *>
+deserializeModule(const llvm::MemoryBuffer *input, MLIRContext *context) {
   context->loadDialect<spirv::SPIRVDialect>();
 
   // Make sure the input stream can be treated as a stream of SPIR-V words
@@ -61,7 +61,7 @@ static OwningOpRef<ModuleOp> deserializeModule(const llvm::MemoryBuffer *input,
       context, input->getBufferIdentifier(), /*line=*/0, /*column=*/0)));
   module->getBody()->push_front(spirvModule.release());
 
-  return module;
+  return std::move(module);
 }
 
 namespace mlir {
index 7517ae0..9e97e1b 100644 (file)
@@ -16,6 +16,7 @@
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Parser/Parser.h"
+#include "mlir/Tools/ParseUtilties.h"
 #include "llvm/Support/SourceMgr.h"
 
 using namespace mlir;
@@ -65,10 +66,10 @@ static void registerTranslateToMLIRFunction(
     const TranslateSourceMgrToMLIRFunction &function) {
   auto wrappedFn = [function](llvm::SourceMgr &sourceMgr, raw_ostream &output,
                               MLIRContext *context) {
-    OwningOpRef<ModuleOp> module = function(sourceMgr, context);
-    if (!module || failed(verify(*module)))
+    OwningOpRef<Operation *> op = function(sourceMgr, context);
+    if (!op || failed(verify(*op)))
       return failure();
-    module->print(output);
+    op.get()->print(output);
     return success();
   };
   registerTranslation(name, description, wrappedFn);
@@ -101,6 +102,12 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
     StringRef name, StringRef description,
     const TranslateFromMLIRFunction &function,
     const std::function<void(DialectRegistry &)> &dialectRegistration) {
+
+  static llvm::cl::opt<bool> noImplicitModule{
+      "no-implicit-module",
+      llvm::cl::desc("Disable the parsing of an implicit top-level module op"),
+      llvm::cl::init(false)};
+
   registerTranslation(name, description,
                       [function, dialectRegistration](
                           llvm::SourceMgr &sourceMgr, raw_ostream &output,
@@ -108,11 +115,11 @@ TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
                         DialectRegistry registry;
                         dialectRegistration(registry);
                         context->appendDialectRegistry(registry);
-                        auto module =
-                            parseSourceFile<ModuleOp>(sourceMgr, context);
-                        if (!module || failed(verify(*module)))
+                        OwningOpRef<Operation *> op = parseSourceFileForTool(
+                            sourceMgr, context, !noImplicitModule);
+                        if (!op || failed(verify(*op)))
                           return failure();
-                        return function(module.get(), output);
+                        return function(op.get(), output);
                       });
 }
 
diff --git a/mlir/test/Target/LLVMIR/invalid-module.mlir b/mlir/test/Target/LLVMIR/invalid-module.mlir
new file mode 100644 (file)
index 0000000..7fd5f26
--- /dev/null
@@ -0,0 +1,6 @@
+// RUN: mlir-translate -verify-diagnostics -mlir-to-llvmir --no-implicit-module %s
+
+// expected-error@below {{'llvm.func' op can not be translated to an LLVMIR module}}
+llvm.func @foo() {
+  llvm.return
+}
diff --git a/mlir/test/Target/SPIRV/invalid-module.mlir b/mlir/test/Target/SPIRV/invalid-module.mlir
new file mode 100644 (file)
index 0000000..e388a03
--- /dev/null
@@ -0,0 +1,4 @@
+// RUN: mlir-translate %s -serialize-spirv -no-implicit-module -verify-diagnostics
+
+// expected-error@below {{expected a 'builtin.module' op, got 'spirv.module'}}
+spirv.module Logical Simple {}