From 6ef5fc582ea8989297822e321bb6e39b29f7da69 Mon Sep 17 00:00:00 2001 From: Uday Bondhugula Date: Fri, 1 Mar 2019 08:49:20 -0800 Subject: [PATCH] Method to align/merge dimensional/symbolic identifiers between two FlatAffineConstraints - add a method to merge and align the spaces (identifiers) of two FlatAffineConstraints (both get dimension-wise and symbol-wise unique columns) - this completes several TODOs, gets rid of previous assumptions/restrictions in composeMap, unionBoundingBox, and reuses common code - remove previous workarounds / duplicated funcitonality in FlatAffineConstraints::composeMap and unionBoundingBox, use mergeAlignIds from both PiperOrigin-RevId: 236320581 --- mlir/include/mlir/Analysis/AffineStructures.h | 14 +- mlir/include/mlir/Analysis/Utils.h | 1 + mlir/lib/Analysis/AffineStructures.cpp | 258 +++++++++++------- mlir/lib/Analysis/Utils.cpp | 13 +- 4 files changed, 181 insertions(+), 105 deletions(-) diff --git a/mlir/include/mlir/Analysis/AffineStructures.h b/mlir/include/mlir/Analysis/AffineStructures.h index 7e75634e537c..48d9383bef08 100644 --- a/mlir/include/mlir/Analysis/AffineStructures.h +++ b/mlir/include/mlir/Analysis/AffineStructures.h @@ -508,7 +508,8 @@ public: /// contains the points of 'this' set and that of 'other', with the symbols /// being treated specially. For each of the dimensions, the min of the lower /// bounds (symbolic) and the max of the upper bounds (symbolic) is computed - /// to determine such a bounding box. + /// to determine such a bounding box. `other' is expected to have the same + /// dimensional identifiers as this constraint system (in the same order). /// /// Eg: if 'this' is {0 <= d0 <= 127}, 'other' is {16 <= d0 <= 192}, the /// output is {0 <= d0 <= 192}. @@ -532,6 +533,13 @@ public: inline ArrayRef> getIds() const { return {ids.data(), ids.size()}; } + inline MutableArrayRef> getIds() { + return {ids.data(), ids.size()}; + } + + /// Returns the optional Value corresponding to the pos^th identifier. + inline Optional getId(unsigned pos) const { return ids[pos]; } + inline Optional &getId(unsigned pos) { return ids[pos]; } /// Returns the Value associated with the pos^th identifier. Asserts if /// no Value identifier was associated. @@ -707,8 +715,8 @@ private: /// Values corresponding to the (column) identifiers of this constraint /// system appearing in the order the identifiers correspond to columns. - /// Temporary ones or those that aren't associated to any Value are to be - /// set to None. + /// Temporary ones or those that aren't associated to any Value are set to + /// None. SmallVector, 8> ids; /// A parameter that controls detection of an unrealistic number of diff --git a/mlir/include/mlir/Analysis/Utils.h b/mlir/include/mlir/Analysis/Utils.h index cb78b3872b56..dff0f57aaccf 100644 --- a/mlir/include/mlir/Analysis/Utils.h +++ b/mlir/include/mlir/Analysis/Utils.h @@ -187,6 +187,7 @@ struct MemRefRegion { /// Returns the size of this MemRefRegion in bytes. Optional getRegionSize(); + // Wrapper around FlatAffineConstraints::unionBoundingBox. bool unionBoundingBox(const MemRefRegion &other); /// Returns the rank of the memref that this region corresponds to. diff --git a/mlir/lib/Analysis/AffineStructures.cpp b/mlir/lib/Analysis/AffineStructures.cpp index 1837e9cf47ae..3066f95d4f7d 100644 --- a/mlir/lib/Analysis/AffineStructures.cpp +++ b/mlir/lib/Analysis/AffineStructures.cpp @@ -28,13 +28,16 @@ #include "mlir/IR/IntegerSet.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #define DEBUG_TYPE "affine-structures" using namespace mlir; -using namespace llvm; +using llvm::SmallDenseMap; +using llvm::SmallDenseSet; +using llvm::SmallPtrSet; namespace { @@ -480,65 +483,156 @@ void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value *id) { assert(ids.size() == getNumIds()); } -// This routine may add additional local variables if the flattened expression -// corresponding to the map has such variables due to the presence of -// mod's, ceildiv's, and floordiv's. -bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) { - // Assert if the map and this constraint set aren't associated with the same - // identifiers in the same order. - assert(vMap->getNumDims() <= getNumDimIds()); - assert(vMap->getNumSymbols() <= getNumSymbolIds()); - for (unsigned i = 0, e = vMap->getNumDims(); i < e; i++) { - assert(ids[i].hasValue()); - assert(vMap->getOperand(i) == ids[i].getValue()); +/// Checks if two constraint systems are in the same space, i.e., if they are +/// associated with the same set of identifiers, appearing in the same order. +bool areIdsAligned(const FlatAffineConstraints &A, + const FlatAffineConstraints &B) { + return A.getNumDimIds() == B.getNumDimIds() && + A.getNumSymbolIds() == B.getNumSymbolIds() && + A.getNumIds() == B.getNumIds() && A.getIds().equals(B.getIds()); +} + +/// Checks if the SSA values associated with `cst''s identifiers are unique. +static bool areIdsUnique(const FlatAffineConstraints &cst) { + SmallPtrSet uniqueIds; + for (auto id : cst.getIds()) { + if (id.hasValue() && !uniqueIds.insert(id.getValue()).second) + return false; + } + return true; +} + +// Swap the posA^th identifier with the posB^th identifier. +static void swapId(FlatAffineConstraints *A, unsigned posA, unsigned posB) { + assert(posA < A->getNumIds() && "invalid position A"); + assert(posB < A->getNumIds() && "invalid position B"); + + if (posA == posB) + return; + + for (unsigned r = 0, e = A->getNumInequalities(); r < e; r++) { + std::swap(A->atIneq(r, posA), A->atIneq(r, posB)); + } + for (unsigned r = 0, e = A->getNumEqualities(); r < e; r++) { + std::swap(A->atEq(r, posA), A->atEq(r, posB)); } - for (unsigned i = 0, e = vMap->getNumSymbols(); i < e; i++) { - assert(ids[numDims + i].hasValue()); - assert(vMap->getOperand(vMap->getNumDims() + i) == - ids[numDims + i].getValue()); + std::swap(A->getId(posA), A->getId(posB)); +} + +/// Merge and align the identifiers of A and B so that both constraint systems +/// get the union of the contained identifiers that is dimension-wise and +/// symbol-wise unique; both constraint systems are updated so that they have +/// the union of all identifiers, with A's original identifiers appearing first +/// followed by any of B's identifiers that didn't appear in A. Local +/// identifiers of each system are by design separate/local and are placed one +/// after other (A's followed by B's). +// Eg: Input: A has ((%i %j) [%M %N]) and B has (%k, %j) [%P, %N, %M]) +// Output: both A, B have (%i, %j, %k) [%M, %N, %P] +// +// TODO(mlir-team): expose this function at some point. +static void mergeAndAlignIds(FlatAffineConstraints *A, + FlatAffineConstraints *B) { + // A merge/align isn't meaningful if a cst's ids aren't distinct. + assert(areIdsUnique(*A) && "A's id values aren't unique"); + assert(areIdsUnique(*B) && "B's id values aren't unique"); + + assert(std::all_of(A->getIds().begin(), + A->getIds().begin() + A->getNumDimAndSymbolIds(), + [](Optional id) { return id.hasValue(); })); + + assert(std::all_of(B->getIds().begin(), + B->getIds().begin() + B->getNumDimAndSymbolIds(), + [](Optional id) { return id.hasValue(); })); + + // Place local id's of A after local id's of B. + for (unsigned l = 0, e = A->getNumLocalIds(); l < e; l++) { + B->addLocalId(0); + } + for (unsigned t = 0, e = B->getNumLocalIds() - A->getNumLocalIds(); t < e; + t++) { + A->addLocalId(A->getNumLocalIds()); } + SmallVector aDimValues, aSymValues; + A->getIdValues(0, A->getNumDimIds(), &aDimValues); + A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &aSymValues); + { + // Merge dims from A into B. + unsigned d = 0; + for (auto *aDimValue : aDimValues) { + unsigned loc; + if (B->findId(*aDimValue, &loc)) { + assert(loc < B->getNumDimIds() && + "A's dim appears in B's non-dim position"); + swapId(B, d, loc); + } else { + B->addDimId(d); + B->setIdValue(d, aDimValue); + } + d++; + } + + // Dimensions that are in B, but not in A, are added at the end. + for (unsigned t = A->getNumDimIds(), e = B->getNumDimIds(); t < e; t++) { + A->addDimId(A->getNumDimIds()); + A->setIdValue(A->getNumDimIds() - 1, B->getIdValue(t)); + } + } + { + // Merge symbols: merge A's symbols into B first. + unsigned s = B->getNumDimIds(); + for (auto *aSymValue : aSymValues) { + unsigned loc; + if (B->findId(*aSymValue, &loc)) { + assert(loc >= B->getNumDimIds() && loc < B->getNumDimAndSymbolIds() && + "A's symbol appears in B's non-symbol position"); + swapId(B, s, loc); + } else { + B->addSymbolId(s - B->getNumDimIds()); + B->setIdValue(s, aSymValue); + } + s++; + } + // Symbols that are in B, but not in A, are added at the end. + for (unsigned t = A->getNumDimAndSymbolIds(), + e = B->getNumDimAndSymbolIds(); + t < e; t++) { + A->addSymbolId(A->getNumSymbolIds()); + A->setIdValue(A->getNumDimAndSymbolIds() - 1, B->getIdValue(t)); + } + } + assert(areIdsAligned(*A, *B) && "IDs expected to be aligned"); +} + +// This routine may add additional local variables if the flattened expression +// corresponding to the map has such variables due to mod's, ceildiv's, and +// floordiv's in it. +bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) { std::vector> flatExprs; - FlatAffineConstraints cst; - if (!getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs, &cst)) { + FlatAffineConstraints localCst; + if (!getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs, &localCst)) { LLVM_DEBUG(llvm::dbgs() << "composition unimplemented for semi-affine maps\n"); return false; } assert(flatExprs.size() == vMap->getNumResults()); - // Make the value map and the flat affine cst dimensions compatible. - // A lot of this code will be refactored/cleaned up. - // TODO(bondhugula): the next ~20 lines of code is pretty UGLY. This needs - // to be factored out into an FlatAffineConstraints::alignAndMerge(). - for (unsigned l = 0, e = cst.getNumLocalIds(); l < e; l++) { - addLocalId(0); - } + // Make the value map and the flat affine localCst dimensions compatible. + SmallVector values(vMap->getOperands().begin(), + vMap->getOperands().end()); + localCst.setIdValues(0, localCst.getNumDimAndSymbolIds(), values); + // Align localCst and this - localCst's identifiers appear first in the union. + mergeAndAlignIds(&localCst, this); + // Add dimensions corresponding to the map's results. for (unsigned t = 0, e = vMap->getNumResults(); t < e; t++) { // TODO: Consider using a batched version to add a range of IDs. addDimId(0); - cst.addDimId(0); + localCst.addDimId(0); } - assert(cst.getNumDimIds() <= getNumDimIds()); - for (unsigned t = 0, e = getNumDimIds() - cst.getNumDimIds(); t < e; t++) { - // Dimensions that are in 'this' but not in vMap/cst are added at the end. - cst.addDimId(cst.getNumDimIds()); - } - assert(cst.getNumSymbolIds() <= getNumSymbolIds()); - for (unsigned t = 0, e = getNumSymbolIds() - cst.getNumSymbolIds(); t < e; - t++) { - // Dimensions that are in 'this' but not in vMap/cst are added at the end. - cst.addSymbolId(cst.getNumSymbolIds()); - } - assert(cst.getNumLocalIds() <= getNumLocalIds()); - for (unsigned t = 0, e = getNumLocalIds() - cst.getNumLocalIds(); t < e; - t++) { - cst.addLocalId(cst.getNumLocalIds()); - } - /// Finally, append cst to this constraint set. - append(cst); + // Finally, append localCst to this constraint set. + append(localCst); // We add one equality for each result connecting the result dim of the map to // the other identifiers. @@ -548,27 +642,26 @@ bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) { // add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0. for (unsigned r = 0, e = flatExprs.size(); r < e; r++) { const auto &flatExpr = flatExprs[r]; + assert(flatExpr.size() >= vMap->getNumOperands() + 1); + // eqToAdd is the equality corresponding to the flattened affine expression. SmallVector eqToAdd(getNumCols(), 0); // Set the coefficient for this result to one. eqToAdd[r] = 1; - assert(flatExpr.size() >= vMap->getNumOperands() + 1); - // Dims and symbols. for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) { unsigned loc; bool ret = findId(*vMap->getOperand(i), &loc); assert(ret && "value map's id can't be found"); (void)ret; - // We need to negate 'eq[r]' since the newly added dimension is going to - // be set to this one. + // Negate 'eq[r]' since the newly added dimension will be set to this one. eqToAdd[loc] = -flatExpr[i]; } - // Local vars common to eq and cst are at the beginning. - int j = getNumDimIds() + getNumSymbolIds(); - int end = flatExpr.size() - 1; - for (int i = vMap->getNumOperands(); i < end; i++, j++) { + // Local vars common to eq and localCst are at the beginning. + unsigned j = getNumDimIds() + getNumSymbolIds(); + unsigned end = flatExpr.size() - 1; + for (unsigned i = vMap->getNumOperands(); i < end; i++, j++) { eqToAdd[j] = -flatExpr[i]; } @@ -2447,52 +2540,25 @@ static BoundCmpResult compareBounds(ArrayRef a, ArrayRef b) { } }; // namespace -// TODO(bondhugula,andydavis): This still doesn't do a comprehensive merge of -// the symbols. Assumes the common symbols appear in the same order (the -// current/common use case). -static void mergeSymbols(FlatAffineConstraints *A, FlatAffineConstraints *B) { - SmallVector symbolsA, symbolsB; - A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &symbolsA); - B->getIdValues(B->getNumDimIds(), B->getNumDimAndSymbolIds(), &symbolsB); - - // Both symbol list have a handful symbols each typically (3-4); a merge - // quadratic in complexity with a linear search is fine. - for (auto *symbolB : symbolsB) { - if (llvm::is_contained(symbolsA, symbolB)) { - A->addSymbolId(symbolsA.size(), symbolB); - symbolsA.push_back(symbolB); - } - } - // symbolsA now holds the merged symbol list. - symbolsB.reserve(symbolsA.size()); - unsigned iB = 0; - for (auto *symbolA : symbolsA) { - assert(iB < symbolsB.size()); - if (symbolA != symbolsB[iB]) { - symbolsB.insert(symbolsB.begin() + iB, symbolA); - B->addSymbolId(iB, symbolA); - } - ++iB; - } -} - -// Compute the bounding box with respect to 'other' by finding the min of the +// Computes the bounding box with respect to 'other' by finding the min of the // lower bounds and the max of the upper bounds along each of the dimensions. bool FlatAffineConstraints::unionBoundingBox( - const FlatAffineConstraints &otherArg) { - assert(otherArg.getNumDimIds() == numDims && "dims mismatch"); - - Optional copy; - if (!otherArg.getIds().equals(getIds())) { - copy.emplace(FlatAffineConstraints(otherArg)); - mergeSymbols(this, ©.getValue()); - assert(getIds().equals(copy->getIds()) && "merge failed"); - } - - const auto &other = copy ? *copy : otherArg; - - assert(other.getNumLocalIds() == 0 && "local ids not eliminated"); - assert(getNumLocalIds() == 0 && "local ids not eliminated"); + const FlatAffineConstraints &otherCst) { + assert(otherCst.getNumDimIds() == numDims && "dims mismatch"); + assert(otherCst.getIds() + .slice(0, getNumDimIds()) + .equals(getIds().slice(0, getNumDimIds())) && + "dim values mismatch"); + assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here"); + assert(getNumLocalIds() == 0 && "local ids not supported yet here"); + + Optional otherCopy; + if (!areIdsAligned(*this, otherCst)) { + otherCopy.emplace(FlatAffineConstraints(otherCst)); + mergeAndAlignIds(this, &otherCopy.getValue()); + } + + const auto &other = otherCopy ? *otherCopy : otherCst; std::vector> boundingLbs; std::vector> boundingUbs; diff --git a/mlir/lib/Analysis/Utils.cpp b/mlir/lib/Analysis/Utils.cpp index 0c51fd920d01..6e3fc38d2f62 100644 --- a/mlir/lib/Analysis/Utils.cpp +++ b/mlir/lib/Analysis/Utils.cpp @@ -164,11 +164,14 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth, operands[i] = accessValueMap.getOperand(i); if (sliceState != nullptr) { + operands.reserve(operands.size() + sliceState->lbOperands[0].size()); // Append slice operands to 'operands' as symbols. - operands.append(sliceState->lbOperands[0].begin(), - sliceState->lbOperands[0].end()); - // Update 'numSymbols' by operands from 'sliceState'. - numSymbols += sliceState->lbOperands[0].size(); + for (auto extraOperand : sliceState->lbOperands[0]) { + if (!llvm::is_contained(operands, extraOperand)) { + operands.push_back(extraOperand); + numSymbols++; + } + } } // We'll first associate the dims and symbols of the access map to the dims @@ -208,7 +211,6 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth, if (!cst.findId(*operand, &loc)) { if (isValidSymbol(operand)) { cst.addSymbolId(cst.getNumSymbolIds(), const_cast(operand)); - loc = cst.getNumDimIds() + cst.getNumSymbolIds() - 1; // Check if the symbol is a constant. if (auto *opInst = operand->getDefiningInst()) { if (auto constOp = opInst->dyn_cast()) { @@ -217,7 +219,6 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth, } } else { cst.addDimId(cst.getNumDimIds(), const_cast(operand)); - loc = cst.getNumDimIds() - 1; } } } -- 2.34.1