[mlir] Expose getNearestSymbolTable as SymbolTable class method
authorLei Zhang <antiagainst@google.com>
Sun, 26 Jan 2020 15:55:17 +0000 (10:55 -0500)
committerLei Zhang <antiagainst@google.com>
Sun, 26 Jan 2020 22:35:26 +0000 (17:35 -0500)
This is a generally useful utility function for interacting with
symbol tables.

Differential Revision: https://reviews.llvm.org/D73433

mlir/include/mlir/IR/SymbolTable.h
mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp
mlir/lib/IR/SymbolTable.cpp

index 2fca74c..bee27ad 100644 (file)
@@ -84,6 +84,10 @@ public:
   /// Sets the visibility of the given symbol operation.
   static void setSymbolVisibility(Operation *symbol, Visibility vis);
 
+  /// Returns the nearest symbol table from a given operation `from`. Returns
+  /// nullptr if no valid parent symbol table could be found.
+  static Operation *getNearestSymbolTable(Operation *from);
+
   /// Returns the operation registered with the given symbol name with the
   /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
   /// with the 'OpTrait::SymbolTable' trait.
index aacb0f6..770623a 100644 (file)
@@ -278,9 +278,7 @@ getOrInsertBuiltinVariable(Block &body, Location loc, spirv::BuiltIn builtin,
 Value mlir::spirv::getBuiltinVariableValue(Operation *op,
                                            spirv::BuiltIn builtin,
                                            OpBuilder &builder) {
-  Operation *parent = op->getParentOp();
-  while (parent && !parent->hasTrait<OpTrait::SymbolTable>())
-    parent = parent->getParentOp();
+  Operation *parent = SymbolTable::getNearestSymbolTable(op->getParentOp());
   if (!parent) {
     op->emitError("expected operation to be within a module-like op");
     return nullptr;
index 8a1a6b8..4265ace 100644 (file)
@@ -20,23 +20,6 @@ static bool isPotentiallyUnknownSymbolTable(Operation *op) {
   return !op->getDialect() && op->getNumRegions() == 1;
 }
 
-/// Returns the nearest symbol table from a given operation `from`. Returns
-/// nullptr if no valid parent symbol table could be found.
-static Operation *getNearestSymbolTable(Operation *from) {
-  assert(from && "expected valid operation");
-  if (isPotentiallyUnknownSymbolTable(from))
-    return nullptr;
-
-  while (!from->hasTrait<OpTrait::SymbolTable>()) {
-    from = from->getParentOp();
-
-    // Check that this is a valid op and isn't an unknown symbol table.
-    if (!from || isPotentiallyUnknownSymbolTable(from))
-      return nullptr;
-  }
-  return from;
-}
-
 /// Returns the string name of the given symbol, or None if this is not a
 /// symbol.
 static Optional<StringRef> getNameIfSymbol(Operation *symbol) {
@@ -212,6 +195,23 @@ void SymbolTable::setSymbolVisibility(Operation *symbol, Visibility vis) {
   symbol->setAttr(getVisibilityAttrName(), StringAttr::get(visName, ctx));
 }
 
+/// Returns the nearest symbol table from a given operation `from`. Returns
+/// nullptr if no valid parent symbol table could be found.
+Operation *SymbolTable::getNearestSymbolTable(Operation *from) {
+  assert(from && "expected valid operation");
+  if (isPotentiallyUnknownSymbolTable(from))
+    return nullptr;
+
+  while (!from->hasTrait<OpTrait::SymbolTable>()) {
+    from = from->getParentOp();
+
+    // Check that this is a valid op and isn't an unknown symbol table.
+    if (!from || isPotentiallyUnknownSymbolTable(from))
+      return nullptr;
+  }
+  return from;
+}
+
 /// Returns the operation registered with the given symbol name with the
 /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
 /// with the 'OpTrait::SymbolTable' trait. Returns nullptr if no valid symbol
@@ -466,7 +466,7 @@ static Optional<WalkResult> walkSymbolScopes(
     if (limitAncestor == symbol) {
       // Check that the nearest symbol table is 'symbol's parent. SymbolRefAttr
       // doesn't support parent references.
-      if (getNearestSymbolTable(limit) != symbol->getParentOp())
+      if (SymbolTable::getNearestSymbolTable(limit) != symbol->getParentOp())
         return WalkResult::advance();
       return callback(SymbolRefAttr::get(symbolName, symbol->getContext()),
                       limit);