Add SymbolTable trait to spirv::ModuleOp.
authorMahesh Ravishankar <ravishankarm@google.com>
Thu, 8 Aug 2019 21:18:39 +0000 (14:18 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 8 Aug 2019 21:20:05 +0000 (14:20 -0700)
Adding the SymbolTable trait allows looking up the name of the
functions using the symbol table while verifying EntryPointOps instead
of manually tracking the function names.

PiperOrigin-RevId: 262431220

mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td
mlir/lib/Dialect/SPIRV/SPIRVOps.cpp

index 1d52860..b44d8ef 100644 (file)
@@ -31,7 +31,8 @@ include "mlir/SPIRV/SPIRVBase.td"
 #endif // SPIRV_BASE
 
 def SPV_ModuleOp : SPV_Op<"module",
-                          [SingleBlockImplicitTerminator<"ModuleEndOp">]> {
+                          [SingleBlockImplicitTerminator<"ModuleEndOp">,
+                           NativeOpTrait<"SymbolTable">]> {
   let summary = "The top-level op that defines a SPIR-V module";
 
   let description = [{
index 9366acc..a7574a5 100644 (file)
@@ -766,22 +766,19 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
   auto &op = *moduleOp.getOperation();
   auto *dialect = op.getDialect();
   auto &body = op.getRegion(0).front();
-  llvm::StringMap<FuncOp> funcNames;
   llvm::DenseMap<std::pair<FuncOp, spirv::ExecutionModel>, spirv::EntryPointOp>
       entryPoints;
 
   for (auto &op : body) {
     if (op.getDialect() == dialect) {
-      // For EntryPoint op, check that the function name is one of the specified
-      // func ops already specified, and that the function and execution model
-      // is not duplicated in EntryPointOps
+      // For EntryPoint op, check that the function and execution model is not
+      // duplicated in EntryPointOps
       if (auto entryPointOp = llvm::dyn_cast<spirv::EntryPointOp>(op)) {
-        auto it = funcNames.find(entryPointOp.fn());
-        if (it == funcNames.end()) {
+        auto funcOp = moduleOp.lookupSymbol(entryPointOp.fn());
+        if (!funcOp) {
           return entryPointOp.emitError("function '")
                  << entryPointOp.fn() << "' not found in 'spv.module'";
         }
-        auto funcOp = it->second;
         auto key = std::pair<FuncOp, spirv::ExecutionModel>(
             funcOp, entryPointOp.execution_model());
         auto entryPtIt = entryPoints.find(key);
@@ -797,8 +794,6 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
     if (!funcOp)
       return op.emitError("'spv.module' can only contain func and spv.* ops");
 
-    funcNames[funcOp.getName()] = funcOp;
-
     if (funcOp.isExternal())
       return op.emitError("'spv.module' cannot contain external functions");