From 255ba1c334b86792054c152ce8533dca5b452b41 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Wed, 22 Mar 2023 02:19:20 -0700 Subject: [PATCH] [mlir][AffineMap] NFC - Refactor getProjectedMap and split into projectDims and projectSymbols The default behavior of getProjectedMap may be surprising as it implicitly compresses the dims and the unused symbols. Make these explicit in the API and refactor to more idiomatic implementations with better reuse. Differential Revision: https://reviews.llvm.org/D146611 --- mlir/include/mlir/IR/AffineMap.h | 58 +++++++++----- mlir/lib/IR/AffineMap.cpp | 160 +++++++++++++++++++++++++-------------- 2 files changed, 144 insertions(+), 74 deletions(-) diff --git a/mlir/include/mlir/IR/AffineMap.h b/mlir/include/mlir/IR/AffineMap.h index 0f4c746..cc7c794 100644 --- a/mlir/include/mlir/IR/AffineMap.h +++ b/mlir/include/mlir/IR/AffineMap.h @@ -403,6 +403,9 @@ private: /// Simplifies an affine map by simplifying its underlying AffineExpr results. AffineMap simplifyAffineMap(AffineMap map); +/// Drop the dims that are listed in `unusedDims`. +AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims); + /// Drop the dims that are not used. AffineMap compressUnusedDims(AffineMap map); @@ -411,8 +414,9 @@ AffineMap compressUnusedDims(AffineMap map); /// dims and symbols. SmallVector compressUnusedDims(ArrayRef maps); -/// Drop the dims that are not listed in `unusedDims`. -AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims); +/// Drop the symbols that are listed in `unusedSymbols`. +AffineMap compressSymbols(AffineMap map, + const llvm::SmallBitVector &unusedSymbols); /// Drop the symbols that are not used. AffineMap compressUnusedSymbols(AffineMap map); @@ -422,10 +426,6 @@ AffineMap compressUnusedSymbols(AffineMap map); /// dims and symbols. SmallVector compressUnusedSymbols(ArrayRef maps); -/// Drop the symbols that are not listed in `unusedSymbols`. -AffineMap compressSymbols(AffineMap map, - const llvm::SmallBitVector &unusedSymbols); - /// Returns a map with the same dimension and symbol count as `map`, but whose /// results are the unique affine expressions of `map`. AffineMap removeDuplicateExprs(AffineMap map); @@ -469,7 +469,7 @@ AffineMap inversePermutation(AffineMap map); /// Return the reverse map of a projected permutation where the projected /// dimensions are transformed into 0s. /// -/// Prerequisites: `map` must be a projected permuation. +/// Prerequisites: `map` must be a projected permutation. /// /// Example 1: /// @@ -559,9 +559,38 @@ AffineMap concatAffineMaps(ArrayRef maps); /// projected_dimensions : {1} /// result : affine_map<(d0, d1) -> (d0, 0)> /// -/// This function also compresses unused symbols away. +/// This function also compresses the dims when the boolean flag is true. +AffineMap projectDims(AffineMap map, + const llvm::SmallBitVector &projectedDimensions, + bool compressDimsFlag = false); +/// Symbol counterpart of `projectDims`. +/// This function also compresses the symbols when the boolean flag is true. +AffineMap projectSymbols(AffineMap map, + const llvm::SmallBitVector &projectedSymbols, + bool compressSymbolsFlag = false); +/// Calls `projectDims(map, projectedDimensions, compressDimsFlag)`. +/// If `compressSymbolsFlag` is true, additionally call `compressUnusedSymbols`. AffineMap getProjectedMap(AffineMap map, - const llvm::SmallBitVector &projectedDimensions); + const llvm::SmallBitVector &projectedDimensions, + bool compressDimsFlag = true, + bool compressSymbolsFlag = true); + +// Return a bitvector where each bit set indicates a dimension that is not used +// by any of the maps in the input array `maps`. +llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef maps); + +// Return a bitvector where each bit set indicates a symbol that is not used +// by any of the maps in the input array `maps`. +llvm::SmallBitVector getUnusedSymbolsBitVector(ArrayRef maps); + +inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) { + map.print(os); + return os; +} + +//===----------------------------------------------------------------------===// +// Templated helper functions. +//===----------------------------------------------------------------------===// /// Apply a permutation from `map` to `source` and return the result. template @@ -584,7 +613,7 @@ SmallVector applyPermutationMap(AffineMap map, llvm::ArrayRef source) { return result; } -/// Calculates maxmimum dimension and symbol positions from the expressions +/// Calculates maximum dimension and symbol positions from the expressions /// in `exprsLists` and stores them in `maxDim` and `maxSym` respectively. template static void getMaxDimAndSymbol(ArrayRef exprsList, @@ -601,15 +630,6 @@ static void getMaxDimAndSymbol(ArrayRef exprsList, } } -inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) { - map.print(os); - return os; -} - -// Return a bitvector where each bit set indicates a dimension that is not used -// by any of the maps in the input array `maps`. -llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef maps); - } // namespace mlir namespace llvm { diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index 90c5466..c924d2b 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -12,12 +12,14 @@ #include "mlir/IR/BuiltinTypes.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Support/MathExtras.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/raw_ostream.h" #include #include +#include using namespace mlir; @@ -569,32 +571,13 @@ AffineMap AffineMap::getMinorSubMap(unsigned numResults) const { return getSliceMap(getNumResults() - numResults, numResults); } -AffineMap mlir::compressDims(AffineMap map, - const llvm::SmallBitVector &unusedDims) { - unsigned numDims = 0; - SmallVector dimReplacements; - dimReplacements.reserve(map.getNumDims()); - MLIRContext *context = map.getContext(); - for (unsigned dim = 0, e = map.getNumDims(); dim < e; ++dim) { - if (unusedDims.test(dim)) - dimReplacements.push_back(getAffineConstantExpr(0, context)); - else - dimReplacements.push_back(getAffineDimExpr(numDims++, context)); - } - SmallVector resultExprs; - resultExprs.reserve(map.getNumResults()); - for (auto e : map.getResults()) - resultExprs.push_back(e.replaceDims(dimReplacements)); - return AffineMap::get(numDims, map.getNumSymbols(), resultExprs, context); -} - -AffineMap mlir::compressUnusedDims(AffineMap map) { - return compressDims(map, getUnusedDimsBitVector({map})); -} - -static SmallVector -compressUnusedImpl(ArrayRef maps, - llvm::function_ref compressionFun) { +/// Implementation detail to compress multiple affine maps with a compressionFun +/// that is expected to be either compressUnusedDims or compressUnusedSymbols. +/// The implementation keeps track of num dims and symbols across the different +/// affine maps. +static SmallVector compressUnusedListImpl( + ArrayRef maps, + llvm::function_ref compressionFun) { if (maps.empty()) return SmallVector(); SmallVector allExprs; @@ -622,41 +605,31 @@ compressUnusedImpl(ArrayRef maps, return res; } +AffineMap mlir::compressDims(AffineMap map, + const llvm::SmallBitVector &unusedDims) { + return projectDims(map, unusedDims, /*compressDimsFlag=*/true); +} + +AffineMap mlir::compressUnusedDims(AffineMap map) { + return compressDims(map, getUnusedDimsBitVector({map})); +} + SmallVector mlir::compressUnusedDims(ArrayRef maps) { - return compressUnusedImpl(maps, - [](AffineMap m) { return compressUnusedDims(m); }); + return compressUnusedListImpl( + maps, [](AffineMap m) { return compressUnusedDims(m); }); } AffineMap mlir::compressSymbols(AffineMap map, const llvm::SmallBitVector &unusedSymbols) { - unsigned numSymbols = 0; - SmallVector symReplacements; - symReplacements.reserve(map.getNumSymbols()); - MLIRContext *context = map.getContext(); - for (unsigned sym = 0, e = map.getNumSymbols(); sym < e; ++sym) { - if (unusedSymbols.test(sym)) - symReplacements.push_back(getAffineConstantExpr(0, context)); - else - symReplacements.push_back(getAffineSymbolExpr(numSymbols++, context)); - } - SmallVector resultExprs; - resultExprs.reserve(map.getNumResults()); - for (auto e : map.getResults()) - resultExprs.push_back(e.replaceSymbols(symReplacements)); - return AffineMap::get(map.getNumDims(), numSymbols, resultExprs, context); + return projectSymbols(map, unusedSymbols, /*compressSymbolsFlag=*/true); } AffineMap mlir::compressUnusedSymbols(AffineMap map) { - llvm::SmallBitVector unusedSymbols(map.getNumSymbols(), true); - map.walkExprs([&](AffineExpr expr) { - if (auto symExpr = expr.dyn_cast()) - unusedSymbols.reset(symExpr.getPosition()); - }); - return compressSymbols(map, unusedSymbols); + return compressSymbols(map, getUnusedSymbolsBitVector({map})); } SmallVector mlir::compressUnusedSymbols(ArrayRef maps) { - return compressUnusedImpl( + return compressUnusedListImpl( maps, [](AffineMap m) { return compressUnusedSymbols(m); }); } @@ -741,15 +714,80 @@ AffineMap mlir::concatAffineMaps(ArrayRef maps) { maps.front().getContext()); } +/// Common implementation to project out dimensions or symbols from an affine +/// map based on the template type. +/// Additionally, if 'compress' is true, the projected out dimensions or symbols +/// are also dropped from the resulting map. +template +static AffineMap projectCommonImpl(AffineMap map, + const llvm::SmallBitVector &toProject, + bool compress) { + static_assert(llvm::is_one_of::value, + "expected AffineDimExpr or AffineSymbolExpr"); + + constexpr bool isDim = std::is_same::value; + int64_t numDimOrSym = (isDim) ? map.getNumDims() : map.getNumSymbols(); + SmallVector replacements; + replacements.reserve(numDimOrSym); + + auto createNewDimOrSym = (isDim) ? getAffineDimExpr : getAffineSymbolExpr; + auto replaceDims = [](AffineExpr e, ArrayRef replacements) { + return e.replaceDims(replacements); + }; + auto replaceSymbols = [](AffineExpr e, ArrayRef replacements) { + return e.replaceSymbols(replacements); + }; + auto replaceNewDimOrSym = (isDim) ? replaceDims : replaceSymbols; + + MLIRContext *context = map.getContext(); + int64_t newNumDimOrSym = 0; + for (unsigned dimOrSym = 0; dimOrSym < numDimOrSym; ++dimOrSym) { + if (toProject.test(dimOrSym)) { + replacements.push_back(getAffineConstantExpr(0, context)); + continue; + } + int64_t newPos = compress ? newNumDimOrSym++ : dimOrSym; + replacements.push_back(createNewDimOrSym(newPos, context)); + } + SmallVector resultExprs; + resultExprs.reserve(map.getNumResults()); + for (auto e : map.getResults()) + resultExprs.push_back(replaceNewDimOrSym(e, replacements)); + + int64_t numDims = (compress && isDim) ? newNumDimOrSym : map.getNumDims(); + int64_t numSyms = (compress && !isDim) ? newNumDimOrSym : map.getNumSymbols(); + return AffineMap::get(numDims, numSyms, resultExprs, context); +} + +AffineMap mlir::projectDims(AffineMap map, + const llvm::SmallBitVector &projectedDimensions, + bool compressDimsFlag) { + return projectCommonImpl(map, projectedDimensions, + compressDimsFlag); +} + +AffineMap mlir::projectSymbols(AffineMap map, + const llvm::SmallBitVector &projectedSymbols, + bool compressSymbolsFlag) { + return projectCommonImpl(map, projectedSymbols, + compressSymbolsFlag); +} + AffineMap mlir::getProjectedMap(AffineMap map, - const llvm::SmallBitVector &unusedDims) { - return compressUnusedSymbols(compressDims(map, unusedDims)); + const llvm::SmallBitVector &projectedDimensions, + bool compressDimsFlag, + bool compressSymbolsFlag) { + map = projectDims(map, projectedDimensions, compressDimsFlag); + if (compressSymbolsFlag) + map = compressUnusedSymbols(map); + return map; } llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef maps) { unsigned numDims = maps[0].getNumDims(); llvm::SmallBitVector numDimsBitVector(numDims, true); - for (const auto &m : maps) { + for (AffineMap m : maps) { for (unsigned i = 0; i < numDims; ++i) { if (m.isFunctionOfDim(i)) numDimsBitVector.reset(i); @@ -758,6 +796,18 @@ llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef maps) { return numDimsBitVector; } +llvm::SmallBitVector mlir::getUnusedSymbolsBitVector(ArrayRef maps) { + unsigned numSymbols = maps[0].getNumSymbols(); + llvm::SmallBitVector numSymbolsBitVector(numSymbols, true); + for (AffineMap m : maps) { + for (unsigned i = 0; i < numSymbols; ++i) { + if (m.isFunctionOfSymbol(i)) + numSymbolsBitVector.reset(i); + } + } + return numSymbolsBitVector; +} + //===----------------------------------------------------------------------===// // MutableAffineMap. //===----------------------------------------------------------------------===// @@ -784,8 +834,8 @@ bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const { return false; } -// Simplifies the result affine expressions of this map. The expressions have to -// be pure for the simplification implemented. +// Simplifies the result affine expressions of this map. The expressions +// have to be pure for the simplification implemented. void MutableAffineMap::simplify() { // Simplify each of the results if possible. // TODO: functional-style map -- 2.7.4