[mlir][AffineMap] NFC - Refactor getProjectedMap and split into projectDims and proje...
authorNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 22 Mar 2023 09:19:20 +0000 (02:19 -0700)
committerNicolas Vasilache <nicolas.vasilache@gmail.com>
Wed, 22 Mar 2023 12:30:48 +0000 (05:30 -0700)
The default behavior of getProjectedMap may be surprising as it implicitly compresses the dims and
the unused symbols.

Make these explicit in the API and refactor to more idiomatic implementations with better reuse.

Differential Revision: https://reviews.llvm.org/D146611

mlir/include/mlir/IR/AffineMap.h
mlir/lib/IR/AffineMap.cpp

index 0f4c746..cc7c794 100644 (file)
@@ -403,6 +403,9 @@ private:
 /// Simplifies an affine map by simplifying its underlying AffineExpr results.
 AffineMap simplifyAffineMap(AffineMap map);
 
+/// Drop the dims that are listed in `unusedDims`.
+AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims);
+
 /// Drop the dims that are not used.
 AffineMap compressUnusedDims(AffineMap map);
 
@@ -411,8 +414,9 @@ AffineMap compressUnusedDims(AffineMap map);
 /// dims and symbols.
 SmallVector<AffineMap> compressUnusedDims(ArrayRef<AffineMap> maps);
 
-/// Drop the dims that are not listed in `unusedDims`.
-AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims);
+/// Drop the symbols that are listed in `unusedSymbols`.
+AffineMap compressSymbols(AffineMap map,
+                          const llvm::SmallBitVector &unusedSymbols);
 
 /// Drop the symbols that are not used.
 AffineMap compressUnusedSymbols(AffineMap map);
@@ -422,10 +426,6 @@ AffineMap compressUnusedSymbols(AffineMap map);
 /// dims and symbols.
 SmallVector<AffineMap> compressUnusedSymbols(ArrayRef<AffineMap> maps);
 
-/// Drop the symbols that are not listed in `unusedSymbols`.
-AffineMap compressSymbols(AffineMap map,
-                          const llvm::SmallBitVector &unusedSymbols);
-
 /// Returns a map with the same dimension and symbol count as `map`, but whose
 /// results are the unique affine expressions of `map`.
 AffineMap removeDuplicateExprs(AffineMap map);
@@ -469,7 +469,7 @@ AffineMap inversePermutation(AffineMap map);
 /// Return the reverse map of a projected permutation where the projected
 /// dimensions are transformed into 0s.
 ///
-/// Prerequisites: `map` must be a projected permuation.
+/// Prerequisites: `map` must be a projected permutation.
 ///
 /// Example 1:
 ///
@@ -559,9 +559,38 @@ AffineMap concatAffineMaps(ArrayRef<AffineMap> maps);
 ///    projected_dimensions : {1}
 ///    result               : affine_map<(d0, d1) -> (d0, 0)>
 ///
-/// This function also compresses unused symbols away.
+/// This function also compresses the dims when the boolean flag is true.
+AffineMap projectDims(AffineMap map,
+                      const llvm::SmallBitVector &projectedDimensions,
+                      bool compressDimsFlag = false);
+/// Symbol counterpart of `projectDims`.
+/// This function also compresses the symbols when the boolean flag is true.
+AffineMap projectSymbols(AffineMap map,
+                         const llvm::SmallBitVector &projectedSymbols,
+                         bool compressSymbolsFlag = false);
+/// Calls `projectDims(map, projectedDimensions, compressDimsFlag)`.
+/// If `compressSymbolsFlag` is true, additionally call `compressUnusedSymbols`.
 AffineMap getProjectedMap(AffineMap map,
-                          const llvm::SmallBitVector &projectedDimensions);
+                          const llvm::SmallBitVector &projectedDimensions,
+                          bool compressDimsFlag = true,
+                          bool compressSymbolsFlag = true);
+
+// Return a bitvector where each bit set indicates a dimension that is not used
+// by any of the maps in the input array `maps`.
+llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef<AffineMap> maps);
+
+// Return a bitvector where each bit set indicates a symbol that is not used
+// by any of the maps in the input array `maps`.
+llvm::SmallBitVector getUnusedSymbolsBitVector(ArrayRef<AffineMap> maps);
+
+inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
+  map.print(os);
+  return os;
+}
+
+//===----------------------------------------------------------------------===//
+// Templated helper functions.
+//===----------------------------------------------------------------------===//
 
 /// Apply a permutation from `map` to `source` and return the result.
 template <typename T>
@@ -584,7 +613,7 @@ SmallVector<T> applyPermutationMap(AffineMap map, llvm::ArrayRef<T> source) {
   return result;
 }
 
-/// Calculates maxmimum dimension and symbol positions from the expressions
+/// Calculates maximum dimension and symbol positions from the expressions
 /// in `exprsLists` and stores them in `maxDim` and `maxSym` respectively.
 template <typename AffineExprContainer>
 static void getMaxDimAndSymbol(ArrayRef<AffineExprContainer> exprsList,
@@ -601,15 +630,6 @@ static void getMaxDimAndSymbol(ArrayRef<AffineExprContainer> exprsList,
   }
 }
 
-inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
-  map.print(os);
-  return os;
-}
-
-// Return a bitvector where each bit set indicates a dimension that is not used
-// by any of the maps in the input array `maps`.
-llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef<AffineMap> maps);
-
 } // namespace mlir
 
 namespace llvm {
index 90c5466..c924d2b 100644 (file)
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/Support/LogicalResult.h"
 #include "mlir/Support/MathExtras.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallBitVector.h"
 #include "llvm/ADT/SmallSet.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Support/raw_ostream.h"
 #include <numeric>
 #include <optional>
+#include <type_traits>
 
 using namespace mlir;
 
@@ -569,32 +571,13 @@ AffineMap AffineMap::getMinorSubMap(unsigned numResults) const {
   return getSliceMap(getNumResults() - numResults, numResults);
 }
 
-AffineMap mlir::compressDims(AffineMap map,
-                             const llvm::SmallBitVector &unusedDims) {
-  unsigned numDims = 0;
-  SmallVector<AffineExpr> dimReplacements;
-  dimReplacements.reserve(map.getNumDims());
-  MLIRContext *context = map.getContext();
-  for (unsigned dim = 0, e = map.getNumDims(); dim < e; ++dim) {
-    if (unusedDims.test(dim))
-      dimReplacements.push_back(getAffineConstantExpr(0, context));
-    else
-      dimReplacements.push_back(getAffineDimExpr(numDims++, context));
-  }
-  SmallVector<AffineExpr> resultExprs;
-  resultExprs.reserve(map.getNumResults());
-  for (auto e : map.getResults())
-    resultExprs.push_back(e.replaceDims(dimReplacements));
-  return AffineMap::get(numDims, map.getNumSymbols(), resultExprs, context);
-}
-
-AffineMap mlir::compressUnusedDims(AffineMap map) {
-  return compressDims(map, getUnusedDimsBitVector({map}));
-}
-
-static SmallVector<AffineMap>
-compressUnusedImpl(ArrayRef<AffineMap> maps,
-                   llvm::function_ref<AffineMap(AffineMap)> compressionFun) {
+/// Implementation detail to compress multiple affine maps with a compressionFun
+/// that is expected to be either compressUnusedDims or compressUnusedSymbols.
+/// The implementation keeps track of num dims and symbols across the different
+/// affine maps.
+static SmallVector<AffineMap> compressUnusedListImpl(
+    ArrayRef<AffineMap> maps,
+    llvm::function_ref<AffineMap(AffineMap)> compressionFun) {
   if (maps.empty())
     return SmallVector<AffineMap>();
   SmallVector<AffineExpr> allExprs;
@@ -622,41 +605,31 @@ compressUnusedImpl(ArrayRef<AffineMap> maps,
   return res;
 }
 
+AffineMap mlir::compressDims(AffineMap map,
+                             const llvm::SmallBitVector &unusedDims) {
+  return projectDims(map, unusedDims, /*compressDimsFlag=*/true);
+}
+
+AffineMap mlir::compressUnusedDims(AffineMap map) {
+  return compressDims(map, getUnusedDimsBitVector({map}));
+}
+
 SmallVector<AffineMap> mlir::compressUnusedDims(ArrayRef<AffineMap> maps) {
-  return compressUnusedImpl(maps,
-                            [](AffineMap m) { return compressUnusedDims(m); });
+  return compressUnusedListImpl(
+      maps, [](AffineMap m) { return compressUnusedDims(m); });
 }
 
 AffineMap mlir::compressSymbols(AffineMap map,
                                 const llvm::SmallBitVector &unusedSymbols) {
-  unsigned numSymbols = 0;
-  SmallVector<AffineExpr> symReplacements;
-  symReplacements.reserve(map.getNumSymbols());
-  MLIRContext *context = map.getContext();
-  for (unsigned sym = 0, e = map.getNumSymbols(); sym < e; ++sym) {
-    if (unusedSymbols.test(sym))
-      symReplacements.push_back(getAffineConstantExpr(0, context));
-    else
-      symReplacements.push_back(getAffineSymbolExpr(numSymbols++, context));
-  }
-  SmallVector<AffineExpr> resultExprs;
-  resultExprs.reserve(map.getNumResults());
-  for (auto e : map.getResults())
-    resultExprs.push_back(e.replaceSymbols(symReplacements));
-  return AffineMap::get(map.getNumDims(), numSymbols, resultExprs, context);
+  return projectSymbols(map, unusedSymbols, /*compressSymbolsFlag=*/true);
 }
 
 AffineMap mlir::compressUnusedSymbols(AffineMap map) {
-  llvm::SmallBitVector unusedSymbols(map.getNumSymbols(), true);
-  map.walkExprs([&](AffineExpr expr) {
-    if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
-      unusedSymbols.reset(symExpr.getPosition());
-  });
-  return compressSymbols(map, unusedSymbols);
+  return compressSymbols(map, getUnusedSymbolsBitVector({map}));
 }
 
 SmallVector<AffineMap> mlir::compressUnusedSymbols(ArrayRef<AffineMap> maps) {
-  return compressUnusedImpl(
+  return compressUnusedListImpl(
       maps, [](AffineMap m) { return compressUnusedSymbols(m); });
 }
 
@@ -741,15 +714,80 @@ AffineMap mlir::concatAffineMaps(ArrayRef<AffineMap> maps) {
                         maps.front().getContext());
 }
 
+/// Common implementation to project out dimensions or symbols from an affine
+/// map based on the template type.
+/// Additionally, if 'compress' is true, the projected out dimensions or symbols
+/// are also dropped from the resulting map.
+template <typename AffineDimOrSymExpr>
+static AffineMap projectCommonImpl(AffineMap map,
+                                   const llvm::SmallBitVector &toProject,
+                                   bool compress) {
+  static_assert(llvm::is_one_of<AffineDimOrSymExpr, AffineDimExpr,
+                                AffineSymbolExpr>::value,
+                "expected AffineDimExpr or AffineSymbolExpr");
+
+  constexpr bool isDim = std::is_same<AffineDimOrSymExpr, AffineDimExpr>::value;
+  int64_t numDimOrSym = (isDim) ? map.getNumDims() : map.getNumSymbols();
+  SmallVector<AffineExpr> replacements;
+  replacements.reserve(numDimOrSym);
+
+  auto createNewDimOrSym = (isDim) ? getAffineDimExpr : getAffineSymbolExpr;
+  auto replaceDims = [](AffineExpr e, ArrayRef<AffineExpr> replacements) {
+    return e.replaceDims(replacements);
+  };
+  auto replaceSymbols = [](AffineExpr e, ArrayRef<AffineExpr> replacements) {
+    return e.replaceSymbols(replacements);
+  };
+  auto replaceNewDimOrSym = (isDim) ? replaceDims : replaceSymbols;
+
+  MLIRContext *context = map.getContext();
+  int64_t newNumDimOrSym = 0;
+  for (unsigned dimOrSym = 0; dimOrSym < numDimOrSym; ++dimOrSym) {
+    if (toProject.test(dimOrSym)) {
+      replacements.push_back(getAffineConstantExpr(0, context));
+      continue;
+    }
+    int64_t newPos = compress ? newNumDimOrSym++ : dimOrSym;
+    replacements.push_back(createNewDimOrSym(newPos, context));
+  }
+  SmallVector<AffineExpr> resultExprs;
+  resultExprs.reserve(map.getNumResults());
+  for (auto e : map.getResults())
+    resultExprs.push_back(replaceNewDimOrSym(e, replacements));
+
+  int64_t numDims = (compress && isDim) ? newNumDimOrSym : map.getNumDims();
+  int64_t numSyms = (compress && !isDim) ? newNumDimOrSym : map.getNumSymbols();
+  return AffineMap::get(numDims, numSyms, resultExprs, context);
+}
+
+AffineMap mlir::projectDims(AffineMap map,
+                            const llvm::SmallBitVector &projectedDimensions,
+                            bool compressDimsFlag) {
+  return projectCommonImpl<AffineDimExpr>(map, projectedDimensions,
+                                          compressDimsFlag);
+}
+
+AffineMap mlir::projectSymbols(AffineMap map,
+                               const llvm::SmallBitVector &projectedSymbols,
+                               bool compressSymbolsFlag) {
+  return projectCommonImpl<AffineSymbolExpr>(map, projectedSymbols,
+                                             compressSymbolsFlag);
+}
+
 AffineMap mlir::getProjectedMap(AffineMap map,
-                                const llvm::SmallBitVector &unusedDims) {
-  return compressUnusedSymbols(compressDims(map, unusedDims));
+                                const llvm::SmallBitVector &projectedDimensions,
+                                bool compressDimsFlag,
+                                bool compressSymbolsFlag) {
+  map = projectDims(map, projectedDimensions, compressDimsFlag);
+  if (compressSymbolsFlag)
+    map = compressUnusedSymbols(map);
+  return map;
 }
 
 llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef<AffineMap> maps) {
   unsigned numDims = maps[0].getNumDims();
   llvm::SmallBitVector numDimsBitVector(numDims, true);
-  for (const auto &m : maps) {
+  for (AffineMap m : maps) {
     for (unsigned i = 0; i < numDims; ++i) {
       if (m.isFunctionOfDim(i))
         numDimsBitVector.reset(i);
@@ -758,6 +796,18 @@ llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef<AffineMap> maps) {
   return numDimsBitVector;
 }
 
+llvm::SmallBitVector mlir::getUnusedSymbolsBitVector(ArrayRef<AffineMap> maps) {
+  unsigned numSymbols = maps[0].getNumSymbols();
+  llvm::SmallBitVector numSymbolsBitVector(numSymbols, true);
+  for (AffineMap m : maps) {
+    for (unsigned i = 0; i < numSymbols; ++i) {
+      if (m.isFunctionOfSymbol(i))
+        numSymbolsBitVector.reset(i);
+    }
+  }
+  return numSymbolsBitVector;
+}
+
 //===----------------------------------------------------------------------===//
 // MutableAffineMap.
 //===----------------------------------------------------------------------===//
@@ -784,8 +834,8 @@ bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
   return false;
 }
 
-// Simplifies the result affine expressions of this map. The expressions have to
-// be pure for the simplification implemented.
+// Simplifies the result affine expressions of this map. The expressions
+// have to be pure for the simplification implemented.
 void MutableAffineMap::simplify() {
   // Simplify each of the results if possible.
   // TODO: functional-style map