[MLIR][Presburger] Add support for piece-wise multi-affine functions
authorArjun P <arjunpitchanathan@gmail.com>
Mon, 7 Feb 2022 19:01:27 +0000 (00:31 +0530)
committerArjun P <arjunpitchanathan@gmail.com>
Mon, 7 Feb 2022 19:13:59 +0000 (00:43 +0530)
Add the class MultiAffineFunction which represents functions whose domain is an
IntegerPolyhedron and which produce an output given by a tuple of affine
expressions in the IntegerPolyhedron's ids.

Also add support for piece-wise MultiAffineFunctions, which are defined on a
union of IntegerPolyhedrons, and may have different output affine expressions
on each IntegerPolyhedron. Thus the function is affine on each individual
IntegerPolyhedron piece in the domain.

This is part of a series of patches leading up to parametric integer programming.

Depends on D118778.

Reviewed By: Groverkss

Differential Revision: https://reviews.llvm.org/D118779

mlir/include/mlir/Analysis/Presburger/IntegerPolyhedron.h
mlir/include/mlir/Analysis/Presburger/PWMAFunction.h [new file with mode: 0644]
mlir/lib/Analysis/Presburger/CMakeLists.txt
mlir/lib/Analysis/Presburger/IntegerPolyhedron.cpp
mlir/lib/Analysis/Presburger/PWMAFunction.cpp [new file with mode: 0644]
mlir/unittests/Analysis/Presburger/CMakeLists.txt
mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp [new file with mode: 0644]

index 4cdccef..28fb6c0 100644 (file)
@@ -56,6 +56,7 @@ public:
   enum class Kind {
     FlatAffineConstraints,
     FlatAffineValueConstraints,
+    MultiAffineFunction,
     IntegerPolyhedron
   };
 
@@ -194,6 +195,11 @@ public:
   /// Adds an equality from the coefficients specified in `eq`.
   void addEquality(ArrayRef<int64_t> eq);
 
+  /// Eliminate the `posB^th` local identifier, replacing every instance of it
+  /// with the `posA^th` local identifier. This should be used when the two
+  /// local variables are known to always take the same values.
+  virtual void eliminateRedundantLocalId(unsigned posA, unsigned posB);
+
   /// Removes identifiers of the specified kind with the specified pos (or
   /// within the specified range) from the system. The specified location is
   /// relative to the first identifier of the specified kind.
@@ -273,6 +279,9 @@ public:
 
   /// Returns true if the given point satisfies the constraints, or false
   /// otherwise.
+  ///
+  /// Note: currently, if the polyhedron contains local ids, the values of
+  /// the local ids must also be provided.
   bool containsPoint(ArrayRef<int64_t> point) const;
 
   /// Find equality and pairs of inequality contraints identified by their
diff --git a/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h b/mlir/include/mlir/Analysis/Presburger/PWMAFunction.h
new file mode 100644 (file)
index 0000000..a01c3ef
--- /dev/null
@@ -0,0 +1,195 @@
+//===- PWMAFunction.h - MLIR PWMAFunction Class------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Support for piece-wise multi-affine functions. These are functions that are
+// defined on a domain that is a union of IntegerPolyhedrons, and on each domain
+// the value of the function is a tuple of integers, with each value in the
+// tuple being an affine expression in the ids of the IntegerPolyhedron.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H
+#define MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H
+
+#include "mlir/Analysis/Presburger/IntegerPolyhedron.h"
+#include "mlir/Analysis/Presburger/PresburgerSet.h"
+
+namespace mlir {
+
+/// This class represents a multi-affine function whose domain is given by an
+/// IntegerPolyhedron. This can be thought of as an IntegerPolyhedron with a
+/// tuple of integer values attached to every point in the polyhedron, with the
+/// value of each element of the tuple given by an affine expression in the ids
+/// of the polyhedron. For example we could have the domain
+///
+/// (x, y) : (x >= 5, y >= x)
+///
+/// and a tuple of three integers defined at every point in the polyhedron:
+///
+/// (x, y) -> (x + 2, 2*x - 3y + 5, 2*x + y).
+///
+/// In this way every point in the polyhedron has a tuple of integers associated
+/// with it. If the integer polyhedron has local ids, then the output
+/// expressions can use them as well. The output expressions are represented as
+/// a matrix with one row for every element in the output vector one column for
+/// each id, and an extra column at the end for the constant term.
+///
+/// Checking equality of two such functions is supported, as well as finding the
+/// value of the function at a specified point. Note that local ids in the
+/// domain are not yet supported for finding the value at a point.
+class MultiAffineFunction : protected IntegerPolyhedron {
+public:
+  /// We use protected inheritance to avoid inheriting the whole public
+  /// interface of IntegerPolyhedron. These using declarations explicitly make
+  /// only the relevant functions part of the public interface.
+  using IntegerPolyhedron::getNumDimAndSymbolIds;
+  using IntegerPolyhedron::getNumDimIds;
+  using IntegerPolyhedron::getNumIds;
+  using IntegerPolyhedron::getNumLocalIds;
+  using IntegerPolyhedron::getNumSymbolIds;
+
+  MultiAffineFunction(const IntegerPolyhedron &domain, const Matrix &output)
+      : IntegerPolyhedron(domain), output(output) {}
+  MultiAffineFunction(const Matrix &output, unsigned numDims,
+                      unsigned numSymbols = 0, unsigned numLocals = 0)
+      : IntegerPolyhedron(numDims, numSymbols, numLocals), output(output) {}
+
+  ~MultiAffineFunction() override = default;
+  Kind getKind() const override { return Kind::MultiAffineFunction; }
+  bool classof(const IntegerPolyhedron *poly) const {
+    return poly->getKind() == Kind::MultiAffineFunction;
+  }
+
+  unsigned getNumInputs() const { return getNumDimAndSymbolIds(); }
+  unsigned getNumOutputs() const { return output.getNumRows(); }
+  bool isConsistent() const { return output.getNumColumns() == numIds + 1; }
+  const IntegerPolyhedron &getDomain() const { return *this; }
+
+  bool hasCompatibleDimensions(const MultiAffineFunction &f) const;
+
+  /// Insert `num` identifiers of the specified kind at position `pos`.
+  /// Positions are relative to the kind of identifier. The coefficient columns
+  /// corresponding to the added identifiers are initialized to zero. Return the
+  /// absolute column position (i.e., not relative to the kind of identifier)
+  /// of the first added identifier.
+  unsigned insertId(IdKind kind, unsigned pos, unsigned num = 1) override;
+
+  /// Swap the posA^th identifier with the posB^th identifier.
+  void swapId(unsigned posA, unsigned posB) override;
+
+  /// Remove the specified range of ids.
+  void removeIdRange(unsigned idStart, unsigned idLimit) override;
+
+  /// Eliminate the `posB^th` local identifier, replacing every instance of it
+  /// with the `posA^th` local identifier. This should be used when the two
+  /// local variables are known to always take the same values.
+  void eliminateRedundantLocalId(unsigned posA, unsigned posB) override;
+
+  /// Return whether the outputs of `this` and `other` agree wherever both
+  /// functions are defined, i.e., the outputs should be equal for all points in
+  /// the intersection of the domains.
+  bool isEqualWhereDomainsOverlap(MultiAffineFunction other) const;
+
+  /// Return whether the `this` and `other` are equal. This is the case if
+  /// they lie in the same space, i.e. have the same dimensions, and their
+  /// domains are identical and their outputs are equal on their domain.
+  bool isEqual(const MultiAffineFunction &other) const;
+
+  /// Get the value of the function at the specified point. If the point lies
+  /// outside the domain, an empty optional is returned.
+  ///
+  /// Note: domains with local ids are not yet supported, and will assert-fail.
+  Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
+
+  void print(raw_ostream &os) const;
+
+  void dump() const;
+
+private:
+  /// The function's output is a tuple of integers, with the ith element of the
+  /// tuple defined by the affine expression given by the ith row of this output
+  /// matrix.
+  Matrix output;
+};
+
+/// This class represents a piece-wise MultiAffineFunction. This can be thought
+/// of as a list of MultiAffineFunction with disjoint domains, with each having
+/// their own affine expressions for their output tuples. For example, we could
+/// have a function with two input variables (x, y), defined as
+///
+/// f(x, y) = (2*x + y, y - 4)  if x >= 0, y >= 0
+///         = (-2*x + y, y + 4) if x < 0,  y < 0
+///         = (4, 1)            if x < 0,  y >= 0
+///
+/// Note that the domains all have to be *disjoint*. Otherwise, the behaviour of
+/// this class is undefined. The domains need not cover all possible points;
+/// this represents a partial function and so could be undefined at some points.
+///
+/// As in PresburgerSets, the input ids are partitioned into dimension ids and
+/// symbolic ids.
+///
+/// Support is provided to compare equality of two such functions as well as
+/// finding the value of the function at a point. Note that local ids in the
+/// piece are not supported for the latter.
+class PWMAFunction {
+public:
+  PWMAFunction(unsigned numDims, unsigned numSymbols, unsigned numOutputs)
+      : numDims(numDims), numSymbols(numSymbols), numOutputs(numOutputs) {
+    assert(numOutputs >= 1 && "The function must output something!");
+  }
+
+  void addPiece(const MultiAffineFunction &piece);
+  void addPiece(const IntegerPolyhedron &domain, const Matrix &output);
+
+  const MultiAffineFunction &getPiece(unsigned i) const { return pieces[i]; }
+  unsigned getNumPieces() const { return pieces.size(); }
+  unsigned getNumOutputs() const { return numOutputs; }
+  unsigned getNumInputs() const { return numDims + numSymbols; }
+  unsigned getNumDimIds() const { return numDims; }
+  unsigned getNumSymbolIds() const { return numSymbols; }
+  MultiAffineFunction &getPiece(unsigned i) { return pieces[i]; }
+
+  /// Return the domain of this piece-wise MultiAffineFunction. This is the
+  /// union of the domains of all the pieces.
+  PresburgerSet getDomain() const;
+
+  /// Check whether the `this` and the given function have compatible
+  /// dimensions, i.e., the same number of dimension inputs, symbol inputs, and
+  /// outputs.
+  bool hasCompatibleDimensions(const MultiAffineFunction &f) const;
+  bool hasCompatibleDimensions(const PWMAFunction &f) const;
+
+  /// Return the value at the specified point and an empty optional if the
+  /// point does not lie in the domain.
+  ///
+  /// Note: domains with local ids are not yet supported, and will assert-fail.
+  Optional<SmallVector<int64_t, 8>> valueAt(ArrayRef<int64_t> point) const;
+
+  /// Return whether `this` and `other` are equal as PWMAFunctions, i.e. whether
+  /// they have the same dimensions, the same domain and they take the same
+  /// value at every point in the domain.
+  bool isEqual(const PWMAFunction &other) const;
+
+  void print(raw_ostream &os) const;
+  void dump() const;
+
+private:
+  /// The list of pieces in this piece-wise MultiAffineFunction.
+  SmallVector<MultiAffineFunction, 4> pieces;
+
+  /// The number of dimensions ids in the domains.
+  unsigned numDims;
+  /// The number of symbol ids in the domains.
+  unsigned numSymbols;
+  /// The number of output ids.
+  unsigned numOutputs;
+};
+
+} // namespace mlir
+
+#endif // MLIR_ANALYSIS_PRESBURGER_PWMAFUNCTION_H
index c2458b1..313742f 100644 (file)
@@ -3,6 +3,7 @@ add_mlir_library(MLIRPresburger
   LinearTransform.cpp
   Matrix.cpp
   PresburgerSet.cpp
+  PWMAFunction.cpp
   Simplex.cpp
   Utils.cpp
 
index 837830e..e9a082e 100644 (file)
@@ -1065,24 +1065,17 @@ void IntegerPolyhedron::removeRedundantConstraints() {
   equalities.resizeVertically(pos);
 }
 
-/// Eliminate `pos2^th` local identifier, replacing its every instance with
-/// `pos1^th` local identifier. This function is intended to be used to remove
-/// redundancy when local variables at position `pos1` and `pos2` are restricted
-/// to have the same value.
-static void eliminateRedundantLocalId(IntegerPolyhedron &poly, unsigned pos1,
-                                      unsigned pos2) {
-
-  assert(pos1 < poly.getNumLocalIds() && "Invalid local id position");
-  assert(pos2 < poly.getNumLocalIds() && "Invalid local id position");
-
-  unsigned localOffset = poly.getNumDimAndSymbolIds();
-  pos1 += localOffset;
-  pos2 += localOffset;
-  for (unsigned i = 0, e = poly.getNumInequalities(); i < e; ++i)
-    poly.atIneq(i, pos1) += poly.atIneq(i, pos2);
-  for (unsigned i = 0, e = poly.getNumEqualities(); i < e; ++i)
-    poly.atEq(i, pos1) += poly.atEq(i, pos2);
-  poly.removeId(pos2);
+void IntegerPolyhedron::eliminateRedundantLocalId(unsigned posA,
+                                                  unsigned posB) {
+  assert(posA < getNumLocalIds() && "Invalid local id position");
+  assert(posB < getNumLocalIds() && "Invalid local id position");
+
+  unsigned localOffset = getIdKindOffset(IdKind::Local);
+  posA += localOffset;
+  posB += localOffset;
+  inequalities.addToColumn(posB, posA, 1);
+  equalities.addToColumn(posB, posA, 1);
+  removeId(posB);
 }
 
 /// Adds additional local ids to the sets such that they both have the union
@@ -1129,8 +1122,8 @@ void IntegerPolyhedron::mergeLocalIds(IntegerPolyhedron &other) {
   // Merge function that merges the local variables in both sets by treating
   // them as the same identifier.
   auto merge = [&polyA, &polyB](unsigned i, unsigned j) -> bool {
-    eliminateRedundantLocalId(polyA, i, j);
-    eliminateRedundantLocalId(polyB, i, j);
+    polyA.eliminateRedundantLocalId(i, j);
+    polyB.eliminateRedundantLocalId(i, j);
     return true;
   };
 
diff --git a/mlir/lib/Analysis/Presburger/PWMAFunction.cpp b/mlir/lib/Analysis/Presburger/PWMAFunction.cpp
new file mode 100644 (file)
index 0000000..385f135
--- /dev/null
@@ -0,0 +1,198 @@
+//===- PWMAFunction.cpp - MLIR PWMAFunction Class -------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Presburger/PWMAFunction.h"
+#include "mlir/Analysis/Presburger/Simplex.h"
+
+using namespace mlir;
+
+// Return the result of subtracting the two given vectors pointwise.
+// The vectors must be of the same size.
+// e.g., [3, 4, 6] - [2, 5, 1] = [1, -1, 5].
+static SmallVector<int64_t, 8> subtract(ArrayRef<int64_t> vecA,
+                                        ArrayRef<int64_t> vecB) {
+  assert(vecA.size() == vecB.size() &&
+         "Cannot subtract vectors of differing lengths!");
+  SmallVector<int64_t, 8> result;
+  result.reserve(vecA.size());
+  for (unsigned i = 0, e = vecA.size(); i < e; ++i)
+    result.push_back(vecA[i] - vecB[i]);
+  return result;
+}
+
+PresburgerSet PWMAFunction::getDomain() const {
+  PresburgerSet domain =
+      PresburgerSet::getEmptySet(getNumDimIds(), getNumSymbolIds());
+  for (const MultiAffineFunction &piece : pieces)
+    domain.unionPolyInPlace(piece.getDomain());
+  return domain;
+}
+
+Optional<SmallVector<int64_t, 8>>
+MultiAffineFunction::valueAt(ArrayRef<int64_t> point) const {
+  assert(getNumLocalIds() == 0 && "Local ids are not yet supported!");
+  assert(point.size() == getNumIds() && "Point has incorrect dimensionality!");
+
+  if (!getDomain().containsPoint(point))
+    return {};
+
+  // The point lies in the domain, so we need to compute the output value.
+  // The matrix `output` has an affine expression in the ith row, corresponding
+  // to the expression for the ith value in the output vector. The last column
+  // of the matrix contains the constant term. Let v be the input point with
+  // a 1 appended at the end. We can see that output * v gives the desired
+  // output vector.
+  SmallVector<int64_t, 8> pointHomogenous{llvm::to_vector(point)};
+  pointHomogenous.push_back(1);
+  SmallVector<int64_t, 8> result =
+      output.postMultiplyWithColumn(pointHomogenous);
+  assert(result.size() == getNumOutputs());
+  return result;
+}
+
+Optional<SmallVector<int64_t, 8>>
+PWMAFunction::valueAt(ArrayRef<int64_t> point) const {
+  assert(point.size() == getNumInputs() &&
+         "Point has incorrect dimensionality!");
+  for (const MultiAffineFunction &piece : pieces)
+    if (Optional<SmallVector<int64_t, 8>> output = piece.valueAt(point))
+      return output;
+  return {};
+}
+
+void MultiAffineFunction::print(raw_ostream &os) const {
+  os << "Domain:";
+  IntegerPolyhedron::print(os);
+  os << "Output:\n";
+  output.print(os);
+  os << "\n";
+}
+
+void MultiAffineFunction::dump() const { print(llvm::errs()); }
+
+bool MultiAffineFunction::isEqual(const MultiAffineFunction &other) const {
+  return hasCompatibleDimensions(other) &&
+         getDomain().isEqual(other.getDomain()) &&
+         isEqualWhereDomainsOverlap(other);
+}
+
+unsigned MultiAffineFunction::insertId(IdKind kind, unsigned pos,
+                                       unsigned num) {
+  unsigned absolutePos = getIdKindOffset(kind) + pos;
+  output.insertColumns(absolutePos, num);
+  return IntegerPolyhedron::insertId(kind, pos, num);
+}
+
+void MultiAffineFunction::swapId(unsigned posA, unsigned posB) {
+  output.swapColumns(posA, posB);
+  IntegerPolyhedron::swapId(posA, posB);
+}
+
+void MultiAffineFunction::removeIdRange(unsigned idStart, unsigned idLimit) {
+  output.removeColumns(idStart, idLimit - idStart);
+  IntegerPolyhedron::removeIdRange(idStart, idLimit);
+}
+
+void MultiAffineFunction::eliminateRedundantLocalId(unsigned posA,
+                                                    unsigned posB) {
+  output.addToColumn(posB, posA, /*scale=*/1);
+  IntegerPolyhedron::eliminateRedundantLocalId(posA, posB);
+}
+
+bool MultiAffineFunction::isEqualWhereDomainsOverlap(
+    MultiAffineFunction other) const {
+  if (!hasCompatibleDimensions(other))
+    return false;
+
+  // `commonFunc` has the same output as `this`.
+  MultiAffineFunction commonFunc = *this;
+  // After this merge, `commonFunc` and `other` have the same local ids; they
+  // are merged.
+  commonFunc.mergeLocalIds(other);
+  // After this, the domain of `commonFunc` will be the intersection of the
+  // domains of `this` and `other`.
+  commonFunc.IntegerPolyhedron::append(other);
+
+  // `commonDomainMatching` contains the subset of the common domain
+  // where the outputs of `this` and `other` match.
+  //
+  // We want to add constraints equating the outputs of `this` and `other`.
+  // However, `this` may have difference local ids from `other`, whereas we
+  // need both to have the same locals. Accordingly, we use `commonFunc.output`
+  // in place of `this->output`, since `commonFunc` has the same output but also
+  // has its locals merged.
+  IntegerPolyhedron commonDomainMatching = commonFunc.getDomain();
+  for (unsigned row = 0, e = getNumOutputs(); row < e; ++row)
+    commonDomainMatching.addEquality(
+        subtract(commonFunc.output.getRow(row), other.output.getRow(row)));
+
+  // If the whole common domain is a subset of commonDomainMatching, then they
+  // are equal and the two functions match on the whole common domain.
+  return commonFunc.getDomain().isSubsetOf(commonDomainMatching);
+}
+
+/// Two PWMAFunctions are equal if they have the same dimensionalities,
+/// the same domain, and take the same value at every point in the domain.
+bool PWMAFunction::isEqual(const PWMAFunction &other) const {
+  if (!hasCompatibleDimensions(other))
+    return false;
+
+  if (!this->getDomain().isEqual(other.getDomain()))
+    return false;
+
+  // Check if, whenever the domains of a piece of `this` and a piece of `other`
+  // overlap, they take the same output value. If `this` and `other` have the
+  // same domain (checked above), then this check passes iff the two functions
+  // have the same output at every point in the domain.
+  for (const MultiAffineFunction &aPiece : this->pieces)
+    for (const MultiAffineFunction &bPiece : other.pieces)
+      if (!aPiece.isEqualWhereDomainsOverlap(bPiece))
+        return false;
+  return true;
+}
+
+void PWMAFunction::addPiece(const MultiAffineFunction &piece) {
+  assert(hasCompatibleDimensions(piece) &&
+         "Piece to be added is not compatible with this PWMAFunction!");
+  assert(piece.isConsistent() && "Piece is internally inconsistent!");
+  assert(this->getDomain()
+             .intersect(PresburgerSet(piece.getDomain()))
+             .isIntegerEmpty() &&
+         "New piece's domain overlaps with that of existing pieces!");
+  pieces.push_back(piece);
+}
+
+void PWMAFunction::addPiece(const IntegerPolyhedron &domain,
+                            const Matrix &output) {
+  addPiece(MultiAffineFunction(domain, output));
+}
+
+void PWMAFunction::print(raw_ostream &os) const {
+  os << pieces.size() << " pieces:\n";
+  for (const MultiAffineFunction &piece : pieces)
+    piece.print(os);
+}
+
+/// The hasCompatibleDimensions functions don't check the number of local ids;
+/// functions are still compatible if they have differing number of locals.
+bool MultiAffineFunction::hasCompatibleDimensions(
+    const MultiAffineFunction &f) const {
+  return getNumDimIds() == f.getNumDimIds() &&
+         getNumSymbolIds() == f.getNumSymbolIds() &&
+         getNumOutputs() == f.getNumOutputs();
+}
+bool PWMAFunction::hasCompatibleDimensions(const MultiAffineFunction &f) const {
+  return getNumDimIds() == f.getNumDimIds() &&
+         getNumSymbolIds() == f.getNumSymbolIds() &&
+         getNumOutputs() == f.getNumOutputs();
+}
+bool PWMAFunction::hasCompatibleDimensions(const PWMAFunction &f) const {
+  return getNumDimIds() == f.getNumDimIds() &&
+         getNumSymbolIds() == f.getNumSymbolIds() &&
+         getNumOutputs() == f.getNumOutputs();
+}
index 2bdb2e7..e7142a7 100644 (file)
@@ -3,6 +3,7 @@ add_mlir_unittest(MLIRPresburgerTests
   LinearTransformTest.cpp
   MatrixTest.cpp
   PresburgerSetTest.cpp
+  PWMAFunctionTest.cpp
   SimplexTest.cpp
   ../../Dialect/Affine/Analysis/AffineStructuresParser.cpp
 )
diff --git a/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp b/mlir/unittests/Analysis/Presburger/PWMAFunctionTest.cpp
new file mode 100644 (file)
index 0000000..614f19c
--- /dev/null
@@ -0,0 +1,183 @@
+//===- PWMAFunctionTest.cpp - Tests for PWMAFunction ----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file contains tests for PWMAFunction.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/Presburger/PWMAFunction.h"
+#include "../../Dialect/Affine/Analysis/AffineStructuresParser.h"
+#include "mlir/Analysis/Presburger/PresburgerSet.h"
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+
+namespace mlir {
+using testing::ElementsAre;
+
+/// Parses an IntegerPolyhedron from a StringRef. It is expected that the
+/// string represents a valid IntegerSet, otherwise it will violate a gtest
+/// assertion.
+static IntegerPolyhedron parsePoly(StringRef str, MLIRContext *context) {
+  FailureOr<IntegerPolyhedron> poly = parseIntegerSetToFAC(str, context);
+  EXPECT_TRUE(succeeded(poly));
+  return *poly;
+}
+
+static Matrix makeMatrix(unsigned numRow, unsigned numColumns,
+                         ArrayRef<SmallVector<int64_t, 8>> matrix) {
+  Matrix results(numRow, numColumns);
+  assert(matrix.size() == numRow);
+  for (unsigned i = 0; i < numRow; ++i) {
+    assert(matrix[i].size() == numColumns &&
+           "Output expression has incorrect dimensionality!");
+    for (unsigned j = 0; j < numColumns; ++j)
+      results(i, j) = matrix[i][j];
+  }
+  return results;
+}
+
+/// Construct a PWMAFunction given the dimensionalities and an array describing
+/// the list of pieces. Each piece is given by a string describing the domain
+/// and a 2D array that represents the output.
+static PWMAFunction parsePWMAF(
+    unsigned numInputs, unsigned numOutputs,
+    ArrayRef<std::pair<StringRef, SmallVector<SmallVector<int64_t, 8>, 8>>>
+        data,
+    unsigned numSymbols = 0) {
+  static MLIRContext context;
+
+  PWMAFunction result(numInputs - numSymbols, numSymbols, numOutputs);
+  for (const auto &pair : data) {
+    IntegerPolyhedron domain = parsePoly(pair.first, &context);
+    result.addPiece(
+        domain, makeMatrix(numOutputs, domain.getNumIds() + 1, pair.second));
+  }
+  return result;
+}
+
+TEST(PWAFunctionTest, isEqual) {
+  MLIRContext context;
+
+  // The output expressions are different but it doesn't matter because they are
+  // equal in this domain.
+  PWMAFunction idAtZeros = parsePWMAF(
+      /*numInputs=*/2, /*numOutputs=*/2,
+      {
+          {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 1, 0}}},             // (x, y).
+          {"(x, y) : (y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
+          {"(x, y) : (-y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 1, 0}}} // (x, y).
+      });
+  PWMAFunction idAtZeros2 = parsePWMAF(
+      /*numInputs=*/2, /*numOutputs=*/2,
+      {
+          {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 20, 0}}}, // (x, 20y).
+          {"(x, y) : (y - 1 >= 0, x == 0)", {{30, 0, 0}, {0, 1, 0}}}, //(30x, y)
+          {"(x, y) : (-y - 1 > =0, x == 0)", {{30, 0, 0}, {0, 1, 0}}} //(30x, y)
+      });
+  EXPECT_TRUE(idAtZeros.isEqual(idAtZeros2));
+
+  PWMAFunction notIdAtZeros = parsePWMAF(
+      /*numInputs=*/2, /*numOutputs=*/2,
+      {
+          {"(x, y) : (y == 0)", {{1, 0, 0}, {0, 1, 0}}},              // (x, y).
+          {"(x, y) : (y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 2, 0}}},  // (x, 2y)
+          {"(x, y) : (-y - 1 >= 0, x == 0)", {{1, 0, 0}, {0, 2, 0}}}, // (x, 2y)
+      });
+  EXPECT_FALSE(idAtZeros.isEqual(notIdAtZeros));
+
+  // These match at their intersection but one has a bigger domain.
+  PWMAFunction idNoNegNegQuadrant = parsePWMAF(
+      /*numInputs=*/2, /*numOutputs=*/2,
+      {
+          {"(x, y) : (x >= 0)", {{1, 0, 0}, {0, 1, 0}}},             // (x, y).
+          {"(x, y) : (-x - 1 >= 0, y >= 0)", {{1, 0, 0}, {0, 1, 0}}} // (x, y).
+      });
+  PWMAFunction idOnlyPosX =
+      parsePWMAF(/*numInputs=*/2, /*numOutputs=*/2,
+                 {
+                     {"(x, y) : (x >= 0)", {{1, 0, 0}, {0, 1, 0}}}, // (x, y).
+                 });
+  EXPECT_FALSE(idNoNegNegQuadrant.isEqual(idOnlyPosX));
+
+  // Different representations of the same domain.
+  PWMAFunction sumPlusOne = parsePWMAF(
+      /*numInputs=*/2, /*numOutputs=*/1,
+      {
+          {"(x, y) : (x >= 0)", {{1, 1, 1}}},                   // x + y + 1.
+          {"(x, y) : (-x - 1 >= 0, -y - 1 >= 0)", {{1, 1, 1}}}, // x + y + 1.
+          {"(x, y) : (-x - 1 >= 0, y >= 0)", {{1, 1, 1}}}       // x + y + 1.
+      });
+  PWMAFunction sumPlusOne2 =
+      parsePWMAF(/*numInputs=*/2, /*numOutputs=*/1,
+                 {
+                     {"(x, y) : ()", {{1, 1, 1}}}, // x + y + 1.
+                 });
+  EXPECT_TRUE(sumPlusOne.isEqual(sumPlusOne2));
+
+  // Functions with zero input dimensions.
+  PWMAFunction noInputs1 = parsePWMAF(/*numInputs=*/0, /*numOutputs=*/1,
+                                      {
+                                          {"() : ()", {{1}}}, // 1.
+                                      });
+  PWMAFunction noInputs2 = parsePWMAF(/*numInputs=*/0, /*numOutputs=*/1,
+                                      {
+                                          {"() : ()", {{2}}}, // 1.
+                                      });
+  EXPECT_TRUE(noInputs1.isEqual(noInputs1));
+  EXPECT_FALSE(noInputs1.isEqual(noInputs2));
+
+  // Mismatched dimensionalities.
+  EXPECT_FALSE(noInputs1.isEqual(sumPlusOne));
+  EXPECT_FALSE(idOnlyPosX.isEqual(sumPlusOne));
+
+  // Divisions.
+  // Domain is only multiples of 6; x = 6k for some k.
+  // x + 4(x/2) + 4(x/3) == 26k.
+  PWMAFunction mul2AndMul3 = parsePWMAF(
+      /*numInputs=*/1, /*numOutputs=*/1,
+      {
+          {"(x) : (x - 2*(x floordiv 2) == 0, x - 3*(x floordiv 3) == 0)",
+           {{1, 4, 4, 0}}}, // x + 4(x/2) + 4(x/3).
+      });
+  PWMAFunction mul6 = parsePWMAF(
+      /*numInputs=*/1, /*numOutputs=*/1,
+      {
+          {"(x) : (x - 6*(x floordiv 6) == 0)", {{0, 26, 0}}}, // 26(x/6).
+      });
+  EXPECT_TRUE(mul2AndMul3.isEqual(mul6));
+
+  PWMAFunction mul6diff = parsePWMAF(
+      /*numInputs=*/1, /*numOutputs=*/1,
+      {
+          {"(x) : (x - 5*(x floordiv 5) == 0)", {{0, 52, 0}}}, // 52(x/6).
+      });
+  EXPECT_FALSE(mul2AndMul3.isEqual(mul6diff));
+
+  PWMAFunction mul5 = parsePWMAF(
+      /*numInputs=*/1, /*numOutputs=*/1,
+      {
+          {"(x) : (x - 5*(x floordiv 5) == 0)", {{0, 26, 0}}}, // 26(x/5).
+      });
+  EXPECT_FALSE(mul2AndMul3.isEqual(mul5));
+}
+
+TEST(PWMAFunction, valueAt) {
+  PWMAFunction nonNegPWAF = parsePWMAF(
+      /*numInputs=*/2, /*numOutputs=*/2,
+      {
+          {"(x, y) : (x >= 0)", {{1, 2, 3}, {3, 4, 5}}}, // (x, y).
+          {"(x, y) : (y >= 0, -x - 1 >= 0)", {{-1, 2, 3}, {-3, 4, 5}}} // (x, y)
+      });
+  EXPECT_THAT(*nonNegPWAF.valueAt({2, 3}), ElementsAre(11, 23));
+  EXPECT_THAT(*nonNegPWAF.valueAt({-2, 3}), ElementsAre(11, 23));
+  EXPECT_THAT(*nonNegPWAF.valueAt({2, -3}), ElementsAre(-1, -1));
+  EXPECT_FALSE(nonNegPWAF.valueAt({-2, -3}).hasValue());
+}
+
+} // namespace mlir