Add a `dialect_registration` callback for "translations" registered with mlir-translate
authorMehdi Amini <joker.eph@gmail.com>
Sun, 23 Aug 2020 00:57:47 +0000 (00:57 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 23 Aug 2020 01:00:39 +0000 (01:00 +0000)
This will allow out-of-tree translation to register the dialects they expect
to see in their input, on the model of getDependentDialects() for passes.

Differential Revision: https://reviews.llvm.org/D86409

mlir/include/mlir/Translation.h
mlir/lib/Dialect/SPIRV/Serialization/TranslateRegistration.cpp
mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
mlir/lib/Target/LLVMIR/ConvertToROCDLIR.cpp
mlir/lib/Target/LLVMIR/LLVMAVX512Intr.cpp
mlir/lib/Translation/Translation.cpp
mlir/tools/mlir-translate/mlir-translate.cpp

index bdc391a..ab9dad8 100644 (file)
@@ -76,8 +76,10 @@ struct TranslateToMLIRRegistration {
 };
 
 struct TranslateFromMLIRRegistration {
-  TranslateFromMLIRRegistration(llvm::StringRef name,
-                                const TranslateFromMLIRFunction &function);
+  TranslateFromMLIRRegistration(
+      llvm::StringRef name, const TranslateFromMLIRFunction &function,
+      std::function<void(DialectRegistry &)> dialectRegistration =
+          [](DialectRegistry &) {});
 };
 struct TranslateRegistration {
   TranslateRegistration(llvm::StringRef name,
index d0aa277..8a6032f 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVModule.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Serialization.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/Dialect.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/Module.h"
 #include "mlir/Parser.h"
@@ -105,8 +107,12 @@ static LogicalResult serializeModule(ModuleOp module, raw_ostream &output) {
 namespace mlir {
 void registerToSPIRVTranslation() {
   TranslateFromMLIRRegistration toBinary(
-      "serialize-spirv", [](ModuleOp module, raw_ostream &output) {
+      "serialize-spirv",
+      [](ModuleOp module, raw_ostream &output) {
         return serializeModule(module, output);
+      },
+      [](DialectRegistry &registry) {
+        registry.insert<spirv::SPIRVDialect>();
       });
 }
 } // namespace mlir
@@ -147,15 +153,23 @@ static LogicalResult roundTripModule(ModuleOp srcModule, bool emitDebugInfo,
 namespace mlir {
 void registerTestRoundtripSPIRV() {
   TranslateFromMLIRRegistration roundtrip(
-      "test-spirv-roundtrip", [](ModuleOp module, raw_ostream &output) {
+      "test-spirv-roundtrip",
+      [](ModuleOp module, raw_ostream &output) {
         return roundTripModule(module, /*emitDebugInfo=*/false, output);
+      },
+      [](DialectRegistry &registry) {
+        registry.insert<spirv::SPIRVDialect>();
       });
 }
 
 void registerTestRoundtripDebugSPIRV() {
   TranslateFromMLIRRegistration roundtrip(
-      "test-spirv-roundtrip-debug", [](ModuleOp module, raw_ostream &output) {
+      "test-spirv-roundtrip-debug",
+      [](ModuleOp module, raw_ostream &output) {
         return roundTripModule(module, /*emitDebugInfo=*/true, output);
+      },
+      [](DialectRegistry &registry) {
+        registry.insert<spirv::SPIRVDialect>();
       });
 }
 } // namespace mlir
index 6027ab3..89fb5b1 100644 (file)
@@ -30,7 +30,8 @@ mlir::translateModuleToLLVMIR(ModuleOp m, llvm::LLVMContext &llvmContext,
 namespace mlir {
 void registerToLLVMIRTranslation() {
   TranslateFromMLIRRegistration registration(
-      "mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) {
+      "mlir-to-llvmir",
+      [](ModuleOp module, raw_ostream &output) {
         llvm::LLVMContext llvmContext;
         auto llvmModule = LLVM::ModuleTranslation::translateModule<>(
             module, llvmContext, "LLVMDialectModule");
@@ -39,6 +40,7 @@ void registerToLLVMIRTranslation() {
 
         llvmModule->print(output, nullptr);
         return success();
-      });
+      },
+      [](DialectRegistry &registry) { registry.insert<LLVM::LLVMDialect>(); });
 }
 } // namespace mlir
index fc2f650..bee7d18 100644 (file)
@@ -99,7 +99,8 @@ mlir::translateModuleToNVVMIR(Operation *m, llvm::LLVMContext &llvmContext,
 namespace mlir {
 void registerToNVVMIRTranslation() {
   TranslateFromMLIRRegistration registration(
-      "mlir-to-nvvmir", [](ModuleOp module, raw_ostream &output) {
+      "mlir-to-nvvmir",
+      [](ModuleOp module, raw_ostream &output) {
         llvm::LLVMContext llvmContext;
         auto llvmModule = mlir::translateModuleToNVVMIR(module, llvmContext);
         if (!llvmModule)
@@ -107,6 +108,9 @@ void registerToNVVMIRTranslation() {
 
         llvmModule->print(output, nullptr);
         return success();
+      },
+      [](DialectRegistry &registry) {
+        registry.insert<LLVM::LLVMDialect, NVVM::NVVMDialect>();
       });
 }
 } // namespace mlir
index 5bd04e4..6bfcf19 100644 (file)
@@ -103,7 +103,8 @@ mlir::translateModuleToROCDLIR(Operation *m, llvm::LLVMContext &llvmContext,
 namespace mlir {
 void registerToROCDLIRTranslation() {
   TranslateFromMLIRRegistration registration(
-      "mlir-to-rocdlir", [](ModuleOp module, raw_ostream &output) {
+      "mlir-to-rocdlir",
+      [](ModuleOp module, raw_ostream &output) {
         llvm::LLVMContext llvmContext;
         auto llvmModule = mlir::translateModuleToROCDLIR(module, llvmContext);
         if (!llvmModule)
@@ -111,6 +112,9 @@ void registerToROCDLIRTranslation() {
 
         llvmModule->print(output, nullptr);
         return success();
+      },
+      [](DialectRegistry &registry) {
+        registry.insert<ROCDL::ROCDLDialect, LLVM::LLVMDialect>();
       });
 }
 } // namespace mlir
index 64f9fef..52f1792 100644 (file)
@@ -45,7 +45,8 @@ translateLLVMAVX512ModuleToLLVMIR(Operation *m, llvm::LLVMContext &llvmContext,
 namespace mlir {
 void registerAVX512ToLLVMIRTranslation() {
   TranslateFromMLIRRegistration reg(
-      "avx512-mlir-to-llvmir", [](ModuleOp module, raw_ostream &output) {
+      "avx512-mlir-to-llvmir",
+      [](ModuleOp module, raw_ostream &output) {
         llvm::LLVMContext llvmContext;
         auto llvmModule = translateLLVMAVX512ModuleToLLVMIR(
             module, llvmContext, "LLVMDialectModule");
@@ -54,6 +55,9 @@ void registerAVX512ToLLVMIRTranslation() {
 
         llvmModule->print(output, nullptr);
         return success();
+      },
+      [](DialectRegistry &registry) {
+        registry.insert<LLVM::LLVMAVX512Dialect, LLVM::LLVMDialect>();
       });
 }
 } // namespace mlir
index 99f0acd..991bdf9 100644 (file)
@@ -92,10 +92,12 @@ TranslateToMLIRRegistration::TranslateToMLIRRegistration(
 //===----------------------------------------------------------------------===//
 
 TranslateFromMLIRRegistration::TranslateFromMLIRRegistration(
-    StringRef name, const TranslateFromMLIRFunction &function) {
-  registerTranslation(name, [function](llvm::SourceMgr &sourceMgr,
-                                       raw_ostream &output,
-                                       MLIRContext *context) {
+    StringRef name, const TranslateFromMLIRFunction &function,
+    std::function<void(DialectRegistry &)> dialectRegistration) {
+  registerTranslation(name, [function, dialectRegistration](
+                                llvm::SourceMgr &sourceMgr, raw_ostream &output,
+                                MLIRContext *context) {
+    dialectRegistration(context->getDialectRegistry());
     auto module = OwningModuleRef(parseSourceFile(sourceMgr, context));
     if (!module)
       return failure();
@@ -173,7 +175,7 @@ LogicalResult mlir::mlirTranslateMain(int argc, char **argv,
   // Processes the memory buffer with a new MLIRContext.
   auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
                            raw_ostream &os) {
-    MLIRContext context;
+    MLIRContext context(false);
     context.printOpOnDiagnostic(!verifyDiagnostics);
     llvm::SourceMgr sourceMgr;
     sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), llvm::SMLoc());
index 6388a4c..cf84856 100644 (file)
@@ -32,7 +32,5 @@ static void registerTestTranslations() {
 int main(int argc, char **argv) {
   registerAllTranslations();
   registerTestTranslations();
-  // TODO: remove the global dialect registry
-  registerAllDialects();
   return failed(mlirTranslateMain(argc, argv, "MLIR Translation Testing Tool"));
 }