[MLIR][Presburger] subtract: fix bug in the non-recursive implementation
authorArjun P <arjunpitchanathan@gmail.com>
Sun, 10 Apr 2022 11:25:14 +0000 (12:25 +0100)
committerArjun P <arjunpitchanathan@gmail.com>
Mon, 11 Apr 2022 19:45:29 +0000 (20:45 +0100)
When making the subtract implementation non-recursive, tail calls were
implemented by incrementing the level but not pushing a frame, and returning
was implemented as returning to the level corresponding to the number of frames in the stack.

This is incorrect, as there could be a case where we tail-recurse at `level`,
and then recurse at `level + 1`, pushing a frame. However, because the previous
frame was missing, this new frame would be interpreted as corresponding to
`level` and not `level + 1`. Fix this by removing the special handling of tail
calls and just doing them as normal recursion, as this is the simplest correct
implementation and handling them specifically would be a premature optimization.

The impact of this bug is only on performance as this can only lead to
unnecessary subtractions of the same disjuncts multiples times. As subtraction
is idempotent, and rationally empty disjuncts are always discarded, this
does not affect the output, so this patch does not include a regression test.
(This also does not affect termination.)

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D123327

mlir/lib/Analysis/Presburger/PresburgerRelation.cpp

index 1bc77ed..7fc76b6 100644 (file)
@@ -196,9 +196,7 @@ static PresburgerRelation getSetDifference(IntegerRelation b,
   SmallVector<Frame, 2> frames;
 
   // When we "recurse", we ensure the current frame is stored in `frames` and
-  // increment `level`. When we "tail recurse", we just increment `level`,
-  // without storing any frame. Accordingly, when we return, we return to the
-  // last level that has a frame associated with it.
+  // increment `level`. When we return, we decrement `level`.
   unsigned level = 1;
   while (level > 0) {
     if (level - 1 >= s.getNumDisjuncts()) {
@@ -266,10 +264,17 @@ static PresburgerRelation getSetDifference(IntegerRelation b,
       if (simplex.isEmpty()) {
         // b /\ s_i is empty, so b \ s_i = b. We move directly to i + 1.
         // We are ignoring level i completely, so we restore the state
-        // *before* going to the next level. We are "tail recursing", so
-        // we don't add a frame before going to the next level.
+        // *before* going to the next level.
         b.truncate(initBCounts);
         simplex.rollback(initialSnapshot);
+        // Recurse. We haven't processed any inequalities and
+        // we don't need to process anything when we return.
+        //
+        // TODO: consider supporting tail recursion directly if this becomes
+        // relevant for performance.
+        frames.push_back(Frame{initialSnapshot, initBCounts, sI,
+                               /*ineqsToProcess=*/{},
+                               /*lastIneqProcessed=*/{}});
         ++level;
         continue;
       }