[mlir] NFC: Move the state for managing aliases out of ModuleState and into a new...
authorRiver Riddle <riverriddle@google.com>
Wed, 8 Jan 2020 18:11:56 +0000 (10:11 -0800)
committerRiver Riddle <riverriddle@google.com>
Wed, 8 Jan 2020 18:34:35 +0000 (10:34 -0800)
Summary: This reduces the complexity of ModuleState and simplifies the code. A future revision will mold ModuleState into something that can be used by users for caching of printer state, as well as for implementing printAsOperand style methods.

Reviewed By: antiagainst

Differential Revision: https://reviews.llvm.org/D72292

mlir/lib/IR/AsmPrinter.cpp

index afe7ab6..2eb4367 100644 (file)
@@ -155,98 +155,46 @@ bool OpPrintingFlags::shouldPrintGenericOpForm() const {
 bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
 
 //===----------------------------------------------------------------------===//
-// ModuleState
+// AliasState
 //===----------------------------------------------------------------------===//
 
 namespace {
-/// A special index constant used for non-kind attribute aliases.
-static constexpr int kNonAttrKindAlias = -1;
-
-class ModuleState {
+/// This class manages the state for type and attribute aliases.
+class AliasState {
 public:
-  explicit ModuleState(MLIRContext *context) : interfaces(context) {}
-  void initialize(Operation *op);
-
-  Twine getAttributeAlias(Attribute attr) const {
-    auto alias = attrToAlias.find(attr);
-    if (alias == attrToAlias.end())
-      return Twine();
-
-    // Return the alias for this attribute, along with the index if this was
-    // generated by a kind alias.
-    int kindIndex = alias->second.second;
-    return alias->second.first +
-           (kindIndex == kNonAttrKindAlias ? Twine() : Twine(kindIndex));
-  }
-
-  void printAttributeAliases(raw_ostream &os) const {
-    auto printAlias = [&](StringRef alias, Attribute attr, int index) {
-      os << '#' << alias;
-      if (index != kNonAttrKindAlias)
-        os << index;
-      os << " = " << attr << '\n';
-    };
-
-    // Print all of the attribute kind aliases.
-    for (auto &kindAlias : attrKindToAlias) {
-      for (unsigned i = 0, e = kindAlias.second.second.size(); i != e; ++i)
-        printAlias(kindAlias.second.first, kindAlias.second.second[i], i);
-      os << "\n";
-    }
+  // Initialize the internal aliases.
+  void
+  initialize(Operation *op,
+             DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
 
-    // In a second pass print all of the remaining attribute aliases that aren't
-    // kind aliases.
-    for (Attribute attr : usedAttributes) {
-      auto alias = attrToAlias.find(attr);
-      if (alias != attrToAlias.end() &&
-          alias->second.second == kNonAttrKindAlias)
-        printAlias(alias->second.first, attr, alias->second.second);
-    }
-  }
+  /// Return a name used for an attribute alias, or empty if there is no alias.
+  Twine getAttributeAlias(Attribute attr) const;
 
-  StringRef getTypeAlias(Type ty) const { return typeToAlias.lookup(ty); }
+  /// Print all of the referenced attribute aliases.
+  void printAttributeAliases(raw_ostream &os) const;
 
-  void printTypeAliases(raw_ostream &os) const {
-    for (Type type : usedTypes) {
-      auto alias = typeToAlias.find(type);
-      if (alias != typeToAlias.end())
-        os << '!' << alias->second << " = type " << type << '\n';
-    }
-  }
+  /// Return a string to use as an alias for the given type, or empty if there
+  /// is no alias recorded.
+  StringRef getTypeAlias(Type ty) const;
 
-  /// Get an instance of the OpAsmDialectInterface for the given dialect, or
-  /// null if one wasn't registered.
-  const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
-    return interfaces.getInterfaceFor(dialect);
-  }
+  /// Print all of the referenced type aliases.
+  void printTypeAliases(raw_ostream &os) const;
 
 private:
-  void recordAttributeReference(Attribute attr) {
-    // Don't recheck attributes that have already been seen or those that
-    // already have an alias.
-    if (!usedAttributes.insert(attr) || attrToAlias.count(attr))
-      return;
+  /// A special index constant used for non-kind attribute aliases.
+  enum { NonAttrKindAlias = -1 };
 
-    // If this attribute kind has an alias, then record one for this attribute.
-    auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
-    if (alias == attrKindToAlias.end())
-      return;
-    std::pair<StringRef, int> attrAlias(alias->second.first,
-                                        alias->second.second.size());
-    attrToAlias.insert({attr, attrAlias});
-    alias->second.second.push_back(attr);
-  }
+  /// Record a reference to the given attribute.
+  void recordAttributeReference(Attribute attr);
 
-  void recordTypeReference(Type ty) { usedTypes.insert(ty); }
+  /// Record a reference to the given type.
+  void recordTypeReference(Type ty);
 
   // Visit functions.
   void visitOperation(Operation *op);
   void visitType(Type type);
   void visitAttribute(Attribute attr);
 
-  // Initialize symbol aliases.
-  void initializeSymbolAliases();
-
   /// Set of attributes known to be used within the module.
   llvm::SetVector<Attribute> usedAttributes;
 
@@ -265,59 +213,9 @@ private:
 
   /// A mapping between a type and a given alias.
   DenseMap<Type, StringRef> typeToAlias;
-
-  /// Collection of OpAsm interfaces implemented in the context.
-  DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
 };
 } // end anonymous namespace
 
-// TODO Support visiting other types/operations when implemented.
-void ModuleState::visitType(Type type) {
-  recordTypeReference(type);
-  if (auto funcType = type.dyn_cast<FunctionType>()) {
-    // Visit input and result types for functions.
-    for (auto input : funcType.getInputs())
-      visitType(input);
-    for (auto result : funcType.getResults())
-      visitType(result);
-    return;
-  }
-  if (auto memref = type.dyn_cast<MemRefType>()) {
-    // Visit affine maps in memref type.
-    for (auto map : memref.getAffineMaps())
-      recordAttributeReference(AffineMapAttr::get(map));
-  }
-  if (auto shapedType = type.dyn_cast<ShapedType>()) {
-    visitType(shapedType.getElementType());
-  }
-}
-
-void ModuleState::visitAttribute(Attribute attr) {
-  recordAttributeReference(attr);
-  if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
-    for (auto elt : arrayAttr.getValue())
-      visitAttribute(elt);
-  } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
-    visitType(typeAttr.getValue());
-  }
-}
-
-void ModuleState::visitOperation(Operation *op) {
-  // Visit all the types used in the operation.
-  for (auto type : op->getOperandTypes())
-    visitType(type);
-  for (auto type : op->getResultTypes())
-    visitType(type);
-  for (auto &region : op->getRegions())
-    for (auto &block : region)
-      for (auto arg : block.getArguments())
-        visitType(arg->getType());
-
-  // Visit each of the attributes.
-  for (auto elt : op->getAttrs())
-    visitAttribute(elt.second);
-}
-
 // Utility to generate a function to register a symbol alias.
 static bool canRegisterAlias(StringRef name, llvm::StringSet<> &usedAliases) {
   assert(!name.empty() && "expected alias name to be non-empty");
@@ -329,7 +227,9 @@ static bool canRegisterAlias(StringRef name, llvm::StringSet<> &usedAliases) {
   return !name.contains('.') && usedAliases.insert(name).second;
 }
 
-void ModuleState::initializeSymbolAliases() {
+void AliasState::initialize(
+    Operation *op,
+    DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
   // Track the identifiers in use for each symbol so that the same identifier
   // isn't used twice.
   llvm::StringSet<> usedAliases;
@@ -374,7 +274,7 @@ void ModuleState::initializeSymbolAliases() {
   for (auto &attrAliasPair : attributeAliases) {
     std::tie(attr, alias) = attrAliasPair;
     if (!reservedAttrNames.match(alias) && canRegisterAlias(alias, usedAliases))
-      attrToAlias.insert({attr, {alias, kNonAttrKindAlias}});
+      attrToAlias.insert({attr, {alias, NonAttrKindAlias}});
   }
 
   // Clear the set of used identifiers as types can have the same identifiers as
@@ -385,16 +285,164 @@ void ModuleState::initializeSymbolAliases() {
   for (auto &typeAliasPair : typeAliases)
     if (canRegisterAlias(typeAliasPair.second, usedAliases))
       typeToAlias.insert(typeAliasPair);
+
+  // Traverse the given IR to generate the set of used attributes/types.
+  op->walk([&](Operation *op) { visitOperation(op); });
 }
 
-void ModuleState::initialize(Operation *op) {
-  // Initialize the symbol aliases.
-  initializeSymbolAliases();
+/// Return a name used for an attribute alias, or empty if there is no alias.
+Twine AliasState::getAttributeAlias(Attribute attr) const {
+  auto alias = attrToAlias.find(attr);
+  if (alias == attrToAlias.end())
+    return Twine();
 
-  // Visit each of the nested operations.
-  op->walk([&](Operation *op) { visitOperation(op); });
+  // Return the alias for this attribute, along with the index if this was
+  // generated by a kind alias.
+  int kindIndex = alias->second.second;
+  return alias->second.first +
+         (kindIndex == NonAttrKindAlias ? Twine() : Twine(kindIndex));
 }
 
+/// Print all of the referenced attribute aliases.
+void AliasState::printAttributeAliases(raw_ostream &os) const {
+  auto printAlias = [&](StringRef alias, Attribute attr, int index) {
+    os << '#' << alias;
+    if (index != NonAttrKindAlias)
+      os << index;
+    os << " = " << attr << '\n';
+  };
+
+  // Print all of the attribute kind aliases.
+  for (auto &kindAlias : attrKindToAlias) {
+    auto &aliasAttrsPair = kindAlias.second;
+    for (unsigned i = 0, e = aliasAttrsPair.second.size(); i != e; ++i)
+      printAlias(aliasAttrsPair.first, aliasAttrsPair.second[i], i);
+    os << "\n";
+  }
+
+  // In a second pass print all of the remaining attribute aliases that aren't
+  // kind aliases.
+  for (Attribute attr : usedAttributes) {
+    auto alias = attrToAlias.find(attr);
+    if (alias != attrToAlias.end() && alias->second.second == NonAttrKindAlias)
+      printAlias(alias->second.first, attr, alias->second.second);
+  }
+}
+
+/// Return a string to use as an alias for the given type, or empty if there
+/// is no alias recorded.
+StringRef AliasState::getTypeAlias(Type ty) const {
+  return typeToAlias.lookup(ty);
+}
+
+/// Print all of the referenced type aliases.
+void AliasState::printTypeAliases(raw_ostream &os) const {
+  for (Type type : usedTypes) {
+    auto alias = typeToAlias.find(type);
+    if (alias != typeToAlias.end())
+      os << '!' << alias->second << " = type " << type << '\n';
+  }
+}
+
+/// Record a reference to the given attribute.
+void AliasState::recordAttributeReference(Attribute attr) {
+  // Don't recheck attributes that have already been seen or those that
+  // already have an alias.
+  if (!usedAttributes.insert(attr) || attrToAlias.count(attr))
+    return;
+
+  // If this attribute kind has an alias, then record one for this attribute.
+  auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
+  if (alias == attrKindToAlias.end())
+    return;
+  std::pair<StringRef, int> attrAlias(alias->second.first,
+                                      alias->second.second.size());
+  attrToAlias.insert({attr, attrAlias});
+  alias->second.second.push_back(attr);
+}
+
+/// Record a reference to the given type.
+void AliasState::recordTypeReference(Type ty) { usedTypes.insert(ty); }
+
+// TODO Support visiting other types/operations when implemented.
+void AliasState::visitType(Type type) {
+  recordTypeReference(type);
+
+  if (auto funcType = type.dyn_cast<FunctionType>()) {
+    // Visit input and result types for functions.
+    for (auto input : funcType.getInputs())
+      visitType(input);
+    for (auto result : funcType.getResults())
+      visitType(result);
+  } else if (auto shapedType = type.dyn_cast<ShapedType>()) {
+    visitType(shapedType.getElementType());
+
+    // Visit affine maps in memref type.
+    if (auto memref = type.dyn_cast<MemRefType>())
+      for (auto map : memref.getAffineMaps())
+        recordAttributeReference(AffineMapAttr::get(map));
+  }
+}
+
+void AliasState::visitAttribute(Attribute attr) {
+  recordAttributeReference(attr);
+
+  if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
+    for (auto elt : arrayAttr.getValue())
+      visitAttribute(elt);
+  } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
+    visitType(typeAttr.getValue());
+  }
+}
+
+void AliasState::visitOperation(Operation *op) {
+  // Visit all the types used in the operation.
+  for (auto type : op->getOperandTypes())
+    visitType(type);
+  for (auto type : op->getResultTypes())
+    visitType(type);
+  for (auto &region : op->getRegions())
+    for (auto &block : region)
+      for (auto arg : block.getArguments())
+        visitType(arg->getType());
+
+  // Visit each of the attributes.
+  for (auto elt : op->getAttrs())
+    visitAttribute(elt.second);
+}
+
+//===----------------------------------------------------------------------===//
+// ModuleState
+//===----------------------------------------------------------------------===//
+
+namespace {
+class ModuleState {
+public:
+  explicit ModuleState(MLIRContext *context) : interfaces(context) {}
+
+  /// Initialize the alias state to enable the printing of aliases.
+  void initializeAliases(Operation *op) {
+    aliasState.initialize(op, interfaces);
+  }
+
+  /// Get an instance of the OpAsmDialectInterface for the given dialect, or
+  /// null if one wasn't registered.
+  const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
+    return interfaces.getInterfaceFor(dialect);
+  }
+
+  /// Get the state used for aliases.
+  AliasState &getAliasState() { return aliasState; }
+
+private:
+  /// Collection of OpAsm interfaces implemented in the context.
+  DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
+
+  /// The state used for attribute and type aliases.
+  AliasState aliasState;
+};
+} // end anonymous namespace
+
 //===----------------------------------------------------------------------===//
 // ModulePrinter
 //===----------------------------------------------------------------------===//
@@ -745,7 +793,7 @@ void ModulePrinter::printAttribute(Attribute attr, bool mayElideType) {
 
   // Check for an alias for this attribute.
   if (state) {
-    Twine alias = state->getAttributeAlias(attr);
+    Twine alias = state->getAliasState().getAttributeAlias(attr);
     if (!alias.isTriviallyEmpty()) {
       os << '#' << alias;
       return;
@@ -975,7 +1023,7 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
 void ModulePrinter::printType(Type type) {
   // Check for an alias for this type.
   if (state) {
-    StringRef alias = state->getTypeAlias(type);
+    StringRef alias = state->getAliasState().getTypeAlias(type);
     if (!alias.empty()) {
       os << '!' << alias;
       return;
@@ -1997,8 +2045,8 @@ void OperationPrinter::printSuccessorAndUseList(Operation *term,
 void ModulePrinter::print(ModuleOp module) {
   // Output the aliases at the top level.
   if (state) {
-    state->printAttributeAliases(os);
-    state->printTypeAliases(os);
+    state->getAliasState().printAttributeAliases(os);
+    state->getAliasState().printTypeAliases(os);
   }
 
   // Print the module.
@@ -2136,9 +2184,9 @@ void Block::printAsOperand(raw_ostream &os, bool printType) {
 
 void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) {
   ModuleState state(getContext());
-  // Skip initializing in local scope to avoid populating aliases.
+  // Don't populate aliases when printing at local scope.
   if (!flags.shouldUseLocalScope())
-    state.initialize(*this);
+    state.initializeAliases(*this);
   ModulePrinter(os, flags, &state).print(*this);
 }