AffineExpr: factor uniqu'ing out of MLIRContext.cpp
authorAlex Zinenko <zinenko@google.com>
Tue, 21 May 2019 08:34:13 +0000 (01:34 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 02:53:43 +0000 (19:53 -0700)
    Affine expressions are designed as components of an attribute and are unique'd
    in the MLIRContext.  When affine expressions were implemented, uniqu'ing
    objects in a context required to modify MLIRContext implementation.  This is no
    longer the case as generic StorageUniquer has been introduced.  Port the
    AffineExpr construction to use the new infrastructure by introducing an
    affineUniquer into the MLIRContext.

--

PiperOrigin-RevId: 249207539

mlir/include/mlir/IR/AffineExpr.h
mlir/include/mlir/IR/MLIRContext.h
mlir/include/mlir/Support/StorageUniquer.h
mlir/lib/IR/AffineExpr.cpp
mlir/lib/IR/AffineExprDetail.h
mlir/lib/IR/MLIRContext.cpp
mlir/unittests/IR/SDBMTest.cpp

index ded953f..58b4fbc 100644 (file)
@@ -189,7 +189,7 @@ public:
 /// A symbolic identifier appearing in an affine expression.
 class AffineSymbolExpr : public AffineExpr {
 public:
-  using ImplType = detail::AffineSymbolExprStorage;
+  using ImplType = detail::AffineDimExprStorage;
   /* implicit */ AffineSymbolExpr(AffineExpr::ImplType *ptr);
   unsigned getPosition() const;
 };
index c74c814..864fd96 100644 (file)
@@ -82,6 +82,9 @@ public:
   /// Returns the diagnostic engine for this context.
   DiagnosticEngine &getDiagEngine();
 
+  /// Returns the storage uniquer used for creating affine constructs.
+  StorageUniquer &getAffineUniquer();
+
   /// Returns the storage uniquer used for constructing type storage instances.
   /// This should not be used directly.
   StorageUniquer &getTypeUniquer();
index 5b408f3..fc20db2 100644 (file)
@@ -35,9 +35,9 @@ struct StorageUniquerImpl;
 ///
 /// For non-parametric storage classes, i.e. those that are solely uniqued by
 /// their kind, nothing else is needed. Instances of these classes can be
-/// queried with 'getSimple'.
+/// created by calling `get` without trailing arguments.
 ///
-/// Otherwise, the parametric storage classes may be queried with 'getComplex',
+/// Otherwise, the parametric storage classes may be created with `get`,
 /// and must respect the following:
 ///    - Define a type alias, KeyTy, to a type that uniquely identifies the
 ///      instance of the storage class within its kind.
index 739be15..03dd4b8 100644 (file)
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/IntegerSet.h"
+#include "mlir/Support/MathExtras.h"
 #include "mlir/Support/STLExtras.h"
 #include "llvm/ADT/STLExtras.h"
 
 using namespace mlir;
 using namespace mlir::detail;
 
-MLIRContext *AffineExpr::getContext() const {
-  return expr->contextAndKind.getPointer();
-}
+MLIRContext *AffineExpr::getContext() const { return expr->context; }
 
 AffineExprKind AffineExpr::getKind() const {
-  return expr->contextAndKind.getInt();
+  return static_cast<AffineExprKind>(expr->getKind());
 }
 
 /// Walk all of the AffineExprs in this subgraph in postorder.
@@ -51,6 +50,23 @@ void AffineExpr::walk(std::function<void(AffineExpr)> callback) const {
   AffineExprWalker(callback).walkPostOrder(*this);
 }
 
+// Dispatch affine expression construction based on kind.
+AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
+                                       AffineExpr rhs) {
+  if (kind == AffineExprKind::Add)
+    return lhs + rhs;
+  if (kind == AffineExprKind::Mul)
+    return lhs * rhs;
+  if (kind == AffineExprKind::FloorDiv)
+    return lhs.floorDiv(rhs);
+  if (kind == AffineExprKind::CeilDiv)
+    return lhs.ceilDiv(rhs);
+  if (kind == AffineExprKind::Mod)
+    return lhs % rhs;
+
+  llvm_unreachable("unknown binary operation on affine expressions");
+}
+
 /// This method substitutes any uses of dimensions and symbols (e.g.
 /// dim#0 with dimReplacements[0]) and returns the modified expression tree.
 AffineExpr
@@ -231,65 +247,313 @@ unsigned AffineDimExpr::getPosition() const {
   return static_cast<ImplType *>(expr)->position;
 }
 
+static AffineExpr getAffineDimOrSymbol(AffineExprKind kind, unsigned position,
+                                       MLIRContext *context) {
+  auto assignCtx = [context](AffineDimExprStorage *storage) {
+    storage->context = context;
+  };
+
+  StorageUniquer &uniquer = context->getAffineUniquer();
+  return uniquer.get<AffineDimExprStorage>(
+      assignCtx, static_cast<unsigned>(kind), position);
+}
+
+AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
+  return getAffineDimOrSymbol(AffineExprKind::DimId, position, context);
+}
+
 AffineSymbolExpr::AffineSymbolExpr(AffineExpr::ImplType *ptr)
     : AffineExpr(ptr) {}
 unsigned AffineSymbolExpr::getPosition() const {
   return static_cast<ImplType *>(expr)->position;
 }
 
+AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
+  return getAffineDimOrSymbol(AffineExprKind::SymbolId, position, context);
+  ;
+}
+
 AffineConstantExpr::AffineConstantExpr(AffineExpr::ImplType *ptr)
     : AffineExpr(ptr) {}
 int64_t AffineConstantExpr::getValue() const {
   return static_cast<ImplType *>(expr)->constant;
 }
 
+AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
+  auto assignCtx = [context](AffineConstantExprStorage *storage) {
+    storage->context = context;
+  };
+
+  StorageUniquer &uniquer = context->getAffineUniquer();
+  return uniquer.get<AffineConstantExprStorage>(
+      assignCtx, static_cast<unsigned>(AffineExprKind::Constant), constant);
+}
+
+/// Simplify add expression. Return nullptr if it can't be simplified.
+static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
+  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
+  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+  // Fold if both LHS, RHS are a constant.
+  if (lhsConst && rhsConst)
+    return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
+                                 lhs.getContext());
+
+  // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
+  // If only one of them is a symbolic expressions, make it the RHS.
+  if (lhs.isa<AffineConstantExpr>() ||
+      (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
+    return rhs + lhs;
+  }
+
+  // At this point, if there was a constant, it would be on the right.
+
+  // Addition with a zero is a noop, return the other input.
+  if (rhsConst) {
+    if (rhsConst.getValue() == 0)
+      return lhs;
+  }
+  // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
+  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
+  if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
+    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
+      return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
+  }
+
+  // When doing successive additions, bring constant to the right: turn (d0 + 2)
+  // + d1 into (d0 + d1) + 2.
+  if (lBin && lBin.getKind() == AffineExprKind::Add) {
+    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
+      return lBin.getLHS() + rhs + lrhs;
+    }
+  }
+
+  // Detect and transform "expr - c * (expr floordiv c)" to "expr mod c". This
+  // leads to a much more efficient form when 'c' is a power of two, and in
+  // general a more compact and readable form.
+
+  // Process '(expr floordiv c) * (-c)'.
+  AffineBinaryOpExpr rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>();
+  if (!rBinOpExpr)
+    return nullptr;
+
+  auto lrhs = rBinOpExpr.getLHS();
+  auto rrhs = rBinOpExpr.getRHS();
+
+  // Process lrhs, which is 'expr floordiv c'.
+  AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
+  if (!lrBinOpExpr)
+    return nullptr;
+
+  auto llrhs = lrBinOpExpr.getLHS();
+  auto rlrhs = lrBinOpExpr.getRHS();
+
+  if (lhs == llrhs && rlrhs == -rrhs) {
+    return lhs % rlrhs;
+  }
+  return nullptr;
+}
+
 AffineExpr AffineExpr::operator+(int64_t v) const {
-  return AffineBinaryOpExprStorage::get(AffineExprKind::Add, expr,
-                                        getAffineConstantExpr(v, getContext()));
+  return *this + getAffineConstantExpr(v, getContext());
 }
 AffineExpr AffineExpr::operator+(AffineExpr other) const {
-  return AffineBinaryOpExprStorage::get(AffineExprKind::Add, expr, other.expr);
+  if (auto simplified = simplifyAdd(*this, other))
+    return simplified;
+
+  StorageUniquer &uniquer = getContext()->getAffineUniquer();
+  return uniquer.get<AffineBinaryOpExprStorage>(
+      /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Add), *this, other);
+}
+
+/// Simplify a multiply expression. Return nullptr if it can't be simplified.
+static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
+  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
+  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+
+  if (lhsConst && rhsConst)
+    return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
+                                 lhs.getContext());
+
+  assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant());
+
+  // Canonicalize the mul expression so that the constant/symbolic term is the
+  // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
+  // constant. (Note that a constant is trivially symbolic).
+  if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) {
+    // At least one of them has to be symbolic.
+    return rhs * lhs;
+  }
+
+  // At this point, if there was a constant, it would be on the right.
+
+  // Multiplication with a one is a noop, return the other input.
+  if (rhsConst) {
+    if (rhsConst.getValue() == 1)
+      return lhs;
+    // Multiplication with zero.
+    if (rhsConst.getValue() == 0)
+      return rhsConst;
+  }
+
+  // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
+  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
+  if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
+    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
+      return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
+  }
+
+  // When doing successive multiplication, bring constant to the right: turn (d0
+  // * 2) * d1 into (d0 * d1) * 2.
+  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
+    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
+      return (lBin.getLHS() * rhs) * lrhs;
+    }
+  }
+
+  return nullptr;
 }
+
 AffineExpr AffineExpr::operator*(int64_t v) const {
-  return AffineBinaryOpExprStorage::get(AffineExprKind::Mul, expr,
-                                        getAffineConstantExpr(v, getContext()));
+  return *this * getAffineConstantExpr(v, getContext());
 }
 AffineExpr AffineExpr::operator*(AffineExpr other) const {
-  return AffineBinaryOpExprStorage::get(AffineExprKind::Mul, expr, other.expr);
+  if (auto simplified = simplifyMul(*this, other))
+    return simplified;
+
+  StorageUniquer &uniquer = getContext()->getAffineUniquer();
+  return uniquer.get<AffineBinaryOpExprStorage>(
+      /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mul), *this, other);
 }
+
 // Unary minus, delegate to operator*.
 AffineExpr AffineExpr::operator-() const {
-  return AffineBinaryOpExprStorage::get(
-      AffineExprKind::Mul, expr, getAffineConstantExpr(-1, getContext()));
+  return *this * getAffineConstantExpr(-1, getContext());
 }
+
 // Delegate to operator+.
 AffineExpr AffineExpr::operator-(int64_t v) const { return *this + (-v); }
 AffineExpr AffineExpr::operator-(AffineExpr other) const {
   return *this + (-other);
 }
+
+static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
+  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
+  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+
+  if (!rhsConst || rhsConst.getValue() < 1)
+    return nullptr;
+
+  if (lhsConst)
+    return getAffineConstantExpr(
+        floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
+
+  // Fold floordiv of a multiply with a constant that is a multiple of the
+  // divisor. Eg: (i * 128) floordiv 64 = i * 2.
+  if (rhsConst.getValue() == 1)
+    return lhs;
+
+  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
+  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
+    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
+      // rhsConst is known to be positive if a constant.
+      if (lrhs.getValue() % rhsConst.getValue() == 0)
+        return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
+    }
+  }
+
+  return nullptr;
+}
+
 AffineExpr AffineExpr::floorDiv(uint64_t v) const {
-  return AffineBinaryOpExprStorage::get(AffineExprKind::FloorDiv, expr,
-                                        getAffineConstantExpr(v, getContext()));
+  return floorDiv(getAffineConstantExpr(v, getContext()));
 }
 AffineExpr AffineExpr::floorDiv(AffineExpr other) const {
-  return AffineBinaryOpExprStorage::get(AffineExprKind::FloorDiv, expr,
-                                        other.expr);
+  if (auto simplified = simplifyFloorDiv(*this, other))
+    return simplified;
+
+  StorageUniquer &uniquer = getContext()->getAffineUniquer();
+  return uniquer.get<AffineBinaryOpExprStorage>(
+      /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::FloorDiv), *this,
+      other);
 }
+
+static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
+  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
+  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+
+  if (!rhsConst || rhsConst.getValue() < 1)
+    return nullptr;
+
+  if (lhsConst)
+    return getAffineConstantExpr(
+        ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
+
+  // Fold ceildiv of a multiply with a constant that is a multiple of the
+  // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
+  if (rhsConst.getValue() == 1)
+    return lhs;
+
+  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
+  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
+    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
+      // rhsConst is known to be positive if a constant.
+      if (lrhs.getValue() % rhsConst.getValue() == 0)
+        return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
+    }
+  }
+
+  return nullptr;
+}
+
 AffineExpr AffineExpr::ceilDiv(uint64_t v) const {
-  return AffineBinaryOpExprStorage::get(AffineExprKind::CeilDiv, expr,
-                                        getAffineConstantExpr(v, getContext()));
+  return ceilDiv(getAffineConstantExpr(v, getContext()));
 }
 AffineExpr AffineExpr::ceilDiv(AffineExpr other) const {
-  return AffineBinaryOpExprStorage::get(AffineExprKind::CeilDiv, expr,
-                                        other.expr);
+  if (auto simplified = simplifyCeilDiv(*this, other))
+    return simplified;
+
+  StorageUniquer &uniquer = getContext()->getAffineUniquer();
+  return uniquer.get<AffineBinaryOpExprStorage>(
+      /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::CeilDiv), *this,
+      other);
+}
+
+static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
+  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
+  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
+
+  if (!rhsConst || rhsConst.getValue() < 1)
+    return nullptr;
+
+  if (lhsConst)
+    return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
+                                 lhs.getContext());
+
+  // Fold modulo of an expression that is known to be a multiple of a constant
+  // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
+  // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
+  if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
+    return getAffineConstantExpr(0, lhs.getContext());
+
+  return nullptr;
+  // TODO(bondhugula): In general, this can be simplified more by using the GCD
+  // test, or in general using quantifier elimination (add two new variables q
+  // and r, and eliminate all variables from the linear system other than r. All
+  // of this can be done through mlir/Analysis/'s FlatAffineConstraints.
 }
+
 AffineExpr AffineExpr::operator%(uint64_t v) const {
-  return AffineBinaryOpExprStorage::get(AffineExprKind::Mod, expr,
-                                        getAffineConstantExpr(v, getContext()));
+  return *this % getAffineConstantExpr(v, getContext());
 }
 AffineExpr AffineExpr::operator%(AffineExpr other) const {
-  return AffineBinaryOpExprStorage::get(AffineExprKind::Mod, expr, other.expr);
+  if (auto simplified = simplifyMod(*this, other))
+    return simplified;
+
+  StorageUniquer &uniquer = getContext()->getAffineUniquer();
+  return uniquer.get<AffineBinaryOpExprStorage>(
+      /*initFn=*/{}, static_cast<unsigned>(AffineExprKind::Mod), *this, other);
 }
+
 AffineExpr AffineExpr::compose(AffineMap map) const {
   SmallVector<AffineExpr, 8> dimReplacements(map.getResults().begin(),
                                              map.getResults().end());
index bca0957..214fee6 100644 (file)
@@ -25,7 +25,7 @@
 
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/MLIRContext.h"
-#include "llvm/ADT/PointerIntPair.h"
+#include "mlir/Support/StorageUniquer.h"
 
 namespace mlir {
 
@@ -34,42 +34,61 @@ class MLIRContext;
 namespace detail {
 
 /// Base storage class appearing in an affine expression.
-struct AffineExprStorage {
-  AffineExprStorage(AffineExprKind kind, MLIRContext *context)
-      : contextAndKind(context, kind) {}
-  llvm::PointerIntPair<MLIRContext *, 3, AffineExprKind> contextAndKind;
+struct AffineExprStorage : public StorageUniquer::BaseStorage {
+  MLIRContext *context;
 };
 
 /// A binary operation appearing in an affine expression.
 struct AffineBinaryOpExprStorage : public AffineExprStorage {
-  AffineBinaryOpExprStorage(AffineExprStorage base, AffineExpr lhs,
-                            AffineExpr rhs)
-      : AffineExprStorage(base), lhs(lhs), rhs(rhs) {}
-  static AffineExpr get(AffineExprKind kind, AffineExpr lhs, AffineExpr rhs);
+  using KeyTy = std::pair<AffineExpr, AffineExpr>;
+
+  bool operator==(const KeyTy &key) const {
+    return key.first == lhs && key.second == rhs;
+  }
+
+  static AffineBinaryOpExprStorage *
+  construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
+    auto *result = allocator.allocate<AffineBinaryOpExprStorage>();
+    result->lhs = key.first;
+    result->rhs = key.second;
+    result->context = result->lhs.getContext();
+    return result;
+  }
+
   AffineExpr lhs;
   AffineExpr rhs;
 };
 
-/// A dimensional identifier appearing in an affine expression.
+/// A dimensional or symbolic identifier appearing in an affine expression.
 struct AffineDimExprStorage : public AffineExprStorage {
-  AffineDimExprStorage(AffineExprStorage base, unsigned position)
-      : AffineExprStorage(base), position(position) {}
-  /// Position of this identifier in the argument list.
-  unsigned position;
-};
+  using KeyTy = unsigned;
 
-/// A symbolic identifier appearing in an affine expression.
-struct AffineSymbolExprStorage : public AffineExprStorage {
-  AffineSymbolExprStorage(AffineExprStorage base, unsigned position)
-      : AffineExprStorage(base), position(position) {}
-  /// Position of this identifier in the symbol list.
+  bool operator==(const KeyTy &key) const { return position == key; }
+
+  static AffineDimExprStorage *
+  construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
+    auto *result = allocator.allocate<AffineDimExprStorage>();
+    result->position = key;
+    return result;
+  }
+
+  /// Position of this identifier in the argument list.
   unsigned position;
 };
 
 /// An integer constant appearing in affine expression.
 struct AffineConstantExprStorage : public AffineExprStorage {
-  AffineConstantExprStorage(AffineExprStorage base, int64_t constant)
-      : AffineExprStorage(base), constant(constant) {}
+  using KeyTy = int64_t;
+
+  bool operator==(const KeyTy &key) const { return constant == key; }
+
+  static AffineConstantExprStorage *
+  construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
+    auto *result = allocator.allocate<AffineConstantExprStorage>();
+    result->constant = key;
+    return result;
+  }
+
   // The constant.
   int64_t constant;
 };
index ab05954..b54e956 100644 (file)
@@ -33,7 +33,6 @@
 #include "mlir/IR/IntegerSet.h"
 #include "mlir/IR/Location.h"
 #include "mlir/IR/Types.h"
-#include "mlir/Support/MathExtras.h"
 #include "mlir/Support/STLExtras.h"
 #include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/SetVector.h"
@@ -352,17 +351,8 @@ public:
   using IntegerSets = DenseSet<IntegerSet, IntegerSetKeyInfo>;
   IntegerSets integerSets;
 
-  // Affine binary op expression uniquing. Figure out uniquing of dimensional
-  // or symbolic identifiers.
-  DenseMap<std::tuple<unsigned, AffineExpr, AffineExpr>, AffineExpr>
-      affineExprs;
-
-  // Uniqui'ing of AffineDimExpr, AffineSymbolExpr's by their position.
-  std::vector<AffineDimExprStorage *> dimExprs;
-  std::vector<AffineSymbolExprStorage *> symbolExprs;
-
-  // Uniqui'ing of AffineConstantExprStorage using constant value as key.
-  DenseMap<int64_t, AffineConstantExprStorage *> constExprs;
+  // Affine expression uniqui'ing.
+  StorageUniquer affineUniquer;
 
   //===--------------------------------------------------------------------===//
   // SDBM uniquing
@@ -918,9 +908,13 @@ AttributeListStorage::get(ArrayRef<NamedAttribute> attrs) {
 }
 
 //===----------------------------------------------------------------------===//
-// AffineMap and AffineExpr uniquing
+// AffineMap uniquing
 //===----------------------------------------------------------------------===//
 
+StorageUniquer &MLIRContext::getAffineUniquer() {
+  return getImpl().affineUniquer;
+}
+
 AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
                          ArrayRef<AffineExpr> results,
                          ArrayRef<AffineExpr> rangeSizes) {
@@ -947,300 +941,6 @@ AffineMap AffineMap::get(unsigned dimCount, unsigned symbolCount,
   });
 }
 
-/// Simplify add expression. Return nullptr if it can't be simplified.
-static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
-  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
-  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
-  // Fold if both LHS, RHS are a constant.
-  if (lhsConst && rhsConst)
-    return getAffineConstantExpr(lhsConst.getValue() + rhsConst.getValue(),
-                                 lhs.getContext());
-
-  // Canonicalize so that only the RHS is a constant. (4 + d0 becomes d0 + 4).
-  // If only one of them is a symbolic expressions, make it the RHS.
-  if (lhs.isa<AffineConstantExpr>() ||
-      (lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant())) {
-    return rhs + lhs;
-  }
-
-  // At this point, if there was a constant, it would be on the right.
-
-  // Addition with a zero is a noop, return the other input.
-  if (rhsConst) {
-    if (rhsConst.getValue() == 0)
-      return lhs;
-  }
-  // Fold successive additions like (d0 + 2) + 3 into d0 + 5.
-  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
-  if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Add) {
-    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
-      return lBin.getLHS() + (lrhs.getValue() + rhsConst.getValue());
-  }
-
-  // When doing successive additions, bring constant to the right: turn (d0 + 2)
-  // + d1 into (d0 + d1) + 2.
-  if (lBin && lBin.getKind() == AffineExprKind::Add) {
-    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
-      return lBin.getLHS() + rhs + lrhs;
-    }
-  }
-
-  // Detect and transform "expr - c * (expr floordiv c)" to "expr mod c". This
-  // leads to a much more efficient form when 'c' is a power of two, and in
-  // general a more compact and readable form.
-
-  // Process '(expr floordiv c) * (-c)'.
-  AffineBinaryOpExpr rBinOpExpr = rhs.dyn_cast<AffineBinaryOpExpr>();
-  if (!rBinOpExpr)
-    return nullptr;
-
-  auto lrhs = rBinOpExpr.getLHS();
-  auto rrhs = rBinOpExpr.getRHS();
-
-  // Process lrhs, which is 'expr floordiv c'.
-  AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
-  if (!lrBinOpExpr)
-    return nullptr;
-
-  auto llrhs = lrBinOpExpr.getLHS();
-  auto rlrhs = lrBinOpExpr.getRHS();
-
-  if (lhs == llrhs && rlrhs == -rrhs) {
-    return lhs % rlrhs;
-  }
-  return nullptr;
-}
-
-/// Simplify a multiply expression. Return nullptr if it can't be simplified.
-static AffineExpr simplifyMul(AffineExpr lhs, AffineExpr rhs) {
-  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
-  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
-
-  if (lhsConst && rhsConst)
-    return getAffineConstantExpr(lhsConst.getValue() * rhsConst.getValue(),
-                                 lhs.getContext());
-
-  assert(lhs.isSymbolicOrConstant() || rhs.isSymbolicOrConstant());
-
-  // Canonicalize the mul expression so that the constant/symbolic term is the
-  // RHS. If both the lhs and rhs are symbolic, swap them if the lhs is a
-  // constant. (Note that a constant is trivially symbolic).
-  if (!rhs.isSymbolicOrConstant() || lhs.isa<AffineConstantExpr>()) {
-    // At least one of them has to be symbolic.
-    return rhs * lhs;
-  }
-
-  // At this point, if there was a constant, it would be on the right.
-
-  // Multiplication with a one is a noop, return the other input.
-  if (rhsConst) {
-    if (rhsConst.getValue() == 1)
-      return lhs;
-    // Multiplication with zero.
-    if (rhsConst.getValue() == 0)
-      return rhsConst;
-  }
-
-  // Fold successive multiplications: eg: (d0 * 2) * 3 into d0 * 6.
-  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
-  if (lBin && rhsConst && lBin.getKind() == AffineExprKind::Mul) {
-    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>())
-      return lBin.getLHS() * (lrhs.getValue() * rhsConst.getValue());
-  }
-
-  // When doing successive multiplication, bring constant to the right: turn (d0
-  // * 2) * d1 into (d0 * d1) * 2.
-  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
-    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
-      return (lBin.getLHS() * rhs) * lrhs;
-    }
-  }
-
-  return nullptr;
-}
-
-static AffineExpr simplifyFloorDiv(AffineExpr lhs, AffineExpr rhs) {
-  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
-  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
-
-  if (!rhsConst || rhsConst.getValue() < 1)
-    return nullptr;
-
-  if (lhsConst)
-    return getAffineConstantExpr(
-        floorDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
-
-  // Fold floordiv of a multiply with a constant that is a multiple of the
-  // divisor. Eg: (i * 128) floordiv 64 = i * 2.
-  if (rhsConst.getValue() == 1)
-    return lhs;
-
-  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
-  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
-    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
-      // rhsConst is known to be positive if a constant.
-      if (lrhs.getValue() % rhsConst.getValue() == 0)
-        return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
-    }
-  }
-
-  return nullptr;
-}
-
-static AffineExpr simplifyCeilDiv(AffineExpr lhs, AffineExpr rhs) {
-  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
-  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
-
-  if (!rhsConst || rhsConst.getValue() < 1)
-    return nullptr;
-
-  if (lhsConst)
-    return getAffineConstantExpr(
-        ceilDiv(lhsConst.getValue(), rhsConst.getValue()), lhs.getContext());
-
-  // Fold ceildiv of a multiply with a constant that is a multiple of the
-  // divisor. Eg: (i * 128) ceildiv 64 = i * 2.
-  if (rhsConst.getValue() == 1)
-    return lhs;
-
-  auto lBin = lhs.dyn_cast<AffineBinaryOpExpr>();
-  if (lBin && lBin.getKind() == AffineExprKind::Mul) {
-    if (auto lrhs = lBin.getRHS().dyn_cast<AffineConstantExpr>()) {
-      // rhsConst is known to be positive if a constant.
-      if (lrhs.getValue() % rhsConst.getValue() == 0)
-        return lBin.getLHS() * (lrhs.getValue() / rhsConst.getValue());
-    }
-  }
-
-  return nullptr;
-}
-
-static AffineExpr simplifyMod(AffineExpr lhs, AffineExpr rhs) {
-  auto lhsConst = lhs.dyn_cast<AffineConstantExpr>();
-  auto rhsConst = rhs.dyn_cast<AffineConstantExpr>();
-
-  if (!rhsConst || rhsConst.getValue() < 1)
-    return nullptr;
-
-  if (lhsConst)
-    return getAffineConstantExpr(mod(lhsConst.getValue(), rhsConst.getValue()),
-                                 lhs.getContext());
-
-  // Fold modulo of an expression that is known to be a multiple of a constant
-  // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
-  // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
-  if (lhs.getLargestKnownDivisor() % rhsConst.getValue() == 0)
-    return getAffineConstantExpr(0, lhs.getContext());
-
-  return nullptr;
-  // TODO(bondhugula): In general, this can be simplified more by using the GCD
-  // test, or in general using quantifier elimination (add two new variables q
-  // and r, and eliminate all variables from the linear system other than r. All
-  // of this can be done through mlir/Analysis/'s FlatAffineConstraints.
-}
-
-/// Return a binary affine op expression with the specified op type and
-/// operands: if it doesn't exist, create it and store it; if it is already
-/// present, return from the list. The stored expressions are unique: they are
-/// constructed and stored in a simplified/canonicalized form. The result after
-/// simplification could be any form of affine expression.
-AffineExpr AffineBinaryOpExprStorage::get(AffineExprKind kind, AffineExpr lhs,
-                                          AffineExpr rhs) {
-  auto &impl = lhs.getContext()->getImpl();
-
-  // Check if we already have this affine expression, and return it if we do.
-  auto keyValue = std::make_tuple((unsigned)kind, lhs, rhs);
-
-  { // Check for an existing instance in read-only mode.
-    llvm::sys::SmartScopedReader<true> affineLock(impl.affineMutex);
-    auto cached = impl.affineExprs.find(keyValue);
-    if (cached != impl.affineExprs.end())
-      return cached->second;
-  }
-
-  // Simplify the expression if possible.
-  AffineExpr simplified;
-  switch (kind) {
-  case AffineExprKind::Add:
-    simplified = simplifyAdd(lhs, rhs);
-    break;
-  case AffineExprKind::Mul:
-    simplified = simplifyMul(lhs, rhs);
-    break;
-  case AffineExprKind::FloorDiv:
-    simplified = simplifyFloorDiv(lhs, rhs);
-    break;
-  case AffineExprKind::CeilDiv:
-    simplified = simplifyCeilDiv(lhs, rhs);
-    break;
-  case AffineExprKind::Mod:
-    simplified = simplifyMod(lhs, rhs);
-    break;
-  default:
-    llvm_unreachable("unexpected binary affine expr");
-  }
-
-  // The simplified one would have already been cached; just return it.
-  if (simplified)
-    return simplified;
-
-  // Aquire a writer-lock so that we can safely create the new instance.
-  llvm::sys::SmartScopedWriter<true> affineLock(impl.affineMutex);
-
-  // Check for an existing instance again here, because another writer thread
-  // may have already created one.
-  auto &result = impl.affineExprs.insert({keyValue, nullptr}).first->second;
-  if (!result) {
-    // An expression with these operands will already be in the
-    // simplified/canonical form. Create and store it.
-    result = new (impl.affineAllocator.Allocate<AffineBinaryOpExprStorage>())
-        AffineBinaryOpExprStorage{{kind, lhs.getContext()}, lhs, rhs};
-  }
-  return result;
-}
-
-AffineExpr mlir::getAffineBinaryOpExpr(AffineExprKind kind, AffineExpr lhs,
-                                       AffineExpr rhs) {
-  return AffineBinaryOpExprStorage::get(kind, lhs, rhs);
-}
-
-AffineExpr mlir::getAffineDimExpr(unsigned position, MLIRContext *context) {
-  auto &impl = context->getImpl();
-
-  return safeGetOrCreate(
-      impl.dimExprs, position, impl.affineMutex, [&impl, context, position] {
-        auto *result = impl.affineAllocator.Allocate<AffineDimExprStorage>();
-        // Initialize the memory using placement new.
-        new (result)
-            AffineDimExprStorage{{AffineExprKind::DimId, context}, position};
-        return result;
-      });
-}
-
-AffineExpr mlir::getAffineSymbolExpr(unsigned position, MLIRContext *context) {
-  auto &impl = context->getImpl();
-
-  return safeGetOrCreate(
-      impl.symbolExprs, position, impl.affineMutex, [&impl, context, position] {
-        auto *result = impl.affineAllocator.Allocate<AffineSymbolExprStorage>();
-        // Initialize the memory using placement new.
-        new (result) AffineSymbolExprStorage{
-            {AffineExprKind::SymbolId, context}, position};
-        return result;
-      });
-}
-
-AffineExpr mlir::getAffineConstantExpr(int64_t constant, MLIRContext *context) {
-  auto &impl = context->getImpl();
-
-  // Safely get or create an AffineConstantExpr instance.
-  return safeGetOrCreate(impl.constExprs, constant, impl.affineMutex, [&] {
-    auto *result = impl.affineAllocator.Allocate<AffineConstantExprStorage>();
-    return new (result) AffineConstantExprStorage{
-        {AffineExprKind::Constant, context}, constant};
-  });
-}
-
 //===----------------------------------------------------------------------===//
 // Integer Sets: these are allocated into the bump pointer, and are immutable.
 // Unlike AffineMap's, these are uniqued only if they are small.
index 307bb38..2676513 100644 (file)
@@ -321,8 +321,8 @@ TEST(SDBMExpr, MatchStripeMulPattern) {
   // pattern (x floordiv B) * B == x # B.
   auto cst = getAffineConstantExpr(42, ctx());
   auto dim = getAffineDimExpr(0, ctx());
-  auto floor = getAffineBinaryOpExpr(AffineExprKind::FloorDiv, dim, cst);
-  auto mul = getAffineBinaryOpExpr(AffineExprKind::Mul, cst, floor);
+  auto floor = dim.floorDiv(cst);
+  auto mul = cst * floor;
   Optional<SDBMExpr> converted = SDBMStripeExpr::tryConvertAffineExpr(mul);
   ASSERT_TRUE(converted.hasValue());
   EXPECT_TRUE(converted->isa<SDBMStripeExpr>());
@@ -331,10 +331,10 @@ TEST(SDBMExpr, MatchStripeMulPattern) {
 TEST(SDBMExpr, NonSDBM) {
   auto d0 = getAffineDimExpr(0, ctx());
   auto d1 = getAffineDimExpr(1, ctx());
-  auto sum = getAffineBinaryOpExpr(AffineExprKind::Add, d0, d1);
+  auto sum = d0 + d1;
   auto c2 = getAffineConstantExpr(2, ctx());
-  auto prod = getAffineBinaryOpExpr(AffineExprKind::Mul, d0, c2);
-  auto ceildiv = getAffineBinaryOpExpr(AffineExprKind::CeilDiv, d1, c2);
+  auto prod = d0 * c2;
+  auto ceildiv = d1.ceilDiv(c2);
 
   // The following are not valid SDBM expressions:
   // - a sum of two variables