/// Trivial C++ wrappers make use of the EDSC C API.
struct PythonMLIRModule {
PythonMLIRModule()
- : mlirContext(), module(new mlir::Module(&mlirContext)),
- moduleManager(module.get()) {}
+ : mlirContext(), module(mlir::Module::create(&mlirContext)),
+ moduleManager(*module) {}
PythonType makeScalarType(const std::string &mlirElemType,
unsigned bitwidth) {
manager.addPass(mlir::createCSEPass());
manager.addPass(mlir::createLowerAffinePass());
manager.addPass(mlir::createConvertToLLVMIRPass());
- if (failed(manager.run(module.get()))) {
+ if (failed(manager.run(*module))) {
llvm::errs() << "conversion to the LLVM IR dialect failed\n";
return;
}
- auto created = mlir::ExecutionEngine::create(module.get());
+ auto created = mlir::ExecutionEngine::create(*module);
llvm::handleAllErrors(created.takeError(),
[](const llvm::ErrorInfoBase &b) {
b.log(llvm::errs());
private:
mlir::MLIRContext mlirContext;
// One single module in a python-exposed MLIRContext for now.
- std::unique_ptr<mlir::Module> module;
+ mlir::OwningModuleRef module;
mlir::ModuleManager moduleManager;
std::unique_ptr<mlir::ExecutionEngine> engine;
};
}
/// A basic function builder
-inline mlir::Function makeFunction(mlir::Module &module, llvm::StringRef name,
+inline mlir::Function makeFunction(mlir::Module module, llvm::StringRef name,
llvm::ArrayRef<mlir::Type> types,
llvm::ArrayRef<mlir::Type> resultTypes) {
auto *context = module.getContext();
}
};
auto pm = cleanupPassManager();
- check(f.getModule()->verify());
+ check(f.getModule().verify());
check(pm->run(f.getModule()));
if (printToOuts)
f.print(llvm::outs());
/// Convert the Linalg dialect types and RangeOp, ViewOp and SliceOp operations
/// to the LLVM IR dialect types and operations in the given `module`. This is
/// the main entry point to the conversion.
-void convertToLLVM(mlir::Module &module);
+void convertToLLVM(mlir::Module module);
} // end namespace linalg
#endif // LINALG1_CONVERTTOLLVMDIALECT_H_
};
} // end anonymous namespace
-void linalg::convertToLLVM(mlir::Module &module) {
+void linalg::convertToLLVM(mlir::Module module) {
// Remove affine constructs if any by using an existing pass.
PassManager pm;
pm.addPass(createLowerAffinePass());
- auto rr = pm.run(&module);
+ auto rr = pm.run(module);
(void)rr;
assert(succeeded(rr) && "affine loop lowering failed");
TEST_FUNC(linalg_ops) {
MLIRContext context;
- Module module(&context);
+ OwningModuleRef module = Module::create(&context);
auto indexType = mlir::IndexType::get(&context);
- mlir::Function f =
- makeFunction(module, "linalg_ops", {indexType, indexType, indexType}, {});
+ mlir::Function f = makeFunction(*module, "linalg_ops",
+ {indexType, indexType, indexType}, {});
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
TEST_FUNC(linalg_ops_folded_slices) {
MLIRContext context;
- Module module(&context);
+ OwningModuleRef module = Module::create(&context);
auto indexType = mlir::IndexType::get(&context);
- mlir::Function f = makeFunction(module, "linalg_ops_folded_slices",
+ mlir::Function f = makeFunction(*module, "linalg_ops_folded_slices",
{indexType, indexType, indexType}, {});
OpBuilder builder(f.getBody());
using namespace linalg::common;
using namespace linalg::intrinsics;
-Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
+Function makeFunctionWithAMatmulOp(Module module, StringRef name) {
MLIRContext *context = module.getContext();
auto dynamic2DMemRefType = floatMemRefType<2>(context);
mlir::Function f = linalg::common::makeFunction(
TEST_FUNC(foo) {
MLIRContext context;
- Module module(&context);
- mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
+ OwningModuleRef module = Module::create(&context);
+ mlir::Function f = makeFunctionWithAMatmulOp(*module, "matmul_as_loops");
lowerToLoops(f);
- convertLinalg3ToLLVM(module);
+ convertLinalg3ToLLVM(*module);
// clang-format off
// CHECK: {{.*}} = llvm.extractvalue {{.*}}[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
// CHECK-NEXT: {{.*}} = llvm.getelementptr {{.*}}[{{.*}}] : (!llvm<"float*">, !llvm.i64) -> !llvm<"float*">
// CHECK-NEXT: llvm.store {{.*}}, {{.*}} : !llvm<"float*">
// clang-format on
- module.print(llvm::outs());
+ module->print(llvm::outs());
}
int main() {
using namespace linalg::common;
using namespace linalg::intrinsics;
-Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
+Function makeFunctionWithAMatmulOp(Module module, StringRef name) {
MLIRContext *context = module.getContext();
auto dynamic2DMemRefType = floatMemRefType<2>(context);
mlir::Function f = linalg::common::makeFunction(
TEST_FUNC(matmul_as_matvec) {
MLIRContext context;
- Module module(&context);
+ Module module = Module::create(&context);
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_matvec");
lowerToFinerGrainedTensorContraction(f);
composeSliceOps(f);
TEST_FUNC(matmul_as_dot) {
MLIRContext context;
- Module module(&context);
+ Module module = Module::create(&context);
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_dot");
lowerToFinerGrainedTensorContraction(f);
lowerToFinerGrainedTensorContraction(f);
TEST_FUNC(matmul_as_loops) {
MLIRContext context;
- Module module(&context);
+ Module module = Module::create(&context);
mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
lowerToLoops(f);
composeSliceOps(f);
TEST_FUNC(matmul_as_matvec_as_loops) {
MLIRContext context;
- Module module(&context);
+ Module module = Module::create(&context);
mlir::Function f =
makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_loops");
lowerToFinerGrainedTensorContraction(f);
TEST_FUNC(matmul_as_matvec_as_affine) {
MLIRContext context;
- Module module(&context);
+ Module module = Module::create(&context);
mlir::Function f =
makeFunctionWithAMatmulOp(module, "matmul_as_matvec_as_affine");
lowerToFinerGrainedTensorContraction(f);
using namespace linalg::common;
using namespace linalg::intrinsics;
-Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
+Function makeFunctionWithAMatmulOp(Module module, StringRef name) {
MLIRContext *context = module.getContext();
auto dynamic2DMemRefType = floatMemRefType<2>(context);
mlir::Function f = linalg::common::makeFunction(
// linalg.matmul operation and lower it all the way down to the LLVM IR
// dialect through partial conversions.
MLIRContext context;
- Module module(&context);
- mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_as_loops");
+ OwningModuleRef module = Module::create(&context);
+ mlir::Function f = makeFunctionWithAMatmulOp(*module, "matmul_as_loops");
lowerToLoops(f);
- convertLinalg3ToLLVM(module);
+ convertLinalg3ToLLVM(*module);
// Create an MLIR execution engine. The execution engine eagerly JIT-compiles
// the module.
- auto maybeEngine = mlir::ExecutionEngine::create(&module);
+ auto maybeEngine = mlir::ExecutionEngine::create(*module);
assert(maybeEngine && "failed to construct an execution engine");
auto &engine = maybeEngine.get();
} // end namespace mlir
namespace linalg {
-void convertLinalg3ToLLVM(mlir::Module &module);
+void convertLinalg3ToLLVM(mlir::Module module);
} // end namespace linalg
#endif // LINALG3_CONVERTTOLLVMDIALECT_H_
context);
}
-void linalg::convertLinalg3ToLLVM(Module &module) {
+void linalg::convertLinalg3ToLLVM(Module module) {
// Remove affine constructs.
for (auto func : module) {
auto rr = lowerAffineConstructs(func);
using namespace linalg::common;
using namespace linalg::intrinsics;
-Function makeFunctionWithAMatmulOp(Module &module, StringRef name) {
+Function makeFunctionWithAMatmulOp(Module module, StringRef name) {
MLIRContext *context = module.getContext();
auto dynamic2DMemRefType = floatMemRefType<2>(context);
mlir::Function f = linalg::common::makeFunction(
TEST_FUNC(matmul_tiled_loops) {
MLIRContext context;
- Module module(&context);
- mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_loops");
+ OwningModuleRef module = Module::create(&context);
+ mlir::Function f = makeFunctionWithAMatmulOp(*module, "matmul_tiled_loops");
lowerToTiledLoops(f, {8, 9});
PassManager pm;
pm.addPass(createLowerLinalgLoadStorePass());
TEST_FUNC(matmul_tiled_views) {
MLIRContext context;
- Module module(&context);
- mlir::Function f = makeFunctionWithAMatmulOp(module, "matmul_tiled_views");
+ OwningModuleRef module = Module::create(&context);
+ mlir::Function f = makeFunctionWithAMatmulOp(*module, "matmul_tiled_views");
OpBuilder b(f.getBody());
lowerToTiledViews(f, {b.create<ConstantIndexOp>(f.getLoc(), 8),
b.create<ConstantIndexOp>(f.getLoc(), 9)});
TEST_FUNC(matmul_tiled_views_as_loops) {
MLIRContext context;
- Module module(&context);
+ OwningModuleRef module = Module::create(&context);
mlir::Function f =
- makeFunctionWithAMatmulOp(module, "matmul_tiled_views_as_loops");
+ makeFunctionWithAMatmulOp(*module, "matmul_tiled_views_as_loops");
OpBuilder b(f.getBody());
lowerToTiledViews(f, {b.create<ConstantIndexOp>(f.getLoc(), 8),
b.create<ConstantIndexOp>(f.getLoc(), 9)});
namespace mlir {
class MLIRContext;
-class Module;
+class OwningModuleRef;
} // namespace mlir
namespace toy {
/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module
/// or nullptr on failure.
-std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
- ModuleAST &moduleAST);
+mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST);
} // namespace toy
#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_
/// Public API: convert the AST for a Toy module (source file) to an MLIR
/// Module.
- std::unique_ptr<mlir::Module> mlirGen(ModuleAST &moduleAST) {
+ mlir::Module mlirGen(ModuleAST &moduleAST) {
// We create an empty MLIR module and codegen functions one at a time and
// add them to the module.
- theModule = make_unique<mlir::Module>(&context);
+ theModule = mlir::Module::create(&context);
for (FunctionAST &F : moduleAST) {
auto func = mlirGen(F);
if (!func)
return nullptr;
- theModule->push_back(func);
+ theModule.push_back(func);
}
// FIXME: (in the next chapter...) without registering a dialect in MLIR,
// this won't do much, but it should at least check some structural
// properties.
- if (failed(theModule->verify())) {
+ if (failed(theModule.verify())) {
emitError(mlir::UnknownLoc::get(&context), "Module verification error");
return nullptr;
}
mlir::MLIRContext &context;
/// A "module" matches a source file: it contains a list of functions.
- std::unique_ptr<mlir::Module> theModule;
+ mlir::Module theModule;
/// The builder is a helper class to create IR inside a function. It is
/// re-initialized every time we enter a function and kept around as a
namespace toy {
// The public API for codegen.
-std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
- ModuleAST &moduleAST) {
+mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context,
+ ModuleAST &moduleAST) {
return MLIRGenImpl(context).mlirGen(moduleAST);
}
int dumpMLIR() {
mlir::MLIRContext context;
- std::unique_ptr<mlir::Module> module;
+ mlir::OwningModuleRef module;
if (inputType == InputType::MLIR ||
llvm::StringRef(inputFilename).endswith(".mlir")) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
}
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
- module.reset(mlir::parseSourceFile(sourceMgr, &context));
+ module = mlir::parseSourceFile(sourceMgr, &context);
if (!module) {
llvm::errs() << "Error can't load file " << inputFilename << "\n";
return 3;
namespace mlir {
class MLIRContext;
-class Module;
+class OwningModuleRef;
} // namespace mlir
namespace toy {
/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module
/// or nullptr on failure.
-std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
- ModuleAST &moduleAST);
+mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST);
} // namespace toy
#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_
/// Public API: convert the AST for a Toy module (source file) to an MLIR
/// Module.
- std::unique_ptr<mlir::Module> mlirGen(ModuleAST &moduleAST) {
+ mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) {
// We create an empty MLIR module and codegen functions one at a time and
// add them to the module.
- theModule = make_unique<mlir::Module>(&context);
+ theModule = mlir::Module::create(&context);
for (FunctionAST &F : moduleAST) {
auto func = mlirGen(F);
mlir::MLIRContext &context;
/// A "module" matches a source file: it contains a list of functions.
- std::unique_ptr<mlir::Module> theModule;
+ mlir::OwningModuleRef theModule;
/// The builder is a helper class to create IR inside a function. It is
/// re-initialized every time we enter a function and kept around as a
namespace toy {
// The public API for codegen.
-std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
- ModuleAST &moduleAST) {
+mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context,
+ ModuleAST &moduleAST) {
return MLIRGenImpl(context).mlirGen(moduleAST);
}
mlir::registerDialect<ToyDialect>();
mlir::MLIRContext context;
- std::unique_ptr<mlir::Module> module;
+ mlir::OwningModuleRef module;
if (inputType == InputType::MLIR ||
llvm::StringRef(inputFilename).endswith(".mlir")) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
}
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
- module.reset(mlir::parseSourceFile(sourceMgr, &context));
+ module = mlir::parseSourceFile(sourceMgr, &context);
if (!module) {
llvm::errs() << "Error can't load file " << inputFilename << "\n";
return 3;
namespace mlir {
class MLIRContext;
-class Module;
+class OwningModuleRef;
} // namespace mlir
namespace toy {
/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module
/// or nullptr on failure.
-std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
- ModuleAST &moduleAST);
+mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST);
} // namespace toy
#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_
/// Public API: convert the AST for a Toy module (source file) to an MLIR
/// Module.
- std::unique_ptr<mlir::Module> mlirGen(ModuleAST &moduleAST) {
+ mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) {
// We create an empty MLIR module and codegen functions one at a time and
// add them to the module.
- theModule = make_unique<mlir::Module>(&context);
+ theModule = mlir::Module::create(&context);
for (FunctionAST &F : moduleAST) {
auto func = mlirGen(F);
mlir::MLIRContext &context;
/// A "module" matches a source file: it contains a list of functions.
- std::unique_ptr<mlir::Module> theModule;
+ mlir::OwningModuleRef theModule;
/// The builder is a helper class to create IR inside a function. It is
/// re-initialized every time we enter a function and kept around as a
namespace toy {
// The public API for codegen.
-std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
- ModuleAST &moduleAST) {
+mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context,
+ ModuleAST &moduleAST) {
return MLIRGenImpl(context).mlirGen(moduleAST);
}
};
void runOnModule() override {
- auto &module = getModule();
+ auto module = getModule();
auto main = module.getNamedFunction("main");
if (!main) {
emitError(mlir::UnknownLoc::get(module.getContext()),
return parser.ParseModule();
}
-mlir::LogicalResult optimize(mlir::Module &module) {
+mlir::LogicalResult optimize(mlir::Module module) {
mlir::PassManager pm;
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(createShapeInferencePass());
// Apply any generic pass manager command line options.
applyPassManagerCLOptions(pm);
- return pm.run(&module);
+ return pm.run(module);
}
int dumpMLIR() {
mlir::registerPassManagerCLOptions();
mlir::MLIRContext context;
- std::unique_ptr<mlir::Module> module;
+ mlir::OwningModuleRef module;
if (inputType == InputType::MLIR ||
llvm::StringRef(inputFilename).endswith(".mlir")) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
}
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
- module.reset(mlir::parseSourceFile(sourceMgr, &context));
+ module = mlir::parseSourceFile(sourceMgr, &context);
if (!module) {
llvm::errs() << "Error can't load file " << inputFilename << "\n";
return 3;
namespace mlir {
class MLIRContext;
-class Module;
+class OwningModuleRef;
} // namespace mlir
namespace toy {
/// Emit IR for the given Toy moduleAST, returns a newly created MLIR module
/// or nullptr on failure.
-std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
- ModuleAST &moduleAST);
+mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context, ModuleAST &moduleAST);
} // namespace toy
#endif // MLIR_TUTORIAL_TOY_MLIRGEN_H_
PatternMatchResult matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
PatternRewriter &rewriter) const override {
// Get or create the declaration of the printf function in the module.
- Function printfFunc = getPrintf(*op->getFunction().getModule());
+ Function printfFunc = getPrintf(op->getFunction().getModule());
auto print = cast<toy::PrintOp>(op);
auto loc = print.getLoc();
/// Return the prototype declaration for printf in the module, create it if
/// necessary.
- Function getPrintf(Module &module) const {
+ Function getPrintf(Module module) const {
auto printfFunc = module.getNamedFunction("printf");
if (printfFunc)
return printfFunc;
// Create a function declaration for printf, signature is `i32 (i8*, ...)`
- Builder builder(&module);
+ Builder builder(module);
auto *dialect =
module.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
/// Public API: convert the AST for a Toy module (source file) to an MLIR
/// Module.
- std::unique_ptr<mlir::Module> mlirGen(ModuleAST &moduleAST) {
+ mlir::OwningModuleRef mlirGen(ModuleAST &moduleAST) {
// We create an empty MLIR module and codegen functions one at a time and
// add them to the module.
- theModule = make_unique<mlir::Module>(&context);
+ theModule = mlir::Module::create(&context);
for (FunctionAST &F : moduleAST) {
auto func = mlirGen(F);
mlir::MLIRContext &context;
/// A "module" matches a source file: it contains a list of functions.
- std::unique_ptr<mlir::Module> theModule;
+ mlir::OwningModuleRef theModule;
/// The builder is a helper class to create IR inside a function. It is
/// re-initialized every time we enter a function and kept around as a
namespace toy {
// The public API for codegen.
-std::unique_ptr<mlir::Module> mlirGen(mlir::MLIRContext &context,
- ModuleAST &moduleAST) {
+mlir::OwningModuleRef mlirGen(mlir::MLIRContext &context,
+ ModuleAST &moduleAST) {
return MLIRGenImpl(context).mlirGen(moduleAST);
}
};
void runOnModule() override {
- auto &module = getModule();
- mlir::ModuleManager moduleManager(&module);
+ auto module = getModule();
+ mlir::ModuleManager moduleManager(module);
auto main = moduleManager.getNamedFunction("main");
if (!main) {
emitError(mlir::UnknownLoc::get(module.getContext()),
return parser.ParseModule();
}
-mlir::LogicalResult optimize(mlir::Module &module) {
+mlir::LogicalResult optimize(mlir::Module module) {
mlir::PassManager pm;
pm.addPass(mlir::createCanonicalizerPass());
pm.addPass(createShapeInferencePass());
// Apply any generic pass manager command line options.
applyPassManagerCLOptions(pm);
- return pm.run(&module);
+ return pm.run(module);
}
-mlir::LogicalResult lowerDialect(mlir::Module &module, bool OnlyLinalg) {
+mlir::LogicalResult lowerDialect(mlir::Module module, bool OnlyLinalg) {
mlir::PassManager pm;
pm.addPass(createEarlyLoweringPass());
pm.addPass(mlir::createCanonicalizerPass());
// Apply any generic pass manager command line options.
applyPassManagerCLOptions(pm);
- return pm.run(&module);
+ return pm.run(module);
}
-std::unique_ptr<mlir::Module> loadFileAndProcessModule(
+mlir::OwningModuleRef loadFileAndProcessModule(
mlir::MLIRContext &context, bool EnableLinalgLowering = false,
bool EnableLLVMLowering = false, bool EnableOpt = false) {
- std::unique_ptr<mlir::Module> module;
+ mlir::OwningModuleRef module;
if (inputType == InputType::MLIR ||
llvm::StringRef(inputFilename).endswith(".mlir")) {
llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> fileOrErr =
}
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(*fileOrErr), llvm::SMLoc());
- module.reset(mlir::parseSourceFile(sourceMgr, &context));
+ module = mlir::parseSourceFile(sourceMgr, &context);
if (!module) {
llvm::errs() << "Error can't load file " << inputFilename << "\n";
return nullptr;
// the module.
auto optPipeline = mlir::makeOptimizingTransformer(
/* optLevel=*/EnableOpt ? 3 : 0, /* sizeLevel=*/0);
- auto maybeEngine = mlir::ExecutionEngine::create(module.get(), optPipeline);
+ auto maybeEngine = mlir::ExecutionEngine::create(*module, optPipeline);
assert(maybeEngine && "failed to construct an execution engine");
auto &engine = maybeEngine.get();
struct MyFunctionPass : public FunctionPass<MyFunctionPass> {
void runOnFunction() override {
// Get the current function being operated on.
- Function *f = getFunction();
+ Function f = getFunction();
// Operate on the operations within the function.
- f->walk([](Operation *inst) {
+ f.walk([](Operation *inst) {
....
});
}
struct MyModulePass : public ModulePass<MyModulePass> {
void runOnModule() override {
// Get the current module being operated on.
- Module *m = getModule();
+ Module m = getModule();
// Operate on the functions within the module.
- for (auto &func : *m) {
+ for (auto func : m) {
....
}
}
/// An interesting module analysis.
struct MyModuleAnalysis {
// Compute this analysis with the provided module.
- MyModuleAnalysis(Module *module);
+ MyModuleAnalysis(Module module);
};
void MyFunctionPass::runOnFunction() {
// Query MyFunctionAnalysis for a child function of the current module. It
// will be computed if it doesn't exist.
- auto *fn = &*getModule().begin();
+ auto fn = *getModule().begin();
MyFunctionAnalysis &myAnalysis = getFunctionAnalysis<MyFunctionAnalysis>(fn);
}
```
pm.addPass(new MyModulePass2());
// Run the pass manager on a module.
-Module *m = ...;
+Module m = ...;
if (failed(pm.run(m)))
... // One of the passes signaled a failure.
```
pm.addInstrumentation(new DominanceCounterInstrumentation(domInfoCount));
// Run the pass manager on a module.
-Module *m = ...;
+Module m = ...;
if (failed(pm.run(m)))
...
/// support different values coming from the same predecessor. If a block has
/// another block as a successor more than once with different values, insert
/// a new dummy block for LLVM PHI nodes to tell the sources apart.
-void ensureDistinctSuccessors(Module *m);
+void ensureDistinctSuccessors(Module m);
} // namespace LLVM
} // namespace mlir
/// If `sharedLibPaths` are provided, the underlying JIT-compilation will open
/// and link the shared libraries for symbol resolution.
static llvm::Expected<std::unique_ptr<ExecutionEngine>>
- create(Module *m, std::function<llvm::Error(llvm::Module *)> transformer = {},
+ create(Module m, std::function<llvm::Error(llvm::Module *)> transformer = {},
ArrayRef<StringRef> sharedLibPaths = {});
/// Looks up a packed-argument function with the given name and returns a
class Builder {
public:
explicit Builder(MLIRContext *context) : context(context) {}
- explicit Builder(Module *module);
+ explicit Builder(Module module);
MLIRContext *getContext() const { return context; }
Identifier getIdentifier(StringRef str);
- Module *createModule();
+ Module createModule();
// Locations.
Location getUnknownLoc();
class Module;
namespace detail {
+class ModuleStorage;
+
/// This class represents all of the internal state of a Function. This allows
/// for the Function class to be value typed.
class FunctionStorage
- : public llvm::ilist_node_with_parent<FunctionStorage, Module> {
+ : public llvm::ilist_node_with_parent<FunctionStorage, ModuleStorage> {
FunctionStorage(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs = {});
FunctionStorage(Location location, StringRef name, FunctionType type,
Identifier name;
/// The module this function is embedded into.
- Module *module = nullptr;
+ ModuleStorage *module = nullptr;
/// The source location the function was defined or derived from.
Location location;
}
MLIRContext *getContext();
- Module *getModule() { return impl->module; }
+ Module getModule();
/// Add an entry block to an empty function, and set up the block arguments
/// to match the signature of the function.
function_iterator first, function_iterator last);
private:
- mlir::Module *getContainingModule();
+ mlir::detail::ModuleStorage *getContainingModule();
};
// Functions hash just like pointers.
#include "llvm/ADT/ilist.h"
namespace mlir {
+class Module;
+
+namespace detail {
+class ModuleStorage {
+ explicit ModuleStorage(MLIRContext *context) : context(context) {}
+
+ /// getSublistAccess() - Returns pointer to member of function list
+ static llvm::iplist<FunctionStorage> ModuleStorage::*
+ getSublistAccess(FunctionStorage *) {
+ return &ModuleStorage::functions;
+ }
+
+ /// The context attached to this module.
+ MLIRContext *context;
+
+ /// This is the actual list of functions the module contains.
+ llvm::iplist<FunctionStorage> functions;
+
+ friend Module;
+ friend struct llvm::ilist_traits<FunctionStorage>;
+ friend FunctionStorage;
+ friend Function;
+};
+} // end namespace detail
class Module {
public:
- explicit Module(MLIRContext *context) : context(context) {}
+ Module(detail::ModuleStorage *impl = nullptr) : impl(impl) {}
+
+ /// Construct a new module object with the given context.
+ static Module create(MLIRContext *context) {
+ return new detail::ModuleStorage(context);
+ }
- MLIRContext *getContext() { return context; }
+ MLIRContext *getContext() { return impl->context; }
+
+ /// Allow converting a Module to bool for null checks.
+ operator bool() const { return impl; }
+ bool operator==(Module other) const { return impl == other.impl; }
+ bool operator!=(Module other) const { return !(*this == other); }
/// An iterator class used to iterate over the held functions.
class iterator : public llvm::mapped_iterator<
llvm::iterator_range<iterator> getFunctions() { return {begin(), end()}; }
// Iteration over the functions in the module.
- iterator begin() { return functions.begin(); }
- iterator end() { return functions.end(); }
- Function front() { return &functions.front(); }
- Function back() { return &functions.back(); }
+ iterator begin() { return impl->functions.begin(); }
+ iterator end() { return impl->functions.end(); }
+ Function front() { return &impl->functions.front(); }
+ Function back() { return &impl->functions.back(); }
- void push_back(Function fn) { functions.push_back(fn.impl); }
+ void push_back(Function fn) { impl->functions.push_back(fn.impl); }
void insert(iterator insertPt, Function fn) {
- functions.insert(insertPt.getCurrent(), fn.impl);
+ impl->functions.insert(insertPt.getCurrent(), fn.impl);
}
// Interfaces for working with the symbol table.
/// name exists. Function names never include the @ on them. Note: This
/// performs a linear scan of held symbols.
Function getNamedFunction(Identifier name) {
+ auto &functions = impl->functions;
auto it = llvm::find_if(functions, [name](detail::FunctionStorage &fn) {
return Function(&fn).getName() == name;
});
void print(raw_ostream &os);
void dump();
-private:
- friend struct llvm::ilist_traits<detail::FunctionStorage>;
- friend detail::FunctionStorage;
- friend Function;
+ /// Erase the current module.
+ void erase() {
+ assert(impl && "expected valid module");
+ delete impl;
+ }
- /// getSublistAccess() - Returns pointer to member of function list
- static llvm::iplist<detail::FunctionStorage> Module::*
- getSublistAccess(detail::FunctionStorage *) {
- return &Module::functions;
+ /// Methods for supporting PointerLikeTypeTraits.
+ const void *getAsOpaquePointer() const {
+ return static_cast<const void *>(impl);
+ }
+ static Module getFromOpaquePointer(const void *pointer) {
+ return reinterpret_cast<detail::ModuleStorage *>(
+ const_cast<void *>(pointer));
}
- /// The context attached to this module.
- MLIRContext *context;
+private:
+ friend detail::FunctionStorage;
+ friend Function;
- /// This is the actual list of functions the module contains.
- llvm::iplist<detail::FunctionStorage> functions;
+ /// The internal impl storage object.
+ detail::ModuleStorage *impl = nullptr;
};
/// A class used to manage the symbols held by a module. This class handles
/// efficent named lookup to held symbols.
class ModuleManager {
public:
- ModuleManager(Module *module) : module(module), symbolTable(module) {}
+ ModuleManager(Module module) : module(module), symbolTable(module) {}
/// Look up a symbol with the specified name, returning null if no such
/// name exists. Names must never include the @ on them.
/// Insert a new symbol into the module, auto-renaming it as necessary.
void insert(Function function) {
symbolTable.insert(function);
- module->push_back(function);
+ module.push_back(function);
}
void insert(Module::iterator insertPt, Function function) {
symbolTable.insert(function);
- module->insert(insertPt, function);
+ module.insert(insertPt, function);
}
/// Remove the given symbol from the module symbol table and then erase it.
}
/// Return the internally held module.
- Module *getModule() const { return module; }
+ Module getModule() const { return module; }
/// Return the context of the internal module.
- MLIRContext *getContext() const { return module->getContext(); }
+ MLIRContext *getContext() const { return getModule().getContext(); }
private:
- Module *module;
+ Module module;
SymbolTable symbolTable;
};
+/// This class acts as an owning reference to a Module, and will automatically
+/// destory the held Module if valid.
+class OwningModuleRef {
+public:
+ OwningModuleRef(std::nullptr_t = nullptr) {}
+ OwningModuleRef(Module module) : module(module) {}
+ OwningModuleRef(OwningModuleRef &&other) : module(other.release()) {}
+ ~OwningModuleRef() {
+ if (module)
+ module.erase();
+ }
+
+ // Assign from another module reference.
+ OwningModuleRef &operator=(OwningModuleRef &&other) {
+ if (module)
+ module.erase();
+ module = other.release();
+ return *this;
+ }
+
+ /// Allow accessing the internal module.
+ Module get() const { return module; }
+ Module operator*() const { return module; }
+ Module *operator->() { return &module; }
+ explicit operator bool() const { return module; }
+
+ /// Release the referenced module.
+ Module release() {
+ Module released;
+ std::swap(released, module);
+ return released;
+ }
+
+private:
+ Module module;
+};
+
//===--------------------------------------------------------------------===//
// Module Operation.
//===--------------------------------------------------------------------===//
} // end namespace mlir
+namespace llvm {
+
+/// Allow stealing the low bits of ModuleStorage.
+template <> struct PointerLikeTypeTraits<mlir::Module> {
+public:
+ static inline void *getAsVoidPointer(mlir::Module I) {
+ return const_cast<void *>(I.getAsOpaquePointer());
+ }
+ static inline mlir::Module getFromVoidPointer(void *P) {
+ return mlir::Module::getFromOpaquePointer(P);
+ }
+ enum { NumLowBitsAvailable = 3 };
+};
+
+} // end namespace llvm
+
#endif // MLIR_IR_MODULE_H
class SymbolTable {
public:
/// Build a symbol table with the symbols within the given module.
- SymbolTable(Module *module);
+ SymbolTable(Module module);
/// Look up a symbol with the specified name, returning null if no such
/// name exists. Names never include the @ on them.
/// This parses the file specified by the indicated SourceMgr and returns an
/// MLIR module if it was valid. If not, the error message is emitted through
/// the error handler registered in the context, and a null pointer is returned.
-Module *parseSourceFile(const llvm::SourceMgr &sourceMgr, MLIRContext *context);
+Module parseSourceFile(const llvm::SourceMgr &sourceMgr, MLIRContext *context);
/// This parses the file specified by the indicated filename and returns an
/// MLIR module if it was valid. If not, the error message is emitted through
/// the error handler registered in the context, and a null pointer is returned.
-Module *parseSourceFile(llvm::StringRef filename, MLIRContext *context);
+Module parseSourceFile(llvm::StringRef filename, MLIRContext *context);
/// This parses the file specified by the indicated filename using the provided
/// SourceMgr and returns an MLIR module if it was valid. If not, the error
/// message is emitted through the error handler registered in the context, and
/// a null pointer is returned.
-Module *parseSourceFile(llvm::StringRef filename, llvm::SourceMgr &sourceMgr,
- MLIRContext *context);
+Module parseSourceFile(llvm::StringRef filename, llvm::SourceMgr &sourceMgr,
+ MLIRContext *context);
/// This parses the module string to a MLIR module if it was valid. If not, the
/// error message is emitted through the error handler registered in the
/// context, and a null pointer is returned.
-Module *parseSourceString(llvm::StringRef moduleStr, MLIRContext *context);
+Module parseSourceString(llvm::StringRef moduleStr, MLIRContext *context);
/// This parses a single MLIR type to an MLIR context if it was valid. If not,
/// an error message is emitted through a new SourceMgrDiagnosticHandler
/// An analysis manager for a specific module instance.
class ModuleAnalysisManager {
public:
- ModuleAnalysisManager(Module *module, PassInstrumentor *passInstrumentor)
+ ModuleAnalysisManager(Module module, PassInstrumentor *passInstrumentor)
: moduleAnalyses(module), passInstrumentor(passInstrumentor) {}
ModuleAnalysisManager(const ModuleAnalysisManager &) = delete;
ModuleAnalysisManager &operator=(const ModuleAnalysisManager &) = delete;
functionAnalyses;
/// The analyses for the owning module.
- detail::AnalysisMap<Module *> moduleAnalyses;
+ detail::AnalysisMap<Module> moduleAnalyses;
/// An optional instrumentation object.
PassInstrumentor *passInstrumentor;
/// Pass to transform a module. Derived passes should not inherit from this
/// class directly, and instead should use the CRTP ModulePass class.
class ModulePassBase : public Pass {
- using PassStateT =
- detail::PassExecutionState<Module *, ModuleAnalysisManager>;
+ using PassStateT = detail::PassExecutionState<Module, ModuleAnalysisManager>;
public:
static bool classof(const Pass *pass) {
virtual void runOnModule() = 0;
/// Return the current module being transformed.
- Module &getModule() { return *getPassState().irAndPassFailed.getPointer(); }
+ Module getModule() { return getPassState().irAndPassFailed.getPointer(); }
/// Return the MLIR context for the current module being transformed.
MLIRContext &getContext() { return *getModule().getContext(); }
private:
/// Forwarding function to execute this pass.
LLVM_NODISCARD
- LogicalResult run(Module *module, ModuleAnalysisManager &mam);
+ LogicalResult run(Module module, ModuleAnalysisManager &mam);
/// The current execution state for the pass.
llvm::Optional<PassStateT> passState;
/// Run the passes within this manager on the provided module.
LLVM_NODISCARD
- LogicalResult run(Module *module);
+ LogicalResult run(Module module);
//===--------------------------------------------------------------------===//
// Pipeline Building
/// from the registered LLVM IR dialect. In case of error, report it
/// to the error handler registered with the MLIR context, if any (obtained from
/// the MLIR module), and return `nullptr`.
-std::unique_ptr<llvm::Module> translateModuleToLLVMIR(Module &m);
+std::unique_ptr<llvm::Module> translateModuleToLLVMIR(Module m);
} // namespace mlir
#define MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
#include "mlir/IR/Block.h"
-#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
#include "mlir/IR/Value.h"
#include "llvm/IR/BasicBlock.h"
class ModuleTranslation {
public:
template <typename T = ModuleTranslation>
- static std::unique_ptr<llvm::Module> translateModule(Module &m) {
+ static std::unique_ptr<llvm::Module> translateModule(Module m) {
auto llvmModule = prepareLLVMModule(m);
T translator(m);
// Translate the given MLIR module expressed in MLIR LLVM IR dialect into an
// LLVM IR module. The MLIR LLVM IR dialect holds a pointer to an
// LLVMContext, the LLVM IR module will be created in that context.
- explicit ModuleTranslation(Module &module) : mlirModule(module) {}
+ explicit ModuleTranslation(Module module) : mlirModule(module) {}
virtual ~ModuleTranslation() {}
virtual bool convertOperation(Operation &op, llvm::IRBuilder<> &builder);
- static std::unique_ptr<llvm::Module> prepareLLVMModule(Module &m);
+ static std::unique_ptr<llvm::Module> prepareLLVMModule(Module m);
private:
bool convertFunctions();
- bool convertOneFunction(Function &func);
- void connectPHINodes(Function &func);
+ bool convertOneFunction(Function func);
+ void connectPHINodes(Function func);
bool convertBlock(Block &bb, bool ignoreArguments);
template <typename Range>
Location loc);
// Original and translated module.
- Module &mlirModule;
+ Module mlirModule;
std::unique_ptr<llvm::Module> llvmModule;
protected:
} // namespace llvm
namespace mlir {
-
class Module;
/// Convert the given MLIR module into NVVM IR. This conversion requires the
/// from the registered LLVM IR dialect. In case of error, report it
/// to the error handler registered with the MLIR context, if any (obtained from
/// the MLIR module), and return `nullptr`.
-std::unique_ptr<llvm::Module> translateModuleToNVVMIR(Module &m);
+std::unique_ptr<llvm::Module> translateModuleToNVVMIR(Module m);
} // namespace mlir
/// conversion object. This function returns failure if a type conversion
/// failed.
LLVM_NODISCARD LogicalResult applyConversionPatterns(
- Module &module, ConversionTarget &target, TypeConverter &converter,
+ Module module, ConversionTarget &target, TypeConverter &converter,
OwningRewritePatternList &&patterns);
/// Convert the given functions with the provided conversion patterns. This
namespace mlir {
class MLIRContext;
class Module;
+class OwningModuleRef;
/// Interface of the function that translates a file to MLIR. The
/// implementation should create a new MLIR Module in the given context and
/// return a pointer to it, or a nullptr in case of any error.
using TranslateToMLIRFunction =
- std::function<std::unique_ptr<Module>(llvm::StringRef, MLIRContext *)>;
+ std::function<OwningModuleRef(llvm::StringRef, MLIRContext *)>;
/// Interface of the function that translates MLIR to a different format and
/// outputs the result to a file. The implementation should return "true" on
/// error and "false" otherwise. It is allowed to modify the module.
-using TranslateFromMLIRFunction =
- std::function<bool(Module *, llvm::StringRef)>;
+using TranslateFromMLIRFunction = std::function<bool(Module, llvm::StringRef)>;
/// Use Translate[To|From]MLIRRegistration as a global initialiser that
/// registers a function and associates it with name. This requires that a
GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(Function &function) {
Builder builder(function.getContext());
- std::unique_ptr<Module> module(builder.createModule());
+ OwningModuleRef module = builder.createModule();
// TODO(herhut): Also handle called functions.
module->push_back(function.clone());
auto llvmModule = translateModuleToNVVMIR(*module);
auto cubin = convertModuleToCubin(*llvmModule, function);
- if (!cubin)
+ if (!cubin) {
return function.emitError("Translation to CUDA binary failed.");
+ }
function.setAttr(kCubinAnnotation,
builder.getStringAttr({cubin->data(), cubin->size()}));
// The types in comments give the actual types expected/returned but the API
// uses void pointers. This is fine as they have the same linkage in C.
void GpuLaunchFuncToCudaCallsPass::declareCudaFunctions(Location loc) {
- Module &module = getModule();
- Builder builder(&module);
+ Module module = getModule();
+ Builder builder(module);
if (!module.getNamedFunction(cuModuleLoadName)) {
module.push_back(
Function::create(loc, cuModuleLoadName,
ArrayRef<Value *>{cuModule, data.getResult(0)});
// Get the function from the module. The name corresponds to the name of
// the kernel function.
- auto cuModuleRef =
+ auto cuOwningModuleRef =
builder.create<LLVM::LoadOp>(loc, getPointerType(), cuModule);
auto kernelName = generateKernelNameConstant(kernelFunction, loc, builder);
auto cuFunction = allocatePointer(builder, loc);
builder.create<LLVM::CallOp>(
loc, ArrayRef<Type>{getCUResultType()},
builder.getFunctionAttr(cuModuleGetFunction),
- ArrayRef<Value *>{cuFunction, cuModuleRef, kernelName});
+ ArrayRef<Value *>{cuFunction, cuOwningModuleRef, kernelName});
// Grab the global stream needed for execution.
Function cuGetStreamHelper =
getModule().getNamedFunction(cuGetStreamHelperName);
void runOnModule() override {
llvmDialect =
getModule().getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
- auto &module = getModule();
+ auto module = getModule();
Builder builder(&getContext());
auto functions = module.getFunctions();
// Insert the `malloc` declaration if it is not already present.
Function mallocFunc =
- op->getFunction().getModule()->getNamedFunction("malloc");
+ op->getFunction().getModule().getNamedFunction("malloc");
if (!mallocFunc) {
auto mallocType =
rewriter.getFunctionType(getIndexType(), getVoidPtrType());
mallocFunc =
Function::create(rewriter.getUnknownLoc(), "malloc", mallocType);
- op->getFunction().getModule()->push_back(mallocFunc);
+ op->getFunction().getModule().push_back(mallocFunc);
}
// Allocate the underlying buffer and store a pointer to it in the MemRef
OperandAdaptor<DeallocOp> transformed(operands);
// Insert the `free` declaration if it is not already present.
- Function freeFunc = op->getFunction().getModule()->getNamedFunction("free");
+ Function freeFunc = op->getFunction().getModule().getNamedFunction("free");
if (!freeFunc) {
auto freeType = rewriter.getFunctionType(getVoidPtrType(), {});
freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType);
- op->getFunction().getModule()->push_back(freeFunc);
+ op->getFunction().getModule().push_back(freeFunc);
}
auto type = transformed.memref()->getType().cast<LLVM::LLVMType>();
}
}
-void mlir::LLVM::ensureDistinctSuccessors(Module *m) {
- for (auto f : *m) {
+void mlir::LLVM::ensureDistinctSuccessors(Module m) {
+ for (auto f : m) {
for (auto &bb : f.getBlocks()) {
::ensureDistinctSuccessors(bb);
}
struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
// Run the dialect converter on the module.
void runOnModule() override {
- Module &m = getModule();
- LLVM::ensureDistinctSuccessors(&m);
+ Module m = getModule();
+ LLVM::ensureDistinctSuccessors(m);
LLVMTypeConverter converter(&getContext());
OwningRewritePatternList patterns;
ExecutionEngine::~ExecutionEngine() = default;
Expected<std::unique_ptr<ExecutionEngine>>
-ExecutionEngine::create(Module *m,
+ExecutionEngine::create(Module m,
std::function<llvm::Error(llvm::Module *)> transformer,
ArrayRef<StringRef> sharedLibPaths) {
auto engine = llvm::make_unique<ExecutionEngine>();
if (!expectedJIT)
return expectedJIT.takeError();
- auto llvmModule = translateModuleToLLVMIR(*m);
+ auto llvmModule = translateModuleToLLVMIR(m);
if (!llvmModule)
return make_string_error("could not convert to LLVM IR");
// FIXME: the triple should be passed to the translation or dialect conversion
return emitOpError("attribute 'kernel' must be a function");
}
- auto *module = getOperation()->getFunction().getModule();
- Function kernelFunc = module->getNamedFunction(kernel());
+ auto module = getOperation()->getFunction().getModule();
+ Function kernelFunc = module.getNamedFunction(kernel());
if (!kernelFunc)
return emitError() << "kernel function '" << kernelAttr << "' is undefined";
class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
public:
void runOnModule() override {
- ModuleManager moduleManager(&getModule());
+ ModuleManager moduleManager(getModule());
for (auto func : getModule()) {
func.walk<mlir::gpu::LaunchOp>([&](mlir::gpu::LaunchOp op) {
Function outlinedFunc = outlineKernelFunc(op);
explicit ModuleState(MLIRContext *context) : context(context) {}
// Initializes module state, populating affine map state.
- void initialize(Module *module);
+ void initialize(Module module);
Twine getAttributeAlias(Attribute attr) const {
auto alias = attrToAlias.find(attr);
}
// Initializes module state, populating affine map and integer set state.
-void ModuleState::initialize(Module *module) {
+void ModuleState::initialize(Module module) {
// Initialize the symbol aliases.
initializeSymbolAliases();
// Walk the module and visit each operation.
- for (auto fn : *module) {
+ for (auto fn : module) {
visitType(fn.getType());
for (auto attr : fn.getAttrs())
ModuleState::visitAttribute(attr.second);
interleave(c.begin(), c.end(), each_fn, [&]() { os << ", "; });
}
- void print(Module *module);
+ void print(Module module);
/// Print the given attribute. If 'mayElideType' is true, some attributes are
/// printed without the type when the type matches the default used in the
}
}
-void ModulePrinter::print(Module *module) {
+void ModulePrinter::print(Module module) {
// Output the aliases at the top level.
state.printAttributeAliases(os);
state.printTypeAliases(os);
// Print the module.
- for (auto fn : *module)
+ for (auto fn : module)
print(fn);
}
void Module::print(raw_ostream &os) {
ModuleState state(getContext());
- state.initialize(this);
- ModulePrinter(os, state).print(this);
+ state.initialize(*this);
+ ModulePrinter(os, state).print(*this);
}
void Module::dump() { print(llvm::errs()); }
#include "mlir/Support/Functional.h"
using namespace mlir;
-Builder::Builder(Module *module) : context(module->getContext()) {}
+Builder::Builder(Module module) : context(module.getContext()) {}
Identifier Builder::getIdentifier(StringRef str) {
return Identifier::get(str, context);
}
-Module *Builder::createModule() { return new Module(context); }
+Module Builder::createModule() { return Module::create(context); }
//===----------------------------------------------------------------------===//
// Locations.
type(type), attrs(attrs), argAttrs(argAttrs), body(this) {}
MLIRContext *Function::getContext() { return getType().getContext(); }
+Module Function::getModule() { return impl->module; }
-Module *llvm::ilist_traits<FunctionStorage>::getContainingModule() {
- size_t Offset(
- size_t(&((Module *)nullptr->*Module::getSublistAccess(nullptr))));
+ModuleStorage *llvm::ilist_traits<FunctionStorage>::getContainingModule() {
+ size_t Offset(size_t(
+ &((ModuleStorage *)nullptr->*ModuleStorage::getSublistAccess(nullptr))));
iplist<FunctionStorage> *Anchor(static_cast<iplist<FunctionStorage> *>(this));
- return reinterpret_cast<Module *>(reinterpret_cast<char *>(Anchor) - Offset);
+ return reinterpret_cast<ModuleStorage *>(reinterpret_cast<char *>(Anchor) -
+ Offset);
}
/// This is a trait method invoked when a Function is added to a Module. We
function_iterator last) {
// If we are transferring functions within the same module, the Module
// pointer doesn't need to be updated.
- Module *curParent = getContainingModule();
+ ModuleStorage *curParent = getContainingModule();
if (curParent == otherList.getContainingModule())
return;
/// Unlink this function from its Module and delete it.
void Function::erase() {
- if (auto *module = getModule())
- getModule()->functions.erase(impl);
+ if (auto module = getModule())
+ module.impl->functions.erase(impl);
else
delete impl;
}
using namespace mlir;
/// Build a symbol table with the symbols within the given module.
-SymbolTable::SymbolTable(Module *module) : context(module->getContext()) {
- for (auto func : *module) {
+SymbolTable::SymbolTable(Module module) : context(module.getContext()) {
+ for (auto func : module) {
auto inserted = symbolTable.insert({func.getName(), func});
(void)inserted;
assert(inserted.second &&
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
auto int64Ty = lowering.convertType(rewriter.getIntegerType(64));
// Insert the `malloc` declaration if it is not already present.
- auto *module = op->getFunction().getModule();
- Function mallocFunc = module->getNamedFunction("malloc");
+ auto module = op->getFunction().getModule();
+ Function mallocFunc = module.getNamedFunction("malloc");
if (!mallocFunc) {
auto mallocType = rewriter.getFunctionType(int64Ty, voidPtrTy);
mallocFunc =
Function::create(rewriter.getUnknownLoc(), "malloc", mallocType);
- module->push_back(mallocFunc);
+ module.push_back(mallocFunc);
}
// Get MLIR types for injecting element pointer.
auto voidPtrTy =
LLVM::LLVMType::getInt8Ty(lowering.getDialect()).getPointerTo();
// Insert the `free` declaration if it is not already present.
- auto *module = op->getFunction().getModule();
- Function freeFunc = module->getNamedFunction("free");
+ auto module = op->getFunction().getModule();
+ Function freeFunc = module.getNamedFunction("free");
if (!freeFunc) {
auto freeType = rewriter.getFunctionType(voidPtrTy, {});
freeFunc = Function::create(rewriter.getUnknownLoc(), "free", freeType);
- module->push_back(freeFunc);
+ module.push_back(freeFunc);
}
// Get MLIR types for extracting element pointer.
static Function getLLVMLibraryCallImplDefinition(Function libFn) {
auto implFnName = (libFn.getName().str() + "_impl");
auto module = libFn.getModule();
- if (auto f = module->getNamedFunction(implFnName)) {
+ if (auto f = module.getNamedFunction(implFnName)) {
return f;
}
SmallVector<Type, 4> fnArgTypes;
// Insert the implementation function definition.
auto implFnDefn = Function::create(libFn.getLoc(), implFnName, implFnType);
- module->push_back(implFnDefn);
+ module.push_back(implFnDefn);
return implFnDefn;
}
assert(isa<LinalgOp>(op));
auto fnName = LinalgOp::getLibraryCallName();
auto module = op->getFunction().getModule();
- if (auto f = module->getNamedFunction(fnName)) {
+ if (auto f = module.getNamedFunction(fnName)) {
return f;
}
"have void return types");
auto libFnType = FunctionType::get(inputTypes, {}, op->getContext());
auto libFn = Function::create(op->getLoc(), fnName, libFnType);
- module->push_back(libFn);
+ module.push_back(libFn);
// Return after creating the function definition. The body will be created
// later.
return libFn;
}
void LowerLinalgToLLVMPass::runOnModule() {
- auto &module = getModule();
+ auto module = getModule();
for (auto f : module.getFunctions()) {
lowerLinalgSubViewOps(f);
public:
explicit ModuleParser(ParserState &state) : Parser(state) {}
- ParseResult parseModule(Module *module);
+ ParseResult parseModule(Module module);
private:
/// Parse an attribute alias declaration.
StringRef &name, FunctionType &type,
SmallVectorImpl<std::pair<SMLoc, StringRef>> &argNames,
SmallVectorImpl<SmallVector<NamedAttribute, 2>> &argAttrs);
- ParseResult parseFunc(Module *module);
+ ParseResult parseFunc(Module module);
};
} // end anonymous namespace
/// function-body ::= `{` block+ `}`
/// function-attributes ::= `attributes` attribute-dict
///
-ParseResult ModuleParser::parseFunc(Module *module) {
+ParseResult ModuleParser::parseFunc(Module module) {
consumeToken();
StringRef name;
// Okay, the function signature was parsed correctly, create the function now.
auto function =
Function::create(getEncodedSourceLocation(loc), name, type, attrs);
- module->push_back(function);
+ module.push_back(function);
// Parse an optional trailing location.
if (parseOptionalTrailingLocation(function))
}
/// This is the top-level module parser.
-ParseResult ModuleParser::parseModule(Module *module) {
+ParseResult ModuleParser::parseModule(Module module) {
while (1) {
switch (getToken().getKind()) {
default:
/// This parses the file specified by the indicated SourceMgr and returns an
/// MLIR module if it was valid. If not, it emits diagnostics and returns
/// null.
-Module *mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
- MLIRContext *context) {
+Module mlir::parseSourceFile(const llvm::SourceMgr &sourceMgr,
+ MLIRContext *context) {
// This is the result module we are parsing into.
- std::unique_ptr<Module> module(new Module(context));
+ OwningModuleRef module(Module::create(context));
ParserState state(sourceMgr, context);
- if (ModuleParser(state).parseModule(module.get())) {
+ if (ModuleParser(state).parseModule(*module))
return nullptr;
- }
// Make sure the parse module has no other structural problems detected by
// the verifier.
/// This parses the file specified by the indicated filename and returns an
/// MLIR module if it was valid. If not, the error message is emitted through
/// the error handler registered in the context, and a null pointer is returned.
-Module *mlir::parseSourceFile(StringRef filename, MLIRContext *context) {
+Module mlir::parseSourceFile(StringRef filename, MLIRContext *context) {
llvm::SourceMgr sourceMgr;
return parseSourceFile(filename, sourceMgr, context);
}
/// SourceMgr and returns an MLIR module if it was valid. If not, the error
/// message is emitted through the error handler registered in the context, and
/// a null pointer is returned.
-Module *mlir::parseSourceFile(StringRef filename, llvm::SourceMgr &sourceMgr,
- MLIRContext *context) {
+Module mlir::parseSourceFile(StringRef filename, llvm::SourceMgr &sourceMgr,
+ MLIRContext *context) {
if (sourceMgr.getNumBuffers() != 0) {
// TODO(b/136086478): Extend to support multiple buffers.
emitError(mlir::UnknownLoc::get(context),
/// This parses the program string to a MLIR module if it was valid. If not,
/// it emits diagnostics and returns null.
-Module *mlir::parseSourceString(StringRef moduleStr, MLIRContext *context) {
+Module mlir::parseSourceString(StringRef moduleStr, MLIRContext *context) {
auto memBuffer = MemoryBuffer::getMemBuffer(moduleStr);
if (!memBuffer)
return nullptr;
// Print the function name and a newline before the Module.
out << " (function: " << function.getName() << ")\n";
- function.getModule()->print(out);
+ function.getModule().print(out);
return;
}
}
// Print the given module.
- assert(llvm::any_isa<Module *>(ir) && "unexpected IR unit");
- llvm::any_cast<Module *>(ir)->print(out);
+ assert(llvm::any_isa<Module>(ir) && "unexpected IR unit");
+ llvm::any_cast<Module>(ir).print(out);
}
/// Instrumentation hooks.
}
/// Forwarding function to execute this pass.
-LogicalResult ModulePassBase::run(Module *module, ModuleAnalysisManager &mam) {
+LogicalResult ModulePassBase::run(Module module, ModuleAnalysisManager &mam) {
// Initialize the pass state.
passState.emplace(module, mam);
}
/// Run all of the passes in this manager over the current module.
-LogicalResult detail::ModulePassExecutor::run(Module *module,
+LogicalResult detail::ModulePassExecutor::run(Module module,
ModuleAnalysisManager &mam) {
// Run each of the held passes.
for (auto &pass : passes)
PassManager::~PassManager() {}
/// Run the passes within this manager on the provided module.
-LogicalResult PassManager::run(Module *module) {
+LogicalResult PassManager::run(Module module) {
ModuleAnalysisManager mam(module, instrumentor.get());
return mpe->run(module, mam);
}
ModulePassExecutor &operator=(const ModulePassExecutor &) = delete;
/// Run the executor on the given module.
- LogicalResult run(Module *module, ModuleAnalysisManager &mam);
+ LogicalResult run(Module module, ModuleAnalysisManager &mam);
/// Add a pass to the current executor. This takes ownership over the provided
/// pass pointer.
// Adds a one-block function named as `spirv_module` to `module` and returns the
// block. The created block will be terminated by `std.return`.
-Block *createOneBlockFunction(Builder builder, Module *module) {
+Block *createOneBlockFunction(Builder builder, Module module) {
auto fnType = builder.getFunctionType(/*inputs=*/{}, /*results=*/{});
auto fn = Function::create(builder.getUnknownLoc(), "spirv_module", fnType);
- module->push_back(fn);
+ module.push_back(fn);
auto *block = new Block();
fn.push_back(block);
// Deserializes the SPIR-V binary module stored in the file named as
// `inputFilename` and returns a module containing the SPIR-V module.
-std::unique_ptr<Module> deserializeModule(llvm::StringRef inputFilename,
- MLIRContext *context) {
+OwningModuleRef deserializeModule(llvm::StringRef inputFilename,
+ MLIRContext *context) {
Builder builder(context);
std::string errorMessage;
// converted SPIR-V ModuleOp inside a MLIR module. This should be changed to
// return the SPIR-V ModuleOp directly after module and function are migrated
// to be general ops.
- std::unique_ptr<Module> module(builder.createModule());
+ OwningModuleRef module(builder.createModule());
Block *block = createOneBlockFunction(builder, module.get());
block->push_front(spirvModule->getOperation());
using namespace mlir;
-LogicalResult serializeModule(Module *module, StringRef outputFilename) {
+LogicalResult serializeModule(Module module, StringRef outputFilename) {
if (!module)
return failure();
// wrapping the SPIR-V ModuleOp inside a MLIR module. This should be changed
// to take in the SPIR-V ModuleOp directly after module and function are
// migrated to be general ops.
- for (auto fn : *module) {
+ for (auto fn : module) {
fn.walk<spirv::ModuleOp>([&](spirv::ModuleOp spirvModule) {
if (done) {
spirvModule.emitError("found more than one 'spv.module' op");
static TranslateFromMLIRRegistration
registration("serialize-spirv",
- [](Module *module, StringRef outputFilename) {
+ [](Module module, StringRef outputFilename) {
return failed(serializeModule(module, outputFilename));
});
auto fnAttr = op.getAttrOfType<FunctionAttr>("callee");
if (!fnAttr)
return op.emitOpError("requires a 'callee' function attribute");
- auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction(
+ auto fn = op.getOperation()->getFunction().getModule().getNamedFunction(
fnAttr.getValue());
if (!fn)
return op.emitOpError() << "'" << fnAttr.getValue()
return op.emitOpError("requires 'value' to be a function reference");
// Try to find the referenced function.
- auto fn = op.getOperation()->getFunction().getModule()->getNamedFunction(
+ auto fn = op.getOperation()->getFunction().getModule().getNamedFunction(
fnAttr.getValue());
if (!fn)
return op.emitOpError("reference to undefined function 'bar'");
performActions(raw_ostream &os, bool verifyDiagnostics, bool verifyPasses,
SourceMgr &sourceMgr, MLIRContext *context,
const std::vector<const mlir::PassRegistryEntry *> &passList) {
- std::unique_ptr<Module> module(parseSourceFile(sourceMgr, context));
+ OwningModuleRef module(parseSourceFile(sourceMgr, context));
if (!module)
return failure();
applyPassManagerCLOptions(pm);
// Run the pipeline.
- if (failed(pm.run(module.get())))
+ if (failed(pm.run(*module)))
return failure();
// Print the output.
// Storage for the translation function wrappers that survive the parser.
static llvm::SmallVector<TranslateFunction, 16> wrapperStorage;
-static LogicalResult printMLIROutput(Module &module,
+static LogicalResult printMLIROutput(Module module,
llvm::StringRef outputFilename) {
if (failed(module.verify()))
return failure();
TranslateFunction wrapper = [function](StringRef inputFilename,
StringRef outputFilename,
MLIRContext *context) {
- std::unique_ptr<Module> module = function(inputFilename, context);
+ OwningModuleRef module = function(inputFilename, context);
if (!module)
return failure();
return printMLIROutput(*module, outputFilename);
MLIRContext *context) {
llvm::SourceMgr sourceMgr;
SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, context);
- auto module = std::unique_ptr<Module>(
- parseSourceFile(inputFilename, sourceMgr, context));
+ auto module =
+ OwningModuleRef(parseSourceFile(inputFilename, sourceMgr, context));
if (!module)
return failure();
return failure(function(module.get(), outputFilename));
using namespace mlir;
-std::unique_ptr<llvm::Module> mlir::translateModuleToLLVMIR(Module &m) {
+std::unique_ptr<llvm::Module> mlir::translateModuleToLLVMIR(Module m) {
return LLVM::ModuleTranslation::translateModule<>(m);
}
static TranslateFromMLIRRegistration registration(
- "mlir-to-llvmir", [](Module *module, llvm::StringRef outputFilename) {
+ "mlir-to-llvmir", [](Module module, llvm::StringRef outputFilename) {
if (!module)
return true;
- auto llvmModule = LLVM::ModuleTranslation::translateModule<>(*module);
+ auto llvmModule = LLVM::ModuleTranslation::translateModule<>(module);
if (!llvmModule)
return true;
class ModuleTranslation : public LLVM::ModuleTranslation {
public:
- explicit ModuleTranslation(Module &module)
- : LLVM::ModuleTranslation(module) {}
+ explicit ModuleTranslation(Module module) : LLVM::ModuleTranslation(module) {}
~ModuleTranslation() override {}
protected:
};
} // namespace
-std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Module &m) {
+std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Module m) {
ModuleTranslation translation(m);
auto llvmModule =
LLVM::ModuleTranslation::translateModule<ModuleTranslation>(m);
static TranslateFromMLIRRegistration
registration("mlir-to-nvvmir",
- [](Module *module, llvm::StringRef outputFilename) {
+ [](Module module, llvm::StringRef outputFilename) {
if (!module)
return true;
- auto llvmModule = mlir::translateModuleToNVVMIR(*module);
+ auto llvmModule = mlir::translateModuleToNVVMIR(module);
if (!llvmModule)
return true;
: terminator.getSuccessorOperand(1, index);
}
-void ModuleTranslation::connectPHINodes(Function &func) {
+void ModuleTranslation::connectPHINodes(Function func) {
// Skip the first block, it cannot be branched to and its arguments correspond
// to the arguments of the LLVM function.
for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
}
// Sort function blocks topologically.
-static llvm::SetVector<Block *> topologicalSort(Function &f) {
+static llvm::SetVector<Block *> topologicalSort(Function f) {
// For each blocks that has not been visited yet (i.e. that has no
// predecessors), add it to the list and traverse its successors in DFS
// preorder.
return blocks;
}
-bool ModuleTranslation::convertOneFunction(Function &func) {
+bool ModuleTranslation::convertOneFunction(Function func) {
// Clear the block and value mappings, they are only relevant within one
// function.
blockMapping.clear();
return false;
}
-std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(Module &m) {
+std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(Module m) {
auto *dialect = m.getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
assert(dialect && "LLVM dialect must be registered");
/// conversion object. If conversion fails for specific functions, those
/// functions remains unmodified.
LogicalResult
-mlir::applyConversionPatterns(Module &module, ConversionTarget &target,
+mlir::applyConversionPatterns(Module module, ConversionTarget &target,
TypeConverter &converter,
OwningRewritePatternList &&patterns) {
SmallVector<Function, 32> allFunctions(module.getFunctions());
makeFunction("vectorize_2d", {}, {memrefType, memrefType, memrefType});
mlir::Function f = owningF;
- mlir::Module module(&globalContext());
- module.push_back(f);
+ mlir::OwningModuleRef module = Module::create(&globalContext());
+ module->push_back(f);
OpBuilder builder(f.getBody());
ScopedContext scope(builder, f.getLoc());
llvm::cl::ZeroOrMore, llvm::cl::MiscFlags::CommaSeparated,
llvm::cl::cat(clOptionsCategory));
-static std::unique_ptr<Module> parseMLIRInput(StringRef inputFilename,
- MLIRContext *context) {
+static OwningModuleRef parseMLIRInput(StringRef inputFilename,
+ MLIRContext *context) {
// Set up the input file.
std::string errorMessage;
auto file = openInputFile(inputFilename, &errorMessage);
llvm::SourceMgr sourceMgr;
sourceMgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
- return std::unique_ptr<Module>(parseSourceFile(sourceMgr, context));
+ return OwningModuleRef(parseSourceFile(sourceMgr, context));
}
// Initialize the relevant subsystems of LLVM.
// - canonicalization
// - affine to standard lowering
// - standard to llvm lowering
-static LogicalResult convertAffineStandardToLLVMIR(Module *module) {
+static LogicalResult convertAffineStandardToLLVMIR(Module module) {
PassManager manager;
manager.addPass(mlir::createCanonicalizerPass());
manager.addPass(mlir::createCSEPass());
}
static Error compileAndExecuteFunctionWithMemRefs(
- Module *module, StringRef entryPoint,
+ Module module, StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer) {
- Function mainFunction = module->getNamedFunction(entryPoint);
+ Function mainFunction = module.getNamedFunction(entryPoint);
if (!mainFunction || mainFunction.getBlocks().empty()) {
return make_string_error("entry point not found");
}
}
static Error compileAndExecuteSingleFloatReturnFunction(
- Module *module, StringRef entryPoint,
+ Module module, StringRef entryPoint,
std::function<llvm::Error(llvm::Module *)> transformer) {
- Function mainFunction = module->getNamedFunction(entryPoint);
+ Function mainFunction = module.getNamedFunction(entryPoint);
if (!mainFunction || mainFunction.isExternal()) {
return make_string_error("entry point not found");
}
/// Minimal class definitions for two analyses.
struct MyAnalysis {
MyAnalysis(Function) {}
- MyAnalysis(Module *) {}
+ MyAnalysis(Module) {}
};
struct OtherAnalysis {
OtherAnalysis(Function) {}
- OtherAnalysis(Module *) {}
+ OtherAnalysis(Module) {}
};
TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) {
MLIRContext context;
// Test fine grain invalidation of the module analysis manager.
- std::unique_ptr<Module> module(new Module(&context));
- ModuleAnalysisManager mam(&*module, /*passInstrumentor=*/nullptr);
+ OwningModuleRef module(Module::create(&context));
+ ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
// Query two different analyses, but only preserve one before invalidating.
mam.getAnalysis<MyAnalysis>();
Builder builder(&context);
// Create a function and a module.
- std::unique_ptr<Module> module(new Module(&context));
+ OwningModuleRef module(Module::create(&context));
Function func1 =
Function::create(builder.getUnknownLoc(), "foo",
builder.getFunctionType(llvm::None, llvm::None));
module->push_back(func1);
// Test fine grain invalidation of the function analysis manager.
- ModuleAnalysisManager mam(&*module, /*passInstrumentor=*/nullptr);
+ ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
FunctionAnalysisManager fam = mam.slice(func1);
// Query two different analyses, but only preserve one before invalidating.
Builder builder(&context);
// Create a function and a module.
- std::unique_ptr<Module> module(new Module(&context));
+ OwningModuleRef module(Module::create(&context));
Function func1 =
Function::create(builder.getUnknownLoc(), "foo",
builder.getFunctionType(llvm::None, llvm::None));
// Test fine grain invalidation of a function analysis from within a module
// analysis manager.
- ModuleAnalysisManager mam(&*module, /*passInstrumentor=*/nullptr);
+ ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr);
// Query two different analyses, but only preserve one before invalidating.
mam.getFunctionAnalysis<MyAnalysis>(func1);