[MLIR][Presburger] Support divisions in union of two PWMAFunction
authorGroverkss <groverkss@gmail.com>
Thu, 19 Jan 2023 13:01:56 +0000 (18:31 +0530)
committerGroverkss <groverkss@gmail.com>
Thu, 19 Jan 2023 13:03:00 +0000 (18:33 +0530)
This patch adds support for divisions in the union of two PWMAFunction. This is
now possible because of previous patches, which made divisions explicitly
stored in MultiAffineFunction (MAF). This patch also refactors the previous
implementation, moving the implementation for obtaining a set of points where a
MAF is lexicographically "better" than the other to MAF.

Reviewed By: arjunp

Differential Revision: https://reviews.llvm.org/D138118

mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h
mlir/lib/Analysis/Presburger/PWMAFunction.cpp
mlir/lib/Analysis/Presburger/PresburgerSpace.cpp
mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp

index 4ba0f44..ea34566 100644 (file)
 namespace mlir {
 namespace presburger {
 
+/// Enum representing a binary comparison operator: equal, not equal, less than,
+/// less than or equal, greater than, greater than or equal.
+enum class OrderingKind { EQ, NE, LT, LE, GT, GE };
+
 /// This class represents a multi-affine function with the domain as Z^d, where
 /// `d` is the number of domain variables of the function. For example:
 ///
@@ -65,7 +69,10 @@ public:
   /// Get the `i^th` output expression.
   ArrayRef<MPInt> getOutputExpr(unsigned i) const { return output.getRow(i); }
 
-  // Remove the specified range of outputs.
+  /// Get the divisions used in this function.
+  const DivisionRepr &getDivs() const { return divs; }
+
+  /// Remove the specified range of outputs.
   void removeOutputs(unsigned start, unsigned end);
 
   /// Given a MAF `other`, merges division variables such that both functions
@@ -89,6 +96,14 @@ public:
 
   void subtract(const MultiAffineFunction &other);
 
+  /// Return the set of domain points where the output of `this` and `other`
+  /// are ordered lexicographically according to the given ordering.
+  /// For example, if the given comparison is `LT`, then the returned set
+  /// contains all points where the first output of `this` is lexicographically
+  /// less than `other`.
+  PresburgerSet getLexSet(OrderingKind comp,
+                          const MultiAffineFunction &other) const;
+
   /// Get this function as a relation.
   IntegerRelation getAsRelation() const;
 
@@ -181,6 +196,9 @@ public:
     return valueAt(getMPIntVec(point));
   }
 
+  /// Return all the pieces of this piece-wise function.
+  ArrayRef<Piece> getAllPieces() const { return pieces; }
+
   /// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether
   /// they have the same dimensions, the same domain and they take the same
   /// value at every point in the domain.
index 03a5dfb..998a70c 100644 (file)
@@ -90,11 +90,14 @@ public:
                            numLocals);
   }
 
-  // Get the domain/range space of this space. The returned space is a set
-  // space.
+  /// Get the domain/range space of this space. The returned space is a set
+  /// space.
   PresburgerSpace getDomainSpace() const;
   PresburgerSpace getRangeSpace() const;
 
+  /// Get the space without local variables.
+  PresburgerSpace getSpaceWithoutLocals() const;
+
   unsigned getNumDomainVars() const { return numDomain; }
   unsigned getNumRangeVars() const { return numRange; }
   unsigned getNumSetDimVars() const { return numRange; }
index c31d50a..64b9ba6 100644 (file)
@@ -172,6 +172,93 @@ void MultiAffineFunction::mergeDivs(MultiAffineFunction &other) {
   other.assertIsConsistent();
 }
 
+PresburgerSet
+MultiAffineFunction::getLexSet(OrderingKind comp,
+                               const MultiAffineFunction &other) const {
+  assert(getSpace().isCompatible(other.getSpace()) &&
+         "Output space of funcs should be compatible");
+
+  // Create copies of functions and merge their local space.
+  MultiAffineFunction funcA = *this;
+  MultiAffineFunction funcB = other;
+  funcA.mergeDivs(funcB);
+
+  // We first create the set `result`, corresponding to the set where output
+  // of funcA is lexicographically larger/smaller than funcB. This is done by
+  // creating a PresburgerSet with the following constraints:
+  //
+  //    (outA[0] > outB[0]) U
+  //    (outA[0] = outB[0], outA[1] > outA[1]) U
+  //    (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U
+  //    ...
+  //    (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1])
+  //
+  // where `n` is the number of outputs.
+  // If `lexMin` is set, the complement inequality is used:
+  //
+  //    (outA[0] < outB[0]) U
+  //    (outA[0] = outB[0], outA[1] < outA[1]) U
+  //    (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U
+  //    ...
+  //    (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1])
+  PresburgerSpace resultSpace = funcA.getDomainSpace();
+  PresburgerSet result =
+      PresburgerSet::getEmpty(resultSpace.getSpaceWithoutLocals());
+  IntegerPolyhedron levelSet(
+      /*numReservedInequalities=*/1 + 2 * resultSpace.getNumLocalVars(),
+      /*numReservedEqualities=*/funcA.getNumOutputs(),
+      /*numReservedCols=*/resultSpace.getNumVars() + 1, resultSpace);
+
+  // Add division inequalities to `levelSet`.
+  for (unsigned i = 0, e = funcA.getNumDivs(); i < e; ++i) {
+    levelSet.addInequality(getDivUpperBound(funcA.divs.getDividend(i),
+                                            funcA.divs.getDenom(i),
+                                            funcA.divs.getDivOffset() + i));
+    levelSet.addInequality(getDivLowerBound(funcA.divs.getDividend(i),
+                                            funcA.divs.getDenom(i),
+                                            funcA.divs.getDivOffset() + i));
+  }
+
+  for (unsigned level = 0; level < funcA.getNumOutputs(); ++level) {
+    // Create the expression `outA - outB` for this level.
+    SmallVector<MPInt, 8> subExpr =
+        subtractExprs(funcA.getOutputExpr(level), funcB.getOutputExpr(level));
+
+    // TODO: Implement all comparison cases.
+    switch (comp) {
+    case OrderingKind::LT:
+      // For less than, we add an upper bound of -1:
+      //        outA - outB <= -1
+      //        outA <= outB - 1
+      //        outA < outB
+      levelSet.addBound(IntegerPolyhedron::BoundType::UB, subExpr, MPInt(-1));
+      break;
+    case OrderingKind::GT:
+      // For greater than, we add a lower bound of 1:
+      //        outA - outB >= 1
+      //        outA > outB + 1
+      //        outA > outB
+      levelSet.addBound(IntegerPolyhedron::BoundType::LB, subExpr, MPInt(1));
+      break;
+    case OrderingKind::GE:
+    case OrderingKind::LE:
+    case OrderingKind::EQ:
+    case OrderingKind::NE:
+      assert(false && "Not implemented case");
+    }
+
+    // Union the set with the result.
+    result.unionInPlace(levelSet);
+    // The last inequality in `levelSet` is the bound we inserted. We remove
+    // that for next iteration.
+    levelSet.removeInequality(levelSet.getNumInequalities() - 1);
+    // Add equality `outA - outB == 0` for this level for next iteration.
+    levelSet.addEquality(subExpr);
+  }
+
+  return result;
+}
+
 /// Two PWMAFunctions are equal if they have the same dimensionalities,
 /// the same domain, and take the same value at every point in the domain.
 bool PWMAFunction::isEqual(const PWMAFunction &other) const {
@@ -195,6 +282,8 @@ bool PWMAFunction::isEqual(const PWMAFunction &other) const {
 
 void PWMAFunction::addPiece(const Piece &piece) {
   assert(piece.isConsistent() && "Piece should be consistent");
+  assert(piece.domain.intersect(getDomain()).isIntegerEmpty() &&
+         "Piece should be disjoint from the function");
   pieces.push_back(piece);
 }
 
@@ -263,85 +352,23 @@ PWMAFunction PWMAFunction::unionFunction(
 }
 
 /// A tiebreak function which breaks ties by comparing the outputs
-/// lexicographically. If `lexMin` is true, then the ties are broken by
-/// taking the lexicographically smaller output and otherwise, by taking the
-/// lexicographically larger output.
-template <bool lexMin>
+/// lexicographically based on the given comparison operator.
+/// This is templated since it is passed as a lambda.
+template <OrderingKind comp>
 static PresburgerSet tiebreakLex(const PWMAFunction::Piece &pieceA,
                                  const PWMAFunction::Piece &pieceB) {
-  // TODO: Support local variables here.
-  assert(pieceA.output.getSpace().isCompatible(pieceB.output.getSpace()) &&
-         "Pieces should be compatible");
-  assert(pieceA.domain.getSpace().getNumLocalVars() == 0 &&
-         "Local variables are not supported yet.");
-
-  PresburgerSpace compatibleSpace = pieceA.domain.getSpace();
-  const PresburgerSpace &space = pieceA.domain.getSpace();
-
-  // We first create the set `result`, corresponding to the set where output
-  // of pieceA is lexicographically larger/smaller than pieceB. This is done by
-  // creating a PresburgerSet with the following constraints:
-  //
-  //    (outA[0] > outB[0]) U
-  //    (outA[0] = outB[0], outA[1] > outA[1]) U
-  //    (outA[0] = outB[0], outA[1] = outA[1], outA[2] > outA[2]) U
-  //    ...
-  //    (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] > outB[n-1])
-  //
-  // where `n` is the number of outputs.
-  // If `lexMin` is set, the complement inequality is used:
-  //
-  //    (outA[0] < outB[0]) U
-  //    (outA[0] = outB[0], outA[1] < outA[1]) U
-  //    (outA[0] = outB[0], outA[1] = outA[1], outA[2] < outA[2]) U
-  //    ...
-  //    (outA[0] = outB[0], ..., outA[n-2] = outB[n-2], outA[n-1] < outB[n-1])
-  PresburgerSet result = PresburgerSet::getEmpty(compatibleSpace);
-  IntegerPolyhedron levelSet(
-      /*numReservedInequalities=*/1,
-      /*numReservedEqualities=*/pieceA.output.getNumOutputs(),
-      /*numReservedCols=*/space.getNumVars() + 1, space);
-  for (unsigned level = 0; level < pieceA.output.getNumOutputs(); ++level) {
-
-    // Create the expression `outA - outB` for this level.
-    SmallVector<MPInt, 8> subExpr = subtractExprs(
-        pieceA.output.getOutputExpr(level), pieceB.output.getOutputExpr(level));
-
-    if (lexMin) {
-      // For lexMin, we add an upper bound of -1:
-      //        outA - outB <= -1
-      //        outA <= outB - 1
-      //        outA < outB
-      levelSet.addBound(IntegerPolyhedron::BoundType::UB, subExpr, MPInt(-1));
-    } else {
-      // For lexMax, we add a lower bound of 1:
-      //        outA - outB >= 1
-      //        outA > outB + 1
-      //        outA > outB
-      levelSet.addBound(IntegerPolyhedron::BoundType::LB, subExpr, MPInt(1));
-    }
-
-    // Union the set with the result.
-    result.unionInPlace(levelSet);
-    // There is only 1 inequality in `levelSet`, so the index is always 0.
-    levelSet.removeInequality(0);
-    // Add equality `outA - outB == 0` for this level for next iteration.
-    levelSet.addEquality(subExpr);
-  }
-
-  // We then intersect `result` with the domain of pieceA and pieceB, to only
-  // tiebreak on the domain where both are defined.
+  PresburgerSet result = pieceA.output.getLexSet(comp, pieceB.output);
   result = result.intersect(pieceA.domain).intersect(pieceB.domain);
 
   return result;
 }
 
 PWMAFunction PWMAFunction::unionLexMin(const PWMAFunction &func) {
-  return unionFunction(func, tiebreakLex</*lexMin=*/true>);
+  return unionFunction(func, tiebreakLex</*comp=*/OrderingKind::LT>);
 }
 
 PWMAFunction PWMAFunction::unionLexMax(const PWMAFunction &func) {
-  return unionFunction(func, tiebreakLex</*lexMin=*/false>);
+  return unionFunction(func, tiebreakLex</*comp=*/OrderingKind::GT>);
 }
 
 void MultiAffineFunction::subtract(const MultiAffineFunction &other) {
index 648860c..e15db1e 100644 (file)
@@ -22,6 +22,12 @@ PresburgerSpace PresburgerSpace::getRangeSpace() const {
   return PresburgerSpace::getSetSpace(numRange, numSymbols, numLocals);
 }
 
+PresburgerSpace PresburgerSpace::getSpaceWithoutLocals() const {
+  PresburgerSpace space = *this;
+  space.removeVarRange(VarKind::Local, 0, numLocals);
+  return space;
+}
+
 unsigned PresburgerSpace::getNumVarKind(VarKind kind) const {
   if (kind == VarKind::Domain)
     return getNumDomainVars();
index cebc7fa..ee2931e 100644 (file)
@@ -395,3 +395,44 @@ TEST(PWMAFunction, unionLexMinComplex) {
   EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
   EXPECT_TRUE(func2.unionLexMin(func1).isEqual(result));
 }
+
+TEST(PWMAFunction, unionLexMinWithDivs) {
+  {
+    PWMAFunction func1 = parsePWMAF({
+        {"(x, y) : (x mod 5 == 0)", "(x, y) -> (x, 1)"},
+    });
+
+    PWMAFunction func2 = parsePWMAF({
+        {"(x, y) : (x mod 7 == 0)", "(x, y) -> (x + y, 2)"},
+    });
+
+    PWMAFunction result = parsePWMAF({
+        {"(x, y) : (x mod 5 == 0, x mod 7 >= 1)", "(x, y) -> (x, 1)"},
+        {"(x, y) : (x mod 7 == 0, x mod 5 >= 1)", "(x, y) -> (x + y, 2)"},
+        {"(x, y) : (x mod 5 == 0, x mod 7 == 0, y >= 0)", "(x, y) -> (x, 1)"},
+        {"(x, y) : (x mod 7 == 0, x mod 5 == 0, y <= -1)",
+         "(x, y) -> (x + y, 2)"},
+    });
+
+    EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
+  }
+
+  {
+    PWMAFunction func1 = parsePWMAF({
+        {"(x) : (x >= 0, x <= 1000)", "(x) -> (x floordiv 16)"},
+    });
+
+    PWMAFunction func2 = parsePWMAF({
+        {"(x) : (x >= 0, x <= 1000)", "(x) -> ((x + 10) floordiv 17)"},
+    });
+
+    PWMAFunction result = parsePWMAF({
+        {"(x) : (x >= 0, x <= 1000, x floordiv 16 <= (x + 10) floordiv 17)",
+         "(x) -> (x floordiv 16)"},
+        {"(x) : (x >= 0, x <= 1000, x floordiv 16 >= (x + 10) floordiv 17 + 1)",
+         "(x) -> ((x + 10) floordiv 17)"},
+    });
+
+    EXPECT_TRUE(func1.unionLexMin(func2).isEqual(result));
+  }
+}