NFC: Refactor Module to be value typed.
authorRiver Riddle <riverriddle@google.com>
Tue, 2 Jul 2019 17:49:17 +0000 (10:49 -0700)
committerMehdi Amini <aminim@google.com>
Tue, 2 Jul 2019 23:43:36 +0000 (16:43 -0700)
As with Functions, Module will soon become an operation, which are value-typed. This eases the transition from Module to ModuleOp. A new class, OwningModuleRef is provided to allow for owning a reference to a Module, and will auto-delete the held module on destruction.

PiperOrigin-RevId: 256196193

70 files changed:
mlir/bindings/python/pybind.cpp
mlir/examples/Linalg/Linalg1/include/linalg1/Common.h
mlir/examples/Linalg/Linalg1/include/linalg1/ConvertToLLVMDialect.h
mlir/examples/Linalg/Linalg1/lib/ConvertToLLVMDialect.cpp
mlir/examples/Linalg/Linalg2/Example.cpp
mlir/examples/Linalg/Linalg3/Conversion.cpp
mlir/examples/Linalg/Linalg3/Example.cpp
mlir/examples/Linalg/Linalg3/Execution.cpp
mlir/examples/Linalg/Linalg3/include/linalg3/ConvertToLLVMDialect.h
mlir/examples/Linalg/Linalg3/lib/ConvertToLLVMDialect.cpp
mlir/examples/Linalg/Linalg4/Example.cpp
mlir/examples/toy/Ch2/include/toy/MLIRGen.h
mlir/examples/toy/Ch2/mlir/MLIRGen.cpp
mlir/examples/toy/Ch2/toyc.cpp
mlir/examples/toy/Ch3/include/toy/MLIRGen.h
mlir/examples/toy/Ch3/mlir/MLIRGen.cpp
mlir/examples/toy/Ch3/toyc.cpp
mlir/examples/toy/Ch4/include/toy/MLIRGen.h
mlir/examples/toy/Ch4/mlir/MLIRGen.cpp
mlir/examples/toy/Ch4/mlir/ShapeInferencePass.cpp
mlir/examples/toy/Ch4/toyc.cpp
mlir/examples/toy/Ch5/include/toy/MLIRGen.h
mlir/examples/toy/Ch5/mlir/LateLowering.cpp
mlir/examples/toy/Ch5/mlir/MLIRGen.cpp
mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
mlir/examples/toy/Ch5/toyc.cpp
mlir/g3doc/WritingAPass.md
mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h
mlir/include/mlir/ExecutionEngine/ExecutionEngine.h
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/IR/Function.h
mlir/include/mlir/IR/Module.h
mlir/include/mlir/IR/SymbolTable.h
mlir/include/mlir/Parser.h
mlir/include/mlir/Pass/AnalysisManager.h
mlir/include/mlir/Pass/Pass.h
mlir/include/mlir/Pass/PassManager.h
mlir/include/mlir/Target/LLVMIR.h
mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
mlir/include/mlir/Target/NVVMIR.h
mlir/include/mlir/Transforms/DialectConversion.h
mlir/include/mlir/Translation.h
mlir/lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp
mlir/lib/Conversion/GPUToCUDA/ConvertLaunchFuncToCudaCalls.cpp
mlir/lib/Conversion/GPUToCUDA/GenerateCubinAccessors.cpp
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/lib/ExecutionEngine/ExecutionEngine.cpp
mlir/lib/GPU/IR/GPUDialect.cpp
mlir/lib/GPU/Transforms/KernelOutlining.cpp
mlir/lib/IR/AsmPrinter.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/IR/Function.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Pass/IRPrinting.cpp
mlir/lib/Pass/Pass.cpp
mlir/lib/Pass/PassDetail.h
mlir/lib/SPIRV/Serialization/ConvertFromBinary.cpp
mlir/lib/SPIRV/Serialization/ConvertToBinary.cpp
mlir/lib/StandardOps/Ops.cpp
mlir/lib/Support/MlirOptMain.cpp
mlir/lib/Support/TranslateClParser.cpp
mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/EDSC/builder-api-test.cpp
mlir/tools/mlir-cpu-runner/mlir-cpu-runner-lib.cpp
mlir/unittests/Pass/AnalysisManagerTest.cpp

index cdf4a7f..f730f8e 100644 (file)
@@ -146,8 +146,8 @@ struct PythonFunction {
 /// Trivial C++ wrappers make use of the EDSC C API.
 struct PythonMLIRModule {
   PythonMLIRModule()
-      : mlirContext(), module(new mlir::Module(&mlirContext)),
-        moduleManager(module.get()) {}
+      : mlirContext(), module(mlir::Module::create(&mlirContext)),
+        moduleManager(*module) {}
 
   PythonType makeScalarType(const std::string &mlirElemType,
                             unsigned bitwidth) {
@@ -197,12 +197,12 @@ struct PythonMLIRModule {
     manager.addPass(mlir::createCSEPass());
     manager.addPass(mlir::createLowerAffinePass());
     manager.addPass(mlir::createConvertToLLVMIRPass());
-    if (failed(manager.run(module.get()))) {
+    if (failed(manager.run(*module))) {
       llvm::errs() << "conversion to the LLVM IR dialect failed\n";
       return;
     }
 
-    auto created = mlir::ExecutionEngine::create(module.get());
+    auto created = mlir::ExecutionEngine::create(*module);
     llvm::handleAllErrors(created.takeError(),
                           [](const llvm::ErrorInfoBase &b) {
                             b.log(llvm::errs());
@@ -235,7 +235,7 @@ struct PythonMLIRModule {
 private:
   mlir::MLIRContext mlirContext;
   // One single module in a python-exposed MLIRContext for now.
-  std::unique_ptr<mlir::Module> module;
+  mlir::OwningModuleRef module;
   mlir::ModuleManager moduleManager;
   std::unique_ptr<mlir::ExecutionEngine> engine;
 };
index 1f129c6..38b304e 100644 (file)
@@ -57,7 +57,7 @@ inline mlir::MemRefType floatMemRefType(mlir::MLIRContext *context,
 }
 
 /// A basic function builder
-inline mlir::Function makeFunction(mlir::Module &module, llvm::StringRef name,
+inline mlir::Function makeFunction(mlir::Module module, llvm::StringRef name,
                                    llvm::ArrayRef<mlir::Type> types,
                                    llvm::ArrayRef<mlir::Type> resultTypes) {
   auto *context = module.getContext();
@@ -92,7 +92,7 @@ inline void cleanupAndPrintFunction(mlir::Function f) {
     }
   };
   auto pm = cleanupPassManager();
-  check(f.getModule()->verify());
+  check(f.getModule().verify());
   check(pm->run(f.getModule()));
   if (printToOuts)
     f.print(llvm::outs());
index e341705..fe77f4e 100644 (file)
@@ -51,7 +51,7 @@ void populateLinalg1ToLLVMConversionPatterns(
 /// Convert the Linalg dialect types and RangeOp, ViewOp and SliceOp operations
 /// to the LLVM IR dialect types and operations in the given `module`.  This is
 /// the main entry point to the conversion.
-void convertToLLVM(mlir::Module &module);
+void convertToLLVM(mlir::Module module);
 } // end namespace linalg
 
 #endif // LINALG1_CONVERTTOLLVMDIALECT_H_
index 0033107..5d20063 100644 (file)
@@ -406,11 +406,11 @@ struct LinalgTypeConverter : public LLVMTypeConverter {
 };
 } // end anonymous namespace
 
-void linalg::convertToLLVM(mlir::Module &module) {
+void linalg::convertToLLVM(mlir::Module module) {
   // Remove affine constructs if any by using an existing pass.
   PassManager pm;
   pm.addPass(createLowerAffinePass());
-  auto rr = pm.run(&module);
+  auto rr = pm.run(module);
   (void)rr;
   assert(succeeded(rr) && "affine loop lowering failed");
 
index 9534711..cb93b96 100644 (file)
@@ -34,10 +34,10 @@ using namespace linalg::intrinsics;
 
 TEST_FUNC(linalg_ops) {
   MLIRContext context;
-  Module module(&context);
+  OwningModuleRef module = Module::create(&context);
   auto indexType = mlir::IndexType::get(&context);
-  mlir::Function f =
-      makeFunction(module, "linalg_ops", {indexType, indexType, indexType}, {});
+  mlir::Function f = makeFunction(*module, "linalg_ops",
+                                  {indexType, indexType, indexType}, {});
 
   OpBuilder builder(f.getBody());
   ScopedContext scope(builder, f.getLoc());
@@ -73,9 +73,9 @@ TEST_FUNC(linalg_ops) {
 
 TEST_FUNC(linalg_ops_folded_slices) {
   MLIRContext context;
-  Module module(&context);
+  OwningModuleRef module = Module::create(&context);
   auto indexType = mlir::IndexType::get(&context);
-  mlir::Function f = makeFunction(module, "linalg_ops_folded_slices",
+  mlir::Function f = makeFunction(*module, "linalg_ops_folded_slices",
                                   {indexType, indexType, indexType}, {});
 
   OpBuilder builder(f.getBody());
index 23d1cfe..6bd428f 100644 (file)
@@ -37,7 +37,7 @@ using namespace linalg;
 using namespace linalg::common;
 using namespace linalg::intrinsics;
 
-Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
+Function makeFunctionWithAMatmulOp(Module module, StringRef name) {
   MLIRContext *context = module.getContext();
   auto dynamic2DMemRefType = floatMemRefType<2>(context);
   mlir::Function f = linalg::common::makeFunction(
@@ -66,11 +66,11 @@ Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
 
 TEST_FUNC(foo) {
   MLIRContext context;
-  Module module(&context);
-  mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
+  OwningModuleRef module = Module::create(&context);
+  mlir::Function f = makeFunctionWithAMatmulOp(*module, "matmul_as_loops");
   lowerToLoops(f);
 
-  convertLinalg3ToLLVM(module);
+  convertLinalg3ToLLVM(*module);
 
   // clang-format off
   // CHECK:      {{.*}} = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
@@ -104,7 +104,7 @@ TEST_FUNC(foo) {
   // CHECK-NEXT: {{.*}} = llvm.getelementptr {{.*}}[{{.*}}] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
   // CHECK-NEXT: llvm.store {{.*}}, {{.*}} : !llvm<"float*">
   // clang-format on
-  module.print(llvm::outs());
+  module->print(llvm::outs());
 }
 
 int main() {
index 8b04344..4ac6a00 100644 (file)
@@ -34,7 +34,7 @@ using namespace linalg;
 using namespace linalg::common;
 using namespace linalg::intrinsics;
 
-Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
+Function makeFunctionWithAMatmulOp(Module module, StringRef name) {
   MLIRContext *context = module.getContext();
   auto dynamic2DMemRefType = floatMemRefType<2>(context);
   mlir::Function f = linalg::common::makeFunction(
@@ -63,7 +63,7 @@ Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
 
 TEST_FUNC(matmul_as_matvec) {
   MLIRContext context;
-  Module module(&context);
+  Module module = Module::create(&context);
   mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec");
   lowerToFinerGrainedTensorContraction(f);
   composeSliceOps(f);
@@ -81,7 +81,7 @@ TEST_FUNC(matmul_as_matvec) {
 
 TEST_FUNC(matmul_as_dot) {
   MLIRContext context;
-  Module module(&context);
+  Module module = Module::create(&context);
   mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_dot");
   lowerToFinerGrainedTensorContraction(f);
   lowerToFinerGrainedTensorContraction(f);
@@ -102,7 +102,7 @@ TEST_FUNC(matmul_as_dot) {
 
 TEST_FUNC(matmul_as_loops) {
   MLIRContext context;
-  Module module(&context);
+  Module module = Module::create(&context);
   mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
   lowerToLoops(f);
   composeSliceOps(f);
@@ -134,7 +134,7 @@ TEST_FUNC(matmul_as_loops) {
 
 TEST_FUNC(matmul_as_matvec_as_loops) {
   MLIRContext context;
-  Module module(&context);
+  Module module = Module::create(&context);
   mlir::Function f =
       makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_loops");
   lowerToFinerGrainedTensorContraction(f);
@@ -165,7 +165,7 @@ TEST_FUNC(matmul_as_matvec_as_loops) {
 
 TEST_FUNC(matmul_as_matvec_as_affine) {
   MLIRContext context;
-  Module module(&context);
+  Module module = Module::create(&context);
   mlir::Function f =
       makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_affine");
   lowerToFinerGrainedTensorContraction(f);
index 94b233a..4b10787 100644 (file)
@@ -37,7 +37,7 @@ using namespace linalg;
 using namespace linalg::common;
 using namespace linalg::intrinsics;
 
-Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
+Function makeFunctionWithAMatmulOp(Module module, StringRef name) {
   MLIRContext *context = module.getContext();
   auto dynamic2DMemRefType = floatMemRefType<2>(context);
   mlir::Function f = linalg::common::makeFunction(
@@ -109,14 +109,14 @@ TEST_FUNC(execution) {
   // linalg.matmul operation and lower it all the way down to the LLVM IR
   // dialect through partial conversions.
   MLIRContext context;
-  Module module(&context);
-  mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
+  OwningModuleRef module = Module::create(&context);
+  mlir::Function f = makeFunctionWithAMatmulOp(*module, "matmul_as_loops");
   lowerToLoops(f);
-  convertLinalg3ToLLVM(module);
+  convertLinalg3ToLLVM(*module);
 
   // Create an MLIR execution engine. The execution engine eagerly JIT-compiles
   // the module.
-  auto maybeEngine = mlir::ExecutionEngine::create(&module);
+  auto maybeEngine = mlir::ExecutionEngine::create(*module);
   assert(maybeEngine && "failed to construct an execution engine");
   auto &engine = maybeEngine.get();
 
index 8f122e0..a1854ae 100644 (file)
@@ -23,7 +23,7 @@ class Module;
 } // end namespace mlir
 
 namespace linalg {
-void convertLinalg3ToLLVM(mlir::Module &module);
+void convertLinalg3ToLLVM(mlir::Module module);
 } // end namespace linalg
 
 #endif // LINALG3_CONVERTTOLLVMDIALECT_H_
index 96b0f37..a01a7fd 100644 (file)
@@ -146,7 +146,7 @@ static void populateLinalg3ToLLVMConversionPatterns(
                                                                  context);
 }
 
-void linalg::convertLinalg3ToLLVM(Module &module) {
+void linalg::convertLinalg3ToLLVM(Module module) {
   // Remove affine constructs.
   for (auto func : module) {
     auto rr = lowerAffineConstructs(func);
index 873e57e..d8ad7c6 100644 (file)
@@ -34,7 +34,7 @@ using namespace linalg;
 using namespace linalg::common;
 using namespace linalg::intrinsics;
 
-Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
+Function makeFunctionWithAMatmulOp(Module module, StringRef name) {
   MLIRContext *context = module.getContext();
   auto dynamic2DMemRefType = floatMemRefType<2>(context);
   mlir::Function f = linalg::common::makeFunction(
@@ -64,8 +64,8 @@ Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
 
 TEST_FUNC(matmul_tiled_loops) {
   MLIRContext context;
-  Module module(&context);
-  mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_loops");
+  OwningModuleRef module = Module::create(&context);
+  mlir::Function f = makeFunctionWithAMatmulOp(*module, "matmul_tiled_loops");
   lowerToTiledLoops(f, {8, 9});
   PassManager pm;
   pm.addPass(createLowerLinalgLoadStorePass());
@@ -95,8 +95,8 @@ TEST_FUNC(matmul_tiled_loops) {
 
 TEST_FUNC(matmul_tiled_views) {
   MLIRContext context;
-  Module module(&context);
-  mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views");
+  OwningModuleRef module = Module::create(&context);
+  mlir::Function f = makeFunctionWithAMatmulOp(*module, "matmul_tiled_views");
   OpBuilder b(f.getBody());
   lowerToTiledViews(f, {b.create<ConstantIndexOp>(f.getLoc(), 8),
                         b.create<ConstantIndexOp>(f.getLoc(), 9)});
@@ -124,9 +124,9 @@ TEST_FUNC(matmul_tiled_views) {
 
 TEST_FUNC(matmul_tiled_views_as_loops) {
   MLIRContext context;
-  Module module(&context);
+  OwningModuleRef module = Module::create(&context);
   mlir::Function f =
-      makeFunctionWithAMatmulOp(module, "matmul_tiled_views_as_loops");
+      makeFunctionWithAMatmulOp(*module, "matmul_tiled_views_as_loops");
   OpBuilder b(f.getBody());
   lowerToTiledViews(f, {b.create<ConstantIndexOp>(f.getLoc(), 8),
                         b.create<ConstantIndexOp>(f.getLoc(), 9)});
index 21637bc..287f432 100644 (file)
@@ -27,7 +27,7 @@
 
 namespace mlir {
 class MLIRContext;
-class Module;
+class OwningModuleRef;
 } // namespace mlir
 
 namespace toy {
@@ -35,8 +35,7 @@ class ModuleAST;
 
 /// Emit IR for the given Toy moduleAST, returns a newly created MLIR module
 /// or nullptr on failure.
-std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
-                                      ModuleAST &moduleAST);
+mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST);
 } // namespace toy
 
 #endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_
index 73789fa..e3f1b19 100644 (file)
@@ -66,22 +66,22 @@ public:
 
   /// Public API: convert the AST for a Toy module (source file) to an MLIR
   /// Module.
-  std::unique_ptr<mlir::Module> mlirGen(ModuleAST &moduleAST) {
+  mlir::Module mlirGen(ModuleAST &moduleAST) {
     // We create an empty MLIR module and codegen functions one at a time and
     // add them to the module.
-    theModule = make_unique<mlir::Module>(&context);
+    theModule = mlir::Module::create(&context);
 
     for (FunctionAST &F : moduleAST) {
       auto func = mlirGen(F);
       if (!func)
         return nullptr;
-      theModule->push_back(func);
+      theModule.push_back(func);
     }
 
     // FIXME: (in the next chapter...) without registering a dialect in MLIR,
     // this won't do much, but it should at least check some structural
     // properties.
-    if (failed(theModule->verify())) {
+    if (failed(theModule.verify())) {
       emitError(mlir::UnknownLoc::get(&context), "Module verification error");
       return nullptr;
     }
@@ -96,7 +96,7 @@ private:
   mlir::MLIRContext &context;
 
   /// A "module" matches a source file: it contains a list of functions.
-  std::unique_ptr<mlir::Module> theModule;
+  mlir::Module theModule;
 
   /// The builder is a helper class to create IR inside a function. It is
   /// re-initialized every time we enter a function and kept around as a
@@ -500,8 +500,8 @@ private:
 namespace toy {
 
 // The public API for codegen.
-std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
-                                      ModuleAST &moduleAST) {
+mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context,
+                              ModuleAST &moduleAST) {
   return MLIRGenImpl(context).mlirGen(moduleAST);
 }
 
index 9846764..b541486 100644 (file)
@@ -75,7 +75,7 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
 
 int dumpMLIR() {
   mlir::MLIRContext context;
-  std::unique_ptr<mlir::Module> module;
+  mlir::OwningModuleRef module;
   if (inputType == InputType::MLIR ||
       llvm::StringRef(inputFilename).endswith(".mlir")) {
     llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
@@ -86,7 +86,7 @@ int dumpMLIR() {
     }
     llvm::SourceMgr sourceMgr;
     sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
-    module.reset(mlir::parseSourceFile(sourceMgr, &context));
+    module = mlir::parseSourceFile(sourceMgr, &context);
     if (!module) {
       llvm::errs() << "Error can't load file " << inputFilename << "\n";
       return 3;
index 21637bc..287f432 100644 (file)
@@ -27,7 +27,7 @@
 
 namespace mlir {
 class MLIRContext;
-class Module;
+class OwningModuleRef;
 } // namespace mlir
 
 namespace toy {
@@ -35,8 +35,7 @@ class ModuleAST;
 
 /// Emit IR for the given Toy moduleAST, returns a newly created MLIR module
 /// or nullptr on failure.
-std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
-                                      ModuleAST &moduleAST);
+mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST);
 } // namespace toy
 
 #endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_
index 23cb853..5d2e3af 100644 (file)
@@ -67,10 +67,10 @@ public:
 
   /// Public API: convert the AST for a Toy module (source file) to an MLIR
   /// Module.
-  std::unique_ptr<mlir::Module> mlirGen(ModuleAST &moduleAST) {
+  mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) {
     // We create an empty MLIR module and codegen functions one at a time and
     // add them to the module.
-    theModule = make_unique<mlir::Module>(&context);
+    theModule = mlir::Module::create(&context);
 
     for (FunctionAST &F : moduleAST) {
       auto func = mlirGen(F);
@@ -97,7 +97,7 @@ private:
   mlir::MLIRContext &context;
 
   /// A "module" matches a source file: it contains a list of functions.
-  std::unique_ptr<mlir::Module> theModule;
+  mlir::OwningModuleRef theModule;
 
   /// The builder is a helper class to create IR inside a function. It is
   /// re-initialized every time we enter a function and kept around as a
@@ -469,8 +469,8 @@ private:
 namespace toy {
 
 // The public API for codegen.
-std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
-                                      ModuleAST &moduleAST) {
+mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context,
+                              ModuleAST &moduleAST) {
   return MLIRGenImpl(context).mlirGen(moduleAST);
 }
 
index 3d18417..864dc42 100644 (file)
@@ -79,7 +79,7 @@ int dumpMLIR() {
   mlir::registerDialect<ToyDialect>();
 
   mlir::MLIRContext context;
-  std::unique_ptr<mlir::Module> module;
+  mlir::OwningModuleRef module;
   if (inputType == InputType::MLIR ||
       llvm::StringRef(inputFilename).endswith(".mlir")) {
     llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
@@ -90,7 +90,7 @@ int dumpMLIR() {
     }
     llvm::SourceMgr sourceMgr;
     sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
-    module.reset(mlir::parseSourceFile(sourceMgr, &context));
+    module = mlir::parseSourceFile(sourceMgr, &context);
     if (!module) {
       llvm::errs() << "Error can't load file " << inputFilename << "\n";
       return 3;
index 21637bc..287f432 100644 (file)
@@ -27,7 +27,7 @@
 
 namespace mlir {
 class MLIRContext;
-class Module;
+class OwningModuleRef;
 } // namespace mlir
 
 namespace toy {
@@ -35,8 +35,7 @@ class ModuleAST;
 
 /// Emit IR for the given Toy moduleAST, returns a newly created MLIR module
 /// or nullptr on failure.
-std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
-                                      ModuleAST &moduleAST);
+mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST);
 } // namespace toy
 
 #endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_
index f2132c2..a482124 100644 (file)
@@ -67,10 +67,10 @@ public:
 
   /// Public API: convert the AST for a Toy module (source file) to an MLIR
   /// Module.
-  std::unique_ptr<mlir::Module> mlirGen(ModuleAST &moduleAST) {
+  mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) {
     // We create an empty MLIR module and codegen functions one at a time and
     // add them to the module.
-    theModule = make_unique<mlir::Module>(&context);
+    theModule = mlir::Module::create(&context);
 
     for (FunctionAST &F : moduleAST) {
       auto func = mlirGen(F);
@@ -97,7 +97,7 @@ private:
   mlir::MLIRContext &context;
 
   /// A "module" matches a source file: it contains a list of functions.
-  std::unique_ptr<mlir::Module> theModule;
+  mlir::OwningModuleRef theModule;
 
   /// The builder is a helper class to create IR inside a function. It is
   /// re-initialized every time we enter a function and kept around as a
@@ -469,8 +469,8 @@ private:
 namespace toy {
 
 // The public API for codegen.
-std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
-                                      ModuleAST &moduleAST) {
+mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context,
+                              ModuleAST &moduleAST) {
   return MLIRGenImpl(context).mlirGen(moduleAST);
 }
 
index f237fd9..3650e5f 100644 (file)
@@ -119,7 +119,7 @@ public:
   };
 
   void runOnModule() override {
-    auto &module = getModule();
+    auto module = getModule();
     auto main = module.getNamedFunction("main");
     if (!main) {
       emitError(mlir::UnknownLoc::get(module.getContext()),
index 77c039d..8dbc6f8 100644 (file)
@@ -78,7 +78,7 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
   return parser.ParseModule();
 }
 
-mlir::LogicalResult optimize(mlir::Module &module) {
+mlir::LogicalResult optimize(mlir::Module module) {
   mlir::PassManager pm;
   pm.addPass(mlir::createCanonicalizerPass());
   pm.addPass(createShapeInferencePass());
@@ -86,7 +86,7 @@ mlir::LogicalResult optimize(mlir::Module &module) {
   // Apply any generic pass manager command line options.
   applyPassManagerCLOptions(pm);
 
-  return pm.run(&module);
+  return pm.run(module);
 }
 
 int dumpMLIR() {
@@ -97,7 +97,7 @@ int dumpMLIR() {
   mlir::registerPassManagerCLOptions();
 
   mlir::MLIRContext context;
-  std::unique_ptr<mlir::Module> module;
+  mlir::OwningModuleRef module;
   if (inputType == InputType::MLIR ||
       llvm::StringRef(inputFilename).endswith(".mlir")) {
     llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
@@ -108,7 +108,7 @@ int dumpMLIR() {
     }
     llvm::SourceMgr sourceMgr;
     sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
-    module.reset(mlir::parseSourceFile(sourceMgr, &context));
+    module = mlir::parseSourceFile(sourceMgr, &context);
     if (!module) {
       llvm::errs() << "Error can't load file " << inputFilename << "\n";
       return 3;
index 21637bc..287f432 100644 (file)
@@ -27,7 +27,7 @@
 
 namespace mlir {
 class MLIRContext;
-class Module;
+class OwningModuleRef;
 } // namespace mlir
 
 namespace toy {
@@ -35,8 +35,7 @@ class ModuleAST;
 
 /// Emit IR for the given Toy moduleAST, returns a newly created MLIR module
 /// or nullptr on failure.
-std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
-                                      ModuleAST &moduleAST);
+mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST);
 } // namespace toy
 
 #endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_
index 60a8b5a..1c22080 100644 (file)
@@ -136,7 +136,7 @@ public:
   PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
                                      PatternRewriter &rewriter) const override {
     // Get or create the declaration of the printf function in the module.
-    Function printfFunc = getPrintf(*op->getFunction().getModule());
+    Function printfFunc = getPrintf(op->getFunction().getModule());
 
     auto print = cast<toy::PrintOp>(op);
     auto loc = print.getLoc();
@@ -205,13 +205,13 @@ private:
 
   /// Return the prototype declaration for printf in the module, create it if
   /// necessary.
-  Function getPrintf(Module &module) const {
+  Function getPrintf(Module module) const {
     auto printfFunc = module.getNamedFunction("printf");
     if (printfFunc)
       return printfFunc;
 
     // Create a function declaration for printf, signature is `i32 (i8*, ...)`
-    Builder builder(&module);
+    Builder builder(module);
     auto *dialect =
         module.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
 
index 9ebfeb4..e0e8883 100644 (file)
@@ -67,10 +67,10 @@ public:
 
   /// Public API: convert the AST for a Toy module (source file) to an MLIR
   /// Module.
-  std::unique_ptr<mlir::Module> mlirGen(ModuleAST &moduleAST) {
+  mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) {
     // We create an empty MLIR module and codegen functions one at a time and
     // add them to the module.
-    theModule = make_unique<mlir::Module>(&context);
+    theModule = mlir::Module::create(&context);
 
     for (FunctionAST &F : moduleAST) {
       auto func = mlirGen(F);
@@ -97,7 +97,7 @@ private:
   mlir::MLIRContext &context;
 
   /// A "module" matches a source file: it contains a list of functions.
-  std::unique_ptr<mlir::Module> theModule;
+  mlir::OwningModuleRef theModule;
 
   /// The builder is a helper class to create IR inside a function. It is
   /// re-initialized every time we enter a function and kept around as a
@@ -469,8 +469,8 @@ private:
 namespace toy {
 
 // The public API for codegen.
-std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
-                                      ModuleAST &moduleAST) {
+mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context,
+                              ModuleAST &moduleAST) {
   return MLIRGenImpl(context).mlirGen(moduleAST);
 }
 
index 0abcb4b..971cf0a 100644 (file)
@@ -119,8 +119,8 @@ public:
   };
 
   void runOnModule() override {
-    auto &module = getModule();
-    mlir::ModuleManager moduleManager(&module);
+    auto module = getModule();
+    mlir::ModuleManager moduleManager(module);
     auto main = moduleManager.getNamedFunction("main");
     if (!main) {
       emitError(mlir::UnknownLoc::get(module.getContext()),
index 9637d72..b5bdde8 100644 (file)
@@ -101,7 +101,7 @@ std::unique_ptr<toy::ModuleAST> parseInputFile(llvm::StringRef filename) {
   return parser.ParseModule();
 }
 
-mlir::LogicalResult optimize(mlir::Module &module) {
+mlir::LogicalResult optimize(mlir::Module module) {
   mlir::PassManager pm;
   pm.addPass(mlir::createCanonicalizerPass());
   pm.addPass(createShapeInferencePass());
@@ -111,10 +111,10 @@ mlir::LogicalResult optimize(mlir::Module &module) {
   // Apply any generic pass manager command line options.
   applyPassManagerCLOptions(pm);
 
-  return pm.run(&module);
+  return pm.run(module);
 }
 
-mlir::LogicalResult lowerDialect(mlir::Module &module, bool OnlyLinalg) {
+mlir::LogicalResult lowerDialect(mlir::Module module, bool OnlyLinalg) {
   mlir::PassManager pm;
   pm.addPass(createEarlyLoweringPass());
   pm.addPass(mlir::createCanonicalizerPass());
@@ -127,14 +127,14 @@ mlir::LogicalResult lowerDialect(mlir::Module &module, bool OnlyLinalg) {
   // Apply any generic pass manager command line options.
   applyPassManagerCLOptions(pm);
 
-  return pm.run(&module);
+  return pm.run(module);
 }
 
-std::unique_ptr<mlir::Module> loadFileAndProcessModule(
+mlir::OwningModuleRef loadFileAndProcessModule(
     mlir::MLIRContext &context, bool EnableLinalgLowering = false,
     bool EnableLLVMLowering = false, bool EnableOpt = false) {
 
-  std::unique_ptr<mlir::Module> module;
+  mlir::OwningModuleRef module;
   if (inputType == InputType::MLIR ||
       llvm::StringRef(inputFilename).endswith(".mlir")) {
     llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
@@ -145,7 +145,7 @@ std::unique_ptr<mlir::Module> loadFileAndProcessModule(
     }
     llvm::SourceMgr sourceMgr;
     sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
-    module.reset(mlir::parseSourceFile(sourceMgr, &context));
+    module = mlir::parseSourceFile(sourceMgr, &context);
     if (!module) {
       llvm::errs() << "Error can't load file " << inputFilename << "\n";
       return nullptr;
@@ -252,7 +252,7 @@ int runJit() {
   // the module.
   auto optPipeline = mlir::makeOptimizingTransformer(
       /* optLevel=*/EnableOpt ? 3 : 0, /* sizeLevel=*/0);
-  auto maybeEngine = mlir::ExecutionEngine::create(module.get(), optPipeline);
+  auto maybeEngine = mlir::ExecutionEngine::create(*module, optPipeline);
   assert(maybeEngine && "failed to construct an execution engine");
   auto &engine = maybeEngine.get();
 
index 5b53a8a..691fe7c 100644 (file)
@@ -54,10 +54,10 @@ namespace {
 struct MyFunctionPass : public FunctionPass<MyFunctionPass> {
   void runOnFunction() override {
     // Get the current function being operated on.
-    Function *f = getFunction();
+    Function f = getFunction();
 
     // Operate on the operations within the function.
-    f->walk([](Operation *inst) {
+    f.walk([](Operation *inst) {
       ....
     });
   }
@@ -94,10 +94,10 @@ namespace {
 struct MyModulePass : public ModulePass<MyModulePass> {
   void runOnModule() override {
     // Get the current module being operated on.
-    Module *m = getModule();
+    Module m = getModule();
 
     // Operate on the functions within the module.
-    for (auto &func : *m) {
+    for (auto func : m) {
       ....
     }
   }
@@ -149,7 +149,7 @@ struct MyFunctionAnalysis {
 /// An interesting module analysis.
 struct MyModuleAnalysis {
   // Compute this analysis with the provided module.
-  MyModuleAnalysis(Module *module);
+  MyModuleAnalysis(Module module);
 };
 
 void MyFunctionPass::runOnFunction() {
@@ -181,7 +181,7 @@ void MyModulePass::runOnModule() {
 
   // Query MyFunctionAnalysis for a child function of the current module. It
   // will be computed if it doesn't exist.
-  auto *fn = &*getModule().begin();
+  auto fn = *getModule().begin();
   MyFunctionAnalysis &myAnalysis = getFunctionAnalysis<MyFunctionAnalysis>(fn);
 }
 ```
@@ -255,7 +255,7 @@ pm.addPass(new MyFunctionPass3());
 pm.addPass(new MyModulePass2());
 
 // Run the pass manager on a module.
-Module *m = ...;
+Module m = ...;
 if (failed(pm.run(m)))
     ... // One of the passes signaled a failure.
 ```
@@ -384,7 +384,7 @@ unsigned domInfoCount;
 pm.addInstrumentation(new DominanceCounterInstrumentation(domInfoCount));
 
 // Run the pass manager on a module.
-Module *m = ...;
+Module m = ...;
 if (failed(pm.run(m)))
     ...
 
index f7f4caa..a949a87 100644 (file)
@@ -48,7 +48,7 @@ namespace LLVM {
 /// support different values coming from the same predecessor.  If a block has
 /// another block as a successor more than once with different values, insert
 /// a new dummy block for LLVM PHI nodes to tell the sources apart.
-void ensureDistinctSuccessors(Module *m);
+void ensureDistinctSuccessors(Module m);
 } // namespace LLVM
 
 } // namespace mlir
index 2ee2945..cbbfffd 100644 (file)
@@ -62,7 +62,7 @@ public:
   /// If `sharedLibPaths` are provided, the underlying JIT-compilation will open
   /// and link the shared libraries for symbol resolution.
   static llvm::Expected<std::unique_ptr<ExecutionEngine>>
-  create(Module *m, std::function<llvm::Error(llvm::Module *)> transformer = {},
+  create(Module m, std::function<llvm::Error(llvm::Module *)> transformer = {},
          ArrayRef<StringRef> sharedLibPaths = {});
 
   /// Looks up a packed-argument function with the given name and returns a
index e5c8c03..b954d46 100644 (file)
@@ -57,12 +57,12 @@ class UnitAttr;
 class Builder {
 public:
   explicit Builder(MLIRContext *context) : context(context) {}
-  explicit Builder(Module *module);
+  explicit Builder(Module module);
 
   MLIRContext *getContext() const { return context; }
 
   Identifier getIdentifier(StringRef str);
-  Module *createModule();
+  Module createModule();
 
   // Locations.
   Location getUnknownLoc();
index 8c66dea..0fee88c 100644 (file)
@@ -34,10 +34,12 @@ class MLIRContext;
 class Module;
 
 namespace detail {
+class ModuleStorage;
+
 /// This class represents all of the internal state of a Function. This allows
 /// for the Function class to be value typed.
 class FunctionStorage
-    : public llvm::ilist_node_with_parent<FunctionStorage, Module> {
+    : public llvm::ilist_node_with_parent<FunctionStorage, ModuleStorage> {
   FunctionStorage(Location location, StringRef name, FunctionType type,
                   ArrayRef<NamedAttribute> attrs = {});
   FunctionStorage(Location location, StringRef name, FunctionType type,
@@ -47,7 +49,7 @@ class FunctionStorage
   Identifier name;
 
   /// The module this function is embedded into.
-  Module *module = nullptr;
+  ModuleStorage *module = nullptr;
 
   /// The source location the function was defined or derived from.
   Location location;
@@ -116,7 +118,7 @@ public:
   }
 
   MLIRContext *getContext();
-  Module *getModule() { return impl->module; }
+  Module getModule();
 
   /// Add an entry block to an empty function, and set up the block arguments
   /// to match the signature of the function.
@@ -541,7 +543,7 @@ struct ilist_traits<::mlir::detail::FunctionStorage>
                              function_iterator first, function_iterator last);
 
 private:
-  mlir::Module *getContainingModule();
+  mlir::detail::ModuleStorage *getContainingModule();
 };
 
 // Functions hash just like pointers.
index d8a4789..a77653b 100644 (file)
 #include "llvm/ADT/ilist.h"
 
 namespace mlir {
+class Module;
+
+namespace detail {
+class ModuleStorage {
+  explicit ModuleStorage(MLIRContext *context) : context(context) {}
+
+  /// getSublistAccess() - Returns pointer to member of function list
+  static llvm::iplist<FunctionStorage> ModuleStorage::*
+  getSublistAccess(FunctionStorage *) {
+    return &ModuleStorage::functions;
+  }
+
+  /// The context attached to this module.
+  MLIRContext *context;
+
+  /// This is the actual list of functions the module contains.
+  llvm::iplist<FunctionStorage> functions;
+
+  friend Module;
+  friend struct llvm::ilist_traits<FunctionStorage>;
+  friend FunctionStorage;
+  friend Function;
+};
+} // end namespace detail
 
 class Module {
 public:
-  explicit Module(MLIRContext *context) : context(context) {}
+  Module(detail::ModuleStorage *impl = nullptr) : impl(impl) {}
+
+  /// Construct a new module object with the given context.
+  static Module create(MLIRContext *context) {
+    return new detail::ModuleStorage(context);
+  }
 
-  MLIRContext *getContext() { return context; }
+  MLIRContext *getContext() { return impl->context; }
+
+  /// Allow converting a Module to bool for null checks.
+  operator bool() const { return impl; }
+  bool operator==(Module other) const { return impl == other.impl; }
+  bool operator!=(Module other) const { return !(*this == other); }
 
   /// An iterator class used to iterate over the held functions.
   class iterator : public llvm::mapped_iterator<
@@ -56,14 +90,14 @@ public:
   llvm::iterator_range<iterator> getFunctions() { return {begin(), end()}; }
 
   // Iteration over the functions in the module.
-  iterator begin() { return functions.begin(); }
-  iterator end() { return functions.end(); }
-  Function front() { return &functions.front(); }
-  Function back() { return &functions.back(); }
+  iterator begin() { return impl->functions.begin(); }
+  iterator end() { return impl->functions.end(); }
+  Function front() { return &impl->functions.front(); }
+  Function back() { return &impl->functions.back(); }
 
-  void push_back(Function fn) { functions.push_back(fn.impl); }
+  void push_back(Function fn) { impl->functions.push_back(fn.impl); }
   void insert(iterator insertPt, Function fn) {
-    functions.insert(insertPt.getCurrent(), fn.impl);
+    impl->functions.insert(insertPt.getCurrent(), fn.impl);
   }
 
   // Interfaces for working with the symbol table.
@@ -79,6 +113,7 @@ public:
   /// name exists. Function names never include the @ on them. Note: This
   /// performs a linear scan of held symbols.
   Function getNamedFunction(Identifier name) {
+    auto &functions = impl->functions;
     auto it = llvm::find_if(functions, [name](detail::FunctionStorage &fn) {
       return Function(&fn).getName() == name;
     });
@@ -93,22 +128,27 @@ public:
   void print(raw_ostream &os);
   void dump();
 
-private:
-  friend struct llvm::ilist_traits<detail::FunctionStorage>;
-  friend detail::FunctionStorage;
-  friend Function;
+  /// Erase the current module.
+  void erase() {
+    assert(impl && "expected valid module");
+    delete impl;
+  }
 
-  /// getSublistAccess() - Returns pointer to member of function list
-  static llvm::iplist<detail::FunctionStorage> Module::*
-  getSublistAccess(detail::FunctionStorage *) {
-    return &Module::functions;
+  /// Methods for supporting PointerLikeTypeTraits.
+  const void *getAsOpaquePointer() const {
+    return static_cast<const void *>(impl);
+  }
+  static Module getFromOpaquePointer(const void *pointer) {
+    return reinterpret_cast<detail::ModuleStorage *>(
+        const_cast<void *>(pointer));
   }
 
-  /// The context attached to this module.
-  MLIRContext *context;
+private:
+  friend detail::FunctionStorage;
+  friend Function;
 
-  /// This is the actual list of functions the module contains.
-  llvm::iplist<detail::FunctionStorage> functions;
+  /// The internal impl storage object.
+  detail::ModuleStorage *impl = nullptr;
 };
 
 /// A class used to manage the symbols held by a module. This class handles
@@ -116,7 +156,7 @@ private:
 /// efficent named lookup to held symbols.
 class ModuleManager {
 public:
-  ModuleManager(Module *module) : module(module), symbolTable(module) {}
+  ModuleManager(Module module) : module(module), symbolTable(module) {}
 
   /// Look up a symbol with the specified name, returning null if no such
   /// name exists. Names must never include the @ on them.
@@ -127,11 +167,11 @@ public:
   /// Insert a new symbol into the module, auto-renaming it as necessary.
   void insert(Function function) {
     symbolTable.insert(function);
-    module->push_back(function);
+    module.push_back(function);
   }
   void insert(Module::iterator insertPt, Function function) {
     symbolTable.insert(function);
-    module->insert(insertPt, function);
+    module.insert(insertPt, function);
   }
 
   /// Remove the given symbol from the module symbol table and then erase it.
@@ -141,16 +181,53 @@ public:
   }
 
   /// Return the internally held module.
-  Module *getModule() const { return module; }
+  Module getModule() const { return module; }
 
   /// Return the context of the internal module.
-  MLIRContext *getContext() const { return module->getContext(); }
+  MLIRContext *getContext() const { return getModule().getContext(); }
 
 private:
-  Module *module;
+  Module module;
   SymbolTable symbolTable;
 };
 
+/// This class acts as an owning reference to a Module, and will automatically
+/// destory the held Module if valid.
+class OwningModuleRef {
+public:
+  OwningModuleRef(std::nullptr_t = nullptr) {}
+  OwningModuleRef(Module module) : module(module) {}
+  OwningModuleRef(OwningModuleRef &&other) : module(other.release()) {}
+  ~OwningModuleRef() {
+    if (module)
+      module.erase();
+  }
+
+  // Assign from another module reference.
+  OwningModuleRef &operator=(OwningModuleRef &&other) {
+    if (module)
+      module.erase();
+    module = other.release();
+    return *this;
+  }
+
+  /// Allow accessing the internal module.
+  Module get() const { return module; }
+  Module operator*() const { return module; }
+  Module *operator->() { return &module; }
+  explicit operator bool() const { return module; }
+
+  /// Release the referenced module.
+  Module release() {
+    Module released;
+    std::swap(released, module);
+    return released;
+  }
+
+private:
+  Module module;
+};
+
 //===--------------------------------------------------------------------===//
 // Module Operation.
 //===--------------------------------------------------------------------===//
@@ -196,4 +273,20 @@ public:
 
 } // end namespace mlir
 
+namespace llvm {
+
+/// Allow stealing the low bits of ModuleStorage.
+template <> struct PointerLikeTypeTraits<mlir::Module> {
+public:
+  static inline void *getAsVoidPointer(mlir::Module I) {
+    return const_cast<void *>(I.getAsOpaquePointer());
+  }
+  static inline mlir::Module getFromVoidPointer(void *P) {
+    return mlir::Module::getFromOpaquePointer(P);
+  }
+  enum { NumLowBitsAvailable = 3 };
+};
+
+} // end namespace llvm
+
 #endif // MLIR_IR_MODULE_H
index a351f66..10d3a5e 100644 (file)
@@ -31,7 +31,7 @@ class MLIRContext;
 class SymbolTable {
 public:
   /// Build a symbol table with the symbols within the given module.
-  SymbolTable(Module *module);
+  SymbolTable(Module module);
 
   /// Look up a symbol with the specified name, returning null if no such
   /// name exists. Names never include the @ on them.
index a2673ca..f347ff1 100644 (file)
@@ -37,24 +37,24 @@ class Type;
 /// This parses the file specified by the indicated SourceMgr and returns an
 /// MLIR module if it was valid.  If not, the error message is emitted through
 /// the error handler registered in the context, and a null pointer is returned.
-Module *parseSourceFile(const llvm::SourceMgr &sourceMgr, MLIRContext *context);
+Module parseSourceFile(const llvm::SourceMgr &sourceMgr, MLIRContext *context);
 
 /// This parses the file specified by the indicated filename and returns an
 /// MLIR module if it was valid.  If not, the error message is emitted through
 /// the error handler registered in the context, and a null pointer is returned.
-Module *parseSourceFile(llvm::StringRef filename, MLIRContext *context);
+Module parseSourceFile(llvm::StringRef filename, MLIRContext *context);
 
 /// This parses the file specified by the indicated filename using the provided
 /// SourceMgr and returns an MLIR module if it was valid.  If not, the error
 /// message is emitted through the error handler registered in the context, and
 /// a null pointer is returned.
-Module *parseSourceFile(llvm::StringRef filename, llvm::SourceMgr &sourceMgr,
-                        MLIRContext *context);
+Module parseSourceFile(llvm::StringRef filename, llvm::SourceMgr &sourceMgr,
+                       MLIRContext *context);
 
 /// This parses the module string to a MLIR module if it was valid.  If not, the
 /// error message is emitted through the error handler registered in the
 /// context, and a null pointer is returned.
-Module *parseSourceString(llvm::StringRef moduleStr, MLIRContext *context);
+Module parseSourceString(llvm::StringRef moduleStr, MLIRContext *context);
 
 /// This parses a single MLIR type to an MLIR context if it was valid.  If not,
 /// an error message is emitted through a new SourceMgrDiagnosticHandler
index c44f88f..18ba7a8 100644 (file)
@@ -223,7 +223,7 @@ private:
 /// An analysis manager for a specific module instance.
 class ModuleAnalysisManager {
 public:
-  ModuleAnalysisManager(Module *module, PassInstrumentor *passInstrumentor)
+  ModuleAnalysisManager(Module module, PassInstrumentor *passInstrumentor)
       : moduleAnalyses(module), passInstrumentor(passInstrumentor) {}
   ModuleAnalysisManager(const ModuleAnalysisManager &) = delete;
   ModuleAnalysisManager &operator=(const ModuleAnalysisManager &) = delete;
@@ -273,7 +273,7 @@ private:
       functionAnalyses;
 
   /// The analyses for the owning module.
-  detail::AnalysisMap<Module *> moduleAnalyses;
+  detail::AnalysisMap<Module> moduleAnalyses;
 
   /// An optional instrumentation object.
   PassInstrumentor *passInstrumentor;
index 41d20cc..6ee78c5 100644 (file)
@@ -138,8 +138,7 @@ private:
 /// Pass to transform a module. Derived passes should not inherit from this
 /// class directly, and instead should use the CRTP ModulePass class.
 class ModulePassBase : public Pass {
-  using PassStateT =
-      detail::PassExecutionState<Module *, ModuleAnalysisManager>;
+  using PassStateT = detail::PassExecutionState<Module, ModuleAnalysisManager>;
 
 public:
   static bool classof(const Pass *pass) {
@@ -153,7 +152,7 @@ protected:
   virtual void runOnModule() = 0;
 
   /// Return the current module being transformed.
-  Module &getModule() { return *getPassState().irAndPassFailed.getPointer(); }
+  Module getModule() { return getPassState().irAndPassFailed.getPointer(); }
 
   /// Return the MLIR context for the current module being transformed.
   MLIRContext &getContext() { return *getModule().getContext(); }
@@ -172,7 +171,7 @@ protected:
 private:
   /// Forwarding function to execute this pass.
   LLVM_NODISCARD
-  LogicalResult run(Module *module, ModuleAnalysisManager &mam);
+  LogicalResult run(Module module, ModuleAnalysisManager &mam);
 
   /// The current execution state for the pass.
   llvm::Optional<PassStateT> passState;
index 32d5fd6..3277d36 100644 (file)
@@ -60,7 +60,7 @@ public:
 
   /// Run the passes within this manager on the provided module.
   LLVM_NODISCARD
-  LogicalResult run(Module *module);
+  LogicalResult run(Module module);
 
   //===--------------------------------------------------------------------===//
   // Pipeline Building
index a4f45ff..e227c3b 100644 (file)
@@ -38,7 +38,7 @@ class Module;
 /// from the registered LLVM IR dialect.  In case of error, report it
 /// to the error handler registered with the MLIR context, if any (obtained from
 /// the MLIR module), and return `nullptr`.
-std::unique_ptr<llvm::Module> translateModuleToLLVMIR(Module &m);
+std::unique_ptr<llvm::Module> translateModuleToLLVMIR(Module m);
 
 } // namespace mlir
 
index 493b0ca..d76776e 100644 (file)
@@ -24,7 +24,7 @@
 #define MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
 
 #include "mlir/IR/Block.h"
-#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
 #include "mlir/IR/Value.h"
 
 #include "llvm/IR/BasicBlock.h"
@@ -48,7 +48,7 @@ namespace LLVM {
 class ModuleTranslation {
 public:
   template <typename T = ModuleTranslation>
-  static std::unique_ptr<llvm::Module> translateModule(Module &m) {
+  static std::unique_ptr<llvm::Module> translateModule(Module m) {
     auto llvmModule = prepareLLVMModule(m);
 
     T translator(m);
@@ -63,17 +63,17 @@ protected:
   // Translate the given MLIR module expressed in MLIR LLVM IR dialect into an
   // LLVM IR module.  The MLIR LLVM IR dialect holds a pointer to an
   // LLVMContext, the LLVM IR module will be created in that context.
-  explicit ModuleTranslation(Module &module) : mlirModule(module) {}
+  explicit ModuleTranslation(Module module) : mlirModule(module) {}
   virtual ~ModuleTranslation() {}
 
   virtual bool convertOperation(Operation &op, llvm::IRBuilder<> &builder);
-  static std::unique_ptr<llvm::Module> prepareLLVMModule(Module &m);
+  static std::unique_ptr<llvm::Module> prepareLLVMModule(Module m);
 
 private:
 
   bool convertFunctions();
-  bool convertOneFunction(Function &func);
-  void connectPHINodes(Function &func);
+  bool convertOneFunction(Function func);
+  void connectPHINodes(Function func);
   bool convertBlock(Block &bb, bool ignoreArguments);
 
   template <typename Range>
@@ -83,7 +83,7 @@ private:
                                   Location loc);
 
   // Original and translated module.
-  Module &mlirModule;
+  Module mlirModule;
   std::unique_ptr<llvm::Module> llvmModule;
 
 protected:
index 27f964e..ba46947 100644 (file)
@@ -30,7 +30,6 @@ class Module;
 } // namespace llvm
 
 namespace mlir {
-
 class Module;
 
 /// Convert the given MLIR module into NVVM IR. This conversion requires the
@@ -38,7 +37,7 @@ class Module;
 /// from the registered LLVM IR dialect.  In case of error, report it
 /// to the error handler registered with the MLIR context, if any (obtained from
 /// the MLIR module), and return `nullptr`.
-std::unique_ptr<llvm::Module> translateModuleToNVVMIR(Module &m);
+std::unique_ptr<llvm::Module> translateModuleToNVVMIR(Module m);
 
 } // namespace mlir
 
index c8ede78..0101673 100644 (file)
@@ -339,7 +339,7 @@ private:
 /// conversion object. This function returns failure if a type conversion
 /// failed.
 LLVM_NODISCARD LogicalResult applyConversionPatterns(
-    Module &module, ConversionTarget &target, TypeConverter &converter,
+    Module module, ConversionTarget &target, TypeConverter &converter,
     OwningRewritePatternList &&patterns);
 
 /// Convert the given functions with the provided conversion patterns. This
index 78518e9..6336083 100644 (file)
 namespace mlir {
 class MLIRContext;
 class Module;
+class OwningModuleRef;
 
 /// Interface of the function that translates a file to MLIR.  The
 /// implementation should create a new MLIR Module in the given context and
 /// return a pointer to it, or a nullptr in case of any error.
 using TranslateToMLIRFunction =
-    std::function<std::unique_ptr<Module>(llvm::StringRef, MLIRContext *)>;
+    std::function<OwningModuleRef(llvm::StringRef, MLIRContext *)>;
 /// Interface of the function that translates MLIR to a different format and
 /// outputs the result to a file.  The implementation should return "true" on
 /// error and "false" otherwise.  It is allowed to modify the module.
-using TranslateFromMLIRFunction =
-    std::function<bool(Module *, llvm::StringRef)>;
+using TranslateFromMLIRFunction = std::function<bool(Module, llvm::StringRef)>;
 
 /// Use Translate[To|From]MLIRRegistration as a global initialiser that
 /// registers a function and associates it with name. This requires that a
index 022d8c7..246bd54 100644 (file)
@@ -139,7 +139,7 @@ LogicalResult
 GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(Function &function) {
   Builder builder(function.getContext());
 
-  std::unique_ptr<Module> module(builder.createModule());
+  OwningModuleRef module = builder.createModule();
 
   // TODO(herhut): Also handle called functions.
   module->push_back(function.clone());
@@ -147,8 +147,9 @@ GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(Function &function) {
   auto llvmModule = translateModuleToNVVMIR(*module);
   auto cubin = convertModuleToCubin(*llvmModule, function);
 
-  if (!cubin)
+  if (!cubin) {
     return function.emitError("Translation to CUDA binary failed.");
+  }
 
   function.setAttr(kCubinAnnotation,
                    builder.getStringAttr({cubin->data(), cubin->size()}));
index f9d5899..0759324 100644 (file)
@@ -152,8 +152,8 @@ private:
 // The types in comments give the actual types expected/returned but the API
 // uses void pointers. This is fine as they have the same linkage in C.
 void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
-  Module &module = getModule();
-  Builder builder(&module);
+  Module module = getModule();
+  Builder builder(module);
   if (!module.getNamedFunction(cuModuleLoadName)) {
     module.push_back(
         Function::create(loc, cuModuleLoadName,
@@ -343,7 +343,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
                                ArrayRef<Value *>{cuModule, data.getResult(0)});
   // Get the function from the module. The name corresponds to the name of
   // the kernel function.
-  auto cuModuleRef =
+  auto cuOwningModuleRef =
       builder.create<LLVM::LoadOp>(loc, getPointerType(), cuModule);
   auto kernelName = generateKernelNameConstant(kernelFunction, loc, builder);
   auto cuFunction = allocatePointer(builder, loc);
@@ -352,7 +352,7 @@ void GpuLaunchFuncToCudaCallsPass::translateGpuLaunchCalls(
   builder.create<LLVM::CallOp>(
       loc, ArrayRef<Type>{getCUResultType()},
       builder.getFunctionAttr(cuModuleGetFunction),
-      ArrayRef<Value *>{cuFunction, cuModuleRef, kernelName});
+      ArrayRef<Value *>{cuFunction, cuOwningModuleRef, kernelName});
   // Grab the global stream needed for execution.
   Function cuGetStreamHelper =
       getModule().getNamedFunction(cuGetStreamHelperName);
index 97790a5..550491b 100644 (file)
@@ -115,7 +115,7 @@ public:
   void runOnModule() override {
     llvmDialect =
         getModule().getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
-    auto &module = getModule();
+    auto module = getModule();
     Builder builder(&getContext());
 
     auto functions = module.getFunctions();
index e849f6f..a0b911e 100644 (file)
@@ -442,13 +442,13 @@ struct AllocOpLowering : public LLVMLegalizationPattern<AllocOp> {
 
     // Insert the `malloc` declaration if it is not already present.
     Function mallocFunc =
-        op->getFunction().getModule()->getNamedFunction("malloc");
+        op->getFunction().getModule().getNamedFunction("malloc");
     if (!mallocFunc) {
       auto mallocType =
           rewriter.getFunctionType(getIndexType(), getVoidPtrType());
       mallocFunc =
           Function::create(rewriter.getUnknownLoc(), "malloc", mallocType);
-      op->getFunction().getModule()->push_back(mallocFunc);
+      op->getFunction().getModule().push_back(mallocFunc);
     }
 
     // Allocate the underlying buffer and store a pointer to it in the MemRef
@@ -503,11 +503,11 @@ struct DeallocOpLowering : public LLVMLegalizationPattern<DeallocOp> {
     OperandAdaptor<DeallocOp> transformed(operands);
 
     // Insert the `free` declaration if it is not already present.
-    Function freeFunc = op->getFunction().getModule()->getNamedFunction("free");
+    Function freeFunc = op->getFunction().getModule().getNamedFunction("free");
     if (!freeFunc) {
       auto freeType = rewriter.getFunctionType(getVoidPtrType(), {});
       freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType);
-      op->getFunction().getModule()->push_back(freeFunc);
+      op->getFunction().getModule().push_back(freeFunc);
     }
 
     auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
@@ -936,8 +936,8 @@ static void ensureDistinctSuccessors(Block &bb) {
   }
 }
 
-void mlir::LLVM::ensureDistinctSuccessors(Module *m) {
-  for (auto f : *m) {
+void mlir::LLVM::ensureDistinctSuccessors(Module m) {
+  for (auto f : m) {
     for (auto &bb : f.getBlocks()) {
       ::ensureDistinctSuccessors(bb);
     }
@@ -1010,8 +1010,8 @@ namespace {
 struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
   // Run the dialect converter on the module.
   void runOnModule() override {
-    Module &m = getModule();
-    LLVM::ensureDistinctSuccessors(&m);
+    Module m = getModule();
+    LLVM::ensureDistinctSuccessors(m);
 
     LLVMTypeConverter converter(&getContext());
     OwningRewritePatternList patterns;
index 5a11512..ea9788c 100644 (file)
@@ -322,7 +322,7 @@ void packFunctionArguments(llvm::Module *module) {
 ExecutionEngine::~ExecutionEngine() = default;
 
 Expected<std::unique_ptr<ExecutionEngine>>
-ExecutionEngine::create(Module *m,
+ExecutionEngine::create(Module m,
                         std::function<llvm::Error(llvm::Module *)> transformer,
                         ArrayRef<StringRef> sharedLibPaths) {
   auto engine = llvm::make_unique<ExecutionEngine>();
@@ -330,7 +330,7 @@ ExecutionEngine::create(Module *m,
   if (!expectedJIT)
     return expectedJIT.takeError();
 
-  auto llvmModule = translateModuleToLLVMIR(*m);
+  auto llvmModule = translateModuleToLLVMIR(m);
   if (!llvmModule)
     return make_string_error("could not convert to LLVM IR");
   // FIXME: the triple should be passed to the translation or dialect conversion
index eadd5d6..6cf57b4 100644 (file)
@@ -426,8 +426,8 @@ LogicalResult LaunchFuncOp::verify() {
     return emitOpError("attribute 'kernel' must be a function");
   }
 
-  auto *module = getOperation()->getFunction().getModule();
-  Function kernelFunc = module->getNamedFunction(kernel());
+  auto module = getOperation()->getFunction().getModule();
+  Function kernelFunc = module.getNamedFunction(kernel());
   if (!kernelFunc)
     return emitError() << "kernel function '" << kernelAttr << "' is undefined";
 
index f93febc..6cb920a 100644 (file)
@@ -97,7 +97,7 @@ namespace {
 class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
 public:
   void runOnModule() override {
-    ModuleManager moduleManager(&getModule());
+    ModuleManager moduleManager(getModule());
     for (auto func : getModule()) {
       func.walk<mlir::gpu::LaunchOp>([&](mlir::gpu::LaunchOp op) {
         Function outlinedFunc = outlineKernelFunc(op);
index 52f54fe..12e24df 100644 (file)
@@ -91,7 +91,7 @@ public:
   explicit ModuleState(MLIRContext *context) : context(context) {}
 
   // Initializes module state, populating affine map state.
-  void initialize(Module *module);
+  void initialize(Module module);
 
   Twine getAttributeAlias(Attribute attr) const {
     auto alias = attrToAlias.find(attr);
@@ -301,12 +301,12 @@ void ModuleState::initializeSymbolAliases() {
 }
 
 // Initializes module state, populating affine map and integer set state.
-void ModuleState::initialize(Module *module) {
+void ModuleState::initialize(Module module) {
   // Initialize the symbol aliases.
   initializeSymbolAliases();
 
   // Walk the module and visit each operation.
-  for (auto fn : *module) {
+  for (auto fn : module) {
     visitType(fn.getType());
     for (auto attr : fn.getAttrs())
       ModuleState::visitAttribute(attr.second);
@@ -331,7 +331,7 @@ public:
     interleave(c.begin(), c.end(), each_fn, [&]() { os << ", "; });
   }
 
-  void print(Module *module);
+  void print(Module module);
 
   /// Print the given attribute. If 'mayElideType' is true, some attributes are
   /// printed without the type when the type matches the default used in the
@@ -451,13 +451,13 @@ void ModulePrinter::printLocationInternal(LocationAttr loc, bool pretty) {
   }
 }
 
-void ModulePrinter::print(Module *module) {
+void ModulePrinter::print(Module module) {
   // Output the aliases at the top level.
   state.printAttributeAliases(os);
   state.printTypeAliases(os);
 
   // Print the module.
-  for (auto fn : *module)
+  for (auto fn : module)
     print(fn);
 }
 
@@ -1784,8 +1784,8 @@ void Function::dump() { print(llvm::errs()); }
 
 void Module::print(raw_ostream &os) {
   ModuleState state(getContext());
-  state.initialize(this);
-  ModulePrinter(os, state).print(this);
+  state.initialize(*this);
+  ModulePrinter(os, state).print(*this);
 }
 
 void Module::dump() { print(llvm::errs()); }
index 89df642..6d0df6d 100644 (file)
 #include "mlir/Support/Functional.h"
 using namespace mlir;
 
-Builder::Builder(Module *module) : context(module->getContext()) {}
+Builder::Builder(Module module) : context(module.getContext()) {}
 
 Identifier Builder::getIdentifier(StringRef str) {
   return Identifier::get(str, context);
 }
 
-Module *Builder::createModule() { return new Module(context); }
+Module Builder::createModule() { return Module::create(context); }
 
 //===----------------------------------------------------------------------===//
 // Locations.
index 9c665e7..77425c7 100644 (file)
@@ -43,12 +43,14 @@ FunctionStorage::FunctionStorage(Location location, StringRef name,
       type(type), attrs(attrs), argAttrs(argAttrs), body(this) {}
 
 MLIRContext *Function::getContext() { return getType().getContext(); }
+Module Function::getModule() { return impl->module; }
 
-Module *llvm::ilist_traits<FunctionStorage>::getContainingModule() {
-  size_t Offset(
-      size_t(&((Module *)nullptr->*Module::getSublistAccess(nullptr))));
+ModuleStorage *llvm::ilist_traits<FunctionStorage>::getContainingModule() {
+  size_t Offset(size_t(
+      &((ModuleStorage *)nullptr->*ModuleStorage::getSublistAccess(nullptr))));
   iplist<FunctionStorage> *Anchor(static_cast<iplist<FunctionStorage> *>(this));
-  return reinterpret_cast<Module *>(reinterpret_cast<char *>(Anchor) - Offset);
+  return reinterpret_cast<ModuleStorage *>(reinterpret_cast<char *>(Anchor) -
+                                           Offset);
 }
 
 /// This is a trait method invoked when a Function is added to a Module.  We
@@ -74,7 +76,7 @@ void llvm::ilist_traits<FunctionStorage>::transferNodesFromList(
     function_iterator last) {
   // If we are transferring functions within the same module, the Module
   // pointer doesn't need to be updated.
-  Module *curParent = getContainingModule();
+  ModuleStorage *curParent = getContainingModule();
   if (curParent == otherList.getContainingModule())
     return;
 
@@ -87,8 +89,8 @@ void llvm::ilist_traits<FunctionStorage>::transferNodesFromList(
 
 /// Unlink this function from its Module and delete it.
 void Function::erase() {
-  if (auto *module = getModule())
-    getModule()->functions.erase(impl);
+  if (auto module = getModule())
+    module.impl->functions.erase(impl);
   else
     delete impl;
 }
index dafbd48..02721b5 100644 (file)
@@ -21,8 +21,8 @@
 using namespace mlir;
 
 /// Build a symbol table with the symbols within the given module.
-SymbolTable::SymbolTable(Module *module) : context(module->getContext()) {
-  for (auto func : *module) {
+SymbolTable::SymbolTable(Module module) : context(module.getContext()) {
+  for (auto func : module) {
     auto inserted = symbolTable.insert({func.getName(), func});
     (void)inserted;
     assert(inserted.second &&
index 5fe4f07..7c51570 100644 (file)
@@ -170,13 +170,13 @@ public:
         LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
     auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
     // Insert the `malloc` declaration if it is not already present.
-    auto *module = op->getFunction().getModule();
-    Function mallocFunc = module->getNamedFunction("malloc");
+    auto module = op->getFunction().getModule();
+    Function mallocFunc = module.getNamedFunction("malloc");
     if (!mallocFunc) {
       auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
       mallocFunc =
           Function::create(rewriter.getUnknownLoc(), "malloc", mallocType);
-      module->push_back(mallocFunc);
+      module.push_back(mallocFunc);
     }
 
     // Get MLIR types for injecting element pointer.
@@ -231,12 +231,12 @@ public:
     auto voidPtrTy =
         LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
     // Insert the `free` declaration if it is not already present.
-    auto *module = op->getFunction().getModule();
-    Function freeFunc = module->getNamedFunction("free");
+    auto module = op->getFunction().getModule();
+    Function freeFunc = module.getNamedFunction("free");
     if (!freeFunc) {
       auto freeType = rewriter.getFunctionType(voidPtrTy, {});
       freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType);
-      module->push_back(freeFunc);
+      module.push_back(freeFunc);
     }
 
     // Get MLIR types for extracting element pointer.
@@ -576,7 +576,7 @@ public:
 static Function getLLVMLibraryCallImplDefinition(Function libFn) {
   auto implFnName = (libFn.getName().str() + "_impl");
   auto module = libFn.getModule();
-  if (auto f = module->getNamedFunction(implFnName)) {
+  if (auto f = module.getNamedFunction(implFnName)) {
     return f;
   }
   SmallVector<Type, 4> fnArgTypes;
@@ -590,7 +590,7 @@ static Function getLLVMLibraryCallImplDefinition(Function libFn) {
 
   // Insert the implementation function definition.
   auto implFnDefn = Function::create(libFn.getLoc(), implFnName, implFnType);
-  module->push_back(implFnDefn);
+  module.push_back(implFnDefn);
   return implFnDefn;
 }
 
@@ -603,7 +603,7 @@ static Function getLLVMLibraryCallDeclaration(Operation *op,
   assert(isa<LinalgOp>(op));
   auto fnName = LinalgOp::getLibraryCallName();
   auto module = op->getFunction().getModule();
-  if (auto f = module->getNamedFunction(fnName)) {
+  if (auto f = module.getNamedFunction(fnName)) {
     return f;
   }
 
@@ -620,7 +620,7 @@ static Function getLLVMLibraryCallDeclaration(Operation *op,
          "have void return types");
   auto libFnType = FunctionType::get(inputTypes, {}, op->getContext());
   auto libFn = Function::create(op->getLoc(), fnName, libFnType);
-  module->push_back(libFn);
+  module.push_back(libFn);
   // Return after creating the function definition. The body will be created
   // later.
   return libFn;
@@ -802,7 +802,7 @@ static void lowerLinalgForToCFG(Function &f) {
 }
 
 void LowerLinalgToLLVMPass::runOnModule() {
-  auto &module = getModule();
+  auto module = getModule();
 
   for (auto f : module.getFunctions()) {
     lowerLinalgSubViewOps(f);
index 677b6eb..5422b8f 100644 (file)
@@ -3857,7 +3857,7 @@ class ModuleParser : public Parser {
 public:
   explicit ModuleParser(ParserState &state) : Parser(state) {}
 
-  ParseResult parseModule(Module *module);
+  ParseResult parseModule(Module module);
 
 private:
   /// Parse an attribute alias declaration.
@@ -3875,7 +3875,7 @@ private:
       StringRef &name, FunctionType &type,
       SmallVectorImpl<std::pair<SMLoc, StringRef>> &argNames,
       SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs);
-  ParseResult parseFunc(Module *module);
+  ParseResult parseFunc(Module module);
 };
 } // end anonymous namespace
 
@@ -4039,7 +4039,7 @@ ParseResult ModuleParser::parseFunctionSignature(
 ///   function-body ::= `{` block+ `}`
 ///   function-attributes ::= `attributes` attribute-dict
 ///
-ParseResult ModuleParser::parseFunc(Module *module) {
+ParseResult ModuleParser::parseFunc(Module module) {
   consumeToken();
 
   StringRef name;
@@ -4061,7 +4061,7 @@ ParseResult ModuleParser::parseFunc(Module *module) {
   // Okay, the function signature was parsed correctly, create the function now.
   auto function =
       Function::create(getEncodedSourceLocation(loc), name, type, attrs);
-  module->push_back(function);
+  module.push_back(function);
 
   // Parse an optional trailing location.
   if (parseOptionalTrailingLocation(function))
@@ -4097,7 +4097,7 @@ ParseResult ModuleParser::parseFunc(Module *module) {
 }
 
 /// This is the top-level module parser.
-ParseResult ModuleParser::parseModule(Module *module) {
+ParseResult ModuleParser::parseModule(Module module) {
   while (1) {
     switch (getToken().getKind()) {
     default:
@@ -4139,16 +4139,15 @@ ParseResult ModuleParser::parseModule(Module *module) {
 /// This parses the file specified by the indicated SourceMgr and returns an
 /// MLIR module if it was valid.  If not, it emits diagnostics and returns
 /// null.
-Module *mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
-                              MLIRContext *context) {
+Module mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
+                             MLIRContext *context) {
 
   // This is the result module we are parsing into.
-  std::unique_ptr<Module> module(new Module(context));
+  OwningModuleRef module(Module::create(context));
 
   ParserState state(sourceMgr, context);
-  if (ModuleParser(state).parseModule(module.get())) {
+  if (ModuleParser(state).parseModule(*module))
     return nullptr;
-  }
 
   // Make sure the parse module has no other structural problems detected by
   // the verifier.
@@ -4161,7 +4160,7 @@ Module *mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
 /// This parses the file specified by the indicated filename and returns an
 /// MLIR module if it was valid.  If not, the error message is emitted through
 /// the error handler registered in the context, and a null pointer is returned.
-Module *mlir::parseSourceFile(StringRef filename, MLIRContext *context) {
+Module mlir::parseSourceFile(StringRef filename, MLIRContext *context) {
   llvm::SourceMgr sourceMgr;
   return parseSourceFile(filename, sourceMgr, context);
 }
@@ -4170,8 +4169,8 @@ Module *mlir::parseSourceFile(StringRef filename, MLIRContext *context) {
 /// SourceMgr and returns an MLIR module if it was valid.  If not, the error
 /// message is emitted through the error handler registered in the context, and
 /// a null pointer is returned.
-Module *mlir::parseSourceFile(StringRef filename, llvm::SourceMgr &sourceMgr,
-                              MLIRContext *context) {
+Module mlir::parseSourceFile(StringRef filename, llvm::SourceMgr &sourceMgr,
+                             MLIRContext *context) {
   if (sourceMgr.getNumBuffers() != 0) {
     // TODO(b/136086478): Extend to support multiple buffers.
     emitError(mlir::UnknownLoc::get(context),
@@ -4192,7 +4191,7 @@ Module *mlir::parseSourceFile(StringRef filename, llvm::SourceMgr &sourceMgr,
 
 /// This parses the program string to a MLIR module if it was valid. If not,
 /// it emits diagnostics and returns null.
-Module *mlir::parseSourceString(StringRef moduleStr, MLIRContext *context) {
+Module mlir::parseSourceString(StringRef moduleStr, MLIRContext *context) {
   auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr);
   if (!memBuffer)
     return nullptr;
index 057f265..aef16ff 100644 (file)
@@ -66,7 +66,7 @@ static void printIR(const llvm::Any &ir, bool printModuleScope,
 
     // Print the function name and a newline before the Module.
     out << " (function: " << function.getName() << ")\n";
-    function.getModule()->print(out);
+    function.getModule().print(out);
     return;
   }
 
@@ -80,8 +80,8 @@ static void printIR(const llvm::Any &ir, bool printModuleScope,
   }
 
   // Print the given module.
-  assert(llvm::any_isa<Module *>(ir) && "unexpected IR unit");
-  llvm::any_cast<Module *>(ir)->print(out);
+  assert(llvm::any_isa<Module>(ir) && "unexpected IR unit");
+  llvm::any_cast<Module>(ir).print(out);
 }
 
 /// Instrumentation hooks.
index 27ec74c..feaf2bb 100644 (file)
@@ -75,7 +75,7 @@ LogicalResult FunctionPassBase::run(Function fn, FunctionAnalysisManager &fam) {
 }
 
 /// Forwarding function to execute this pass.
-LogicalResult ModulePassBase::run(Module *module, ModuleAnalysisManager &mam) {
+LogicalResult ModulePassBase::run(Module module, ModuleAnalysisManager &mam) {
   // Initialize the pass state.
   passState.emplace(module, mam);
 
@@ -124,7 +124,7 @@ LogicalResult detail::FunctionPassExecutor::run(Function function,
 }
 
 /// Run all of the passes in this manager over the current module.
-LogicalResult detail::ModulePassExecutor::run(Module *module,
+LogicalResult detail::ModulePassExecutor::run(Module module,
                                               ModuleAnalysisManager &mam) {
   // Run each of the held passes.
   for (auto &pass : passes)
@@ -261,7 +261,7 @@ PassManager::PassManager(bool verifyPasses)
 PassManager::~PassManager() {}
 
 /// Run the passes within this manager on the provided module.
-LogicalResult PassManager::run(Module *module) {
+LogicalResult PassManager::run(Module module) {
   ModuleAnalysisManager mam(module, instrumentor.get());
   return mpe->run(module, mam);
 }
index d2563fb..b0cd228 100644 (file)
@@ -76,7 +76,7 @@ public:
   ModulePassExecutor &operator=(const ModulePassExecutor &) = delete;
 
   /// Run the executor on the given module.
-  LogicalResult run(Module *module, ModuleAnalysisManager &mam);
+  LogicalResult run(Module module, ModuleAnalysisManager &mam);
 
   /// Add a pass to the current executor. This takes ownership over the provided
   /// pass pointer.
index 543b730..688d1f9 100644 (file)
@@ -34,10 +34,10 @@ using namespace mlir;
 
 // Adds a one-block function named as `spirv_module` to `module` and returns the
 // block. The created block will be terminated by `std.return`.
-Block *createOneBlockFunction(Builder builder, Module *module) {
+Block *createOneBlockFunction(Builder builder, Module module) {
   auto fnType = builder.getFunctionType(/*inputs=*/{}, /*results=*/{});
   auto fn = Function::create(builder.getUnknownLoc(), "spirv_module", fnType);
-  module->push_back(fn);
+  module.push_back(fn);
 
   auto *block = new Block();
   fn.push_back(block);
@@ -51,8 +51,8 @@ Block *createOneBlockFunction(Builder builder, Module *module) {
 
 // Deserializes the SPIR-V binary module stored in the file named as
 // `inputFilename` and returns a module containing the SPIR-V module.
-std::unique_ptr<Module> deserializeModule(llvm::StringRef inputFilename,
-                                          MLIRContext *context) {
+OwningModuleRef deserializeModule(llvm::StringRef inputFilename,
+                                  MLIRContext *context) {
   Builder builder(context);
 
   std::string errorMessage;
@@ -83,7 +83,7 @@ std::unique_ptr<Module> deserializeModule(llvm::StringRef inputFilename,
   // converted SPIR-V ModuleOp inside a MLIR module. This should be changed to
   // return the SPIR-V ModuleOp directly after module and function are migrated
   // to be general ops.
-  std::unique_ptr<Module> module(builder.createModule());
+  OwningModuleRef module(builder.createModule());
   Block *block = createOneBlockFunction(builder, module.get());
   block->push_front(spirvModule->getOperation());
 
index 33572d5..7cf9bdb 100644 (file)
@@ -31,7 +31,7 @@
 
 using namespace mlir;
 
-LogicalResult serializeModule(Module *module, StringRef outputFilename) {
+LogicalResult serializeModule(Module module, StringRef outputFilename) {
   if (!module)
     return failure();
 
@@ -45,7 +45,7 @@ LogicalResult serializeModule(Module *module, StringRef outputFilename) {
   // wrapping the SPIR-V ModuleOp inside a MLIR module. This should be changed
   // to take in the SPIR-V ModuleOp directly after module and function are
   // migrated to be general ops.
-  for (auto fn : *module) {
+  for (auto fn : module) {
     fn.walk<spirv::ModuleOp>([&](spirv::ModuleOp spirvModule) {
       if (done) {
         spirvModule.emitError("found more than one 'spv.module' op");
@@ -73,6 +73,6 @@ LogicalResult serializeModule(Module *module, StringRef outputFilename) {
 
 static TranslateFromMLIRRegistration
     registration("serialize-spirv",
-                 [](Module *module, StringRef outputFilename) {
+                 [](Module module, StringRef outputFilename) {
                    return failed(serializeModule(module, outputFilename));
                  });
index 92a2389..a527e85 100644 (file)
@@ -440,7 +440,7 @@ static LogicalResult verify(CallOp op) {
   auto fnAttr = op.getAttrOfType<FunctionAttr>("callee");
   if (!fnAttr)
     return op.emitOpError("requires a 'callee' function attribute");
-  auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction(
+  auto fn = op.getOperation()->getFunction().getModule().getNamedFunction(
       fnAttr.getValue());
   if (!fn)
     return op.emitOpError() << "'" << fnAttr.getValue()
@@ -1107,7 +1107,7 @@ static LogicalResult verify(ConstantOp &op) {
       return op.emitOpError("requires 'value' to be a function reference");
 
     // Try to find the referenced function.
-    auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction(
+    auto fn = op.getOperation()->getFunction().getModule().getNamedFunction(
         fnAttr.getValue());
     if (!fn)
       return op.emitOpError("reference to undefined function 'bar'");
index 15b148a..d2957c1 100644 (file)
@@ -50,7 +50,7 @@ static LogicalResult
 performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
                SourceMgr &sourceMgr, MLIRContext *context,
                const std::vector<const mlir::PassRegistryEntry *> &passList) {
-  std::unique_ptr<Module> module(parseSourceFile(sourceMgr, context));
+  OwningModuleRef module(parseSourceFile(sourceMgr, context));
   if (!module)
     return failure();
 
@@ -63,7 +63,7 @@ performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
   applyPassManagerCLOptions(pm);
 
   // Run the pipeline.
-  if (failed(pm.run(module.get())))
+  if (failed(pm.run(*module)))
     return failure();
 
   // Print the output.
index eb18acb..dcb55d6 100644 (file)
@@ -37,7 +37,7 @@ using namespace mlir;
 // Storage for the translation function wrappers that survive the parser.
 static llvm::SmallVector<TranslateFunction, 16> wrapperStorage;
 
-static LogicalResult printMLIROutput(Module &module,
+static LogicalResult printMLIROutput(Module module,
                                      llvm::StringRef outputFilename) {
   if (failed(module.verify()))
     return failure();
@@ -62,7 +62,7 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt)
     TranslateFunction wrapper = [function](StringRef inputFilename,
                                            StringRef outputFilename,
                                            MLIRContext *context) {
-      std::unique_ptr<Module> module = function(inputFilename, context);
+      OwningModuleRef module = function(inputFilename, context);
       if (!module)
         return failure();
       return printMLIROutput(*module, outputFilename);
@@ -79,8 +79,8 @@ TranslationParser::TranslationParser(llvm::cl::Option &opt)
                                            MLIRContext *context) {
       llvm::SourceMgr sourceMgr;
       SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
-      auto module = std::unique_ptr<Module>(
-          parseSourceFile(inputFilename, sourceMgr, context));
+      auto module =
+          OwningModuleRef(parseSourceFile(inputFilename, sourceMgr, context));
       if (!module)
         return failure();
       return failure(function(module.get(), outputFilename));
index 34dd374..49431d4 100644 (file)
 
 using namespace mlir;
 
-std::unique_ptr<llvm::Module> mlir::translateModuleToLLVMIR(Module &m) {
+std::unique_ptr<llvm::Module> mlir::translateModuleToLLVMIR(Module m) {
   return LLVM::ModuleTranslation::translateModule<>(m);
 }
 
 static TranslateFromMLIRRegistration registration(
-    "mlir-to-llvmir", [](Module *module, llvm::StringRef outputFilename) {
+    "mlir-to-llvmir", [](Module module, llvm::StringRef outputFilename) {
       if (!module)
         return true;
 
-      auto llvmModule = LLVM::ModuleTranslation::translateModule<>(*module);
+      auto llvmModule = LLVM::ModuleTranslation::translateModule<>(module);
       if (!llvmModule)
         return true;
 
index 1e84092..bcff6e4 100644 (file)
@@ -47,8 +47,7 @@ static llvm::Value *createIntrinsicCall(llvm::IRBuilder<> &builder,
 class ModuleTranslation : public LLVM::ModuleTranslation {
 
 public:
-  explicit ModuleTranslation(Module &module)
-      : LLVM::ModuleTranslation(module) {}
+  explicit ModuleTranslation(Module module) : LLVM::ModuleTranslation(module) {}
   ~ModuleTranslation() override {}
 
 protected:
@@ -62,7 +61,7 @@ protected:
 };
 } // namespace
 
-std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Module &m) {
+std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Module m) {
   ModuleTranslation translation(m);
   auto llvmModule =
       LLVM::ModuleTranslation::translateModule<ModuleTranslation>(m);
@@ -91,11 +90,11 @@ std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Module &m) {
 
 static TranslateFromMLIRRegistration
     registration("mlir-to-nvvmir",
-                 [](Module *module, llvm::StringRef outputFilename) {
+                 [](Module module, llvm::StringRef outputFilename) {
                    if (!module)
                      return true;
 
-                   auto llvmModule = mlir::translateModuleToNVVMIR(*module);
+                   auto llvmModule = mlir::translateModuleToNVVMIR(module);
                    if (!llvmModule)
                      return true;
 
index 4a68ac7..df2bd57 100644 (file)
@@ -275,7 +275,7 @@ static Value *getPHISourceValue(Block *current, Block *pred,
              : terminator.getSuccessorOperand(1, index);
 }
 
-void ModuleTranslation::connectPHINodes(Function &func) {
+void ModuleTranslation::connectPHINodes(Function func) {
   // Skip the first block, it cannot be branched to and its arguments correspond
   // to the arguments of the LLVM function.
   for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
@@ -306,7 +306,7 @@ static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) {
 }
 
 // Sort function blocks topologically.
-static llvm::SetVector<Block *> topologicalSort(Function &f) {
+static llvm::SetVector<Block *> topologicalSort(Function f) {
   // For each blocks that has not been visited yet (i.e. that has no
   // predecessors), add it to the list and traverse its successors in DFS
   // preorder.
@@ -320,7 +320,7 @@ static llvm::SetVector<Block *> topologicalSort(Function &f) {
   return blocks;
 }
 
-bool ModuleTranslation::convertOneFunction(Function &func) {
+bool ModuleTranslation::convertOneFunction(Function func) {
   // Clear the block and value mappings, they are only relevant within one
   // function.
   blockMapping.clear();
@@ -404,7 +404,7 @@ bool ModuleTranslation::convertFunctions() {
   return false;
 }
 
-std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(Module &m) {
+std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(Module m) {
   auto *dialect = m.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
   assert(dialect && "LLVM dialect must be registered");
 
index 9916c9e..9375a7b 100644 (file)
@@ -1128,7 +1128,7 @@ auto ConversionTarget::getOpAction(OperationName op) const
 /// conversion object. If conversion fails for specific functions, those
 /// functions remains unmodified.
 LogicalResult
-mlir::applyConversionPatterns(Module &module, ConversionTarget &target,
+mlir::applyConversionPatterns(Module module, ConversionTarget &target,
                               TypeConverter &converter,
                               OwningRewritePatternList &&patterns) {
   SmallVector<Function, 32> allFunctions(module.getFunctions());
index a88312d..788857c 100644 (file)
@@ -555,8 +555,8 @@ TEST_FUNC(vectorize_2d) {
       makeFunction("vectorize_2d", {}, {memrefType, memrefType, memrefType});
 
   mlir::Function f = owningF;
-  mlir::Module module(&globalContext());
-  module.push_back(f);
+  mlir::OwningModuleRef module = Module::create(&globalContext());
+  module->push_back(f);
 
   OpBuilder builder(f.getBody());
   ScopedContext scope(builder, f.getLoc());
index 1ac6c40..1b1fbcc 100644 (file)
@@ -89,8 +89,8 @@ static llvm::cl::list<std::string>
                  llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
                  llvm::cl::cat(clOptionsCategory));
 
-static std::unique_ptr<Module> parseMLIRInput(StringRef inputFilename,
-                                              MLIRContext *context) {
+static OwningModuleRef parseMLIRInput(StringRef inputFilename,
+                                      MLIRContext *context) {
   // Set up the input file.
   std::string errorMessage;
   auto file = openInputFile(inputFilename, &errorMessage);
@@ -101,7 +101,7 @@ static std::unique_ptr<Module> parseMLIRInput(StringRef inputFilename,
 
   llvm::SourceMgr sourceMgr;
   sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
-  return std::unique_ptr<Module>(parseSourceFile(sourceMgr, context));
+  return OwningModuleRef(parseSourceFile(sourceMgr, context));
 }
 
 // Initialize the relevant subsystems of LLVM.
@@ -151,7 +151,7 @@ static void printMemRefArguments(ArrayRef<Type> argTypes,
 // - canonicalization
 // - affine to standard lowering
 // - standard to llvm lowering
-static LogicalResult convertAffineStandardToLLVMIR(Module *module) {
+static LogicalResult convertAffineStandardToLLVMIR(Module module) {
   PassManager manager;
   manager.addPass(mlir::createCanonicalizerPass());
   manager.addPass(mlir::createCSEPass());
@@ -161,9 +161,9 @@ static LogicalResult convertAffineStandardToLLVMIR(Module *module) {
 }
 
 static Error compileAndExecuteFunctionWithMemRefs(
-    Module *module, StringRef entryPoint,
+    Module module, StringRef entryPoint,
     std::function<llvm::Error(llvm::Module *)> transformer) {
-  Function mainFunction = module->getNamedFunction(entryPoint);
+  Function mainFunction = module.getNamedFunction(entryPoint);
   if (!mainFunction || mainFunction.getBlocks().empty()) {
     return make_string_error("entry point not found");
   }
@@ -204,9 +204,9 @@ static Error compileAndExecuteFunctionWithMemRefs(
 }
 
 static Error compileAndExecuteSingleFloatReturnFunction(
-    Module *module, StringRef entryPoint,
+    Module module, StringRef entryPoint,
     std::function<llvm::Error(llvm::Module *)> transformer) {
-  Function mainFunction = module->getNamedFunction(entryPoint);
+  Function mainFunction = module.getNamedFunction(entryPoint);
   if (!mainFunction || mainFunction.isExternal()) {
     return make_string_error("entry point not found");
   }
index d2a8237..0464498 100644 (file)
@@ -26,19 +26,19 @@ namespace {
 /// Minimal class definitions for two analyses.
 struct MyAnalysis {
   MyAnalysis(Function) {}
-  MyAnalysis(Module *) {}
+  MyAnalysis(Module) {}
 };
 struct OtherAnalysis {
   OtherAnalysis(Function) {}
-  OtherAnalysis(Module *) {}
+  OtherAnalysis(Module) {}
 };
 
 TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) {
   MLIRContext context;
 
   // Test fine grain invalidation of the module analysis manager.
-  std::unique_ptr<Module> module(new Module(&context));
-  ModuleAnalysisManager mam(&*module, /*passInstrumentor=*/nullptr);
+  OwningModuleRef module(Module::create(&context));
+  ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
 
   // Query two different analyses, but only preserve one before invalidating.
   mam.getAnalysis<MyAnalysis>();
@@ -58,14 +58,14 @@ TEST(AnalysisManagerTest, FineGrainFunctionAnalysisPreservation) {
   Builder builder(&context);
 
   // Create a function and a module.
-  std::unique_ptr<Module> module(new Module(&context));
+  OwningModuleRef module(Module::create(&context));
   Function func1 =
       Function::create(builder.getUnknownLoc(), "foo",
                        builder.getFunctionType(llvm::None, llvm::None));
   module->push_back(func1);
 
   // Test fine grain invalidation of the function analysis manager.
-  ModuleAnalysisManager mam(&*module, /*passInstrumentor=*/nullptr);
+  ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
   FunctionAnalysisManager fam = mam.slice(func1);
 
   // Query two different analyses, but only preserve one before invalidating.
@@ -86,7 +86,7 @@ TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) {
   Builder builder(&context);
 
   // Create a function and a module.
-  std::unique_ptr<Module> module(new Module(&context));
+  OwningModuleRef module(Module::create(&context));
   Function func1 =
       Function::create(builder.getUnknownLoc(), "foo",
                        builder.getFunctionType(llvm::None, llvm::None));
@@ -94,7 +94,7 @@ TEST(AnalysisManagerTest, FineGrainChildFunctionAnalysisPreservation) {
 
   // Test fine grain invalidation of a function analysis from within a module
   // analysis manager.
-  ModuleAnalysisManager mam(&*module, /*passInstrumentor=*/nullptr);
+  ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
 
   // Query two different analyses, but only preserve one before invalidating.
   mam.getFunctionAnalysis<MyAnalysis>(func1);