From 93b9f50b4c6e84626f976df95602af3ecbb98ce4 Mon Sep 17 00:00:00 2001 From: Arjun P Date: Wed, 23 Mar 2022 23:40:20 +0000 Subject: [PATCH] [MLIR][Presburger] IntegerRelation: implement partial rollback support It is often necessary to "rollback" IntegerRelations to an earlier state. Although providing full rollback support is non-trivial, we really only need to support the case where the only changes made are to append ids or append constraints, and then rollback these additions. This patch adds support to rollback in such situations by recording the number of ids and constraints of each kind and providing support to truncate the IntegerRelation to those counts by removing appended ids and constraints. This already simplifies subtraction a little bit and will also be useful in the implementation of symbolic integer lexmin. Reviewed By: Groverkss Differential Revision: https://reviews.llvm.org/D122178 --- .../mlir/Analysis/Presburger/IntegerRelation.h | 29 ++++++++++++++++++++++ .../mlir/Analysis/Presburger/PresburgerSpace.h | 4 +++ mlir/lib/Analysis/Presburger/IntegerRelation.cpp | 20 +++++++++++++++ .../lib/Analysis/Presburger/PresburgerRelation.cpp | 11 +++----- mlir/lib/Analysis/Presburger/PresburgerSpace.cpp | 6 +++++ 5 files changed, 63 insertions(+), 7 deletions(-) diff --git a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h index ea2c6ae..41c0500 100644 --- a/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h +++ b/mlir/include/mlir/Analysis/Presburger/IntegerRelation.h @@ -148,6 +148,30 @@ public: return inequalities.getRow(idx); } + /// The struct CountsSnapshot stores the count of each IdKind, and also of + /// each constraint type. getCounts() returns a CountsSnapshot object + /// describing the current state of the IntegerRelation. truncate() truncates + /// all ids of each IdKind and all constraints of both kinds beyond the counts + /// in the specified CountsSnapshot object. This can be used to achieve + /// rudimentary rollback support. As long as none of the existing constraints + /// or ids are disturbed, and only additional ids or constraints are added, + /// this addition can be rolled back using truncate. + struct CountsSnapshot { + public: + CountsSnapshot(const PresburgerLocalSpace &space, unsigned numIneqs, + unsigned numEqs) + : space(space), numIneqs(numIneqs), numEqs(numEqs) {} + const PresburgerLocalSpace &getSpace() const { return space; }; + unsigned getNumIneqs() const { return numIneqs; } + unsigned getNumEqs() const { return numEqs; } + + private: + PresburgerLocalSpace space; + unsigned numIneqs, numEqs; + }; + CountsSnapshot getCounts() const; + void truncate(const CountsSnapshot &counts); + /// Insert `num` identifiers of the specified kind at position `pos`. /// Positions are relative to the kind of identifier. The coefficient columns /// corresponding to the added identifiers are initialized to zero. Return the @@ -491,6 +515,11 @@ protected: /// arrays as needed. void removeIdRange(unsigned idStart, unsigned idLimit); + using PresburgerSpace::truncateIdKind; + /// Truncate the ids to the number in the space of the specified + /// CountsSnapshot. + void truncateIdKind(IdKind kind, const CountsSnapshot &counts); + /// A parameter that controls detection of an unrealistic number of /// constraints. If the number of constraints is this many times the number of /// variables, we consider such a system out of line with the intended use diff --git a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h index f9b03e4..a832f00 100644 --- a/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h +++ b/mlir/include/mlir/Analysis/Presburger/PresburgerSpace.h @@ -109,6 +109,10 @@ public: /// idLimit). The range is relative to the kind of identifier. virtual void removeIdRange(IdKind kind, unsigned idStart, unsigned idLimit); + /// Truncate the ids of the specified kind to the specified number by dropping + /// some ids at the end. `num` must be less than the current number. + void truncateIdKind(IdKind kind, unsigned num); + /// Returns true if both the spaces are equal i.e. if both spaces have the /// same number of identifiers of each kind (excluding Local Identifiers). bool isEqual(const PresburgerSpace &other) const; diff --git a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp index 611615b..4b41c23 100644 --- a/mlir/lib/Analysis/Presburger/IntegerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/IntegerRelation.cpp @@ -126,6 +126,26 @@ void removeConstraintsInvolvingIdRange(IntegerRelation &poly, unsigned begin, if (!rangeIsZero(poly.getInequality(i - 1).slice(begin, count))) poly.removeInequality(i - 1); } + +IntegerRelation::CountsSnapshot IntegerRelation::getCounts() const { + return {PresburgerLocalSpace(*this), getNumInequalities(), + getNumEqualities()}; +} + +void IntegerRelation::truncateIdKind(IdKind kind, + const CountsSnapshot &counts) { + truncateIdKind(kind, counts.getSpace().getNumIdKind(kind)); +} + +void IntegerRelation::truncate(const CountsSnapshot &counts) { + truncateIdKind(IdKind::Domain, counts); + truncateIdKind(IdKind::Range, counts); + truncateIdKind(IdKind::Symbol, counts); + truncateIdKind(IdKind::Local, counts); + removeInequalityRange(counts.getNumIneqs(), getNumInequalities()); + removeInequalityRange(counts.getNumEqs(), getNumEqualities()); +} + unsigned IntegerRelation::insertId(IdKind kind, unsigned pos, unsigned num) { assert(pos <= getNumIdKind(kind)); diff --git a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp index 37a3b78..934c80f 100644 --- a/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerRelation.cpp @@ -153,16 +153,12 @@ static void subtractRecursively(IntegerRelation &b, Simplex &simplex, // rollback b to its initial state before returning, which we will do by // removing all constraints beyond the original number of inequalities // and equalities, so we store these counts first. - const unsigned bInitNumIneqs = b.getNumInequalities(); - const unsigned bInitNumEqs = b.getNumEqualities(); - const unsigned bInitNumLocals = b.getNumLocalIds(); + const IntegerRelation::CountsSnapshot bCounts = b.getCounts(); // Similarly, we also want to rollback simplex to its original state. const unsigned initialSnapshot = simplex.getSnapshot(); auto restoreState = [&]() { - b.removeIdRange(IdKind::Local, bInitNumLocals, b.getNumLocalIds()); - b.removeInequalityRange(bInitNumIneqs, b.getNumInequalities()); - b.removeEqualityRange(bInitNumEqs, b.getNumEqualities()); + b.truncate(bCounts); simplex.rollback(initialSnapshot); }; @@ -198,7 +194,8 @@ static void subtractRecursively(IntegerRelation &b, Simplex &simplex, } unsigned offset = simplex.getNumConstraints(); - unsigned numLocalsAdded = b.getNumLocalIds() - bInitNumLocals; + unsigned numLocalsAdded = + b.getNumLocalIds() - bCounts.getSpace().getNumLocalIds(); simplex.appendVariable(numLocalsAdded); unsigned snapshotBeforeIntersect = simplex.getSnapshot(); diff --git a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp index cbeba24..c4e06aa 100644 --- a/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp +++ b/mlir/lib/Analysis/Presburger/PresburgerSpace.cpp @@ -91,6 +91,12 @@ void PresburgerSpace::removeIdRange(IdKind kind, unsigned idStart, llvm_unreachable("PresburgerSpace does not support local identifiers!"); } +void PresburgerSpace::truncateIdKind(IdKind kind, unsigned num) { + unsigned curNum = getNumIdKind(kind); + assert(num <= curNum && "Can't truncate to more ids!"); + removeIdRange(kind, num, curNum); +} + unsigned PresburgerLocalSpace::insertId(IdKind kind, unsigned pos, unsigned num) { if (kind == IdKind::Local) { -- 2.7.4