[MLIR][Presburger] PresburgerSet::subtract: automatically restore state on return
authorArjun P <arjunpitchanathan@gmail.com>
Tue, 22 Feb 2022 17:28:21 +0000 (17:28 +0000)
committerArjun P <arjunpitchanathan@gmail.com>
Wed, 23 Feb 2022 19:20:44 +0000 (19:20 +0000)
Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D120339

mlir/lib/Analysis/Presburger/PresburgerSet.cpp

index b75cd1d..8294629 100644 (file)
@@ -10,6 +10,7 @@
 #include "mlir/Analysis/Presburger/Simplex.h"
 #include "mlir/Analysis/Presburger/Utils.h"
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/ScopeExit.h"
 #include "llvm/ADT/SmallBitVector.h"
 
 using namespace mlir;
@@ -198,7 +199,6 @@ static void subtractRecursively(IntegerPolyhedron &b, Simplex &simplex,
   // Similarly, we also want to rollback simplex to its original state.
   const unsigned initialSnapshot = simplex.getSnapshot();
 
-  // Automatically restore the original state when we return.
   auto restoreState = [&]() {
     b.removeIdRange(IntegerPolyhedron::IdKind::Local, bInitNumLocals,
                     b.getNumLocalIds());
@@ -207,6 +207,9 @@ static void subtractRecursively(IntegerPolyhedron &b, Simplex &simplex,
     simplex.rollback(initialSnapshot);
   };
 
+  // Automatically restore the original state when we return.
+  auto stateRestorer = llvm::make_scope_exit(restoreState);
+
   // Find out which inequalities of sI correspond to division inequalities for
   // the local variables of sI.
   std::vector<MaybeLocalRepr> repr(sI.getNumLocalIds());
@@ -243,11 +246,16 @@ static void subtractRecursively(IntegerPolyhedron &b, Simplex &simplex,
   simplex.intersectIntegerPolyhedron(sI);
 
   if (simplex.isEmpty()) {
-    /// b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
-    /// We are ignoring level i completely, so we restore the state
-    /// *before* going to level i + 1.
+    // b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1.
+    // We are ignoring level i completely, so we restore the state
+    // *before* going to level i + 1.
     restoreState();
     subtractRecursively(b, simplex, s, i + 1, result);
+
+    // We already restored the state above and the recursive call should have
+    // restored to the same state before returning, so we don't need to restore
+    // the state again.
+    stateRestorer.release();
     return;
   }
 
@@ -309,8 +317,6 @@ static void subtractRecursively(IntegerPolyhedron &b, Simplex &simplex,
     if (!isMarkedRedundant[offset + 2 * j + 1])
       processInequality(getNegatedCoeffs(coeffs));
   }
-
-  restoreState();
 }
 
 /// Return the set difference poly \ set.