Provide dialect hooks for defining named aliases for AffineMap/IntegerSet/Type.
authorRiver Riddle <riverriddle@google.com>
Fri, 11 Jan 2019 06:08:39 +0000 (22:08 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 22:08:55 +0000 (15:08 -0700)
The AsmPrinter will then query registered dialects for aliases of symbols used within the module and use them in place.

PiperOrigin-RevId: 228831678

mlir/include/mlir/IR/Dialect.h
mlir/lib/IR/AsmPrinter.cpp

index 89287648f9ed45fd34f55203e1cceadb642c6536..e1ef802fd8ef727b870e4287b809b24b7569a1cf 100644 (file)
@@ -25,6 +25,8 @@
 #include "mlir/IR/OperationSupport.h"
 
 namespace mlir {
+class AffineMap;
+class IntegerSet;
 class Type;
 
 using DialectConstantFoldHook = std::function<bool(
@@ -63,7 +65,18 @@ public:
   /// characters.
   DialectTypePrinterHook typePrintHook = nullptr;
 
-  // TODO: Hook to return the list of named types that are known.
+  /// Registered hooks for getting identifier aliases for symbols. The
+  /// identifier is used in place of the symbol when printing textual IR.
+  ///
+  /// Hook for defining AffineMap aliases.
+  virtual void getAffineMapAliases(
+      SmallVectorImpl<std::pair<StringRef, AffineMap>> &aliases) {}
+  /// Hook for defining IntegerSet aliases.
+  virtual void getIntegerSetAliases(
+      SmallVectorImpl<std::pair<StringRef, IntegerSet>> &aliases) {}
+  /// Hook for defining Type aliases.
+  virtual void
+  getTypeAliases(SmallVectorImpl<std::pair<StringRef, Type>> &aliases) {}
 
   virtual ~Dialect();
 
index 68616917c846cfff0c65d73276b81f32b5a96c31..6922c66a9a6939a7fbee70784572e42086e1e74f 100644 (file)
@@ -28,6 +28,7 @@
 #include "mlir/IR/InstVisitor.h"
 #include "mlir/IR/Instructions.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/StandardTypes.h"
 #include "llvm/ADT/APFloat.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringExtras.h"
 #include "llvm/ADT/StringSet.h"
+#include "llvm/Support/Regex.h"
 using namespace mlir;
 
 void Identifier::print(raw_ostream &os) const { os << str(); }
@@ -65,6 +68,10 @@ public:
   // Initializes module state, populating affine map state.
   void initialize(const Module *module);
 
+  StringRef getAffineMapAlias(AffineMap affineMap) const {
+    return affineMapToAlias.lookup(affineMap);
+  }
+
   int getAffineMapId(AffineMap affineMap) const {
     auto it = affineMapIds.find(affineMap);
     if (it == affineMapIds.end()) {
@@ -75,6 +82,10 @@ public:
 
   ArrayRef<AffineMap> getAffineMapIds() const { return affineMapsById; }
 
+  StringRef getIntegerSetAlias(IntegerSet integerSet) const {
+    return integerSetToAlias.lookup(integerSet);
+  }
+
   int getIntegerSetId(IntegerSet integerSet) const {
     auto it = integerSetIds.find(integerSet);
     if (it == integerSetIds.end()) {
@@ -85,6 +96,10 @@ public:
 
   ArrayRef<IntegerSet> getIntegerSetIds() const { return integerSetsById; }
 
+  StringRef getTypeAlias(Type ty) const { return typeToAlias.lookup(ty); }
+
+  ArrayRef<Type> getTypeIds() const { return usedTypes.getArrayRef(); }
+
 private:
   void recordAffineMapReference(AffineMap affineMap) {
     if (affineMapIds.count(affineMap) == 0) {
@@ -100,6 +115,8 @@ private:
     }
   }
 
+  void recordTypeReference(Type ty) { usedTypes.insert(ty); }
+
   // Return true if this map could be printed using the shorthand form.
   static bool hasShorthandForm(AffineMap boundMap) {
     if (boundMap.isSingleConstant())
@@ -120,16 +137,25 @@ private:
   void visitType(Type type);
   void visitAttribute(Attribute attr);
 
+  // Initialize symbol aliases.
+  void initializeSymbolAliases();
+
   DenseMap<AffineMap, int> affineMapIds;
   std::vector<AffineMap> affineMapsById;
+  DenseMap<AffineMap, StringRef> affineMapToAlias;
 
   DenseMap<IntegerSet, int> integerSetIds;
   std::vector<IntegerSet> integerSetsById;
+  DenseMap<IntegerSet, StringRef> integerSetToAlias;
+
+  llvm::SetVector<Type> usedTypes;
+  DenseMap<Type, StringRef> typeToAlias;
 };
 } // end anonymous namespace
 
 // TODO Support visiting other types/instructions when implemented.
 void ModuleState::visitType(Type type) {
+  recordTypeReference(type);
   if (auto funcType = type.dyn_cast<FunctionType>()) {
     // Visit input and result types for functions.
     for (auto input : funcType.getInputs())
@@ -141,6 +167,8 @@ void ModuleState::visitType(Type type) {
     for (auto map : memref.getAffineMaps()) {
       recordAffineMapReference(map);
     }
+  } else if (auto vecOrTensor = type.dyn_cast<VectorOrTensorType>()) {
+    visitType(vecOrTensor.getElementType());
   }
 }
 
@@ -193,6 +221,70 @@ void ModuleState::visitInstruction(const Instruction *inst) {
   }
 }
 
+// Utility to generate a function to register a symbol alias.
+template <typename SymbolsInModuleSetTy, typename SymbolTy>
+static void registerSymbolAlias(StringRef name, SymbolTy sym,
+                                SymbolsInModuleSetTy &symbolsInModuleSet,
+                                llvm::StringSet<> &usedAliases,
+                                DenseMap<SymbolTy, StringRef> &symToAlias) {
+  assert(!name.empty() && "expected alias name to be non-empty");
+  assert(sym && "expected alias symbol to be non-null");
+  // TODO(riverriddle) Assert that the provided alias name can be lexed as
+  // an identifier.
+
+  // Check if the symbol is not referenced by the module or the name is
+  // already used by another alias.
+  if (!symbolsInModuleSet.count(sym) || !usedAliases.insert(name).second)
+    return;
+  symToAlias.try_emplace(sym, name);
+}
+
+void ModuleState::initializeSymbolAliases() {
+  // Track the identifiers in use for each symbol so that the same identifier
+  // isn't used twice.
+  llvm::StringSet<> usedAliases;
+
+  // Get the currently registered dialects.
+  auto dialects = context->getRegisteredDialects();
+
+  // Collect the set of aliases from each dialect.
+  SmallVector<std::pair<StringRef, AffineMap>, 8> affineMapAliases;
+  SmallVector<std::pair<StringRef, IntegerSet>, 8> integerSetAliases;
+  SmallVector<std::pair<StringRef, Type>, 16> typeAliases;
+  for (auto *dialect : dialects) {
+    dialect->getAffineMapAliases(affineMapAliases);
+    dialect->getIntegerSetAliases(integerSetAliases);
+    dialect->getTypeAliases(typeAliases);
+  }
+
+  // Register the affine aliases.
+  // Create a regex for the non-alias names of sets and maps, so that an alias
+  // is not registered with a conflicting name.
+  llvm::Regex reservedAffineNames("(set|map)[0-9]+");
+
+  // AffineMap aliases
+  for (auto &affineAliasPair : affineMapAliases) {
+    if (!reservedAffineNames.match(affineAliasPair.first))
+      registerSymbolAlias(affineAliasPair.first, affineAliasPair.second,
+                          affineMapIds, usedAliases, affineMapToAlias);
+  }
+
+  // IntegerSet aliases
+  for (auto &integerSetAliasPair : integerSetAliases) {
+    if (!reservedAffineNames.match(integerSetAliasPair.first))
+      registerSymbolAlias(integerSetAliasPair.first, integerSetAliasPair.second,
+                          integerSetIds, usedAliases, integerSetToAlias);
+  }
+
+  // Clear the set of used identifiers as types can have the same identifiers as
+  // affine structures.
+  usedAliases.clear();
+
+  for (auto &typeAliasPair : typeAliases)
+    registerSymbolAlias(typeAliasPair.first, typeAliasPair.second, usedTypes,
+                        usedAliases, typeToAlias);
+}
+
 // Initializes module state, populating affine map and integer set state.
 void ModuleState::initialize(const Module *module) {
   for (auto &fn : *module) {
@@ -201,6 +293,9 @@ void ModuleState::initialize(const Module *module) {
     const_cast<Function &>(fn).walkInsts(
         [&](Instruction *op) { ModuleState::visitInstruction(op); });
   }
+
+  // Initialize the symbol aliases.
+  initializeSymbolAliases();
 }
 
 //===----------------------------------------------------------------------===//
@@ -238,8 +333,10 @@ protected:
                              ArrayRef<const char *> elidedAttrs = {});
   void printAffineMapId(int affineMapId) const;
   void printAffineMapReference(AffineMap affineMap);
+  void printAffineMapAlias(StringRef alias) const;
   void printIntegerSetId(int integerSetId) const;
   void printIntegerSetReference(IntegerSet integerSet);
+  void printIntegerSetAlias(StringRef alias) const;
   void printDenseElementsAttr(DenseElementsAttr attr);
 
   /// This enum is used to represent the binding stength of the enclosing
@@ -259,7 +356,16 @@ void ModulePrinter::printAffineMapId(int affineMapId) const {
   os << "#map" << affineMapId;
 }
 
+void ModulePrinter::printAffineMapAlias(StringRef alias) const {
+  os << '#' << alias;
+}
+
 void ModulePrinter::printAffineMapReference(AffineMap affineMap) {
+  // Check for an affine map alias.
+  auto alias = state.getAffineMapAlias(affineMap);
+  if (!alias.empty())
+    return printAffineMapAlias(alias);
+
   int mapId = state.getAffineMapId(affineMap);
   if (mapId >= 0) {
     // Map will be printed at top of module so print reference to its id.
@@ -275,7 +381,18 @@ void ModulePrinter::printIntegerSetId(int integerSetId) const {
   os << "#set" << integerSetId;
 }
 
+void ModulePrinter::printIntegerSetAlias(StringRef alias) const {
+  os << '#' << alias;
+}
+
 void ModulePrinter::printIntegerSetReference(IntegerSet integerSet) {
+  // Check for an integer set alias.
+  auto alias = state.getIntegerSetAlias(integerSet);
+  if (!alias.empty()) {
+    printIntegerSetAlias(alias);
+    return;
+  }
+
   int setId;
   if ((setId = state.getIntegerSetId(integerSet)) >= 0) {
     // The set will be printed at top of module; so print reference to its id.
@@ -288,17 +405,30 @@ void ModulePrinter::printIntegerSetReference(IntegerSet integerSet) {
 
 void ModulePrinter::print(const Module *module) {
   for (const auto &map : state.getAffineMapIds()) {
-    printAffineMapId(state.getAffineMapId(map));
+    StringRef alias = state.getAffineMapAlias(map);
+    if (!alias.empty())
+      printAffineMapAlias(alias);
+    else
+      printAffineMapId(state.getAffineMapId(map));
     os << " = ";
     map.print(os);
     os << '\n';
   }
   for (const auto &set : state.getIntegerSetIds()) {
-    printIntegerSetId(state.getIntegerSetId(set));
+    StringRef alias = state.getIntegerSetAlias(set);
+    if (!alias.empty())
+      printIntegerSetAlias(alias);
+    else
+      printIntegerSetId(state.getIntegerSetId(set));
     os << " = ";
     set.print(os);
     os << '\n';
   }
+  for (const auto &type : state.getTypeIds()) {
+    StringRef alias = state.getTypeAlias(type);
+    if (!alias.empty())
+      os << '!' << alias << " = type " << type << '\n';
+  }
   for (auto const &fn : *module)
     print(&fn);
 }
@@ -485,6 +615,13 @@ void ModulePrinter::printDenseElementsAttr(DenseElementsAttr attr) {
 }
 
 void ModulePrinter::printType(Type type) {
+  // Check for an alias for this type.
+  StringRef alias = state.getTypeAlias(type);
+  if (!alias.empty()) {
+    os << '!' << alias;
+    return;
+  }
+
   switch (type.getKind()) {
   default: {
     auto &dialect = type.getDialect();