Separate the Registration from Loading dialects in the Context
authorMehdi Amini <joker.eph@gmail.com>
Tue, 18 Aug 2020 20:01:19 +0000 (20:01 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Tue, 18 Aug 2020 23:23:56 +0000 (23:23 +0000)
This changes the behavior of constructing MLIRContext to no longer load globally
registered dialects on construction. Instead Dialects are only loaded explicitly
on demand:
- the Parser is lazily loading Dialects in the context as it encounters them
during parsing. This is the only purpose for registering dialects and not load
them in the context.
- Passes are expected to declare the dialects they will create entity from
(Operations, Attributes, or Types), and the PassManager is loading Dialects into
the Context when starting a pipeline.

This changes simplifies the configuration of the registration: a compiler only
need to load the dialect for the IR it will emit, and the optimizer is
self-contained and load the required Dialects. For example in the Toy tutorial,
the compiler only needs to load the Toy dialect in the Context, all the others
(linalg, affine, std, LLVM, ...) are automatically loaded depending on the
optimization pipeline enabled.

To adjust to this change, stop using the existing dialect registration: the
global registry will be removed soon.

1) For passes, you need to override the method:

virtual void getDependentDialects(DialectRegistry &registry) const {}

and registery on the provided registry any dialect that this pass can produce.
Passes defined in TableGen can provide this list in the dependentDialects list
field.

2) For dialects, on construction you can register dependent dialects using the
provided MLIRContext: `context.getOrLoadDialect<DialectName>()`
This is useful if a dialect may canonicalize or have interfaces involving
another dialect.

3) For loading IR, dialect that can be in the input file must be explicitly
registered with the context. `MlirOptMain()` is taking an explicit registry for
this purpose. See how the standalone-opt.cpp example is setup:

  mlir::DialectRegistry registry;
  registry.insert<mlir::standalone::StandaloneDialect>();
  registry.insert<mlir::StandardOpsDialect>();

Only operations from these two dialects can be in the input file. To include all
of the dialects in MLIR Core, you can populate the registry this way:

  mlir::registerAllDialects(registry);

4) For `mlir-translate` callback, as well as frontend, Dialects can be loaded in
the context before emitting the IR: context.getOrLoadDialect<ToyDialect>()

Differential Revision: https://reviews.llvm.org/D85622

95 files changed:
mlir/examples/standalone/standalone-opt/standalone-opt.cpp
mlir/examples/standalone/test/Standalone/standalone-opt.mlir
mlir/examples/toy/Ch2/toyc.cpp
mlir/examples/toy/Ch3/toyc.cpp
mlir/examples/toy/Ch4/toyc.cpp
mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch5/toyc.cpp
mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch6/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch6/toyc.cpp
mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
mlir/examples/toy/Ch7/mlir/LowerToLLVM.cpp
mlir/examples/toy/Ch7/toyc.cpp
mlir/include/mlir-c/Registration.h
mlir/include/mlir/Conversion/Passes.td
mlir/include/mlir/Dialect/Affine/Passes.td
mlir/include/mlir/Dialect/LLVMIR/LLVMDialect.h
mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td
mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h
mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
mlir/include/mlir/Dialect/LLVMIR/ROCDLDialect.h
mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td
mlir/include/mlir/Dialect/Linalg/Passes.td
mlir/include/mlir/Dialect/SCF/Passes.td
mlir/include/mlir/IR/Dialect.h
mlir/include/mlir/IR/FunctionSupport.h
mlir/include/mlir/IR/MLIRContext.h
mlir/include/mlir/IR/OpBase.td
mlir/include/mlir/InitAllDialects.h
mlir/include/mlir/InitAllTranslations.h
mlir/include/mlir/Pass/Pass.h
mlir/include/mlir/Pass/PassBase.td
mlir/include/mlir/Pass/PassManager.h
mlir/include/mlir/Support/MlirOptMain.h
mlir/include/mlir/TableGen/Dialect.h
mlir/include/mlir/TableGen/Pass.h
mlir/include/mlir/Transforms/Passes.td
mlir/lib/CAPI/IR/IR.cpp
mlir/lib/CAPI/Registration/Registration.cpp
mlir/lib/Conversion/GPUToVulkan/ConvertGPULaunchFuncToVulkanLaunchFunc.cpp
mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp
mlir/lib/Conversion/PassDetail.h
mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp
mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp
mlir/lib/Dialect/Affine/Transforms/PassDetail.h
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/PassDetail.h
mlir/lib/Dialect/SCF/Transforms/PassDetail.h
mlir/lib/Dialect/SDBM/SDBMExpr.cpp
mlir/lib/ExecutionEngine/JitRunner.cpp
mlir/lib/IR/Dialect.cpp
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/Operation.cpp
mlir/lib/IR/Verifier.cpp
mlir/lib/Parser/AttributeParser.cpp
mlir/lib/Parser/DialectSymbolParser.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Pass/Pass.cpp
mlir/lib/Pass/PassDetail.h
mlir/lib/Support/MlirOptMain.cpp
mlir/lib/TableGen/Dialect.cpp
mlir/lib/TableGen/Pass.cpp
mlir/lib/Target/LLVMIR/ConvertFromLLVMIR.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/lib/Transforms/PassDetail.h
mlir/test/CAPI/ir.c
mlir/test/EDSC/builder-api-test.cpp
mlir/test/SDBM/sdbm-api-test.cpp
mlir/test/lib/Dialect/Affine/TestVectorizationUtils.cpp
mlir/test/lib/Dialect/SPIRV/TestAvailability.cpp
mlir/test/lib/Dialect/Test/TestDialect.cpp
mlir/test/lib/Dialect/Test/TestDialect.h
mlir/test/lib/Dialect/Test/TestPatterns.cpp
mlir/test/lib/Transforms/TestAllReduceLowering.cpp
mlir/test/lib/Transforms/TestBufferPlacement.cpp
mlir/test/lib/Transforms/TestGpuMemoryPromotion.cpp
mlir/test/lib/Transforms/TestLinalgHoisting.cpp
mlir/test/lib/Transforms/TestLinalgTransforms.cpp
mlir/test/lib/Transforms/TestVectorTransforms.cpp
mlir/test/mlir-opt/commandline.mlir
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp
mlir/tools/mlir-opt/mlir-opt.cpp
mlir/tools/mlir-tblgen/DialectGen.cpp
mlir/tools/mlir-tblgen/PassGen.cpp
mlir/tools/mlir-translate/mlir-translate.cpp
mlir/unittests/Dialect/Quant/QuantizationUtilsTest.cpp
mlir/unittests/Dialect/SPIRV/DeserializationTest.cpp
mlir/unittests/Dialect/SPIRV/SerializationTest.cpp
mlir/unittests/IR/AttributeTest.cpp
mlir/unittests/IR/DialectTest.cpp
mlir/unittests/IR/OperationSupportTest.cpp
mlir/unittests/Pass/AnalysisManagerTest.cpp
mlir/unittests/SDBM/SDBMTest.cpp
mlir/unittests/TableGen/OpBuildGen.cpp
mlir/unittests/TableGen/StructsGenTest.cpp

index b33dab2..86cf679 100644 (file)
 int main(int argc, char **argv) {
   mlir::registerAllDialects();
   mlir::registerAllPasses();
-
-  mlir::registerDialect<mlir::standalone::StandaloneDialect>();
   // TODO: Register standalone passes here.
 
-  return failed(mlir::MlirOptMain(argc, argv, "Standalone optimizer driver\n"));
+  mlir::DialectRegistry registry;
+  registry.insert<mlir::standalone::StandaloneDialect>();
+  registry.insert<mlir::StandardOpsDialect>();
+  // Add the following to include *all* MLIR Core dialects, or selectively
+  // include what you need like above. You only need to register dialects that
+  // will be *parsed* by the tool, not the one generated
+  // registerAllDialects(registry);
+
+  return failed(
+      mlir::MlirOptMain(argc, argv, "Standalone optimizer driver\n", registry));
 }
index fac0814..1a78a9d 100644 (file)
@@ -1,3 +1,3 @@
 // RUN: standalone-opt --show-dialects | FileCheck %s
-// CHECK: Registered Dialects:
+// CHECK: Available Dialects:
 // CHECK: standalone
index d0880ce..99232d8 100644 (file)
@@ -68,10 +68,9 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
 }
 
 int dumpMLIR() {
-  // Register our Dialect with MLIR.
-  mlir::registerDialect<mlir::toy::ToyDialect>();
-
-  mlir::MLIRContext context;
+  mlir::MLIRContext context(/*loadAllDialects=*/false);
+  // Load our Dialect in this MLIR Context.
+  context.getOrLoadDialect<mlir::toy::ToyDialect>();
 
   // Handle '.toy' input to the compiler.
   if (inputType != InputType::MLIR &&
index f9d5631..d0430ce 100644 (file)
@@ -102,10 +102,10 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
 }
 
 int dumpMLIR() {
-  // Register our Dialect with MLIR.
-  mlir::registerDialect<mlir::toy::ToyDialect>();
+  mlir::MLIRContext context(/*loadAllDialects=*/false);
+  // Load our Dialect in this MLIR Context.
+  context.getOrLoadDialect<mlir::toy::ToyDialect>();
 
-  mlir::MLIRContext context;
   mlir::OwningModuleRef module;
   llvm::SourceMgr sourceMgr;
   mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
index e11f35c..9f95887 100644 (file)
@@ -103,10 +103,10 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
 }
 
 int dumpMLIR() {
-  // Register our Dialect with MLIR.
-  mlir::registerDialect<mlir::toy::ToyDialect>();
+  mlir::MLIRContext context(/*loadAllDialects=*/false);
+  // Load our Dialect in this MLIR Context.
+  context.getOrLoadDialect<mlir::toy::ToyDialect>();
 
-  mlir::MLIRContext context;
   mlir::OwningModuleRef module;
   llvm::SourceMgr sourceMgr;
   mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
index 3097681..92fd246 100644 (file)
@@ -256,6 +256,9 @@ struct TransposeOpLowering : public ConversionPattern {
 namespace {
 struct ToyToAffineLoweringPass
     : public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect, StandardOpsDialect>();
+  }
   void runOnFunction() final;
 };
 } // end anonymous namespace.
index ed04969..16faac0 100644 (file)
@@ -106,10 +106,10 @@ int loadMLIR(llvm::SourceMgr &sourceMgr, mlir::MLIRContext &context,
 }
 
 int dumpMLIR() {
-  // Register our Dialect with MLIR.
-  mlir::registerDialect<mlir::toy::ToyDialect>();
+  mlir::MLIRContext context(/*loadAllDialects=*/false);
+  // Load our Dialect in this MLIR Context.
+  context.getOrLoadDialect<mlir::toy::ToyDialect>();
 
-  mlir::MLIRContext context;
   mlir::OwningModuleRef module;
   llvm::SourceMgr sourceMgr;
   mlir::SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context);
index cac3415..f3857f3 100644 (file)
@@ -255,6 +255,9 @@ struct TransposeOpLowering : public ConversionPattern {
 namespace {
 struct ToyToAffineLoweringPass
     : public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect, StandardOpsDialect>();
+  }
   void runOnFunction() final;
 };
 } // end anonymous namespace.
index 74b32dc..19bf27e 100644 (file)
@@ -159,6 +159,9 @@ private:
 namespace {
 struct ToyToLLVMLoweringPass
     : public PassWrapper<ToyToLLVMLoweringPass, OperationPass<ModuleOp>> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<LLVM::LLVMDialect, scf::SCFDialect>();
+  }
   void runOnOperation() final;
 };
 } // end anonymous namespace
index bdcdf1a..9504a38 100644 (file)
@@ -255,10 +255,10 @@ int main(int argc, char **argv) {
 
   // If we aren't dumping the AST, then we are compiling with/to MLIR.
 
-  // Register our Dialect with MLIR.
-  mlir::registerDialect<mlir::toy::ToyDialect>();
+  mlir::MLIRContext context(/*loadAllDialects=*/false);
+  // Load our Dialect in this MLIR Context.
+  context.getOrLoadDialect<mlir::toy::ToyDialect>();
 
-  mlir::MLIRContext context;
   mlir::OwningModuleRef module;
   if (int error = loadAndProcessMLIR(context, module))
     return error;
index 3097681..92fd246 100644 (file)
@@ -256,6 +256,9 @@ struct TransposeOpLowering : public ConversionPattern {
 namespace {
 struct ToyToAffineLoweringPass
     : public PassWrapper<ToyToAffineLoweringPass, FunctionPass> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect, StandardOpsDialect>();
+  }
   void runOnFunction() final;
 };
 } // end anonymous namespace.
index 74b32dc..19bf27e 100644 (file)
@@ -159,6 +159,9 @@ private:
 namespace {
 struct ToyToLLVMLoweringPass
     : public PassWrapper<ToyToLLVMLoweringPass, OperationPass<ModuleOp>> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<LLVM::LLVMDialect, scf::SCFDialect>();
+  }
   void runOnOperation() final;
 };
 } // end anonymous namespace
index c1cc207..cb3b455 100644 (file)
@@ -256,10 +256,10 @@ int main(int argc, char **argv) {
 
   // If we aren't dumping the AST, then we are compiling with/to MLIR.
 
-  // Register our Dialect with MLIR.
-  mlir::registerDialect<mlir::toy::ToyDialect>();
+  mlir::MLIRContext context(/*loadAllDialects=*/false);
+  // Load our Dialect in this MLIR Context.
+  context.getOrLoadDialect<mlir::toy::ToyDialect>();
 
-  mlir::MLIRContext context;
   mlir::OwningModuleRef module;
   if (int error = loadAndProcessMLIR(context, module))
     return error;
index 5e5aa0e..05d4aac 100644 (file)
 #ifndef MLIR_C_REGISTRATION_H
 #define MLIR_C_REGISTRATION_H
 
+#include "mlir-c/IR.h"
+
 #ifdef __cplusplus
 extern "C" {
 #endif
 
-/** Registers all dialects known to core MLIR with the system. This must be
- * called before creating an MlirContext if it needs access to the registered
- * dialects. */
-void mlirRegisterAllDialects();
+/** Registers all dialects known to core MLIR with the provided Context.
+ * This is needed before creating IR for these Dialects.
+ */
+void mlirRegisterAllDialects(MlirContext context);
 
 #ifdef __cplusplus
 }
index 4ff23d7..0a043c0 100644 (file)
@@ -66,6 +66,11 @@ def ConvertAffineToStandard : Pass<"lower-affine"> {
         `affine.apply`.
   }];
   let constructor = "mlir::createLowerAffinePass()";
+  let dependentDialects = [
+    "scf::SCFDialect",
+    "StandardOpsDialect",
+    "vector::VectorDialect"
+  ];
 }
 
 //===----------------------------------------------------------------------===//
@@ -76,6 +81,7 @@ def ConvertAVX512ToLLVM : Pass<"convert-avx512-to-llvm", "ModuleOp"> {
   let summary = "Convert the operations from the avx512 dialect into the LLVM "
                 "dialect";
   let constructor = "mlir::createConvertAVX512ToLLVMPass()";
+  let dependentDialects = ["LLVM::LLVMDialect", "LLVM::LLVMAVX512Dialect"];
 }
 
 //===----------------------------------------------------------------------===//
@@ -98,6 +104,7 @@ def GpuToLLVMConversionPass : Pass<"gpu-to-llvm", "ModuleOp"> {
 def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
   let summary = "Generate NVVM operations for gpu operations";
   let constructor = "mlir::createLowerGpuOpsToNVVMOpsPass()";
+  let dependentDialects = ["NVVM::NVVMDialect"];
   let options = [
     Option<"indexBitwidth", "index-bitwidth", "unsigned",
            /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
@@ -112,6 +119,7 @@ def ConvertGpuOpsToNVVMOps : Pass<"convert-gpu-to-nvvm", "gpu::GPUModuleOp"> {
 def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
   let summary = "Generate ROCDL operations for gpu operations";
   let constructor = "mlir::createLowerGpuOpsToROCDLOpsPass()";
+  let dependentDialects = ["ROCDL::ROCDLDialect"];
   let options = [
     Option<"indexBitwidth", "index-bitwidth", "unsigned",
            /*default=kDeriveIndexBitwidthFromDataLayout*/"0",
@@ -126,6 +134,7 @@ def ConvertGpuOpsToROCDLOps : Pass<"convert-gpu-to-rocdl", "gpu::GPUModuleOp"> {
 def ConvertGPUToSPIRV : Pass<"convert-gpu-to-spirv", "ModuleOp"> {
   let summary = "Convert GPU dialect to SPIR-V dialect";
   let constructor = "mlir::createConvertGPUToSPIRVPass()";
+  let dependentDialects = ["spirv::SPIRVDialect"];
 }
 
 //===----------------------------------------------------------------------===//
@@ -136,6 +145,7 @@ def ConvertGpuLaunchFuncToVulkanLaunchFunc
     : Pass<"convert-gpu-launch-to-vulkan-launch", "ModuleOp"> {
   let summary = "Convert gpu.launch_func to vulkanLaunch external call";
   let constructor = "mlir::createConvertGpuLaunchFuncToVulkanLaunchFuncPass()";
+  let dependentDialects = ["spirv::SPIRVDialect"];
 }
 
 def ConvertVulkanLaunchFuncToVulkanCalls
@@ -143,6 +153,7 @@ def ConvertVulkanLaunchFuncToVulkanCalls
   let summary = "Convert vulkanLaunch external call to Vulkan runtime external "
                 "calls";
   let constructor = "mlir::createConvertVulkanLaunchFuncToVulkanCallsPass()";
+  let dependentDialects = ["LLVM::LLVMDialect"];
 }
 
 //===----------------------------------------------------------------------===//
@@ -153,6 +164,7 @@ def ConvertLinalgToLLVM : Pass<"convert-linalg-to-llvm", "ModuleOp"> {
   let summary = "Convert the operations from the linalg dialect into the LLVM "
                 "dialect";
   let constructor = "mlir::createConvertLinalgToLLVMPass()";
+  let dependentDialects = ["scf::SCFDialect", "LLVM::LLVMDialect"];
 }
 
 //===----------------------------------------------------------------------===//
@@ -163,6 +175,7 @@ def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
   let summary = "Convert the operations from the linalg dialect into the "
                 "Standard dialect";
   let constructor = "mlir::createConvertLinalgToStandardPass()";
+  let dependentDialects = ["StandardOpsDialect"];
 }
 
 //===----------------------------------------------------------------------===//
@@ -172,6 +185,7 @@ def ConvertLinalgToStandard : Pass<"convert-linalg-to-std", "ModuleOp"> {
 def ConvertLinalgToSPIRV : Pass<"convert-linalg-to-spirv", "ModuleOp"> {
   let summary = "Convert Linalg ops to SPIR-V ops";
   let constructor = "mlir::createLinalgToSPIRVPass()";
+  let dependentDialects = ["spirv::SPIRVDialect"];
 }
 
 //===----------------------------------------------------------------------===//
@@ -182,6 +196,7 @@ def SCFToStandard : Pass<"convert-scf-to-std"> {
   let summary = "Convert SCF dialect to Standard dialect, replacing structured"
                 " control flow with a CFG";
   let constructor = "mlir::createLowerToCFGPass()";
+  let dependentDialects = ["StandardOpsDialect"];
 }
 
 //===----------------------------------------------------------------------===//
@@ -191,6 +206,7 @@ def SCFToStandard : Pass<"convert-scf-to-std"> {
 def ConvertAffineForToGPU : FunctionPass<"convert-affine-for-to-gpu"> {
   let summary = "Convert top-level AffineFor Ops to GPU kernels";
   let constructor = "mlir::createAffineForToGPUPass()";
+  let dependentDialects = ["gpu::GPUDialect"];
   let options = [
     Option<"numBlockDims", "gpu-block-dims", "unsigned", /*default=*/"1u",
            "Number of GPU block dimensions for mapping">,
@@ -202,6 +218,7 @@ def ConvertAffineForToGPU : FunctionPass<"convert-affine-for-to-gpu"> {
 def ConvertParallelLoopToGpu : Pass<"convert-parallel-loops-to-gpu"> {
   let summary = "Convert mapped scf.parallel ops to gpu launch operations";
   let constructor = "mlir::createParallelLoopToGpuPass()";
+  let dependentDialects = ["AffineDialect", "gpu::GPUDialect"];
 }
 
 //===----------------------------------------------------------------------===//
@@ -212,6 +229,7 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
   let summary = "Convert operations from the shape dialect into the standard "
                 "dialect";
   let constructor = "mlir::createConvertShapeToStandardPass()";
+  let dependentDialects = ["StandardOpsDialect"];
 }
 
 //===----------------------------------------------------------------------===//
@@ -221,6 +239,7 @@ def ConvertShapeToStandard : Pass<"convert-shape-to-std", "ModuleOp"> {
 def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> {
   let summary = "Convert operations from the shape dialect to the SCF dialect";
   let constructor = "mlir::createConvertShapeToSCFPass()";
+  let dependentDialects = ["scf::SCFDialect"];
 }
 
 //===----------------------------------------------------------------------===//
@@ -230,6 +249,7 @@ def ConvertShapeToSCF : FunctionPass<"convert-shape-to-scf"> {
 def ConvertSPIRVToLLVM : Pass<"convert-spirv-to-llvm", "ModuleOp"> {
   let summary = "Convert SPIR-V dialect to LLVM dialect";
   let constructor = "mlir::createConvertSPIRVToLLVMPass()";
+  let dependentDialects = ["LLVM::LLVMDialect"];
 }
 
 //===----------------------------------------------------------------------===//
@@ -264,6 +284,7 @@ def ConvertStandardToLLVM : Pass<"convert-std-to-llvm", "ModuleOp"> {
     LLVM IR types.
   }];
   let constructor = "mlir::createLowerToLLVMPass()";
+  let dependentDialects = ["LLVM::LLVMDialect"];
   let options = [
     Option<"useAlignedAlloc", "use-aligned-alloc", "bool", /*default=*/"false",
            "Use aligned_alloc in place of malloc for heap allocations">,
@@ -291,11 +312,13 @@ def ConvertStandardToLLVM : Pass<"convert-std-to-llvm", "ModuleOp"> {
 def LegalizeStandardForSPIRV : Pass<"legalize-std-for-spirv"> {
   let summary = "Legalize standard ops for SPIR-V lowering";
   let constructor = "mlir::createLegalizeStdOpsForSPIRVLoweringPass()";
+  let dependentDialects = ["spirv::SPIRVDialect"];
 }
 
 def ConvertStandardToSPIRV : Pass<"convert-std-to-spirv", "ModuleOp"> {
   let summary = "Convert Standard Ops to SPIR-V dialect";
   let constructor = "mlir::createConvertStandardToSPIRVPass()";
+  let dependentDialects = ["spirv::SPIRVDialect"];
 }
 
 //===----------------------------------------------------------------------===//
@@ -306,6 +329,7 @@ def ConvertVectorToSCF : FunctionPass<"convert-vector-to-scf"> {
   let summary = "Lower the operations from the vector dialect into the SCF "
                 "dialect";
   let constructor = "mlir::createConvertVectorToSCFPass()";
+  let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
   let options = [
     Option<"fullUnroll", "full-unroll", "bool", /*default=*/"false",
            "Perform full unrolling when converting vector transfers to SCF">,
@@ -320,6 +344,7 @@ def ConvertVectorToLLVM : Pass<"convert-vector-to-llvm", "ModuleOp"> {
   let summary = "Lower the operations from the vector dialect into the LLVM "
                 "dialect";
   let constructor = "mlir::createConvertVectorToLLVMPass()";
+  let dependentDialects = ["LLVM::LLVMDialect"];
   let options = [
     Option<"reassociateFPReductions", "reassociate-fp-reductions",
            "bool", /*default=*/"false",
@@ -335,6 +360,7 @@ def ConvertVectorToROCDL : Pass<"convert-vector-to-rocdl", "ModuleOp"> {
   let summary = "Lower the operations from the vector dialect into the ROCDL "
                 "dialect";
   let constructor = "mlir::createConvertVectorToROCDLPass()";
+  let dependentDialects = ["ROCDL::ROCDLDialect"];
 }
 
 #endif // MLIR_CONVERSION_PASSES
index 8106400..f43fabd 100644 (file)
@@ -94,6 +94,7 @@ def AffineLoopUnrollAndJam : FunctionPass<"affine-loop-unroll-jam"> {
 def AffineVectorize : FunctionPass<"affine-super-vectorize"> {
   let summary = "Vectorize to a target independent n-D vector abstraction";
   let constructor = "mlir::createSuperVectorizePass()";
+  let dependentDialects = ["vector::VectorDialect"];
   let options = [
     ListOption<"vectorSizes", "virtual-vector-size", "int64_t",
                "Specify an n-D virtual vector size for vectorization",
index 04700f0..2f465f0 100644 (file)
@@ -15,6 +15,7 @@
 #define MLIR_DIALECT_LLVMIR_LLVMDIALECT_H_
 
 #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/Function.h"
 #include "mlir/IR/OpDefinition.h"
index e824f97..2267435 100644 (file)
@@ -19,6 +19,11 @@ include "mlir/IR/OpBase.td"
 def LLVM_Dialect : Dialect {
   let name = "llvm";
   let cppNamespace = "LLVM";
+
+  /// FIXME: at the moment this is a dependency of the translation to LLVM IR,
+  /// not really one of this dialect per-se.
+  let dependentDialects = ["omp::OpenMPDialect"];
+
   let hasRegionArgAttrVerify = 1;
   let hasOperationAttrVerify = 1;
   let extraClassDeclaration = [{
index 86d437c..9cc5314 100644 (file)
@@ -14,6 +14,7 @@
 #ifndef MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_
 #define MLIR_DIALECT_LLVMIR_NVVMDIALECT_H_
 
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
index 5f022e3..7d47e50 100644 (file)
@@ -23,6 +23,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 def NVVM_Dialect : Dialect {
   let name = "nvvm";
   let cppNamespace = "NVVM";
+  let dependentDialects = ["LLVM::LLVMDialect"];
 }
 
 //===----------------------------------------------------------------------===//
index bf761c3..eb40373 100644 (file)
@@ -22,6 +22,7 @@
 #ifndef MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_
 #define MLIR_DIALECT_LLVMIR_ROCDLDIALECT_H_
 
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Interfaces/SideEffectInterfaces.h"
index 0cd1169..f85c4f0 100644 (file)
@@ -23,6 +23,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
 def ROCDL_Dialect : Dialect {
   let name = "rocdl";
   let cppNamespace = "ROCDL";
+  let dependentDialects = ["LLVM::LLVMDialect"];
 }
 
 //===----------------------------------------------------------------------===//
index 11f12ad..dcf4b5e 100644 (file)
@@ -30,17 +30,20 @@ def LinalgFusion : FunctionPass<"linalg-fusion"> {
 def LinalgFusionOfTensorOps : Pass<"linalg-fusion-for-tensor-ops"> {
   let summary = "Fuse operations on RankedTensorType in linalg dialect";
   let constructor = "mlir::createLinalgFusionOfTensorOpsPass()";
+  let dependentDialects = ["AffineDialect"];
 }
 
 def LinalgLowerToAffineLoops : FunctionPass<"convert-linalg-to-affine-loops"> {
   let summary = "Lower the operations from the linalg dialect into affine "
                 "loops";
   let constructor = "mlir::createConvertLinalgToAffineLoopsPass()";
+  let dependentDialects = ["AffineDialect"];
 }
 
 def LinalgLowerToLoops : FunctionPass<"convert-linalg-to-loops"> {
   let summary = "Lower the operations from the linalg dialect into loops";
   let constructor = "mlir::createConvertLinalgToLoopsPass()";
+  let dependentDialects = ["scf::SCFDialect", "AffineDialect"];
 }
 
 def LinalgOnTensorsToBuffers : Pass<"convert-linalg-on-tensors-to-buffers", "ModuleOp"> {
@@ -54,6 +57,7 @@ def LinalgLowerToParallelLoops
   let summary = "Lower the operations from the linalg dialect into parallel "
                 "loops";
   let constructor = "mlir::createConvertLinalgToParallelLoopsPass()";
+  let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
 }
 
 def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> {
@@ -70,6 +74,9 @@ def LinalgPromotion : FunctionPass<"linalg-promote-subviews"> {
 def LinalgTiling : FunctionPass<"linalg-tile"> {
   let summary = "Tile operations in the linalg dialect";
   let constructor = "mlir::createLinalgTilingPass()";
+  let dependentDialects = [
+    "AffineDialect", "scf::SCFDialect"
+  ];
   let options = [
     ListOption<"tileSizes", "linalg-tile-sizes", "int64_t",
                "Test generation of dynamic promoted buffers",
@@ -86,6 +93,7 @@ def LinalgTilingToParallelLoops
                "Test generation of dynamic promoted buffers",
                "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
   ];
+  let dependentDialects = ["AffineDialect", "scf::SCFDialect"];
 }
 
 #endif // MLIR_DIALECT_LINALG_PASSES
index 483d0ba..6f3cf0e 100644 (file)
@@ -36,6 +36,7 @@ def SCFParallelLoopTiling : FunctionPass<"parallel-loop-tiling"> {
                "Factors to tile parallel loops by",
                "llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated">
   ];
+  let dependentDialects = ["AffineDialect"];
 }
 
 #endif // MLIR_DIALECT_SCF_PASSES
index 12a19af..0fe1a7f 100644 (file)
@@ -16,6 +16,8 @@
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/Support/TypeID.h"
 
+#include <map>
+
 namespace mlir {
 class DialectAsmParser;
 class DialectAsmPrinter;
@@ -23,7 +25,7 @@ class DialectInterface;
 class OpBuilder;
 class Type;
 
-using DialectAllocatorFunction = std::function<void(MLIRContext *)>;
+using DialectAllocatorFunction = std::function<Dialect *(MLIRContext *)>;
 
 /// Dialects are groups of MLIR operations and behavior associated with the
 /// entire group.  For example, hooks into other systems for constant folding,
@@ -222,30 +224,87 @@ private:
   /// A collection of registered dialect interfaces.
   DenseMap<TypeID, std::unique_ptr<DialectInterface>> registeredInterfaces;
 
-  /// Registers a specific dialect creation function with the global registry.
-  /// Used through the registerDialect template.
-  /// Registrations are deduplicated by dialect TypeID and only the first
-  /// registration will be used.
-  static void
-  registerDialectAllocator(TypeID typeID,
-                           const DialectAllocatorFunction &function);
-  template <typename ConcreteDialect>
   friend void registerDialect();
   friend class MLIRContext;
 };
 
-/// Registers all dialects and hooks from the global registries with the
-/// specified MLIRContext.
+/// The DialectRegistry maps a dialect namespace to a constructor for the
+/// matching dialect.
+/// This allows for decoupling the list of dialects "available" from the
+/// dialects loaded in the Context. The parser in particular will lazily load
+/// dialects in in the Context as operations are encountered.
+class DialectRegistry {
+  using MapTy =
+      std::map<std::string, std::pair<TypeID, DialectAllocatorFunction>>;
+
+public:
+  template <typename ConcreteDialect>
+  void insert() {
+    insert(TypeID::get<ConcreteDialect>(),
+           ConcreteDialect::getDialectNamespace(),
+           static_cast<DialectAllocatorFunction>(([](MLIRContext *ctx) {
+             // Just allocate the dialect, the context
+             // takes ownership of it.
+             return ctx->getOrLoadDialect<ConcreteDialect>();
+           })));
+  }
+
+  template <typename ConcreteDialect, typename OtherDialect,
+            typename... MoreDialects>
+  void insert() {
+    insert<ConcreteDialect>();
+    insert<OtherDialect, MoreDialects...>();
+  }
+
+  /// Add a new dialect constructor to the registry.
+  void insert(TypeID typeID, StringRef name, DialectAllocatorFunction ctor);
+
+  /// Load a dialect for this namespace in the provided context.
+  Dialect *loadByName(StringRef name, MLIRContext *context);
+
+  // Register all dialects available in the current registry with the registry
+  // in the provided context.
+  void appendTo(DialectRegistry &destination) {
+    for (const auto &nameAndRegistrationIt : registry)
+      destination.insert(nameAndRegistrationIt.second.first,
+                         nameAndRegistrationIt.first,
+                         nameAndRegistrationIt.second.second);
+  }
+  // Load all dialects available in the registry in the provided context.
+  void loadAll(MLIRContext *context) {
+    for (const auto &nameAndRegistrationIt : registry)
+      nameAndRegistrationIt.second.second(context);
+  }
+
+  MapTy::const_iterator begin() const { return registry.begin(); }
+  MapTy::const_iterator end() const { return registry.end(); }
+
+private:
+  MapTy registry;
+};
+
+/// Deprecated: this provides a global registry for convenience, while we're
+/// transitionning the registration mechanism to a stateless approach.
+DialectRegistry &getGlobalDialectRegistry();
+
+/// Registers all dialects from the global registries with the
+/// specified MLIRContext. This won't load the dialects in the context,
+/// but only make them available for lazy loading by name.
 /// Note: This method is not thread-safe.
 void registerAllDialects(MLIRContext *context);
 
+/// Register and return the dialect with the given namespace in the provided
+/// context. Returns nullptr is there is no constructor registered for this
+/// dialect.
+inline Dialect *registerDialect(StringRef name, MLIRContext *context) {
+  return getGlobalDialectRegistry().loadByName(name, context);
+}
+
 /// Utility to register a dialect. Client can register their dialect with the
 /// global registry by calling registerDialect<MyDialect>();
 /// Note: This method is not thread-safe.
 template <typename ConcreteDialect> void registerDialect() {
-  Dialect::registerDialectAllocator(
-      TypeID::get<ConcreteDialect>(),
-      [](MLIRContext *ctx) { ctx->getOrCreateDialect<ConcreteDialect>(); });
+  getGlobalDialectRegistry().insert<ConcreteDialect>();
 }
 
 /// DialectRegistration provides a global initializer that registers a Dialect
index 7e281f3..3d467cd 100644 (file)
@@ -428,7 +428,7 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
       if (!attr.first.strref().contains('.'))
         return funcOp.emitOpError("arguments may only have dialect attributes");
       auto dialectNamePair = attr.first.strref().split('.');
-      if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) {
+      if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) {
         if (failed(dialect->verifyRegionArgAttribute(op, /*regionIndex=*/0,
                                                      /*argIndex=*/i, attr)))
           return failure();
@@ -444,7 +444,7 @@ LogicalResult FunctionLike<ConcreteType>::verifyTrait(Operation *op) {
       if (!attr.first.strref().contains('.'))
         return funcOp.emitOpError("results may only have dialect attributes");
       auto dialectNamePair = attr.first.strref().split('.');
-      if (auto *dialect = ctx->getRegisteredDialect(dialectNamePair.first)) {
+      if (auto *dialect = ctx->getLoadedDialect(dialectNamePair.first)) {
         if (failed(dialect->verifyRegionResultAttribute(op, /*regionIndex=*/0,
                                                         /*resultIndex=*/i,
                                                         attr)))
index 0192a8a..e8a5d6e 100644 (file)
@@ -19,10 +19,12 @@ namespace mlir {
 class AbstractOperation;
 class DiagnosticEngine;
 class Dialect;
+class DialectRegistry;
 class InFlightDiagnostic;
 class Location;
 class MLIRContextImpl;
 class StorageUniquer;
+DialectRegistry &getGlobalDialectRegistry();
 
 /// MLIRContext is the top-level object for a collection of MLIR modules.  It
 /// holds immortal uniqued objects like types, and the tables used to unique
@@ -34,34 +36,69 @@ class StorageUniquer;
 ///
 class MLIRContext {
 public:
-  explicit MLIRContext();
+  /// Create a new Context.
+  /// The loadAllDialects parameters allows to load all dialects from the global
+  /// registry on Context construction. It is deprecated and will be removed
+  /// soon.
+  explicit MLIRContext(bool loadAllDialects = true);
   ~MLIRContext();
 
-  /// Return information about all registered IR dialects.
-  std::vector<Dialect *> getRegisteredDialects();
+  /// Return information about all IR dialects loaded in the context.
+  std::vector<Dialect *> getLoadedDialects();
+
+  /// Return the dialect registry associated with this context.
+  DialectRegistry &getDialectRegistry();
+
+  /// Return information about all available dialects in the registry in this
+  /// context.
+  std::vector<StringRef> getAvailableDialects();
 
   /// Get a registered IR dialect with the given namespace. If an exact match is
   /// not found, then return nullptr.
-  Dialect *getRegisteredDialect(StringRef name);
+  Dialect *getLoadedDialect(StringRef name);
 
   /// Get a registered IR dialect for the given derived dialect type. The
   /// derived type must provide a static 'getDialectNamespace' method.
-  template <typename T> T *getRegisteredDialect() {
-    return static_cast<T *>(getRegisteredDialect(T::getDialectNamespace()));
+  template <typename T>
+  T *getLoadedDialect() {
+    return static_cast<T *>(getLoadedDialect(T::getDialectNamespace()));
   }
 
   /// Get (or create) a dialect for the given derived dialect type. The derived
   /// type must provide a static 'getDialectNamespace' method.
   template <typename T>
-  T *getOrCreateDialect() {
-    return static_cast<T *>(getOrCreateDialect(
-        T::getDialectNamespace(), TypeID::get<T>(), [this]() {
+  T *getOrLoadDialect() {
+    return static_cast<T *>(
+        getOrLoadDialect(T::getDialectNamespace(), TypeID::get<T>(), [this]() {
           std::unique_ptr<T> dialect(new T(this));
-          dialect->dialectID = TypeID::get<T>();
           return dialect;
         }));
   }
 
+  /// Load a dialect in the context.
+  template <typename Dialect>
+  void loadDialect() {
+    getOrLoadDialect<Dialect>();
+  }
+
+  /// Load a list dialects in the context.
+  template <typename Dialect, typename OtherDialect, typename... MoreDialects>
+  void loadDialect() {
+    getOrLoadDialect<Dialect>();
+    loadDialect<OtherDialect, MoreDialects...>();
+  }
+
+  /// Deprecated: load all globally registered dialects into this context.
+  /// This method will be removed soon, it can be used temporarily as we're
+  /// phasing out the global registry.
+  void loadAllGloballyRegisteredDialects();
+
+  /// Get (or create) a dialect for the given derived dialect name.
+  /// The dialect will be loaded from the registry if no dialect is found.
+  /// If no dialect is loaded for this name and none is available in the
+  /// registry, returns nullptr.
+  Dialect *getOrLoadDialect(StringRef name);
+
   /// Return true if we allow to create operation for unregistered dialects.
   bool allowsUnregisteredDialects();
 
@@ -123,10 +160,12 @@ private:
   const std::unique_ptr<MLIRContextImpl> impl;
 
   /// Get a dialect for the provided namespace and TypeID: abort the program if
-  /// a dialect exist for this namespace with different TypeID. Returns a
-  /// pointer to the dialect owned by the context.
-  Dialect *getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID,
-                              function_ref<std::unique_ptr<Dialect>()> ctor);
+  /// a dialect exist for this namespace with different TypeID. If a dialect has
+  /// not been loaded for this namespace/TypeID yet, use the provided ctor to
+  /// create one on the fly and load it. Returns a pointer to the dialect owned
+  /// by the context.
+  Dialect *getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
+                            function_ref<std::unique_ptr<Dialect>()> ctor);
 
   MLIRContext(const MLIRContext &) = delete;
   void operator=(const MLIRContext &) = delete;
index 9cc57a6..a28410f 100644 (file)
@@ -244,6 +244,11 @@ class Dialect {
   // The description of the dialect.
   string description = ?;
 
+  // A list of dialects this dialect will load on construction as dependencies.
+  // These are dialects that this dialect may involved in canonicalization
+  // pattern or interfaces.
+  list<string> dependentDialects = [];
+
   // The C++ namespace that ops of this dialect should be placed into.
   //
   // By default, uses the name of the dialect as the only namespace. To avoid
index b76b26f..147ecec 100644 (file)
 
 namespace mlir {
 
+// Add all the MLIR dialects to the provided registry.
+inline void registerAllDialects(DialectRegistry &registry) {
+  // clang-format off
+  registry.insert<acc::OpenACCDialect,
+                  AffineDialect,
+                  avx512::AVX512Dialect,
+                  gpu::GPUDialect,
+                  LLVM::LLVMAVX512Dialect,
+                  LLVM::LLVMDialect,
+                  linalg::LinalgDialect,
+                  scf::SCFDialect,
+                  omp::OpenMPDialect,
+                  quant::QuantizationDialect,
+                  spirv::SPIRVDialect,
+                  StandardOpsDialect,
+                  vector::VectorDialect,
+                  NVVM::NVVMDialect,
+                  ROCDL::ROCDLDialect,
+                  SDBMDialect,
+                  shape::ShapeDialect>();
+  // clang-format on
+}
+
 // This function should be called before creating any MLIRContext if one expect
 // all the possible dialects to be made available to the context automatically.
 inline void registerAllDialects() {
-  static bool init_once = []() {
-    registerDialect<acc::OpenACCDialect>();
-    registerDialect<AffineDialect>();
-    registerDialect<avx512::AVX512Dialect>();
-    registerDialect<gpu::GPUDialect>();
-    registerDialect<LLVM::LLVMAVX512Dialect>();
-    registerDialect<LLVM::LLVMDialect>();
-    registerDialect<linalg::LinalgDialect>();
-    registerDialect<scf::SCFDialect>();
-    registerDialect<omp::OpenMPDialect>();
-    registerDialect<quant::QuantizationDialect>();
-    registerDialect<spirv::SPIRVDialect>();
-    registerDialect<StandardOpsDialect>();
-    registerDialect<vector::VectorDialect>();
-    registerDialect<NVVM::NVVMDialect>();
-    registerDialect<ROCDL::ROCDLDialect>();
-    registerDialect<SDBMDialect>();
-    registerDialect<shape::ShapeDialect>();
-    return true;
-  }();
-  (void)init_once;
+  static bool initOnce =
+      ([]() { registerAllDialects(getGlobalDialectRegistry()); }(), true);
+  (void)initOnce;
 }
 } // namespace mlir
 
index 31ca025..a1771da 100644 (file)
@@ -28,7 +28,7 @@ void registerAVX512ToLLVMIRTranslation();
 // expects all the possible translations to be made available to the context
 // automatically.
 inline void registerAllTranslations() {
-  static bool init_once = []() {
+  static bool initOnce = []() {
     registerFromLLVMIRTranslation();
     registerFromSPIRVTranslation();
     registerToLLVMIRTranslation();
@@ -38,7 +38,7 @@ inline void registerAllTranslations() {
     registerAVX512ToLLVMIRTranslation();
     return true;
   }();
-  (void)init_once;
+  (void)initOnce;
 }
 } // namespace mlir
 
index 8de31d9..cd4c06a 100644 (file)
@@ -9,6 +9,7 @@
 #ifndef MLIR_PASS_PASS_H
 #define MLIR_PASS_PASS_H
 
+#include "mlir/IR/Dialect.h"
 #include "mlir/IR/Function.h"
 #include "mlir/Pass/AnalysisManager.h"
 #include "mlir/Pass/PassRegistry.h"
@@ -57,6 +58,13 @@ public:
   /// Returns the derived pass name.
   virtual StringRef getName() const = 0;
 
+  /// Register dependent dialects for the current pass.
+  /// A pass is expected to register the dialects it will create entities for
+  /// (Operations, Types, Attributes), other than dialect that exists in the
+  /// input. For example, a pass that converts from Linalg to Affine would
+  /// register the Affine dialect but does not need to register Linalg.
+  virtual void getDependentDialects(DialectRegistry &registry) const {}
+
   /// Returns the command line argument used when registering this pass. Return
   /// an empty string if one does not exist.
   virtual StringRef getArgument() const {
index 54b4403..7a2feff 100644 (file)
@@ -78,6 +78,9 @@ class PassBase<string passArg, string base> {
   // A C++ constructor call to create an instance of this pass.
   code constructor = [{}];
 
+  // A list of dialects this pass may produce entities in.
+  list<string> dependentDialects = [];
+
   // A set of options provided by this pass.
   list<Option> options = [];
 
index 9cbfb0b..29e7c07 100644 (file)
@@ -9,6 +9,7 @@
 #ifndef MLIR_PASS_PASSMANAGER_H
 #define MLIR_PASS_PASSMANAGER_H
 
+#include "mlir/IR/Dialect.h"
 #include "mlir/IR/OperationSupport.h"
 #include "mlir/Support/LogicalResult.h"
 #include "llvm/ADT/Optional.h"
@@ -58,6 +59,14 @@ public:
   pass_iterator end();
   iterator_range<pass_iterator> getPasses() { return {begin(), end()}; }
 
+  using const_pass_iterator = llvm::pointee_iterator<
+      std::vector<std::unique_ptr<Pass>>::const_iterator>;
+  const_pass_iterator begin() const;
+  const_pass_iterator end() const;
+  iterator_range<const_pass_iterator> getPasses() const {
+    return {begin(), end()};
+  }
+
   /// Run the held passes over the given operation.
   LogicalResult run(Operation *op, AnalysisManager am);
 
@@ -100,6 +109,11 @@ public:
   /// Merge the pass statistics of this class into 'other'.
   void mergeStatisticsInto(OpPassManager &other);
 
+  /// Register dependent dialects for the current pass manager.
+  /// This is forwarding to every pass in this PassManager, see the
+  /// documentation for the same method on the Pass class.
+  void getDependentDialects(DialectRegistry &dialects) const;
+
 private:
   OpPassManager(OperationName name, bool verifyPasses);
 
index 137cf66..da03bae 100644 (file)
@@ -21,12 +21,14 @@ class MemoryBuffer;
 } // end namespace llvm
 
 namespace mlir {
+class DialectRegistry;
 class PassPipelineCLParser;
 
 /// Perform the core processing behind `mlir-opt`:
 /// - outputStream is the stream where the resulting IR is printed.
 /// - buffer is the in-memory file to parser and process.
 /// - passPipeline is the specification of the pipeline that will be applied.
+/// - registry should contain all the dialects that can be parsed in the source.
 /// - splitInputFile will look for a "-----" marker in the input file, and load
 /// each chunk in an individual ModuleOp processed separately.
 /// - verifyDiagnostics enables a verification mode where comments starting with
@@ -35,13 +37,25 @@ class PassPipelineCLParser;
 /// - verifyPasses enables the IR verifier in-between each pass in the pipeline.
 /// - allowUnregisteredDialects allows to parse and create operation without
 /// registering the Dialect in the MLIRContext.
+/// - preloadDialectsInContext will trigger the upfront loading of all
+///   dialects from the global registry in the MLIRContext. This option is
+///   deprecated and will be removed soon.
 LogicalResult MlirOptMain(llvm::raw_ostream &outputStream,
                           std::unique_ptr<llvm::MemoryBuffer> buffer,
                           const PassPipelineCLParser &passPipeline,
-                          bool splitInputFile, bool verifyDiagnostics,
-                          bool verifyPasses, bool allowUnregisteredDialects);
+                          DialectRegistry &registry, bool splitInputFile,
+                          bool verifyDiagnostics, bool verifyPasses,
+                          bool allowUnregisteredDialects,
+                          bool preloadDialectsInContext = true);
 
 /// Implementation for tools like `mlir-opt`.
-LogicalResult MlirOptMain(int argc, char **argv, llvm::StringRef toolName);
+/// - toolName is used for the header displayed by `--help`.
+/// - registry should contain all the dialects that can be parsed in the source.
+/// - preloadDialectsInContext will trigger the upfront loading of all
+///   dialects from the global registry in the MLIRContext. This option is
+///   deprecated and will be removed soon.
+LogicalResult MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
+                          DialectRegistry &registry,
+                          bool preloadDialectsInContext = true);
 
 } // end namespace mlir
index 5e85806..99217d8 100644 (file)
@@ -14,6 +14,7 @@
 
 #include "mlir/Support/LLVM.h"
 #include <string>
+#include <vector>
 
 namespace llvm {
 class Record;
@@ -25,7 +26,7 @@ namespace tblgen {
 // and provides helper methods for accessing them.
 class Dialect {
 public:
-  explicit Dialect(const llvm::Record *def) : def(def) {}
+  explicit Dialect(const llvm::Record *def);
 
   // Returns the name of this dialect.
   StringRef getName() const;
@@ -43,6 +44,10 @@ public:
   // Returns the description of the dialect. Returns empty string if none.
   StringRef getDescription() const;
 
+  // Returns the list of dialect (class names) that this dialect depends on.
+  // These are dialects that will be loaded on construction of this dialect.
+  ArrayRef<StringRef> getDependentDialects() const;
+
   // Returns the dialects extra class declaration code.
   llvm::Optional<StringRef> getExtraClassDeclaration() const;
 
@@ -70,6 +75,7 @@ public:
 
 private:
   const llvm::Record *def;
+  std::vector<StringRef> dependentDialects;
 };
 } // end namespace tblgen
 } // end namespace mlir
index 02427e4..968c854 100644 (file)
@@ -94,6 +94,9 @@ public:
   /// Return the C++ constructor call to create an instance of this pass.
   StringRef getConstructor() const;
 
+  /// Return the dialects this pass needs to be registered.
+  ArrayRef<StringRef> getDependentDialects() const;
+
   /// Return the options provided by this pass.
   ArrayRef<PassOption> getOptions() const;
 
@@ -104,6 +107,7 @@ public:
 
 private:
   const llvm::Record *def;
+  std::vector<StringRef> dependentDialects;
   std::vector<PassOption> options;
   std::vector<PassStatistic> statistics;
 };
index 7787805..3292d5e 100644 (file)
@@ -162,6 +162,8 @@ def BufferPlacement : FunctionPass<"buffer-placement"> {
 
   }];
   let constructor = "mlir::createBufferPlacementPass()";
+  // TODO: this pass likely shouldn't depend on Linalg?
+  let dependentDialects = ["linalg::LinalgDialect"];
 }
 
 def Canonicalizer : Pass<"canonicalize"> {
index 1ba1a6a..9d3028d 100644 (file)
@@ -10,6 +10,7 @@
 
 #include "mlir/CAPI/IR.h"
 #include "mlir/IR/Attributes.h"
+#include "mlir/IR/Dialect.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/Operation.h"
 #include "mlir/IR/Types.h"
@@ -50,7 +51,7 @@ private:
 /* ========================================================================== */
 
 MlirContext mlirContextCreate() {
-  auto *context = new MLIRContext;
+  auto *context = new MLIRContext(/*loadAllDialects=*/false);
   return wrap(context);
 }
 
index 400a4eb..1d6294d 100644 (file)
@@ -8,6 +8,11 @@
 
 #include "mlir-c/Registration.h"
 
+#include "mlir/CAPI/IR.h"
 #include "mlir/InitAllDialects.h"
 
-void mlirRegisterAllDialects() { mlir::registerAllDialects(); }
+void mlirRegisterAllDialects(MlirContext context) {
+  registerAllDialects(unwrap(context)->getDialectRegistry());
+  // TODO: we may not want to eagerly load here.
+  unwrap(context)->getDialectRegistry().loadAll(unwrap(context));
+}
index 1ebf481..4267393 100644 (file)
@@ -16,6 +16,7 @@
 #include "../PassDetail.h"
 #include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
 #include "mlir/Dialect/GPU/GPUDialect.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/SPIRV/SPIRVOps.h"
 #include "mlir/Dialect/SPIRV/Serialization.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
index 7b57854..0460d98 100644 (file)
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
 #include "mlir/Dialect/Linalg/Passes.h"
+#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
index 6da0bc8..7fa5a5a 100644 (file)
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+class AffineDialect;
+class StandardOpsDialect;
+
+// Forward declaration from Dialect.h
+template <typename ConcreteDialect>
+void registerDialect(DialectRegistry &registry);
 
 namespace gpu {
+class GPUDialect;
 class GPUModuleOp;
 } // end namespace gpu
 
+namespace LLVM {
+class LLVMDialect;
+class LLVMAVX512Dialect;
+} // end namespace LLVM
+
+namespace NVVM {
+class NVVMDialect;
+} // end namespace NVVM
+
+namespace ROCDL {
+class ROCDLDialect;
+} // end namespace ROCDL
+
+namespace scf {
+class SCFDialect;
+} // end namespace scf
+
+namespace spirv {
+class SPIRVDialect;
+} // end namespace spirv
+
+namespace vector {
+class VectorDialect;
+} // end namespace vector
+
 #define GEN_PASS_CLASSES
 #include "mlir/Conversion/Passes.h.inc"
 
index 0ee1166..44d912b 100644 (file)
@@ -125,7 +125,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
 /// Create an LLVMTypeConverter using custom LowerToLLVMOptions.
 LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
                                      const LowerToLLVMOptions &options)
-    : llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()),
+    : llvmDialect(ctx->getOrLoadDialect<LLVM::LLVMDialect>()),
       options(options) {
   assert(llvmDialect && "LLVM IR dialect is not registered");
   if (options.indexBitwidth == kDeriveIndexBitwidthFromDataLayout)
index 19643d2..a2e608d 100644 (file)
@@ -14,6 +14,7 @@
 #include "../PassDetail.h"
 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRV.h"
 #include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/PatternMatch.h"
index 3bae059..da8f7ac 100644 (file)
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+// Forward declaration from Dialect.h
+template <typename ConcreteDialect>
+void registerDialect(DialectRegistry &registry);
+
+namespace linalg {
+class LinalgDialect;
+} // end namespace linalg
+namespace vector {
+class VectorDialect;
+} // end namespace vector
 
 #define GEN_PASS_CLASSES
 #include "mlir/Dialect/Affine/Passes.h.inc"
index 308272d..5d52d09 100644 (file)
@@ -1244,6 +1244,7 @@ template <typename NamedStructuredOpType>
 static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
                                           OperationState &result) {
   SmallVector<OpAsmParser::OperandType, 8> operandsInfo;
+  result.getContext()->getOrLoadDialect<StandardOpsDialect>();
 
   // Optional attributes may be added.
   if (parser.parseOperandList(operandsInfo) ||
index 7fa05ff..0415aeb 100644 (file)
@@ -9,9 +9,18 @@
 #ifndef DIALECT_LINALG_TRANSFORMS_PASSDETAIL_H_
 #define DIALECT_LINALG_TRANSFORMS_PASSDETAIL_H_
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/IR/Dialect.h"
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+// Forward declaration from Dialect.h
+template <typename ConcreteDialect>
+void registerDialect(DialectRegistry &registry);
+
+namespace scf {
+class SCFDialect;
+} // end namespace scf
 
 #define GEN_PASS_CLASSES
 #include "mlir/Dialect/Linalg/Passes.h.inc"
index 95f8636..6fa7f22 100644 (file)
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+// Forward declaration from Dialect.h
+template <typename ConcreteDialect>
+void registerDialect(DialectRegistry &registry);
+
+class AffineDialect;
 
 #define GEN_PASS_CLASSES
 #include "mlir/Dialect/SCF/Passes.h.inc"
index 8da6c40..5adcbcc 100644 (file)
@@ -516,7 +516,7 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
 
     SDBMDialect *dialect;
   } converter;
-  converter.dialect = affine.getContext()->getRegisteredDialect<SDBMDialect>();
+  converter.dialect = affine.getContext()->getOrLoadDialect<SDBMDialect>();
 
   if (auto result = converter.visit(affine))
     return result;
index 7959183..2b18adb 100644 (file)
@@ -259,7 +259,9 @@ int mlir::JitRunnerMain(
     }
   }
 
-  MLIRContext context;
+  MLIRContext context(/*loadAllDialects=*/false);
+  registerAllDialects(&context);
+
   auto m = parseMLIRInput(options.inputFilename, &context);
   if (!m) {
     llvm::errs() << "could not parse the input IR\n";
index 555bb2b..44b0ee6 100644 (file)
@@ -27,21 +27,29 @@ DialectAsmParser::~DialectAsmParser() {}
 //===----------------------------------------------------------------------===//
 
 /// Registry for all dialect allocation functions.
-static llvm::ManagedStatic<llvm::MapVector<TypeID, DialectAllocatorFunction>>
-    dialectRegistry;
-
-void Dialect::registerDialectAllocator(
-    TypeID typeID, const DialectAllocatorFunction &function) {
-  assert(function &&
-         "Attempting to register an empty dialect initialize function");
-  dialectRegistry->insert({typeID, function});
-}
+static llvm::ManagedStatic<DialectRegistry> dialectRegistry;
+DialectRegistry &mlir::getGlobalDialectRegistry() { return *dialectRegistry; }
 
-/// Registers all dialects and hooks from the global registries with the
-/// specified MLIRContext.
 void mlir::registerAllDialects(MLIRContext *context) {
-  for (const auto &it : *dialectRegistry)
-    it.second(context);
+  dialectRegistry->appendTo(context->getDialectRegistry());
+}
+
+Dialect *DialectRegistry::loadByName(StringRef name, MLIRContext *context) {
+  auto it = registry.find(name.str());
+  if (it == registry.end())
+    return nullptr;
+  return it->second.second(context);
+}
+
+void DialectRegistry::insert(TypeID typeID, StringRef name,
+                             DialectAllocatorFunction ctor) {
+  auto inserted =
+      registry.insert(std::make_pair(name, std::make_pair(typeID, ctor)));
+  if (!inserted.second && inserted.first->second.first != typeID) {
+    llvm::report_fatal_error(
+        "Trying to register different dialects for the same namespace: " +
+        name);
+  }
 }
 
 //===----------------------------------------------------------------------===//
@@ -119,7 +127,7 @@ DialectInterface::~DialectInterface() {}
 
 DialectInterfaceCollectionBase::DialectInterfaceCollectionBase(
     MLIRContext *ctx, TypeID interfaceKind) {
-  for (auto *dialect : ctx->getRegisteredDialects()) {
+  for (auto *dialect : ctx->getLoadedDialects()) {
     if (auto *interface = dialect->getRegisteredInterface(interfaceKind)) {
       interfaces.insert(interface);
       orderedInterfaces.push_back(interface);
index a86f27a..b47c143 100644 (file)
 #include "llvm/ADT/Twine.h"
 #include "llvm/Support/Allocator.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/RWMutex.h"
 #include "llvm/Support/raw_ostream.h"
 #include <memory>
 
+#define DEBUG_TYPE "mlircontext"
+
 using namespace mlir;
 using namespace mlir::detail;
 
@@ -274,7 +277,8 @@ public:
 
   /// This is a list of dialects that are created referring to this context.
   /// The MLIRContext owns the objects.
-  std::vector<std::unique_ptr<Dialect>> dialects;
+  DenseMap<StringRef, std::unique_ptr<Dialect>> loadedDialects;
+  DialectRegistry dialectsRegistry;
 
   /// This is a mapping from operation name to AbstractOperation for registered
   /// operations.
@@ -345,7 +349,7 @@ public:
 };
 } // end namespace mlir
 
-MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
+MLIRContext::MLIRContext(bool loadAllDialects) : impl(new MLIRContextImpl()) {
   // Initialize values based on the command line flags if they were provided.
   if (clOptions.isConstructed()) {
     disableMultithreading(clOptions->disableThreading);
@@ -354,8 +358,9 @@ MLIRContext::MLIRContext() : impl(new MLIRContextImpl()) {
   }
 
   // Register dialects with this context.
-  getOrCreateDialect<BuiltinDialect>();
-  registerAllDialects(this);
+  getOrLoadDialect<BuiltinDialect>();
+  if (loadAllDialects)
+    loadAllGloballyRegisteredDialects();
 
   // Initialize several common attributes and types to avoid the need to lock
   // the context when accessing them.
@@ -431,54 +436,72 @@ DiagnosticEngine &MLIRContext::getDiagEngine() { return getImpl().diagEngine; }
 // Dialect and Operation Registration
 //===----------------------------------------------------------------------===//
 
+DialectRegistry &MLIRContext::getDialectRegistry() {
+  return impl->dialectsRegistry;
+}
+
 /// Return information about all registered IR dialects.
-std::vector<Dialect *> MLIRContext::getRegisteredDialects() {
+std::vector<Dialect *> MLIRContext::getLoadedDialects() {
   std::vector<Dialect *> result;
-  result.reserve(impl->dialects.size());
-  for (auto &dialect : impl->dialects)
-    result.push_back(dialect.get());
+  result.reserve(impl->loadedDialects.size());
+  for (auto &dialect : impl->loadedDialects)
+    result.push_back(dialect.second.get());
+  llvm::array_pod_sort(result.begin(), result.end(),
+                       [](Dialect *const *lhs, Dialect *const *rhs) -> int {
+                         return (*lhs)->getNamespace() < (*rhs)->getNamespace();
+                       });
+  return result;
+}
+std::vector<StringRef> MLIRContext::getAvailableDialects() {
+  std::vector<StringRef> result;
+  for (auto &dialect : impl->dialectsRegistry)
+    result.push_back(dialect.first);
   return result;
 }
 
 /// Get a registered IR dialect with the given namespace. If none is found,
 /// then return nullptr.
-Dialect *MLIRContext::getRegisteredDialect(StringRef name) {
+Dialect *MLIRContext::getLoadedDialect(StringRef name) {
   // Dialects are sorted by name, so we can use binary search for lookup.
-  auto it = llvm::lower_bound(
-      impl->dialects, name,
-      [](const auto &lhs, StringRef rhs) { return lhs->getNamespace() < rhs; });
-  return (it != impl->dialects.end() && (*it)->getNamespace() == name)
-             ? (*it).get()
-             : nullptr;
+  auto it = impl->loadedDialects.find(name);
+  return (it != impl->loadedDialects.end()) ? it->second.get() : nullptr;
+}
+
+Dialect *MLIRContext::getOrLoadDialect(StringRef name) {
+  Dialect *dialect = getLoadedDialect(name);
+  if (dialect)
+    return dialect;
+  return impl->dialectsRegistry.loadByName(name, this);
 }
 
 /// Get a dialect for the provided namespace and TypeID: abort the program if a
 /// dialect exist for this namespace with different TypeID. Returns a pointer to
 /// the dialect owned by the context.
 Dialect *
-MLIRContext::getOrCreateDialect(StringRef dialectNamespace, TypeID dialectID,
-                                function_ref<std::unique_ptr<Dialect>()> ctor) {
+MLIRContext::getOrLoadDialect(StringRef dialectNamespace, TypeID dialectID,
+                              function_ref<std::unique_ptr<Dialect>()> ctor) {
   auto &impl = getImpl();
   // Get the correct insertion position sorted by namespace.
-  auto insertPt =
-      llvm::lower_bound(impl.dialects, nullptr,
-                        [&](const std::unique_ptr<Dialect> &lhs,
-                            const std::unique_ptr<Dialect> &rhs) {
-                          if (!lhs)
-                            return dialectNamespace < rhs->getNamespace();
-                          return lhs->getNamespace() < dialectNamespace;
-                        });
+  std::unique_ptr<Dialect> &dialect = impl.loadedDialects[dialectNamespace];
+
+  if (!dialect) {
+    LLVM_DEBUG(llvm::dbgs()
+               << "Load new dialect in Context" << dialectNamespace);
+    dialect = ctor();
+    assert(dialect && "dialect ctor failed");
+    return dialect.get();
+  }
 
   // Abort if dialect with namespace has already been registered.
-  if (insertPt != impl.dialects.end() &&
-      (*insertPt)->getNamespace() == dialectNamespace) {
-    if ((*insertPt)->getTypeID() == dialectID)
-      return insertPt->get();
+  if (dialect->getTypeID() != dialectID)
     llvm::report_fatal_error("a dialect with namespace '" + dialectNamespace +
                              "' has already been registered");
-  }
-  auto it = impl.dialects.insert(insertPt, ctor());
-  return &**it;
+
+  return dialect.get();
+}
+
+void MLIRContext::loadAllGloballyRegisteredDialects() {
+  getGlobalDialectRegistry().loadAll(this);
 }
 
 bool MLIRContext::allowsUnregisteredDialects() {
index 4ddc3df..67249b8 100644 (file)
@@ -214,7 +214,7 @@ Dialect *Operation::getDialect() {
 
   // If this operation hasn't been registered or doesn't have abstract
   // operation, try looking up the dialect name in the context.
-  return getContext()->getRegisteredDialect(getName().getDialect());
+  return getContext()->getLoadedDialect(getName().getDialect());
 }
 
 Region *Operation::getParentRegion() {
index b1aed88..4caf989 100644 (file)
@@ -50,7 +50,7 @@ public:
   Dialect *getDialectForAttribute(const NamedAttribute &attr) {
     assert(attr.first.strref().contains('.') && "expected dialect attribute");
     auto dialectNamePair = attr.first.strref().split('.');
-    return ctx->getRegisteredDialect(dialectNamePair.first);
+    return ctx->getLoadedDialect(dialectNamePair.first);
   }
 
 private:
@@ -218,7 +218,7 @@ LogicalResult OperationVerifier::verifyOperation(Operation &op) {
   auto it = dialectAllowsUnknownOps.find(dialectPrefix);
   if (it == dialectAllowsUnknownOps.end()) {
     // If the operation dialect is registered, query it directly.
-    if (auto *dialect = ctx->getRegisteredDialect(dialectPrefix))
+    if (auto *dialect = ctx->getLoadedDialect(dialectPrefix))
       it = dialectAllowsUnknownOps
                .try_emplace(dialectPrefix, dialect->allowsUnknownOperations())
                .first;
index 1c1261e..37ee938 100644 (file)
@@ -12,6 +12,7 @@
 
 #include "Parser.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Dialect.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/StandardTypes.h"
 #include "llvm/ADT/StringExtras.h"
@@ -246,6 +247,11 @@ ParseResult Parser::parseAttributeDict(NamedAttrList &attributes) {
       return emitError("duplicate key in dictionary attribute");
     consumeToken();
 
+    // Lazy load a dialect in the context if there is a possible namespace.
+    auto splitName = nameId->strref().split('.');
+    if (!splitName.second.empty())
+      getContext()->getOrLoadDialect(splitName.first);
+
     // Try to parse the '=' for the attribute value.
     if (!consumeIf(Token::equal)) {
       // If there is no '=', we treat this as a unit attribute.
@@ -817,7 +823,9 @@ Attribute Parser::parseOpaqueElementsAttr(Type attrType) {
     return (emitError("expected dialect namespace"), nullptr);
 
   auto name = getToken().getStringValue();
-  auto *dialect = builder.getContext()->getRegisteredDialect(name);
+  // Lazy load a dialect in the context if there is a possible namespace.
+  Dialect *dialect = builder.getContext()->getOrLoadDialect(name);
+
   // TODO: Allow for having an unknown dialect on an opaque
   // attribute. Otherwise, it can't be roundtripped without having the dialect
   // registered.
index 3b522a8..d45ddf0 100644 (file)
@@ -526,7 +526,8 @@ Attribute Parser::parseExtendedAttr(Type type) {
           return Attribute();
 
         // If we found a registered dialect, then ask it to parse the attribute.
-        if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
+        if (Dialect *dialect =
+                builder.getContext()->getOrLoadDialect(dialectName)) {
           return parseSymbol<Attribute>(
               symbolData, state.context, state.symbols, [&](Parser &parser) {
                 CustomDialectAsmParser customParser(symbolData, parser);
@@ -563,7 +564,9 @@ Type Parser::parseExtendedType() {
       [&](StringRef dialectName, StringRef symbolData,
           llvm::SMLoc loc) -> Type {
         // If we found a registered dialect, then ask it to parse the type.
-        if (auto *dialect = state.context->getRegisteredDialect(dialectName)) {
+        auto *dialect = state.context->getOrLoadDialect(dialectName);
+
+        if (dialect) {
           return parseSymbol<Type>(
               symbolData, state.context, state.symbols, [&](Parser &parser) {
                 CustomDialectAsmParser customParser(symbolData, parser);
index 3a995a4..837b08c 100644 (file)
@@ -12,6 +12,7 @@
 
 #include "Parser.h"
 #include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Dialect.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/Verifier.h"
 #include "mlir/Parser.h"
@@ -727,7 +728,7 @@ Operation *OperationParser::parseGenericOperation() {
   // Get location information for the operation.
   auto srcLocation = getEncodedSourceLocation(getToken().getLoc());
 
-  auto name = getToken().getStringValue();
+  std::string name = getToken().getStringValue();
   if (name.empty())
     return (emitError("empty operation name is invalid"), nullptr);
   if (name.find('\0') != StringRef::npos)
@@ -737,6 +738,15 @@ Operation *OperationParser::parseGenericOperation() {
 
   OperationState result(srcLocation, name);
 
+  // Lazy load dialects in the context as needed.
+  if (!result.name.getAbstractOperation()) {
+    StringRef dialectName = StringRef(name).split('.').first;
+    if (!getContext()->getLoadedDialect(dialectName) &&
+        getContext()->getOrLoadDialect(dialectName)) {
+      result.name = OperationName(name, getContext());
+    }
+  }
+
   // Parse the operand list.
   SmallVector<SSAUseInfo, 8> operandInfos;
   if (parseToken(Token::l_paren, "expected '(' to start operand list") ||
@@ -1442,17 +1452,28 @@ private:
 
 Operation *
 OperationParser::parseCustomOperation(ArrayRef<ResultRecord> resultIDs) {
-  auto opLoc = getToken().getLoc();
-  auto opName = getTokenSpelling();
+  llvm::SMLoc opLoc = getToken().getLoc();
+  StringRef opName = getTokenSpelling();
 
   auto *opDefinition = AbstractOperation::lookup(opName, getContext());
-  if (!opDefinition && !opName.contains('.')) {
-    // If the operation name has no namespace prefix we treat it as a standard
-    // operation and prefix it with "std".
-    // TODO: Would it be better to just build a mapping of the registered
-    // operations in the standard dialect?
-    opDefinition =
-        AbstractOperation::lookup(Twine("std." + opName).str(), getContext());
+  if (!opDefinition) {
+    if (opName.contains('.')) {
+      // This op has a dialect, we try to check if we can register it in the
+      // context on the fly.
+      StringRef dialectName = opName.split('.').first;
+      if (!getContext()->getLoadedDialect(dialectName) &&
+          getContext()->getOrLoadDialect(dialectName)) {
+        opDefinition = AbstractOperation::lookup(opName, getContext());
+      }
+    } else {
+      // If the operation name has no namespace prefix we treat it as a standard
+      // operation and prefix it with "std".
+      // TODO: Would it be better to just build a mapping of the registered
+      // operations in the standard dialect?
+      if (getContext()->getOrLoadDialect("std"))
+        opDefinition = AbstractOperation::lookup(Twine("std." + opName).str(),
+                                                 getContext());
+    }
   }
 
   if (!opDefinition) {
index b791bf4..9bc23c2 100644 (file)
@@ -290,6 +290,13 @@ OpPassManager::pass_iterator OpPassManager::begin() {
 }
 OpPassManager::pass_iterator OpPassManager::end() { return impl->passes.end(); }
 
+OpPassManager::const_pass_iterator OpPassManager::begin() const {
+  return impl->passes.begin();
+}
+OpPassManager::const_pass_iterator OpPassManager::end() const {
+  return impl->passes.end();
+}
+
 /// Run all of the passes in this manager over the current operation.
 LogicalResult OpPassManager::run(Operation *op, AnalysisManager am) {
   // Run each of the held passes.
@@ -346,6 +353,16 @@ void OpPassManager::printAsTextualPipeline(raw_ostream &os) {
   ::printAsTextualPipeline(impl->passes, os);
 }
 
+static void registerDialectsForPipeline(const OpPassManager &pm,
+                                        DialectRegistry &dialects) {
+  for (const Pass &pass : pm.getPasses())
+    pass.getDependentDialects(dialects);
+}
+
+void OpPassManager::getDependentDialects(DialectRegistry &dialects) const {
+  registerDialectsForPipeline(*this, dialects);
+}
+
 //===----------------------------------------------------------------------===//
 // OpToOpPassAdaptor
 //===----------------------------------------------------------------------===//
@@ -378,6 +395,11 @@ OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) {
   mgrs.emplace_back(std::move(mgr));
 }
 
+void OpToOpPassAdaptor::getDependentDialects(DialectRegistry &dialects) const {
+  for (auto &pm : mgrs)
+    pm.getDependentDialects(dialects);
+}
+
 /// Merge the current pass adaptor into given 'rhs'.
 void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) {
   for (auto &pm : mgrs) {
@@ -721,6 +743,11 @@ LogicalResult PassManager::run(ModuleOp module) {
   // pipeline.
   getImpl().coalesceAdjacentAdaptorPasses();
 
+  // Register all dialects for the current pipeline.
+  DialectRegistry dependentDialects;
+  getDependentDialects(dependentDialects);
+  dependentDialects.loadAll(module.getContext());
+
   // Construct an analysis manager for the pipeline.
   ModuleAnalysisManager am(module, instrumentor.get());
 
index 2342a1a..f69701d 100644 (file)
@@ -43,6 +43,10 @@ public:
   /// Returns the pass managers held by this adaptor.
   MutableArrayRef<OpPassManager> getPassManagers() { return mgrs; }
 
+  /// Populate the set of dependent dialects for the passes in the current
+  /// adaptor.
+  void getDependentDialects(DialectRegistry &dialects) const override;
+
   /// Return the async pass managers held by this parallel adaptor.
   MutableArrayRef<SmallVector<OpPassManager, 1>> getParallelPassManagers() {
     return asyncExecutors;
index 699eded..77b0760 100644 (file)
@@ -81,13 +81,18 @@ static LogicalResult processBuffer(raw_ostream &os,
                                    std::unique_ptr<MemoryBuffer> ownedBuffer,
                                    bool verifyDiagnostics, bool verifyPasses,
                                    bool allowUnregisteredDialects,
-                                   const PassPipelineCLParser &passPipeline) {
+                                   bool preloadDialectsInContext,
+                                   const PassPipelineCLParser &passPipeline,
+                                   DialectRegistry &registry) {
   // Tell sourceMgr about this buffer, which is what the parser will pick up.
   SourceMgr sourceMgr;
   sourceMgr.AddNewSourceBuffer(std::move(ownedBuffer), SMLoc());
 
   // Parse the input file.
-  MLIRContext context;
+  MLIRContext context(/*loadAllDialects=*/preloadDialectsInContext);
+  registry.appendTo(context.getDialectRegistry());
+  if (preloadDialectsInContext)
+    registry.loadAll(&context);
   context.allowUnregisteredDialects(allowUnregisteredDialects);
   context.printOpOnDiagnostic(!verifyDiagnostics);
 
@@ -115,9 +120,10 @@ static LogicalResult processBuffer(raw_ostream &os,
 LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
                                 std::unique_ptr<MemoryBuffer> buffer,
                                 const PassPipelineCLParser &passPipeline,
-                                bool splitInputFile, bool verifyDiagnostics,
-                                bool verifyPasses,
-                                bool allowUnregisteredDialects) {
+                                DialectRegistry &registry, bool splitInputFile,
+                                bool verifyDiagnostics, bool verifyPasses,
+                                bool allowUnregisteredDialects,
+                                bool preloadDialectsInContext) {
   // The split-input-file mode is a very specific mode that slices the file
   // up into small pieces and checks each independently.
   if (splitInputFile)
@@ -126,15 +132,19 @@ LogicalResult mlir::MlirOptMain(raw_ostream &outputStream,
         [&](std::unique_ptr<MemoryBuffer> chunkBuffer, raw_ostream &os) {
           return processBuffer(os, std::move(chunkBuffer), verifyDiagnostics,
                                verifyPasses, allowUnregisteredDialects,
-                               passPipeline);
+                               preloadDialectsInContext, passPipeline,
+                               registry);
         },
         outputStream);
 
   return processBuffer(outputStream, std::move(buffer), verifyDiagnostics,
-                       verifyPasses, allowUnregisteredDialects, passPipeline);
+                       verifyPasses, allowUnregisteredDialects,
+                       preloadDialectsInContext, passPipeline, registry);
 }
 
-LogicalResult mlir::MlirOptMain(int argc, char **argv, StringRef toolName) {
+LogicalResult mlir::MlirOptMain(int argc, char **argv, llvm::StringRef toolName,
+                                DialectRegistry &registry,
+                                bool preloadDialectsInContext) {
   static cl::opt<std::string> inputFilename(
       cl::Positional, cl::desc("<input file>"), cl::init("-"));
 
@@ -180,25 +190,19 @@ LogicalResult mlir::MlirOptMain(int argc, char **argv, StringRef toolName) {
   {
     llvm::raw_string_ostream os(helpHeader);
     MLIRContext context;
-    interleaveComma(context.getRegisteredDialects(), os, [&](Dialect *dialect) {
-      StringRef name = dialect->getNamespace();
-      // filter the builtin dialect.
-      if (name.empty())
-        os << "<builtin>";
-      else
-        os << name;
+    interleaveComma(registry, os, [&](auto &registryEntry) {
+      StringRef name = registryEntry.first;
+      os << name;
     });
   }
   // Parse pass names in main to ensure static initialization completed.
   cl::ParseCommandLineOptions(argc, argv, helpHeader);
 
   if (showDialects) {
-    llvm::outs() << "Registered Dialects:\n";
-    MLIRContext context;
+    llvm::outs() << "Available Dialects:\n";
     interleave(
-        context.getRegisteredDialects(), llvm::outs(),
-        [](Dialect *dialect) { llvm::outs() << dialect->getNamespace(); },
-        "\n");
+        registry, llvm::outs(),
+        [](auto &registryEntry) { llvm::outs() << registryEntry.first; }, "\n");
     return success();
   }
 
@@ -216,9 +220,9 @@ LogicalResult mlir::MlirOptMain(int argc, char **argv, StringRef toolName) {
     return failure();
   }
 
-  if (failed(MlirOptMain(output->os(), std::move(file), passPipeline,
+  if (failed(MlirOptMain(output->os(), std::move(file), passPipeline, registry,
                          splitInputFile, verifyDiagnostics, verifyPasses,
-                         allowUnregisteredDialects)))
+                         allowUnregisteredDialects, preloadDialectsInContext)))
     return failure();
 
   // Keep the output file if the invocation of MlirOptMain was successful.
index 6af77e7..2b5f7e5 100644 (file)
 
 using namespace mlir;
 using namespace mlir::tblgen;
+Dialect::Dialect(const llvm::Record *def) : def(def) {
+  for (StringRef dialect : def->getValueAsListOfStrings("dependentDialects"))
+    dependentDialects.push_back(dialect);
+}
 
 StringRef Dialect::getName() const { return def->getValueAsString("name"); }
 
@@ -46,6 +50,10 @@ StringRef Dialect::getDescription() const {
   return getAsStringOrEmpty(*def, "description");
 }
 
+ArrayRef<StringRef> Dialect::getDependentDialects() const {
+  return dependentDialects;
+}
+
 llvm::Optional<StringRef> Dialect::getExtraClassDeclaration() const {
   auto value = def->getValueAsString("extraClassDeclaration");
   return value.empty() ? llvm::Optional<StringRef>() : value;
index 4bc46b6..f961806 100644 (file)
@@ -69,6 +69,8 @@ Pass::Pass(const llvm::Record *def) : def(def) {
     options.push_back(PassOption(init));
   for (auto *init : def->getValueAsListOfDefs("statistics"))
     statistics.push_back(PassStatistic(init));
+  for (StringRef dialect : def->getValueAsListOfStrings("dependentDialects"))
+    dependentDialects.push_back(dialect);
 }
 
 StringRef Pass::getArgument() const {
@@ -88,6 +90,9 @@ StringRef Pass::getDescription() const {
 StringRef Pass::getConstructor() const {
   return def->getValueAsString("constructor");
 }
+ArrayRef<StringRef> Pass::getDependentDialects() const {
+  return dependentDialects;
+}
 
 ArrayRef<PassOption> Pass::getOptions() const { return options; }
 
index 470044b..a5d8338 100644 (file)
@@ -836,6 +836,7 @@ LogicalResult Importer::processBasicBlock(llvm::BasicBlock *bb, Block *block) {
 OwningModuleRef
 mlir::translateLLVMIRToModule(std::unique_ptr<llvm::Module> llvmModule,
                               MLIRContext *context) {
+  context->loadDialect<LLVMDialect>();
   OwningModuleRef module(ModuleOp::create(
       FileLineColLoc::get("", /*line=*/0, /*column=*/0, context)));
 
index f8277d1..21f5201 100644 (file)
@@ -302,8 +302,7 @@ ModuleTranslation::ModuleTranslation(Operation *module,
     : mlirModule(module), llvmModule(std::move(llvmModule)),
       debugTranslation(
           std::make_unique<DebugTranslation>(module, *this->llvmModule)),
-      ompDialect(
-          module->getContext()->getRegisteredDialect<omp::OpenMPDialect>()),
+      ompDialect(module->getContext()->getOrLoadDialect<omp::OpenMPDialect>()),
       typeTranslator(this->llvmModule->getContext()) {
   assert(satisfiesLLVMModule(mlirModule) &&
          "mlirModule should honor LLVM's module semantics.");
@@ -944,8 +943,8 @@ ModuleTranslation::lookupValues(ValueRange values) {
 
 std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(
     Operation *m, llvm::LLVMContext &llvmContext, StringRef name) {
+  m->getContext()->getOrLoadDialect<LLVM::LLVMDialect>();
   auto llvmModule = std::make_unique<llvm::Module>(name, llvmContext);
-
   if (auto dataLayoutAttr =
           m->getAttr(LLVM::LLVMDialect::getDataLayoutAttrName()))
     llvmModule->setDataLayout(dataLayoutAttr.cast<StringAttr>().getValue());
index c6f7e22..220ed1a 100644 (file)
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+// Forward declaration from Dialect.h
+template <typename ConcreteDialect>
+void registerDialect(DialectRegistry &registry);
+
+namespace linalg {
+class LinalgDialect;
+} // end namespace linalg
 
 #define GEN_PASS_CLASSES
 #include "mlir/Transforms/Passes.h.inc"
index 56b7ecd..12dc100 100644 (file)
@@ -381,8 +381,8 @@ static int printStandardTypes(MlirContext ctx) {
 }
 
 int main() {
-  mlirRegisterAllDialects();
   MlirContext ctx = mlirContextCreate();
+  mlirRegisterAllDialects(ctx);
   MlirLocation location = mlirLocationUnknownGet(ctx);
 
   MlirModule moduleOp = makeAdd(ctx, location);
index 3fcfcf2..062e4b5 100644 (file)
@@ -36,16 +36,18 @@ using namespace mlir::edsc;
 using namespace mlir::edsc::intrinsics;
 
 static MLIRContext &globalContext() {
-  static bool init_once = []() {
-    registerDialect<AffineDialect>();
-    registerDialect<linalg::LinalgDialect>();
-    registerDialect<scf::SCFDialect>();
-    registerDialect<StandardOpsDialect>();
-    registerDialect<vector::VectorDialect>();
+  static thread_local MLIRContext context(/*loadAllDialects=*/false);
+  static thread_local bool initOnce = [&]() {
+    // clang-format off
+    context.loadDialect<AffineDialect,
+                        scf::SCFDialect,
+                        linalg::LinalgDialect,
+                        StandardOpsDialect,
+                        vector::VectorDialect>();
+    // clang-format on
     return true;
   }();
-  (void)init_once;
-  static thread_local MLIRContext context;
+  (void)initOnce;
   context.allowUnregisteredDialects();
   return context;
 }
index 0b58e29..ddefc52 100644 (file)
 
 using namespace mlir;
 
-// Load the SDBM dialect
-static DialectRegistration<SDBMDialect> SDBMRegistration;
 
 static MLIRContext *ctx() {
-  static thread_local MLIRContext context;
+  static thread_local MLIRContext context(/*loadAllDialects=*/false);
+  static thread_local bool once =
+      (context.getOrLoadDialect<SDBMDialect>(), true);
+  (void)once;
   return &context;
 }
 
 static SDBMDialect *dialect() {
   static thread_local SDBMDialect *d = nullptr;
   if (!d) {
-    d = ctx()->getRegisteredDialect<SDBMDialect>();
+    d = ctx()->getOrLoadDialect<SDBMDialect>();
   }
   return d;
 }
index a6719b0..cfac2dc 100644 (file)
@@ -14,6 +14,7 @@
 #include "mlir/Analysis/NestedMatcher.h"
 #include "mlir/Analysis/SliceAnalysis.h"
 #include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Dialect/Vector/VectorUtils.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Diagnostics.h"
@@ -72,6 +73,9 @@ struct VectorizerTestPass
     : public PassWrapper<VectorizerTestPass, FunctionPass> {
   static constexpr auto kTestAffineMapOpName = "test_affine_map";
   static constexpr auto kTestAffineMapAttrName = "affine_map";
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<vector::VectorDialect>();
+  }
 
   void runOnFunction() override;
   void testVectorShapeRatio(llvm::raw_ostream &outs);
index 0c1069f..03c425d 100644 (file)
@@ -30,7 +30,7 @@ void PrintOpAvailability::runOnFunction() {
   auto f = getFunction();
   llvm::outs() << f.getName() << "\n";
 
-  Dialect *spvDialect = getContext().getRegisteredDialect("spv");
+  Dialect *spvDialect = getContext().getLoadedDialect("spv");
 
   f.getOperation()->walk([&](Operation *op) {
     if (op->getDialect() != spvDialect)
index 7bea72d..bccdfb7 100644 (file)
 
 using namespace mlir;
 
+void mlir::registerTestDialect(DialectRegistry &registry) {
+  registry.insert<TestDialect>();
+}
+
 //===----------------------------------------------------------------------===//
 // TestDialect Interfaces
 //===----------------------------------------------------------------------===//
index fd1914c..34fc1a9 100644 (file)
@@ -37,6 +37,8 @@ namespace mlir {
 #define GET_OP_CLASSES
 #include "TestOps.h.inc"
 
+void registerTestDialect(DialectRegistry &registry);
+
 } // end namespace mlir
 
 #endif // MLIR_TESTDIALECT_H
index f2a17a9..be5d799 100644 (file)
@@ -768,6 +768,10 @@ struct TestTypeConversionProducer
 
 struct TestTypeConversionDriver
     : public PassWrapper<TestTypeConversionDriver, OperationPass<ModuleOp>> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<TestDialect>();
+  }
+
   void runOnOperation() override {
     // Initialize the type converter.
     TypeConverter converter;
index c043d0f..0c72b6c 100644 (file)
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "mlir/Dialect/GPU/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
 
@@ -19,6 +20,9 @@ using namespace mlir;
 namespace {
 struct TestAllReduceLoweringPass
     : public PassWrapper<TestAllReduceLoweringPass, OperationPass<ModuleOp>> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<StandardOpsDialect>();
+  }
   void runOnOperation() override {
     OwningRewritePatternList patterns;
     populateGpuRewritePatterns(&getContext(), patterns);
index 5ad441a..6cc0924 100644 (file)
@@ -116,6 +116,10 @@ struct TestBufferPlacementPreparationPass
     patterns->insert<GenericOpConverter>(context, placer, converter);
   }
 
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<linalg::LinalgDialect>();
+  }
+
   void runOnOperation() override {
     MLIRContext &context = this->getContext();
     ConversionTarget target(context);
index 08862dd..3c2b933 100644 (file)
@@ -13,6 +13,9 @@
 
 #include "mlir/Dialect/GPU/GPUDialect.h"
 #include "mlir/Dialect/GPU/MemoryPromotion.h"
+#include "mlir/Dialect/SCF/SCF.h"
+#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/Pass/Pass.h"
 
@@ -26,6 +29,10 @@ namespace {
 class TestGpuMemoryPromotionPass
     : public PassWrapper<TestGpuMemoryPromotionPass,
                          OperationPass<gpu::GPUFuncOp>> {
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<StandardOpsDialect, scf::SCFDialect>();
+  }
+
   void runOnOperation() override {
     gpu::GPUFuncOp op = getOperation();
     for (unsigned i = 0, e = op.getNumArguments(); i < e; ++i) {
index d1e478f..5d4031f 100644 (file)
@@ -10,6 +10,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Hoisting.h"
 #include "mlir/Pass/Pass.h"
@@ -22,6 +23,9 @@ struct TestLinalgHoisting
     : public PassWrapper<TestLinalgHoisting, FunctionPass> {
   TestLinalgHoisting() = default;
   TestLinalgHoisting(const TestLinalgHoisting &pass) {}
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect>();
+  }
 
   void runOnFunction() override;
 
index dffe4f2..4fc880a 100644 (file)
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Linalg/Utils/Utils.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
@@ -30,6 +31,16 @@ struct TestLinalgTransforms
   TestLinalgTransforms() = default;
   TestLinalgTransforms(const TestLinalgTransforms &pass) {}
 
+  void getDependentDialects(DialectRegistry &registry) const override {
+    // clang-format off
+    registry.insert<AffineDialect,
+                    scf::SCFDialect,
+                    StandardOpsDialect,
+                    vector::VectorDialect,
+                    gpu::GPUDialect>();
+    // clang-format on
+  }
+
   void runOnFunction() override;
 
   Option<bool> testPatterns{*this, "test-patterns",
index 9da3156..ab84603 100644 (file)
@@ -8,6 +8,9 @@
 
 #include <type_traits>
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
+#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
+#include "mlir/Dialect/SCF/SCF.h"
 #include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/Dialect/Vector/VectorOps.h"
 #include "mlir/Dialect/Vector/VectorTransforms.h"
@@ -128,6 +131,11 @@ struct TestVectorTransferFullPartialSplitPatterns
   TestVectorTransferFullPartialSplitPatterns() = default;
   TestVectorTransferFullPartialSplitPatterns(
       const TestVectorTransferFullPartialSplitPatterns &pass) {}
+
+  void getDependentDialects(DialectRegistry &registry) const override {
+    registry.insert<AffineDialect, linalg::LinalgDialect, scf::SCFDialect>();
+  }
+
   Option<bool> useLinalgOps{
       *this, "use-linalg-copy",
       llvm::cl::desc("Split using a unmasked vector.transfer + linalg.fill + "
index f99a68d..4cf6ea9 100644 (file)
@@ -1,5 +1,5 @@
 // RUN: mlir-opt --show-dialects | FileCheck %s
-// CHECK: Registered Dialects:
+// CHECK: Available Dialects:
 // CHECK: affine
 // CHECK: gpu
 // CHECK: linalg
index 12e6aee..92efef6 100644 (file)
@@ -1703,7 +1703,7 @@ int main(int argc, char **argv) {
   if (testEmitIncludeTdHeader)
     output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\"";
 
-  MLIRContext context;
+  MLIRContext context(/*loadAllDialects=*/false);
   llvm::SourceMgr mgr;
   mgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
   Parser parser(mgr, &context);
index f5f5a47..6f27949 100644 (file)
@@ -48,6 +48,7 @@ void registerTestConstantFold();
 void registerTestConvertGPUKernelToCubinPass();
 void registerTestConvertGPUKernelToHsacoPass();
 void registerTestDominancePass();
+void registerTestDialect(DialectRegistry &);
 void registerTestExpandTanhPass();
 void registerTestFunc();
 void registerTestGpuMemoryPromotionPass();
@@ -130,5 +131,10 @@ int main(int argc, char **argv) {
 #ifdef MLIR_INCLUDE_TESTS
   registerTestPasses();
 #endif
-  return failed(MlirOptMain(argc, argv, "MLIR modular optimizer driver"));
+  DialectRegistry registry;
+  registerAllDialects(registry);
+  registerTestDialect(registry);
+  return failed(MlirOptMain(argc, argv, "MLIR modular optimizer driver\n",
+                            registry,
+                            /*preloadDialectsInContext=*/false));
 }
index 13421c4..3a19379 100644 (file)
@@ -61,11 +61,14 @@ filterForDialect(ArrayRef<llvm::Record *> records, Dialect &dialect) {
 ///
 /// {0}: The name of the dialect class.
 /// {1}: The dialect namespace.
+/// {2}: initialization code that is emitted in the ctor body before calling
+/// initialize()
 static const char *const dialectDeclBeginStr = R"(
 class {0} : public ::mlir::Dialect {
   explicit {0}(::mlir::MLIRContext *context)
     : ::mlir::Dialect(getDialectNamespace(), context,
       ::mlir::TypeID::get<{0}>()) {{
+    {2}
     initialize();
   }
   void initialize();
@@ -74,6 +77,12 @@ public:
   static ::llvm::StringRef getDialectNamespace() { return "{1}"; }
 )";
 
+/// Registration for a single dependent dialect: to be inserted in the ctor
+/// above for each dependent dialect.
+const char *const dialectRegistrationTemplate = R"(
+    getContext()->getOrLoadDialect<{0}>();
+)";
+
 /// The code block for the attribute parser/printer hooks.
 static const char *const attrParserDecl = R"(
   /// Parse an attribute registered to this dialect.
@@ -136,9 +145,18 @@ static void emitDialectDecl(Dialect &dialect,
                             iterator_range<DialectFilterIterator> dialectAttrs,
                             iterator_range<DialectFilterIterator> dialectTypes,
                             raw_ostream &os) {
+  /// Build the list of dependent dialects
+  std::string dependentDialectRegistrations;
+  {
+    llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
+    for (StringRef dependentDialect : dialect.getDependentDialects())
+      dialectsOs << llvm::formatv(dialectRegistrationTemplate,
+                                  dependentDialect);
+  }
   // Emit the start of the decl.
   std::string cppName = dialect.getCppClassName();
-  os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName());
+  os << llvm::formatv(dialectDeclBeginStr, cppName, dialect.getName(),
+                      dependentDialectRegistrations);
 
   // Check for any attributes/types registered to this dialect.  If there are,
   // add the hooks for parsing/printing.
index c2dcdb8..c1664a0 100644 (file)
@@ -36,6 +36,7 @@ static llvm::cl::opt<std::string>
 /// {0}: The def name of the pass record.
 /// {1}: The base class for the pass.
 /// {2): The command line argument for the pass.
+/// {3}: The dependent dialects registration.
 const char *const passDeclBegin = R"(
 //===----------------------------------------------------------------------===//
 // {0}
@@ -63,9 +64,20 @@ public:
     return std::make_unique<DerivedT>(*static_cast<const DerivedT *>(this));
   }
 
+  /// Return the dialect that must be loaded in the context before this pass.
+  void getDependentDialects(::mlir::DialectRegistry &registry) const override {
+    {3}
+  }
+
 protected:
 )";
 
+/// Registration for a single dependent dialect, to be inserted for each
+/// dependent dialect in the `getDependentDialects` above.
+const char *const dialectRegistrationTemplate = R"(
+  registry.insert<{0}>();
+)";
+
 /// Emit the declarations for each of the pass options.
 static void emitPassOptionDecls(const Pass &pass, raw_ostream &os) {
   for (const PassOption &opt : pass.getOptions()) {
@@ -94,8 +106,15 @@ static void emitPassStatisticDecls(const Pass &pass, raw_ostream &os) {
 
 static void emitPassDecl(const Pass &pass, raw_ostream &os) {
   StringRef defName = pass.getDef()->getName();
+  std::string dependentDialectRegistrations;
+  {
+    llvm::raw_string_ostream dialectsOs(dependentDialectRegistrations);
+    for (StringRef dependentDialect : pass.getDependentDialects())
+      dialectsOs << llvm::formatv(dialectRegistrationTemplate,
+                                  dependentDialect);
+  }
   os << llvm::formatv(passDeclBegin, defName, pass.getBaseClass(),
-                      pass.getArgument());
+                      pass.getArgument(), dependentDialectRegistrations);
   emitPassOptionDecls(pass, os);
   emitPassStatisticDecls(pass, os);
   os << "};\n";
index 914bd34..0d67286 100644 (file)
@@ -88,7 +88,8 @@ int main(int argc, char **argv) {
   // Processes the memory buffer with a new MLIRContext.
   auto processBuffer = [&](std::unique_ptr<llvm::MemoryBuffer> ownedBuffer,
                            raw_ostream &os) {
-    MLIRContext context;
+    MLIRContext context(false);
+    registerAllDialects(&context);
     context.allowUnregisteredDialects();
     context.printOpOnDiagnostic(!verifyDiagnostics);
     llvm::SourceMgr sourceMgr;
index 97c94a5..bae95e1 100644 (file)
@@ -17,9 +17,6 @@
 using namespace mlir;
 using namespace mlir::quant;
 
-// Load the quant dialect
-static DialectRegistration<QuantizationDialect> QuantOpsRegistration;
-
 namespace {
 
 // Test UniformQuantizedValueConverter converts all APFloat to a magic number 5.
@@ -78,7 +75,8 @@ UniformQuantizedType getTestQuantizedType(Type storageType, MLIRContext *ctx) {
 }
 
 TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
-  MLIRContext ctx;
+  MLIRContext ctx(/*loadAllDialects=*/false);
+  ctx.getOrLoadDialect<QuantizationDialect>();
   IntegerType convertedType = IntegerType::get(8, &ctx);
   auto quantizedType = getTestQuantizedType(convertedType, &ctx);
   TestUniformQuantizedValueConverter converter(quantizedType);
@@ -95,7 +93,8 @@ TEST(QuantizationUtilsTest, convertFloatAttrUniform) {
 }
 
 TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
-  MLIRContext ctx;
+  MLIRContext ctx(/*loadAllDialects=*/false);
+  ctx.getOrLoadDialect<QuantizationDialect>();
   IntegerType convertedType = IntegerType::get(8, &ctx);
   auto quantizedType = getTestQuantizedType(convertedType, &ctx);
   TestUniformQuantizedValueConverter converter(quantizedType);
@@ -119,7 +118,8 @@ TEST(QuantizationUtilsTest, convertRankedDenseAttrUniform) {
 }
 
 TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
-  MLIRContext ctx;
+  MLIRContext ctx(/*loadAllDialects=*/false);
+  ctx.getOrLoadDialect<QuantizationDialect>();
   IntegerType convertedType = IntegerType::get(8, &ctx);
   auto quantizedType = getTestQuantizedType(convertedType, &ctx);
   TestUniformQuantizedValueConverter converter(quantizedType);
@@ -143,7 +143,8 @@ TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
 }
 
 TEST(QuantizationUtilsTest, convertRankedSparseAttrUniform) {
-  MLIRContext ctx;
+  MLIRContext ctx(/*loadAllDialects=*/false);
+  ctx.getOrLoadDialect<QuantizationDialect>();
   IntegerType convertedType = IntegerType::get(8, &ctx);
   auto quantizedType = getTestQuantizedType(convertedType, &ctx);
   TestUniformQuantizedValueConverter converter(quantizedType);
index fe5632d..4aa2ffe 100644 (file)
@@ -38,7 +38,8 @@ using ::testing::StrEq;
 /// diagnostic checking utilities.
 class DeserializationTest : public ::testing::Test {
 protected:
-  DeserializationTest() {
+  DeserializationTest() : context(/*loadAllDialects=*/false) {
+    context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
     // Register a diagnostic handler to capture the diagnostic so that we can
     // check it later.
     context.getDiagEngine().registerHandler([&](Diagnostic &diag) {
index 3d57e55..cb89cd6 100644 (file)
@@ -36,7 +36,10 @@ using namespace mlir;
 
 class SerializationTest : public ::testing::Test {
 protected:
-  SerializationTest() { createModuleOp(); }
+  SerializationTest() : context(/*loadAllDialects=*/false) {
+    context.getOrLoadDialect<mlir::spirv::SPIRVDialect>();
+    createModuleOp();
+  }
 
   void createModuleOp() {
     OpBuilder builder(&context);
index df449a0..78f7dd5 100644 (file)
@@ -32,7 +32,7 @@ static void testSplat(Type eltType, const EltTy &splatElt) {
 
 namespace {
 TEST(DenseSplatTest, BoolSplat) {
-  MLIRContext context;
+  MLIRContext context(false);
   IntegerType boolTy = IntegerType::get(1, &context);
   RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
 
@@ -57,7 +57,7 @@ TEST(DenseSplatTest, BoolSplat) {
 TEST(DenseSplatTest, LargeBoolSplat) {
   constexpr int64_t boolCount = 56;
 
-  MLIRContext context;
+  MLIRContext context(false);
   IntegerType boolTy = IntegerType::get(1, &context);
   RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
 
@@ -80,7 +80,7 @@ TEST(DenseSplatTest, LargeBoolSplat) {
 }
 
 TEST(DenseSplatTest, BoolNonSplat) {
-  MLIRContext context;
+  MLIRContext context(false);
   IntegerType boolTy = IntegerType::get(1, &context);
   RankedTensorType shape = RankedTensorType::get({6}, boolTy);
 
@@ -92,7 +92,7 @@ TEST(DenseSplatTest, BoolNonSplat) {
 
 TEST(DenseSplatTest, OddIntSplat) {
   // Test detecting a splat with an odd(non 8-bit) integer bitwidth.
-  MLIRContext context;
+  MLIRContext context(false);
   constexpr size_t intWidth = 19;
   IntegerType intTy = IntegerType::get(intWidth, &context);
   APInt value(intWidth, 10);
@@ -101,7 +101,7 @@ TEST(DenseSplatTest, OddIntSplat) {
 }
 
 TEST(DenseSplatTest, Int32Splat) {
-  MLIRContext context;
+  MLIRContext context(false);
   IntegerType intTy = IntegerType::get(32, &context);
   int value = 64;
 
@@ -109,7 +109,7 @@ TEST(DenseSplatTest, Int32Splat) {
 }
 
 TEST(DenseSplatTest, IntAttrSplat) {
-  MLIRContext context;
+  MLIRContext context(false);
   IntegerType intTy = IntegerType::get(85, &context);
   Attribute value = IntegerAttr::get(intTy, 109);
 
@@ -117,7 +117,7 @@ TEST(DenseSplatTest, IntAttrSplat) {
 }
 
 TEST(DenseSplatTest, F32Splat) {
-  MLIRContext context;
+  MLIRContext context(false);
   FloatType floatTy = FloatType::getF32(&context);
   float value = 10.0;
 
@@ -125,7 +125,7 @@ TEST(DenseSplatTest, F32Splat) {
 }
 
 TEST(DenseSplatTest, F64Splat) {
-  MLIRContext context;
+  MLIRContext context(false);
   FloatType floatTy = FloatType::getF64(&context);
   double value = 10.0;
 
@@ -133,7 +133,7 @@ TEST(DenseSplatTest, F64Splat) {
 }
 
 TEST(DenseSplatTest, FloatAttrSplat) {
-  MLIRContext context;
+  MLIRContext context(false);
   FloatType floatTy = FloatType::getF32(&context);
   Attribute value = FloatAttr::get(floatTy, 10.0);
 
@@ -141,7 +141,7 @@ TEST(DenseSplatTest, FloatAttrSplat) {
 }
 
 TEST(DenseSplatTest, BF16Splat) {
-  MLIRContext context;
+  MLIRContext context(false);
   FloatType floatTy = FloatType::getBF16(&context);
   Attribute value = FloatAttr::get(floatTy, 10.0);
 
@@ -149,7 +149,7 @@ TEST(DenseSplatTest, BF16Splat) {
 }
 
 TEST(DenseSplatTest, StringSplat) {
-  MLIRContext context;
+  MLIRContext context(false);
   Type stringType =
       OpaqueType::get(Identifier::get("test", &context), "string", &context);
   StringRef value = "test-string";
@@ -157,7 +157,7 @@ TEST(DenseSplatTest, StringSplat) {
 }
 
 TEST(DenseSplatTest, StringAttrSplat) {
-  MLIRContext context;
+  MLIRContext context(false);
   Type stringType =
       OpaqueType::get(Identifier::get("test", &context), "string", &context);
   Attribute stringAttr = StringAttr::get("test-string", stringType);
@@ -165,28 +165,28 @@ TEST(DenseSplatTest, StringAttrSplat) {
 }
 
 TEST(DenseComplexTest, ComplexFloatSplat) {
-  MLIRContext context;
+  MLIRContext context(false);
   ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
   std::complex<float> value(10.0, 15.0);
   testSplat(complexType, value);
 }
 
 TEST(DenseComplexTest, ComplexIntSplat) {
-  MLIRContext context;
+  MLIRContext context(false);
   ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
   std::complex<int64_t> value(10, 15);
   testSplat(complexType, value);
 }
 
 TEST(DenseComplexTest, ComplexAPFloatSplat) {
-  MLIRContext context;
+  MLIRContext context(false);
   ComplexType complexType = ComplexType::get(FloatType::getF32(&context));
   std::complex<APFloat> value(APFloat(10.0f), APFloat(15.0f));
   testSplat(complexType, value);
 }
 
 TEST(DenseComplexTest, ComplexAPIntSplat) {
-  MLIRContext context;
+  MLIRContext context(false);
   ComplexType complexType = ComplexType::get(IntegerType::get(64, &context));
   std::complex<APInt> value(APInt(64, 10), APInt(64, 15));
   testSplat(complexType, value);
index bc389ce..5a0a229 100644 (file)
@@ -26,12 +26,12 @@ struct AnotherTestDialect : public Dialect {
 };
 
 TEST(DialectDeathTest, MultipleDialectsWithSameNamespace) {
-  MLIRContext context;
+  MLIRContext context(false);
 
   // Registering a dialect with the same namespace twice should result in a
   // failure.
-  context.getOrCreateDialect<TestDialect>();
-  ASSERT_DEATH(context.getOrCreateDialect<AnotherTestDialect>(), "");
+  context.loadDialect<TestDialect>();
+  ASSERT_DEATH(context.loadDialect<AnotherTestDialect>(), "");
 }
 
 } // end namespace
index 95ddccc..9669330 100644 (file)
@@ -25,7 +25,7 @@ static Operation *createOp(MLIRContext *context,
 
 namespace {
 TEST(OperandStorageTest, NonResizable) {
-  MLIRContext context;
+  MLIRContext context(false);
   Builder builder(&context);
 
   Operation *useOp =
@@ -49,7 +49,7 @@ TEST(OperandStorageTest, NonResizable) {
 }
 
 TEST(OperandStorageTest, Resizable) {
-  MLIRContext context;
+  MLIRContext context(false);
   Builder builder(&context);
 
   Operation *useOp =
@@ -77,7 +77,7 @@ TEST(OperandStorageTest, Resizable) {
 }
 
 TEST(OperandStorageTest, RangeReplace) {
-  MLIRContext context;
+  MLIRContext context(false);
   Builder builder(&context);
 
   Operation *useOp =
@@ -113,7 +113,7 @@ TEST(OperandStorageTest, RangeReplace) {
 }
 
 TEST(OperandStorageTest, MutableRange) {
-  MLIRContext context;
+  MLIRContext context(false);
   Builder builder(&context);
 
   Operation *useOp =
index 41a9064..918c898 100644 (file)
@@ -29,7 +29,7 @@ struct OpSpecificAnalysis {
 };
 
 TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) {
-  MLIRContext context;
+  MLIRContext context(false);
 
   // Test fine grain invalidation of the module analysis manager.
   OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context)));
@@ -50,7 +50,7 @@ TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) {
 }
 
 TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) {
-  MLIRContext context;
+  MLIRContext context(false);
   Builder builder(&context);
 
   // Create a function and a module.
@@ -79,7 +79,7 @@ TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) {
 }
 
 TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) {
-  MLIRContext context;
+  MLIRContext context(false);
   Builder builder(&context);
 
   // Create a function and a module.
@@ -122,7 +122,7 @@ struct CustomInvalidatingAnalysis {
 };
 
 TEST(AnalysisManagerTest, CustomInvalidation) {
-  MLIRContext context;
+  MLIRContext context(false);
   Builder builder(&context);
 
   // Create a function and a module.
index 61d6706..bbe87e3 100644 (file)
 
 using namespace mlir;
 
-/// Load the SDBM dialect.
-static DialectRegistration<SDBMDialect> SDBMRegistration;
 
 static MLIRContext *ctx() {
-  static thread_local MLIRContext context;
+  static thread_local MLIRContext context(false);
+  context.getOrLoadDialect<SDBMDialect>();
   return &context;
 }
 
 static SDBMDialect *dialect() {
   static thread_local SDBMDialect *d = nullptr;
   if (!d) {
-    d = ctx()->getRegisteredDialect<SDBMDialect>();
+    d = ctx()->getOrLoadDialect<SDBMDialect>();
   }
   return d;
 }
index 3e3256e..46a37da 100644 (file)
@@ -25,11 +25,16 @@ namespace mlir {
 // Test Fixture
 //===----------------------------------------------------------------------===//
 
+static MLIRContext &getContext() {
+  static MLIRContext ctx(false);
+  ctx.getOrLoadDialect<TestDialect>();
+  return ctx;
+}
 /// Test fixture for providing basic utilities for testing.
 class OpBuildGenTest : public ::testing::Test {
 protected:
   OpBuildGenTest()
-      : ctx{}, builder(&ctx), loc(builder.getUnknownLoc()),
+      : ctx(getContext()), builder(&ctx), loc(builder.getUnknownLoc()),
         i32Ty(builder.getI32Type()), f32Ty(builder.getF32Type()),
         cstI32(builder.create<TableGenConstant>(loc, i32Ty)),
         cstF32(builder.create<TableGenConstant>(loc, f32Ty)),
@@ -86,7 +91,7 @@ protected:
   }
 
 protected:
-  MLIRContext ctx;
+  MLIRContext &ctx;
   OpBuilder builder;
   Location loc;
   Type i32Ty;
index c58fedb..14b0abc 100644 (file)
@@ -42,7 +42,7 @@ static test::TestStruct getTestStruct(mlir::MLIRContext *context) {
 /// Validates that test::TestStruct::classof correctly identifies a valid
 /// test::TestStruct.
 TEST(StructsGenTest, ClassofTrue) {
-  mlir::MLIRContext context;
+  mlir::MLIRContext context(false);
   auto structAttr = getTestStruct(&context);
   ASSERT_TRUE(test::TestStruct::classof(structAttr));
 }