[mlir] Use SymbolUserOpInterface in LLVM::AddressOfOp verifier
authorEugene Zhulenev <ezhulenev@google.com>
Fri, 5 Aug 2022 17:35:39 +0000 (10:35 -0700)
committerEugene Zhulenev <ezhulenev@google.com>
Fri, 5 Aug 2022 17:51:30 +0000 (10:51 -0700)
Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D131271

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td
mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp
mlir/test/Dialect/LLVMIR/global.mlir

index 1eb861c..e658271 100644 (file)
@@ -974,7 +974,8 @@ def UnnamedAddr : LLVM_EnumAttr<
   let cppNamespace = "::mlir::LLVM";
 }
 
-def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof", [NoSideEffect]> {
+def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof",
+    [NoSideEffect, DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
   let arguments = (ins FlatSymbolRefAttr:$global_name);
   let results = (outs LLVM_AnyPointer:$res);
 
@@ -1036,7 +1037,6 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof", [NoSideEffect]> {
   }];
 
   let assemblyFormat = "$global_name attr-dict `:` type($res)";
-  let hasVerifier = 1;
 }
 
 def LLVM_MetadataOp : LLVM_Op<"metadata", [
index 76cb05a..3d9ec17 100644 (file)
@@ -1729,27 +1729,28 @@ LogicalResult ResumeOp::verify() {
 // Verifier for LLVM::AddressOfOp.
 //===----------------------------------------------------------------------===//
 
-static Operation *lookupSymbolInModule(Operation *parent, StringRef name) {
-  Operation *module = parent;
+static Operation *parentLLVMModule(Operation *op) {
+  Operation *module = op->getParentOp();
   while (module && !satisfiesLLVMModule(module))
     module = module->getParentOp();
   assert(module && "unexpected operation outside of a module");
-  return mlir::SymbolTable::lookupSymbolIn(module, name);
+  return module;
 }
 
 GlobalOp AddressOfOp::getGlobal() {
   return dyn_cast_or_null<GlobalOp>(
-      lookupSymbolInModule((*this)->getParentOp(), getGlobalName()));
+      SymbolTable::lookupSymbolIn(parentLLVMModule(*this), getGlobalName()));
 }
 
 LLVMFuncOp AddressOfOp::getFunction() {
   return dyn_cast_or_null<LLVMFuncOp>(
-      lookupSymbolInModule((*this)->getParentOp(), getGlobalName()));
+      SymbolTable::lookupSymbolIn(parentLLVMModule(*this), getGlobalName()));
 }
 
-LogicalResult AddressOfOp::verify() {
+LogicalResult
+AddressOfOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
   Operation *symbol =
-      lookupSymbolInModule((*this)->getParentOp(), getGlobalName());
+      symbolTable.lookupSymbolIn(parentLLVMModule(*this), getGlobalNameAttr());
 
   auto global = dyn_cast_or_null<GlobalOp>(symbol);
   auto function = dyn_cast_or_null<LLVMFuncOp>(symbol);
index b53f6d4..9454662 100644 (file)
@@ -155,6 +155,7 @@ func.func @foo() {
   // them to trigger the attribute type mismatch error.
   // expected-error @+1 {{invalid kind of attribute specified}}
   llvm.mlir.addressof "foo" : i64 : !llvm.ptr<func<void ()>>
+  llvm.return
 }
 
 // -----
@@ -162,6 +163,7 @@ func.func @foo() {
 func.func @foo() {
   // expected-error @+1 {{must reference a global defined by 'llvm.mlir.global'}}
   llvm.mlir.addressof @foo : !llvm.ptr<func<void ()>>
+  llvm.return
 }
 
 // -----
@@ -171,6 +173,7 @@ llvm.mlir.global internal @foo(0: i32) : i32
 func.func @bar() {
   // expected-error @+1 {{the type must be a pointer to the type of the referenced global}}
   llvm.mlir.addressof @foo : !llvm.ptr<i64>
+  llvm.return
 }
 
 // -----
@@ -180,6 +183,7 @@ llvm.func @foo()
 llvm.func @bar() {
   // expected-error @+1 {{the type must be a pointer to the type of the referenced function}}
   llvm.mlir.addressof @foo : !llvm.ptr<i8>
+  llvm.return
 }
 
 // -----
@@ -211,6 +215,7 @@ llvm.mlir.global internal @g(32 : i64) {addr_space = 3: i32} : i64
 func.func @mismatch_addr_space_implicit_global() {
   // expected-error @+1 {{pointer address space must match address space of the referenced global}}
   llvm.mlir.addressof @g : !llvm.ptr<i64>
+  llvm.return
 }
 
 // -----
@@ -219,6 +224,7 @@ llvm.mlir.global internal @g(32 : i64) {addr_space = 3: i32} : i64
 func.func @mismatch_addr_space() {
   // expected-error @+1 {{pointer address space must match address space of the referenced global}}
   llvm.mlir.addressof @g : !llvm.ptr<i64, 4>
+  llvm.return
 }
 // -----
 
@@ -227,6 +233,7 @@ llvm.mlir.global internal @g(32 : i64) {addr_space = 3: i32} : i64
 func.func @mismatch_addr_space_opaque() {
   // expected-error @+1 {{pointer address space must match address space of the referenced global}}
   llvm.mlir.addressof @g : !llvm.ptr<4>
+  llvm.return
 }
 
 // -----