[mlir] Fix ordering of intermixed attribute/type aliases
authorRiver Riddle <riddleriver@gmail.com>
Thu, 17 Nov 2022 00:26:18 +0000 (16:26 -0800)
committerRiver Riddle <riddleriver@gmail.com>
Fri, 18 Nov 2022 10:09:57 +0000 (02:09 -0800)
We properly order dependencies between attribute/type aliases,
but we currently always print attribute aliases separately from type
aliases. This creates problems if an attribute wants to use a type
alias during printing.

This commit refactors alias collection such that attribute/type aliases
are collected together and printed together.

Differential Revision: https://reviews.llvm.org/D138162

mlir/lib/IR/AsmPrinter.cpp
mlir/test/IR/print-attr-type-aliases.mlir
mlir/test/Target/LLVMIR/Import/debug-info.ll

index 32a2647..248d0f4 100644 (file)
@@ -448,16 +448,21 @@ namespace {
 /// This class represents a specific instance of a symbol Alias.
 class SymbolAlias {
 public:
-  SymbolAlias(StringRef name, uint32_t suffixIndex, bool isDeferrable)
-      : name(name), suffixIndex(suffixIndex), isDeferrable(isDeferrable) {}
+  SymbolAlias(StringRef name, uint32_t suffixIndex, bool isType,
+              bool isDeferrable)
+      : name(name), suffixIndex(suffixIndex), isType(isType),
+        isDeferrable(isDeferrable) {}
 
   /// Print this alias to the given stream.
   void print(raw_ostream &os) const {
-    os << name;
+    os << (isType ? "!" : "#") << name;
     if (suffixIndex)
       os << suffixIndex;
   }
 
+  /// Returns true if this is a type alias.
+  bool isTypeAlias() const { return isType; }
+
   /// Returns true if this alias supports deferred resolution when parsing.
   bool canBeDeferred() const { return isDeferrable; }
 
@@ -465,7 +470,9 @@ private:
   /// The main name of the alias.
   StringRef name;
   /// The suffix index of the alias.
-  uint32_t suffixIndex : 31;
+  uint32_t suffixIndex : 30;
+  /// A flag indicating whether this alias is for a type.
+  bool isType : 1;
   /// A flag indicating whether this alias may be deferred or not.
   bool isDeferrable : 1;
 };
@@ -482,31 +489,34 @@ public:
         aliasOS(aliasBuffer) {}
 
   void initialize(Operation *op, const OpPrintingFlags &printerFlags,
-                  llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
-                  llvm::MapVector<Type, SymbolAlias> &typeToAlias);
+                  llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias);
 
   /// Visit the given attribute to see if it has an alias. `canBeDeferred` is
   /// set to true if the originator of this attribute can resolve the alias
   /// after parsing has completed (e.g. in the case of operation locations).
   /// Returns the maximum alias depth of the attribute.
   size_t visit(Attribute attr, bool canBeDeferred = false) {
-    return visitImpl(attr, attrAliases, canBeDeferred);
+    return visitImpl(attr, aliases, canBeDeferred);
   }
 
   /// Visit the given type to see if it has an alias. Returns the maximum alias
   /// depth of the type.
-  size_t visit(Type type) { return visitImpl(type, typeAliases); }
+  size_t visit(Type type) { return visitImpl(type, aliases); }
 
 private:
   struct InProgressAliasInfo {
-    InProgressAliasInfo() : aliasDepth(0), canBeDeferred(false) {}
-    InProgressAliasInfo(StringRef alias, bool canBeDeferred)
-        : alias(alias), aliasDepth(0), canBeDeferred(canBeDeferred) {}
+    InProgressAliasInfo()
+        : aliasDepth(0), isType(false), canBeDeferred(false) {}
+    InProgressAliasInfo(StringRef alias, bool isType, bool canBeDeferred)
+        : alias(alias), aliasDepth(1), isType(isType),
+          canBeDeferred(canBeDeferred) {}
 
     bool operator<(const InProgressAliasInfo &rhs) const {
-      // Order first by depth, and then by name.
+      // Order first by depth, then by attr/type kind, and then by name.
       if (aliasDepth != rhs.aliasDepth)
         return aliasDepth < rhs.aliasDepth;
+      if (isType != rhs.isType)
+        return isType;
       return alias < rhs.alias;
     }
 
@@ -514,7 +524,9 @@ private:
     Optional<StringRef> alias;
     /// The alias depth of this attribute or type, i.e. an indication of the
     /// relative ordering of when to print this alias.
-    unsigned aliasDepth : 31;
+    unsigned aliasDepth : 30;
+    /// If this alias represents a type or an attribute.
+    bool isType : 1;
     /// If this alias can be deferred or not.
     bool canBeDeferred : 1;
   };
@@ -524,22 +536,20 @@ private:
   /// the alias after parsing has completed (e.g. in the case of operation
   /// locations). Returns the maximum alias depth of the value.
   template <typename T>
-  size_t visitImpl(T value, llvm::MapVector<T, InProgressAliasInfo> &aliases,
+  size_t visitImpl(T value,
+                   llvm::MapVector<const void *, InProgressAliasInfo> &aliases,
                    bool canBeDeferred = false);
 
   /// Try to generate an alias for the provided symbol. If an alias is
   /// generated, the provided alias mapping and reverse mapping are updated.
-  /// Returns success if an alias was generated, failure otherwise.
   template <typename T>
-  LogicalResult generateAlias(T symbol, InProgressAliasInfo &alias,
-                              bool canBeDeferred);
+  void generateAlias(T symbol, InProgressAliasInfo &alias, bool canBeDeferred);
 
   /// Given a collection of aliases and symbols, initialize a mapping from a
   /// symbol to a given alias.
-  template <typename T>
-  static void
-  initializeAliases(llvm::MapVector<T, InProgressAliasInfo> &visitedSymbols,
-                    llvm::MapVector<T, SymbolAlias> &symbolToAlias);
+  static void initializeAliases(
+      llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
+      llvm::MapVector<const void *, SymbolAlias> &symbolToAlias);
 
   /// The set of asm interfaces within the context.
   DialectInterfaceCollection<OpAsmDialectInterface> &interfaces;
@@ -548,8 +558,7 @@ private:
   llvm::BumpPtrAllocator &aliasAllocator;
 
   /// The set of built aliases.
-  llvm::MapVector<Attribute, InProgressAliasInfo> attrAliases;
-  llvm::MapVector<Type, InProgressAliasInfo> typeAliases;
+  llvm::MapVector<const void *, InProgressAliasInfo> aliases;
 
   /// Storage and stream used when generating an alias.
   SmallString<32> aliasBuffer;
@@ -792,11 +801,10 @@ static StringRef sanitizeIdentifier(StringRef name, SmallString<16> &buffer,
 
 /// Given a collection of aliases and symbols, initialize a mapping from a
 /// symbol to a given alias.
-template <typename T>
 void AliasInitializer::initializeAliases(
-    llvm::MapVector<T, InProgressAliasInfo> &visitedSymbols,
-    llvm::MapVector<T, SymbolAlias> &symbolToAlias) {
-  std::vector<std::pair<T, InProgressAliasInfo>> unprocessedAliases =
+    llvm::MapVector<const void *, InProgressAliasInfo> &visitedSymbols,
+    llvm::MapVector<const void *, SymbolAlias> &symbolToAlias) {
+  std::vector<std::pair<const void *, InProgressAliasInfo>> unprocessedAliases =
       visitedSymbols.takeVector();
   llvm::stable_sort(unprocessedAliases, [](const auto &lhs, const auto &rhs) {
     return lhs.second < rhs.second;
@@ -809,31 +817,30 @@ void AliasInitializer::initializeAliases(
     StringRef alias = *aliasInfo.alias;
     unsigned nameIndex = nameCounts[alias]++;
     symbolToAlias.insert(
-        {symbol, SymbolAlias(alias, nameIndex, aliasInfo.canBeDeferred)});
+        {symbol, SymbolAlias(alias, nameIndex, aliasInfo.isType,
+                             aliasInfo.canBeDeferred)});
   }
 }
 
 void AliasInitializer::initialize(
     Operation *op, const OpPrintingFlags &printerFlags,
-    llvm::MapVector<Attribute, SymbolAlias> &attrToAlias,
-    llvm::MapVector<Type, SymbolAlias> &typeToAlias) {
+    llvm::MapVector<const void *, SymbolAlias> &attrTypeToAlias) {
   // Use a dummy printer when walking the IR so that we can collect the
   // attributes/types that will actually be used during printing when
   // considering aliases.
   DummyAliasOperationPrinter aliasPrinter(printerFlags, *this);
   aliasPrinter.printCustomOrGenericOp(op);
 
-  // Initialize the aliases sorted by name.
-  initializeAliases(attrAliases, attrToAlias);
-  initializeAliases(typeAliases, typeToAlias);
+  // Initialize the aliases.
+  initializeAliases(aliases, attrTypeToAlias);
 }
 
 template <typename T>
-size_t
-AliasInitializer::visitImpl(T value,
-                            llvm::MapVector<T, InProgressAliasInfo> &aliases,
-                            bool canBeDeferred) {
-  auto [it, inserted] = aliases.insert({value, InProgressAliasInfo()});
+size_t AliasInitializer::visitImpl(
+    T value, llvm::MapVector<const void *, InProgressAliasInfo> &aliases,
+    bool canBeDeferred) {
+  auto [it, inserted] =
+      aliases.insert({value.getAsOpaquePointer(), InProgressAliasInfo()});
   if (!inserted) {
     // Make sure that the alias isn't deferred if we don't permit it.
     if (!canBeDeferred)
@@ -842,7 +849,7 @@ AliasInitializer::visitImpl(T value,
   }
 
   // Try to generate an alias for this attribute.
-  bool hasAlias = succeeded(generateAlias(value, it->second, canBeDeferred));
+  generateAlias(value, it->second, canBeDeferred);
   size_t aliasIndex = std::distance(aliases.begin(), it);
 
   // Check for any sub elements.
@@ -852,17 +859,19 @@ AliasInitializer::visitImpl(T value,
   if (auto subElementInterface = dyn_cast<SubElementInterfaceT>(value)) {
     size_t maxAliasDepth = 0;
     auto visitSubElement = [&](auto element) {
-      if (Optional<size_t> depth = visit(element))
-        maxAliasDepth = std::max(maxAliasDepth, *depth + 1);
+      if (!element)
+        return;
+      if (size_t depth = visit(element))
+        maxAliasDepth = std::max(maxAliasDepth, depth + 1);
     };
-    subElementInterface.walkSubElements(visitSubElement, visitSubElement);
+    subElementInterface.walkImmediateSubElements(visitSubElement,
+                                                 visitSubElement);
 
     // Make sure to recompute `it` in case the map was reallocated.
     it = std::next(aliases.begin(), aliasIndex);
 
-    // If we had sub elements and an alias, update our main alias to account for
-    // the depth.
-    if (maxAliasDepth && hasAlias)
+    // If we had sub elements, update to account for the depth.
+    if (maxAliasDepth)
       it->second.aliasDepth = maxAliasDepth;
   }
 
@@ -871,9 +880,8 @@ AliasInitializer::visitImpl(T value,
 }
 
 template <typename T>
-LogicalResult AliasInitializer::generateAlias(T symbol,
-                                              InProgressAliasInfo &alias,
-                                              bool canBeDeferred) {
+void AliasInitializer::generateAlias(T symbol, InProgressAliasInfo &alias,
+                                     bool canBeDeferred) {
   SmallString<32> nameBuffer;
   for (const auto &interface : interfaces) {
     OpAsmDialectInterface::AliasResult result =
@@ -887,15 +895,15 @@ LogicalResult AliasInitializer::generateAlias(T symbol,
   }
 
   if (nameBuffer.empty())
-    return failure();
+    return;
 
   SmallString<16> tempBuffer;
   StringRef name =
       sanitizeIdentifier(nameBuffer, tempBuffer, /*allowedPunctChars=*/"$_-",
                          /*allowTrailingDigit=*/false);
   name = name.copy(aliasAllocator);
-  alias = InProgressAliasInfo(name, canBeDeferred);
-  return success();
+  alias = InProgressAliasInfo(name, /*isType=*/std::is_base_of_v<Type, T>,
+                              canBeDeferred);
 }
 
 //===----------------------------------------------------------------------===//
@@ -936,10 +944,8 @@ private:
   void printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
                     bool isDeferred);
 
-  /// Mapping between attribute and alias.
-  llvm::MapVector<Attribute, SymbolAlias> attrToAlias;
-  /// Mapping between type and alias.
-  llvm::MapVector<Type, SymbolAlias> typeToAlias;
+  /// Mapping between attribute/type and alias.
+  llvm::MapVector<const void *, SymbolAlias> attrTypeToAlias;
 
   /// An allocator used for alias names.
   llvm::BumpPtrAllocator aliasAllocator;
@@ -950,23 +956,23 @@ void AliasState::initialize(
     Operation *op, const OpPrintingFlags &printerFlags,
     DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
   AliasInitializer initializer(interfaces, aliasAllocator);
-  initializer.initialize(op, printerFlags, attrToAlias, typeToAlias);
+  initializer.initialize(op, printerFlags, attrTypeToAlias);
 }
 
 LogicalResult AliasState::getAlias(Attribute attr, raw_ostream &os) const {
-  auto it = attrToAlias.find(attr);
-  if (it == attrToAlias.end())
+  auto it = attrTypeToAlias.find(attr.getAsOpaquePointer());
+  if (it == attrTypeToAlias.end())
     return failure();
-  it->second.print(os << '#');
+  it->second.print(os);
   return success();
 }
 
 LogicalResult AliasState::getAlias(Type ty, raw_ostream &os) const {
-  auto it = typeToAlias.find(ty);
-  if (it == typeToAlias.end())
+  auto it = attrTypeToAlias.find(ty.getAsOpaquePointer());
+  if (it == attrTypeToAlias.end())
     return failure();
 
-  it->second.print(os << '!');
+  it->second.print(os);
   return success();
 }
 
@@ -975,27 +981,26 @@ void AliasState::printAliases(AsmPrinter::Impl &p, NewLineCounter &newLine,
   auto filterFn = [=](const auto &aliasIt) {
     return aliasIt.second.canBeDeferred() == isDeferred;
   };
-  for (auto &[attr, alias] : llvm::make_filter_range(attrToAlias, filterFn)) {
-    alias.print(p.getStream() << '#');
-    p.getStream() << " = ";
-
-    // TODO: Support nested aliases in mutable attributes.
-    if (attr.hasTrait<AttributeTrait::IsMutable>())
-      p.getStream() << attr;
-    else
-      p.printAttributeImpl(attr);
-
-    p.getStream() << newLine;
-  }
-  for (auto &[type, alias] : llvm::make_filter_range(typeToAlias, filterFn)) {
-    alias.print(p.getStream() << '!');
+  for (auto &[opaqueSymbol, alias] :
+       llvm::make_filter_range(attrTypeToAlias, filterFn)) {
+    alias.print(p.getStream());
     p.getStream() << " = ";
 
-    // TODO: Support nested aliases in mutable types.
-    if (type.hasTrait<TypeTrait::IsMutable>())
-      p.getStream() << type;
-    else
-      p.printTypeImpl(type);
+    if (alias.isTypeAlias()) {
+      // TODO: Support nested aliases in mutable types.
+      Type type = Type::getFromOpaquePointer(opaqueSymbol);
+      if (type.hasTrait<TypeTrait::IsMutable>())
+        p.getStream() << type;
+      else
+        p.printTypeImpl(type);
+    } else {
+      // TODO: Support nested aliases in mutable attributes.
+      Attribute attr = Attribute::getFromOpaquePointer(opaqueSymbol);
+      if (attr.hasTrait<AttributeTrait::IsMutable>())
+        p.getStream() << attr;
+      else
+        p.printAttributeImpl(attr);
+    }
 
     p.getStream() << newLine;
   }
index 4d65ae6..b9893f2 100644 (file)
@@ -1,6 +1,6 @@
-// RUN: mlir-opt %s | FileCheck %s
+// RUN: mlir-opt %s -split-input-file | FileCheck %s
 // Verify printer of type & attr aliases.
-// RUN: mlir-opt %s | mlir-opt | FileCheck %s
+// RUN: mlir-opt %s -split-input-file | mlir-opt -split-input-file | FileCheck %s
 
 // CHECK-DAG: #test2Ealias = "alias_test:dot_in_name"
 "test.op"() {alias_test = "alias_test:dot_in_name"} : () -> ()
 // CHECK-DAG: #loc2 = loc("nested")
 // CHECK-DAG: #loc3 = loc(fused<#loc2>["test.mlir":10:8])
 "test.op"() {alias_test = loc(fused<loc("nested")>["test.mlir":10:8])} : () -> ()
+
+// -----
+
+// Check proper ordering of intermixed attribute/type aliases.
+// CHECK: !tuple = tuple<
+// CHECK: #loc1 = loc(fused<!tuple
+"test.op"() {alias_test = loc(fused<tuple<i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32>>["test.mlir":10:8])} : () -> ()
index 575ac01..aa0a07d 100644 (file)
@@ -152,13 +152,13 @@ define void @derived_type() !dbg !3 {
 
 ; // -----
 
-; CHECK: #[[INT:.+]] = #llvm.di_basic_type<tag = DW_TAG_base_type, name = "int">
-; CHECK: #[[FILE:.+]] = #llvm.di_file<"debug-info.ll" in "/">
-; CHECK: #[[COMP1:.+]] = #llvm.di_composite_type<tag = DW_TAG_array_type, name = "array1", line = 10, sizeInBits = 128, alignInBits = 32>
-; CHECK: #[[COMP2:.+]] = #llvm.di_composite_type<{{.*}}, file = #[[FILE]], line = 0, scope = #[[FILE]], baseType = #[[INT]], sizeInBits = 0, alignInBits = 0>
-; CHECK: #[[COMP3:.+]] = #llvm.di_composite_type<{{.*}}, flags = Vector, {{.*}}, elements = #llvm.di_subrange<count = 4 : i64>>
-; CHECK: #[[COMP4:.+]] = #llvm.di_composite_type<{{.*}}, elements = #llvm.di_subrange<lowerBound = 0 : i64, upperBound = 4 : i64, stride = 1 : i64>>
-; CHECK: #llvm.di_subroutine_type<argumentTypes = #[[COMP1]], #[[COMP2]], #[[COMP3]], #[[COMP4]]>
+; CHECK-DAG: #[[INT:.+]] = #llvm.di_basic_type<tag = DW_TAG_base_type, name = "int">
+; CHECK-DAG: #[[FILE:.+]] = #llvm.di_file<"debug-info.ll" in "/">
+; CHECK-DAG: #[[COMP1:.+]] = #llvm.di_composite_type<tag = DW_TAG_array_type, name = "array1", line = 10, sizeInBits = 128, alignInBits = 32>
+; CHECK-DAG: #[[COMP2:.+]] = #llvm.di_composite_type<{{.*}}, file = #[[FILE]], line = 0, scope = #[[FILE]], baseType = #[[INT]], sizeInBits = 0, alignInBits = 0>
+; CHECK-DAG: #[[COMP3:.+]] = #llvm.di_composite_type<{{.*}}, flags = Vector, {{.*}}, elements = #llvm.di_subrange<count = 4 : i64>>
+; CHECK-DAG: #[[COMP4:.+]] = #llvm.di_composite_type<{{.*}}, elements = #llvm.di_subrange<lowerBound = 0 : i64, upperBound = 4 : i64, stride = 1 : i64>>
+; CHECK-DAG: #llvm.di_subroutine_type<argumentTypes = #[[COMP1]], #[[COMP2]], #[[COMP3]], #[[COMP4]]>
 
 define void @composite_type() !dbg !3 {
   ret void