#include "mlir/IR/IntegerSet.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#define DEBUG_TYPE "affine-structures"
using namespace mlir;
-using namespace llvm;
+using llvm::SmallDenseMap;
+using llvm::SmallDenseSet;
+using llvm::SmallPtrSet;
namespace {
assert(ids.size() == getNumIds());
}
-// This routine may add additional local variables if the flattened expression
-// corresponding to the map has such variables due to the presence of
-// mod's, ceildiv's, and floordiv's.
-bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) {
- // Assert if the map and this constraint set aren't associated with the same
- // identifiers in the same order.
- assert(vMap->getNumDims() <= getNumDimIds());
- assert(vMap->getNumSymbols() <= getNumSymbolIds());
- for (unsigned i = 0, e = vMap->getNumDims(); i < e; i++) {
- assert(ids[i].hasValue());
- assert(vMap->getOperand(i) == ids[i].getValue());
+/// Checks if two constraint systems are in the same space, i.e., if they are
+/// associated with the same set of identifiers, appearing in the same order.
+bool areIdsAligned(const FlatAffineConstraints &A,
+ const FlatAffineConstraints &B) {
+ return A.getNumDimIds() == B.getNumDimIds() &&
+ A.getNumSymbolIds() == B.getNumSymbolIds() &&
+ A.getNumIds() == B.getNumIds() && A.getIds().equals(B.getIds());
+}
+
+/// Checks if the SSA values associated with `cst''s identifiers are unique.
+static bool areIdsUnique(const FlatAffineConstraints &cst) {
+ SmallPtrSet<Value *, 8> uniqueIds;
+ for (auto id : cst.getIds()) {
+ if (id.hasValue() && !uniqueIds.insert(id.getValue()).second)
+ return false;
+ }
+ return true;
+}
+
+// Swap the posA^th identifier with the posB^th identifier.
+static void swapId(FlatAffineConstraints *A, unsigned posA, unsigned posB) {
+ assert(posA < A->getNumIds() && "invalid position A");
+ assert(posB < A->getNumIds() && "invalid position B");
+
+ if (posA == posB)
+ return;
+
+ for (unsigned r = 0, e = A->getNumInequalities(); r < e; r++) {
+ std::swap(A->atIneq(r, posA), A->atIneq(r, posB));
+ }
+ for (unsigned r = 0, e = A->getNumEqualities(); r < e; r++) {
+ std::swap(A->atEq(r, posA), A->atEq(r, posB));
}
- for (unsigned i = 0, e = vMap->getNumSymbols(); i < e; i++) {
- assert(ids[numDims + i].hasValue());
- assert(vMap->getOperand(vMap->getNumDims() + i) ==
- ids[numDims + i].getValue());
+ std::swap(A->getId(posA), A->getId(posB));
+}
+
+/// Merge and align the identifiers of A and B so that both constraint systems
+/// get the union of the contained identifiers that is dimension-wise and
+/// symbol-wise unique; both constraint systems are updated so that they have
+/// the union of all identifiers, with A's original identifiers appearing first
+/// followed by any of B's identifiers that didn't appear in A. Local
+/// identifiers of each system are by design separate/local and are placed one
+/// after other (A's followed by B's).
+// Eg: Input: A has ((%i %j) [%M %N]) and B has (%k, %j) [%P, %N, %M])
+// Output: both A, B have (%i, %j, %k) [%M, %N, %P]
+//
+// TODO(mlir-team): expose this function at some point.
+static void mergeAndAlignIds(FlatAffineConstraints *A,
+ FlatAffineConstraints *B) {
+ // A merge/align isn't meaningful if a cst's ids aren't distinct.
+ assert(areIdsUnique(*A) && "A's id values aren't unique");
+ assert(areIdsUnique(*B) && "B's id values aren't unique");
+
+ assert(std::all_of(A->getIds().begin(),
+ A->getIds().begin() + A->getNumDimAndSymbolIds(),
+ [](Optional<Value *> id) { return id.hasValue(); }));
+
+ assert(std::all_of(B->getIds().begin(),
+ B->getIds().begin() + B->getNumDimAndSymbolIds(),
+ [](Optional<Value *> id) { return id.hasValue(); }));
+
+ // Place local id's of A after local id's of B.
+ for (unsigned l = 0, e = A->getNumLocalIds(); l < e; l++) {
+ B->addLocalId(0);
+ }
+ for (unsigned t = 0, e = B->getNumLocalIds() - A->getNumLocalIds(); t < e;
+ t++) {
+ A->addLocalId(A->getNumLocalIds());
}
+ SmallVector<Value *, 4> aDimValues, aSymValues;
+ A->getIdValues(0, A->getNumDimIds(), &aDimValues);
+ A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &aSymValues);
+ {
+ // Merge dims from A into B.
+ unsigned d = 0;
+ for (auto *aDimValue : aDimValues) {
+ unsigned loc;
+ if (B->findId(*aDimValue, &loc)) {
+ assert(loc < B->getNumDimIds() &&
+ "A's dim appears in B's non-dim position");
+ swapId(B, d, loc);
+ } else {
+ B->addDimId(d);
+ B->setIdValue(d, aDimValue);
+ }
+ d++;
+ }
+
+ // Dimensions that are in B, but not in A, are added at the end.
+ for (unsigned t = A->getNumDimIds(), e = B->getNumDimIds(); t < e; t++) {
+ A->addDimId(A->getNumDimIds());
+ A->setIdValue(A->getNumDimIds() - 1, B->getIdValue(t));
+ }
+ }
+ {
+ // Merge symbols: merge A's symbols into B first.
+ unsigned s = B->getNumDimIds();
+ for (auto *aSymValue : aSymValues) {
+ unsigned loc;
+ if (B->findId(*aSymValue, &loc)) {
+ assert(loc >= B->getNumDimIds() && loc < B->getNumDimAndSymbolIds() &&
+ "A's symbol appears in B's non-symbol position");
+ swapId(B, s, loc);
+ } else {
+ B->addSymbolId(s - B->getNumDimIds());
+ B->setIdValue(s, aSymValue);
+ }
+ s++;
+ }
+ // Symbols that are in B, but not in A, are added at the end.
+ for (unsigned t = A->getNumDimAndSymbolIds(),
+ e = B->getNumDimAndSymbolIds();
+ t < e; t++) {
+ A->addSymbolId(A->getNumSymbolIds());
+ A->setIdValue(A->getNumDimAndSymbolIds() - 1, B->getIdValue(t));
+ }
+ }
+ assert(areIdsAligned(*A, *B) && "IDs expected to be aligned");
+}
+
+// This routine may add additional local variables if the flattened expression
+// corresponding to the map has such variables due to mod's, ceildiv's, and
+// floordiv's in it.
+bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) {
std::vector<SmallVector<int64_t, 8>> flatExprs;
- FlatAffineConstraints cst;
- if (!getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs, &cst)) {
+ FlatAffineConstraints localCst;
+ if (!getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs, &localCst)) {
LLVM_DEBUG(llvm::dbgs()
<< "composition unimplemented for semi-affine maps\n");
return false;
}
assert(flatExprs.size() == vMap->getNumResults());
- // Make the value map and the flat affine cst dimensions compatible.
- // A lot of this code will be refactored/cleaned up.
- // TODO(bondhugula): the next ~20 lines of code is pretty UGLY. This needs
- // to be factored out into an FlatAffineConstraints::alignAndMerge().
- for (unsigned l = 0, e = cst.getNumLocalIds(); l < e; l++) {
- addLocalId(0);
- }
+ // Make the value map and the flat affine localCst dimensions compatible.
+ SmallVector<Value *, 8> values(vMap->getOperands().begin(),
+ vMap->getOperands().end());
+ localCst.setIdValues(0, localCst.getNumDimAndSymbolIds(), values);
+ // Align localCst and this - localCst's identifiers appear first in the union.
+ mergeAndAlignIds(&localCst, this);
+ // Add dimensions corresponding to the map's results.
for (unsigned t = 0, e = vMap->getNumResults(); t < e; t++) {
// TODO: Consider using a batched version to add a range of IDs.
addDimId(0);
- cst.addDimId(0);
+ localCst.addDimId(0);
}
- assert(cst.getNumDimIds() <= getNumDimIds());
- for (unsigned t = 0, e = getNumDimIds() - cst.getNumDimIds(); t < e; t++) {
- // Dimensions that are in 'this' but not in vMap/cst are added at the end.
- cst.addDimId(cst.getNumDimIds());
- }
- assert(cst.getNumSymbolIds() <= getNumSymbolIds());
- for (unsigned t = 0, e = getNumSymbolIds() - cst.getNumSymbolIds(); t < e;
- t++) {
- // Dimensions that are in 'this' but not in vMap/cst are added at the end.
- cst.addSymbolId(cst.getNumSymbolIds());
- }
- assert(cst.getNumLocalIds() <= getNumLocalIds());
- for (unsigned t = 0, e = getNumLocalIds() - cst.getNumLocalIds(); t < e;
- t++) {
- cst.addLocalId(cst.getNumLocalIds());
- }
- /// Finally, append cst to this constraint set.
- append(cst);
+ // Finally, append localCst to this constraint set.
+ append(localCst);
// We add one equality for each result connecting the result dim of the map to
// the other identifiers.
// add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
const auto &flatExpr = flatExprs[r];
+ assert(flatExpr.size() >= vMap->getNumOperands() + 1);
+
// eqToAdd is the equality corresponding to the flattened affine expression.
SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
// Set the coefficient for this result to one.
eqToAdd[r] = 1;
- assert(flatExpr.size() >= vMap->getNumOperands() + 1);
-
// Dims and symbols.
for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) {
unsigned loc;
bool ret = findId(*vMap->getOperand(i), &loc);
assert(ret && "value map's id can't be found");
(void)ret;
- // We need to negate 'eq[r]' since the newly added dimension is going to
- // be set to this one.
+ // Negate 'eq[r]' since the newly added dimension will be set to this one.
eqToAdd[loc] = -flatExpr[i];
}
- // Local vars common to eq and cst are at the beginning.
- int j = getNumDimIds() + getNumSymbolIds();
- int end = flatExpr.size() - 1;
- for (int i = vMap->getNumOperands(); i < end; i++, j++) {
+ // Local vars common to eq and localCst are at the beginning.
+ unsigned j = getNumDimIds() + getNumSymbolIds();
+ unsigned end = flatExpr.size() - 1;
+ for (unsigned i = vMap->getNumOperands(); i < end; i++, j++) {
eqToAdd[j] = -flatExpr[i];
}
}
}; // namespace
-// TODO(bondhugula,andydavis): This still doesn't do a comprehensive merge of
-// the symbols. Assumes the common symbols appear in the same order (the
-// current/common use case).
-static void mergeSymbols(FlatAffineConstraints *A, FlatAffineConstraints *B) {
- SmallVector<Value *, 4> symbolsA, symbolsB;
- A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &symbolsA);
- B->getIdValues(B->getNumDimIds(), B->getNumDimAndSymbolIds(), &symbolsB);
-
- // Both symbol list have a handful symbols each typically (3-4); a merge
- // quadratic in complexity with a linear search is fine.
- for (auto *symbolB : symbolsB) {
- if (llvm::is_contained(symbolsA, symbolB)) {
- A->addSymbolId(symbolsA.size(), symbolB);
- symbolsA.push_back(symbolB);
- }
- }
- // symbolsA now holds the merged symbol list.
- symbolsB.reserve(symbolsA.size());
- unsigned iB = 0;
- for (auto *symbolA : symbolsA) {
- assert(iB < symbolsB.size());
- if (symbolA != symbolsB[iB]) {
- symbolsB.insert(symbolsB.begin() + iB, symbolA);
- B->addSymbolId(iB, symbolA);
- }
- ++iB;
- }
-}
-
-// Compute the bounding box with respect to 'other' by finding the min of the
+// Computes the bounding box with respect to 'other' by finding the min of the
// lower bounds and the max of the upper bounds along each of the dimensions.
bool FlatAffineConstraints::unionBoundingBox(
- const FlatAffineConstraints &otherArg) {
- assert(otherArg.getNumDimIds() == numDims && "dims mismatch");
-
- Optional<FlatAffineConstraints> copy;
- if (!otherArg.getIds().equals(getIds())) {
- copy.emplace(FlatAffineConstraints(otherArg));
- mergeSymbols(this, ©.getValue());
- assert(getIds().equals(copy->getIds()) && "merge failed");
- }
-
- const auto &other = copy ? *copy : otherArg;
-
- assert(other.getNumLocalIds() == 0 && "local ids not eliminated");
- assert(getNumLocalIds() == 0 && "local ids not eliminated");
+ const FlatAffineConstraints &otherCst) {
+ assert(otherCst.getNumDimIds() == numDims && "dims mismatch");
+ assert(otherCst.getIds()
+ .slice(0, getNumDimIds())
+ .equals(getIds().slice(0, getNumDimIds())) &&
+ "dim values mismatch");
+ assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here");
+ assert(getNumLocalIds() == 0 && "local ids not supported yet here");
+
+ Optional<FlatAffineConstraints> otherCopy;
+ if (!areIdsAligned(*this, otherCst)) {
+ otherCopy.emplace(FlatAffineConstraints(otherCst));
+ mergeAndAlignIds(this, &otherCopy.getValue());
+ }
+
+ const auto &other = otherCopy ? *otherCopy : otherCst;
std::vector<SmallVector<int64_t, 8>> boundingLbs;
std::vector<SmallVector<int64_t, 8>> boundingUbs;