Extract the automatic function renaming and symbol table out of Module.
authorRiver Riddle <riverriddle@google.com>
Sun, 30 Jun 2019 17:44:07 +0000 (10:44 -0700)
committerjpienaar <jpienaar@google.com>
Mon, 1 Jul 2019 16:55:13 +0000 (09:55 -0700)
This functionality is now moved to a new class, ModuleManager. This class allows for inserting functions into a module, and will auto-rename them on insert to ensure a unique name. This now means that users adding new functions to a module must ensure that the function name is unique, as the Module will no longer do it automatically. This also means that Module::getNamedFunction now operates in O(N) instead of the O(c) time it did before. This simplifies the move of Modules to Operations as the ModuleOp will not be able to have this functionality.

PiperOrigin-RevId: 255846088

mlir/bindings/python/pybind.cpp
mlir/examples/toy/Ch5/mlir/ShapeInferencePass.cpp
mlir/include/mlir/IR/Module.h
mlir/include/mlir/IR/SymbolTable.h
mlir/lib/Analysis/Verifier.cpp
mlir/lib/GPU/Transforms/KernelOutlining.cpp
mlir/lib/IR/Function.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/lib/Parser/Parser.cpp
mlir/test/IR/invalid.mlir

index bfbe6f9..222ef52 100644 (file)
@@ -143,7 +143,9 @@ struct PythonFunction {
 
 /// Trivial C++ wrappers make use of the EDSC C API.
 struct PythonMLIRModule {
-  PythonMLIRModule() : mlirContext(), module(new mlir::Module(&mlirContext)) {}
+  PythonMLIRModule()
+      : mlirContext(), module(new mlir::Module(&mlirContext)),
+        moduleManager(module.get()) {}
 
   PythonType makeScalarType(const std::string &mlirElemType,
                             unsigned bitwidth) {
@@ -220,7 +222,7 @@ struct PythonMLIRModule {
   }
 
   PythonFunction getNamedFunction(const std::string &name) {
-    return module->getNamedFunction(name);
+    return moduleManager.getNamedFunction(name);
   }
 
   PythonFunctionContext
@@ -232,6 +234,7 @@ private:
   mlir::MLIRContext mlirContext;
   // One single module in a python-exposed MLIRContext for now.
   std::unique_ptr<mlir::Module> module;
+  mlir::ModuleManager moduleManager;
   std::unique_ptr<mlir::ExecutionEngine> engine;
 };
 
@@ -595,7 +598,7 @@ PythonMLIRModule::declareFunction(const std::string &name,
       UnknownLoc::get(&mlirContext), name,
       mlir::Type::getFromOpaquePointer(funcType).cast<FunctionType>(), attrs,
       inputAttrs);
-  module->getFunctions().push_back(func);
+  moduleManager.insert(func);
   return func;
 }
 
index 2552cbe..cad2ded 100644 (file)
@@ -120,7 +120,8 @@ public:
 
   void runOnModule() override {
     auto &module = getModule();
-    auto *main = module.getNamedFunction("main");
+    mlir::ModuleManager moduleManager(&module);
+    auto *main = moduleManager.getNamedFunction("main");
     if (!main) {
       emitError(mlir::UnknownLoc::get(module.getContext()),
                 "Shape inference failed: can't find a main function\n");
@@ -133,7 +134,7 @@ public:
     SmallVector<FunctionToSpecialize, 8> worklist;
     worklist.push_back({main, "", {}});
     while (!worklist.empty()) {
-      if (failed(specialize(worklist)))
+      if (failed(specialize(worklist, moduleManager)))
         return;
     }
 
@@ -151,7 +152,8 @@ public:
   /// Run inference on a function. If a mangledName is provided, we need to
   /// specialize the function: to this end clone it first.
   mlir::LogicalResult
-  specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist) {
+  specialize(SmallVectorImpl<FunctionToSpecialize> &funcWorklist,
+             mlir::ModuleManager &moduleManager) {
     FunctionToSpecialize &functionToSpecialize = funcWorklist.back();
     mlir::Function *f = functionToSpecialize.function;
 
@@ -159,7 +161,7 @@ public:
     // We will create a new function with the concrete types for the parameters
     // and clone the body into it.
     if (!functionToSpecialize.mangledName.empty()) {
-      if (getModule().getNamedFunction(functionToSpecialize.mangledName)) {
+      if (moduleManager.getNamedFunction(functionToSpecialize.mangledName)) {
         funcWorklist.pop_back();
         // Function already specialized, move on.
         return mlir::success();
@@ -171,7 +173,7 @@ public:
                                           &getContext());
       auto *newFunction = new mlir::Function(
           f->getLoc(), functionToSpecialize.mangledName, type, f->getAttrs());
-      getModule().getFunctions().push_back(newFunction);
+      moduleManager.insert(newFunction);
 
       // Clone the function body
       mlir::BlockAndValueMapping mapper;
@@ -293,7 +295,7 @@ public:
       // restart after the callee is processed.
       if (auto callOp = llvm::dyn_cast<GenericCallOp>(op)) {
         auto calleeName = callOp.getCalleeName();
-        auto *callee = getModule().getNamedFunction(calleeName);
+        auto *callee = moduleManager.getNamedFunction(calleeName);
         if (!callee) {
           signalPassFailure();
           return f->emitError("Shape inference failed, call to unknown '")
@@ -302,7 +304,7 @@ public:
         auto mangledName = mangle(calleeName, op->getOpOperands());
         LLVM_DEBUG(llvm::dbgs() << "Found callee to infer: '" << calleeName
                                 << "', mangled: '" << mangledName << "'\n");
-        auto *mangledCallee = getModule().getNamedFunction(mangledName);
+        auto *mangledCallee = moduleManager.getNamedFunction(mangledName);
         if (!mangledCallee) {
           // Can't find the target, this is where we queue the request for the
           // callee and stop the inference for the current function now.
index a286c6e..8161a30 100644 (file)
@@ -30,9 +30,9 @@ namespace mlir {
 
 class Module {
 public:
-  explicit Module(MLIRContext *context) : symbolTable(context) {}
+  explicit Module(MLIRContext *context) : context(context) {}
 
-  MLIRContext *getContext() { return symbolTable.getContext(); }
+  MLIRContext *getContext() { return context; }
 
   /// This is the list of functions in the module.
   using FunctionListType = llvm::iplist<Function>;
@@ -50,15 +50,19 @@ public:
   // Interfaces for working with the symbol table.
 
   /// Look up a function with the specified name, returning null if no such
-  /// name exists.  Function names never include the @ on them.
+  /// name exists. Function names never include the @ on them. Note: This
+  /// performs a linear scan of held symbols.
   Function *getNamedFunction(StringRef name) {
-    return symbolTable.lookup(name);
+    return getNamedFunction(Identifier::get(name, getContext()));
   }
 
   /// Look up a function with the specified name, returning null if no such
-  /// name exists.  Function names never include the @ on them.
+  /// name exists. Function names never include the @ on them. Note: This
+  /// performs a linear scan of held symbols.
   Function *getNamedFunction(Identifier name) {
-    return symbolTable.lookup(name);
+    auto it = llvm::find_if(
+        functions, [name](Function &fn) { return fn.getName() == name; });
+    return it == functions.end() ? nullptr : &*it;
   }
 
   /// Perform (potentially expensive) checks of invariants, used to detect
@@ -78,13 +82,53 @@ private:
     return &Module::functions;
   }
 
-  /// The symbol table used for functions.
-  SymbolTable symbolTable;
+  /// The context attached to this module.
+  MLIRContext *context;
 
   /// This is the actual list of functions the module contains.
   FunctionListType functions;
 };
 
+/// A class used to manage the symbols held by a module. This class handles
+/// ensures that symbols inserted into a module have a unique name, and provides
+/// efficent named lookup to held symbols.
+class ModuleManager {
+public:
+  ModuleManager(Module *module) : module(module), symbolTable(module) {}
+
+  /// Look up a symbol with the specified name, returning null if no such
+  /// name exists. Names must never include the @ on them.
+  template <typename NameTy> Function *getNamedFunction(NameTy &&name) const {
+    return symbolTable.lookup(name);
+  }
+
+  /// Insert a new symbol into the module, auto-renaming it as necessary.
+  void insert(Function *function) {
+    symbolTable.insert(function);
+    module->getFunctions().push_back(function);
+  }
+  void insert(Module::iterator insertPt, Function *function) {
+    symbolTable.insert(function);
+    module->getFunctions().insert(insertPt, function);
+  }
+
+  /// Remove the given symbol from the module symbol table and then erase it.
+  void erase(Function *function) {
+    symbolTable.erase(function);
+    function->erase();
+  }
+
+  /// Return the internally held module.
+  Module *getModule() const { return module; }
+
+  /// Return the context of the internal module.
+  MLIRContext *getContext() const { return module->getContext(); }
+
+private:
+  Module *module;
+  SymbolTable symbolTable;
+};
+
 //===--------------------------------------------------------------------===//
 // Module Operation.
 //===--------------------------------------------------------------------===//
index 8c30dfe..3074958 100644 (file)
 
 namespace mlir {
 class Function;
+class Module;
 class MLIRContext;
 
 /// This class represents the symbol table used by a module for function
 /// symbols.
 class SymbolTable {
 public:
-  SymbolTable(MLIRContext *ctx) : context(ctx) {}
+  /// Build a symbol table with the symbols within the given module.
+  SymbolTable(Module *module);
 
   /// Look up a symbol with the specified name, returning null if no such
   /// name exists. Names never include the @ on them.
index 2092dbf..1330fe0 100644 (file)
@@ -367,7 +367,18 @@ LogicalResult Operation::verify() {
 /// compiler bugs.  On error, this reports the error through the MLIRContext and
 /// returns failure.
 LogicalResult Module::verify() {
-  /// Check that each function is correct.
+  // Check that all functions are uniquely named.
+  llvm::StringMap<Location> nameToOrigLoc;
+  for (auto &fn : *this) {
+    auto it = nameToOrigLoc.try_emplace(fn.getName(), fn.getLoc());
+    if (!it.second)
+      return fn.emitError()
+          .append("redefinition of symbol named '", fn.getName(), "'")
+          .attachNote(it.first->second)
+          .append("see existing symbol definition here");
+  }
+
+  // Check that each function is correct.
   for (auto &fn : *this)
     if (failed(fn.verify()))
       return failure();
index 56f5251..46363f0 100644 (file)
 
 using namespace mlir;
 
-namespace {
-
 template <typename OpTy>
-void createForAllDimensions(OpBuilder &builder, Location loc,
-                            SmallVectorImpl<Value *> &values) {
+static void createForAllDimensions(OpBuilder &builder, Location loc,
+                                   SmallVectorImpl<Value *> &values) {
   for (StringRef dim : {"x", "y", "z"}) {
     Value *v = builder.create<OpTy>(loc, builder.getIndexType(),
                                     builder.getStringAttr(dim));
@@ -42,7 +40,7 @@ void createForAllDimensions(OpBuilder &builder, Location loc,
 
 // Add operations generating block/thread ids and gird/block dimensions at the
 // beginning of `kernelFunc` and replace uses of the respective function args.
-void injectGpuIndexOperations(Location loc, Function &kernelFunc) {
+static void injectGpuIndexOperations(Location loc, Function &kernelFunc) {
   OpBuilder OpBuilder(kernelFunc.getBody());
   SmallVector<Value *, 12> indexOps;
   createForAllDimensions<gpu::BlockId>(OpBuilder, loc, indexOps);
@@ -60,16 +58,16 @@ void injectGpuIndexOperations(Location loc, Function &kernelFunc) {
 
 // Outline the `gpu.launch` operation body into a kernel function. Replace
 // `gpu.return` operations by `std.return` in the generated functions.
-Function *outlineKernelFunc(Module &module, gpu::LaunchOp &launchOp) {
+static Function *outlineKernelFunc(gpu::LaunchOp launchOp) {
   Location loc = launchOp.getLoc();
   SmallVector<Type, 4> kernelOperandTypes(launchOp.getKernelOperandTypes());
   FunctionType type =
-      FunctionType::get(kernelOperandTypes, {}, module.getContext());
+      FunctionType::get(kernelOperandTypes, {}, launchOp.getContext());
   std::string kernelFuncName =
       Twine(launchOp.getOperation()->getFunction()->getName(), "_kernel").str();
   Function *outlinedFunc = new mlir::Function(loc, kernelFuncName, type);
   outlinedFunc->getBody().takeBody(launchOp.getBody());
-  Builder builder(&module);
+  Builder builder(launchOp.getContext());
   outlinedFunc->setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
                         builder.getUnitAttr());
   injectGpuIndexOperations(loc, *outlinedFunc);
@@ -78,13 +76,13 @@ Function *outlineKernelFunc(Module &module, gpu::LaunchOp &launchOp) {
     replacer.create<ReturnOp>(op.getLoc());
     op.erase();
   });
-  module.getFunctions().push_back(outlinedFunc);
   return outlinedFunc;
 }
 
 // Replace `gpu.launch` operations with an `gpu.launch_func` operation launching
 // `kernelFunc`.
-void convertToLaunchFuncOp(gpu::LaunchOp &launchOp, Function &kernelFunc) {
+static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp,
+                                  Function &kernelFunc) {
   OpBuilder builder(launchOp);
   SmallVector<Value *, 4> kernelOperandValues(
       launchOp.getKernelOperandValues());
@@ -94,20 +92,24 @@ void convertToLaunchFuncOp(gpu::LaunchOp &launchOp, Function &kernelFunc) {
   launchOp.erase();
 }
 
-} // namespace
+namespace {
 
 class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
 public:
   void runOnModule() override {
+    ModuleManager moduleManager(&getModule());
     for (auto &func : getModule()) {
       func.walk<mlir::gpu::LaunchOp>([&](mlir::gpu::LaunchOp op) {
-        Function *outlinedFunc = outlineKernelFunc(getModule(), op);
+        Function *outlinedFunc = outlineKernelFunc(op);
+        moduleManager.insert(outlinedFunc);
         convertToLaunchFuncOp(op, *outlinedFunc);
       });
     }
   }
 };
 
+} // namespace
+
 ModulePassBase *mlir::createGpuKernelOutliningPass() {
   return new GpuKernelOutliningPass();
 }
index 6f68397..7d17ed1 100644 (file)
@@ -52,20 +52,13 @@ Module *llvm::ilist_traits<Function>::getContainingModule() {
 /// keep the module pointer and module symbol table up to date.
 void llvm::ilist_traits<Function>::addNodeToList(Function *function) {
   assert(!function->getModule() && "already in a module!");
-  auto *module = getContainingModule();
-  function->module = module;
-
-  // Add this function to the symbol table of the module.
-  module->symbolTable.insert(function);
+  function->module = getContainingModule();
 }
 
 /// This is a trait method invoked when a Function is removed from a Module.
 /// We keep the module pointer up to date.
 void llvm::ilist_traits<Function>::removeNodeFromList(Function *function) {
   assert(function->module && "not already in a module!");
-
-  // Remove the symbol table entry.
-  function->module->symbolTable.erase(function);
   function->module = nullptr;
 }
 
index da9ff0f..a0819a7 100644 (file)
 // =============================================================================
 
 #include "mlir/IR/SymbolTable.h"
-#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
 
 using namespace mlir;
 
+/// Build a symbol table with the symbols within the given module.
+SymbolTable::SymbolTable(Module *module) : context(module->getContext()) {
+  for (auto &func : *module) {
+    auto inserted = symbolTable.insert({func.getName(), &func});
+    (void)inserted;
+    assert(inserted.second &&
+           "expected module to contain uniquely named functions");
+  }
+}
+
 /// Look up a symbol with the specified name, returning null if no such name
 /// exists. Names never include the @ on them.
 Function *SymbolTable::lookup(StringRef name) const {
index 2cf6d3f..44f0596 100644 (file)
@@ -4053,10 +4053,6 @@ ParseResult ModuleParser::parseFunc(Module *module) {
       new Function(getEncodedSourceLocation(loc), name, type, attrs);
   module->getFunctions().push_back(function);
 
-  // Verify no name collision / redefinition.
-  if (function->getName() != name)
-    return emitError(loc, "redefinition of function named '") << name << "'";
-
   // Parse an optional trailing location.
   if (parseOptionalTrailingLocation(function))
     return failure();
index f55fc5a..faafc6a 100644 (file)
@@ -461,8 +461,8 @@ func @return_inside_loop() {
 
 // -----
 
-func @redef()
-func @redef()  // expected-error {{redefinition of function named 'redef'}}
+func @redef()  // expected-note {{see existing symbol definition here}}
+func @redef()  // expected-error {{redefinition of symbol named 'redef'}}
 
 // -----