From 00b293e83f6bb84f970eea972f022d578923d832 Mon Sep 17 00:00:00 2001 From: Arjun P Date: Thu, 7 Apr 2022 14:41:04 +0100 Subject: [PATCH] [MLIR][Presburger] refactor subtraction to be non-recursive Subtraction was previously implemented recursively. This refactors it to be non-recursive to avoid issues with potential stack overflows. Reviewed By: Groverkss Differential Revision: https://reviews.llvm.org/D123248 --- .../lib/Analysis/Presburger/PresburgerRelation.cpp | 334 ++++++++++++--------- 1 file changed, 187 insertions(+), 147 deletions(-) diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp index 1515897..1bc77ed 100644 --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -129,18 +129,17 @@ static SmallVector getIneqCoeffsFromIdx(const IntegerRelation &rel, return getNegatedCoeffs(eqCoeffs); } -/// Return the set difference b \ s and accumulate the result into `result`. -/// `simplex` must correspond to b. +/// Return the set difference b \ s. /// -/// In the following, U denotes union, ^ denotes intersection, \ denotes set +/// In the following, U denotes union, /\ denotes intersection, \ denotes set /// difference and ~ denotes complement. -/// Let b be the IntegerRelation and s = (U_i s_i) be the set. We want -/// b \ (U_i s_i). /// -/// Let s_i = ^_j s_ij, where each s_ij is a single inequality. To compute -/// b \ s_i = b ^ ~s_i, we partition s_i based on the first violated inequality: -/// ~s_i = (~s_i1) U (s_i1 ^ ~s_i2) U (s_i1 ^ s_i2 ^ ~s_i3) U ... -/// And the required result is (b ^ ~s_i1) U (b ^ s_i1 ^ ~s_i2) U ... +/// Let s = (U_i s_i). We want b \ (U_i s_i). +/// +/// Let s_i = /\_j s_ij, where each s_ij is a single inequality. To compute +/// b \ s_i = b /\ ~s_i, we partition s_i based on the first violated +/// inequality: ~s_i = (~s_i1) U (s_i1 /\ ~s_i2) U (s_i1 /\ s_i2 /\ ~s_i3) U ... +/// And the required result is (b /\ ~s_i1) U (b /\ s_i1 /\ ~s_i2) U ... /// We recurse by subtracting U_{j > i} S_j from each of these parts and /// returning the union of the results. Each equality is handled as a /// conjunction of two inequalities. @@ -162,151 +161,192 @@ static SmallVector getIneqCoeffsFromIdx(const IntegerRelation &rel, /// that some constraints are redundant. These redundant constraints are /// ignored. /// -/// b should not have duplicate divs because this might lead to existing -/// divs disappearing in the call to mergeLocalIds below, which cannot be -/// handled. -static void subtractRecursively(IntegerRelation &b, Simplex &simplex, - const PresburgerRelation &s, unsigned i, - PresburgerRelation &result) { - - if (i == s.getNumDisjuncts()) { - result.unionInPlace(b); - return; - } +static PresburgerRelation getSetDifference(IntegerRelation b, + const PresburgerRelation &s) { + assert(b.isSpaceCompatible(s) && "Spaces should match"); + if (b.isEmptyByGCDTest()) + return PresburgerRelation::getEmpty(b.getSpaceWithoutLocals()); - IntegerRelation sI = s.getDisjunct(i); - // Remove the duplicate divs up front to avoid them possibly disappearing - // in the call to mergeLocalIds below. - sI.removeDuplicateDivs(); - - // Below, we append some additional constraints and ids to b. We want to - // rollback b to its initial state before returning, which we will do by - // removing all constraints beyond the original number of inequalities - // and equalities, so we store these counts first. - IntegerRelation::CountsSnapshot initBCounts = b.getCounts(); - // Similarly, we also want to rollback simplex to its original state. - unsigned initialSnapshot = simplex.getSnapshot(); - - // Find out which inequalities of sI correspond to division inequalities for - // the local variables of sI. - std::vector repr(sI.getNumLocalIds()); - sI.getLocalReprs(repr); - - // Add sI's locals to b, after b's locals. Also add b's locals to sI, before - // sI's locals. - b.mergeLocalIds(sI); - unsigned numLocalsAdded = - b.getNumLocalIds() - initBCounts.getSpace().getNumLocalIds(); - // Update simplex to also include the new locals in `b` from merging. - simplex.appendVariable(numLocalsAdded); - - // Equalities are processed by considering them as a pair of inequalities. - // The first sI.getNumInequalities() elements are for sI's inequalities; - // then a pair of inequalities occurs for each of sI's equalities. - // If the equality is expr == 0, the first element in the pair - // corresponds to expr >= 0, and the second to expr <= 0. - llvm::SmallBitVector canIgnoreIneq(sI.getNumInequalities() + - 2 * sI.getNumEqualities()); - - // Add all division inequalities to `b`. - for (MaybeLocalRepr &maybeInequality : repr) { - assert(maybeInequality.kind == ReprKind::Inequality && - "Subtraction is not supported when a representation of the local " - "variables of the subtrahend cannot be found!"); - unsigned lb = maybeInequality.repr.inequalityPair.lowerBoundIdx; - unsigned ub = maybeInequality.repr.inequalityPair.upperBoundIdx; - - b.addInequality(sI.getInequality(lb)); - b.addInequality(sI.getInequality(ub)); - - assert(lb != ub && - "Upper and lower bounds must be different inequalities!"); - - // We just added these inequalities to `b`, so there is no point considering - // the parts where these inequalities occur complemented -- such parts are - // empty. Therefore, we mark that these can be ignored. - canIgnoreIneq[lb] = true; - canIgnoreIneq[ub] = true; - } - - unsigned offset = simplex.getNumConstraints(); - unsigned snapshotBeforeIntersect = simplex.getSnapshot(); - simplex.intersectIntegerRelation(sI); - - if (simplex.isEmpty()) { - // b ^ s_i is empty, so b \ s_i = b. We move directly to i + 1. - // We are ignoring level i completely, so we restore the state - // *before* going to level i + 1. - b.truncate(initBCounts); - simplex.rollback(initialSnapshot); - subtractRecursively(b, simplex, s, i + 1, result); - return; - } + // Remove duplicate divs up front here to avoid existing + // divs disappearing in the call to mergeLocalIds below. + b.removeDuplicateDivs(); - simplex.detectRedundant(); - - unsigned totalNewSimplexInequalities = - 2 * sI.getNumEqualities() + sI.getNumInequalities(); - // Redundant inequalities can be safely ignored. This is not required for - // correctness but improves performance and results in a more compact - // representation of the set difference. - for (unsigned j = 0; j < totalNewSimplexInequalities; j++) - canIgnoreIneq[j] = simplex.isMarkedRedundant(offset + j); - simplex.rollback(snapshotBeforeIntersect); - - SmallVector ineqsToProcess(totalNewSimplexInequalities); - for (unsigned i = 0; i < totalNewSimplexInequalities; ++i) - if (!canIgnoreIneq[i]) - ineqsToProcess.push_back(i); - - // Recurse with the part b ^ ~ineq. Note that b is modified throughout - // subtractRecursively. At the time this function is called, the current b is - // actually equal to b ^ s_i1 ^ s_i2 ^ ... ^ s_ij, and ineq is the next - // inequality, s_{i,j+1}. This function recurses into the next level i + 1 - // with the part b ^ s_i1 ^ s_i2 ^ ... ^ s_ij ^ ~s_{i,j+1}. - auto recurseWithInequality = [&, i](ArrayRef ineq) { - b.addInequality(ineq); - simplex.addInequality(ineq); - subtractRecursively(b, simplex, s, i + 1, result); + PresburgerRelation result = + PresburgerRelation::getEmpty(b.getSpaceWithoutLocals()); + Simplex simplex(b); + + // This algorithm is more naturally expressed recursively, but we implement + // it iteratively here to avoid issues with stack sizes. + // + // Each level of the recursion has five stack variables. + struct Frame { + // A snapshot of the simplex state to rollback to. + unsigned simplexSnapshot; + // A CountsSnapshot of `b` to rollback to. + IntegerRelation::CountsSnapshot bCounts; + // The IntegerRelation currently being operated on. + IntegerRelation sI; + // A list of indexes (see getIneqCoeffsFromIdx) of inequalities to be + // processed. + SmallVector ineqsToProcess; + // The index of the last inequality that was processed at this level. + // This is empty when we are coming to this level for the first time. + Optional lastIneqProcessed; }; + SmallVector frames; + + // When we "recurse", we ensure the current frame is stored in `frames` and + // increment `level`. When we "tail recurse", we just increment `level`, + // without storing any frame. Accordingly, when we return, we return to the + // last level that has a frame associated with it. + unsigned level = 1; + while (level > 0) { + if (level - 1 >= s.getNumDisjuncts()) { + // No more parts to subtract; add to the result and return. + result.unionInPlace(b); + level = frames.size(); + continue; + } - // For each inequality ineq, we first recurse with the part where ineq - // is not satisfied, and then add the ineq to b and simplex because - // ineq must be satisfied by all later parts. - auto processInequality = [&](ArrayRef ineq) { - unsigned snapshot = simplex.getSnapshot(); - IntegerRelation::CountsSnapshot bCounts = b.getCounts(); - recurseWithInequality(getComplementIneq(ineq)); - simplex.rollback(snapshot); - b.truncate(bCounts); - - b.addInequality(ineq); - simplex.addInequality(ineq); - }; + if (level > frames.size()) { + // No frame for this level yet, so we have just recursed into this level. + IntegerRelation sI = s.getDisjunct(level - 1); + // Remove the duplicate divs up front to avoid them possibly disappearing + // in the call to mergeLocalIds below. + sI.removeDuplicateDivs(); + + // Below, we append some additional constraints and ids to b. We want to + // rollback b to its initial state before returning, which we will do by + // removing all constraints beyond the original number of inequalities + // and equalities, so we store these counts first. + IntegerRelation::CountsSnapshot initBCounts = b.getCounts(); + // Similarly, we also want to rollback simplex to its original state. + unsigned initialSnapshot = simplex.getSnapshot(); + + // Find out which inequalities of sI correspond to division inequalities + // for the local variables of sI. + std::vector repr(sI.getNumLocalIds()); + sI.getLocalReprs(repr); + + // Add sI's locals to b, after b's locals. Only those locals of sI which + // do not already exist in b will be added. (i.e., duplicate divisions + // will not be added.) Also add b's locals to sI, in such a way that both + // have the same locals in the same order in the end. + b.mergeLocalIds(sI); + + // Mark which inequalities of sI are division inequalities and add all + // such inequalities to b. + llvm::SmallBitVector canIgnoreIneq(sI.getNumInequalities() + + 2 * sI.getNumEqualities()); + for (MaybeLocalRepr &maybeInequality : repr) { + assert( + maybeInequality.kind == ReprKind::Inequality && + "Subtraction is not supported when a representation of the local " + "variables of the subtrahend cannot be found!"); + unsigned lb = maybeInequality.repr.inequalityPair.lowerBoundIdx; + unsigned ub = maybeInequality.repr.inequalityPair.upperBoundIdx; + + b.addInequality(sI.getInequality(lb)); + b.addInequality(sI.getInequality(ub)); + + assert(lb != ub && + "Upper and lower bounds must be different inequalities!"); + canIgnoreIneq[lb] = true; + canIgnoreIneq[ub] = true; + } - for (unsigned idx : ineqsToProcess) - processInequality(getIneqCoeffsFromIdx(sI, idx)); -} + unsigned offset = simplex.getNumConstraints(); + unsigned numLocalsAdded = + b.getNumLocalIds() - initBCounts.getSpace().getNumLocalIds(); + simplex.appendVariable(numLocalsAdded); + + unsigned snapshotBeforeIntersect = simplex.getSnapshot(); + simplex.intersectIntegerRelation(sI); + + if (simplex.isEmpty()) { + // b /\ s_i is empty, so b \ s_i = b. We move directly to i + 1. + // We are ignoring level i completely, so we restore the state + // *before* going to the next level. We are "tail recursing", so + // we don't add a frame before going to the next level. + b.truncate(initBCounts); + simplex.rollback(initialSnapshot); + ++level; + continue; + } -/// Return the set difference disjunct \ set. -/// -/// The disjunct here is modified in subtractRecursively, so it cannot be a -/// const reference even though it is restored to its original state before -/// returning from that function. -static PresburgerRelation getSetDifference(IntegerRelation disjunct, - const PresburgerRelation &set) { - assert(disjunct.isSpaceCompatible(set) && "Spaces should match"); - if (disjunct.isEmptyByGCDTest()) - return PresburgerRelation::getEmpty(disjunct.getSpaceWithoutLocals()); - - // Remove duplicate divs up front here as subtractRecursively does not support - // this set having duplicate divs. - disjunct.removeDuplicateDivs(); + simplex.detectRedundant(); + + // Equalities are added to simplex as a pair of inequalities. + unsigned totalNewSimplexInequalities = + 2 * sI.getNumEqualities() + sI.getNumInequalities(); + for (unsigned j = 0; j < totalNewSimplexInequalities; j++) + canIgnoreIneq[j] = simplex.isMarkedRedundant(offset + j); + simplex.rollback(snapshotBeforeIntersect); + + SmallVector ineqsToProcess(totalNewSimplexInequalities); + for (unsigned i = 0; i < totalNewSimplexInequalities; ++i) + if (!canIgnoreIneq[i]) + ineqsToProcess.push_back(i); + + if (ineqsToProcess.empty()) { + // Nothing to process; return. (we have no frame to pop.) + level = frames.size(); + continue; + } + + unsigned simplexSnapshot = simplex.getSnapshot(); + IntegerRelation::CountsSnapshot bCounts = b.getCounts(); + frames.push_back(Frame{simplexSnapshot, bCounts, sI, ineqsToProcess, + /*lastIneqProcessed=*/llvm::None}); + // We have completed the initial setup for this level. + // Fallthrough to the main recursive part below. + } + + // For each inequality ineq, we first recurse with the part where ineq + // is not satisfied, and then add ineq to b and simplex because + // ineq must be satisfied by all later parts. + if (level == frames.size()) { + Frame &frame = frames.back(); + if (frame.lastIneqProcessed) { + // Let the current value of b be b' and + // let the initial value of b when we first came to this level be b. + // + // b' is equal to b /\ s_i1 /\ s_i2 /\ ... /\ s_i{j-1} /\ ~s_ij. + // We had previously recursed with the part where s_ij was not + // satisfied; all further parts satisfy s_ij, so we rollback to the + // state before adding this complement constraint, and add s_ij to b. + simplex.rollback(frame.simplexSnapshot); + b.truncate(frame.bCounts); + SmallVector ineq = + getIneqCoeffsFromIdx(frame.sI, *frame.lastIneqProcessed); + b.addInequality(ineq); + simplex.addInequality(ineq); + } + + if (frame.ineqsToProcess.empty()) { + // No ineqs left to process; pop this level's frame and return. + frames.pop_back(); + level = frames.size(); + continue; + } + + // "Recurse" with the part where the ineq is not satisfied. + frame.bCounts = b.getCounts(); + frame.simplexSnapshot = simplex.getSnapshot(); + + unsigned idx = frame.ineqsToProcess.back(); + SmallVector ineq = + getComplementIneq(getIneqCoeffsFromIdx(frame.sI, idx)); + b.addInequality(ineq); + simplex.addInequality(ineq); + + frame.ineqsToProcess.pop_back(); + frame.lastIneqProcessed = idx; + ++level; + continue; + } + } - PresburgerRelation result = - PresburgerRelation::getEmpty(disjunct.getSpaceWithoutLocals()); - Simplex simplex(disjunct); - subtractRecursively(disjunct, simplex, set, 0, result); return result; } -- 2.7.4