SDBMExpr: factor uniqu'ing out of MLIRContext.cpp
authorAlex Zinenko <zinenko@google.com>
Tue, 21 May 2019 14:22:00 +0000 (07:22 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 02:54:22 +0000 (19:54 -0700)
    SDBM expressions are designed as components of an attribute, similarly to
    affine expressions.  As such, they need to be unique'd in the MLIRContext.
    When SDBM expressions were implemented, uniqu'ing objects in a context required
    to modify MLIRContext implementation.  This is no longer the case as generic
    StorageUniquer has been introduced.  Port the SDBMExpr uniqu'ing to use a newly
    introduced uniquer and remove SDBM construction from MLIRContext.cpp.

--

PiperOrigin-RevId: 249244739

mlir/include/mlir/IR/MLIRContext.h
mlir/include/mlir/Support/StorageUniquer.h
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/SDBMExpr.cpp
mlir/lib/IR/SDBMExprDetail.h

index 864fd96..9309d51 100644 (file)
@@ -93,6 +93,10 @@ public:
   /// instances. This should not be used directly.
   StorageUniquer &getAttributeUniquer();
 
+  /// Returns the storage uniquer used for constructing SDBM expressions, which
+  /// are not attributes by themselves. This should not be used directly.
+  StorageUniquer &getSDBMUniquer();
+
 private:
   const std::unique_ptr<MLIRContextImpl> impl;
 
index fc20db2..3112a36 100644 (file)
@@ -50,7 +50,7 @@ struct StorageUniquerImpl;
 ///    - Provide a method, 'bool operator==(const KeyTy &) const', to
 ///      compare the storage instance against an instance of the key type.
 ///
-///    - Provide a construction method:
+///    - Provide a static construction method:
 ///        'DerivedStorage *construct(StorageAllocator &, const KeyTy &key)'
 ///      that builds a unique instance of the derived storage. The arguments to
 ///      this function are an allocator to store any uniqued data and the key
index b54e956..c41d72b 100644 (file)
@@ -21,7 +21,6 @@
 #include "AttributeDetail.h"
 #include "IntegerSetDetail.h"
 #include "LocationDetail.h"
-#include "SDBMExprDetail.h"
 #include "TypeDetail.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
@@ -357,22 +356,7 @@ public:
   //===--------------------------------------------------------------------===//
   // SDBM uniquing
   //===--------------------------------------------------------------------===//
-  llvm::BumpPtrAllocator SDBMAllocator;
-  llvm::sys::SmartRWMutex<true> SDBMMutex;
-
-  DenseMap<std::tuple<SDBMVaryingExpr, SDBMConstantExpr>,
-           SDBMBinaryExprStorage *>
-      SDBMSumExprs;
-  DenseMap<std::tuple<SDBMPositiveExpr, SDBMConstantExpr>,
-           SDBMBinaryExprStorage *>
-      SDBMStripeExprs;
-  DenseMap<std::tuple<SDBMPositiveExpr, SDBMPositiveExpr>,
-           SDBMDiffExprStorage *>
-      SDBMDiffExprs;
-  std::vector<SDBMPositiveExprStorage *> SDBMDimExprs;
-  std::vector<SDBMPositiveExprStorage *> SDBMSymbolExprs;
-  DenseMap<SDBMPositiveExpr, SDBMNegExprStorage *> SDBMNegExprs;
-  DenseMap<int64_t, SDBMConstantExprStorage *> SDBMConstExprs;
+  StorageUniquer sdbmUniquer;
 
   //===--------------------------------------------------------------------===//
   // Type uniquing
@@ -699,102 +683,10 @@ Location FusedLoc::get(ArrayRef<Location> locs, Attribute metadata,
 }
 
 //===----------------------------------------------------------------------===//
-// SDBMExpr uniquing
+// SDBM uniquing
 //===----------------------------------------------------------------------===//
 
-SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) {
-  assert(lhs && "expected SDBM variable expression");
-  assert(rhs && "expected SDBM constant");
-
-  MLIRContextImpl &impl = lhs.getContext()->getImpl();
-
-  // If LHS of a sum is another sum, fold the constant RHS parts.
-  if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>()) {
-    lhs = lhsSum.getLHS();
-    rhs = SDBMConstantExpr::get(rhs.getContext(),
-                                rhs.getValue() + lhsSum.getRHS().getValue());
-  }
-
-  auto key = std::make_tuple(lhs, rhs);
-  return safeGetOrCreate(
-      impl.SDBMSumExprs, key, impl.SDBMMutex, [&impl, lhs, rhs] {
-        auto *mem = impl.SDBMAllocator.Allocate<SDBMBinaryExprStorage>();
-        return new (mem) SDBMBinaryExprStorage(SDBMExprKind::Add,
-                                               lhs.getContext(), lhs, rhs);
-      });
-}
-
-SDBMDiffExpr SDBMDiffExpr::get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs) {
-  assert(lhs && "expected SDBM dimension");
-  assert(rhs && "expected SDBM dimension");
-
-  MLIRContextImpl &impl = lhs.getContext()->getImpl();
-  auto key = std::make_tuple(lhs, rhs);
-  return safeGetOrCreate(
-      impl.SDBMDiffExprs, key, impl.SDBMMutex, [&impl, lhs, rhs] {
-        auto *mem = impl.SDBMAllocator.Allocate<SDBMDiffExprStorage>();
-        return new (mem) SDBMDiffExprStorage(lhs.getContext(), lhs, rhs);
-      });
-}
-
-SDBMStripeExpr SDBMStripeExpr::get(SDBMPositiveExpr var,
-                                   SDBMConstantExpr stripeFactor) {
-  assert(var && "expected SDBM variable expression");
-  assert(stripeFactor && "expected non-null stripe factor");
-  if (stripeFactor.getValue() <= 0)
-    llvm::report_fatal_error("non-positive stripe factor");
-
-  MLIRContextImpl &impl = var.getContext()->getImpl();
-  auto key = std::make_tuple(var, stripeFactor);
-  return safeGetOrCreate(
-      impl.SDBMStripeExprs, key, impl.SDBMMutex, [&impl, var, stripeFactor] {
-        auto *mem = impl.SDBMAllocator.Allocate<SDBMBinaryExprStorage>();
-        return new (mem) SDBMBinaryExprStorage(
-            SDBMExprKind::Stripe, var.getContext(), var, stripeFactor);
-      });
-}
-
-SDBMDimExpr SDBMDimExpr::get(MLIRContext *context, unsigned position) {
-  assert(context && "expected non-null context");
-  MLIRContextImpl &impl = context->getImpl();
-  return safeGetOrCreate(
-      impl.SDBMDimExprs, position, impl.SDBMMutex, [&impl, context, position] {
-        auto *mem = impl.SDBMAllocator.Allocate<SDBMPositiveExprStorage>();
-        return new (mem)
-            SDBMPositiveExprStorage(SDBMExprKind::DimId, context, position);
-      });
-}
-
-SDBMSymbolExpr SDBMSymbolExpr::get(MLIRContext *context, unsigned position) {
-  assert(context && "expected non-null context");
-  MLIRContextImpl &impl = context->getImpl();
-  return safeGetOrCreate(
-      impl.SDBMSymbolExprs, position, impl.SDBMMutex,
-      [&impl, context, position] {
-        auto *mem = impl.SDBMAllocator.Allocate<SDBMPositiveExprStorage>();
-        return new (mem)
-            SDBMPositiveExprStorage(SDBMExprKind::SymbolId, context, position);
-      });
-}
-
-SDBMConstantExpr SDBMConstantExpr::get(MLIRContext *context, int64_t value) {
-  assert(context && "expected non-null context");
-  MLIRContextImpl &impl = context->getImpl();
-  return safeGetOrCreate(
-      impl.SDBMConstExprs, value, impl.SDBMMutex, [&impl, context, value] {
-        auto *mem = impl.SDBMAllocator.Allocate<SDBMConstantExprStorage>();
-        return new (mem) SDBMConstantExprStorage(context, value);
-      });
-}
-
-SDBMNegExpr SDBMNegExpr::get(SDBMPositiveExpr var) {
-  assert(var && "expected non-null SDBM variable expression");
-  MLIRContextImpl &impl = var.getContext()->getImpl();
-  return safeGetOrCreate(impl.SDBMNegExprs, var, impl.SDBMMutex, [&impl, var] {
-    auto *mem = impl.SDBMAllocator.Allocate<SDBMNegExprStorage>();
-    return new (mem) SDBMNegExprStorage(var);
-  });
-}
+StorageUniquer &MLIRContext::getSDBMUniquer() { return getImpl().sdbmUniquer; }
 
 //===----------------------------------------------------------------------===//
 // Type uniquing
index abf3352..a95e3fe 100644 (file)
@@ -149,7 +149,7 @@ AffineExprMatcher::AffineExprMatcher(AffineExprKind k, AffineExprMatcher a,
 
 SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); }
 
-MLIRContext *SDBMExpr::getContext() const { return impl->getContext(); }
+MLIRContext *SDBMExpr::getContext() const { return impl->context; }
 
 void SDBMExpr::print(raw_ostream &os) const {
   struct Printer : public SDBMVisitor<Printer> {
@@ -223,6 +223,22 @@ SDBMExpr SDBMExpr::operator-() { return SDBMNegator().visit(*this); }
 // SDBMSumExpr
 //===----------------------------------------------------------------------===//
 
+SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) {
+  assert(lhs && "expected SDBM variable expression");
+  assert(rhs && "expected SDBM constant");
+
+  // If LHS of a sum is another sum, fold the constant RHS parts.
+  if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>()) {
+    lhs = lhsSum.getLHS();
+    rhs = SDBMConstantExpr::get(rhs.getContext(),
+                                rhs.getValue() + lhsSum.getRHS().getValue());
+  }
+
+  StorageUniquer &uniquer = lhs.getContext()->getSDBMUniquer();
+  return uniquer.get<detail::SDBMBinaryExprStorage>(
+      /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
+}
+
 SDBMVaryingExpr SDBMSumExpr::getLHS() const {
   return static_cast<ImplType *>(impl)->lhs;
 }
@@ -396,6 +412,15 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
 // SDBMDiffExpr
 //===----------------------------------------------------------------------===//
 
+SDBMDiffExpr SDBMDiffExpr::get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs) {
+  assert(lhs && "expected SDBM dimension");
+  assert(rhs && "expected SDBM dimension");
+
+  StorageUniquer &uniquer = lhs.getContext()->getSDBMUniquer();
+  return uniquer.get<detail::SDBMDiffExprStorage>(
+      /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Diff), lhs, rhs);
+}
+
 SDBMPositiveExpr SDBMDiffExpr::getLHS() const {
   return static_cast<ImplType *>(impl)->lhs;
 }
@@ -408,6 +433,19 @@ SDBMPositiveExpr SDBMDiffExpr::getRHS() const {
 // SDBMStripeExpr
 //===----------------------------------------------------------------------===//
 
+SDBMStripeExpr SDBMStripeExpr::get(SDBMPositiveExpr var,
+                                   SDBMConstantExpr stripeFactor) {
+  assert(var && "expected SDBM variable expression");
+  assert(stripeFactor && "expected non-null stripe factor");
+  if (stripeFactor.getValue() <= 0)
+    llvm::report_fatal_error("non-positive stripe factor");
+
+  StorageUniquer &uniquer = var.getContext()->getSDBMUniquer();
+  return uniquer.get<detail::SDBMBinaryExprStorage>(
+      /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Stripe), var,
+      stripeFactor);
+}
+
 SDBMPositiveExpr SDBMStripeExpr::getVar() const {
   if (SDBMVaryingExpr lhs = static_cast<ImplType *>(impl)->lhs)
     return lhs.cast<SDBMPositiveExpr>();
@@ -427,9 +465,53 @@ unsigned SDBMInputExpr::getPosition() const {
 }
 
 //===----------------------------------------------------------------------===//
+// SDBMDimExpr
+//===----------------------------------------------------------------------===//
+
+SDBMDimExpr SDBMDimExpr::get(MLIRContext *context, unsigned position) {
+  assert(context && "expected non-null context");
+
+  auto assignCtx = [context](detail::SDBMPositiveExprStorage *storage) {
+    storage->context = context;
+  };
+
+  StorageUniquer &uniquer = context->getSDBMUniquer();
+  return uniquer.get<detail::SDBMPositiveExprStorage>(
+      assignCtx, static_cast<unsigned>(SDBMExprKind::DimId), position);
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMSymbolExpr
+//===----------------------------------------------------------------------===//
+
+SDBMSymbolExpr SDBMSymbolExpr::get(MLIRContext *context, unsigned position) {
+  assert(context && "expected non-null context");
+
+  auto assignCtx = [context](detail::SDBMPositiveExprStorage *storage) {
+    storage->context = context;
+  };
+
+  StorageUniquer &uniquer = context->getSDBMUniquer();
+  return uniquer.get<detail::SDBMPositiveExprStorage>(
+      assignCtx, static_cast<unsigned>(SDBMExprKind::SymbolId), position);
+}
+
+//===----------------------------------------------------------------------===//
 // SDBMConstantExpr
 //===----------------------------------------------------------------------===//
 
+SDBMConstantExpr SDBMConstantExpr::get(MLIRContext *context, int64_t value) {
+  assert(context && "expected non-null context");
+
+  auto assignCtx = [context](detail::SDBMConstantExprStorage *storage) {
+    storage->context = context;
+  };
+
+  StorageUniquer &uniquer = context->getSDBMUniquer();
+  return uniquer.get<detail::SDBMConstantExprStorage>(
+      assignCtx, static_cast<unsigned>(SDBMExprKind::Constant), value);
+}
+
 int64_t SDBMConstantExpr::getValue() const {
   return static_cast<ImplType *>(impl)->constant;
 }
@@ -438,6 +520,14 @@ int64_t SDBMConstantExpr::getValue() const {
 // SDBMNegExpr
 //===----------------------------------------------------------------------===//
 
+SDBMNegExpr SDBMNegExpr::get(SDBMPositiveExpr var) {
+  assert(var && "expected non-null SDBM variable expression");
+
+  StorageUniquer &uniquer = var.getContext()->getSDBMUniquer();
+  return uniquer.get<detail::SDBMNegExprStorage>(
+      /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Neg), var);
+}
+
 SDBMPositiveExpr SDBMNegExpr::getVar() const {
   return static_cast<ImplType *>(impl)->dim;
 }
index f661d2a..b4e748f 100644 (file)
 
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/SDBMExpr.h"
-#include "llvm/ADT/PointerIntPair.h"
+#include "mlir/Support/StorageUniquer.h"
 
 namespace mlir {
 namespace detail {
 
-struct SDBMExprStorage {
-  SDBMExprStorage(SDBMExprKind kind, MLIRContext *context)
-      : contextAndKind(context, kind) {}
+// Base storage class for SDBMExpr.
+struct SDBMExprStorage : public StorageUniquer::BaseStorage {
+  SDBMExprKind getKind() {
+    return static_cast<SDBMExprKind>(BaseStorage::getKind());
+  }
 
-  SDBMExprKind getKind() { return contextAndKind.getInt(); }
-
-  MLIRContext *getContext() { return contextAndKind.getPointer(); }
-
-  // This needs to know the layout of MLIRContext so the relevant file is
-  // included.
-  llvm::PointerIntPair<MLIRContext *, 3, SDBMExprKind> contextAndKind;
+  MLIRContext *context;
 };
 
+// Storage class for SDBM sum and stripe expressions.
 struct SDBMBinaryExprStorage : public SDBMExprStorage {
-  SDBMBinaryExprStorage(SDBMExprKind kind, MLIRContext *context,
-                        SDBMVaryingExpr left, SDBMConstantExpr right)
-      : SDBMExprStorage(kind, context), lhs(left), rhs(right) {}
+  using KeyTy = std::pair<SDBMVaryingExpr, SDBMConstantExpr>;
+
+  bool operator==(const KeyTy &key) const {
+    return std::get<0>(key) == lhs && std::get<1>(key) == rhs;
+  }
+
+  static SDBMBinaryExprStorage *
+  construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
+    auto *result = allocator.allocate<SDBMBinaryExprStorage>();
+    result->lhs = std::get<0>(key);
+    result->rhs = std::get<1>(key);
+    result->context = result->lhs.getContext();
+    return result;
+  }
+
   SDBMVaryingExpr lhs;
   SDBMConstantExpr rhs;
 };
 
+// Storage class for SDBM difference expressions.
 struct SDBMDiffExprStorage : public SDBMExprStorage {
-  SDBMDiffExprStorage(MLIRContext *context, SDBMPositiveExpr left,
-                      SDBMPositiveExpr right)
-      : SDBMExprStorage(SDBMExprKind::Diff, context), lhs(left), rhs(right) {}
+  using KeyTy = std::pair<SDBMPositiveExpr, SDBMPositiveExpr>;
+
+  bool operator==(const KeyTy &key) const {
+    return std::get<0>(key) == lhs && std::get<1>(key) == rhs;
+  }
+
+  static SDBMDiffExprStorage *
+  construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
+    auto *result = allocator.allocate<SDBMDiffExprStorage>();
+    result->lhs = std::get<0>(key);
+    result->rhs = std::get<1>(key);
+    result->context = result->lhs.getContext();
+    return result;
+  }
+
   SDBMPositiveExpr lhs;
   SDBMPositiveExpr rhs;
 };
 
+// Storage class for SDBM constant expressions.
 struct SDBMConstantExprStorage : public SDBMExprStorage {
-  SDBMConstantExprStorage(MLIRContext *context, int64_t value)
-      : SDBMExprStorage(SDBMExprKind::Constant, context), constant(value) {}
+  using KeyTy = int64_t;
+
+  bool operator==(const KeyTy &key) const { return constant == key; }
+
+  static SDBMConstantExprStorage *
+  construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
+    auto *result = allocator.allocate<SDBMConstantExprStorage>();
+    result->constant = key;
+    return result;
+  }
+
   int64_t constant;
 };
 
+// Storage class for SDBM dimension and symbol expressions.
 struct SDBMPositiveExprStorage : public SDBMExprStorage {
-  SDBMPositiveExprStorage(SDBMExprKind kind, MLIRContext *context, unsigned pos)
-      : SDBMExprStorage(kind, context), position(pos) {}
+  using KeyTy = unsigned;
+
+  bool operator==(const KeyTy &key) const { return position == key; }
+
+  static SDBMPositiveExprStorage *
+  construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
+    auto *result = allocator.allocate<SDBMPositiveExprStorage>();
+    result->position = key;
+    return result;
+  }
+
   unsigned position;
 };
 
+// Storage class for SDBM negation expressions.
 struct SDBMNegExprStorage : public SDBMExprStorage {
-  SDBMNegExprStorage(SDBMPositiveExpr expr)
-      : SDBMExprStorage(SDBMExprKind::Neg, expr.getContext()), dim(expr) {}
+  using KeyTy = SDBMPositiveExpr;
+
+  bool operator==(const KeyTy &key) const { return key == dim; }
+
+  static SDBMNegExprStorage *
+  construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
+    auto *result = allocator.allocate<SDBMNegExprStorage>();
+    result->dim = key;
+    result->context = key.getContext();
+    return result;
+  }
+
   SDBMPositiveExpr dim;
 };