Method to align/merge dimensional/symbolic identifiers between two FlatAffineConstraints
authorUday Bondhugula <bondhugula@google.com>
Fri, 1 Mar 2019 16:49:20 +0000 (08:49 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 23:51:47 +0000 (16:51 -0700)
- add a method to merge and align the spaces (identifiers) of two
  FlatAffineConstraints (both get dimension-wise and symbol-wise unique
  columns)
- this completes several TODOs, gets rid of previous assumptions/restrictions
  in composeMap, unionBoundingBox, and reuses common code
- remove previous workarounds / duplicated funcitonality in
  FlatAffineConstraints::composeMap and unionBoundingBox, use mergeAlignIds
  from both

PiperOrigin-RevId: 236320581

mlir/include/mlir/Analysis/AffineStructures.h
mlir/include/mlir/Analysis/Utils.h
mlir/lib/Analysis/AffineStructures.cpp
mlir/lib/Analysis/Utils.cpp

index 7e75634e537ce9bed9ed19c37a5bd3b6be1f685b..48d9383bef0865db8b643395f9575b13d80022a5 100644 (file)
@@ -508,7 +508,8 @@ public:
   /// contains the points of 'this' set and that of 'other', with the symbols
   /// being treated specially. For each of the dimensions, the min of the lower
   /// bounds (symbolic) and the max of the upper bounds (symbolic) is computed
-  /// to determine such a bounding box.
+  /// to determine such a bounding box. `other' is expected to have the same
+  /// dimensional identifiers as this constraint system (in the same order).
   ///
   /// Eg: if 'this' is {0 <= d0 <= 127}, 'other' is {16 <= d0 <= 192}, the
   ///      output is {0 <= d0 <= 192}.
@@ -532,6 +533,13 @@ public:
   inline ArrayRef<Optional<Value *>> getIds() const {
     return {ids.data(), ids.size()};
   }
+  inline MutableArrayRef<Optional<Value *>> getIds() {
+    return {ids.data(), ids.size()};
+  }
+
+  /// Returns the optional Value corresponding to the pos^th identifier.
+  inline Optional<Value *> getId(unsigned pos) const { return ids[pos]; }
+  inline Optional<Value *> &getId(unsigned pos) { return ids[pos]; }
 
   /// Returns the Value associated with the pos^th identifier. Asserts if
   /// no Value identifier was associated.
@@ -707,8 +715,8 @@ private:
 
   /// Values corresponding to the (column) identifiers of this constraint
   /// system appearing in the order the identifiers correspond to columns.
-  /// Temporary ones or those that aren't associated to any Value are to be
-  /// set to None.
+  /// Temporary ones or those that aren't associated to any Value are set to
+  /// None.
   SmallVector<Optional<Value *>, 8> ids;
 
   /// A parameter that controls detection of an unrealistic number of
index cb78b3872b56d0bc017e3db90ae8b92801633fe0..dff0f57aaccfc0e328bab79baab2925e9e8a7f74 100644 (file)
@@ -187,6 +187,7 @@ struct MemRefRegion {
   /// Returns the size of this MemRefRegion in bytes.
   Optional<int64_t> getRegionSize();
 
+  // Wrapper around FlatAffineConstraints::unionBoundingBox.
   bool unionBoundingBox(const MemRefRegion &other);
 
   /// Returns the rank of the memref that this region corresponds to.
index 1837e9cf47aebc9b3ac9d140d0ff001b0ec0cb5c..3066f95d4f7d410a14fec8b6fa8a3345341d8bdd 100644 (file)
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/Support/MathExtras.h"
 #include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 
 #define DEBUG_TYPE "affine-structures"
 
 using namespace mlir;
-using namespace llvm;
+using llvm::SmallDenseMap;
+using llvm::SmallDenseSet;
+using llvm::SmallPtrSet;
 
 namespace {
 
@@ -480,65 +483,156 @@ void FlatAffineConstraints::addId(IdKind kind, unsigned pos, Value *id) {
   assert(ids.size() == getNumIds());
 }
 
-// This routine may add additional local variables if the flattened expression
-// corresponding to the map has such variables due to the presence of
-// mod's, ceildiv's, and floordiv's.
-bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) {
-  // Assert if the map and this constraint set aren't associated with the same
-  // identifiers in the same order.
-  assert(vMap->getNumDims() <= getNumDimIds());
-  assert(vMap->getNumSymbols() <= getNumSymbolIds());
-  for (unsigned i = 0, e = vMap->getNumDims(); i < e; i++) {
-    assert(ids[i].hasValue());
-    assert(vMap->getOperand(i) == ids[i].getValue());
+/// Checks if two constraint systems are in the same space, i.e., if they are
+/// associated with the same set of identifiers, appearing in the same order.
+bool areIdsAligned(const FlatAffineConstraints &A,
+                   const FlatAffineConstraints &B) {
+  return A.getNumDimIds() == B.getNumDimIds() &&
+         A.getNumSymbolIds() == B.getNumSymbolIds() &&
+         A.getNumIds() == B.getNumIds() && A.getIds().equals(B.getIds());
+}
+
+/// Checks if the SSA values associated with `cst''s identifiers are unique.
+static bool areIdsUnique(const FlatAffineConstraints &cst) {
+  SmallPtrSet<Value *, 8> uniqueIds;
+  for (auto id : cst.getIds()) {
+    if (id.hasValue() && !uniqueIds.insert(id.getValue()).second)
+      return false;
+  }
+  return true;
+}
+
+// Swap the posA^th identifier with the posB^th identifier.
+static void swapId(FlatAffineConstraints *A, unsigned posA, unsigned posB) {
+  assert(posA < A->getNumIds() && "invalid position A");
+  assert(posB < A->getNumIds() && "invalid position B");
+
+  if (posA == posB)
+    return;
+
+  for (unsigned r = 0, e = A->getNumInequalities(); r < e; r++) {
+    std::swap(A->atIneq(r, posA), A->atIneq(r, posB));
+  }
+  for (unsigned r = 0, e = A->getNumEqualities(); r < e; r++) {
+    std::swap(A->atEq(r, posA), A->atEq(r, posB));
   }
-  for (unsigned i = 0, e = vMap->getNumSymbols(); i < e; i++) {
-    assert(ids[numDims + i].hasValue());
-    assert(vMap->getOperand(vMap->getNumDims() + i) ==
-           ids[numDims + i].getValue());
+  std::swap(A->getId(posA), A->getId(posB));
+}
+
+/// Merge and align the identifiers of A and B so that both constraint systems
+/// get the union of the contained identifiers that is dimension-wise and
+/// symbol-wise unique; both constraint systems are updated so that they have
+/// the union of all identifiers, with A's original identifiers appearing first
+/// followed by any of B's identifiers that didn't appear in A. Local
+/// identifiers of each system are by design separate/local and are placed one
+/// after other (A's followed by B's).
+//  Eg: Input: A has ((%i %j) [%M %N]) and B has (%k, %j) [%P, %N, %M])
+//      Output: both A, B have (%i, %j, %k) [%M, %N, %P]
+//
+// TODO(mlir-team): expose this function at some point.
+static void mergeAndAlignIds(FlatAffineConstraints *A,
+                             FlatAffineConstraints *B) {
+  // A merge/align isn't meaningful if a cst's ids aren't distinct.
+  assert(areIdsUnique(*A) && "A's id values aren't unique");
+  assert(areIdsUnique(*B) && "B's id values aren't unique");
+
+  assert(std::all_of(A->getIds().begin(),
+                     A->getIds().begin() + A->getNumDimAndSymbolIds(),
+                     [](Optional<Value *> id) { return id.hasValue(); }));
+
+  assert(std::all_of(B->getIds().begin(),
+                     B->getIds().begin() + B->getNumDimAndSymbolIds(),
+                     [](Optional<Value *> id) { return id.hasValue(); }));
+
+  // Place local id's of A after local id's of B.
+  for (unsigned l = 0, e = A->getNumLocalIds(); l < e; l++) {
+    B->addLocalId(0);
+  }
+  for (unsigned t = 0, e = B->getNumLocalIds() - A->getNumLocalIds(); t < e;
+       t++) {
+    A->addLocalId(A->getNumLocalIds());
   }
 
+  SmallVector<Value *, 4> aDimValues, aSymValues;
+  A->getIdValues(0, A->getNumDimIds(), &aDimValues);
+  A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &aSymValues);
+  {
+    // Merge dims from A into B.
+    unsigned d = 0;
+    for (auto *aDimValue : aDimValues) {
+      unsigned loc;
+      if (B->findId(*aDimValue, &loc)) {
+        assert(loc < B->getNumDimIds() &&
+               "A's dim appears in B's non-dim position");
+        swapId(B, d, loc);
+      } else {
+        B->addDimId(d);
+        B->setIdValue(d, aDimValue);
+      }
+      d++;
+    }
+
+    // Dimensions that are in B, but not in A, are added at the end.
+    for (unsigned t = A->getNumDimIds(), e = B->getNumDimIds(); t < e; t++) {
+      A->addDimId(A->getNumDimIds());
+      A->setIdValue(A->getNumDimIds() - 1, B->getIdValue(t));
+    }
+  }
+  {
+    // Merge symbols: merge A's symbols into B first.
+    unsigned s = B->getNumDimIds();
+    for (auto *aSymValue : aSymValues) {
+      unsigned loc;
+      if (B->findId(*aSymValue, &loc)) {
+        assert(loc >= B->getNumDimIds() && loc < B->getNumDimAndSymbolIds() &&
+               "A's symbol appears in B's non-symbol position");
+        swapId(B, s, loc);
+      } else {
+        B->addSymbolId(s - B->getNumDimIds());
+        B->setIdValue(s, aSymValue);
+      }
+      s++;
+    }
+    // Symbols that are in B, but not in A, are added at the end.
+    for (unsigned t = A->getNumDimAndSymbolIds(),
+                  e = B->getNumDimAndSymbolIds();
+         t < e; t++) {
+      A->addSymbolId(A->getNumSymbolIds());
+      A->setIdValue(A->getNumDimAndSymbolIds() - 1, B->getIdValue(t));
+    }
+  }
+  assert(areIdsAligned(*A, *B) && "IDs expected to be aligned");
+}
+
+// This routine may add additional local variables if the flattened expression
+// corresponding to the map has such variables due to mod's, ceildiv's, and
+// floordiv's in it.
+bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) {
   std::vector<SmallVector<int64_t, 8>> flatExprs;
-  FlatAffineConstraints cst;
-  if (!getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs, &cst)) {
+  FlatAffineConstraints localCst;
+  if (!getFlattenedAffineExprs(vMap->getAffineMap(), &flatExprs, &localCst)) {
     LLVM_DEBUG(llvm::dbgs()
                << "composition unimplemented for semi-affine maps\n");
     return false;
   }
   assert(flatExprs.size() == vMap->getNumResults());
 
-  // Make the value map and the flat affine cst dimensions compatible.
-  // A lot of this code will be refactored/cleaned up.
-  // TODO(bondhugula): the next ~20 lines of code is pretty UGLY. This needs
-  // to be factored out into an FlatAffineConstraints::alignAndMerge().
-  for (unsigned l = 0, e = cst.getNumLocalIds(); l < e; l++) {
-    addLocalId(0);
-  }
+  // Make the value map and the flat affine localCst dimensions compatible.
+  SmallVector<Value *, 8> values(vMap->getOperands().begin(),
+                                 vMap->getOperands().end());
+  localCst.setIdValues(0, localCst.getNumDimAndSymbolIds(), values);
+  // Align localCst and this - localCst's identifiers appear first in the union.
+  mergeAndAlignIds(&localCst, this);
 
+  // Add dimensions corresponding to the map's results.
   for (unsigned t = 0, e = vMap->getNumResults(); t < e; t++) {
     // TODO: Consider using a batched version to add a range of IDs.
     addDimId(0);
-    cst.addDimId(0);
+    localCst.addDimId(0);
   }
 
-  assert(cst.getNumDimIds() <= getNumDimIds());
-  for (unsigned t = 0, e = getNumDimIds() - cst.getNumDimIds(); t < e; t++) {
-    // Dimensions that are in 'this' but not in vMap/cst are added at the end.
-    cst.addDimId(cst.getNumDimIds());
-  }
-  assert(cst.getNumSymbolIds() <= getNumSymbolIds());
-  for (unsigned t = 0, e = getNumSymbolIds() - cst.getNumSymbolIds(); t < e;
-       t++) {
-    // Dimensions that are in 'this' but not in vMap/cst are added at the end.
-    cst.addSymbolId(cst.getNumSymbolIds());
-  }
-  assert(cst.getNumLocalIds() <= getNumLocalIds());
-  for (unsigned t = 0, e = getNumLocalIds() - cst.getNumLocalIds(); t < e;
-       t++) {
-    cst.addLocalId(cst.getNumLocalIds());
-  }
-  /// Finally, append cst to this constraint set.
-  append(cst);
+  // Finally, append localCst to this constraint set.
+  append(localCst);
 
   // We add one equality for each result connecting the result dim of the map to
   // the other identifiers.
@@ -548,27 +642,26 @@ bool FlatAffineConstraints::composeMap(AffineValueMap *vMap) {
   //  add two equalities overall: d_0 - i0 - 1 == 0, d1 - i0 - 8*i2 == 0.
   for (unsigned r = 0, e = flatExprs.size(); r < e; r++) {
     const auto &flatExpr = flatExprs[r];
+    assert(flatExpr.size() >= vMap->getNumOperands() + 1);
+
     // eqToAdd is the equality corresponding to the flattened affine expression.
     SmallVector<int64_t, 8> eqToAdd(getNumCols(), 0);
     // Set the coefficient for this result to one.
     eqToAdd[r] = 1;
 
-    assert(flatExpr.size() >= vMap->getNumOperands() + 1);
-
     // Dims and symbols.
     for (unsigned i = 0, e = vMap->getNumOperands(); i < e; i++) {
       unsigned loc;
       bool ret = findId(*vMap->getOperand(i), &loc);
       assert(ret && "value map's id can't be found");
       (void)ret;
-      // We need to negate 'eq[r]' since the newly added dimension is going to
-      // be set to this one.
+      // Negate 'eq[r]' since the newly added dimension will be set to this one.
       eqToAdd[loc] = -flatExpr[i];
     }
-    // Local vars common to eq and cst are at the beginning.
-    int j = getNumDimIds() + getNumSymbolIds();
-    int end = flatExpr.size() - 1;
-    for (int i = vMap->getNumOperands(); i < end; i++, j++) {
+    // Local vars common to eq and localCst are at the beginning.
+    unsigned j = getNumDimIds() + getNumSymbolIds();
+    unsigned end = flatExpr.size() - 1;
+    for (unsigned i = vMap->getNumOperands(); i < end; i++, j++) {
       eqToAdd[j] = -flatExpr[i];
     }
 
@@ -2447,52 +2540,25 @@ static BoundCmpResult compareBounds(ArrayRef<int64_t> a, ArrayRef<int64_t> b) {
 }
 }; // namespace
 
-// TODO(bondhugula,andydavis): This still doesn't do a comprehensive merge of
-// the symbols. Assumes the common symbols appear in the same order (the
-// current/common use case).
-static void mergeSymbols(FlatAffineConstraints *A, FlatAffineConstraints *B) {
-  SmallVector<Value *, 4> symbolsA, symbolsB;
-  A->getIdValues(A->getNumDimIds(), A->getNumDimAndSymbolIds(), &symbolsA);
-  B->getIdValues(B->getNumDimIds(), B->getNumDimAndSymbolIds(), &symbolsB);
-
-  // Both symbol list have a handful symbols each typically (3-4); a merge
-  // quadratic in complexity with a linear search is fine.
-  for (auto *symbolB : symbolsB) {
-    if (llvm::is_contained(symbolsA, symbolB)) {
-      A->addSymbolId(symbolsA.size(), symbolB);
-      symbolsA.push_back(symbolB);
-    }
-  }
-  // symbolsA now holds the merged symbol list.
-  symbolsB.reserve(symbolsA.size());
-  unsigned iB = 0;
-  for (auto *symbolA : symbolsA) {
-    assert(iB < symbolsB.size());
-    if (symbolA != symbolsB[iB]) {
-      symbolsB.insert(symbolsB.begin() + iB, symbolA);
-      B->addSymbolId(iB, symbolA);
-    }
-    ++iB;
-  }
-}
-
-// Compute the bounding box with respect to 'other' by finding the min of the
+// Computes the bounding box with respect to 'other' by finding the min of the
 // lower bounds and the max of the upper bounds along each of the dimensions.
 bool FlatAffineConstraints::unionBoundingBox(
-    const FlatAffineConstraints &otherArg) {
-  assert(otherArg.getNumDimIds() == numDims && "dims mismatch");
-
-  Optional<FlatAffineConstraints> copy;
-  if (!otherArg.getIds().equals(getIds())) {
-    copy.emplace(FlatAffineConstraints(otherArg));
-    mergeSymbols(this, &copy.getValue());
-    assert(getIds().equals(copy->getIds()) && "merge failed");
-  }
-
-  const auto &other = copy ? *copy : otherArg;
-
-  assert(other.getNumLocalIds() == 0 && "local ids not eliminated");
-  assert(getNumLocalIds() == 0 && "local ids not eliminated");
+    const FlatAffineConstraints &otherCst) {
+  assert(otherCst.getNumDimIds() == numDims && "dims mismatch");
+  assert(otherCst.getIds()
+             .slice(0, getNumDimIds())
+             .equals(getIds().slice(0, getNumDimIds())) &&
+         "dim values mismatch");
+  assert(otherCst.getNumLocalIds() == 0 && "local ids not supported here");
+  assert(getNumLocalIds() == 0 && "local ids not supported yet here");
+
+  Optional<FlatAffineConstraints> otherCopy;
+  if (!areIdsAligned(*this, otherCst)) {
+    otherCopy.emplace(FlatAffineConstraints(otherCst));
+    mergeAndAlignIds(this, &otherCopy.getValue());
+  }
+
+  const auto &other = otherCopy ? *otherCopy : otherCst;
 
   std::vector<SmallVector<int64_t, 8>> boundingLbs;
   std::vector<SmallVector<int64_t, 8>> boundingUbs;
index 0c51fd920d019211749f309cff3673fdb795955a..6e3fc38d2f625fdbf1d3621c685e11dca121388c 100644 (file)
@@ -164,11 +164,14 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth,
     operands[i] = accessValueMap.getOperand(i);
 
   if (sliceState != nullptr) {
+    operands.reserve(operands.size() + sliceState->lbOperands[0].size());
     // Append slice operands to 'operands' as symbols.
-    operands.append(sliceState->lbOperands[0].begin(),
-                    sliceState->lbOperands[0].end());
-    // Update 'numSymbols' by operands from 'sliceState'.
-    numSymbols += sliceState->lbOperands[0].size();
+    for (auto extraOperand : sliceState->lbOperands[0]) {
+      if (!llvm::is_contained(operands, extraOperand)) {
+        operands.push_back(extraOperand);
+        numSymbols++;
+      }
+    }
   }
 
   // We'll first associate the dims and symbols of the access map to the dims
@@ -208,7 +211,6 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth,
       if (!cst.findId(*operand, &loc)) {
         if (isValidSymbol(operand)) {
           cst.addSymbolId(cst.getNumSymbolIds(), const_cast<Value *>(operand));
-          loc = cst.getNumDimIds() + cst.getNumSymbolIds() - 1;
           // Check if the symbol is a constant.
           if (auto *opInst = operand->getDefiningInst()) {
             if (auto constOp = opInst->dyn_cast<ConstantIndexOp>()) {
@@ -217,7 +219,6 @@ bool MemRefRegion::compute(Instruction *inst, unsigned loopDepth,
           }
         } else {
           cst.addDimId(cst.getNumDimIds(), const_cast<Value *>(operand));
-          loc = cst.getNumDimIds() - 1;
         }
       }
     }