/// Simplifies an affine map by simplifying its underlying AffineExpr results.
AffineMap simplifyAffineMap(AffineMap map);
+/// Drop the dims that are listed in `unusedDims`.
+AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims);
+
/// Drop the dims that are not used.
AffineMap compressUnusedDims(AffineMap map);
/// dims and symbols.
SmallVector<AffineMap> compressUnusedDims(ArrayRef<AffineMap> maps);
-/// Drop the dims that are not listed in `unusedDims`.
-AffineMap compressDims(AffineMap map, const llvm::SmallBitVector &unusedDims);
+/// Drop the symbols that are listed in `unusedSymbols`.
+AffineMap compressSymbols(AffineMap map,
+ const llvm::SmallBitVector &unusedSymbols);
/// Drop the symbols that are not used.
AffineMap compressUnusedSymbols(AffineMap map);
/// dims and symbols.
SmallVector<AffineMap> compressUnusedSymbols(ArrayRef<AffineMap> maps);
-/// Drop the symbols that are not listed in `unusedSymbols`.
-AffineMap compressSymbols(AffineMap map,
- const llvm::SmallBitVector &unusedSymbols);
-
/// Returns a map with the same dimension and symbol count as `map`, but whose
/// results are the unique affine expressions of `map`.
AffineMap removeDuplicateExprs(AffineMap map);
/// Return the reverse map of a projected permutation where the projected
/// dimensions are transformed into 0s.
///
-/// Prerequisites: `map` must be a projected permuation.
+/// Prerequisites: `map` must be a projected permutation.
///
/// Example 1:
///
/// projected_dimensions : {1}
/// result : affine_map<(d0, d1) -> (d0, 0)>
///
-/// This function also compresses unused symbols away.
+/// This function also compresses the dims when the boolean flag is true.
+AffineMap projectDims(AffineMap map,
+ const llvm::SmallBitVector &projectedDimensions,
+ bool compressDimsFlag = false);
+/// Symbol counterpart of `projectDims`.
+/// This function also compresses the symbols when the boolean flag is true.
+AffineMap projectSymbols(AffineMap map,
+ const llvm::SmallBitVector &projectedSymbols,
+ bool compressSymbolsFlag = false);
+/// Calls `projectDims(map, projectedDimensions, compressDimsFlag)`.
+/// If `compressSymbolsFlag` is true, additionally call `compressUnusedSymbols`.
AffineMap getProjectedMap(AffineMap map,
- const llvm::SmallBitVector &projectedDimensions);
+ const llvm::SmallBitVector &projectedDimensions,
+ bool compressDimsFlag = true,
+ bool compressSymbolsFlag = true);
+
+// Return a bitvector where each bit set indicates a dimension that is not used
+// by any of the maps in the input array `maps`.
+llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef<AffineMap> maps);
+
+// Return a bitvector where each bit set indicates a symbol that is not used
+// by any of the maps in the input array `maps`.
+llvm::SmallBitVector getUnusedSymbolsBitVector(ArrayRef<AffineMap> maps);
+
+inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
+ map.print(os);
+ return os;
+}
+
+//===----------------------------------------------------------------------===//
+// Templated helper functions.
+//===----------------------------------------------------------------------===//
/// Apply a permutation from `map` to `source` and return the result.
template <typename T>
return result;
}
-/// Calculates maxmimum dimension and symbol positions from the expressions
+/// Calculates maximum dimension and symbol positions from the expressions
/// in `exprsLists` and stores them in `maxDim` and `maxSym` respectively.
template <typename AffineExprContainer>
static void getMaxDimAndSymbol(ArrayRef<AffineExprContainer> exprsList,
}
}
-inline raw_ostream &operator<<(raw_ostream &os, AffineMap map) {
- map.print(os);
- return os;
-}
-
-// Return a bitvector where each bit set indicates a dimension that is not used
-// by any of the maps in the input array `maps`.
-llvm::SmallBitVector getUnusedDimsBitVector(ArrayRef<AffineMap> maps);
-
} // namespace mlir
namespace llvm {
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/MathExtras.h"
+#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/raw_ostream.h"
#include <numeric>
#include <optional>
+#include <type_traits>
using namespace mlir;
return getSliceMap(getNumResults() - numResults, numResults);
}
-AffineMap mlir::compressDims(AffineMap map,
- const llvm::SmallBitVector &unusedDims) {
- unsigned numDims = 0;
- SmallVector<AffineExpr> dimReplacements;
- dimReplacements.reserve(map.getNumDims());
- MLIRContext *context = map.getContext();
- for (unsigned dim = 0, e = map.getNumDims(); dim < e; ++dim) {
- if (unusedDims.test(dim))
- dimReplacements.push_back(getAffineConstantExpr(0, context));
- else
- dimReplacements.push_back(getAffineDimExpr(numDims++, context));
- }
- SmallVector<AffineExpr> resultExprs;
- resultExprs.reserve(map.getNumResults());
- for (auto e : map.getResults())
- resultExprs.push_back(e.replaceDims(dimReplacements));
- return AffineMap::get(numDims, map.getNumSymbols(), resultExprs, context);
-}
-
-AffineMap mlir::compressUnusedDims(AffineMap map) {
- return compressDims(map, getUnusedDimsBitVector({map}));
-}
-
-static SmallVector<AffineMap>
-compressUnusedImpl(ArrayRef<AffineMap> maps,
- llvm::function_ref<AffineMap(AffineMap)> compressionFun) {
+/// Implementation detail to compress multiple affine maps with a compressionFun
+/// that is expected to be either compressUnusedDims or compressUnusedSymbols.
+/// The implementation keeps track of num dims and symbols across the different
+/// affine maps.
+static SmallVector<AffineMap> compressUnusedListImpl(
+ ArrayRef<AffineMap> maps,
+ llvm::function_ref<AffineMap(AffineMap)> compressionFun) {
if (maps.empty())
return SmallVector<AffineMap>();
SmallVector<AffineExpr> allExprs;
return res;
}
+AffineMap mlir::compressDims(AffineMap map,
+ const llvm::SmallBitVector &unusedDims) {
+ return projectDims(map, unusedDims, /*compressDimsFlag=*/true);
+}
+
+AffineMap mlir::compressUnusedDims(AffineMap map) {
+ return compressDims(map, getUnusedDimsBitVector({map}));
+}
+
SmallVector<AffineMap> mlir::compressUnusedDims(ArrayRef<AffineMap> maps) {
- return compressUnusedImpl(maps,
- [](AffineMap m) { return compressUnusedDims(m); });
+ return compressUnusedListImpl(
+ maps, [](AffineMap m) { return compressUnusedDims(m); });
}
AffineMap mlir::compressSymbols(AffineMap map,
const llvm::SmallBitVector &unusedSymbols) {
- unsigned numSymbols = 0;
- SmallVector<AffineExpr> symReplacements;
- symReplacements.reserve(map.getNumSymbols());
- MLIRContext *context = map.getContext();
- for (unsigned sym = 0, e = map.getNumSymbols(); sym < e; ++sym) {
- if (unusedSymbols.test(sym))
- symReplacements.push_back(getAffineConstantExpr(0, context));
- else
- symReplacements.push_back(getAffineSymbolExpr(numSymbols++, context));
- }
- SmallVector<AffineExpr> resultExprs;
- resultExprs.reserve(map.getNumResults());
- for (auto e : map.getResults())
- resultExprs.push_back(e.replaceSymbols(symReplacements));
- return AffineMap::get(map.getNumDims(), numSymbols, resultExprs, context);
+ return projectSymbols(map, unusedSymbols, /*compressSymbolsFlag=*/true);
}
AffineMap mlir::compressUnusedSymbols(AffineMap map) {
- llvm::SmallBitVector unusedSymbols(map.getNumSymbols(), true);
- map.walkExprs([&](AffineExpr expr) {
- if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
- unusedSymbols.reset(symExpr.getPosition());
- });
- return compressSymbols(map, unusedSymbols);
+ return compressSymbols(map, getUnusedSymbolsBitVector({map}));
}
SmallVector<AffineMap> mlir::compressUnusedSymbols(ArrayRef<AffineMap> maps) {
- return compressUnusedImpl(
+ return compressUnusedListImpl(
maps, [](AffineMap m) { return compressUnusedSymbols(m); });
}
maps.front().getContext());
}
+/// Common implementation to project out dimensions or symbols from an affine
+/// map based on the template type.
+/// Additionally, if 'compress' is true, the projected out dimensions or symbols
+/// are also dropped from the resulting map.
+template <typename AffineDimOrSymExpr>
+static AffineMap projectCommonImpl(AffineMap map,
+ const llvm::SmallBitVector &toProject,
+ bool compress) {
+ static_assert(llvm::is_one_of<AffineDimOrSymExpr, AffineDimExpr,
+ AffineSymbolExpr>::value,
+ "expected AffineDimExpr or AffineSymbolExpr");
+
+ constexpr bool isDim = std::is_same<AffineDimOrSymExpr, AffineDimExpr>::value;
+ int64_t numDimOrSym = (isDim) ? map.getNumDims() : map.getNumSymbols();
+ SmallVector<AffineExpr> replacements;
+ replacements.reserve(numDimOrSym);
+
+ auto createNewDimOrSym = (isDim) ? getAffineDimExpr : getAffineSymbolExpr;
+ auto replaceDims = [](AffineExpr e, ArrayRef<AffineExpr> replacements) {
+ return e.replaceDims(replacements);
+ };
+ auto replaceSymbols = [](AffineExpr e, ArrayRef<AffineExpr> replacements) {
+ return e.replaceSymbols(replacements);
+ };
+ auto replaceNewDimOrSym = (isDim) ? replaceDims : replaceSymbols;
+
+ MLIRContext *context = map.getContext();
+ int64_t newNumDimOrSym = 0;
+ for (unsigned dimOrSym = 0; dimOrSym < numDimOrSym; ++dimOrSym) {
+ if (toProject.test(dimOrSym)) {
+ replacements.push_back(getAffineConstantExpr(0, context));
+ continue;
+ }
+ int64_t newPos = compress ? newNumDimOrSym++ : dimOrSym;
+ replacements.push_back(createNewDimOrSym(newPos, context));
+ }
+ SmallVector<AffineExpr> resultExprs;
+ resultExprs.reserve(map.getNumResults());
+ for (auto e : map.getResults())
+ resultExprs.push_back(replaceNewDimOrSym(e, replacements));
+
+ int64_t numDims = (compress && isDim) ? newNumDimOrSym : map.getNumDims();
+ int64_t numSyms = (compress && !isDim) ? newNumDimOrSym : map.getNumSymbols();
+ return AffineMap::get(numDims, numSyms, resultExprs, context);
+}
+
+AffineMap mlir::projectDims(AffineMap map,
+ const llvm::SmallBitVector &projectedDimensions,
+ bool compressDimsFlag) {
+ return projectCommonImpl<AffineDimExpr>(map, projectedDimensions,
+ compressDimsFlag);
+}
+
+AffineMap mlir::projectSymbols(AffineMap map,
+ const llvm::SmallBitVector &projectedSymbols,
+ bool compressSymbolsFlag) {
+ return projectCommonImpl<AffineSymbolExpr>(map, projectedSymbols,
+ compressSymbolsFlag);
+}
+
AffineMap mlir::getProjectedMap(AffineMap map,
- const llvm::SmallBitVector &unusedDims) {
- return compressUnusedSymbols(compressDims(map, unusedDims));
+ const llvm::SmallBitVector &projectedDimensions,
+ bool compressDimsFlag,
+ bool compressSymbolsFlag) {
+ map = projectDims(map, projectedDimensions, compressDimsFlag);
+ if (compressSymbolsFlag)
+ map = compressUnusedSymbols(map);
+ return map;
}
llvm::SmallBitVector mlir::getUnusedDimsBitVector(ArrayRef<AffineMap> maps) {
unsigned numDims = maps[0].getNumDims();
llvm::SmallBitVector numDimsBitVector(numDims, true);
- for (const auto &m : maps) {
+ for (AffineMap m : maps) {
for (unsigned i = 0; i < numDims; ++i) {
if (m.isFunctionOfDim(i))
numDimsBitVector.reset(i);
return numDimsBitVector;
}
+llvm::SmallBitVector mlir::getUnusedSymbolsBitVector(ArrayRef<AffineMap> maps) {
+ unsigned numSymbols = maps[0].getNumSymbols();
+ llvm::SmallBitVector numSymbolsBitVector(numSymbols, true);
+ for (AffineMap m : maps) {
+ for (unsigned i = 0; i < numSymbols; ++i) {
+ if (m.isFunctionOfSymbol(i))
+ numSymbolsBitVector.reset(i);
+ }
+ }
+ return numSymbolsBitVector;
+}
+
//===----------------------------------------------------------------------===//
// MutableAffineMap.
//===----------------------------------------------------------------------===//
return false;
}
-// Simplifies the result affine expressions of this map. The expressions have to
-// be pure for the simplification implemented.
+// Simplifies the result affine expressions of this map. The expressions
+// have to be pure for the simplification implemented.
void MutableAffineMap::simplify() {
// Simplify each of the results if possible.
// TODO: functional-style map