[MLIR][Presburger] Introduce SimplexRollbackScopeExit to rollback on scope exit
authorArjun P <arjunpitchanathan@gmail.com>
Mon, 21 Mar 2022 19:59:17 +0000 (19:59 +0000)
committerArjun P <arjunpitchanathan@gmail.com>
Thu, 24 Mar 2022 00:27:47 +0000 (00:27 +0000)
This simplifies many places where we just want to do something in a "transient context"
and return some value.

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D122172

mlir/include/mlir/Analysis/Presburger/Simplex.h
mlir/lib/Analysis/Presburger/PresburgerRelation.cpp
mlir/lib/Analysis/Presburger/Simplex.cpp

index 8a4ea2a..90f092f 100644 (file)
@@ -664,6 +664,23 @@ private:
   void reduceBasis(Matrix &basis, unsigned level);
 };
 
+/// Takes a snapshot of the simplex state on construction and rolls back to the
+/// snapshot on destruction.
+///
+/// Useful for performing operations in a "transient context", all changes from
+/// which get rolled back on scope exit.
+class SimplexRollbackScopeExit {
+public:
+  SimplexRollbackScopeExit(Simplex &simplex) : simplex(simplex) {
+    snapshot = simplex.getSnapshot();
+  };
+  ~SimplexRollbackScopeExit() { simplex.rollback(snapshot); }
+
+private:
+  SimplexBase &simplex;
+  unsigned snapshot;
+};
+
 } // namespace presburger
 } // namespace mlir
 
index 934c80f..891f8d5 100644 (file)
@@ -232,12 +232,11 @@ static void subtractRecursively(IntegerRelation &b, Simplex &simplex,
   // inequality, s_{i,j+1}. This function recurses into the next level i + 1
   // with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}.
   auto recurseWithInequality = [&, i](ArrayRef<int64_t> ineq) {
-    size_t snapshot = simplex.getSnapshot();
+    SimplexRollbackScopeExit scopeExit(simplex);
     b.addInequality(ineq);
     simplex.addInequality(ineq);
     subtractRecursively(b, simplex, s, i + 1, result);
     b.removeInequality(b.getNumInequalities() - 1);
-    simplex.rollback(snapshot);
   };
 
   // For each inequality ineq, we first recurse with the part where ineq
@@ -519,16 +518,11 @@ PresburgerRelation SetCoalescer::coalesce() {
 /// that all inequalities of `cuttingIneqsB` are redundant for the facet of
 /// `simp` where `ineq` holds as an equality is contained within `a`.
 bool SetCoalescer::isFacetContained(ArrayRef<int64_t> ineq, Simplex &simp) {
-  unsigned snapshot = simp.getSnapshot();
+  SimplexRollbackScopeExit scopeExit(simp);
   simp.addEquality(ineq);
-  if (llvm::any_of(cuttingIneqsB, [&simp](ArrayRef<int64_t> curr) {
-        return !simp.isRedundantInequality(curr);
-      })) {
-    simp.rollback(snapshot);
-    return false;
-  }
-  simp.rollback(snapshot);
-  return true;
+  return llvm::all_of(cuttingIneqsB, [&simp](ArrayRef<int64_t> curr) {
+    return simp.isRedundantInequality(curr);
+  });
 }
 
 void SetCoalescer::addCoalescedDisjunct(unsigned i, unsigned j,
index a79bbb0..341bcbc 100644 (file)
@@ -888,11 +888,11 @@ MaybeOptimum<Fraction> Simplex::computeOptimum(Direction direction,
                                                ArrayRef<int64_t> coeffs) {
   if (empty)
     return OptimumKind::Empty;
-  unsigned snapshot = getSnapshot();
+
+  SimplexRollbackScopeExit scopeExit(*this);
   unsigned conIndex = addRow(coeffs);
   unsigned row = con[conIndex].pos;
   MaybeOptimum<Fraction> optimum = computeRowOptimum(direction, row);
-  rollback(snapshot);
   return optimum;
 }
 
@@ -1205,7 +1205,7 @@ public:
     // tableau before returning. We instead add a row for the objective function
     // ourselves, call into computeOptimum, compute the duals from the tableau
     // state, and finally rollback the addition of the row before returning.
-    unsigned snap = simplex.getSnapshot();
+    SimplexRollbackScopeExit scopeExit(simplex);
     unsigned conIndex = simplex.addRow(getCoeffsForDirection(dir));
     unsigned row = simplex.con[conIndex].pos;
     MaybeOptimum<Fraction> maybeWidth =
@@ -1248,7 +1248,6 @@ public:
       else
         dual.push_back(0);
     }
-    simplex.rollback(snap);
     return *maybeWidth;
   }