[MLIR][FlatAffineConstraints] Add support for extracting divisions with tighter bounds
authorGroverkss <groverkss@gmail.com>
Sat, 11 Dec 2021 08:27:19 +0000 (13:57 +0530)
committerGroverkss <groverkss@gmail.com>
Sat, 11 Dec 2021 10:53:54 +0000 (16:23 +0530)
This patch adds support for extracting divisions when the set contains bounds
which are tighter than the division bounds. For example:

```
     3q - i + 2 >= 0                       <-- Lower bound for 'q'
    -3q + i - 1 >= 0                       <-- Tighter upper bound for 'q'
```

Here, the actual upper bound for division for `q` would be `-3q + i >= 0`, but
since this actual upper bound is implied by a tighter upper bound, which awe can still
extract the divison.

Reviewed By: arjunp

Differential Revision: https://reviews.llvm.org/D115096

mlir/lib/Analysis/AffineStructures.cpp
mlir/unittests/Analysis/AffineStructuresTest.cpp

index 8395a9d..f4d8574 100644 (file)
@@ -1215,15 +1215,28 @@ bool FlatAffineConstraints::containsPoint(ArrayRef<int64_t> point) const {
 ///      -divisor * id + expr                 >= 0  <-- Upper bound for 'id'
 ///
 /// For example:
-///       32*k >= 16*i + j - 31                 <-- Lower bound for 'k'
-///       32*k  <= 16*i + j                     <-- Upper bound for 'k'
-///       expr = 16*i + j, divisor = 32
-///       k = ( 16*i + j ) floordiv 32
+///     32*k >= 16*i + j - 31                 <-- Lower bound for 'k'
+///     32*k  <= 16*i + j                     <-- Upper bound for 'k'
+///     expr = 16*i + j, divisor = 32
+///     k = ( 16*i + j ) floordiv 32
 ///
-///       4q >= i + j - 2                       <-- Lower bound for 'q'
-///       4q <= i + j + 1                       <-- Upper bound for 'q'
-///       expr = i + j + 1, divisor = 4
-///       q = (i + j + 1) floordiv 4
+///     4q >= i + j - 2                       <-- Lower bound for 'q'
+///     4q <= i + j + 1                       <-- Upper bound for 'q'
+///     expr = i + j + 1, divisor = 4
+///     q = (i + j + 1) floordiv 4
+//
+/// This function also supports detecting divisions from bounds that are
+/// strictly tighter than the division bounds described above, since tighter
+/// bounds imply the division bounds. For example:
+///     4q - i - j + 2 >= 0                       <-- Lower bound for 'q'
+///    -4q + i + j     >= 0                       <-- Tight upper bound for 'q'
+///
+/// To extract floor divisions with tighter bounds, we assume that that the
+/// constraints are of the form:
+///     c <= expr - divisior * id <= divisor - 1, where 0 <= c <= divisor - 1
+/// Rearranging, we have:
+///     divisor * id - expr + (divisor - 1) >= 0  <-- Lower bound for 'id'
+///    -divisor * id + expr - c             >= 0  <-- Upper bound for 'id'
 ///
 /// If successful, `expr` is set to dividend of the division and `divisor` is
 /// set to the denominator of the division.
@@ -1238,21 +1251,11 @@ static LogicalResult getDivRepr(const FlatAffineConstraints &cst, unsigned pos,
   assert(lbIneq <= cst.getNumInequalities() &&
          "Invalid upper bound inequality position");
 
-  // Due to the form of the inequalities, sum of constants of the
-  // inequalities is (divisor - 1).
-  int64_t denominator = cst.atIneq(lbIneq, cst.getNumCols() - 1) +
-                        cst.atIneq(ubIneq, cst.getNumCols() - 1) + 1;
+  // Extract divisor from the lower bound.
+  divisor = cst.atIneq(lbIneq, pos);
 
-  // Divisor should be positive.
-  if (denominator <= 0)
-    return failure();
-
-  // Check if coeff of variable is equal to divisor.
-  if (denominator != cst.atIneq(lbIneq, pos))
-    return failure();
-
-  // Check if constraints are opposite of each other. Constant term
-  // is not required to be opposite and is not checked.
+  // First, check if the constraints are opposite of each other except the
+  // constant term.
   unsigned i = 0, e = 0;
   for (i = 0, e = cst.getNumIds(); i < e; ++i)
     if (cst.atIneq(ubIneq, i) != -cst.atIneq(lbIneq, i))
@@ -1261,15 +1264,30 @@ static LogicalResult getDivRepr(const FlatAffineConstraints &cst, unsigned pos,
   if (i < e)
     return failure();
 
-  // Set expr with dividend of the division.
-  SmallVector<int64_t, 8> dividend(cst.getNumCols());
-  for (i = 0, e = cst.getNumCols(); i < e; ++i)
+  // Then, check if the constant term is of the proper form.
+  // Due to the form of the upper/lower bound inequalities, the sum of their
+  // constants is `divisor - 1 - c`. From this, we can extract c:
+  int64_t constantSum = cst.atIneq(lbIneq, cst.getNumCols() - 1) +
+                        cst.atIneq(ubIneq, cst.getNumCols() - 1);
+  int64_t c = divisor - 1 - constantSum;
+
+  // Check if `c` satisfies the condition `0 <= c <= divisor - 1`. This also
+  // implictly checks that `divisor` is positive.
+  if (!(c >= 0 && c <= divisor - 1))
+    return failure();
+
+  // The inequality pair can be used to extract the division.
+  // Set `expr` to the dividend of the division except the constant term, which
+  // is set below.
+  expr.resize(cst.getNumCols(), 0);
+  for (i = 0, e = cst.getNumIds(); i < e; ++i)
     if (i != pos)
-      dividend[i] = cst.atIneq(ubIneq, i);
-  expr = dividend;
+      expr[i] = cst.atIneq(ubIneq, i);
 
-  // Set divisor.
-  divisor = denominator;
+  // From the upper bound inequality's form, its constant term is equal to the
+  // constant term of `expr`, minus `c`. From this,
+  // constant term of `expr` = constant term of upper bound + `c`.
+  expr.back() = cst.atIneq(ubIneq, cst.getNumCols() - 1) + c;
 
   return success();
 }
index 7c447d7..d459d08 100644 (file)
@@ -623,11 +623,15 @@ static void checkDivisionRepresentation(
 
   fac.getLocalReprs(dividends, denominators);
 
-  // Check that the `dividends` and `expectedDividends` match.
-  EXPECT_TRUE(expectedDividends == dividends);
-
   // Check that the `denominators` and `expectedDenominators` match.
   EXPECT_TRUE(expectedDenominators == denominators);
+
+  // Check that the `dividends` and `expectedDividends` match. If the
+  // denominator for a division is zero, we ignore its dividend.
+  EXPECT_TRUE(dividends.size() == expectedDividends.size());
+  for (unsigned i = 0, e = dividends.size(); i < e; ++i)
+    if (denominators[i] != 0)
+      EXPECT_TRUE(expectedDividends[i] == dividends[i]);
 }
 
 TEST(FlatAffineConstraintsTest, computeLocalReprSimple) {
@@ -687,6 +691,57 @@ TEST(FlatAffineConstraintsTest, computeLocalReprRecursive) {
   checkDivisionRepresentation(fac, divisions, denoms);
 }
 
+TEST(FlatAffineConstraintsTest, computeLocalReprTightUpperBound) {
+  MLIRContext context;
+
+  {
+    FlatAffineConstraints fac = parseFAC("(i) : (i mod 3 - 1 >= 0)", &context);
+
+    // The set formed by the fac is:
+    //        3q - i + 2 >= 0             <-- Division lower bound
+    //       -3q + i - 1 >= 0
+    //       -3q + i     >= 0             <-- Division upper bound
+    // We remove redundant constraints to get the set:
+    //        3q - i + 2 >= 0             <-- Division lower bound
+    //       -3q + i - 1 >= 0             <-- Tighter division upper bound
+    // thus, making the upper bound tighter.
+    fac.removeRedundantConstraints();
+
+    std::vector<SmallVector<int64_t, 8>> divisions = {{1, 0, 0}};
+    SmallVector<unsigned, 8> denoms = {3};
+
+    // Check if the divisions can be computed even with a tighter upper bound.
+    checkDivisionRepresentation(fac, divisions, denoms);
+  }
+
+  {
+    FlatAffineConstraints fac = parseFAC(
+        "(i, j, q) : (4*q - i - j + 2 >= 0, -4*q + i + j >= 0)", &context);
+    // Convert `q` to a local variable.
+    fac.convertDimToLocal(2, 3);
+
+    std::vector<SmallVector<int64_t, 8>> divisions = {{1, 1, 0, 1}};
+    SmallVector<unsigned, 8> denoms = {4};
+
+    // Check if the divisions can be computed even with a tighter upper bound.
+    checkDivisionRepresentation(fac, divisions, denoms);
+  }
+}
+
+TEST(FlatAffineConstraintsTest, computeLocalReprNoRepr) {
+  MLIRContext context;
+  FlatAffineConstraints fac =
+      parseFAC("(x, q) : (x - 3 * q >= 0, -x + 3 * q + 3 >= 0)", &context);
+  // Convert q to a local variable.
+  fac.convertDimToLocal(1, 2);
+
+  std::vector<SmallVector<int64_t, 8>> divisions = {{0, 0, 0}};
+  SmallVector<unsigned, 8> denoms = {0};
+
+  // Check that no division is computed.
+  checkDivisionRepresentation(fac, divisions, denoms);
+}
+
 TEST(FlatAffineConstraintsTest, simplifyLocalsTest) {
   // (x) : (exists y: 2x + y = 1 and y = 2).
   FlatAffineConstraints fac(1, 0, 1);