[mlir] Implement replacement of SymbolRefAttrs in Dialect attributes using SubElement...
authorMarkus Böck <markus.boeck02@gmail.com>
Thu, 28 Oct 2021 17:08:10 +0000 (19:08 +0200)
committerMarkus Böck <markus.boeck02@gmail.com>
Thu, 28 Oct 2021 17:08:20 +0000 (19:08 +0200)
This patch extends the SubElementAttr interface to allow replacing a contained sub attribute. The attribute that should be replaced is identified by an index which denotes the n-th element returned by the accompanying walkImmediateSubElements method.

Using this addition the patch implements replacing SymbolRefAttrs contained within any dialect attributes.

Differential Revision: https://reviews.llvm.org/D111357

mlir/include/mlir/IR/BuiltinAttributes.td
mlir/include/mlir/IR/SubElementInterfaces.td
mlir/lib/IR/BuiltinAttributes.cpp
mlir/lib/IR/SymbolTable.cpp
mlir/test/IR/test-symbol-rauw.mlir
mlir/test/lib/Dialect/Test/TestAttrDefs.td
mlir/test/lib/Dialect/Test/TestAttributes.cpp

index fcd6082..51ac32d 100644 (file)
@@ -71,7 +71,8 @@ def Builtin_AffineMapAttr : Builtin_Attr<"AffineMap", [
 //===----------------------------------------------------------------------===//
 
 def Builtin_ArrayAttr : Builtin_Attr<"Array", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+    DeclareAttrInterfaceMethods<SubElementAttrInterface,
+        ["replaceImmediateSubAttribute"]>
   ]> {
   let summary = "A collection of other Attribute values";
   let description = [{
@@ -345,7 +346,8 @@ def Builtin_DenseStringElementsAttr : Builtin_Attr<
 //===----------------------------------------------------------------------===//
 
 def Builtin_DictionaryAttr : Builtin_Attr<"Dictionary", [
-    DeclareAttrInterfaceMethods<SubElementAttrInterface>
+    DeclareAttrInterfaceMethods<SubElementAttrInterface,
+        ["replaceImmediateSubAttribute"]>
   ]> {
   let summary = "An dictionary of named Attribute values";
   let description = [{
@@ -954,10 +956,11 @@ def Builtin_SymbolRefAttr : Builtin_Attr<"SymbolRef"> {
     symbol nested within a different symbol table.
 
     This attribute can only be held internally by
-    [array attributes](#array-attribute) and
+    [array attributes](#array-attribute),
     [dictionary attributes](#dictionary-attribute)(including the top-level
-    operation attribute dictionary), i.e. no other attribute kinds such as
-    Locations or extended attribute kinds.
+    operation attribute dictionary) as well as attributes exposing it via
+    the `SubElementAttrInterface` interface. Symbol reference attributes
+    nested in types are currently not supported.
 
     **Rationale:** Identifying accesses to global data is critical to
     enabling efficient multi-threaded compilation. Restricting global
index 8a48852..9ee9051 100644 (file)
@@ -33,6 +33,20 @@ class SubElementInterfaceBase<string interfaceName, string derivedValue> {
       (ins "llvm::function_ref<void(mlir::Attribute)>":$walkAttrsFn,
            "llvm::function_ref<void(mlir::Type)>":$walkTypesFn)
     >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Replace the attributes identified by the indices with the corresponding
+        value. The index is derived from the order of the attributes returned by
+        the attribute callback of `walkImmediateSubElements`. An index of 0 would
+        replace the very first attribute given by `walkImmediateSubElements`.
+        The new instance with the values replaced is returned.
+      }], cppNamespace # "::" # interfaceName, "replaceImmediateSubAttribute",
+      (ins "::llvm::ArrayRef<std::pair<size_t, ::mlir::Attribute>>":$replacements),
+      [{}],
+      /*defaultImplementation=*/[{
+        llvm_unreachable("Attribute or Type does not support replacing attributes");
+      }]
+    >,
   ];
 
   code extraClassDeclaration = [{
index fe8f6a5..72891d9 100644 (file)
@@ -53,6 +53,15 @@ void ArrayAttr::walkImmediateSubElements(
     walkAttrsFn(attr);
 }
 
+SubElementAttrInterface ArrayAttr::replaceImmediateSubAttribute(
+    ArrayRef<std::pair<size_t, Attribute>> replacements) const {
+  std::vector<Attribute> vector = getValue().vec();
+  for (auto &it : replacements) {
+    vector[it.first] = it.second;
+  }
+  return get(getContext(), vector);
+}
+
 //===----------------------------------------------------------------------===//
 // DictionaryAttr
 //===----------------------------------------------------------------------===//
@@ -217,6 +226,17 @@ void DictionaryAttr::walkImmediateSubElements(
     walkAttrsFn(attr);
 }
 
+SubElementAttrInterface DictionaryAttr::replaceImmediateSubAttribute(
+    ArrayRef<std::pair<size_t, Attribute>> replacements) const {
+  std::vector<NamedAttribute> vec = getValue().vec();
+  for (auto &it : replacements) {
+    vec[it.first].second = it.second;
+  }
+  // The above only modifies the mapped value, but not the key, and therefore
+  // not the order of the elements. It remains sorted
+  return getWithSorted(getContext(), vec);
+}
+
 //===----------------------------------------------------------------------===//
 // StringAttr
 //===----------------------------------------------------------------------===//
index ad4f083..6634eab 100644 (file)
@@ -485,16 +485,30 @@ static WalkResult walkSymbolRefs(
 
   // A worklist of a container attribute and the current index into the held
   // attribute list.
-  SmallVector<Attribute, 1> attrWorklist(1, attrDict);
+  struct WorklistItem {
+    SubElementAttrInterface container;
+    SmallVector<Attribute> immediateSubElements;
+
+    explicit WorklistItem(SubElementAttrInterface container) {
+      SmallVector<Attribute> subElements;
+      container.walkImmediateSubElements(
+          [&](Attribute attr) { subElements.push_back(attr); }, [](Type) {});
+      immediateSubElements = std::move(subElements);
+    }
+  };
+
+  SmallVector<WorklistItem, 1> attrWorklist(1, WorklistItem(attrDict));
   SmallVector<int, 1> curAccessChain(1, /*Value=*/-1);
 
   // Process the symbol references within the given nested attribute range.
-  auto processAttrs = [&](int &index, auto attrRange) -> WalkResult {
-    for (Attribute attr : llvm::drop_begin(attrRange, index)) {
+  auto processAttrs = [&](int &index,
+                          WorklistItem &worklistItem) -> WalkResult {
+    for (Attribute attr :
+         llvm::drop_begin(worklistItem.immediateSubElements, index)) {
       /// Check for a nested container attribute, these will also need to be
       /// walked.
-      if (attr.isa<ArrayAttr, DictionaryAttr>()) {
-        attrWorklist.push_back(attr);
+      if (auto interface = attr.dyn_cast<SubElementAttrInterface>()) {
+        attrWorklist.emplace_back(interface);
         curAccessChain.push_back(-1);
         return WalkResult::advance();
       }
@@ -517,15 +531,12 @@ static WalkResult walkSymbolRefs(
 
   WalkResult result = WalkResult::advance();
   do {
-    Attribute attr = attrWorklist.back();
+    WorklistItem &item = attrWorklist.back();
     int &index = curAccessChain.back();
     ++index;
 
     // Process the given attribute, which is guaranteed to be a container.
-    if (auto dict = attr.dyn_cast<DictionaryAttr>())
-      result = processAttrs(index, make_second_range(dict.getValue()));
-    else
-      result = processAttrs(index, attr.cast<ArrayAttr>().getValue());
+    result = processAttrs(index, item);
   } while (!attrWorklist.empty() && !result.wasInterrupted());
   return result;
 }
@@ -811,48 +822,46 @@ bool SymbolTable::symbolKnownUseEmpty(Operation *symbol, Region *from) {
 
 /// Rebuild the given attribute container after replacing all references to a
 /// symbol with the updated attribute in 'accesses'.
-static Attribute rebuildAttrAfterRAUW(
-    Attribute container,
+static SubElementAttrInterface rebuildAttrAfterRAUW(
+    SubElementAttrInterface container,
     ArrayRef<std::pair<SmallVector<int, 1>, SymbolRefAttr>> accesses,
     unsigned depth) {
   // Given a range of Attributes, update the ones referred to by the given
   // access chains to point to the new symbol attribute.
-  auto updateAttrs = [&](auto &&attrRange) {
-    auto attrBegin = std::begin(attrRange);
-    for (unsigned i = 0, e = accesses.size(); i != e;) {
-      ArrayRef<int> access = accesses[i].first;
-      Attribute &attr = *std::next(attrBegin, access[depth]);
-
-      // Check to see if this is a leaf access, i.e. a SymbolRef.
-      if (access.size() == depth + 1) {
-        attr = accesses[i].second;
-        ++i;
-        continue;
-      }
 
-      // Otherwise, this is a container. Collect all of the accesses for this
-      // index and recurse. The recursion here is bounded by the size of the
-      // largest access array.
-      auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) {
-        ArrayRef<int> nextAccess = it.first;
-        return nextAccess.size() > depth + 1 &&
-               nextAccess[depth] == access[depth];
-      });
-      attr = rebuildAttrAfterRAUW(attr, nestedAccesses, depth + 1);
-
-      // Skip over all of the accesses that refer to the nested container.
-      i += nestedAccesses.size();
+  SmallVector<std::pair<size_t, Attribute>> replacements;
+
+  SmallVector<Attribute> subElements;
+  container.walkImmediateSubElements(
+      [&](Attribute attribute) { subElements.push_back(attribute); },
+      [](Type) {});
+  for (unsigned i = 0, e = accesses.size(); i != e;) {
+    ArrayRef<int> access = accesses[i].first;
+
+    // Check to see if this is a leaf access, i.e. a SymbolRef.
+    if (access.size() == depth + 1) {
+      replacements.emplace_back(access.back(), accesses[i].second);
+      ++i;
+      continue;
     }
-  };
 
-  if (auto dictAttr = container.dyn_cast<DictionaryAttr>()) {
-    auto newAttrs = llvm::to_vector<4>(dictAttr.getValue());
-    updateAttrs(make_second_range(newAttrs));
-    return DictionaryAttr::get(dictAttr.getContext(), newAttrs);
+    // Otherwise, this is a container. Collect all of the accesses for this
+    // index and recurse. The recursion here is bounded by the size of the
+    // largest access array.
+    auto nestedAccesses = accesses.drop_front(i).take_while([&](auto &it) {
+      ArrayRef<int> nextAccess = it.first;
+      return nextAccess.size() > depth + 1 &&
+             nextAccess[depth] == access[depth];
+    });
+    auto result = rebuildAttrAfterRAUW(subElements[access[depth]],
+                                       nestedAccesses, depth + 1);
+    replacements.emplace_back(access[depth], result);
+
+    // Skip over all of the accesses that refer to the nested container.
+    i += nestedAccesses.size();
   }
-  auto newAttrs = llvm::to_vector<4>(container.cast<ArrayAttr>().getValue());
-  updateAttrs(newAttrs);
-  return ArrayAttr::get(container.getContext(), newAttrs);
+
+  return container.replaceImmediateSubAttribute(replacements);
 }
 
 /// Generates a new symbol reference attribute with a new leaf reference.
index 5d50bc0..931c26b 100644 (file)
@@ -73,3 +73,24 @@ module {
   "foo.possibly_unknown_symbol_table"() ({
   }) : () -> ()
 }
+
+// -----
+
+// Check that replacement works in any implementations of SubElementsAttrInterface
+module {
+    // CHECK: func private @replaced_foo
+    func private @symbol_foo() attributes {sym.new_name = "replaced_foo" }
+
+    // CHECK: func @symbol_bar
+    func @symbol_bar() {
+      // CHECK: foo.op
+      // CHECK-SAME: non_symbol_attr,
+      // CHECK-SAME: use = [#test.sub_elements_access<[@replaced_foo], @symbol_bar, @replaced_foo>],
+      // CHECK-SAME: z_non_symbol_attr_3
+      "foo.op"() {
+        non_symbol_attr,
+        use = [#test.sub_elements_access<[@symbol_foo],@symbol_bar,@symbol_foo>],
+        z_non_symbol_attr_3
+      } : () -> ()
+    }
+}
index 8e36f63..3062fd6 100644 (file)
@@ -16,6 +16,7 @@
 // To get the test dialect definition.
 include "TestOps.td"
 include "mlir/IR/BuiltinAttributeInterfaces.td"
+include "mlir/IR/SubElementInterfaces.td"
 
 // All of the attributes will extend this class.
 class Test_Attr<string name, list<Trait> traits = []>
@@ -101,4 +102,18 @@ def TestI64ElementsAttr : Test_Attr<"TestI64Elements", [
   let genVerifyDecl = 1;
 }
 
+def TestSubElementsAccessAttr : Test_Attr<"TestSubElementsAccess", [
+    DeclareAttrInterfaceMethods<SubElementAttrInterface,
+        ["replaceImmediateSubAttribute"]>
+  ]> {
+
+  let mnemonic = "sub_elements_access";
+
+  let parameters = (ins
+    "::mlir::Attribute":$first,
+    "::mlir::Attribute":$second,
+    "::mlir::Attribute":$third
+  );
+}
+
 #endif // TEST_ATTRDEFS
index 29b7023..9cd9c57 100644 (file)
@@ -128,6 +128,57 @@ TestI64ElementsAttr::verify(function_ref<InFlightDiagnostic()> emitError,
 }
 
 //===----------------------------------------------------------------------===//
+// TestSubElementsAccessAttr
+//===----------------------------------------------------------------------===//
+
+Attribute TestSubElementsAccessAttr::parse(::mlir::DialectAsmParser &parser,
+                                           ::mlir::Type type) {
+  Attribute first, second, third;
+  if (parser.parseLess() || parser.parseAttribute(first) ||
+      parser.parseComma() || parser.parseAttribute(second) ||
+      parser.parseComma() || parser.parseAttribute(third) ||
+      parser.parseGreater()) {
+    return {};
+  }
+  return get(parser.getContext(), first, second, third);
+}
+
+void TestSubElementsAccessAttr::print(
+    ::mlir::DialectAsmPrinter &printer) const {
+  printer << getMnemonic() << "<" << getFirst() << ", " << getSecond() << ", "
+          << getThird() << ">";
+}
+
+void TestSubElementsAccessAttr::walkImmediateSubElements(
+    llvm::function_ref<void(mlir::Attribute)> walkAttrsFn,
+    llvm::function_ref<void(mlir::Type)> walkTypesFn) const {
+  walkAttrsFn(getFirst());
+  walkAttrsFn(getSecond());
+  walkAttrsFn(getThird());
+}
+
+SubElementAttrInterface TestSubElementsAccessAttr::replaceImmediateSubAttribute(
+    ArrayRef<std::pair<size_t, Attribute>> replacements) const {
+  Attribute first = getFirst();
+  Attribute second = getSecond();
+  Attribute third = getThird();
+  for (auto &it : replacements) {
+    switch (it.first) {
+    case 0:
+      first = it.second;
+      break;
+    case 1:
+      second = it.second;
+      break;
+    case 2:
+      third = it.second;
+      break;
+    }
+  }
+  return get(getContext(), first, second, third);
+}
+
+//===----------------------------------------------------------------------===//
 // Tablegen Generated Definitions
 //===----------------------------------------------------------------------===//