/// -divisor * id + expr >= 0 <-- Upper bound for 'id'
///
/// For example:
-/// 32*k >= 16*i + j - 31 <-- Lower bound for 'k'
-/// 32*k <= 16*i + j <-- Upper bound for 'k'
-/// expr = 16*i + j, divisor = 32
-/// k = ( 16*i + j ) floordiv 32
+/// 32*k >= 16*i + j - 31 <-- Lower bound for 'k'
+/// 32*k <= 16*i + j <-- Upper bound for 'k'
+/// expr = 16*i + j, divisor = 32
+/// k = ( 16*i + j ) floordiv 32
///
-/// 4q >= i + j - 2 <-- Lower bound for 'q'
-/// 4q <= i + j + 1 <-- Upper bound for 'q'
-/// expr = i + j + 1, divisor = 4
-/// q = (i + j + 1) floordiv 4
+/// 4q >= i + j - 2 <-- Lower bound for 'q'
+/// 4q <= i + j + 1 <-- Upper bound for 'q'
+/// expr = i + j + 1, divisor = 4
+/// q = (i + j + 1) floordiv 4
+//
+/// This function also supports detecting divisions from bounds that are
+/// strictly tighter than the division bounds described above, since tighter
+/// bounds imply the division bounds. For example:
+/// 4q - i - j + 2 >= 0 <-- Lower bound for 'q'
+/// -4q + i + j >= 0 <-- Tight upper bound for 'q'
+///
+/// To extract floor divisions with tighter bounds, we assume that that the
+/// constraints are of the form:
+/// c <= expr - divisior * id <= divisor - 1, where 0 <= c <= divisor - 1
+/// Rearranging, we have:
+/// divisor * id - expr + (divisor - 1) >= 0 <-- Lower bound for 'id'
+/// -divisor * id + expr - c >= 0 <-- Upper bound for 'id'
///
/// If successful, `expr` is set to dividend of the division and `divisor` is
/// set to the denominator of the division.
assert(lbIneq <= cst.getNumInequalities() &&
"Invalid upper bound inequality position");
- // Due to the form of the inequalities, sum of constants of the
- // inequalities is (divisor - 1).
- int64_t denominator = cst.atIneq(lbIneq, cst.getNumCols() - 1) +
- cst.atIneq(ubIneq, cst.getNumCols() - 1) + 1;
+ // Extract divisor from the lower bound.
+ divisor = cst.atIneq(lbIneq, pos);
- // Divisor should be positive.
- if (denominator <= 0)
- return failure();
-
- // Check if coeff of variable is equal to divisor.
- if (denominator != cst.atIneq(lbIneq, pos))
- return failure();
-
- // Check if constraints are opposite of each other. Constant term
- // is not required to be opposite and is not checked.
+ // First, check if the constraints are opposite of each other except the
+ // constant term.
unsigned i = 0, e = 0;
for (i = 0, e = cst.getNumIds(); i < e; ++i)
if (cst.atIneq(ubIneq, i) != -cst.atIneq(lbIneq, i))
if (i < e)
return failure();
- // Set expr with dividend of the division.
- SmallVector<int64_t, 8> dividend(cst.getNumCols());
- for (i = 0, e = cst.getNumCols(); i < e; ++i)
+ // Then, check if the constant term is of the proper form.
+ // Due to the form of the upper/lower bound inequalities, the sum of their
+ // constants is `divisor - 1 - c`. From this, we can extract c:
+ int64_t constantSum = cst.atIneq(lbIneq, cst.getNumCols() - 1) +
+ cst.atIneq(ubIneq, cst.getNumCols() - 1);
+ int64_t c = divisor - 1 - constantSum;
+
+ // Check if `c` satisfies the condition `0 <= c <= divisor - 1`. This also
+ // implictly checks that `divisor` is positive.
+ if (!(c >= 0 && c <= divisor - 1))
+ return failure();
+
+ // The inequality pair can be used to extract the division.
+ // Set `expr` to the dividend of the division except the constant term, which
+ // is set below.
+ expr.resize(cst.getNumCols(), 0);
+ for (i = 0, e = cst.getNumIds(); i < e; ++i)
if (i != pos)
- dividend[i] = cst.atIneq(ubIneq, i);
- expr = dividend;
+ expr[i] = cst.atIneq(ubIneq, i);
- // Set divisor.
- divisor = denominator;
+ // From the upper bound inequality's form, its constant term is equal to the
+ // constant term of `expr`, minus `c`. From this,
+ // constant term of `expr` = constant term of upper bound + `c`.
+ expr.back() = cst.atIneq(ubIneq, cst.getNumCols() - 1) + c;
return success();
}
fac.getLocalReprs(dividends, denominators);
- // Check that the `dividends` and `expectedDividends` match.
- EXPECT_TRUE(expectedDividends == dividends);
-
// Check that the `denominators` and `expectedDenominators` match.
EXPECT_TRUE(expectedDenominators == denominators);
+
+ // Check that the `dividends` and `expectedDividends` match. If the
+ // denominator for a division is zero, we ignore its dividend.
+ EXPECT_TRUE(dividends.size() == expectedDividends.size());
+ for (unsigned i = 0, e = dividends.size(); i < e; ++i)
+ if (denominators[i] != 0)
+ EXPECT_TRUE(expectedDividends[i] == dividends[i]);
}
TEST(FlatAffineConstraintsTest, computeLocalReprSimple) {
checkDivisionRepresentation(fac, divisions, denoms);
}
+TEST(FlatAffineConstraintsTest, computeLocalReprTightUpperBound) {
+ MLIRContext context;
+
+ {
+ FlatAffineConstraints fac = parseFAC("(i) : (i mod 3 - 1 >= 0)", &context);
+
+ // The set formed by the fac is:
+ // 3q - i + 2 >= 0 <-- Division lower bound
+ // -3q + i - 1 >= 0
+ // -3q + i >= 0 <-- Division upper bound
+ // We remove redundant constraints to get the set:
+ // 3q - i + 2 >= 0 <-- Division lower bound
+ // -3q + i - 1 >= 0 <-- Tighter division upper bound
+ // thus, making the upper bound tighter.
+ fac.removeRedundantConstraints();
+
+ std::vector<SmallVector<int64_t, 8>> divisions = {{1, 0, 0}};
+ SmallVector<unsigned, 8> denoms = {3};
+
+ // Check if the divisions can be computed even with a tighter upper bound.
+ checkDivisionRepresentation(fac, divisions, denoms);
+ }
+
+ {
+ FlatAffineConstraints fac = parseFAC(
+ "(i, j, q) : (4*q - i - j + 2 >= 0, -4*q + i + j >= 0)", &context);
+ // Convert `q` to a local variable.
+ fac.convertDimToLocal(2, 3);
+
+ std::vector<SmallVector<int64_t, 8>> divisions = {{1, 1, 0, 1}};
+ SmallVector<unsigned, 8> denoms = {4};
+
+ // Check if the divisions can be computed even with a tighter upper bound.
+ checkDivisionRepresentation(fac, divisions, denoms);
+ }
+}
+
+TEST(FlatAffineConstraintsTest, computeLocalReprNoRepr) {
+ MLIRContext context;
+ FlatAffineConstraints fac =
+ parseFAC("(x, q) : (x - 3 * q >= 0, -x + 3 * q + 3 >= 0)", &context);
+ // Convert q to a local variable.
+ fac.convertDimToLocal(1, 2);
+
+ std::vector<SmallVector<int64_t, 8>> divisions = {{0, 0, 0}};
+ SmallVector<unsigned, 8> denoms = {0};
+
+ // Check that no division is computed.
+ checkDivisionRepresentation(fac, divisions, denoms);
+}
+
TEST(FlatAffineConstraintsTest, simplifyLocalsTest) {
// (x) : (exists y: 2x + y = 1 and y = 2).
FlatAffineConstraints fac(1, 0, 1);