Add support for walking the uses of a symbol.
authorRiver Riddle <riverriddle@google.com>
Tue, 8 Oct 2019 17:21:26 +0000 (10:21 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 8 Oct 2019 17:21:59 +0000 (10:21 -0700)
MLIR uses symbol references to model references to many global entities, such as functions/variables/etc. Before this change, there is no way to actually reason about the uses of such entities. This change provides a walker for symbol references(via SymbolTable::walkSymbolUses), as well as 'use_empty' support(via SymbolTable::symbol_use_empty). It also resolves some deficiencies in the LangRef definition of SymbolRefAttr, namely the restrictions on where a SymbolRefAttr can be stored, ArrayAttr and DictionaryAttr, and the relationship with operations containing the SymbolTable trait.

PiperOrigin-RevId: 273549331

mlir/g3doc/LangRef.md
mlir/include/mlir/IR/SymbolTable.h
mlir/lib/IR/SymbolTable.cpp
mlir/test/IR/test-symbol-uses.mlir [new file with mode: 0644]
mlir/test/lib/CMakeLists.txt
mlir/test/lib/IR/CMakeLists.txt [new file with mode: 0644]
mlir/test/lib/IR/TestSymbolUses.cpp [new file with mode: 0644]
mlir/tools/mlir-opt/CMakeLists.txt

index da57c9e..4fdcbba 100644 (file)
@@ -804,7 +804,7 @@ memref<16x32xf32, #identity, memspace0>
 // The memref index space is of size %M x %N, while %B1 and %B2 bind to the
 // symbols s0, s1 respectively of the layout map #tiled_dynamic. Data tiles of
 // size %B1 x %B2 in the logical space will be stored contiguously in memory.
-// The allocation size will be (%M ceildiv %B1) * %B1 * (%N ceildiv %B2) * %B2 
+// The allocation size will be (%M ceildiv %B1) * %B1 * (%N ceildiv %B2) * %B2
 // f32 elements.
 %T = alloc(%M, %N) [%B1, %B2] : memref<?x?xf32, #tiled_dynamic>
 
@@ -860,7 +860,6 @@ integral. In addition, an index map must specify the size of each of its range
 dimensions onto which it maps. Index map symbols must be listed in order with
 symbols for dynamic dimension sizes first, followed by other required symbols.
 
-
 ##### Layout Map
 
 A layout map is a [semi-affine map](Dialects/Affine.md#semi-affine-maps) which
@@ -1360,7 +1359,22 @@ symbol-ref-attribute ::= symbol-ref-id
 ```
 
 A symbol reference attribute is a literal attribute that represents a named
-reference to a given operation.
+reference to an operation that is nested within an operation with the
+`OpTrait::SymbolTable` trait. As such, this reference is given meaning by the
+nearest parent operation containing the `OpTrait::SymbolTable` trait.
+
+This attribute can only be held internally by
+[array attributes](#array-attribute) and
+[dictionary attributes](#dictionary-attribute)(including the top-level operation
+attribute dictionary), i.e. no other attribute kinds such as Locations or
+extended attribute kinds. If a reference to a symbol is necessary from outside
+of the symbol table that the symbol is defined in, a
+[string attribute](string-attribute) can be used to refer to the symbol name.
+
+**Rationale:** Given that MLIR models global accesses with symbol references, to
+enable efficient multi-threading, it becomes difficult to effectively reason
+about their uses. By restricting the places that can legally hold a symbol
+reference, we can always opaquely reason about a symbols usage characteristics.
 
 #### Type Attribute
 
index 3b89d73..04925be 100644 (file)
@@ -53,6 +53,10 @@ public:
   /// Return the name of the attribute used for symbol names.
   static StringRef getSymbolAttrName() { return "sym_name"; }
 
+  //===--------------------------------------------------------------------===//
+  // Symbol Utilities
+  //===--------------------------------------------------------------------===//
+
   /// Returns the operation registered with the given symbol name with the
   /// regions of 'symbolTableOp'. 'symbolTableOp' is required to be an operation
   /// with the 'OpTrait::SymbolTable' trait.
@@ -64,6 +68,47 @@ public:
   /// found.
   static Operation *lookupNearestSymbolFrom(Operation *from, StringRef symbol);
 
+  /// This class represents a specific symbol use.
+  class SymbolUse {
+  public:
+    SymbolUse(Operation *op, SymbolRefAttr symbolRef)
+        : owner(op), symbolRef(symbolRef) {}
+
+    /// Return the operation user of this symbol reference.
+    Operation *getUser() const { return owner; }
+
+    /// Return the symbol reference that this use represents.
+    SymbolRefAttr getSymbolRef() const { return symbolRef; }
+
+  private:
+    /// The operation that this access is held by.
+    Operation *owner;
+
+    /// The symbol reference that this use represents.
+    SymbolRefAttr symbolRef;
+  };
+
+  /// Walk all of the uses, for any symbol, that are nested within the given
+  /// operation 'from', invoking the provided callback for each. This does not
+  /// traverse into any nested symbol tables, and will also only return uses on
+  /// 'from' if it does not also define a symbol table.
+  static WalkResult
+  walkSymbolUses(Operation *from, function_ref<WalkResult(SymbolUse)> callback);
+
+  /// Walk all of the uses of the given symbol that are nested within the given
+  /// operation 'from', invoking the provided callback for each. This does not
+  /// traverse into any nested symbol tables, and will also only return uses on
+  /// 'from' if it does not also define a symbol table.
+  static WalkResult
+  walkSymbolUses(StringRef symbol, Operation *from,
+                 function_ref<WalkResult(SymbolUse)> callback);
+
+  /// Return if the given symbol has no uses that are nested within the given
+  /// operation 'from'. This does not traverse into any nested symbol tables,
+  /// and will also only count uses on 'from' if it does not also define a
+  /// symbol table.
+  static bool symbol_use_empty(StringRef symbol, Operation *from);
+
 private:
   MLIRContext *context;
 
index 9221906..e442d64 100644 (file)
@@ -142,3 +142,160 @@ LogicalResult OpTrait::impl::verifySymbolTable(Operation *op) {
   }
   return success();
 }
+
+//===----------------------------------------------------------------------===//
+// SymbolTable Trait Types
+//===----------------------------------------------------------------------===//
+
+/// A utility result for walking a nested attribute for symbol uses.
+enum HandlerResult {
+  /// The walk of the containter can continue.
+  Continue = 0,
+  /// The walk should recurse into the given attribute, as it is a container.
+  RecurseNestedAttribute,
+  /// The walk should end immediately, as an interrupt has been signaled.
+  Interrupt
+};
+
+/// Utility function used to handle a nested attribute during a walk of symbol
+/// uses. It returns the above HandlerResult signaling the next action for the
+/// walk.
+HandlerResult handleAttrDuringSymbolWalk(
+    Operation *op, Attribute attr,
+    SmallVectorImpl<std::pair<Attribute, unsigned>> &worklist,
+    function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
+  switch (attr.getKind()) {
+  /// Check for a nested container attribute, these will also need to be
+  /// walked.
+  case StandardAttributes::Array:
+  case StandardAttributes::Dictionary: {
+    worklist.push_back({attr, /*index*/ 0});
+    return HandlerResult::RecurseNestedAttribute;
+  }
+
+  // Invoke the provided callback if we find a symbol use and check for a
+  // requested interrupt.
+  case StandardAttributes::SymbolRef: {
+    SymbolTable::SymbolUse use(op, attr.cast<SymbolRefAttr>());
+    return callback(use).wasInterrupted() ? HandlerResult::Interrupt
+                                          : HandlerResult::Continue;
+  }
+  default:
+    return HandlerResult::Continue;
+  }
+}
+
+/// Walk all of the symbol references within the given operation, invoking the
+/// provided callback for each found use.
+static WalkResult
+walkSymbolRefs(Operation *op,
+               function_ref<WalkResult(SymbolTable::SymbolUse)> callback) {
+  // Check to see if the operation has any attributes.
+  DictionaryAttr attrDict = op->getAttrList().getDictionary();
+  if (!attrDict)
+    return WalkResult::advance();
+
+  // A worklist of a container attribute and the current index into the held
+  // attribute list.
+  SmallVector<std::pair<Attribute, unsigned>, 1> worklist;
+  worklist.push_back({attrDict, /*index*/ 0});
+  while (!worklist.empty()) {
+    Attribute attr = worklist.back().first;
+    unsigned &index = worklist.back().second;
+
+    // Iterate over the given attribute, which is guaranteed to be a container.
+    HandlerResult handlerResult = HandlerResult::Continue;
+    if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
+      ArrayRef<Attribute> attrs = arrayAttr.getValue();
+      unsigned attrSize = attrs.size();
+      while (index != attrSize)
+        if ((handlerResult = handleAttrDuringSymbolWalk(op, attrs[index++],
+                                                        worklist, callback)))
+          break;
+    } else {
+      auto dictAttr = attr.cast<DictionaryAttr>();
+      ArrayRef<NamedAttribute> attrs = dictAttr.getValue();
+      unsigned attrSize = attrs.size();
+      while (index != attrSize)
+        if ((handlerResult = handleAttrDuringSymbolWalk(
+                 op, attrs[index++].second, worklist, callback)))
+          break;
+    }
+    if (handlerResult == HandlerResult::Interrupt)
+      return WalkResult::interrupt();
+
+    // If we didn't encounter a nested attribute, pop the last item from the
+    // worklist.
+    if (handlerResult != HandlerResult::RecurseNestedAttribute)
+      worklist.pop_back();
+  }
+  return WalkResult::advance();
+}
+
+/// Walk all of the uses, for any symbol, that are nested within the given
+/// operation 'from', invoking the provided callback for each. This does not
+/// traverse into any nested symbol tables, and will also only return uses on
+/// 'from' if it does not also define a symbol table.
+WalkResult
+SymbolTable::walkSymbolUses(Operation *from,
+                            function_ref<WalkResult(SymbolUse)> callback) {
+  // If from is not a symbol table, check for uses. A symbol table defines a new
+  // scope, so we can't walk the attributes from the symbol table op.
+  if (!from->hasTrait<OpTrait::SymbolTable>()) {
+    if (walkSymbolRefs(from, callback).wasInterrupted())
+      return WalkResult::interrupt();
+  }
+
+  SmallVector<Region *, 1> worklist;
+  worklist.reserve(from->getNumRegions());
+  for (Region &region : from->getRegions())
+    worklist.push_back(&region);
+
+  while (!worklist.empty()) {
+    Region *region = worklist.pop_back_val();
+    for (Block &block : *region) {
+      for (Operation &op : block) {
+        if (walkSymbolRefs(&op, callback).wasInterrupted())
+          return WalkResult::interrupt();
+
+        // If this op defines a new symbol table scope, we can't traverse. Any
+        // symbol references nested within 'op' are different semantically.
+        if (!op.hasTrait<OpTrait::SymbolTable>()) {
+          for (Region &region : op.getRegions())
+            worklist.push_back(&region);
+        }
+      }
+    }
+  }
+  return WalkResult::advance();
+}
+
+/// Walk all of the uses, for any symbol, that are nested within the given
+/// operation 'from', invoking the provided callback for each. This does not
+/// traverse into any nested symbol tables, and will also only return uses on
+/// 'from' if it does not also define a symbol table.
+WalkResult
+SymbolTable::walkSymbolUses(StringRef symbol, Operation *from,
+                            function_ref<WalkResult(SymbolUse)> callback) {
+  SymbolRefAttr symbolRefAttr = SymbolRefAttr::get(symbol, from->getContext());
+  return walkSymbolUses(from, [&](SymbolUse symbolUse) {
+    if (symbolUse.getSymbolRef() != symbolRefAttr)
+      return WalkResult::advance();
+    return callback(std::move(symbolUse));
+  });
+}
+
+/// Return if the given symbol has no uses that are nested within the given
+/// operation 'from'. This does not traverse into any nested symbol tables,
+/// and will also only count uses on 'from' if it does not also define a
+/// symbol table.
+bool SymbolTable::symbol_use_empty(StringRef symbol, Operation *from) {
+  SymbolRefAttr symbolRefAttr = SymbolRefAttr::get(symbol, from->getContext());
+
+  // Walk all of the symbol uses looking for a reference to 'symbol'.
+  auto walkResult = walkSymbolUses(from, [&](SymbolUse symbolUse) {
+    return symbolUse.getSymbolRef() == symbolRefAttr ? WalkResult::interrupt()
+                                                     : WalkResult::advance();
+  });
+  return !walkResult.wasInterrupted();
+}
diff --git a/mlir/test/IR/test-symbol-uses.mlir b/mlir/test/IR/test-symbol-uses.mlir
new file mode 100644 (file)
index 0000000..874b907
--- /dev/null
@@ -0,0 +1,29 @@
+// RUN: mlir-opt %s -test-symbol-uses -verify-diagnostics
+
+
+// Symbol references to the module itself don't affect uses of symbols within
+// its table.
+module attributes {sym.outside_use = @symbol_foo } {
+  // expected-remark@+1 {{function has 2 uses}}
+  func @symbol_foo()
+
+  // expected-remark@+3 {{function has no uses}}
+  // expected-remark@+2 {{found use of function : @symbol_foo}}
+  // expected-remark@+1 {{function contains 2 nested references}}
+  func @symbol_bar() attributes {sym.use = @symbol_foo} {
+    // expected-remark@+1 {{found use of function : @symbol_foo}}
+    "foo.op"() {
+      non_symbol_attr,
+      use = [{ nested_symbol = [@symbol_foo]}],
+      z_other_non_symbol_attr
+    } : () -> ()
+  }
+
+  // expected-remark@+1 {{function has 1 use}}
+  func @symbol_baz()
+
+  // expected-remark@+1 {{found use of function : @symbol_baz}}
+  module attributes {test.reference = @symbol_baz} {
+    "foo.op"() {test.nested_reference = @symbol_baz} : () -> ()
+  }
+}
index 091dfc9..de7d50a 100644 (file)
@@ -1,3 +1,4 @@
+add_subdirectory(IR)
 add_subdirectory(Pass)
 add_subdirectory(TestDialect)
 add_subdirectory(Transforms)
diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt
new file mode 100644 (file)
index 0000000..c8ff810
--- /dev/null
@@ -0,0 +1,8 @@
+add_llvm_library(MLIRTestIR
+  TestSymbolUses.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  )
+target_link_libraries(MLIRTestIR
+  MLIRPass
+  )
diff --git a/mlir/test/lib/IR/TestSymbolUses.cpp b/mlir/test/lib/IR/TestSymbolUses.cpp
new file mode 100644 (file)
index 0000000..e8ca39f
--- /dev/null
@@ -0,0 +1,63 @@
+//===- TestSymbolUses.cpp - Pass to test symbol uselists ------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/Function.h"
+#include "mlir/Pass/Pass.h"
+
+using namespace mlir;
+
+namespace {
+/// This is a symbol test pass that tests the symbol uselist functionality
+/// provided by the symbol table.
+struct SymbolUsesPass : public ModulePass<SymbolUsesPass> {
+  void runOnModule() override {
+    auto module = getModule();
+
+    for (FuncOp func : module.getOps<FuncOp>()) {
+      // Test computing uses on a non symboltable op.
+      unsigned numUses = 0;
+      SymbolTable::walkSymbolUses(func, [&](SymbolTable::SymbolUse) {
+        ++numUses;
+        return WalkResult::advance();
+      });
+      if (numUses != 0)
+        func.emitRemark() << "function contains " << numUses
+                          << " nested references";
+
+      // Test the functionality of symbol_use_empty.
+      if (SymbolTable::symbol_use_empty(func.getName(), module)) {
+        func.emitRemark() << "function has no uses";
+        continue;
+      }
+
+      // Test the functionality of walkSymbolUses.
+      numUses = 0;
+      SymbolTable::walkSymbolUses(
+          func.getName(), module, [&](SymbolTable::SymbolUse symbolUse) {
+            symbolUse.getUser()->emitRemark()
+                << "found use of function : " << symbolUse.getSymbolRef();
+            ++numUses;
+            return WalkResult::advance();
+          });
+      func.emitRemark() << "function has " << numUses << " uses";
+    }
+  }
+};
+} // end anonymous namespace
+
+static PassRegistration<SymbolUsesPass> pass("test-symbol-uses",
+                                             "Test detection of symbol uses");
index 6868b90..6b8a3ae 100644 (file)
@@ -42,6 +42,7 @@ set(LIBS
   MLIRStandardToLLVM
   MLIRTransforms
   MLIRTestDialect
+  MLIRTestIR
   MLIRTestPass
   MLIRTestTransforms
   MLIRSupport