[mlir] Allow passing AsmState when printing Attributes and Types
authorRiver Riddle <riddleriver@gmail.com>
Sat, 3 Sep 2022 04:23:47 +0000 (21:23 -0700)
committerRiver Riddle <riddleriver@gmail.com>
Tue, 6 Sep 2022 21:45:12 +0000 (14:45 -0700)
This allows for extracting assembly information when printing an attribute
or type, such as the dialect resources referenced. This functionality is used in
a followup that adds resource support to the bytecode. This change also results
in a nice cleanup of AsmPrinter now that we don't need to awkwardly workaround
optional AsmStates.

Differential Revision: https://reviews.llvm.org/D132728

mlir/include/mlir/IR/AsmState.h
mlir/include/mlir/IR/Attributes.h
mlir/include/mlir/IR/Types.h
mlir/lib/IR/AsmPrinter.cpp

index 51a6631..d3ef630 100644 (file)
@@ -21,6 +21,7 @@
 
 namespace mlir {
 class AsmResourcePrinter;
+class AsmDialectResourceHandle;
 class Operation;
 
 namespace detail {
@@ -455,6 +456,9 @@ public:
   AsmState(Operation *op,
            const OpPrintingFlags &printerFlags = OpPrintingFlags(),
            LocationMap *locationMap = nullptr);
+  AsmState(MLIRContext *ctx,
+           const OpPrintingFlags &printerFlags = OpPrintingFlags(),
+           LocationMap *locationMap = nullptr);
   ~AsmState();
 
   /// Get the printer flags.
@@ -480,6 +484,11 @@ public:
         name, std::forward<CallableT>(printFn)));
   }
 
+  /// Returns a map of dialect resources that were referenced when using this
+  /// state to print IR.
+  DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> &
+  getDialectResources() const;
+
 private:
   AsmState() = delete;
 
index 6ebb044..a3ee5d2 100644 (file)
@@ -13,6 +13,7 @@
 #include "llvm/Support/PointerLikeTypeTraits.h"
 
 namespace mlir {
+class AsmState;
 class StringAttr;
 
 /// Attributes are known-constant values of operations.
@@ -76,6 +77,7 @@ public:
 
   /// Print the attribute.
   void print(raw_ostream &os) const;
+  void print(raw_ostream &os, AsmState &state) const;
   void dump() const;
 
   /// Get an opaque pointer to the attribute.
index 5cac1e2..28cccd1 100644 (file)
@@ -15,6 +15,8 @@
 #include "llvm/Support/PointerLikeTypeTraits.h"
 
 namespace mlir {
+class AsmState;
+
 /// Instances of the Type class are uniqued, have an immutable identifier and an
 /// optional mutable component.  They wrap a pointer to the storage object owned
 /// by MLIRContext.  Therefore, instances of Type are passed around by value.
@@ -162,6 +164,7 @@ public:
 
   /// Print the current type.
   void print(raw_ostream &os) const;
+  void print(raw_ostream &os, AsmState &state) const;
   void dump() const;
 
   friend ::llvm::hash_code hash_value(Type arg);
index 9cf9501..841cc56 100644 (file)
@@ -853,6 +853,7 @@ public:
   enum : unsigned { NameSentinel = ~0U };
 
   SSANameState(Operation *op, const OpPrintingFlags &printerFlags);
+  SSANameState() = default;
 
   /// Print the SSA identifier for the given value to 'stream'. If
   /// 'printResultNo' is true, it also presents the result number ('#' number)
@@ -1282,6 +1283,9 @@ public:
                         AsmState::LocationMap *locationMap)
       : interfaces(op->getContext()), nameState(op, printerFlags),
         printerFlags(printerFlags), locationMap(locationMap) {}
+  explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
+                        AsmState::LocationMap *locationMap)
+      : interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {}
 
   /// Initialize the alias state to enable the printing of aliases.
   void initializeAliases(Operation *op) {
@@ -1315,6 +1319,12 @@ public:
       (*locationMap)[op] = std::make_pair(line, col);
   }
 
+  /// Return the referenced dialect resources within the printer.
+  DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> &
+  getDialectResources() {
+    return dialectResources;
+  }
+
 private:
   /// Collection of OpAsm interfaces implemented in the context.
   DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
@@ -1322,6 +1332,9 @@ private:
   /// A collection of non-dialect resource printers.
   SmallVector<std::unique_ptr<AsmResourcePrinter>> externalResourcePrinters;
 
+  /// A set of dialect resources that were referenced during printing.
+  DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> dialectResources;
+
   /// The state used for attribute and type aliases.
   AliasState aliasState;
 
@@ -1379,6 +1392,9 @@ AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags,
                    LocationMap *locationMap)
     : impl(std::make_unique<AsmStateImpl>(
           op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {}
+AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags,
+                   LocationMap *locationMap)
+    : impl(std::make_unique<AsmStateImpl>(ctx, printerFlags, locationMap)) {}
 AsmState::~AsmState() = default;
 
 const OpPrintingFlags &AsmState::getPrinterFlags() const {
@@ -1390,6 +1406,11 @@ void AsmState::attachResourcePrinter(
   impl->externalResourcePrinters.emplace_back(std::move(printer));
 }
 
+DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> &
+AsmState::getDialectResources() const {
+  return impl->getDialectResources();
+}
+
 //===----------------------------------------------------------------------===//
 // AsmPrinter::Impl
 //===----------------------------------------------------------------------===//
@@ -1397,11 +1418,9 @@ void AsmState::attachResourcePrinter(
 namespace mlir {
 class AsmPrinter::Impl {
 public:
-  Impl(raw_ostream &os, OpPrintingFlags flags = llvm::None,
-       AsmStateImpl *state = nullptr)
-      : os(os), printerFlags(flags), state(state) {}
-  explicit Impl(Impl &other)
-      : Impl(other.os, other.printerFlags, other.state) {}
+  Impl(raw_ostream &os, AsmStateImpl &state)
+      : os(os), state(state), printerFlags(state.getPrinterFlags()) {}
+  explicit Impl(Impl &other) : Impl(other.os, other.state) {}
 
   /// Returns the output stream of the printer.
   raw_ostream &getStream() { return os; }
@@ -1446,7 +1465,7 @@ public:
   void printResourceHandle(const AsmDialectResourceHandle &resource) {
     auto *interface = cast<OpAsmDialectInterface>(resource.getDialect());
     os << interface->getResourceKey(resource);
-    dialectResources[resource.getDialect()].insert(resource);
+    state.getDialectResources()[resource.getDialect()].insert(resource);
   }
 
   void printAffineMap(AffineMap map);
@@ -1503,17 +1522,14 @@ protected:
   /// The output stream for the printer.
   raw_ostream &os;
 
+  /// An underlying assembly printer state.
+  AsmStateImpl &state;
+
   /// A set of flags to control the printer's behavior.
   OpPrintingFlags printerFlags;
 
-  /// An optional printer state for the module.
-  AsmStateImpl *state;
-
   /// A tracker for the number of new lines emitted during printing.
   NewLineCounter newLine;
-
-  /// A set of dialect resources that were referenced during printing.
-  DenseMap<Dialect *, SetVector<AsmDialectResourceHandle>> dialectResources;
 };
 } // namespace mlir
 
@@ -1647,7 +1663,7 @@ void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) {
     return printLocationInternal(loc, /*pretty=*/true);
 
   os << "loc(";
-  if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os)))
+  if (!allowAlias || failed(printAlias(loc)))
     printLocationInternal(loc);
   os << ')';
 }
@@ -1734,11 +1750,11 @@ static void printElidedElementsAttr(raw_ostream &os) {
 }
 
 LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) {
-  return success(state && succeeded(state->getAliasState().getAlias(attr, os)));
+  return state.getAliasState().getAlias(attr, os);
 }
 
 LogicalResult AsmPrinter::Impl::printAlias(Type type) {
-  return success(state && succeeded(state->getAliasState().getAlias(type, os)));
+  return state.getAliasState().getAlias(type, os);
 }
 
 void AsmPrinter::Impl::printAttribute(Attribute attr,
@@ -2068,7 +2084,7 @@ void AsmPrinter::Impl::printType(Type type) {
   }
 
   // Try to print an alias for this type.
-  if (state && succeeded(state->getAliasState().getAlias(type, os)))
+  if (succeeded(printAlias(type)))
     return;
 
   TypeSwitch<Type>(type)
@@ -2242,14 +2258,9 @@ void AsmPrinter::Impl::printDialectAttribute(Attribute attr) {
   std::string attrName;
   {
     llvm::raw_string_ostream attrNameStr(attrName);
-    Impl subPrinter(attrNameStr, printerFlags, state);
+    Impl subPrinter(attrNameStr, state);
     DialectAsmPrinter printer(subPrinter);
     dialect.printAttribute(attr, printer);
-
-    // FIXME: Delete this when we no longer require a nested printer.
-    for (auto &it : subPrinter.dialectResources)
-      for (const auto &resource : it.second)
-        dialectResources[it.first].insert(resource);
   }
   printDialectSymbol(os, "#", dialect.getNamespace(), attrName);
 }
@@ -2261,14 +2272,9 @@ void AsmPrinter::Impl::printDialectType(Type type) {
   std::string typeName;
   {
     llvm::raw_string_ostream typeNameStr(typeName);
-    Impl subPrinter(typeNameStr, printerFlags, state);
+    Impl subPrinter(typeNameStr, state);
     DialectAsmPrinter printer(subPrinter);
     dialect.printType(type, printer);
-
-    // FIXME: Delete this when we no longer require a nested printer.
-    for (auto &it : subPrinter.dialectResources)
-      for (const auto &resource : it.second)
-        dialectResources[it.first].insert(resource);
   }
   printDialectSymbol(os, "!", dialect.getNamespace(), typeName);
 }
@@ -2561,8 +2567,7 @@ public:
   using Impl::printType;
 
   explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state)
-      : Impl(os, state.getPrinterFlags(), &state),
-        OpAsmPrinter(static_cast<Impl &>(*this)) {}
+      : Impl(os, state), OpAsmPrinter(static_cast<Impl &>(*this)) {}
 
   /// Print the given top-level operation.
   void printTopLevelOperation(Operation *op);
@@ -2646,7 +2651,7 @@ public:
   /// operations. If any entry in namesToUse is null, the corresponding
   /// argument name is left alone.
   void shadowRegionArgs(Region &region, ValueRange namesToUse) override {
-    state->getSSANameState().shadowRegionArgs(region, namesToUse);
+    state.getSSANameState().shadowRegionArgs(region, namesToUse);
   }
 
   /// Print the given affine map with the symbol and dimension operands printed
@@ -2736,14 +2741,14 @@ private:
 
 void OperationPrinter::printTopLevelOperation(Operation *op) {
   // Output the aliases at the top level that can't be deferred.
-  state->getAliasState().printNonDeferredAliases(os, newLine);
+  state.getAliasState().printNonDeferredAliases(os, newLine);
 
   // Print the module.
   print(op);
   os << newLine;
 
   // Output the aliases at the top level that can be deferred.
-  state->getAliasState().printDeferredAliases(os, newLine);
+  state.getAliasState().printDeferredAliases(os, newLine);
 
   // Output any file level metadata.
   printFileMetadataDictionary(op);
@@ -2795,7 +2800,8 @@ void OperationPrinter::printResourceFileMetadata(
 
   // Print the `dialect_resources` section if we have any dialects with
   // resources.
-  for (const OpAsmDialectInterface &interface : state->getDialectInterfaces()) {
+  for (const OpAsmDialectInterface &interface : state.getDialectInterfaces()) {
+    auto &dialectResources = state.getDialectResources();
     StringRef name = interface.getDialect()->getNamespace();
     auto it = dialectResources.find(interface.getDialect());
     if (it != dialectResources.end())
@@ -2810,7 +2816,7 @@ void OperationPrinter::printResourceFileMetadata(
   // Print the `external_resources` section if we have any external clients with
   // resources.
   hadResource = false;
-  for (const auto &printer : state->getResourcePrinters())
+  for (const auto &printer : state.getResourcePrinters())
     processProvider("external", printer.getName(), printer);
   if (hadResource)
     os << newLine << "  }";
@@ -2836,7 +2842,7 @@ void OperationPrinter::printRegionArgument(BlockArgument arg,
 
 void OperationPrinter::print(Operation *op) {
   // Track the location of this operation.
-  state->registerOperationLocation(op, newLine.curLine, currentIndent);
+  state.registerOperationLocation(op, newLine.curLine, currentIndent);
 
   os.indent(currentIndent);
   printOperation(op);
@@ -2854,7 +2860,7 @@ void OperationPrinter::printOperation(Operation *op) {
     };
 
     // Check to see if this operation has multiple result groups.
-    ArrayRef<int> resultGroups = state->getSSANameState().getOpResultGroups(op);
+    ArrayRef<int> resultGroups = state.getSSANameState().getOpResultGroups(op);
     if (!resultGroups.empty()) {
       // Interleave the groups excluding the last one, this one will be handled
       // separately.
@@ -3010,7 +3016,7 @@ void OperationPrinter::printGenericOp(Operation *op, bool printOpName) {
 }
 
 void OperationPrinter::printBlockName(Block *block) {
-  os << state->getSSANameState().getBlockInfo(block).name;
+  os << state.getSSANameState().getBlockInfo(block).name;
 }
 
 void OperationPrinter::print(Block *block, bool printBlockArgs,
@@ -3048,7 +3054,7 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
       // whatever order the use-list is in, so gather and sort them.
       SmallVector<BlockInfo, 4> predIDs;
       for (auto *pred : block->getPredecessors())
-        predIDs.push_back(state->getSSANameState().getBlockInfo(pred));
+        predIDs.push_back(state.getSSANameState().getBlockInfo(pred));
       llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) {
         return lhs.ordering < rhs.ordering;
       });
@@ -3084,14 +3090,14 @@ void OperationPrinter::print(Block *block, bool printBlockArgs,
 
 void OperationPrinter::printValueID(Value value, bool printResultNo,
                                     raw_ostream *streamOverride) const {
-  state->getSSANameState().printValueID(value, printResultNo,
-                                        streamOverride ? *streamOverride : os);
+  state.getSSANameState().printValueID(value, printResultNo,
+                                       streamOverride ? *streamOverride : os);
 }
 
 void OperationPrinter::printOperationID(Operation *op,
                                         raw_ostream *streamOverride) const {
-  state->getSSANameState().printOperationID(op, streamOverride ? *streamOverride
-                                                               : os);
+  state.getSSANameState().printOperationID(op, streamOverride ? *streamOverride
+                                                              : os);
 }
 
 void OperationPrinter::printSuccessor(Block *successor) {
@@ -3176,7 +3182,16 @@ void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr,
 //===----------------------------------------------------------------------===//
 
 void Attribute::print(raw_ostream &os) const {
-  AsmPrinter::Impl(os).printAttribute(*this);
+  if (!*this) {
+    os << "<<NULL ATTRIBUTE>>";
+    return;
+  }
+
+  AsmState state(getContext());
+  print(os, state);
+}
+void Attribute::print(raw_ostream &os, AsmState &state) const {
+  AsmPrinter::Impl(os, state.getImpl()).printAttribute(*this);
 }
 
 void Attribute::dump() const {
@@ -3185,7 +3200,16 @@ void Attribute::dump() const {
 }
 
 void Type::print(raw_ostream &os) const {
-  AsmPrinter::Impl(os).printType(*this);
+  if (!*this) {
+    os << "<<NULL TYPE>>";
+    return;
+  }
+
+  AsmState state(getContext());
+  print(os, state);
+}
+void Type::print(raw_ostream &os, AsmState &state) const {
+  AsmPrinter::Impl(os, state.getImpl()).printType(*this);
 }
 
 void Type::dump() const { print(llvm::errs()); }
@@ -3205,7 +3229,8 @@ void AffineExpr::print(raw_ostream &os) const {
     os << "<<NULL AFFINE EXPR>>";
     return;
   }
-  AsmPrinter::Impl(os).printAffineExpr(*this);
+  AsmState state(getContext());
+  AsmPrinter::Impl(os, state.getImpl()).printAffineExpr(*this);
 }
 
 void AffineExpr::dump() const {
@@ -3218,11 +3243,13 @@ void AffineMap::print(raw_ostream &os) const {
     os << "<<NULL AFFINE MAP>>";
     return;
   }
-  AsmPrinter::Impl(os).printAffineMap(*this);
+  AsmState state(getContext());
+  AsmPrinter::Impl(os, state.getImpl()).printAffineMap(*this);
 }
 
 void IntegerSet::print(raw_ostream &os) const {
-  AsmPrinter::Impl(os).printIntegerSet(*this);
+  AsmState state(getContext());
+  AsmPrinter::Impl(os, state.getImpl()).printIntegerSet(*this);
 }
 
 void Value::print(raw_ostream &os) { print(os, OpPrintingFlags()); }