Move SDBM uniquer into the SDBM dialect
authorAlex Zinenko <zinenko@google.com>
Tue, 21 May 2019 14:25:55 +0000 (07:25 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Sun, 2 Jun 2019 02:54:43 +0000 (19:54 -0700)
    MLIRContext does not have to be aware of the SDBM unique data structures
    directly.  Move the SDBM storage uniquer from MLIRContext to the SDBM dialect
    instance.  Expressions that previously required a context to be constructed now
    require an instance of the dialect in order to access the uniquer.  While they
    could look up the dialect in the context, it would have introduced a rather
    expensive lookup into each construction.  Instead, the caller is expected to
    obtain the dialect instance and cache it.

--

PiperOrigin-RevId: 249245199

mlir/include/mlir/IR/MLIRContext.h
mlir/include/mlir/SDBM/SDBM.h
mlir/include/mlir/SDBM/SDBMDialect.h
mlir/include/mlir/SDBM/SDBMExpr.h
mlir/lib/IR/MLIRContext.cpp
mlir/lib/SDBM/SDBM.cpp
mlir/lib/SDBM/SDBMExpr.cpp
mlir/lib/SDBM/SDBMExprDetail.h
mlir/test/SDBM/sdbm-api-test.cpp
mlir/unittests/SDBM/SDBMTest.cpp

index 9309d51..864fd96 100644 (file)
@@ -93,10 +93,6 @@ public:
   /// instances. This should not be used directly.
   StorageUniquer &getAttributeUniquer();
 
-  /// Returns the storage uniquer used for constructing SDBM expressions, which
-  /// are not attributes by themselves. This should not be used directly.
-  StorageUniquer &getSDBMUniquer();
-
 private:
   const std::unique_ptr<MLIRContextImpl> impl;
 
index 894165b..b1c2723 100644 (file)
@@ -29,6 +29,7 @@
 namespace mlir {
 
 class MLIRContext;
+class SDBMDialect;
 class SDBMExpr;
 class SDBMPositiveExpr;
 
@@ -110,7 +111,7 @@ public:
   static SDBM get(ArrayRef<SDBMExpr> inequalities,
                   ArrayRef<SDBMExpr> equalities);
 
-  void getSDBMExpressions(MLIRContext *context,
+  void getSDBMExpressions(SDBMDialect *dialect,
                           SmallVectorImpl<SDBMExpr> &inequalities,
                           SmallVectorImpl<SDBMExpr> &equalities);
 
@@ -129,15 +130,14 @@ private:
   /// and at(col,row) of the DBM.  Depending on the values being finite and
   /// being subsumed by stripe expressions, this may or may not add elements to
   /// the lists of equalities and inequalities.
-  void convertDBMElement(MLIRContext *context, unsigned row, unsigned col,
-                         SDBMPositiveExpr rowExpr, SDBMPositiveExpr colExpr,
+  void convertDBMElement(unsigned row, unsigned col, SDBMPositiveExpr rowExpr,
+                         SDBMPositiveExpr colExpr,
                          SmallVectorImpl<SDBMExpr> &inequalities,
                          SmallVectorImpl<SDBMExpr> &equalities);
 
   /// Populate `inequalities` based on the value at(pos,pos) of the DBM. Only
   /// adds new inequalities if the inequality is not trivially true.
-  void convertDBMDiagonalElement(MLIRContext *context, unsigned pos,
-                                 SDBMPositiveExpr expr,
+  void convertDBMDiagonalElement(unsigned pos, SDBMPositiveExpr expr,
                                  SmallVectorImpl<SDBMExpr> &inequalities);
 
   /// Get the total number of elements in the matrix.
index ba645b3..12086dc 100644 (file)
@@ -19,6 +19,7 @@
 #define MLIR_SDBM_SDBMDIALECT_H
 
 #include "mlir/IR/Dialect.h"
+#include "mlir/Support/StorageUniquer.h"
 
 namespace mlir {
 class MLIRContext;
@@ -28,6 +29,12 @@ public:
   SDBMDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) {}
 
   static StringRef getDialectNamespace() { return "sdbm"; }
+
+  /// Get the uniquer for SDBM expressions. This should not be used directly.
+  StorageUniquer &getUniquer() { return uniquer; }
+
+private:
+  StorageUniquer uniquer;
 };
 } // namespace mlir
 
index b29eb3a..afbeda1 100644 (file)
@@ -44,6 +44,7 @@ struct SDBMNegExprStorage;
 } // namespace detail
 
 class SDBMConstantExpr;
+class SDBMDialect;
 class SDBMDimExpr;
 class SDBMSymbolExpr;
 
@@ -118,6 +119,9 @@ public:
   /// Returns the MLIR context in which this expression lives.
   MLIRContext *getContext() const;
 
+  /// Returns the SDBM dialect instance.
+  SDBMDialect *getDialect() const;
+
   /// Convert the SDBM expression into an Affine expression.  This always
   /// succeeds because SDBM are a subset of affine.
   AffineExpr getAsAffineExpr() const;
@@ -140,8 +144,9 @@ public:
 
   using SDBMExpr::SDBMExpr;
 
-  /// Obtain or create a constant expression unique'ed in the given context.
-  static SDBMConstantExpr get(MLIRContext *context, int64_t value);
+  /// Obtain or create a constant expression unique'ed in the given dialect
+  /// (which belongs to a context).
+  static SDBMConstantExpr get(SDBMDialect *dialect, int64_t value);
 
   static bool isClassFor(const SDBMExpr &expr) {
     return expr.getKind() == SDBMExprKind::Constant;
@@ -265,8 +270,9 @@ public:
   using ImplType = detail::SDBMPositiveExprStorage;
   using SDBMInputExpr::SDBMInputExpr;
 
-  /// Obtain or create a dimension expression unique'ed in the given context.
-  static SDBMDimExpr get(MLIRContext *context, unsigned position);
+  /// Obtain or create a dimension expression unique'ed in the given dialect
+  /// (which belongs to a context).
+  static SDBMDimExpr get(SDBMDialect *dialect, unsigned position);
 
   static bool isClassFor(const SDBMExpr &expr) {
     return expr.getKind() == SDBMExprKind::DimId;
@@ -280,8 +286,9 @@ public:
   using ImplType = detail::SDBMPositiveExprStorage;
   using SDBMInputExpr::SDBMInputExpr;
 
-  /// Obtain or create a symbol expression unique'ed in the given context.
-  static SDBMSymbolExpr get(MLIRContext *context, unsigned position);
+  /// Obtain or create a symbol expression unique'ed in the given dialect (which
+  /// belongs to a context).
+  static SDBMSymbolExpr get(SDBMDialect *dialect, unsigned position);
 
   static bool isClassFor(const SDBMExpr &expr) {
     return expr.getKind() == SDBMExprKind::SymbolId;
@@ -414,27 +421,27 @@ namespace ops_assertions {
 /// simultaneously.
 SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs);
 inline SDBMExpr operator+(SDBMExpr lhs, int64_t rhs) {
-  return lhs + SDBMConstantExpr::get(lhs.getContext(), rhs);
+  return lhs + SDBMConstantExpr::get(lhs.getDialect(), rhs);
 }
 inline SDBMExpr operator+(int64_t lhs, SDBMExpr rhs) {
-  return SDBMConstantExpr::get(rhs.getContext(), lhs) + rhs;
+  return SDBMConstantExpr::get(rhs.getDialect(), lhs) + rhs;
 }
 
 /// Subtract an SDBM expression from another SDBM expression.  Both expressions
 /// must not be difference expressions.
 SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs);
 inline SDBMExpr operator-(SDBMExpr lhs, int64_t rhs) {
-  return lhs - SDBMConstantExpr::get(lhs.getContext(), rhs);
+  return lhs - SDBMConstantExpr::get(lhs.getDialect(), rhs);
 }
 inline SDBMExpr operator-(int64_t lhs, SDBMExpr rhs) {
-  return SDBMConstantExpr::get(rhs.getContext(), lhs) - rhs;
+  return SDBMConstantExpr::get(rhs.getDialect(), lhs) - rhs;
 }
 
 /// Construct a stripe expression from a positive expression and a positive
 /// constant stripe factor.
 SDBMExpr stripe(SDBMExpr expr, SDBMExpr factor);
 inline SDBMExpr stripe(SDBMExpr expr, int64_t factor) {
-  return stripe(expr, SDBMConstantExpr::get(expr.getContext(), factor));
+  return stripe(expr, SDBMConstantExpr::get(expr.getDialect(), factor));
 }
 } // namespace ops_assertions
 
index c41d72b..9c11f05 100644 (file)
@@ -354,11 +354,6 @@ public:
   StorageUniquer affineUniquer;
 
   //===--------------------------------------------------------------------===//
-  // SDBM uniquing
-  //===--------------------------------------------------------------------===//
-  StorageUniquer sdbmUniquer;
-
-  //===--------------------------------------------------------------------===//
   // Type uniquing
   //===--------------------------------------------------------------------===//
   StorageUniquer typeUniquer;
@@ -683,12 +678,6 @@ Location FusedLoc::get(ArrayRef<Location> locs, Attribute metadata,
 }
 
 //===----------------------------------------------------------------------===//
-// SDBM uniquing
-//===----------------------------------------------------------------------===//
-
-StorageUniquer &MLIRContext::getSDBMUniquer() { return getImpl().sdbmUniquer; }
-
-//===----------------------------------------------------------------------===//
 // Type uniquing
 //===----------------------------------------------------------------------===//
 
index 080a132..db32834 100644 (file)
@@ -353,7 +353,7 @@ SDBM SDBM::get(ArrayRef<SDBMExpr> inequalities, ArrayRef<SDBMExpr> equalities) {
 // If one of the expressions is derived from another using a stripe operation,
 // check if the inequalities induced by the stripe operation subsume the
 // inequalities defined in the DBM and if so, elide these inequalities.
-void SDBM::convertDBMElement(MLIRContext *context, unsigned row, unsigned col,
+void SDBM::convertDBMElement(unsigned row, unsigned col,
                              SDBMPositiveExpr rowExpr, SDBMPositiveExpr colExpr,
                              SmallVectorImpl<SDBMExpr> &inequalities,
                              SmallVectorImpl<SDBMExpr> &equalities) {
@@ -417,18 +417,17 @@ void SDBM::convertDBMElement(MLIRContext *context, unsigned row, unsigned col,
 // to -C <= 0.  Only construct the inequalities when C is negative, which
 // are trivially false but necessary for the returned system of inequalities
 // to indicate that the set it defines is empty.
-void SDBM::convertDBMDiagonalElement(MLIRContext *context, unsigned pos,
-                                     SDBMPositiveExpr expr,
+void SDBM::convertDBMDiagonalElement(unsigned pos, SDBMPositiveExpr expr,
                                      SmallVectorImpl<SDBMExpr> &inequalities) {
   auto selfDifference = at(pos, pos);
   if (selfDifference.isFinite() && selfDifference < 0) {
     auto selfDifferenceValueExpr =
-        SDBMConstantExpr::get(context, -selfDifference.getValue());
+        SDBMConstantExpr::get(expr.getDialect(), -selfDifference.getValue());
     inequalities.push_back(selfDifferenceValueExpr);
   }
 }
 
-void SDBM::getSDBMExpressions(MLIRContext *context,
+void SDBM::getSDBMExpressions(SDBMDialect *dialect,
                               SmallVectorImpl<SDBMExpr> &inequalities,
                               SmallVectorImpl<SDBMExpr> &equalities) {
   using ops_assertions::operator-;
@@ -436,10 +435,10 @@ void SDBM::getSDBMExpressions(MLIRContext *context,
 
   // Helper function that creates an SDBMInputExpr given the linearized position
   // of variable in the DBM.
-  auto getInput = [context, this](unsigned matrixPos) -> SDBMInputExpr {
+  auto getInput = [dialect, this](unsigned matrixPos) -> SDBMInputExpr {
     if (matrixPos < numDims)
-      return SDBMDimExpr::get(context, matrixPos);
-    return SDBMSymbolExpr::get(context, matrixPos - numDims);
+      return SDBMDimExpr::get(dialect, matrixPos);
+    return SDBMSymbolExpr::get(dialect, matrixPos - numDims);
   };
 
   // The top-left value corresponds to inequality 0 <= C.  If C is negative, the
@@ -449,7 +448,7 @@ void SDBM::getSDBMExpressions(MLIRContext *context,
   auto difference = at(0, 0);
   if (difference.isFinite() && difference < 0) {
     inequalities.push_back(
-        SDBMConstantExpr::get(context, -difference.getValue()));
+        SDBMConstantExpr::get(dialect, -difference.getValue()));
   }
 
   // Traverse the segment of the matrix that involves non-temporary variables.
@@ -470,12 +469,12 @@ void SDBM::getSDBMExpressions(MLIRContext *context,
     }
 
     // Introduce trivially false inequalities if required by diagonal elements.
-    convertDBMDiagonalElement(context, 1 + i, inputExpr, inequalities);
+    convertDBMDiagonalElement(1 + i, inputExpr, inequalities);
 
     // Introduce equalities or inequalities between non-temporary variables.
     for (unsigned j = 0; j < i; ++j) {
-      convertDBMElement(context, 1 + i, 1 + j, getInput(i), getInput(j),
-                        inequalities, equalities);
+      convertDBMElement(1 + i, 1 + j, getInput(i), getInput(j), inequalities,
+                        equalities);
     }
   }
 
@@ -495,21 +494,20 @@ void SDBM::getSDBMExpressions(MLIRContext *context,
     // Mixed constraints involving one temporary (j) and one non-temporary (i)
     // variable.
     for (unsigned j = 0; j < numTrueVariables; ++j) {
-      convertDBMElement(context, i, 1 + j,
-                        stripeToPoint[i].cast<SDBMStripeExpr>(), getInput(j),
-                        inequalities, equalities);
+      convertDBMElement(i, 1 + j, stripeToPoint[i].cast<SDBMStripeExpr>(),
+                        getInput(j), inequalities, equalities);
     }
 
     // Constraints involving only temporary variables.
     for (unsigned j = 1 + numTrueVariables; j < i; ++j) {
-      convertDBMElement(context, i, j, stripeToPoint[i].cast<SDBMStripeExpr>(),
+      convertDBMElement(i, j, stripeToPoint[i].cast<SDBMStripeExpr>(),
                         stripeToPoint[j].cast<SDBMStripeExpr>(), inequalities,
                         equalities);
     }
 
     // Introduce trivially false inequalities if required by diagonal elements.
-    convertDBMDiagonalElement(
-        context, i, stripeToPoint[i].cast<SDBMStripeExpr>(), inequalities);
+    convertDBMDiagonalElement(i, stripeToPoint[i].cast<SDBMStripeExpr>(),
+                              inequalities);
   }
 }
 
index fc8662d..5757ebe 100644 (file)
@@ -25,6 +25,7 @@
 #include "SDBMExprDetail.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/SDBM/SDBMDialect.h"
 
 #include "llvm/Support/raw_ostream.h"
 
@@ -149,7 +150,11 @@ AffineExprMatcher::AffineExprMatcher(AffineExprKind k, AffineExprMatcher a,
 
 SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); }
 
-MLIRContext *SDBMExpr::getContext() const { return impl->context; }
+MLIRContext *SDBMExpr::getContext() const {
+  return impl->dialect->getContext();
+}
+
+SDBMDialect *SDBMExpr::getDialect() const { return impl->dialect; }
 
 void SDBMExpr::print(raw_ostream &os) const {
   struct Printer : public SDBMVisitor<Printer> {
@@ -202,7 +207,7 @@ struct SDBMNegator : public SDBMVisitor<SDBMNegator, SDBMExpr> {
   SDBMExpr visitNeg(SDBMNegExpr expr) { return expr.getVar(); }
   // The value of the constant is negated.
   SDBMExpr visitConstant(SDBMConstantExpr expr) {
-    return SDBMConstantExpr::get(expr.getContext(), -expr.getValue());
+    return SDBMConstantExpr::get(expr.getDialect(), -expr.getValue());
   }
   // Both terms of the sum are negated recursively.
   SDBMExpr visitSum(SDBMSumExpr expr) {
@@ -230,11 +235,11 @@ SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) {
   // If LHS of a sum is another sum, fold the constant RHS parts.
   if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>()) {
     lhs = lhsSum.getLHS();
-    rhs = SDBMConstantExpr::get(rhs.getContext(),
+    rhs = SDBMConstantExpr::get(rhs.getDialect(),
                                 rhs.getValue() + lhsSum.getRHS().getValue());
   }
 
-  StorageUniquer &uniquer = lhs.getContext()->getSDBMUniquer();
+  StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
   return uniquer.get<detail::SDBMBinaryExprStorage>(
       /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Add), lhs, rhs);
 }
@@ -346,7 +351,7 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
             // TODO(ntv): return varConverted.stripe(C.getConstantValue());
             return SDBMStripeExpr::get(
                 varConverted,
-                SDBMConstantExpr::get(varConverted.getContext(),
+                SDBMConstantExpr::get(dialect,
                                       C.getMatchedConstantValue().getValue()));
         }
       }
@@ -393,15 +398,18 @@ Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
 
     // Dimensions, symbols and constants are converted trivially.
     SDBMExpr visitConstantExpr(AffineConstantExpr expr) {
-      return SDBMConstantExpr::get(expr.getContext(), expr.getValue());
+      return SDBMConstantExpr::get(dialect, expr.getValue());
     }
     SDBMExpr visitDimExpr(AffineDimExpr expr) {
-      return SDBMDimExpr::get(expr.getContext(), expr.getPosition());
+      return SDBMDimExpr::get(dialect, expr.getPosition());
     }
     SDBMExpr visitSymbolExpr(AffineSymbolExpr expr) {
-      return SDBMSymbolExpr::get(expr.getContext(), expr.getPosition());
+      return SDBMSymbolExpr::get(dialect, expr.getPosition());
     }
+
+    SDBMDialect *dialect;
   } converter;
+  converter.dialect = affine.getContext()->getRegisteredDialect<SDBMDialect>();
 
   if (auto result = converter.visit(affine))
     return result;
@@ -416,7 +424,7 @@ SDBMDiffExpr SDBMDiffExpr::get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs) {
   assert(lhs && "expected SDBM dimension");
   assert(rhs && "expected SDBM dimension");
 
-  StorageUniquer &uniquer = lhs.getContext()->getSDBMUniquer();
+  StorageUniquer &uniquer = lhs.getDialect()->getUniquer();
   return uniquer.get<detail::SDBMDiffExprStorage>(
       /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Diff), lhs, rhs);
 }
@@ -440,7 +448,7 @@ SDBMStripeExpr SDBMStripeExpr::get(SDBMPositiveExpr var,
   if (stripeFactor.getValue() <= 0)
     llvm::report_fatal_error("non-positive stripe factor");
 
-  StorageUniquer &uniquer = var.getContext()->getSDBMUniquer();
+  StorageUniquer &uniquer = var.getDialect()->getUniquer();
   return uniquer.get<detail::SDBMBinaryExprStorage>(
       /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Stripe), var,
       stripeFactor);
@@ -468,46 +476,46 @@ unsigned SDBMInputExpr::getPosition() const {
 // SDBMDimExpr
 //===----------------------------------------------------------------------===//
 
-SDBMDimExpr SDBMDimExpr::get(MLIRContext *context, unsigned position) {
-  assert(context && "expected non-null context");
+SDBMDimExpr SDBMDimExpr::get(SDBMDialect *dialect, unsigned position) {
+  assert(dialect && "expected non-null dialect");
 
-  auto assignCtx = [context](detail::SDBMPositiveExprStorage *storage) {
-    storage->context = context;
+  auto assignDialect = [dialect](detail::SDBMPositiveExprStorage *storage) {
+    storage->dialect = dialect;
   };
 
-  StorageUniquer &uniquer = context->getSDBMUniquer();
+  StorageUniquer &uniquer = dialect->getUniquer();
   return uniquer.get<detail::SDBMPositiveExprStorage>(
-      assignCtx, static_cast<unsigned>(SDBMExprKind::DimId), position);
+      assignDialect, static_cast<unsigned>(SDBMExprKind::DimId), position);
 }
 
 //===----------------------------------------------------------------------===//
 // SDBMSymbolExpr
 //===----------------------------------------------------------------------===//
 
-SDBMSymbolExpr SDBMSymbolExpr::get(MLIRContext *context, unsigned position) {
-  assert(context && "expected non-null context");
+SDBMSymbolExpr SDBMSymbolExpr::get(SDBMDialect *dialect, unsigned position) {
+  assert(dialect && "expected non-null dialect");
 
-  auto assignCtx = [context](detail::SDBMPositiveExprStorage *storage) {
-    storage->context = context;
+  auto assignDialect = [dialect](detail::SDBMPositiveExprStorage *storage) {
+    storage->dialect = dialect;
   };
 
-  StorageUniquer &uniquer = context->getSDBMUniquer();
+  StorageUniquer &uniquer = dialect->getUniquer();
   return uniquer.get<detail::SDBMPositiveExprStorage>(
-      assignCtx, static_cast<unsigned>(SDBMExprKind::SymbolId), position);
+      assignDialect, static_cast<unsigned>(SDBMExprKind::SymbolId), position);
 }
 
 //===----------------------------------------------------------------------===//
 // SDBMConstantExpr
 //===----------------------------------------------------------------------===//
 
-SDBMConstantExpr SDBMConstantExpr::get(MLIRContext *context, int64_t value) {
-  assert(context && "expected non-null context");
+SDBMConstantExpr SDBMConstantExpr::get(SDBMDialect *dialect, int64_t value) {
+  assert(dialect && "expected non-null dialect");
 
-  auto assignCtx = [context](detail::SDBMConstantExprStorage *storage) {
-    storage->context = context;
+  auto assignCtx = [dialect](detail::SDBMConstantExprStorage *storage) {
+    storage->dialect = dialect;
   };
 
-  StorageUniquer &uniquer = context->getSDBMUniquer();
+  StorageUniquer &uniquer = dialect->getUniquer();
   return uniquer.get<detail::SDBMConstantExprStorage>(
       assignCtx, static_cast<unsigned>(SDBMExprKind::Constant), value);
 }
@@ -523,7 +531,7 @@ int64_t SDBMConstantExpr::getValue() const {
 SDBMNegExpr SDBMNegExpr::get(SDBMPositiveExpr var) {
   assert(var && "expected non-null SDBM variable expression");
 
-  StorageUniquer &uniquer = var.getContext()->getSDBMUniquer();
+  StorageUniquer &uniquer = var.getDialect()->getUniquer();
   return uniquer.get<detail::SDBMNegExprStorage>(
       /*initFn=*/{}, static_cast<unsigned>(SDBMExprKind::Neg), var);
 }
@@ -567,7 +575,7 @@ SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) {
 
   // Constant-fold if LHS is a constant.
   if (lhsConstant)
-    return SDBMConstantExpr::get(lhs.getContext(), lhsConstant.getValue() +
+    return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() +
                                                        rhsConstant.getValue());
 
   // Fold x + 0 == x.
@@ -581,7 +589,7 @@ SDBMExpr operator+(SDBMExpr lhs, SDBMExpr rhs) {
 SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) {
   // Fold x - x == 0.
   if (lhs == rhs)
-    return SDBMConstantExpr::get(lhs.getContext(), 0);
+    return SDBMConstantExpr::get(lhs.getDialect(), 0);
 
   // LHS and RHS may be constants.
   auto lhsConstant = lhs.dyn_cast<SDBMConstantExpr>();
@@ -589,7 +597,7 @@ SDBMExpr operator-(SDBMExpr lhs, SDBMExpr rhs) {
 
   // Constant fold if both LHS and RHS are constants.
   if (lhsConstant && rhsConstant)
-    return SDBMConstantExpr::get(lhs.getContext(), lhsConstant.getValue() -
+    return SDBMConstantExpr::get(lhs.getDialect(), lhsConstant.getValue() -
                                                        rhsConstant.getValue());
 
   // Replace a difference with a sum with a negated value if one of LHS and RHS
index 0ce8ea3..d2c241e 100644 (file)
 // =============================================================================
 //
 // This holds implementation details of SDBMExpr, in particular underlying
-// storage types.  MLIRContext.cpp needs to know the storage layout for
-// allocation and unique'ing purposes.
+// storage types.
 //
 //===----------------------------------------------------------------------===//
 
 #ifndef MLIR_IR_SDBMEXPRDETAIL_H
 #define MLIR_IR_SDBMEXPRDETAIL_H
 
-#include "mlir/IR/MLIRContext.h"
 #include "mlir/SDBM/SDBMExpr.h"
 #include "mlir/Support/StorageUniquer.h"
 
 namespace mlir {
+
+class SDBMDialect;
+
 namespace detail {
 
 // Base storage class for SDBMExpr.
@@ -37,7 +38,7 @@ struct SDBMExprStorage : public StorageUniquer::BaseStorage {
     return static_cast<SDBMExprKind>(BaseStorage::getKind());
   }
 
-  MLIRContext *context;
+  SDBMDialect *dialect;
 };
 
 // Storage class for SDBM sum and stripe expressions.
@@ -53,7 +54,7 @@ struct SDBMBinaryExprStorage : public SDBMExprStorage {
     auto *result = allocator.allocate<SDBMBinaryExprStorage>();
     result->lhs = std::get<0>(key);
     result->rhs = std::get<1>(key);
-    result->context = result->lhs.getContext();
+    result->dialect = result->lhs.getDialect();
     return result;
   }
 
@@ -74,7 +75,7 @@ struct SDBMDiffExprStorage : public SDBMExprStorage {
     auto *result = allocator.allocate<SDBMDiffExprStorage>();
     result->lhs = std::get<0>(key);
     result->rhs = std::get<1>(key);
-    result->context = result->lhs.getContext();
+    result->dialect = result->lhs.getDialect();
     return result;
   }
 
@@ -124,7 +125,7 @@ struct SDBMNegExprStorage : public SDBMExprStorage {
   construct(StorageUniquer::StorageAllocator &allocator, const KeyTy &key) {
     auto *result = allocator.allocate<SDBMNegExprStorage>();
     result->dim = key;
-    result->context = key.getContext();
+    result->dialect = key.getDialect();
     return result;
   }
 
index 08075fd..0777bec 100644 (file)
@@ -19,6 +19,7 @@
 
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/SDBM/SDBM.h"
+#include "mlir/SDBM/SDBMDialect.h"
 #include "mlir/SDBM/SDBMExpr.h"
 
 #include "llvm/Support/raw_ostream.h"
@@ -32,9 +33,19 @@ static MLIRContext *ctx() {
   return &context;
 }
 
-static SDBMExpr dim(unsigned pos) { return SDBMDimExpr::get(ctx(), pos); }
+static SDBMDialect *dialect() {
+  static thread_local SDBMDialect *d = nullptr;
+  if (!d) {
+    d = ctx()->getRegisteredDialect<SDBMDialect>();
+  }
+  return d;
+}
+
+static SDBMExpr dim(unsigned pos) { return SDBMDimExpr::get(dialect(), pos); }
 
-static SDBMExpr symb(unsigned pos) { return SDBMSymbolExpr::get(ctx(), pos); }
+static SDBMExpr symb(unsigned pos) {
+  return SDBMSymbolExpr::get(dialect(), pos);
+}
 
 namespace {
 
@@ -129,7 +140,7 @@ TEST_FUNC(SDBM_StripeTightening) {
   auto sdbm = SDBM::get({tight}, {s - dim(0), s - dim(1) + 42});
 
   SmallVector<SDBMExpr, 4> eqs, ineqs;
-  sdbm.getSDBMExpressions(ctx(), ineqs, eqs);
+  sdbm.getSDBMExpressions(dialect(), ineqs, eqs);
   // CHECK-DAG: d0 - s0 # 3 + -2
   // CHECK-DAG: d1 - d0 + -42
   // CHEKC-DAG: d0 - s0 # 3 # 5
index 22dd6c3..c4c1c56 100644 (file)
@@ -18,6 +18,7 @@
 #include "mlir/SDBM/SDBM.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/MLIRContext.h"
+#include "mlir/SDBM/SDBMDialect.h"
 #include "mlir/SDBM/SDBMExpr.h"
 #include "gtest/gtest.h"
 
@@ -30,9 +31,19 @@ static MLIRContext *ctx() {
   return &context;
 }
 
-static SDBMExpr dim(unsigned pos) { return SDBMDimExpr::get(ctx(), pos); }
+static SDBMDialect *dialect() {
+  static thread_local SDBMDialect *d = nullptr;
+  if (!d) {
+    d = ctx()->getRegisteredDialect<SDBMDialect>();
+  }
+  return d;
+}
+
+static SDBMExpr dim(unsigned pos) { return SDBMDimExpr::get(dialect(), pos); }
 
-static SDBMExpr symb(unsigned pos) { return SDBMSymbolExpr::get(ctx(), pos); }
+static SDBMExpr symb(unsigned pos) {
+  return SDBMSymbolExpr::get(dialect(), pos);
+}
 
 namespace {
 
@@ -47,7 +58,7 @@ TEST(SDBMOperators, Add) {
 }
 
 TEST(SDBMOperators, AddFolding) {
-  auto constant = SDBMConstantExpr::get(ctx(), 2) + 42;
+  auto constant = SDBMConstantExpr::get(dialect(), 2) + 42;
   auto constantExpr = constant.dyn_cast<SDBMConstantExpr>();
   ASSERT_TRUE(constantExpr);
   EXPECT_EQ(constantExpr.getValue(), 44);
@@ -57,13 +68,13 @@ TEST(SDBMOperators, AddFolding) {
   ASSERT_TRUE(sumExpr);
   EXPECT_EQ(sumExpr.getRHS().getValue(), 42);
 
-  expr = dim(0) + SDBMNegExpr::get(SDBMDimExpr::get(ctx(), 1));
+  expr = dim(0) + SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1));
   auto diffExpr = expr.dyn_cast<SDBMDiffExpr>();
   ASSERT_TRUE(diffExpr);
   EXPECT_EQ(diffExpr.getLHS(), dim(0));
   EXPECT_EQ(diffExpr.getRHS(), dim(1));
 
-  auto inverted = SDBMNegExpr::get(SDBMDimExpr::get(ctx(), 1)) + dim(0);
+  auto inverted = SDBMNegExpr::get(SDBMDimExpr::get(dialect(), 1)) + dim(0);
   EXPECT_EQ(inverted, expr);
 }
 
@@ -76,7 +87,7 @@ TEST(SDBMOperators, Diff) {
 }
 
 TEST(SDBMOperators, DiffFolding) {
-  auto constant = SDBMConstantExpr::get(ctx(), 10) - 3;
+  auto constant = SDBMConstantExpr::get(dialect(), 10) - 3;
   auto constantExpr = constant.dyn_cast<SDBMConstantExpr>();
   ASSERT_TRUE(constantExpr);
   EXPECT_EQ(constantExpr.getValue(), 7);
@@ -118,13 +129,13 @@ TEST(SDBM, RoundTripEqs) {
   auto s = stripe(stripe(symb(0), 3), 5);
   auto sdbm = SDBM::get(llvm::None, {s - dim(0), s - dim(1) + 42});
   SmallVector<SDBMExpr, 4> eqs, ineqs;
-  sdbm.getSDBMExpressions(ctx(), ineqs, eqs);
+  sdbm.getSDBMExpressions(dialect(), ineqs, eqs);
   ASSERT_TRUE(ineqs.empty());
 
   // Do the second round-trip.
   auto sdbm2 = SDBM::get(llvm::None, eqs);
   SmallVector<SDBMExpr, 4> eqs2, ineqs2;
-  sdbm2.getSDBMExpressions(ctx(), ineqs2, eqs2);
+  sdbm2.getSDBMExpressions(dialect(), ineqs2, eqs2);
   ASSERT_EQ(eqs.size(), eqs2.size());
 
   // Convert that the sets of equalities are equal, their order is not relevant.
@@ -136,11 +147,11 @@ TEST(SDBM, RoundTripEqs) {
 
 TEST(SDBMExpr, Constant) {
   // We can create consants and query them.
-  auto expr = SDBMConstantExpr::get(ctx(), 42);
+  auto expr = SDBMConstantExpr::get(dialect(), 42);
   EXPECT_EQ(expr.getValue(), 42);
 
   // Two separately created constants with identical values are trivially equal.
-  auto expr2 = SDBMConstantExpr::get(ctx(), 42);
+  auto expr2 = SDBMConstantExpr::get(dialect(), 42);
   EXPECT_EQ(expr, expr2);
 
   // Hierarchy is okay.
@@ -150,12 +161,12 @@ TEST(SDBMExpr, Constant) {
 
 TEST(SDBMExpr, Dim) {
   // We can create dimension expressions and query them.
-  auto expr = SDBMDimExpr::get(ctx(), 0);
+  auto expr = SDBMDimExpr::get(dialect(), 0);
   EXPECT_EQ(expr.getPosition(), 0u);
 
   // Two separately created dimensions with the same position are trivially
   // equal.
-  auto expr2 = SDBMDimExpr::get(ctx(), 0);
+  auto expr2 = SDBMDimExpr::get(dialect(), 0);
   EXPECT_EQ(expr, expr2);
 
   // Hierarchy is okay.
@@ -166,18 +177,18 @@ TEST(SDBMExpr, Dim) {
   EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
 
   // Dimensions are not Symbols.
-  auto symbol = SDBMSymbolExpr::get(ctx(), 0);
+  auto symbol = SDBMSymbolExpr::get(dialect(), 0);
   EXPECT_NE(expr, symbol);
   EXPECT_FALSE(expr.isa<SDBMSymbolExpr>());
 }
 
 TEST(SDBMExpr, Symbol) {
   // We can create symbol expressions and query them.
-  auto expr = SDBMSymbolExpr::get(ctx(), 0);
+  auto expr = SDBMSymbolExpr::get(dialect(), 0);
   EXPECT_EQ(expr.getPosition(), 0u);
 
   // Two separately created symbols with the same position are trivially equal.
-  auto expr2 = SDBMSymbolExpr::get(ctx(), 0);
+  auto expr2 = SDBMSymbolExpr::get(dialect(), 0);
   EXPECT_EQ(expr, expr2);
 
   // Hierarchy is okay.
@@ -188,15 +199,15 @@ TEST(SDBMExpr, Symbol) {
   EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
 
   // Dimensions are not Symbols.
-  auto symbol = SDBMDimExpr::get(ctx(), 0);
+  auto symbol = SDBMDimExpr::get(dialect(), 0);
   EXPECT_NE(expr, symbol);
   EXPECT_FALSE(expr.isa<SDBMDimExpr>());
 }
 
 TEST(SDBMExpr, Stripe) {
-  auto cst2 = SDBMConstantExpr::get(ctx(), 2);
-  auto cst0 = SDBMConstantExpr::get(ctx(), 0);
-  auto var = SDBMSymbolExpr::get(ctx(), 0);
+  auto cst2 = SDBMConstantExpr::get(dialect(), 2);
+  auto cst0 = SDBMConstantExpr::get(dialect(), 0);
+  auto var = SDBMSymbolExpr::get(dialect(), 0);
 
   // We can create stripe expressions and query them.
   auto expr = SDBMStripeExpr::get(var, cst2);
@@ -205,11 +216,11 @@ TEST(SDBMExpr, Stripe) {
 
   // Two separately created stripe expressions with the same LHS and RHS are
   // trivially equal.
-  auto expr2 = SDBMStripeExpr::get(SDBMSymbolExpr::get(ctx(), 0), cst2);
+  auto expr2 = SDBMStripeExpr::get(SDBMSymbolExpr::get(dialect(), 0), cst2);
   EXPECT_EQ(expr, expr2);
 
   // Stripes can be nested.
-  SDBMStripeExpr::get(expr, SDBMConstantExpr::get(ctx(), 4));
+  SDBMStripeExpr::get(expr, SDBMConstantExpr::get(dialect(), 4));
 
   // Non-positive stripe factors are not allowed.
   EXPECT_DEATH(SDBMStripeExpr::get(var, cst0), "non-positive");
@@ -222,8 +233,8 @@ TEST(SDBMExpr, Stripe) {
 }
 
 TEST(SDBMExpr, Neg) {
-  auto cst2 = SDBMConstantExpr::get(ctx(), 2);
-  auto var = SDBMSymbolExpr::get(ctx(), 0);
+  auto cst2 = SDBMConstantExpr::get(dialect(), 2);
+  auto var = SDBMSymbolExpr::get(dialect(), 0);
   auto stripe = SDBMStripeExpr::get(var, cst2);
 
   // We can create negation expressions and query them.
@@ -242,8 +253,8 @@ TEST(SDBMExpr, Neg) {
 }
 
 TEST(SDBMExpr, Sum) {
-  auto cst2 = SDBMConstantExpr::get(ctx(), 2);
-  auto var = SDBMSymbolExpr::get(ctx(), 0);
+  auto cst2 = SDBMConstantExpr::get(dialect(), 2);
+  auto var = SDBMSymbolExpr::get(dialect(), 0);
   auto stripe = SDBMStripeExpr::get(var, cst2);
 
   // We can create sum expressions and query them.
@@ -264,8 +275,8 @@ TEST(SDBMExpr, Sum) {
 }
 
 TEST(SDBMExpr, Diff) {
-  auto cst2 = SDBMConstantExpr::get(ctx(), 2);
-  auto var = SDBMSymbolExpr::get(ctx(), 0);
+  auto cst2 = SDBMConstantExpr::get(dialect(), 2);
+  auto var = SDBMSymbolExpr::get(dialect(), 0);
   auto stripe = SDBMStripeExpr::get(var, cst2);
 
   // We can create sum expressions and query them.
@@ -287,8 +298,8 @@ TEST(SDBMExpr, Diff) {
 
 TEST(SDBMExpr, AffineRoundTrip) {
   // Build an expression (s0 - s0 # 2)
-  auto cst2 = SDBMConstantExpr::get(ctx(), 2);
-  auto var = SDBMSymbolExpr::get(ctx(), 0);
+  auto cst2 = SDBMConstantExpr::get(dialect(), 2);
+  auto var = SDBMSymbolExpr::get(dialect(), 0);
   auto stripe = SDBMStripeExpr::get(var, cst2);
   auto expr = SDBMDiffExpr::get(var, stripe);
 
@@ -301,7 +312,7 @@ TEST(SDBMExpr, AffineRoundTrip) {
 
   // Check that (s0 # 2 # 5) can be converted to AffineExpr, i.e. stripe
   // detection supports nested expressions.
-  auto cst5 = SDBMConstantExpr::get(ctx(), 5);
+  auto cst5 = SDBMConstantExpr::get(dialect(), 5);
   auto outerStripe = SDBMStripeExpr::get(stripe, cst5);
   roundtripped = SDBMExpr::tryConvertAffineExpr(outerStripe.getAsAffineExpr());
   ASSERT_TRUE(roundtripped.hasValue());