/// Trivial C++ wrappers make use of the EDSC C API.
struct PythonMLIRModule {
- PythonMLIRModule() : mlirContext(), module(new mlir::Module(&mlirContext)) {}
+ PythonMLIRModule()
+ : mlirContext(), module(new mlir::Module(&mlirContext)),
+ moduleManager(module.get()) {}
PythonType makeScalarType(const std::string &mlirElemType,
unsigned bitwidth) {
}
PythonFunction getNamedFunction(const std::string &name) {
- return module->getNamedFunction(name);
+ return moduleManager.getNamedFunction(name);
}
PythonFunctionContext
mlir::MLIRContext mlirContext;
// One single module in a python-exposed MLIRContext for now.
std::unique_ptr<mlir::Module> module;
+ mlir::ModuleManager moduleManager;
std::unique_ptr<mlir::ExecutionEngine> engine;
};
UnknownLoc::get(&mlirContext), name,
mlir::Type::getFromOpaquePointer(funcType).cast<FunctionType>(), attrs,
inputAttrs);
- module->getFunctions().push_back(func);
+ moduleManager.insert(func);
return func;
}
void runOnModule() override {
auto &module = getModule();
- auto *main = module.getNamedFunction("main");
+ mlir::ModuleManager moduleManager(&module);
+ auto *main = moduleManager.getNamedFunction("main");
if (!main) {
emitError(mlir::UnknownLoc::get(module.getContext()),
"Shape inference failed: can't find a main function\n");
SmallVector<FunctionToSpecialize, 8> worklist;
worklist.push_back({main, "", {}});
while (!worklist.empty()) {
- if (failed(specialize(worklist)))
+ if (failed(specialize(worklist, moduleManager)))
return;
}
/// Run inference on a function. If a mangledName is provided, we need to
/// specialize the function: to this end clone it first.
mlir::LogicalResult
- specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist) {
+ specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist,
+ mlir::ModuleManager &moduleManager) {
FunctionToSpecialize &functionToSpecialize = funcWorklist.back();
mlir::Function *f = functionToSpecialize.function;
// We will create a new function with the concrete types for the parameters
// and clone the body into it.
if (!functionToSpecialize.mangledName.empty()) {
- if (getModule().getNamedFunction(functionToSpecialize.mangledName)) {
+ if (moduleManager.getNamedFunction(functionToSpecialize.mangledName)) {
funcWorklist.pop_back();
// Function already specialized, move on.
return mlir::success();
&getContext());
auto *newFunction = new mlir::Function(
f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs());
- getModule().getFunctions().push_back(newFunction);
+ moduleManager.insert(newFunction);
// Clone the function body
mlir::BlockAndValueMapping mapper;
// restart after the callee is processed.
if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
auto calleeName = callOp.getCalleeName();
- auto *callee = getModule().getNamedFunction(calleeName);
+ auto *callee = moduleManager.getNamedFunction(calleeName);
if (!callee) {
signalPassFailure();
return f->emitError("Shape inference failed, call to unknown '")
auto mangledName = mangle(calleeName, op->getOpOperands());
LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
<< "', mangled: '" << mangledName << "'\n");
- auto *mangledCallee = getModule().getNamedFunction(mangledName);
+ auto *mangledCallee = moduleManager.getNamedFunction(mangledName);
if (!mangledCallee) {
// Can't find the target, this is where we queue the request for the
// callee and stop the inference for the current function now.
class Module {
public:
- explicit Module(MLIRContext *context) : symbolTable(context) {}
+ explicit Module(MLIRContext *context) : context(context) {}
- MLIRContext *getContext() { return symbolTable.getContext(); }
+ MLIRContext *getContext() { return context; }
/// This is the list of functions in the module.
using FunctionListType = llvm::iplist<Function>;
// Interfaces for working with the symbol table.
/// Look up a function with the specified name, returning null if no such
- /// name exists. Function names never include the @ on them.
+ /// name exists. Function names never include the @ on them. Note: This
+ /// performs a linear scan of held symbols.
Function *getNamedFunction(StringRef name) {
- return symbolTable.lookup(name);
+ return getNamedFunction(Identifier::get(name, getContext()));
}
/// Look up a function with the specified name, returning null if no such
- /// name exists. Function names never include the @ on them.
+ /// name exists. Function names never include the @ on them. Note: This
+ /// performs a linear scan of held symbols.
Function *getNamedFunction(Identifier name) {
- return symbolTable.lookup(name);
+ auto it = llvm::find_if(
+ functions, [name](Function &fn) { return fn.getName() == name; });
+ return it == functions.end() ? nullptr : &*it;
}
/// Perform (potentially expensive) checks of invariants, used to detect
return &Module::functions;
}
- /// The symbol table used for functions.
- SymbolTable symbolTable;
+ /// The context attached to this module.
+ MLIRContext *context;
/// This is the actual list of functions the module contains.
FunctionListType functions;
};
+/// A class used to manage the symbols held by a module. This class handles
+/// ensures that symbols inserted into a module have a unique name, and provides
+/// efficent named lookup to held symbols.
+class ModuleManager {
+public:
+ ModuleManager(Module *module) : module(module), symbolTable(module) {}
+
+ /// Look up a symbol with the specified name, returning null if no such
+ /// name exists. Names must never include the @ on them.
+ template <typename NameTy> Function *getNamedFunction(NameTy &&name) const {
+ return symbolTable.lookup(name);
+ }
+
+ /// Insert a new symbol into the module, auto-renaming it as necessary.
+ void insert(Function *function) {
+ symbolTable.insert(function);
+ module->getFunctions().push_back(function);
+ }
+ void insert(Module::iterator insertPt, Function *function) {
+ symbolTable.insert(function);
+ module->getFunctions().insert(insertPt, function);
+ }
+
+ /// Remove the given symbol from the module symbol table and then erase it.
+ void erase(Function *function) {
+ symbolTable.erase(function);
+ function->erase();
+ }
+
+ /// Return the internally held module.
+ Module *getModule() const { return module; }
+
+ /// Return the context of the internal module.
+ MLIRContext *getContext() const { return module->getContext(); }
+
+private:
+ Module *module;
+ SymbolTable symbolTable;
+};
+
//===--------------------------------------------------------------------===//
// Module Operation.
//===--------------------------------------------------------------------===//
namespace mlir {
class Function;
+class Module;
class MLIRContext;
/// This class represents the symbol table used by a module for function
/// symbols.
class SymbolTable {
public:
- SymbolTable(MLIRContext *ctx) : context(ctx) {}
+ /// Build a symbol table with the symbols within the given module.
+ SymbolTable(Module *module);
/// Look up a symbol with the specified name, returning null if no such
/// name exists. Names never include the @ on them.
/// compiler bugs. On error, this reports the error through the MLIRContext and
/// returns failure.
LogicalResult Module::verify() {
- /// Check that each function is correct.
+ // Check that all functions are uniquely named.
+ llvm::StringMap<Location> nameToOrigLoc;
+ for (auto &fn : *this) {
+ auto it = nameToOrigLoc.try_emplace(fn.getName(), fn.getLoc());
+ if (!it.second)
+ return fn.emitError()
+ .append("redefinition of symbol named '", fn.getName(), "'")
+ .attachNote(it.first->second)
+ .append("see existing symbol definition here");
+ }
+
+ // Check that each function is correct.
for (auto &fn : *this)
if (failed(fn.verify()))
return failure();
using namespace mlir;
-namespace {
-
template <typename OpTy>
-void createForAllDimensions(OpBuilder &builder, Location loc,
- SmallVectorImpl<Value *> &values) {
+static void createForAllDimensions(OpBuilder &builder, Location loc,
+ SmallVectorImpl<Value *> &values) {
for (StringRef dim : {"x", "y", "z"}) {
Value *v = builder.create<OpTy>(loc, builder.getIndexType(),
builder.getStringAttr(dim));
// Add operations generating block/thread ids and gird/block dimensions at the
// beginning of `kernelFunc` and replace uses of the respective function args.
-void injectGpuIndexOperations(Location loc, Function &kernelFunc) {
+static void injectGpuIndexOperations(Location loc, Function &kernelFunc) {
OpBuilder OpBuilder(kernelFunc.getBody());
SmallVector<Value *, 12> indexOps;
createForAllDimensions<gpu::BlockId>(OpBuilder, loc, indexOps);
// Outline the `gpu.launch` operation body into a kernel function. Replace
// `gpu.return` operations by `std.return` in the generated functions.
-Function *outlineKernelFunc(Module &module, gpu::LaunchOp &launchOp) {
+static Function *outlineKernelFunc(gpu::LaunchOp launchOp) {
Location loc = launchOp.getLoc();
SmallVector<Type, 4> kernelOperandTypes(launchOp.getKernelOperandTypes());
FunctionType type =
- FunctionType::get(kernelOperandTypes, {}, module.getContext());
+ FunctionType::get(kernelOperandTypes, {}, launchOp.getContext());
std::string kernelFuncName =
Twine(launchOp.getOperation()->getFunction()->getName(), "_kernel").str();
Function *outlinedFunc = new mlir::Function(loc, kernelFuncName, type);
outlinedFunc->getBody().takeBody(launchOp.getBody());
- Builder builder(&module);
+ Builder builder(launchOp.getContext());
outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
builder.getUnitAttr());
injectGpuIndexOperations(loc, *outlinedFunc);
replacer.create<ReturnOp>(op.getLoc());
op.erase();
});
- module.getFunctions().push_back(outlinedFunc);
return outlinedFunc;
}
// Replace `gpu.launch` operations with an `gpu.launch_func` operation launching
// `kernelFunc`.
-void convertToLaunchFuncOp(gpu::LaunchOp &launchOp, Function &kernelFunc) {
+static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp,
+ Function &kernelFunc) {
OpBuilder builder(launchOp);
SmallVector<Value *, 4> kernelOperandValues(
launchOp.getKernelOperandValues());
launchOp.erase();
}
-} // namespace
+namespace {
class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
public:
void runOnModule() override {
+ ModuleManager moduleManager(&getModule());
for (auto &func : getModule()) {
func.walk<mlir::gpu::LaunchOp>([&](mlir::gpu::LaunchOp op) {
- Function *outlinedFunc = outlineKernelFunc(getModule(), op);
+ Function *outlinedFunc = outlineKernelFunc(op);
+ moduleManager.insert(outlinedFunc);
convertToLaunchFuncOp(op, *outlinedFunc);
});
}
}
};
+} // namespace
+
ModulePassBase *mlir::createGpuKernelOutliningPass() {
return new GpuKernelOutliningPass();
}
/// keep the module pointer and module symbol table up to date.
void llvm::ilist_traits<Function>::addNodeToList(Function *function) {
assert(!function->getModule() && "already in a module!");
- auto *module = getContainingModule();
- function->module = module;
-
- // Add this function to the symbol table of the module.
- module->symbolTable.insert(function);
+ function->module = getContainingModule();
}
/// This is a trait method invoked when a Function is removed from a Module.
/// We keep the module pointer up to date.
void llvm::ilist_traits<Function>::removeNodeFromList(Function *function) {
assert(function->module && "not already in a module!");
-
- // Remove the symbol table entry.
- function->module->symbolTable.erase(function);
function->module = nullptr;
}
// =============================================================================
#include "mlir/IR/SymbolTable.h"
-#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
using namespace mlir;
+/// Build a symbol table with the symbols within the given module.
+SymbolTable::SymbolTable(Module *module) : context(module->getContext()) {
+ for (auto &func : *module) {
+ auto inserted = symbolTable.insert({func.getName(), &func});
+ (void)inserted;
+ assert(inserted.second &&
+ "expected module to contain uniquely named functions");
+ }
+}
+
/// Look up a symbol with the specified name, returning null if no such name
/// exists. Names never include the @ on them.
Function *SymbolTable::lookup(StringRef name) const {
new Function(getEncodedSourceLocation(loc), name, type, attrs);
module->getFunctions().push_back(function);
- // Verify no name collision / redefinition.
- if (function->getName() != name)
- return emitError(loc, "redefinition of function named '") << name << "'";
-
// Parse an optional trailing location.
if (parseOptionalTrailingLocation(function))
return failure();
// -----
-func @redef()
-func @redef() // expected-error {{redefinition of function named 'redef'}}
+func @redef() // expected-note {{see existing symbol definition here}}
+func @redef() // expected-error {{redefinition of symbol named 'redef'}}
// -----