Introduce Fourier-Motzkin variable elimination + other cleanup/support
authorUday Bondhugula <bondhugula@google.com>
Thu, 25 Oct 2018 15:33:02 +0000 (08:33 -0700)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 20:38:24 +0000 (13:38 -0700)
- Introduce Fourier-Motzkin variable elimination to eliminate a dimension from
  a system of linear equalities/inequalities. Update isEmpty to use this.
  Since FM is only exact on rational/real spaces, an emptiness check based on
  this is guaranteed to be exact whenever it says the underlying set is empty;
  if it says, it's not empty, there may still be no integer points in it.
  Also, supports a version that computes "dark shadows".

- Test this by checking for "always false" conditionals in if statements.

- Unique IntegerSet's that are small (few constraints, few variables). This
  basically means the canonical empty set and other small sets that are
  likely commonly used get uniqued; allows checking for the canonical empty set
  by pointer. IntegerSet::kUniquingThreshold gives the threshold constraint size
  for uniqui'ing.

- rename simplify-affine-expr -> simplify-affine-structures

Other cleanup

- IntegerSet::numConstraints, AffineMap::numResults are no longer needed;
  remove them.
- add copy assignment operators for AffineMap, IntegerSet.
- rename Invalid() -> Null() on AffineExpr, AffineMap, IntegerSet
- Misc cleanup for FlatAffineConstraints API

PiperOrigin-RevId: 218690456

20 files changed:
mlir/include/mlir/Analysis/AffineStructures.h
mlir/include/mlir/IR/AffineMap.h
mlir/include/mlir/IR/IntegerSet.h
mlir/include/mlir/IR/Statements.h
mlir/include/mlir/Support/MathExtras.h
mlir/include/mlir/Transforms/Passes.h
mlir/include/mlir/Transforms/Utils.h
mlir/lib/Analysis/AffineAnalysis.cpp
mlir/lib/Analysis/AffineStructures.cpp
mlir/lib/IR/AffineMap.cpp
mlir/lib/IR/AffineMapDetail.h
mlir/lib/IR/Builders.cpp
mlir/lib/IR/IntegerSet.cpp
mlir/lib/IR/IntegerSetDetail.h
mlir/lib/IR/MLIRContext.cpp
mlir/lib/Parser/Parser.cpp
mlir/lib/Transforms/LoopUtils.cpp
mlir/lib/Transforms/SimplifyAffineExpr.cpp
mlir/test/Transforms/simplify.mlir
mlir/tools/mlir-opt/mlir-opt.cpp

index 96add34ad0a61c5bdeec01b6b3540c56d7d87a5c..bff3da9a11bc4e9738d7c96a4ced53082ddc3e52 100644 (file)
@@ -232,12 +232,16 @@ public:
   /// Construct a constraint system reserving memory for the specified number of
   /// constraints and identifiers..
   FlatAffineConstraints(unsigned numReservedInequalities,
-                        unsigned numReservedEqualities, unsigned numReservedIds)
+                        unsigned numReservedEqualities,
+                        unsigned numReservedCols, unsigned numDims = 0,
+                        unsigned numSymbols = 0, unsigned numLocals = 0)
       : numReservedEqualities(numReservedEqualities),
-        numReservedInequalities(numReservedInequalities),
-        numReservedIds(numReservedIds) {
-    equalities.reserve(numReservedIds * numReservedEqualities);
-    inequalities.reserve(numReservedIds * numReservedInequalities);
+        numReservedInequalities(numReservedInequalities), numDims(numDims),
+        numSymbols(numSymbols) {
+    assert(numReservedCols >= 1 && "minimum 1 column");
+    equalities.reserve(numReservedCols * numReservedEqualities);
+    inequalities.reserve(numReservedCols * numReservedInequalities);
+    numIds = numDims + numSymbols + numLocals;
   }
 
   explicit FlatAffineConstraints(const HyperRectangularSet &set);
@@ -255,8 +259,10 @@ public:
   // TODO(bondhugula)
   explicit FlatAffineConstraints(const IntegerValueSet &set);
 
+  FlatAffineConstraints(const FlatAffineConstraints &other);
+
   FlatAffineConstraints(ArrayRef<const AffineValueMap *> avmRef,
-                        const IntegerSet &set);
+                        IntegerSet set);
 
   FlatAffineConstraints(const MutableAffineMap &map);
 
@@ -267,23 +273,25 @@ public:
   // constraints.
   // Returns true if the GCD test fails for any equality, or if any invalid
   // constraints are discovered on any row. Returns false otherwise.
-  // TODO(andydavis) Change this method to operate on cloned constraints.
-  bool isEmpty();
+  bool isEmpty() const;
 
   // Eliminates a single identifier at 'position' from equality and inequality
   // constraints. Returns 'true' if the identifier was eliminated.
   // Returns 'false' otherwise.
-  bool eliminateIdentifier(unsigned position);
+  bool gaussianEliminateId(unsigned position);
 
   // Eliminates identifiers from equality and inequality constraints
   // in column range [posStart, posLimit).
   // Returns the number of variables eliminated.
-  unsigned eliminateIdentifiers(unsigned posStart, unsigned posLimit);
+  unsigned gaussianEliminateIds(unsigned posStart, unsigned posLimit);
 
+  // Clones this object.
+  std::unique_ptr<FlatAffineConstraints> clone() const;
+
+  /// Returns the value at the specified equality row and column.
   inline int64_t atEq(unsigned i, unsigned j) const {
     return equalities[i * (numIds + 1) + j];
   }
-
   inline int64_t &atEq(unsigned i, unsigned j) {
     return equalities[i * (numIds + 1) + j];
   }
@@ -322,11 +330,11 @@ public:
     return inequalities.size() / getNumCols();
   }
 
-  ArrayRef<int64_t> getEquality(unsigned idx) {
+  inline ArrayRef<int64_t> getEquality(unsigned idx) const {
     return ArrayRef<int64_t>(&equalities[idx * getNumCols()], getNumCols());
   }
 
-  ArrayRef<int64_t> getInequality(unsigned idx) {
+  inline ArrayRef<int64_t> getInequality(unsigned idx) const {
     return ArrayRef<int64_t>(&inequalities[idx * getNumCols()], getNumCols());
   }
 
@@ -340,13 +348,25 @@ public:
   void addSymbolId(unsigned pos);
   void addLocalId(unsigned pos);
 
+  /// Eliminates identifier at the specified position using Fourier-Motzkin
+  /// variable elimination. If the result of the elimination is integer exact,
+  /// *isResultIntegerExact is set to true. If 'darkShadow' is set to true, a
+  /// potential under approximation (subset) of the rational shadow / exact
+  /// integer shadow is computed.
+  // See implementation comments for more details.
+  bool FourierMotzkinEliminate(unsigned pos, bool darkShadow = false,
+                               bool *isResultIntegerExact = nullptr);
+
   void removeId(IdKind idKind, unsigned pos);
+  void removeId(unsigned pos);
+
+  void removeDim(unsigned pos);
 
   void removeEquality(unsigned pos);
   void removeInequality(unsigned pos);
 
   unsigned getNumConstraints() const {
-    return equalities.size() + inequalities.size();
+    return getNumInequalities() + getNumEqualities();
   }
   inline unsigned getNumIds() const { return numIds; }
   inline unsigned getNumResultDimIds() const { return numResultDims; }
@@ -356,6 +376,12 @@ public:
     return numIds - numResultDims - numDims - numSymbols;
   }
 
+  /// Clears this list of constraints and copies other into it.
+  void clearAndCopyFrom(const FlatAffineConstraints &other);
+
+  // More expensive ones.
+  void removeDuplicates();
+
   void print(raw_ostream &os) const;
   void dump() const;
 
index 1e98d6d7979d49c3dfc11ba12b6fc27ad8f68de3..7586a00f902385c9641f55bafbc88076a7e425a7 100644 (file)
@@ -47,7 +47,7 @@ public:
   using ImplType = detail::AffineMapStorage;
 
   explicit AffineMap(ImplType *map = nullptr) : map(map) {}
-  static AffineMap Invalid() { return AffineMap(nullptr); }
+  static AffineMap Null() { return AffineMap(nullptr); }
 
   static AffineMap get(unsigned dimCount, unsigned symbolCount,
                        ArrayRef<AffineExpr> results,
index 7ab9db476a441fdf57991d670e4fb85a46e9d81d..71c6ea6f79aacad4059d35533354d1c694ef758c 100644 (file)
@@ -55,18 +55,28 @@ public:
 
   explicit IntegerSet(ImplType *set = nullptr) : set(set) {}
 
+  IntegerSet &operator=(const IntegerSet other) {
+    set = other.set;
+    return *this;
+  }
+
   static IntegerSet get(unsigned dimCount, unsigned symbolCount,
                         ArrayRef<AffineExpr> constraints,
-                        ArrayRef<bool> eqFlags, MLIRContext *context);
+                        ArrayRef<bool> eqFlags);
 
-  // Returns a canonical empty IntegerSet (i.e. a set with no integer points).
+  // Returns the canonical empty IntegerSet (i.e. a set with no integer points).
   static IntegerSet getEmptySet(unsigned numDims, unsigned numSymbols,
                                 MLIRContext *context) {
     auto one = getAffineConstantExpr(1, context);
     /* 1 == 0 */
-    return get(numDims, numSymbols, one, true, context);
+    return get(numDims, numSymbols, one, true);
   }
 
+  /// Returns true if this is the canonical integer set.
+  bool isEmptyIntegerSet() const;
+
+  static IntegerSet Null() { return IntegerSet(nullptr); }
+
   explicit operator bool() { return set; }
   bool operator==(IntegerSet other) const { return set == other.set; }
 
@@ -98,6 +108,8 @@ public:
 
 private:
   ImplType *set;
+  /// Sets with constraints fewer than kUniquingThreshold are uniqued.
+  constexpr static unsigned kUniquingThreshold = 4;
 };
 
 // Make AffineExpr hashable.
index 77a4327164ba8d7438922e16f5aaedcbaf512986..ad51470c20994ef29572d2a988c5620d749bf952 100644 (file)
@@ -547,7 +547,7 @@ private:
 class AffineCondition {
 public:
   const IfStmt *getIfStmt() const { return &stmt; }
-  IntegerSet getSet() const { return set; }
+  IntegerSet getIntegerSet() const { return set; }
 
 private:
   // 'if' statement that contains this affine condition.
index fce4616898b3d33a561be90a225e8fb26746efd6..767677fbc5d92d8e1f0c4774839a9895a95ddd39 100644 (file)
@@ -23,6 +23,7 @@
 #define MLIR_SUPPORT_MATHEXTRAS_H_
 
 #include "mlir/Support/LLVM.h"
+#include "llvm/ADT/APInt.h"
 
 namespace mlir {
 
index bcb3b6de462cafb91bf4f69b574ff76e244ac4e0..1c34e36a57e7cb6cdcf5ed7061aec3aa181d219e 100644 (file)
@@ -51,8 +51,8 @@ MLFunctionPass *createLoopUnrollPass(int unrollFactor = -1,
 /// line if provided.
 MLFunctionPass *createLoopUnrollAndJamPass(int unrollJamFactor = -1);
 
-/// Creates an affine expression simplification pass.
-FunctionPass *createSimplifyAffineExprPass();
+/// Creates an simplification pass for affine structures.
+FunctionPass *createSimplifyAffineStructuresPass();
 
 /// Creates a pass to pipeline explicit movement of data across levels of the
 /// memory hierarchy.
index 79e5bf46e37a92b47ecbb967e0219d337cb32ae9..1b9820e429cca3d243372624d25b02ff891afc90 100644 (file)
@@ -48,7 +48,7 @@ class SSAValue;
 // extended to add additional indices at any position.
 bool replaceAllMemRefUsesWith(const MLValue *oldMemRef, MLValue *newMemRef,
                               llvm::ArrayRef<MLValue *> extraIndices,
-                              AffineMap indexRemap = AffineMap::Invalid());
+                              AffineMap indexRemap = AffineMap::Null());
 
 /// Creates and inserts into 'builder' a new AffineApplyOp, with the number of
 /// its results equal to the number of operands, as a composition
index 9b5cc99ff8e365f0801be9e890c6d532e257c565..8f792641dabc0dcd75cce5ad3b96e4e6a2399d6d 100644 (file)
@@ -286,7 +286,7 @@ AffineExpr mlir::simplifyAffineExpr(AffineExpr expr, unsigned numDims,
   // TODO(bondhugula): only pure affine for now. The simplification here can be
   // extended to semi-affine maps in the future.
   if (!expr.isPureAffine())
-    return nullptr;
+    return expr;
 
   AffineExprFlattener flattener(numDims, numSymbols, expr.getContext());
   flattener.walkPostOrder(expr);
index e45b5ccc064666aa7bb3a2c290b975e363266c08..d8c41cfed2647509fcbd0e833ba2942304ba8252 100644 (file)
 
 #include "mlir/Analysis/AffineStructures.h"
 #include "mlir/Analysis/AffineAnalysis.h"
-#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineExprVisitor.h"
-#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/MLValue.h"
 #include "mlir/Support/MathExtras.h"
-#include "llvm/ADT/DenseMap.h"
+#include "third_party/llvm/llvm/projects/google-mlir/include/mlir/Analysis/AffineStructures.h"
 #include "llvm/ADT/DenseSet.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 
+#define DEBUG_TYPE "affine-structures"
+
 using namespace mlir;
 using namespace llvm;
 
@@ -439,12 +440,45 @@ AffineMap AffineValueMap::getAffineMap() { return map.getAffineMap(); }
 
 AffineValueMap::~AffineValueMap() {}
 
+//===----------------------------------------------------------------------===//
+// FlatAffineConstraints.
+//===----------------------------------------------------------------------===//
+
+// Copy constructor.
+FlatAffineConstraints::FlatAffineConstraints(
+    const FlatAffineConstraints &other) {
+  numReservedEqualities = other.numReservedEqualities;
+  numReservedInequalities = other.numReservedInequalities;
+  numIds = other.getNumIds();
+  numDims = other.getNumDimIds();
+  numSymbols = other.getNumSymbolIds();
+
+  equalities.reserve(numReservedEqualities * getNumCols());
+  inequalities.reserve(numReservedEqualities * getNumCols());
+
+  for (unsigned r = 0, e = other.getNumInequalities(); r < e; r++) {
+    addInequality(other.getInequality(r));
+  }
+  for (unsigned r = 0, e = other.getNumEqualities(); r < e; r++) {
+    addEquality(other.getEquality(r));
+  }
+}
+
+// Clones this object.
+std::unique_ptr<FlatAffineConstraints> FlatAffineConstraints::clone() const {
+  return std::make_unique<FlatAffineConstraints>(*this);
+}
+
+// Construct from an IntegerSet.
 FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
-    : numReservedEqualities(0), numReservedInequalities(0), numReservedIds(0),
+    : numReservedEqualities(set.getNumEqualities()),
+      numReservedInequalities(set.getNumInequalities()), numReservedIds(0),
       numIds(set.getNumDims() + set.getNumSymbols()), numDims(set.getNumDims()),
       numSymbols(set.getNumSymbols()) {
-  unsigned numConstraints = set.getNumConstraints();
-  for (unsigned i = 0; i < numConstraints; ++i) {
+  equalities.reserve(set.getNumEqualities() * getNumCols());
+  inequalities.reserve(set.getNumInequalities() * getNumCols());
+
+  for (unsigned i = 0, e = set.getNumConstraints(); i < e; ++i) {
     AffineExpr expr = set.getConstraint(i);
     SmallVector<int64_t, 4> flattenedExpr;
     getFlattenedAffineExpr(expr, set.getNumDims(), set.getNumSymbols(),
@@ -457,7 +491,6 @@ FlatAffineConstraints::FlatAffineConstraints(IntegerSet set)
     }
   }
 }
-
 // Searches for a constraint with a non-zero coefficient at 'colIdx' in
 // equality (isEq=true) or inequality (isEq=false) constraints.
 // Returns true and sets row found in search in 'rowIdx'.
@@ -658,12 +691,16 @@ void FlatAffineConstraints::removeColumnRange(unsigned colStart,
 // all equality constraint rows, and checks the constraint validity.
 // Returns 'true' if the GCD test fails on any row, or if any invalid
 // constraint is detected. Returns 'false' otherwise.
-bool FlatAffineConstraints::isEmpty() {
-  if (eliminateIdentifiers(0, numIds) == 0)
-    return false;
-  if (isEmptyByGCDTest(*this))
+bool FlatAffineConstraints::isEmpty() const {
+  auto tmpCst = clone();
+  if (tmpCst->gaussianEliminateIds(0, numIds) < numIds) {
+    for (unsigned i = 0, e = tmpCst->getNumIds(); i < e; i++)
+      if (!tmpCst->FourierMotzkinEliminate(0))
+        return false;
+  }
+  if (isEmptyByGCDTest(*tmpCst))
     return true;
-  if (hasInvalidConstraint(*this))
+  if (hasInvalidConstraint(*tmpCst))
     return true;
   return false;
 }
@@ -671,13 +708,13 @@ bool FlatAffineConstraints::isEmpty() {
 // Eliminates a single identifier at 'position' from equality and inequality
 // constraints. Returns 'true' if the identifier was eliminated.
 // Returns 'false' otherwise.
-bool FlatAffineConstraints::eliminateIdentifier(unsigned position) {
-  return eliminateIdentifiers(position, position + 1) == 1;
+bool FlatAffineConstraints::gaussianEliminateId(unsigned position) {
+  return gaussianEliminateIds(position, position + 1) == 1;
 }
 
 // Eliminates all identifer variables in column range [posStart, posLimit).
 // Returns the number of variables eliminated.
-unsigned FlatAffineConstraints::eliminateIdentifiers(unsigned posStart,
+unsigned FlatAffineConstraints::gaussianEliminateIds(unsigned posStart,
                                                      unsigned posLimit) {
   // Return if identifier positions to eliminate are out of range.
   if (posStart >= posLimit || posLimit > numIds)
@@ -768,3 +805,243 @@ void FlatAffineConstraints::print(raw_ostream &os) const {
 }
 
 void FlatAffineConstraints::dump() const { print(llvm::errs()); }
+
+void FlatAffineConstraints::removeDuplicates() {
+  // TODO: remove redundant constraints.
+}
+
+void FlatAffineConstraints::clearAndCopyFrom(
+    const FlatAffineConstraints &other) {
+  FlatAffineConstraints copy(other);
+  std::swap(*this, copy);
+}
+
+void FlatAffineConstraints::removeId(unsigned pos) {
+  assert(pos >= 0 && pos < getNumIds());
+
+  for (unsigned r = 0; r < getNumInequalities(); r++) {
+    for (unsigned c = pos; c < getNumCols() - 1; c++) {
+      atIneq(r, c) = atIneq(r, c + 1);
+    }
+  }
+
+  for (unsigned r = 0; r < getNumEqualities(); r++) {
+    for (unsigned c = pos; c < getNumCols() - 1; c++) {
+      atEq(r, c) = atEq(r, c + 1);
+    }
+  }
+
+  if (pos < numDims)
+    numDims--;
+  else if (pos < numSymbols)
+    numSymbols--;
+  numIds--;
+}
+
+static std::pair<unsigned, unsigned>
+getNewNumDimsSymbols(unsigned pos, const FlatAffineConstraints &cst) {
+  unsigned numDims = cst.getNumDimIds();
+  unsigned numSymbols = cst.getNumSymbolIds();
+  unsigned newNumDims, newNumSymbols;
+  if (pos < numDims) {
+    newNumDims = numDims - 1;
+    newNumSymbols = numSymbols;
+  } else if (pos < numDims + numSymbols) {
+    assert(numSymbols >= 1);
+    newNumDims = numDims;
+    newNumSymbols = numSymbols - 1;
+  } else {
+    newNumDims = numDims;
+    newNumSymbols = numSymbols;
+  }
+  return {newNumDims, newNumSymbols};
+}
+
+/// Eliminates identifier at the specified position using Fourier-Motzkin
+/// variable elimination. This technique is exact for rational spaces but
+/// conservative (in "rare" cases) for integer spaces. The operation corresponds
+/// to a projection operation yielding the (convex) set of integer points
+/// contained in the rational shadow of the set. An emptiness test that relies
+/// on this method will guarantee emptiness, i.e., it disproves the existence of
+/// a solution if it says it's empty.
+/// If a non-null isResultIntegerExact is passed, it is set to true if the
+/// result is also integer exact. If it's set to false, the obtained solution
+/// *may* not be exact, i.e., it may contain integer points that do not have an
+/// integer pre-image in the original set.
+///
+/// Eg:
+/// j >= 0, j <= i + 1
+/// i >= 0, i <= N + 1
+/// Eliminating i yields,
+///   j >= 0, 0 <= N + 1, j - 1 <= N + 1
+///
+/// If darkShadow = true, this method computes the dark shadow on elimination;
+/// the dark shadow is a convex integer subset of the exact integer shadow. A
+/// non-empty dark shadow proves the existence of an integer solution. The
+/// elimination in such a case could however be an under-approximation, and thus
+/// should not be used for scanning sets or used by itself for dependence
+/// checking.
+///
+/// Eg: 2-d set, * represents grid points, 'o' represents a point in the set.
+///            ^
+///            |
+///            | * * * * o o
+///         i  | * * o o o o
+///            | o * * * * *
+///            --------------->
+///                 j ->
+///
+/// Eliminating i from this system (projecting on the j dimension):
+/// rational shadow / integer light shadow:  1 <= j <= 6
+/// dark shadow:                             3 <= j <= 6
+/// exact integer shadow:                    j = 1 \union  3 <= j <= 6
+/// holes/splinters:                         j = 2
+///
+/// darkShadow = false, isResultIntegerExact = nullptr are default values.
+// TODO(bondhugula): a slight modification to yield dark shadow version of FM
+// (tightened), which can prove the existence of a solution if there is one.
+bool FlatAffineConstraints::FourierMotzkinEliminate(
+    unsigned pos, bool darkShadow, bool *isResultIntegerExact) {
+  assert(pos < getNumIds() && "invalid position");
+  LLVM_DEBUG(llvm::dbgs() << "FM Input:\n");
+  LLVM_DEBUG(dump());
+
+  // Check if this identifier can be eliminated through a substitution.
+  for (unsigned r = 0; r < getNumEqualities(); r++) {
+    if (atIneq(r, pos) != 0) {
+      // Use Gaussian elimination here (since we have an equality).
+      bool ret = gaussianEliminateId(pos);
+      assert(ret && "Gaussian elimination guaranteed to succeed");
+      return ret;
+    }
+  }
+
+  // Check if the identifier appears at all in any of the inequalities.
+  unsigned r, e;
+  for (r = 0, e = getNumInequalities(); r < e; r++) {
+    if (atIneq(r, pos) != 0)
+      break;
+  }
+  if (r == getNumInequalities()) {
+    // If it doesn't appear, just remove the column and return.
+    // TODO(andydavis,bondhugula): refactor removeColumns to use it from here.
+    removeId(pos);
+    LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
+    LLVM_DEBUG(dump());
+    return true;
+  }
+
+  // Positions of constraints that are lower bounds on the variable.
+  SmallVector<unsigned, 4> lbIndices;
+  // Positions of constraints that are lower bounds on the variable.
+  SmallVector<unsigned, 4> ubIndices;
+  // Positions of constraints that do not involve the variable.
+  std::vector<unsigned> nbIndices;
+  nbIndices.reserve(getNumInequalities());
+
+  // Gather all lower bounds and upper bounds of the variable. Since the
+  // canonical form c_1*x_1 + c_2*x_2 + ... + c_0 >= 0, a constraint is a lower
+  // bound for x_i if c_i >= 1, and an upper bound if c_i <= -1.
+  for (unsigned r = 0, e = getNumInequalities(); r < e; r++) {
+    if (atIneq(r, pos) == 0) {
+      // Id does not appear in bound.
+      nbIndices.push_back(r);
+    } else if (atIneq(r, pos) >= 1) {
+      // Lower bound.
+      lbIndices.push_back(r);
+    } else {
+      // Upper bound.
+      ubIndices.push_back(r);
+    }
+  }
+
+  // Set the number of dimensions, symbols in the resulting system.
+  const auto &dimsSymbols = getNewNumDimsSymbols(pos, *this);
+  unsigned newNumDims = dimsSymbols.first;
+  unsigned newNumSymbols = dimsSymbols.second;
+
+  /// Create the new system which has one identifier less.
+  FlatAffineConstraints newFac(
+      lbIndices.size() * ubIndices.size() + nbIndices.size(),
+      getNumEqualities(), getNumCols() - 1, newNumDims, newNumSymbols,
+      /*numLocals=*/getNumIds() - 1 - newNumDims - newNumSymbols);
+
+  // This will be used to check if the elimination was integer exact.
+  unsigned lcmProducts = 1;
+
+  // Let x be the variable we are eliminating.
+  // For each lower bound, lb <= c_l*x, and each upper bound c_u*x <= ub, (note
+  // that c_l, c_u >= 1) we have:
+  // lb*lcm(c_l, c_u)/c_l <= lcm(c_l, c_u)*x <= ub*lcm(c_l, c_u)/c_u
+  // We thus generate a constraint:
+  // lcm(c_l, c_u)/c_l*lb <= lcm(c_l, c_u)/c_u*ub.
+  // Note if c_l = c_u = 1, all integer points captured by the resulting
+  // constraint correspond to integer points in the original system (i.e., they
+  // have integer pre-images). Hence, if the lcm's are all 1, the elimination is
+  // integer exact.
+  for (auto ubPos : ubIndices) {
+    for (auto lbPos : lbIndices) {
+      SmallVector<int64_t, 4> ineq;
+      ineq.reserve(newFac.getNumCols());
+      int64_t lbCoeff = atIneq(lbPos, pos);
+      // Note that in the comments above, ubCoeff is the negation of the
+      // coefficient in the canonical form as the view taken here is that of the
+      // term being moved to the other size of '>='.
+      int64_t ubCoeff = -atIneq(ubPos, pos);
+      // TODO(bondhugula): refactor this loop to avoid all branches inside.
+      for (unsigned l = 0, e = getNumCols(); l < e; l++) {
+        if (l == pos)
+          continue;
+        assert(lbCoeff >= 1 && ubCoeff >= 1 && "bounds wrongly identified");
+        int64_t lcm = mlir::lcm(lbCoeff, ubCoeff);
+        ineq.push_back(atIneq(ubPos, l) * (lcm / ubCoeff) +
+                       atIneq(lbPos, l) * (lcm / lbCoeff));
+        lcmProducts *= lcm;
+      }
+      if (darkShadow) {
+        // The dark shadow is a convex subset of the exact integer shadow. If
+        // there is a point here, it proves the existence of a solution.
+        ineq[ineq.size() - 1] += lbCoeff * ubCoeff - lbCoeff - ubCoeff + 1;
+      }
+      // TODO: we need to have a way to add inequalities in-place in
+      // FlatAffineConstraints instead of creating and copying over.
+      newFac.addInequality(ineq);
+    }
+  }
+
+  if (lcmProducts == 1 && isResultIntegerExact)
+    *isResultIntegerExact = 1;
+
+  // Copy over the constraints not involving this variable.
+  for (auto nbPos : nbIndices) {
+    SmallVector<int64_t, 4> ineq;
+    ineq.reserve(getNumCols() - 1);
+    for (unsigned l = 0, e = getNumCols(); l < e; l++) {
+      if (l == pos)
+        continue;
+      ineq.push_back(atIneq(nbPos, l));
+    }
+    newFac.addInequality(ineq);
+  }
+
+  assert(newFac.getNumConstraints() ==
+         lbIndices.size() * ubIndices.size() + nbIndices.size());
+
+  // Copy over the equalities.
+  for (unsigned r = 0, e = getNumEqualities(); r < e; r++) {
+    SmallVector<int64_t, 4> eq;
+    eq.reserve(newFac.getNumCols());
+    for (unsigned l = 0, e = getNumCols(); l < e; l++) {
+      if (l == pos)
+        continue;
+      eq.push_back(atEq(r, l));
+    }
+    newFac.addEquality(eq);
+  }
+
+  newFac.removeDuplicates();
+  clearAndCopyFrom(newFac);
+  LLVM_DEBUG(llvm::dbgs() << "FM output:\n");
+  LLVM_DEBUG(dump());
+  return true;
+}
index dbf5126e22ea3a31dc9b99104139eea944c7706d..5dc79824517e75cfe8882e8d7ce8b4389ce79a1d 100644 (file)
@@ -121,7 +121,7 @@ int64_t AffineMap::getSingleConstantResult() const {
 
 unsigned AffineMap::getNumDims() const { return map->numDims; }
 unsigned AffineMap::getNumSymbols() const { return map->numSymbols; }
-unsigned AffineMap::getNumResults() const { return map->numResults; }
+unsigned AffineMap::getNumResults() const { return map->results.size(); }
 unsigned AffineMap::getNumInputs() const {
   return map->numDims + map->numSymbols;
 }
index 0db06601964b5a3843e644a5a30bfad066a8971c..edbc714f00b52ddc6c88dceca1327eb67184c9c1 100644 (file)
@@ -32,7 +32,6 @@ namespace detail {
 struct AffineMapStorage {
   unsigned numDims;
   unsigned numSymbols;
-  unsigned numResults;
 
   /// The affine expressions for this (multi-dimensional) map.
   /// TODO: use trailing objects for this.
index 5759d8477cadab53ace5f049ff7d8ac7c5c257bb..7a751ee38351b0d8c8a5cdddde4630a45728f6df 100644 (file)
@@ -194,7 +194,7 @@ AffineExpr Builder::getAffineConstantExpr(int64_t constant) {
 IntegerSet Builder::getIntegerSet(unsigned dimCount, unsigned symbolCount,
                                   ArrayRef<AffineExpr> constraints,
                                   ArrayRef<bool> isEq) {
-  return IntegerSet::get(dimCount, symbolCount, constraints, isEq, context);
+  return IntegerSet::get(dimCount, symbolCount, constraints, isEq);
 }
 
 AffineMap Builder::getConstantAffineMap(int64_t val) {
index ff70128ad64188b611b7d29982c88c070f227bc7..74a1297dcdd2c128cd9eb89efc34fdaaeb4aab9d 100644 (file)
@@ -27,7 +27,10 @@ unsigned IntegerSet::getNumSymbols() const { return set->symbolCount; }
 unsigned IntegerSet::getNumOperands() const {
   return set->dimCount + set->symbolCount;
 }
-unsigned IntegerSet::getNumConstraints() const { return set->numConstraints; }
+
+unsigned IntegerSet::getNumConstraints() const {
+  return set->constraints.size();
+}
 
 unsigned IntegerSet::getNumEqualities() const {
   unsigned numEqualities = 0;
@@ -41,6 +44,13 @@ unsigned IntegerSet::getNumInequalities() const {
   return getNumConstraints() - getNumEqualities();
 }
 
+bool IntegerSet::isEmptyIntegerSet() const {
+  // This will only work if uniqui'ing is on.
+  static_assert(kUniquingThreshold >= 1,
+                "uniquing threshold should be at least one");
+  return *this == getEmptySet(set->dimCount, set->symbolCount, getContext());
+}
+
 ArrayRef<AffineExpr> IntegerSet::getConstraints() const {
   return set->constraints;
 }
index 59b3f87ec296ee9e2734eb23a2f4a2ab8ef853c9..b3eda5205fb0ce9d64a3b3e911032ffdd905fb3d 100644 (file)
@@ -31,7 +31,6 @@ namespace detail {
 struct IntegerSetStorage {
   unsigned dimCount;
   unsigned symbolCount;
-  unsigned numConstraints;
 
   /// Array of affine constraints: a constraint is either an equality
   /// (affine_expr == 0) or an inequality (affine_expr >= 0).
index 6705619696eb0c169af4b0f2922d12fe29b7ac6c..d32cabfa6ac8656c17bd598177e7c2604388b360 100644 (file)
@@ -84,6 +84,29 @@ struct AffineMapKeyInfo : DenseMapInfo<AffineMap> {
   }
 };
 
+struct IntegerSetKeyInfo : DenseMapInfo<IntegerSet> {
+  // Integer sets are uniqued based on their dim/symbol counts, affine
+  // expressions appearing in the LHS of constraints, and eqFlags.
+  using KeyTy =
+      std::tuple<unsigned, unsigned, ArrayRef<AffineExpr>, ArrayRef<bool>>;
+  using DenseMapInfo<IntegerSet>::getHashValue;
+  using DenseMapInfo<IntegerSet>::isEqual;
+
+  static unsigned getHashValue(KeyTy key) {
+    return hash_combine(
+        std::get<0>(key), std::get<1>(key),
+        hash_combine_range(std::get<2>(key).begin(), std::get<2>(key).end()),
+        hash_combine_range(std::get<3>(key).begin(), std::get<3>(key).end()));
+  }
+
+  static bool isEqual(const KeyTy &lhs, IntegerSet rhs) {
+    if (rhs == getEmptyKey() || rhs == getTombstoneKey())
+      return false;
+    return lhs == std::make_tuple(rhs.getNumDims(), rhs.getNumSymbols(),
+                                  rhs.getConstraints(), rhs.getEqFlags());
+  }
+};
+
 struct VectorTypeKeyInfo : DenseMapInfo<VectorType *> {
   // Vectors are uniqued based on their element type and shape.
   using KeyTy = std::pair<Type *, ArrayRef<int>>;
@@ -280,6 +303,10 @@ public:
   using AffineMapSet = DenseSet<AffineMap, AffineMapKeyInfo>;
   AffineMapSet affineMaps;
 
+  // Integer set uniquing.
+  using IntegerSets = DenseSet<IntegerSet, IntegerSetKeyInfo>;
+  IntegerSets integerSets;
+
   // Affine binary op expression uniquing. Figure out uniquing of dimensional
   // or symbolic identifiers.
   DenseMap<std::tuple<unsigned, AffineExpr, AffineExpr>, AffineExpr>
@@ -1132,9 +1159,8 @@ AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
   rangeSizes = impl.copyInto(rangeSizes);
 
   // Initialize the memory using placement new.
-  new (res) detail::AffineMapStorage{dimCount, symbolCount,
-                                     static_cast<unsigned>(results.size()),
-                                     results, rangeSizes};
+  new (res)
+      detail::AffineMapStorage{dimCount, symbolCount, results, rangeSizes};
 
   // Cache and return it.
   return *existing.first = AffineMap(res);
@@ -1412,27 +1438,45 @@ AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
 
 //===----------------------------------------------------------------------===//
 // Integer Sets: these are allocated into the bump pointer, and are immutable.
-// But they aren't uniqued like AffineMap's; there isn't an advantage to.
+// Unlike AffineMap's, these are uniqued only if they are small.
 //===----------------------------------------------------------------------===//
 
 IntegerSet IntegerSet::get(unsigned dimCount, unsigned symbolCount,
                            ArrayRef<AffineExpr> constraints,
-                           ArrayRef<bool> eqFlags, MLIRContext *context) {
-  assert(eqFlags.size() == constraints.size());
+                           ArrayRef<bool> eqFlags) {
+  // The number of constraints can't be zero.
+  assert(!constraints.empty());
+  assert(constraints.size() == eqFlags.size());
 
-  auto &impl = context->getImpl();
+  bool unique = constraints.size() < IntegerSet::kUniquingThreshold;
 
-  // Allocate them into the bump pointer.
-  auto *res = impl.allocator.Allocate<IntegerSetStorage>();
+  auto &impl = constraints[0].getContext()->getImpl();
 
-  // Copy the equalities and inequalities into the bump pointer.
-  constraints = impl.copyInto(ArrayRef<AffineExpr>(constraints));
-  eqFlags = impl.copyInto(ArrayRef<bool>(eqFlags));
+  std::pair<DenseSet<IntegerSet, IntegerSetKeyInfo>::Iterator, bool> existing;
+  if (unique) {
+    // Check if we already have this integer set.
+    auto key = std::make_tuple(dimCount, symbolCount, constraints, eqFlags);
+    existing = impl.integerSets.insert_as(IntegerSet(nullptr), key);
+
+    // If we already have it, return that value.
+    if (!existing.second)
+      return *existing.first;
+  }
+
+  // On the first use, we allocate them into the bump pointer.
+  auto *res = impl.allocator.Allocate<detail::IntegerSetStorage>();
+
+  // Copy the results and equality flags into the bump pointer.
+  constraints = impl.copyInto(constraints);
+  eqFlags = impl.copyInto(eqFlags);
 
   // Initialize the memory using placement new.
-  res = new (res) IntegerSetStorage{dimCount, symbolCount,
-                                    static_cast<unsigned>(constraints.size()),
-                                    constraints, eqFlags};
+  new (res)
+      detail::IntegerSetStorage{dimCount, symbolCount, constraints, eqFlags};
+
+  if (unique)
+    // Cache and return it.
+    return *existing.first = IntegerSet(res);
 
   return IntegerSet(res);
 }
index e35c20c6d43b7703a98b84f065c44eb0fc641748..7f9f870ad2da01c773d128f563d9a56ff1584956 100644 (file)
@@ -1571,17 +1571,17 @@ AffineMap AffineParser::parseAffineMapInline() {
 
   // List of dimensional identifiers.
   if (parseDimIdList(numDims))
-    return AffineMap::Invalid();
+    return AffineMap::Null();
 
   // Symbols are optional.
   if (getToken().is(Token::l_square)) {
     if (parseSymbolIdList(numSymbols))
-      return AffineMap::Invalid();
+      return AffineMap::Null();
   }
 
   if (parseToken(Token::arrow, "expected '->' or '['") ||
       parseToken(Token::l_paren, "expected '(' at start of affine map range"))
-    return AffineMap::Invalid();
+    return AffineMap::Null();
 
   SmallVector<AffineExpr, 4> exprs;
   auto parseElt = [&]() -> ParseResult {
@@ -1595,7 +1595,7 @@ AffineMap AffineParser::parseAffineMapInline() {
   // 1-d affine expressions); the list cannot be empty. Grammar:
   // multi-dim-affine-expr ::= `(` affine-expr (`,` affine-expr)* `)
   if (parseCommaSeparatedListUntil(Token::r_paren, parseElt, false))
-    return AffineMap::Invalid();
+    return AffineMap::Null();
 
   // Parse optional range sizes.
   //  range-sizes ::= (`size` `(` dim-size (`,` dim-size)* `)`)?
@@ -1607,7 +1607,7 @@ AffineMap AffineParser::parseAffineMapInline() {
     // Location of the l_paren token (if it exists) for error reporting later.
     auto loc = getToken().getLoc();
     if (parseToken(Token::l_paren, "expected '(' at start of affine map range"))
-      return AffineMap::Invalid();
+      return AffineMap::Null();
 
     auto parseRangeSize = [&]() -> ParseResult {
       auto loc = getToken().getLoc();
@@ -1624,13 +1624,13 @@ AffineMap AffineParser::parseAffineMapInline() {
     };
 
     if (parseCommaSeparatedListUntil(Token::r_paren, parseRangeSize, false))
-      return AffineMap::Invalid();
+      return AffineMap::Null();
     if (exprs.size() > rangeSizes.size())
       return (emitError(loc, "fewer range sizes than range expressions"),
-              AffineMap::Invalid());
+              AffineMap::Null());
     if (exprs.size() < rangeSizes.size())
       return (emitError(loc, "more range sizes than range expressions"),
-              AffineMap::Invalid());
+              AffineMap::Null());
   }
 
   // Parsed a valid affine map.
@@ -1647,7 +1647,7 @@ AffineMap Parser::parseAffineMapReference() {
     StringRef affineMapId = getTokenSpelling().drop_front();
     if (getState().affineMapDefinitions.count(affineMapId) == 0)
       return (emitError("undefined affine map id '" + affineMapId + "'"),
-              AffineMap::Invalid());
+              AffineMap::Null());
     consumeToken(Token::hash_identifier);
     return getState().affineMapDefinitions[affineMapId];
   }
index cc3de09ebb683f9a6edee7b89be62380ba90eb43..a6a850280bb39497188b68ef816553a8386d79c1 100644 (file)
@@ -46,12 +46,12 @@ AffineMap mlir::getUnrolledLoopUpperBound(const ForStmt &forStmt,
 
   // Single result lower bound map only.
   if (lbMap.getNumResults() != 1)
-    return AffineMap::Invalid();
+    return AffineMap::Null();
 
   // Sometimes, the trip count cannot be expressed as an affine expression.
   auto tripCount = getTripCountExpr(forStmt);
   if (!tripCount)
-    return AffineMap::Invalid();
+    return AffineMap::Null();
 
   AffineExpr lb(lbMap.getResult(0));
   unsigned step = forStmt.getStep();
@@ -72,12 +72,12 @@ AffineMap mlir::getCleanupLoopLowerBound(const ForStmt &forStmt,
 
   // Single result lower bound map only.
   if (lbMap.getNumResults() != 1)
-    return AffineMap::Invalid();
+    return AffineMap::Null();
 
   // Sometimes the trip count cannot be expressed as an affine expression.
   AffineExpr tripCount(getTripCountExpr(forStmt));
   if (!tripCount)
-    return AffineMap::Invalid();
+    return AffineMap::Null();
 
   AffineExpr lb(lbMap.getResult(0));
   unsigned step = forStmt.getStep();
index edb60c9bf232649a28daa44f8e1ca21128d3f73c..3ab799a40282a19b710f7182bbcde5948e9c1fd7 100644 (file)
@@ -46,13 +46,13 @@ struct SimplifyAffineStructures : public FunctionPass,
   // for this yet? TODO(someone).
   PassResult runOnCFGFunction(CFGFunction *f) { return success(); }
 
-  void visitOperationStmt(OperationStmt *stmt);
   void visitIfStmt(IfStmt *ifStmt);
+  void visitOperationStmt(OperationStmt *opStmt);
 };
 
 } // end anonymous namespace
 
-FunctionPass *mlir::createSimplifyAffineExprPass() {
+FunctionPass *mlir::createSimplifyAffineStructuresPass() {
   return new SimplifyAffineStructures();
 }
 
@@ -65,9 +65,8 @@ static IntegerSet simplifyIntegerSet(IntegerSet set) {
 }
 
 void SimplifyAffineStructures::visitIfStmt(IfStmt *ifStmt) {
-  auto set = ifStmt->getCondition().getSet();
-  IntegerSet simplified = simplifyIntegerSet(set);
-  ifStmt->setIntegerSet(simplified);
+  auto set = ifStmt->getCondition().getIntegerSet();
+  ifStmt->setIntegerSet(simplifyIntegerSet(set));
 }
 
 void SimplifyAffineStructures::visitOperationStmt(OperationStmt *opStmt) {
index 5ac9cf8f392c0a8c8e75e6509b69de12822c34ab..dfb5c90f2b36719dcc0ea3c081a116d8be5af6bb 100644 (file)
@@ -1,4 +1,4 @@
-// RUN: mlir-opt %s -simplify-affine-expr | FileCheck %s
+// RUN: mlir-opt %s -simplify-affine-structures | FileCheck %s
 
 // CHECK: #map{{[0-9]+}} = (d0, d1) -> (0, 0)
 #map0 = (d0, d1) -> ((d0 - d0 mod 4) mod 4, (d0 - d0 mod 128 - 64) mod 64)
 // CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0 - (d0 floordiv 8) * 8, (d1 floordiv 8) * 8)
 #map6 = (d0, d1) -> (d0 mod 8, d1 - d1 mod 8)
 
-// Set for test case: test_gaussian_elimination_empty_set0
 // CHECK: @@set0 = (d0, d1) : (1 == 0)
-@@set0 = (d0, d1) : (2 == 0)
-
-// Set for test case: test_gaussian_elimination_empty_set1
-// CHECK: @@set1 = (d0, d1) : (1 == 0)
-@@set1 = (d0, d1) : (1 >= 0, -1 >= 0)
+// CHECK: @@set1 = (d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, d0 * -1 + 100 >= 0, d1 >= 0, d1 + 101 >= 0)
+// CHECK: @@set2 = (d0, d1)[s0, s1] : (1 == 0)
+// CHECK: @@set3 = (d0, d1)[s0, s1] : (d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0, d0 * 5 - d1 * 11 + s0 * 7 + s1 == 0, d0 * 11 + d1 * 7 - s0 * 5 + s1 == 0, d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0)
+// CHECK: @@set4 = (d0) : (1 == 0)
+// CHECK: @@set5 = (d0)[s0, s1] : (1 == 0)
+// CHECK: @@set6 = (d0, d1, d2) : (1 == 0)
 
 // Set for test case: test_gaussian_elimination_non_empty_set2
-// CHECK: @@set2 = (d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, d0 * -1 + 100 >= 0, d1 >= 0, d1 + 101 >= 0)
+// @@set2 = (d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, d0 * -1 + 100 >= 0, d1 >= 0, d1 + 101 >= 0)
 @@set2 = (d0, d1) : (d0 - 100 == 0, d1 - 10 == 0, -d0 + 100 >= 0, d1 >= 0, d1 + 101 >= 0)
 
 // Set for test case: test_gaussian_elimination_empty_set3
-// CHECK: @@set3 = (d0, d1)[s0, s1] : (1 == 0)
+// @@set3 = (d0, d1)[s0, s1] : (1 == 0)
 @@set3 = (d0, d1)[s0, s1] : (d0 - s0 == 0, d0 + s0 == 0, s0 - 1 == 0)
 
 // Set for test case: test_gaussian_elimination_non_empty_set4
-// CHECK: @@set4 = (d0, d1)[s0, s1] : (d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0, d0 * 5 - d1 * 11 + s0 * 7 + s1 == 0, d0 * 11 + d1 * 7 - s0 * 5 + s1 == 0, d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0)
 @@set4 = (d0, d1)[s0, s1] : (d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0,
                              d0 * 5 - d1 * 11 + s0 * 7 + s1 == 0,
                             d0 * 11 + d1 * 7 - s0 * 5 + s1 == 0,
                             d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0)
 
-// Add invalide constraints to previous non-empty set to make it empty.
+// Add invalid constraints to previous non-empty set to make it empty.
 // Set for test case: test_gaussian_elimination_empty_set5
-// CHECK: @@set5 = (d0, d1)[s0, s1] : (1 == 0)
 @@set5 = (d0, d1)[s0, s1] : (d0 * 7 + d1 * 5 + s0 * 11 + s1 == 0,
                              d0 * 5 - d1 * 11 + s0 * 7 + s1 == 0,
                             d0 * 11 + d1 * 7 - s0 * 5 + s1 == 0,
@@ -74,7 +72,7 @@ mlfunc @test_gaussian_elimination_empty_set0() {
   for %i0 = 1 to 10 {
     for %i1 = 1 to 100 {
       // CHECK: @@set0(%i0, %i1)
-      if @@set0(%i0, %i1) {
+      if (d0, d1) : (2 == 0)(%i0, %i1) {
       }
     }
   }
@@ -85,8 +83,8 @@ mlfunc @test_gaussian_elimination_empty_set0() {
 mlfunc @test_gaussian_elimination_empty_set1() {
   for %i0 = 1 to 10 {
     for %i1 = 1 to 100 {
-      // CHECK: @@set1(%i0, %i1)
-      if @@set1(%i0, %i1) {
+      // CHECK: @@set0(%i0, %i1)
+      if (d0, d1) : (1 >= 0, -1 >= 0) (%i0, %i1) {
       }
     }
   }
@@ -97,7 +95,7 @@ mlfunc @test_gaussian_elimination_empty_set1() {
 mlfunc @test_gaussian_elimination_non_empty_set2() {
   for %i0 = 1 to 10 {
     for %i1 = 1 to 100 {
-      // CHECK: @@set2(%i0, %i1)
+      // CHECK: @@set1(%i0, %i1)
       if @@set2(%i0, %i1) {
       }
     }
@@ -111,7 +109,7 @@ mlfunc @test_gaussian_elimination_empty_set3() {
   %c11 = constant 11 : index
   for %i0 = 1 to 10 {
     for %i1 = 1 to 100 {
-      // CHECK: @@set3(%i0, %i1)[%c7, %c11]
+      // CHECK: @@set2(%i0, %i1)[%c7, %c11]
       if @@set3(%i0, %i1)[%c7, %c11] {
       }
     }
@@ -125,7 +123,7 @@ mlfunc @test_gaussian_elimination_non_empty_set4() {
   %c11 = constant 11 : index
   for %i0 = 1 to 10 {
     for %i1 = 1 to 100 {
-      // CHECK: @@set4(%i0, %i1)[%c7, %c11]
+      // CHECK: @@set3(%i0, %i1)[%c7, %c11]
       if @@set4(%i0, %i1)[%c7, %c11] {
       }
     }
@@ -139,10 +137,40 @@ mlfunc @test_gaussian_elimination_empty_set5() {
   %c11 = constant 11 : index
   for %i0 = 1 to 10 {
     for %i1 = 1 to 100 {
-      // CHECK: @@set5(%i0, %i1)[%c7, %c11]
+      // CHECK: @@set2(%i0, %i1)[%c7, %c11]
       if @@set5(%i0, %i1)[%c7, %c11] {
       }
     }
   }
   return
-}
\ No newline at end of file
+}
+
+// CHECK-LABEL: mlfunc @test_fourier_motzkin(%arg0 : index) {
+mlfunc @test_fourier_motzkin(%N : index) {
+  for %i = 0 to 10 {
+    for %j = 0 to 10 {
+      // CHECK: if @@set0(%i0, %i1)
+      if (d0, d1) : (d0 - d1 >= 0, d1 - d0 - 1 >= 0)(%i, %j) {
+        "foo"() : () -> ()
+      }
+      // CHECK: if @@set4(%i0)
+      if (d0) : (d0 >= 0, -d0 - 1 >= 0)(%i) {
+        "bar"() : () -> ()
+      }
+      // CHECK: if @@set4(%i0)
+      if (d0) : (d0 >= 0, -d0 - 1 >= 0)(%i) {
+        "foo"() : () -> ()
+      }
+      // CHECK: if @@set5(%i0)[%arg0, %arg0]
+      if (d0)[s0, s1] : (d0 >= 0, -d0 + s0 - 1 >= 0, -s0 >= 0)(%i)[%N, %N] {
+        "bar"() : () -> ()
+      }
+      // CHECK: if @@set6(%i0, %i1, %arg0)
+      // The set below implies d0 = d1; so d1 >= d0, but d0 >= d1 + 1.
+      if (d0, d1, d2) : (d0 - d1 == 0, d2 - d0 >= 0, d0 - d1 - 1 >= 0)(%i, %j, %N) {
+        "foo"() : () -> ()
+      }
+    }
+  }
+  return
+}
index 596456d9504d6aebe05c1c1852126ed8c13e7089..38d72c6c6524cf6ee07a939b501ef59ada850fbb 100644 (file)
@@ -75,7 +75,7 @@ enum Passes {
   LoopUnrollAndJam,
   PipelineDataTransfer,
   PrintCFGGraph,
-  SimplifyAffineExpr,
+  SimplifyAffineStructures,
   TFRaiseControlFlow,
   XLALower,
 };
@@ -99,7 +99,7 @@ static cl::list<Passes> passList(
                    "explicitly managed levels of the memory hierarchy"),
         clEnumValN(PrintCFGGraph, "print-cfg-graph",
                    "Print CFG graph per function"),
-        clEnumValN(SimplifyAffineExpr, "simplify-affine-expr",
+        clEnumValN(SimplifyAffineStructures, "simplify-affine-structures",
                    "Simplify affine expressions"),
         clEnumValN(TFRaiseControlFlow, "tf-raise-control-flow",
                    "Dynamic TensorFlow Switch/Match nodes to a CFG"),
@@ -206,8 +206,8 @@ static OptResult performActions(SourceMgr &sourceMgr, MLIRContext *context) {
     case PrintCFGGraph:
       pass = createPrintCFGGraphPass();
       break;
-    case SimplifyAffineExpr:
-      pass = createSimplifyAffineExprPass();
+    case SimplifyAffineStructures:
+      pass = createSimplifyAffineStructuresPass();
       break;
     case TFRaiseControlFlow:
       pass = createRaiseTFControlFlowPass();