From f3502afe852693a19848e9e328f2c2a55fc9e9bb Mon Sep 17 00:00:00 2001 From: River Riddle Date: Fri, 2 Sep 2022 21:23:47 -0700 Subject: [PATCH] [mlir] Allow passing AsmState when printing Attributes and Types This allows for extracting assembly information when printing an attribute or type, such as the dialect resources referenced. This functionality is used in a followup that adds resource support to the bytecode. This change also results in a nice cleanup of AsmPrinter now that we don't need to awkwardly workaround optional AsmStates. Differential Revision: https://reviews.llvm.org/D132728 --- mlir/include/mlir/IR/AsmState.h | 9 +++ mlir/include/mlir/IR/Attributes.h | 2 + mlir/include/mlir/IR/Types.h | 3 + mlir/lib/IR/AsmPrinter.cpp | 123 +++++++++++++++++++++++--------------- 4 files changed, 89 insertions(+), 48 deletions(-) diff --git a/mlir/include/mlir/IR/AsmState.h b/mlir/include/mlir/IR/AsmState.h index 51a6631..d3ef630 100644 --- a/mlir/include/mlir/IR/AsmState.h +++ b/mlir/include/mlir/IR/AsmState.h @@ -21,6 +21,7 @@ namespace mlir { class AsmResourcePrinter; +class AsmDialectResourceHandle; class Operation; namespace detail { @@ -455,6 +456,9 @@ public: AsmState(Operation *op, const OpPrintingFlags &printerFlags = OpPrintingFlags(), LocationMap *locationMap = nullptr); + AsmState(MLIRContext *ctx, + const OpPrintingFlags &printerFlags = OpPrintingFlags(), + LocationMap *locationMap = nullptr); ~AsmState(); /// Get the printer flags. @@ -480,6 +484,11 @@ public: name, std::forward(printFn))); } + /// Returns a map of dialect resources that were referenced when using this + /// state to print IR. + DenseMap> & + getDialectResources() const; + private: AsmState() = delete; diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 6ebb044..a3ee5d2 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -13,6 +13,7 @@ #include "llvm/Support/PointerLikeTypeTraits.h" namespace mlir { +class AsmState; class StringAttr; /// Attributes are known-constant values of operations. @@ -76,6 +77,7 @@ public: /// Print the attribute. void print(raw_ostream &os) const; + void print(raw_ostream &os, AsmState &state) const; void dump() const; /// Get an opaque pointer to the attribute. diff --git a/mlir/include/mlir/IR/Types.h b/mlir/include/mlir/IR/Types.h index 5cac1e2..28cccd1 100644 --- a/mlir/include/mlir/IR/Types.h +++ b/mlir/include/mlir/IR/Types.h @@ -15,6 +15,8 @@ #include "llvm/Support/PointerLikeTypeTraits.h" namespace mlir { +class AsmState; + /// Instances of the Type class are uniqued, have an immutable identifier and an /// optional mutable component. They wrap a pointer to the storage object owned /// by MLIRContext. Therefore, instances of Type are passed around by value. @@ -162,6 +164,7 @@ public: /// Print the current type. void print(raw_ostream &os) const; + void print(raw_ostream &os, AsmState &state) const; void dump() const; friend ::llvm::hash_code hash_value(Type arg); diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 9cf9501..841cc56 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -853,6 +853,7 @@ public: enum : unsigned { NameSentinel = ~0U }; SSANameState(Operation *op, const OpPrintingFlags &printerFlags); + SSANameState() = default; /// Print the SSA identifier for the given value to 'stream'. If /// 'printResultNo' is true, it also presents the result number ('#' number) @@ -1282,6 +1283,9 @@ public: AsmState::LocationMap *locationMap) : interfaces(op->getContext()), nameState(op, printerFlags), printerFlags(printerFlags), locationMap(locationMap) {} + explicit AsmStateImpl(MLIRContext *ctx, const OpPrintingFlags &printerFlags, + AsmState::LocationMap *locationMap) + : interfaces(ctx), printerFlags(printerFlags), locationMap(locationMap) {} /// Initialize the alias state to enable the printing of aliases. void initializeAliases(Operation *op) { @@ -1315,6 +1319,12 @@ public: (*locationMap)[op] = std::make_pair(line, col); } + /// Return the referenced dialect resources within the printer. + DenseMap> & + getDialectResources() { + return dialectResources; + } + private: /// Collection of OpAsm interfaces implemented in the context. DialectInterfaceCollection interfaces; @@ -1322,6 +1332,9 @@ private: /// A collection of non-dialect resource printers. SmallVector> externalResourcePrinters; + /// A set of dialect resources that were referenced during printing. + DenseMap> dialectResources; + /// The state used for attribute and type aliases. AliasState aliasState; @@ -1379,6 +1392,9 @@ AsmState::AsmState(Operation *op, const OpPrintingFlags &printerFlags, LocationMap *locationMap) : impl(std::make_unique( op, verifyOpAndAdjustFlags(op, printerFlags), locationMap)) {} +AsmState::AsmState(MLIRContext *ctx, const OpPrintingFlags &printerFlags, + LocationMap *locationMap) + : impl(std::make_unique(ctx, printerFlags, locationMap)) {} AsmState::~AsmState() = default; const OpPrintingFlags &AsmState::getPrinterFlags() const { @@ -1390,6 +1406,11 @@ void AsmState::attachResourcePrinter( impl->externalResourcePrinters.emplace_back(std::move(printer)); } +DenseMap> & +AsmState::getDialectResources() const { + return impl->getDialectResources(); +} + //===----------------------------------------------------------------------===// // AsmPrinter::Impl //===----------------------------------------------------------------------===// @@ -1397,11 +1418,9 @@ void AsmState::attachResourcePrinter( namespace mlir { class AsmPrinter::Impl { public: - Impl(raw_ostream &os, OpPrintingFlags flags = llvm::None, - AsmStateImpl *state = nullptr) - : os(os), printerFlags(flags), state(state) {} - explicit Impl(Impl &other) - : Impl(other.os, other.printerFlags, other.state) {} + Impl(raw_ostream &os, AsmStateImpl &state) + : os(os), state(state), printerFlags(state.getPrinterFlags()) {} + explicit Impl(Impl &other) : Impl(other.os, other.state) {} /// Returns the output stream of the printer. raw_ostream &getStream() { return os; } @@ -1446,7 +1465,7 @@ public: void printResourceHandle(const AsmDialectResourceHandle &resource) { auto *interface = cast(resource.getDialect()); os << interface->getResourceKey(resource); - dialectResources[resource.getDialect()].insert(resource); + state.getDialectResources()[resource.getDialect()].insert(resource); } void printAffineMap(AffineMap map); @@ -1503,17 +1522,14 @@ protected: /// The output stream for the printer. raw_ostream &os; + /// An underlying assembly printer state. + AsmStateImpl &state; + /// A set of flags to control the printer's behavior. OpPrintingFlags printerFlags; - /// An optional printer state for the module. - AsmStateImpl *state; - /// A tracker for the number of new lines emitted during printing. NewLineCounter newLine; - - /// A set of dialect resources that were referenced during printing. - DenseMap> dialectResources; }; } // namespace mlir @@ -1647,7 +1663,7 @@ void AsmPrinter::Impl::printLocation(LocationAttr loc, bool allowAlias) { return printLocationInternal(loc, /*pretty=*/true); os << "loc("; - if (!allowAlias || !state || failed(state->getAliasState().getAlias(loc, os))) + if (!allowAlias || failed(printAlias(loc))) printLocationInternal(loc); os << ')'; } @@ -1734,11 +1750,11 @@ static void printElidedElementsAttr(raw_ostream &os) { } LogicalResult AsmPrinter::Impl::printAlias(Attribute attr) { - return success(state && succeeded(state->getAliasState().getAlias(attr, os))); + return state.getAliasState().getAlias(attr, os); } LogicalResult AsmPrinter::Impl::printAlias(Type type) { - return success(state && succeeded(state->getAliasState().getAlias(type, os))); + return state.getAliasState().getAlias(type, os); } void AsmPrinter::Impl::printAttribute(Attribute attr, @@ -2068,7 +2084,7 @@ void AsmPrinter::Impl::printType(Type type) { } // Try to print an alias for this type. - if (state && succeeded(state->getAliasState().getAlias(type, os))) + if (succeeded(printAlias(type))) return; TypeSwitch(type) @@ -2242,14 +2258,9 @@ void AsmPrinter::Impl::printDialectAttribute(Attribute attr) { std::string attrName; { llvm::raw_string_ostream attrNameStr(attrName); - Impl subPrinter(attrNameStr, printerFlags, state); + Impl subPrinter(attrNameStr, state); DialectAsmPrinter printer(subPrinter); dialect.printAttribute(attr, printer); - - // FIXME: Delete this when we no longer require a nested printer. - for (auto &it : subPrinter.dialectResources) - for (const auto &resource : it.second) - dialectResources[it.first].insert(resource); } printDialectSymbol(os, "#", dialect.getNamespace(), attrName); } @@ -2261,14 +2272,9 @@ void AsmPrinter::Impl::printDialectType(Type type) { std::string typeName; { llvm::raw_string_ostream typeNameStr(typeName); - Impl subPrinter(typeNameStr, printerFlags, state); + Impl subPrinter(typeNameStr, state); DialectAsmPrinter printer(subPrinter); dialect.printType(type, printer); - - // FIXME: Delete this when we no longer require a nested printer. - for (auto &it : subPrinter.dialectResources) - for (const auto &resource : it.second) - dialectResources[it.first].insert(resource); } printDialectSymbol(os, "!", dialect.getNamespace(), typeName); } @@ -2561,8 +2567,7 @@ public: using Impl::printType; explicit OperationPrinter(raw_ostream &os, AsmStateImpl &state) - : Impl(os, state.getPrinterFlags(), &state), - OpAsmPrinter(static_cast(*this)) {} + : Impl(os, state), OpAsmPrinter(static_cast(*this)) {} /// Print the given top-level operation. void printTopLevelOperation(Operation *op); @@ -2646,7 +2651,7 @@ public: /// operations. If any entry in namesToUse is null, the corresponding /// argument name is left alone. void shadowRegionArgs(Region ®ion, ValueRange namesToUse) override { - state->getSSANameState().shadowRegionArgs(region, namesToUse); + state.getSSANameState().shadowRegionArgs(region, namesToUse); } /// Print the given affine map with the symbol and dimension operands printed @@ -2736,14 +2741,14 @@ private: void OperationPrinter::printTopLevelOperation(Operation *op) { // Output the aliases at the top level that can't be deferred. - state->getAliasState().printNonDeferredAliases(os, newLine); + state.getAliasState().printNonDeferredAliases(os, newLine); // Print the module. print(op); os << newLine; // Output the aliases at the top level that can be deferred. - state->getAliasState().printDeferredAliases(os, newLine); + state.getAliasState().printDeferredAliases(os, newLine); // Output any file level metadata. printFileMetadataDictionary(op); @@ -2795,7 +2800,8 @@ void OperationPrinter::printResourceFileMetadata( // Print the `dialect_resources` section if we have any dialects with // resources. - for (const OpAsmDialectInterface &interface : state->getDialectInterfaces()) { + for (const OpAsmDialectInterface &interface : state.getDialectInterfaces()) { + auto &dialectResources = state.getDialectResources(); StringRef name = interface.getDialect()->getNamespace(); auto it = dialectResources.find(interface.getDialect()); if (it != dialectResources.end()) @@ -2810,7 +2816,7 @@ void OperationPrinter::printResourceFileMetadata( // Print the `external_resources` section if we have any external clients with // resources. hadResource = false; - for (const auto &printer : state->getResourcePrinters()) + for (const auto &printer : state.getResourcePrinters()) processProvider("external", printer.getName(), printer); if (hadResource) os << newLine << " }"; @@ -2836,7 +2842,7 @@ void OperationPrinter::printRegionArgument(BlockArgument arg, void OperationPrinter::print(Operation *op) { // Track the location of this operation. - state->registerOperationLocation(op, newLine.curLine, currentIndent); + state.registerOperationLocation(op, newLine.curLine, currentIndent); os.indent(currentIndent); printOperation(op); @@ -2854,7 +2860,7 @@ void OperationPrinter::printOperation(Operation *op) { }; // Check to see if this operation has multiple result groups. - ArrayRef resultGroups = state->getSSANameState().getOpResultGroups(op); + ArrayRef resultGroups = state.getSSANameState().getOpResultGroups(op); if (!resultGroups.empty()) { // Interleave the groups excluding the last one, this one will be handled // separately. @@ -3010,7 +3016,7 @@ void OperationPrinter::printGenericOp(Operation *op, bool printOpName) { } void OperationPrinter::printBlockName(Block *block) { - os << state->getSSANameState().getBlockInfo(block).name; + os << state.getSSANameState().getBlockInfo(block).name; } void OperationPrinter::print(Block *block, bool printBlockArgs, @@ -3048,7 +3054,7 @@ void OperationPrinter::print(Block *block, bool printBlockArgs, // whatever order the use-list is in, so gather and sort them. SmallVector predIDs; for (auto *pred : block->getPredecessors()) - predIDs.push_back(state->getSSANameState().getBlockInfo(pred)); + predIDs.push_back(state.getSSANameState().getBlockInfo(pred)); llvm::sort(predIDs, [](BlockInfo lhs, BlockInfo rhs) { return lhs.ordering < rhs.ordering; }); @@ -3084,14 +3090,14 @@ void OperationPrinter::print(Block *block, bool printBlockArgs, void OperationPrinter::printValueID(Value value, bool printResultNo, raw_ostream *streamOverride) const { - state->getSSANameState().printValueID(value, printResultNo, - streamOverride ? *streamOverride : os); + state.getSSANameState().printValueID(value, printResultNo, + streamOverride ? *streamOverride : os); } void OperationPrinter::printOperationID(Operation *op, raw_ostream *streamOverride) const { - state->getSSANameState().printOperationID(op, streamOverride ? *streamOverride - : os); + state.getSSANameState().printOperationID(op, streamOverride ? *streamOverride + : os); } void OperationPrinter::printSuccessor(Block *successor) { @@ -3176,7 +3182,16 @@ void OperationPrinter::printAffineExprOfSSAIds(AffineExpr expr, //===----------------------------------------------------------------------===// void Attribute::print(raw_ostream &os) const { - AsmPrinter::Impl(os).printAttribute(*this); + if (!*this) { + os << "<>"; + return; + } + + AsmState state(getContext()); + print(os, state); +} +void Attribute::print(raw_ostream &os, AsmState &state) const { + AsmPrinter::Impl(os, state.getImpl()).printAttribute(*this); } void Attribute::dump() const { @@ -3185,7 +3200,16 @@ void Attribute::dump() const { } void Type::print(raw_ostream &os) const { - AsmPrinter::Impl(os).printType(*this); + if (!*this) { + os << "<>"; + return; + } + + AsmState state(getContext()); + print(os, state); +} +void Type::print(raw_ostream &os, AsmState &state) const { + AsmPrinter::Impl(os, state.getImpl()).printType(*this); } void Type::dump() const { print(llvm::errs()); } @@ -3205,7 +3229,8 @@ void AffineExpr::print(raw_ostream &os) const { os << "<>"; return; } - AsmPrinter::Impl(os).printAffineExpr(*this); + AsmState state(getContext()); + AsmPrinter::Impl(os, state.getImpl()).printAffineExpr(*this); } void AffineExpr::dump() const { @@ -3218,11 +3243,13 @@ void AffineMap::print(raw_ostream &os) const { os << "<>"; return; } - AsmPrinter::Impl(os).printAffineMap(*this); + AsmState state(getContext()); + AsmPrinter::Impl(os, state.getImpl()).printAffineMap(*this); } void IntegerSet::print(raw_ostream &os) const { - AsmPrinter::Impl(os).printIntegerSet(*this); + AsmState state(getContext()); + AsmPrinter::Impl(os, state.getImpl()).printIntegerSet(*this); } void Value::print(raw_ostream &os) { print(os, OpPrintingFlags()); } -- 2.7.4