[MLIR][NFC] flat affine constraints - refactor to share, renames
authorUday Bondhugula <uday@polymagelabs.com>
Tue, 24 Mar 2020 04:22:41 +0000 (09:52 +0530)
committerUday Bondhugula <uday@polymagelabs.com>
Tue, 24 Mar 2020 05:27:42 +0000 (10:57 +0530)
- refactor to remove duplicate code
- some renaming / comment updates for readability

Signed-off-by: Uday Bondhugula <uday@polymagelabs.com>
Differential Revision: https://reviews.llvm.org/D76667

mlir/include/mlir/Analysis/AffineStructures.h
mlir/lib/Analysis/AffineStructures.cpp

index e37d698..5d99320 100644 (file)
@@ -443,16 +443,17 @@ public:
   /// identifier. Returns None if it's not a constant. This method employs
   /// trivial (low complexity / cost) checks and detection. Symbolic identifiers
   /// are treated specially, i.e., it looks for constant differences between
-  /// affine expressions involving only the symbolic identifiers. See comments
-  /// at function definition for examples. 'lb' and 'lbDivisor', if provided,
-  /// are used to express the lower bound associated with the constant
-  /// difference: 'lb' has the coefficients and lbDivisor, the divisor. For eg.,
-  /// if the lower bound is [(s0 + s2 - 1) floordiv 32] for a system with three
-  /// symbolic identifiers, *lb = [1, 0, 1], lbDivisor = 32.
+  /// affine expressions involving only the symbolic identifiers. `lb` and
+  /// `ub` (along with the `boundFloorDivisor`) are set to represent the lower
+  /// and upper bound associated with the constant difference: `lb`, `ub` have
+  /// the coefficients, and boundFloorDivisor, their divisor.
+  /// Ex: if the lower bound is [(s0 + s2 - 1) floordiv 32] for a system with
+  /// three symbolic identifiers, *lb = [1, 0, 1], boundDivisor = 32. See
+  /// comments at function definition for examples.
   Optional<int64_t>
   getConstantBoundOnDimSize(unsigned pos,
                             SmallVectorImpl<int64_t> *lb = nullptr,
-                            int64_t *lbFloorDivisor = nullptr,
+                            int64_t *boundFloorDivisor = nullptr,
                             SmallVectorImpl<int64_t> *ub = nullptr) const;
 
   /// Returns the constant lower bound for the pos^th identifier if there is
index 3448443..6ebc673 100644 (file)
@@ -1201,17 +1201,30 @@ static bool detectAsMod(const FlatAffineConstraints &cst, unsigned pos,
   return false;
 }
 
-/// Gather all lower and upper bounds of the identifier at `pos`.
+/// Gather all lower and upper bounds of the identifier at `pos`. The bounds are
+/// to be independent of [offset, offset + num) identifiers.
 static void getLowerAndUpperBoundIndices(const FlatAffineConstraints &cst,
                                          unsigned pos,
                                          SmallVectorImpl<unsigned> *lbIndices,
-                                         SmallVectorImpl<unsigned> *ubIndices) {
+                                         SmallVectorImpl<unsigned> *ubIndices,
+                                         unsigned offset = 0,
+                                         unsigned num = 0) {
   assert(pos < cst.getNumIds() && "invalid position");
 
   // Gather all lower bounds and upper bounds of the variable. Since the
   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
   for (unsigned r = 0, e = cst.getNumInequalities(); r < e; r++) {
+    // The bounds are to be independent of [offset, offset + num) columns.
+    unsigned c, f;
+    for (c = offset, f = offset + num; c < f; ++c) {
+      if (c == pos)
+        continue;
+      if (cst.atIneq(r, c) != 0)
+        break;
+    }
+    if (c < f)
+      continue;
     if (cst.atIneq(r, pos) >= 1) {
       // Lower bound.
       lbIndices->push_back(r);
@@ -1866,7 +1879,8 @@ void FlatAffineConstraints::removeEquality(unsigned pos) {
 /// Finds an equality that equates the specified identifier to a constant.
 /// Returns the position of the equality row. If 'symbolic' is set to true,
 /// symbols are also treated like a constant, i.e., an affine function of the
-/// symbols is also treated like a constant.
+/// symbols is also treated like a constant. Returns -1 if such an equality
+/// could not be found.
 static int findEqualityToConstant(const FlatAffineConstraints &cst,
                                   unsigned pos, bool symbolic = false) {
   assert(pos < cst.getNumIds() && "invalid position");
@@ -1937,19 +1951,15 @@ void FlatAffineConstraints::constantFoldIdRange(unsigned pos, unsigned num) {
 //       s0 - 7 <= 8*j <= s0 returns 1 with lb = s0, lbDivisor = 8 (since lb =
 //       ceil(s0 - 7 / 8) = floor(s0 / 8)).
 Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
-    unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *lbFloorDivisor,
+    unsigned pos, SmallVectorImpl<int64_t> *lb, int64_t *boundFloorDivisor,
     SmallVectorImpl<int64_t> *ub) const {
   assert(pos < getNumDimIds() && "Invalid identifier position");
   assert(getNumLocalIds() == 0);
 
-  // TODO(bondhugula): eliminate all remaining dimensional identifiers (other
-  // than the one at 'pos' to make this more powerful. Not needed for
-  // hyper-rectangular spaces.
-
   // Find an equality for 'pos'^th identifier that equates it to some function
   // of the symbolic identifiers (+ constant).
-  int eqRow = findEqualityToConstant(*this, pos, /*symbolic=*/true);
-  if (eqRow != -1) {
+  int eqPos = findEqualityToConstant(*this, pos, /*symbolic=*/true);
+  if (eqPos != -1) {
     // This identifier can only take a single value.
     if (lb) {
       // Set lb to that symbolic value.
@@ -1957,18 +1967,18 @@ Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
       if (ub)
         ub->resize(getNumSymbolIds() + 1);
       for (unsigned c = 0, f = getNumSymbolIds() + 1; c < f; c++) {
-        int64_t v = atEq(eqRow, pos);
+        int64_t v = atEq(eqPos, pos);
         // atEq(eqRow, pos) is either -1 or 1.
         assert(v * v == 1);
-        (*lb)[c] = v < 0 ? atEq(eqRow, getNumDimIds() + c) / -v
-                         : -atEq(eqRow, getNumDimIds() + c) / v;
+        (*lb)[c] = v < 0 ? atEq(eqPos, getNumDimIds() + c) / -v
+                         : -atEq(eqPos, getNumDimIds() + c) / v;
         // Since this is an equality, ub = lb.
         if (ub)
           (*ub)[c] = (*lb)[c];
       }
-      assert(lbFloorDivisor &&
+      assert(boundFloorDivisor &&
              "both lb and divisor or none should be provided");
-      *lbFloorDivisor = 1;
+      *boundFloorDivisor = 1;
     }
     return 1;
   }
@@ -1990,25 +2000,9 @@ Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
   // the bounds can only involve symbolic (and local) identifiers. Since the
   // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
   // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
-  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
-    unsigned c, f;
-    for (c = 0, f = getNumDimIds(); c < f; c++) {
-      if (c != pos && atIneq(r, c) != 0)
-        break;
-    }
-    if (c < getNumDimIds())
-      // Not a pure symbolic bound.
-      continue;
-    if (atIneq(r, pos) >= 1)
-      // Lower bound.
-      lbIndices.push_back(r);
-    else if (atIneq(r, pos) <= -1)
-      // Upper bound.
-      ubIndices.push_back(r);
-  }
-
-  // TODO(bondhugula): eliminate other dimensional identifiers to make this more
-  // powerful. Not needed for hyper-rectangular iteration spaces.
+  getLowerAndUpperBoundIndices(*this, pos, &lbIndices, &ubIndices,
+                               /*offset=*/0,
+                               /*num=*/getNumDimIds());
 
   Optional<int64_t> minDiff = None;
   unsigned minLbPosition, minUbPosition;
@@ -2046,8 +2040,8 @@ Optional<int64_t> FlatAffineConstraints::getConstantBoundOnDimSize(
     // of the variable at 'pos'. We express the ceildiv equivalently as a floor
     // for uniformity. For eg., if the lower bound constraint was: 32*d0 - N +
     // 31 >= 0, the lower bound for d0 is ceil(N - 31, 32), i.e., floor(N, 32).
-    *lbFloorDivisor = atIneq(minLbPosition, pos);
-    assert(*lbFloorDivisor == -atIneq(minUbPosition, pos));
+    *boundFloorDivisor = atIneq(minLbPosition, pos);
+    assert(*boundFloorDivisor == -atIneq(minUbPosition, pos));
     for (unsigned c = 0, e = getNumSymbolIds() + 1; c < e; c++) {
       (*lb)[c] = -atIneq(minLbPosition, getNumDimIds() + c);
     }