#include "mlir/IR/InstVisitor.h"
#include "mlir/IR/Instructions.h"
#include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/StandardTypes.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Regex.h"
using namespace mlir;
void Identifier::print(raw_ostream &os) const { os << str(); }
// Initializes module state, populating affine map state.
void initialize(const Module *module);
+ StringRef getAffineMapAlias(AffineMap affineMap) const {
+ return affineMapToAlias.lookup(affineMap);
+ }
+
int getAffineMapId(AffineMap affineMap) const {
auto it = affineMapIds.find(affineMap);
if (it == affineMapIds.end()) {
ArrayRef<AffineMap> getAffineMapIds() const { return affineMapsById; }
+ StringRef getIntegerSetAlias(IntegerSet integerSet) const {
+ return integerSetToAlias.lookup(integerSet);
+ }
+
int getIntegerSetId(IntegerSet integerSet) const {
auto it = integerSetIds.find(integerSet);
if (it == integerSetIds.end()) {
ArrayRef<IntegerSet> getIntegerSetIds() const { return integerSetsById; }
+ StringRef getTypeAlias(Type ty) const { return typeToAlias.lookup(ty); }
+
+ ArrayRef<Type> getTypeIds() const { return usedTypes.getArrayRef(); }
+
private:
void recordAffineMapReference(AffineMap affineMap) {
if (affineMapIds.count(affineMap) == 0) {
}
}
+ void recordTypeReference(Type ty) { usedTypes.insert(ty); }
+
// Return true if this map could be printed using the shorthand form.
static bool hasShorthandForm(AffineMap boundMap) {
if (boundMap.isSingleConstant())
void visitType(Type type);
void visitAttribute(Attribute attr);
+ // Initialize symbol aliases.
+ void initializeSymbolAliases();
+
DenseMap<AffineMap, int> affineMapIds;
std::vector<AffineMap> affineMapsById;
+ DenseMap<AffineMap, StringRef> affineMapToAlias;
DenseMap<IntegerSet, int> integerSetIds;
std::vector<IntegerSet> integerSetsById;
+ DenseMap<IntegerSet, StringRef> integerSetToAlias;
+
+ llvm::SetVector<Type> usedTypes;
+ DenseMap<Type, StringRef> typeToAlias;
};
} // end anonymous namespace
// TODO Support visiting other types/instructions when implemented.
void ModuleState::visitType(Type type) {
+ recordTypeReference(type);
if (auto funcType = type.dyn_cast<FunctionType>()) {
// Visit input and result types for functions.
for (auto input : funcType.getInputs())
for (auto map : memref.getAffineMaps()) {
recordAffineMapReference(map);
}
+ } else if (auto vecOrTensor = type.dyn_cast<VectorOrTensorType>()) {
+ visitType(vecOrTensor.getElementType());
}
}
}
}
+// Utility to generate a function to register a symbol alias.
+template <typename SymbolsInModuleSetTy, typename SymbolTy>
+static void registerSymbolAlias(StringRef name, SymbolTy sym,
+ SymbolsInModuleSetTy &symbolsInModuleSet,
+ llvm::StringSet<> &usedAliases,
+ DenseMap<SymbolTy, StringRef> &symToAlias) {
+ assert(!name.empty() && "expected alias name to be non-empty");
+ assert(sym && "expected alias symbol to be non-null");
+ // TODO(riverriddle) Assert that the provided alias name can be lexed as
+ // an identifier.
+
+ // Check if the symbol is not referenced by the module or the name is
+ // already used by another alias.
+ if (!symbolsInModuleSet.count(sym) || !usedAliases.insert(name).second)
+ return;
+ symToAlias.try_emplace(sym, name);
+}
+
+void ModuleState::initializeSymbolAliases() {
+ // Track the identifiers in use for each symbol so that the same identifier
+ // isn't used twice.
+ llvm::StringSet<> usedAliases;
+
+ // Get the currently registered dialects.
+ auto dialects = context->getRegisteredDialects();
+
+ // Collect the set of aliases from each dialect.
+ SmallVector<std::pair<StringRef, AffineMap>, 8> affineMapAliases;
+ SmallVector<std::pair<StringRef, IntegerSet>, 8> integerSetAliases;
+ SmallVector<std::pair<StringRef, Type>, 16> typeAliases;
+ for (auto *dialect : dialects) {
+ dialect->getAffineMapAliases(affineMapAliases);
+ dialect->getIntegerSetAliases(integerSetAliases);
+ dialect->getTypeAliases(typeAliases);
+ }
+
+ // Register the affine aliases.
+ // Create a regex for the non-alias names of sets and maps, so that an alias
+ // is not registered with a conflicting name.
+ llvm::Regex reservedAffineNames("(set|map)[0-9]+");
+
+ // AffineMap aliases
+ for (auto &affineAliasPair : affineMapAliases) {
+ if (!reservedAffineNames.match(affineAliasPair.first))
+ registerSymbolAlias(affineAliasPair.first, affineAliasPair.second,
+ affineMapIds, usedAliases, affineMapToAlias);
+ }
+
+ // IntegerSet aliases
+ for (auto &integerSetAliasPair : integerSetAliases) {
+ if (!reservedAffineNames.match(integerSetAliasPair.first))
+ registerSymbolAlias(integerSetAliasPair.first, integerSetAliasPair.second,
+ integerSetIds, usedAliases, integerSetToAlias);
+ }
+
+ // Clear the set of used identifiers as types can have the same identifiers as
+ // affine structures.
+ usedAliases.clear();
+
+ for (auto &typeAliasPair : typeAliases)
+ registerSymbolAlias(typeAliasPair.first, typeAliasPair.second, usedTypes,
+ usedAliases, typeToAlias);
+}
+
// Initializes module state, populating affine map and integer set state.
void ModuleState::initialize(const Module *module) {
for (auto &fn : *module) {
const_cast<Function &>(fn).walkInsts(
[&](Instruction *op) { ModuleState::visitInstruction(op); });
}
+
+ // Initialize the symbol aliases.
+ initializeSymbolAliases();
}
//===----------------------------------------------------------------------===//
ArrayRef<const char *> elidedAttrs = {});
void printAffineMapId(int affineMapId) const;
void printAffineMapReference(AffineMap affineMap);
+ void printAffineMapAlias(StringRef alias) const;
void printIntegerSetId(int integerSetId) const;
void printIntegerSetReference(IntegerSet integerSet);
+ void printIntegerSetAlias(StringRef alias) const;
void printDenseElementsAttr(DenseElementsAttr attr);
/// This enum is used to represent the binding stength of the enclosing
os << "#map" << affineMapId;
}
+void ModulePrinter::printAffineMapAlias(StringRef alias) const {
+ os << '#' << alias;
+}
+
void ModulePrinter::printAffineMapReference(AffineMap affineMap) {
+ // Check for an affine map alias.
+ auto alias = state.getAffineMapAlias(affineMap);
+ if (!alias.empty())
+ return printAffineMapAlias(alias);
+
int mapId = state.getAffineMapId(affineMap);
if (mapId >= 0) {
// Map will be printed at top of module so print reference to its id.
os << "#set" << integerSetId;
}
+void ModulePrinter::printIntegerSetAlias(StringRef alias) const {
+ os << '#' << alias;
+}
+
void ModulePrinter::printIntegerSetReference(IntegerSet integerSet) {
+ // Check for an integer set alias.
+ auto alias = state.getIntegerSetAlias(integerSet);
+ if (!alias.empty()) {
+ printIntegerSetAlias(alias);
+ return;
+ }
+
int setId;
if ((setId = state.getIntegerSetId(integerSet)) >= 0) {
// The set will be printed at top of module; so print reference to its id.
void ModulePrinter::print(const Module *module) {
for (const auto &map : state.getAffineMapIds()) {
- printAffineMapId(state.getAffineMapId(map));
+ StringRef alias = state.getAffineMapAlias(map);
+ if (!alias.empty())
+ printAffineMapAlias(alias);
+ else
+ printAffineMapId(state.getAffineMapId(map));
os << " = ";
map.print(os);
os << '\n';
}
for (const auto &set : state.getIntegerSetIds()) {
- printIntegerSetId(state.getIntegerSetId(set));
+ StringRef alias = state.getIntegerSetAlias(set);
+ if (!alias.empty())
+ printIntegerSetAlias(alias);
+ else
+ printIntegerSetId(state.getIntegerSetId(set));
os << " = ";
set.print(os);
os << '\n';
}
+ for (const auto &type : state.getTypeIds()) {
+ StringRef alias = state.getTypeAlias(type);
+ if (!alias.empty())
+ os << '!' << alias << " = type " << type << '\n';
+ }
for (auto const &fn : *module)
print(&fn);
}
}
void ModulePrinter::printType(Type type) {
+ // Check for an alias for this type.
+ StringRef alias = state.getTypeAlias(type);
+ if (!alias.empty()) {
+ os << '!' << alias;
+ return;
+ }
+
switch (type.getKind()) {
default: {
auto &dialect = type.getDialect();