bool OpPrintingFlags::shouldUseLocalScope() const { return printLocalScope; }
//===----------------------------------------------------------------------===//
-// ModuleState
+// AliasState
//===----------------------------------------------------------------------===//
namespace {
-/// A special index constant used for non-kind attribute aliases.
-static constexpr int kNonAttrKindAlias = -1;
-
-class ModuleState {
+/// This class manages the state for type and attribute aliases.
+class AliasState {
public:
- explicit ModuleState(MLIRContext *context) : interfaces(context) {}
- void initialize(Operation *op);
-
- Twine getAttributeAlias(Attribute attr) const {
- auto alias = attrToAlias.find(attr);
- if (alias == attrToAlias.end())
- return Twine();
-
- // Return the alias for this attribute, along with the index if this was
- // generated by a kind alias.
- int kindIndex = alias->second.second;
- return alias->second.first +
- (kindIndex == kNonAttrKindAlias ? Twine() : Twine(kindIndex));
- }
-
- void printAttributeAliases(raw_ostream &os) const {
- auto printAlias = [&](StringRef alias, Attribute attr, int index) {
- os << '#' << alias;
- if (index != kNonAttrKindAlias)
- os << index;
- os << " = " << attr << '\n';
- };
-
- // Print all of the attribute kind aliases.
- for (auto &kindAlias : attrKindToAlias) {
- for (unsigned i = 0, e = kindAlias.second.second.size(); i != e; ++i)
- printAlias(kindAlias.second.first, kindAlias.second.second[i], i);
- os << "\n";
- }
+ // Initialize the internal aliases.
+ void
+ initialize(Operation *op,
+ DialectInterfaceCollection<OpAsmDialectInterface> &interfaces);
- // In a second pass print all of the remaining attribute aliases that aren't
- // kind aliases.
- for (Attribute attr : usedAttributes) {
- auto alias = attrToAlias.find(attr);
- if (alias != attrToAlias.end() &&
- alias->second.second == kNonAttrKindAlias)
- printAlias(alias->second.first, attr, alias->second.second);
- }
- }
+ /// Return a name used for an attribute alias, or empty if there is no alias.
+ Twine getAttributeAlias(Attribute attr) const;
- StringRef getTypeAlias(Type ty) const { return typeToAlias.lookup(ty); }
+ /// Print all of the referenced attribute aliases.
+ void printAttributeAliases(raw_ostream &os) const;
- void printTypeAliases(raw_ostream &os) const {
- for (Type type : usedTypes) {
- auto alias = typeToAlias.find(type);
- if (alias != typeToAlias.end())
- os << '!' << alias->second << " = type " << type << '\n';
- }
- }
+ /// Return a string to use as an alias for the given type, or empty if there
+ /// is no alias recorded.
+ StringRef getTypeAlias(Type ty) const;
- /// Get an instance of the OpAsmDialectInterface for the given dialect, or
- /// null if one wasn't registered.
- const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
- return interfaces.getInterfaceFor(dialect);
- }
+ /// Print all of the referenced type aliases.
+ void printTypeAliases(raw_ostream &os) const;
private:
- void recordAttributeReference(Attribute attr) {
- // Don't recheck attributes that have already been seen or those that
- // already have an alias.
- if (!usedAttributes.insert(attr) || attrToAlias.count(attr))
- return;
+ /// A special index constant used for non-kind attribute aliases.
+ enum { NonAttrKindAlias = -1 };
- // If this attribute kind has an alias, then record one for this attribute.
- auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
- if (alias == attrKindToAlias.end())
- return;
- std::pair<StringRef, int> attrAlias(alias->second.first,
- alias->second.second.size());
- attrToAlias.insert({attr, attrAlias});
- alias->second.second.push_back(attr);
- }
+ /// Record a reference to the given attribute.
+ void recordAttributeReference(Attribute attr);
- void recordTypeReference(Type ty) { usedTypes.insert(ty); }
+ /// Record a reference to the given type.
+ void recordTypeReference(Type ty);
// Visit functions.
void visitOperation(Operation *op);
void visitType(Type type);
void visitAttribute(Attribute attr);
- // Initialize symbol aliases.
- void initializeSymbolAliases();
-
/// Set of attributes known to be used within the module.
llvm::SetVector<Attribute> usedAttributes;
/// A mapping between a type and a given alias.
DenseMap<Type, StringRef> typeToAlias;
-
- /// Collection of OpAsm interfaces implemented in the context.
- DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
};
} // end anonymous namespace
-// TODO Support visiting other types/operations when implemented.
-void ModuleState::visitType(Type type) {
- recordTypeReference(type);
- if (auto funcType = type.dyn_cast<FunctionType>()) {
- // Visit input and result types for functions.
- for (auto input : funcType.getInputs())
- visitType(input);
- for (auto result : funcType.getResults())
- visitType(result);
- return;
- }
- if (auto memref = type.dyn_cast<MemRefType>()) {
- // Visit affine maps in memref type.
- for (auto map : memref.getAffineMaps())
- recordAttributeReference(AffineMapAttr::get(map));
- }
- if (auto shapedType = type.dyn_cast<ShapedType>()) {
- visitType(shapedType.getElementType());
- }
-}
-
-void ModuleState::visitAttribute(Attribute attr) {
- recordAttributeReference(attr);
- if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
- for (auto elt : arrayAttr.getValue())
- visitAttribute(elt);
- } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
- visitType(typeAttr.getValue());
- }
-}
-
-void ModuleState::visitOperation(Operation *op) {
- // Visit all the types used in the operation.
- for (auto type : op->getOperandTypes())
- visitType(type);
- for (auto type : op->getResultTypes())
- visitType(type);
- for (auto ®ion : op->getRegions())
- for (auto &block : region)
- for (auto arg : block.getArguments())
- visitType(arg->getType());
-
- // Visit each of the attributes.
- for (auto elt : op->getAttrs())
- visitAttribute(elt.second);
-}
-
// Utility to generate a function to register a symbol alias.
static bool canRegisterAlias(StringRef name, llvm::StringSet<> &usedAliases) {
assert(!name.empty() && "expected alias name to be non-empty");
return !name.contains('.') && usedAliases.insert(name).second;
}
-void ModuleState::initializeSymbolAliases() {
+void AliasState::initialize(
+ Operation *op,
+ DialectInterfaceCollection<OpAsmDialectInterface> &interfaces) {
// Track the identifiers in use for each symbol so that the same identifier
// isn't used twice.
llvm::StringSet<> usedAliases;
for (auto &attrAliasPair : attributeAliases) {
std::tie(attr, alias) = attrAliasPair;
if (!reservedAttrNames.match(alias) && canRegisterAlias(alias, usedAliases))
- attrToAlias.insert({attr, {alias, kNonAttrKindAlias}});
+ attrToAlias.insert({attr, {alias, NonAttrKindAlias}});
}
// Clear the set of used identifiers as types can have the same identifiers as
for (auto &typeAliasPair : typeAliases)
if (canRegisterAlias(typeAliasPair.second, usedAliases))
typeToAlias.insert(typeAliasPair);
+
+ // Traverse the given IR to generate the set of used attributes/types.
+ op->walk([&](Operation *op) { visitOperation(op); });
}
-void ModuleState::initialize(Operation *op) {
- // Initialize the symbol aliases.
- initializeSymbolAliases();
+/// Return a name used for an attribute alias, or empty if there is no alias.
+Twine AliasState::getAttributeAlias(Attribute attr) const {
+ auto alias = attrToAlias.find(attr);
+ if (alias == attrToAlias.end())
+ return Twine();
- // Visit each of the nested operations.
- op->walk([&](Operation *op) { visitOperation(op); });
+ // Return the alias for this attribute, along with the index if this was
+ // generated by a kind alias.
+ int kindIndex = alias->second.second;
+ return alias->second.first +
+ (kindIndex == NonAttrKindAlias ? Twine() : Twine(kindIndex));
}
+/// Print all of the referenced attribute aliases.
+void AliasState::printAttributeAliases(raw_ostream &os) const {
+ auto printAlias = [&](StringRef alias, Attribute attr, int index) {
+ os << '#' << alias;
+ if (index != NonAttrKindAlias)
+ os << index;
+ os << " = " << attr << '\n';
+ };
+
+ // Print all of the attribute kind aliases.
+ for (auto &kindAlias : attrKindToAlias) {
+ auto &aliasAttrsPair = kindAlias.second;
+ for (unsigned i = 0, e = aliasAttrsPair.second.size(); i != e; ++i)
+ printAlias(aliasAttrsPair.first, aliasAttrsPair.second[i], i);
+ os << "\n";
+ }
+
+ // In a second pass print all of the remaining attribute aliases that aren't
+ // kind aliases.
+ for (Attribute attr : usedAttributes) {
+ auto alias = attrToAlias.find(attr);
+ if (alias != attrToAlias.end() && alias->second.second == NonAttrKindAlias)
+ printAlias(alias->second.first, attr, alias->second.second);
+ }
+}
+
+/// Return a string to use as an alias for the given type, or empty if there
+/// is no alias recorded.
+StringRef AliasState::getTypeAlias(Type ty) const {
+ return typeToAlias.lookup(ty);
+}
+
+/// Print all of the referenced type aliases.
+void AliasState::printTypeAliases(raw_ostream &os) const {
+ for (Type type : usedTypes) {
+ auto alias = typeToAlias.find(type);
+ if (alias != typeToAlias.end())
+ os << '!' << alias->second << " = type " << type << '\n';
+ }
+}
+
+/// Record a reference to the given attribute.
+void AliasState::recordAttributeReference(Attribute attr) {
+ // Don't recheck attributes that have already been seen or those that
+ // already have an alias.
+ if (!usedAttributes.insert(attr) || attrToAlias.count(attr))
+ return;
+
+ // If this attribute kind has an alias, then record one for this attribute.
+ auto alias = attrKindToAlias.find(static_cast<unsigned>(attr.getKind()));
+ if (alias == attrKindToAlias.end())
+ return;
+ std::pair<StringRef, int> attrAlias(alias->second.first,
+ alias->second.second.size());
+ attrToAlias.insert({attr, attrAlias});
+ alias->second.second.push_back(attr);
+}
+
+/// Record a reference to the given type.
+void AliasState::recordTypeReference(Type ty) { usedTypes.insert(ty); }
+
+// TODO Support visiting other types/operations when implemented.
+void AliasState::visitType(Type type) {
+ recordTypeReference(type);
+
+ if (auto funcType = type.dyn_cast<FunctionType>()) {
+ // Visit input and result types for functions.
+ for (auto input : funcType.getInputs())
+ visitType(input);
+ for (auto result : funcType.getResults())
+ visitType(result);
+ } else if (auto shapedType = type.dyn_cast<ShapedType>()) {
+ visitType(shapedType.getElementType());
+
+ // Visit affine maps in memref type.
+ if (auto memref = type.dyn_cast<MemRefType>())
+ for (auto map : memref.getAffineMaps())
+ recordAttributeReference(AffineMapAttr::get(map));
+ }
+}
+
+void AliasState::visitAttribute(Attribute attr) {
+ recordAttributeReference(attr);
+
+ if (auto arrayAttr = attr.dyn_cast<ArrayAttr>()) {
+ for (auto elt : arrayAttr.getValue())
+ visitAttribute(elt);
+ } else if (auto typeAttr = attr.dyn_cast<TypeAttr>()) {
+ visitType(typeAttr.getValue());
+ }
+}
+
+void AliasState::visitOperation(Operation *op) {
+ // Visit all the types used in the operation.
+ for (auto type : op->getOperandTypes())
+ visitType(type);
+ for (auto type : op->getResultTypes())
+ visitType(type);
+ for (auto ®ion : op->getRegions())
+ for (auto &block : region)
+ for (auto arg : block.getArguments())
+ visitType(arg->getType());
+
+ // Visit each of the attributes.
+ for (auto elt : op->getAttrs())
+ visitAttribute(elt.second);
+}
+
+//===----------------------------------------------------------------------===//
+// ModuleState
+//===----------------------------------------------------------------------===//
+
+namespace {
+class ModuleState {
+public:
+ explicit ModuleState(MLIRContext *context) : interfaces(context) {}
+
+ /// Initialize the alias state to enable the printing of aliases.
+ void initializeAliases(Operation *op) {
+ aliasState.initialize(op, interfaces);
+ }
+
+ /// Get an instance of the OpAsmDialectInterface for the given dialect, or
+ /// null if one wasn't registered.
+ const OpAsmDialectInterface *getOpAsmInterface(Dialect *dialect) {
+ return interfaces.getInterfaceFor(dialect);
+ }
+
+ /// Get the state used for aliases.
+ AliasState &getAliasState() { return aliasState; }
+
+private:
+ /// Collection of OpAsm interfaces implemented in the context.
+ DialectInterfaceCollection<OpAsmDialectInterface> interfaces;
+
+ /// The state used for attribute and type aliases.
+ AliasState aliasState;
+};
+} // end anonymous namespace
+
//===----------------------------------------------------------------------===//
// ModulePrinter
//===----------------------------------------------------------------------===//
// Check for an alias for this attribute.
if (state) {
- Twine alias = state->getAttributeAlias(attr);
+ Twine alias = state->getAliasState().getAttributeAlias(attr);
if (!alias.isTriviallyEmpty()) {
os << '#' << alias;
return;
void ModulePrinter::printType(Type type) {
// Check for an alias for this type.
if (state) {
- StringRef alias = state->getTypeAlias(type);
+ StringRef alias = state->getAliasState().getTypeAlias(type);
if (!alias.empty()) {
os << '!' << alias;
return;
void ModulePrinter::print(ModuleOp module) {
// Output the aliases at the top level.
if (state) {
- state->printAttributeAliases(os);
- state->printTypeAliases(os);
+ state->getAliasState().printAttributeAliases(os);
+ state->getAliasState().printTypeAliases(os);
}
// Print the module.
void ModuleOp::print(raw_ostream &os, OpPrintingFlags flags) {
ModuleState state(getContext());
- // Skip initializing in local scope to avoid populating aliases.
+ // Don't populate aliases when printing at local scope.
if (!flags.shouldUseLocalScope())
- state.initialize(*this);
+ state.initializeAliases(*this);
ModulePrinter(os, flags, &state).print(*this);
}