[MLIR][Presburger] Implement domain and range restriction for PresburgerRelation
authoriambrj <joshibharathiramana@gmail.com>
Tue, 18 Jul 2023 13:36:30 +0000 (19:06 +0530)
committerGroverkss <groverkss@gmail.com>
Tue, 18 Jul 2023 13:42:12 +0000 (19:12 +0530)
This patch implements domain and range restriction for PresburgerRelation

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D154798

mlir/include/mlir/Analysis/Presburger/PresburgerRelation.h
mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
mlir/unittests/Analysis/Presburger/PresburgerRelationTest.cpp

index adcab9b..caf6b30 100644 (file)
@@ -64,6 +64,8 @@ public:
   /// exceeds that of some disjunct, an assert failure will occur.
   void setSpace(const PresburgerSpace &oSpace);
 
+  void insertVarInPlace(VarKind kind, unsigned pos, unsigned num = 1);
+
   /// Return a reference to the list of disjuncts.
   ArrayRef<IntegerRelation> getAllDisjuncts() const;
 
@@ -83,6 +85,18 @@ public:
   /// Return the intersection of this set and the given set.
   PresburgerRelation intersect(const PresburgerRelation &set) const;
 
+  /// Intersect the given `set` with the range in-place.
+  ///
+  /// Formally, let the relation `this` be R: A -> B and `set` is C, then this
+  /// operation modifies R to be A -> (B intersection C).
+  PresburgerRelation intersectRange(PresburgerSet &set);
+
+  /// Intersect the given `set` with the domain in-place.
+  ///
+  /// Formally, let the relation `this` be R: A -> B and `set` is C, then this
+  /// operation modifies R to be (A intersection C) -> B.
+  PresburgerRelation intersectDomain(const PresburgerSet &set);
+
   /// Invert the relation, i.e. swap its domain and range.
   ///
   /// Formally, if `this`: A -> B then `inverse` updates `this` in-place to
index 4fe63a6..71625b9 100644 (file)
@@ -30,6 +30,13 @@ void PresburgerRelation::setSpace(const PresburgerSpace &oSpace) {
     disjunct.setSpaceExceptLocals(space);
 }
 
+void PresburgerRelation::insertVarInPlace(VarKind kind, unsigned pos,
+                                          unsigned num) {
+  for (IntegerRelation &cs : disjuncts)
+    cs.insertVar(kind, pos, num);
+  space.insertVar(kind, pos, num);
+}
+
 unsigned PresburgerRelation::getNumDisjuncts() const {
   return disjuncts.size();
 }
@@ -117,6 +124,26 @@ PresburgerRelation::intersect(const PresburgerRelation &set) const {
   return result;
 }
 
+PresburgerRelation PresburgerRelation::intersectRange(PresburgerSet &set) {
+  assert(space.getRangeSpace().isCompatible(set.getSpace()) &&
+         "Range of `this` must be compatible with range of `set`");
+
+  PresburgerRelation other = set;
+  other.insertVarInPlace(VarKind::Domain, 0, getNumDomainVars());
+  return intersect(other);
+}
+
+PresburgerRelation
+PresburgerRelation::intersectDomain(const PresburgerSet &set) {
+  assert(space.getDomainSpace().isCompatible(set.getSpace()) &&
+         "Domain of `this` must be compatible with range of `set`");
+
+  PresburgerRelation other = set;
+  other.insertVarInPlace(VarKind::Domain, 0, getNumDomainVars());
+  other.inverse();
+  return intersect(other);
+}
+
 void PresburgerRelation::inverse() {
   for (IntegerRelation &cs : disjuncts)
     cs.inverse();
index 8b6e3c4..7054ed0 100644 (file)
@@ -31,6 +31,73 @@ parsePresburgerRelationFromPresburgerSet(ArrayRef<StringRef> strs,
   return result;
 }
 
+TEST(PresburgerRelationTest, intersectDomainAndRange) {
+  PresburgerRelation rel = parsePresburgerRelationFromPresburgerSet(
+      {// (x, y) -> (x + N, y - N)
+       "(x, y, a, b)[N] : (x - a + N == 0, y - b - N == 0)",
+       // (x, y) -> (x + y, x - y)
+       "(x, y, a, b)[N] : (a - x - y == 0, b - x + y == 0)",
+       // (x, y) -> (x - y, y - x)}
+       "(x, y, a, b)[N] : (a - x + y == 0, b - y + x == 0)"},
+      2);
+
+  {
+    PresburgerSet set =
+        parsePresburgerSet({// (2x, x)
+                            "(a, b)[N] : (a - 2 * b == 0)",
+                            // (x, -x)
+                            "(a, b)[N] : (a + b == 0)",
+                            // (N, N)
+                            "(a, b)[N] : (a - N == 0, b - N == 0)"});
+
+    PresburgerRelation expectedRel = parsePresburgerRelationFromPresburgerSet(
+        {"(x, y, a, b)[N] : (x - a + N == 0, y - b - N == 0, x - 2 * y == 0)",
+         "(x, y, a, b)[N] : (x - a + N == 0, y - b - N == 0, x + y == 0)",
+         "(x, y, a, b)[N] : (x - a + N == 0, y - b - N == 0, x - N == 0, y - N "
+         "== 0)",
+         "(x, y, a, b)[N] : (a - x - y == 0, b - x + y == 0, x - 2 * y == 0)",
+         "(x, y, a, b)[N] : (a - x - y == 0, b - x + y == 0, x + y == 0)",
+         "(x, y, a, b)[N] : (a - x - y == 0, b - x + y == 0, x - N == 0, y - N "
+         "== 0)",
+         "(x, y, a, b)[N] : (a - x + y == 0, b - y + x == 0, x - 2 * y == 0)",
+         "(x, y, a, b)[N] : (a - x + y == 0, b - y + x == 0, x + y == 0)",
+         "(x, y, a, b)[N] : (a - x + y == 0, b - y + x == 0, x - N == 0, y - N "
+         "== 0)"},
+        2);
+
+    PresburgerRelation computedRel = rel.intersectDomain(set);
+    EXPECT_TRUE(computedRel.isEqual(expectedRel));
+  }
+
+  {
+    PresburgerSet set =
+        parsePresburgerSet({// (2x, x)
+                            "(a, b)[N] : (a - 2 * b == 0)",
+                            // (x, -x)
+                            "(a, b)[N] : (a + b == 0)",
+                            // (N, N)
+                            "(a, b)[N] : (a - N == 0, b - N == 0)"});
+
+    PresburgerRelation expectedRel = parsePresburgerRelationFromPresburgerSet(
+        {"(x, y, a, b)[N] : (x - a + N == 0, y - b - N == 0, a - 2 * b == 0)",
+         "(x, y, a, b)[N] : (x - a + N == 0, y - b - N == 0, a + b == 0)",
+         "(x, y, a, b)[N] : (x - a + N == 0, y - b - N == 0, a - N == 0, b - N "
+         "== 0)",
+         "(x, y, a, b)[N] : (a - x - y == 0, b - x + y == 0, a - 2 * b == 0)",
+         "(x, y, a, b)[N] : (a - x - y == 0, b - x + y == 0, a + b == 0)",
+         "(x, y, a, b)[N] : (a - x - y == 0, b - x + y == 0, a - N == 0, b - N "
+         "== 0)",
+         "(x, y, a, b)[N] : (a - x + y == 0, b - y + x == 0, a - 2 * b == 0)",
+         "(x, y, a, b)[N] : (a - x + y == 0, b - y + x == 0, a + b == 0)",
+         "(x, y, a, b)[N] : (a - x + y == 0, b - y + x == 0, a - N == 0, b - N "
+         "== 0)"},
+        2);
+
+    PresburgerRelation computedRel = rel.intersectRange(set);
+    EXPECT_TRUE(computedRel.isEqual(expectedRel));
+  }
+}
+
 TEST(PresburgerRelationTest, applyDomainAndRange) {
   {
     PresburgerRelation map1 = parsePresburgerRelationFromPresburgerSet(