Start building SDBM infrastructure
authorAlex Zinenko <zinenko@google.com>
Fri, 26 Apr 2019 08:05:24 +0000 (01:05 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 6 May 2019 15:17:42 +0000 (08:17 -0700)
    Striped difference-bound matrix expressions are a subset of affine expressions
    supporting low-complexity algorithms that can be useful for loop
    transformations.  This introduces the basic data data structures for building
    such expressions and unique'ing them in a MLIRContext.

--

PiperOrigin-RevId: 245380206

mlir/include/mlir/IR/SDBMExpr.h [new file with mode: 0644]
mlir/lib/IR/MLIRContext.cpp
mlir/lib/IR/SDBMExpr.cpp [new file with mode: 0644]
mlir/lib/IR/SDBMExprDetail.h [new file with mode: 0644]
mlir/unittests/IR/CMakeLists.txt
mlir/unittests/IR/SDBMTest.cpp [new file with mode: 0644]

diff --git a/mlir/include/mlir/IR/SDBMExpr.h b/mlir/include/mlir/IR/SDBMExpr.h
new file mode 100644 (file)
index 0000000..02e1de3
--- /dev/null
@@ -0,0 +1,353 @@
+//===- SDBMExpr.h - MLIR SDBM Expression ------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// A striped difference-bound matrix (SDBM) expression is a constant expression,
+// an identifier, a binary expression with constant RHS and +, stripe operators
+// or a difference expression between two identifiers.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_SDBMEXPR_H
+#define MLIR_IR_SDBMEXPR_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMapInfo.h"
+
+namespace mlir {
+
+class MLIRContext;
+
+enum class SDBMExprKind { Add, Stripe, Diff, Constant, DimId, SymbolId, Neg };
+
+namespace detail {
+struct SDBMExprStorage;
+struct SDBMBinaryExprStorage;
+struct SDBMDiffExprStorage;
+struct SDBMPositiveExprStorage;
+struct SDBMConstantExprStorage;
+struct SDBMNegExprStorage;
+} // namespace detail
+
+/// Striped Difference-Bounded Matrix (SDBM) expression is a base left-hand side
+/// expression for the SDBM framework.  SDBM expressions are a subset of affine
+/// expressions supporting low-complexity algorithms for the operations used in
+/// loop transformations.  In particular, are supported:
+///   - constant expressions;
+///   - single variables (dimensions and symbols) with +1 or -1 coefficient;
+///   - stripe expressions: "x # C", where "x" is a single variable or another
+///     stripe expression, "#" is the stripe operator, and "C" is a constant
+///     expression; "#" is defined as x - x mod C.
+///   - sum expressions between single variable/stripe expressions and constant
+///     expressions;
+///   - difference expressions between single variable/stripe expressions.
+/// `SDBMExpr` class hierarchy provides a type-safe interface to constructing
+/// and operating on SDBM expressions.  For example, it requires the LHS of a
+/// sum expression to be a single variable or a stripe expression.  These
+/// restrictions are intended to force the caller to perform the necessary
+/// simplifications to stay within the SDBM domain, because SDBM expressions do
+/// not combine in more cases than they do.  This choice may be reconsidered in
+/// the future.
+///
+/// `SDBMExpr` and derived classes are thin wrappers around a pointer owned by
+/// an MLIRContext, and should be used by-value.  They are uniqued in the
+/// MLIRContext and immortal.
+class SDBMExpr {
+public:
+  using ImplType = detail::SDBMExprStorage;
+  SDBMExpr() : impl(nullptr) {}
+  /* implicit */ SDBMExpr(ImplType *expr) : impl(expr) {}
+
+  /// SDBM expressions are thin wrappers around a unique'ed immutable pointer,
+  /// which makes them trivially assignable and trivially copyable.
+  SDBMExpr(const SDBMExpr &) = default;
+  SDBMExpr &operator=(const SDBMExpr &) = default;
+
+  /// SDBM expressions can be compared straight-forwardly.
+  bool operator==(const SDBMExpr &other) const { return impl == other.impl; }
+  bool operator!=(const SDBMExpr &other) const { return !(*this == other); }
+
+  /// SDBM expressions are convertible to `bool`: null expressions are converted
+  /// to false, non-null expressions are converted to true.
+  explicit operator bool() const { return impl != nullptr; }
+  bool operator!() const { return !static_cast<bool>(*this); }
+
+  /// Prints the SDBM expression.
+  void print(raw_ostream &os) const;
+  void dump() const;
+
+  /// LLVM-style casts.
+  template <typename U> bool isa() const { return U::isClassFor(*this); }
+  template <typename U> U dyn_cast() const {
+    if (!isa<U>())
+      return {};
+    return U(const_cast<SDBMExpr *>(this)->impl);
+  }
+  template <typename U> U cast() const {
+    assert(isa<U>() && "cast to incorrect subtype");
+    return U(const_cast<SDBMExpr *>(this)->impl);
+  }
+
+  /// Support for LLVM hashing.
+  ::llvm::hash_code hash_value() const { return ::llvm::hash_value(impl); }
+
+  /// Returns the kind of the SDBM expression.
+  SDBMExprKind getKind() const;
+
+  /// Returns the MLIR context in which this expression lives.
+  MLIRContext *getContext() const;
+
+protected:
+  ImplType *impl;
+};
+
+/// SDBM constant expression, wraps a 64-bit integer.
+class SDBMConstantExpr : public SDBMExpr {
+public:
+  using ImplType = detail::SDBMConstantExprStorage;
+
+  using SDBMExpr::SDBMExpr;
+
+  /// Obtain or create a constant expression unique'ed in the given context.
+  static SDBMConstantExpr get(MLIRContext *context, int64_t value);
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    return expr.getKind() == SDBMExprKind::Constant;
+  }
+
+  int64_t getValue() const;
+};
+
+/// SDBM varying expression can be one of:
+///   - input variable expression;
+///   - stripe expression;
+///   - negation (product with -1) of either of the above.
+///   - sum of a varying and a constant expression
+///   - difference between varying expressions
+class SDBMVaryingExpr : public SDBMExpr {
+public:
+  using ImplType = detail::SDBMExprStorage;
+  using SDBMExpr::SDBMExpr;
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    return expr.getKind() == SDBMExprKind::DimId ||
+           expr.getKind() == SDBMExprKind::SymbolId ||
+           expr.getKind() == SDBMExprKind::Neg ||
+           expr.getKind() == SDBMExprKind::Stripe ||
+           expr.getKind() == SDBMExprKind::Add ||
+           expr.getKind() == SDBMExprKind::Diff;
+  }
+};
+
+/// SDBM positive variable expression can be one of:
+///  - single variable expression;
+///  - stripe expression.
+class SDBMPositiveExpr : public SDBMVaryingExpr {
+public:
+  using SDBMVaryingExpr::SDBMVaryingExpr;
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    return expr.getKind() == SDBMExprKind::DimId ||
+           expr.getKind() == SDBMExprKind::SymbolId ||
+           expr.getKind() == SDBMExprKind::Stripe;
+  }
+};
+
+/// SDBM sum expression.  LHS is a varying expression and RHS is always a
+/// constant expression.
+class SDBMSumExpr : public SDBMVaryingExpr {
+public:
+  using ImplType = detail::SDBMBinaryExprStorage;
+  using SDBMVaryingExpr::SDBMVaryingExpr;
+
+  /// Obtain or create a sum expression unique'ed in the given context.
+  static SDBMSumExpr get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs);
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    SDBMExprKind kind = expr.getKind();
+    return kind == SDBMExprKind::Add;
+  }
+
+  SDBMVaryingExpr getLHS() const;
+  SDBMConstantExpr getRHS() const;
+};
+
+/// SDBM difference expression.  Both LHS and RHS are positive variable
+/// expressions.
+class SDBMDiffExpr : public SDBMVaryingExpr {
+public:
+  using ImplType = detail::SDBMDiffExprStorage;
+  using SDBMVaryingExpr::SDBMVaryingExpr;
+
+  /// Obtain or create a difference expression unique'ed in the given context.
+  static SDBMDiffExpr get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs);
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    return expr.getKind() == SDBMExprKind::Diff;
+  }
+
+  SDBMPositiveExpr getLHS() const;
+  SDBMPositiveExpr getRHS() const;
+};
+
+/// SDBM stripe expression "x # C" where "x" is a positive variable expression,
+/// "C" is a constant expression and "#" is the stripe operator defined as:
+///   x # C = x - x mod C.
+class SDBMStripeExpr : public SDBMPositiveExpr {
+public:
+  using ImplType = detail::SDBMBinaryExprStorage;
+  using SDBMPositiveExpr::SDBMPositiveExpr;
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    return expr.getKind() == SDBMExprKind::Stripe;
+  }
+
+  static SDBMStripeExpr get(SDBMPositiveExpr var,
+                            SDBMConstantExpr stripeFactor);
+
+  SDBMPositiveExpr getVar() const;
+  SDBMConstantExpr getStripeFactor() const;
+};
+
+/// SDBM "input" variable expression can be either a dimension identifier or
+/// a symbol identifier.  When used to define SDBM functions, dimensions are
+/// interpreted as function arguments while symbols are treated as unknown but
+/// constant values, hence the name.
+class SDBMInputExpr : public SDBMPositiveExpr {
+public:
+  using ImplType = detail::SDBMPositiveExprStorage;
+  using SDBMPositiveExpr::SDBMPositiveExpr;
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    return expr.getKind() == SDBMExprKind::DimId ||
+           expr.getKind() == SDBMExprKind::SymbolId;
+  }
+
+  unsigned getPosition() const;
+};
+
+/// SDBM dimension expression.  Dimensions correspond to function arguments
+/// when defining functions using SDBM expressions.
+class SDBMDimExpr : public SDBMInputExpr {
+public:
+  using ImplType = detail::SDBMPositiveExprStorage;
+  using SDBMInputExpr::SDBMInputExpr;
+
+  /// Obtain or create a dimension expression unique'ed in the given context.
+  static SDBMDimExpr get(MLIRContext *context, unsigned position);
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    return expr.getKind() == SDBMExprKind::DimId;
+  }
+};
+
+/// SDBM symbol expression.  Symbols correspond to symbolic constants when
+/// defining functions using SDBM expressions.
+class SDBMSymbolExpr : public SDBMInputExpr {
+public:
+  using ImplType = detail::SDBMPositiveExprStorage;
+  using SDBMInputExpr::SDBMInputExpr;
+
+  /// Obtain or create a symbol expression unique'ed in the given context.
+  static SDBMSymbolExpr get(MLIRContext *context, unsigned position);
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    return expr.getKind() == SDBMExprKind::SymbolId;
+  }
+};
+
+/// Negation of an SDBM variable expression.  Equivalent to multiplying the
+/// expression with -1 (SDBM does not support other coefficients that 1 and -1).
+class SDBMNegExpr : public SDBMVaryingExpr {
+public:
+  using ImplType = detail::SDBMNegExprStorage;
+  using SDBMVaryingExpr::SDBMVaryingExpr;
+
+  /// Obtain or create a negation expression unique'ed in the given context.
+  static SDBMNegExpr get(SDBMPositiveExpr var);
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    return expr.getKind() == SDBMExprKind::Neg;
+  }
+
+  SDBMPositiveExpr getVar() const;
+};
+
+} // end namespace mlir
+
+namespace llvm {
+// SDBMVaryingExpr hash just like pointers.
+template <> struct DenseMapInfo<mlir::SDBMVaryingExpr> {
+  static mlir::SDBMVaryingExpr getEmptyKey() {
+    auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+    return mlir::SDBMVaryingExpr(
+        static_cast<mlir::SDBMExpr::ImplType *>(pointer));
+  }
+  static mlir::SDBMVaryingExpr getTombstoneKey() {
+    auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+    return mlir::SDBMVaryingExpr(
+        static_cast<mlir::SDBMExpr::ImplType *>(pointer));
+  }
+  static unsigned getHashValue(mlir::SDBMVaryingExpr expr) {
+    return expr.hash_value();
+  }
+  static bool isEqual(mlir::SDBMVaryingExpr lhs, mlir::SDBMVaryingExpr rhs) {
+    return lhs == rhs;
+  }
+};
+
+// SDBMPositiveExpr hash just like pointers.
+template <> struct DenseMapInfo<mlir::SDBMPositiveExpr> {
+  static mlir::SDBMPositiveExpr getEmptyKey() {
+    auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+    return mlir::SDBMPositiveExpr(
+        static_cast<mlir::SDBMExpr::ImplType *>(pointer));
+  }
+  static mlir::SDBMPositiveExpr getTombstoneKey() {
+    auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+    return mlir::SDBMPositiveExpr(
+        static_cast<mlir::SDBMExpr::ImplType *>(pointer));
+  }
+  static unsigned getHashValue(mlir::SDBMPositiveExpr expr) {
+    return expr.hash_value();
+  }
+  static bool isEqual(mlir::SDBMPositiveExpr lhs, mlir::SDBMPositiveExpr rhs) {
+    return lhs == rhs;
+  }
+};
+
+// SDBMConstantExpr hash just like pointers.
+template <> struct DenseMapInfo<mlir::SDBMConstantExpr> {
+  static mlir::SDBMConstantExpr getEmptyKey() {
+    auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+    return mlir::SDBMConstantExpr(
+        static_cast<mlir::SDBMExpr::ImplType *>(pointer));
+  }
+  static mlir::SDBMConstantExpr getTombstoneKey() {
+    auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+    return mlir::SDBMConstantExpr(
+        static_cast<mlir::SDBMExpr::ImplType *>(pointer));
+  }
+  static unsigned getHashValue(mlir::SDBMConstantExpr expr) {
+    return expr.hash_value();
+  }
+  static bool isEqual(mlir::SDBMConstantExpr lhs, mlir::SDBMConstantExpr rhs) {
+    return lhs == rhs;
+  }
+};
+} // namespace llvm
+
+#endif // MLIR_IR_SDBMEXPR_H
index e630914..29ac989 100644 (file)
@@ -21,6 +21,7 @@
 #include "AttributeDetail.h"
 #include "IntegerSetDetail.h"
 #include "LocationDetail.h"
+#include "SDBMExprDetail.h"
 #include "TypeDetail.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
@@ -462,6 +463,26 @@ public:
   DenseMap<int64_t, AffineConstantExprStorage *> constExprs;
 
   //===--------------------------------------------------------------------===//
+  // SDBM uniquing
+  //===--------------------------------------------------------------------===//
+  llvm::BumpPtrAllocator SDBMAllocator;
+  llvm::sys::SmartRWMutex<true> SDBMMutex;
+
+  DenseMap<std::tuple<SDBMVaryingExpr, SDBMConstantExpr>,
+           SDBMBinaryExprStorage *>
+      SDBMSumExprs;
+  DenseMap<std::tuple<SDBMPositiveExpr, SDBMConstantExpr>,
+           SDBMBinaryExprStorage *>
+      SDBMStripeExprs;
+  DenseMap<std::tuple<SDBMPositiveExpr, SDBMPositiveExpr>,
+           SDBMDiffExprStorage *>
+      SDBMDiffExprs;
+  std::vector<SDBMPositiveExprStorage *> SDBMDimExprs;
+  std::vector<SDBMPositiveExprStorage *> SDBMSymbolExprs;
+  DenseMap<SDBMPositiveExpr, SDBMNegExprStorage *> SDBMNegExprs;
+  DenseMap<int64_t, SDBMConstantExprStorage *> SDBMConstExprs;
+
+  //===--------------------------------------------------------------------===//
   // Type uniquing
   //===--------------------------------------------------------------------===//
   StorageUniquer typeUniquer;
@@ -844,6 +865,103 @@ Location FusedLoc::get(ArrayRef<Location> locs, Attribute metadata,
 }
 
 //===----------------------------------------------------------------------===//
+// SDBMExpr uniquing
+//===----------------------------------------------------------------------===//
+
+SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) {
+  assert(lhs && "expected SDBM variable expression");
+  assert(rhs && "expected SDBM constant");
+
+  MLIRContextImpl &impl = lhs.getContext()->getImpl();
+
+  // If LHS of a sum is another sum, fold the constant RHS parts.
+  if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>()) {
+    lhs = lhsSum.getLHS();
+    rhs = SDBMConstantExpr::get(rhs.getContext(),
+                                rhs.getValue() + lhsSum.getRHS().getValue());
+  }
+
+  auto key = std::make_tuple(lhs, rhs);
+  return safeGetOrCreate(
+      impl.SDBMSumExprs, key, impl.SDBMMutex, [&impl, lhs, rhs] {
+        auto *mem = impl.SDBMAllocator.Allocate<SDBMBinaryExprStorage>();
+        return new (mem) SDBMBinaryExprStorage(SDBMExprKind::Add,
+                                               lhs.getContext(), lhs, rhs);
+      });
+}
+
+SDBMDiffExpr SDBMDiffExpr::get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs) {
+  assert(lhs && "expected SDBM dimension");
+  assert(rhs && "expected SDBM dimension");
+
+  MLIRContextImpl &impl = lhs.getContext()->getImpl();
+  auto key = std::make_tuple(lhs, rhs);
+  return safeGetOrCreate(
+      impl.SDBMDiffExprs, key, impl.SDBMMutex, [&impl, lhs, rhs] {
+        auto *mem = impl.SDBMAllocator.Allocate<SDBMDiffExprStorage>();
+        return new (mem) SDBMDiffExprStorage(lhs.getContext(), lhs, rhs);
+      });
+}
+
+SDBMStripeExpr SDBMStripeExpr::get(SDBMPositiveExpr var,
+                                   SDBMConstantExpr stripeFactor) {
+  assert(var && "expected SDBM variable expression");
+  assert(stripeFactor && "expected non-null stripe factor");
+  assert(stripeFactor.getValue() > 0 && "non-positive stripe factor");
+
+  MLIRContextImpl &impl = var.getContext()->getImpl();
+  auto key = std::make_tuple(var, stripeFactor);
+  return safeGetOrCreate(
+      impl.SDBMStripeExprs, key, impl.SDBMMutex, [&impl, var, stripeFactor] {
+        auto *mem = impl.SDBMAllocator.Allocate<SDBMBinaryExprStorage>();
+        return new (mem) SDBMBinaryExprStorage(
+            SDBMExprKind::Stripe, var.getContext(), var, stripeFactor);
+      });
+}
+
+SDBMDimExpr SDBMDimExpr::get(MLIRContext *context, unsigned position) {
+  assert(context && "expected non-null context");
+  MLIRContextImpl &impl = context->getImpl();
+  return safeGetOrCreate(
+      impl.SDBMDimExprs, position, impl.SDBMMutex, [&impl, context, position] {
+        auto *mem = impl.SDBMAllocator.Allocate<SDBMPositiveExprStorage>();
+        return new (mem)
+            SDBMPositiveExprStorage(SDBMExprKind::DimId, context, position);
+      });
+}
+
+SDBMSymbolExpr SDBMSymbolExpr::get(MLIRContext *context, unsigned position) {
+  assert(context && "expected non-null context");
+  MLIRContextImpl &impl = context->getImpl();
+  return safeGetOrCreate(
+      impl.SDBMSymbolExprs, position, impl.SDBMMutex,
+      [&impl, context, position] {
+        auto *mem = impl.SDBMAllocator.Allocate<SDBMPositiveExprStorage>();
+        return new (mem)
+            SDBMPositiveExprStorage(SDBMExprKind::SymbolId, context, position);
+      });
+}
+
+SDBMConstantExpr SDBMConstantExpr::get(MLIRContext *context, int64_t value) {
+  assert(context && "expected non-null context");
+  MLIRContextImpl &impl = context->getImpl();
+  return safeGetOrCreate(
+      impl.SDBMConstExprs, value, impl.SDBMMutex, [&impl, context, value] {
+        auto *mem = impl.SDBMAllocator.Allocate<SDBMConstantExprStorage>();
+        return new (mem) SDBMConstantExprStorage(context, value);
+      });
+}
+
+SDBMNegExpr SDBMNegExpr::get(SDBMPositiveExpr var) {
+  assert(var && "expected non-null SDBM variable expression");
+  MLIRContextImpl &impl = var.getContext()->getImpl();
+  return safeGetOrCreate(impl.SDBMNegExprs, var, impl.SDBMMutex, [&impl, var] {
+    auto *mem = impl.SDBMAllocator.Allocate<SDBMNegExprStorage>();
+    return new (mem) SDBMNegExprStorage(var);
+  });
+}
+
+//===----------------------------------------------------------------------===//
 // Type uniquing
 //===----------------------------------------------------------------------===//
 
diff --git a/mlir/lib/IR/SDBMExpr.cpp b/mlir/lib/IR/SDBMExpr.cpp
new file mode 100644 (file)
index 0000000..2d94dd8
--- /dev/null
@@ -0,0 +1,202 @@
+//===- SDBMExpr.h - MLIR SDBM Expression implementation -------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// A striped difference-bound matrix (SDBM) expression is a constant expression,
+// an identifier, a binary expression with constant RHS and +, stripe operators
+// or a difference expression between two identifiers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/SDBMExpr.h"
+#include "SDBMExprDetail.h"
+
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// SDBMExpr
+//===----------------------------------------------------------------------===//
+
+SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); }
+
+MLIRContext *SDBMExpr::getContext() const { return impl->getContext(); }
+
+template <typename Derived> class SDBMVisitor {
+public:
+  /// Visit the given SDBM expression, dispatching to kind-specific functions.
+  void visit(SDBMExpr expr) {
+    auto *derived = static_cast<Derived *>(this);
+    switch (expr.getKind()) {
+    case SDBMExprKind::Add:
+    case SDBMExprKind::Diff:
+    case SDBMExprKind::DimId:
+    case SDBMExprKind::SymbolId:
+    case SDBMExprKind::Neg:
+    case SDBMExprKind::Stripe:
+      return derived->visitVarying(expr.cast<SDBMVaryingExpr>());
+    case SDBMExprKind::Constant:
+      return derived->visitConstant(expr.cast<SDBMConstantExpr>());
+    }
+  }
+
+protected:
+  /// Default visitors do nothing.
+  void visitSum(SDBMSumExpr) {}
+  void visitDiff(SDBMDiffExpr) {}
+  void visitStripe(SDBMStripeExpr) {}
+  void visitDim(SDBMDimExpr) {}
+  void visitSymbol(SDBMSymbolExpr) {}
+  void visitNeg(SDBMNegExpr) {}
+  void visitConstant(SDBMConstantExpr) {}
+
+  /// Default implementation of visitPositive dispatches to the special
+  /// functions for stripes and other variables.  Concrete visitors can override
+  /// it.
+  void visitPositive(SDBMPositiveExpr expr) {
+    auto *derived = static_cast<Derived *>(this);
+    if (expr.getKind() == SDBMExprKind::Stripe)
+      derived->visitStripe(expr.cast<SDBMStripeExpr>());
+    else
+      derived->visitInput(expr.cast<SDBMInputExpr>());
+  }
+
+  /// Default implementation of visitInput dispatches to the special
+  /// functions for dimensions or symbols.  Concrete visitors can override it to
+  /// visit all variables instead.
+  void visitInput(SDBMInputExpr expr) {
+    auto *derived = static_cast<Derived *>(this);
+    if (expr.getKind() == SDBMExprKind::DimId)
+      derived->visitDim(expr.cast<SDBMDimExpr>());
+    else
+      derived->visitSymbol(expr.cast<SDBMSymbolExpr>());
+  }
+
+  /// Default implementation of visitVarying dispatches to the special
+  /// functions for variables and negations thereof.  Concerete visitors can
+  /// override it to visit all variables and negations isntead.
+  void visitVarying(SDBMVaryingExpr expr) {
+    auto *derived = static_cast<Derived *>(this);
+    if (auto var = expr.dyn_cast<SDBMPositiveExpr>())
+      derived->visitPositive(var);
+    else if (auto neg = expr.dyn_cast<SDBMNegExpr>())
+      derived->visitNeg(neg);
+    else if (auto sum = expr.dyn_cast<SDBMSumExpr>())
+      derived->visitSum(sum);
+    else if (auto diff = expr.dyn_cast<SDBMDiffExpr>())
+      derived->visitDiff(diff);
+
+    llvm_unreachable("unhandled subtype of varying SDBM expression");
+  }
+};
+
+void SDBMExpr::print(raw_ostream &os) const {
+  struct Printer : public SDBMVisitor<Printer> {
+    Printer(raw_ostream &ostream) : prn(ostream) {}
+
+    void visitSum(SDBMSumExpr expr) {
+      visitVarying(expr.getLHS());
+      prn << " + ";
+      visitConstant(expr.getRHS());
+    }
+    void visitDiff(SDBMDiffExpr expr) {
+      visitPositive(expr.getLHS());
+      prn << " - ";
+      visitPositive(expr.getRHS());
+    }
+    void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); }
+    void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); }
+    void visitStripe(SDBMStripeExpr expr) {
+      visitPositive(expr.getVar());
+      prn << " # ";
+      visitConstant(expr.getStripeFactor());
+    }
+    void visitNeg(SDBMNegExpr expr) {
+      prn << '-';
+      visitPositive(expr.getVar());
+    }
+    void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); }
+
+    raw_ostream &prn;
+  };
+  Printer printer(os);
+  printer.visit(*this);
+}
+
+void SDBMExpr::dump() const { print(llvm::errs()); }
+
+//===----------------------------------------------------------------------===//
+// SDBMSumExpr
+//===----------------------------------------------------------------------===//
+
+SDBMVaryingExpr SDBMSumExpr::getLHS() const {
+  return static_cast<ImplType *>(impl)->lhs;
+}
+
+SDBMConstantExpr SDBMSumExpr::getRHS() const {
+  return static_cast<ImplType *>(impl)->rhs;
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMDiffExpr
+//===----------------------------------------------------------------------===//
+
+SDBMPositiveExpr SDBMDiffExpr::getLHS() const {
+  return static_cast<ImplType *>(impl)->lhs;
+}
+
+SDBMPositiveExpr SDBMDiffExpr::getRHS() const {
+  return static_cast<ImplType *>(impl)->rhs;
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMStripeExpr
+//===----------------------------------------------------------------------===//
+
+SDBMPositiveExpr SDBMStripeExpr::getVar() const {
+  if (SDBMVaryingExpr lhs = static_cast<ImplType *>(impl)->lhs)
+    return lhs.cast<SDBMPositiveExpr>();
+  return {};
+}
+
+SDBMConstantExpr SDBMStripeExpr::getStripeFactor() const {
+  return static_cast<ImplType *>(impl)->rhs;
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMInputExpr
+//===----------------------------------------------------------------------===//
+
+unsigned SDBMInputExpr::getPosition() const {
+  return static_cast<ImplType *>(impl)->position;
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMConstantExpr
+//===----------------------------------------------------------------------===//
+
+int64_t SDBMConstantExpr::getValue() const {
+  return static_cast<ImplType *>(impl)->constant;
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMNegExpr
+//===----------------------------------------------------------------------===//
+
+SDBMPositiveExpr SDBMNegExpr::getVar() const {
+  return static_cast<ImplType *>(impl)->dim;
+}
diff --git a/mlir/lib/IR/SDBMExprDetail.h b/mlir/lib/IR/SDBMExprDetail.h
new file mode 100644 (file)
index 0000000..f661d2a
--- /dev/null
@@ -0,0 +1,84 @@
+//===- SDBMExprDetail.h - MLIR SDBM Expression storage details --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This holds implementation details of SDBMExpr, in particular underlying
+// storage types.  MLIRContext.cpp needs to know the storage layout for
+// allocation and unique'ing purposes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_SDBMEXPRDETAIL_H
+#define MLIR_IR_SDBMEXPRDETAIL_H
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/SDBMExpr.h"
+#include "llvm/ADT/PointerIntPair.h"
+
+namespace mlir {
+namespace detail {
+
+struct SDBMExprStorage {
+  SDBMExprStorage(SDBMExprKind kind, MLIRContext *context)
+      : contextAndKind(context, kind) {}
+
+  SDBMExprKind getKind() { return contextAndKind.getInt(); }
+
+  MLIRContext *getContext() { return contextAndKind.getPointer(); }
+
+  // This needs to know the layout of MLIRContext so the relevant file is
+  // included.
+  llvm::PointerIntPair<MLIRContext *, 3, SDBMExprKind> contextAndKind;
+};
+
+struct SDBMBinaryExprStorage : public SDBMExprStorage {
+  SDBMBinaryExprStorage(SDBMExprKind kind, MLIRContext *context,
+                        SDBMVaryingExpr left, SDBMConstantExpr right)
+      : SDBMExprStorage(kind, context), lhs(left), rhs(right) {}
+  SDBMVaryingExpr lhs;
+  SDBMConstantExpr rhs;
+};
+
+struct SDBMDiffExprStorage : public SDBMExprStorage {
+  SDBMDiffExprStorage(MLIRContext *context, SDBMPositiveExpr left,
+                      SDBMPositiveExpr right)
+      : SDBMExprStorage(SDBMExprKind::Diff, context), lhs(left), rhs(right) {}
+  SDBMPositiveExpr lhs;
+  SDBMPositiveExpr rhs;
+};
+
+struct SDBMConstantExprStorage : public SDBMExprStorage {
+  SDBMConstantExprStorage(MLIRContext *context, int64_t value)
+      : SDBMExprStorage(SDBMExprKind::Constant, context), constant(value) {}
+  int64_t constant;
+};
+
+struct SDBMPositiveExprStorage : public SDBMExprStorage {
+  SDBMPositiveExprStorage(SDBMExprKind kind, MLIRContext *context, unsigned pos)
+      : SDBMExprStorage(kind, context), position(pos) {}
+  unsigned position;
+};
+
+struct SDBMNegExprStorage : public SDBMExprStorage {
+  SDBMNegExprStorage(SDBMPositiveExpr expr)
+      : SDBMExprStorage(SDBMExprKind::Neg, expr.getContext()), dim(expr) {}
+  SDBMPositiveExpr dim;
+};
+
+} // end namespace detail
+} // end namespace mlir
+
+#endif // MLIR_IR_SDBMEXPRDETAIL_H
index 5236ea7..5d72c1f 100644 (file)
@@ -1,6 +1,7 @@
 add_mlir_unittest(MLIRIRTests
   DialectTest.cpp
   OperationSupportTest.cpp
+  SDBMTest.cpp
 )
 target_link_libraries(MLIRIRTests
   PRIVATE
diff --git a/mlir/unittests/IR/SDBMTest.cpp b/mlir/unittests/IR/SDBMTest.cpp
new file mode 100644 (file)
index 0000000..2075db5
--- /dev/null
@@ -0,0 +1,182 @@
+//===- SDBMTest.cpp - SDBM expression unit tests --------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/SDBMExpr.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+static MLIRContext *ctx() {
+  static thread_local MLIRContext context;
+  return &context;
+}
+
+namespace {
+
+TEST(SDBMExpr, Constant) {
+  // We can create consants and query them.
+  auto expr = SDBMConstantExpr::get(ctx(), 42);
+  EXPECT_EQ(expr.getValue(), 42);
+
+  // Two separately created constants with identical values are trivially equal.
+  auto expr2 = SDBMConstantExpr::get(ctx(), 42);
+  EXPECT_EQ(expr, expr2);
+
+  // Hierarchy is okay.
+  auto generic = static_cast<SDBMExpr>(expr);
+  EXPECT_TRUE(generic.isa<SDBMConstantExpr>());
+}
+
+TEST(SDBMExpr, Dim) {
+  // We can create dimension expressions and query them.
+  auto expr = SDBMDimExpr::get(ctx(), 0);
+  EXPECT_EQ(expr.getPosition(), 0);
+
+  // Two separately created dimensions with the same position are trivially
+  // equal.
+  auto expr2 = SDBMDimExpr::get(ctx(), 0);
+  EXPECT_EQ(expr, expr2);
+
+  // Hierarchy is okay.
+  auto generic = static_cast<SDBMExpr>(expr);
+  EXPECT_TRUE(generic.isa<SDBMDimExpr>());
+  EXPECT_TRUE(generic.isa<SDBMInputExpr>());
+  EXPECT_TRUE(generic.isa<SDBMPositiveExpr>());
+  EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
+
+  // Dimensions are not Symbols.
+  auto symbol = SDBMSymbolExpr::get(ctx(), 0);
+  EXPECT_NE(expr, symbol);
+  EXPECT_FALSE(expr.isa<SDBMSymbolExpr>());
+}
+
+TEST(SDBMExpr, Symbol) {
+  // We can create symbol expressions and query them.
+  auto expr = SDBMSymbolExpr::get(ctx(), 0);
+  EXPECT_EQ(expr.getPosition(), 0);
+
+  // Two separately created symbols with the same position are trivially equal.
+  auto expr2 = SDBMSymbolExpr::get(ctx(), 0);
+  EXPECT_EQ(expr, expr2);
+
+  // Hierarchy is okay.
+  auto generic = static_cast<SDBMExpr>(expr);
+  EXPECT_TRUE(generic.isa<SDBMSymbolExpr>());
+  EXPECT_TRUE(generic.isa<SDBMInputExpr>());
+  EXPECT_TRUE(generic.isa<SDBMPositiveExpr>());
+  EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
+
+  // Dimensions are not Symbols.
+  auto symbol = SDBMDimExpr::get(ctx(), 0);
+  EXPECT_NE(expr, symbol);
+  EXPECT_FALSE(expr.isa<SDBMDimExpr>());
+}
+
+TEST(SDBMExpr, Stripe) {
+  auto cst2 = SDBMConstantExpr::get(ctx(), 2);
+  auto cst0 = SDBMConstantExpr::get(ctx(), 0);
+  auto var = SDBMSymbolExpr::get(ctx(), 0);
+
+  // We can create stripe expressions and query them.
+  auto expr = SDBMStripeExpr::get(var, cst2);
+  EXPECT_EQ(expr.getVar(), var);
+  EXPECT_EQ(expr.getStripeFactor(), cst2);
+
+  // Two separately created stripe expressions with the same LHS and RHS are
+  // trivially equal.
+  auto expr2 = SDBMStripeExpr::get(SDBMSymbolExpr::get(ctx(), 0), cst2);
+  EXPECT_EQ(expr, expr2);
+
+  // Stripes can be nested.
+  SDBMStripeExpr::get(expr, SDBMConstantExpr::get(ctx(), 4));
+
+  // Non-positive stripe factors are not allowed.
+  EXPECT_DEATH(SDBMStripeExpr::get(var, cst0), "non-positive");
+
+  // Hierarchy is okay.
+  auto generic = static_cast<SDBMExpr>(expr);
+  EXPECT_TRUE(generic.isa<SDBMStripeExpr>());
+  EXPECT_TRUE(generic.isa<SDBMPositiveExpr>());
+  EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
+}
+
+TEST(SDBMExpr, Neg) {
+  auto cst2 = SDBMConstantExpr::get(ctx(), 2);
+  auto var = SDBMSymbolExpr::get(ctx(), 0);
+  auto stripe = SDBMStripeExpr::get(var, cst2);
+
+  // We can create negation expressions and query them.
+  auto expr = SDBMNegExpr::get(var);
+  EXPECT_EQ(expr.getVar(), var);
+  auto expr2 = SDBMNegExpr::get(stripe);
+  EXPECT_EQ(expr2.getVar(), stripe);
+
+  // Neg expressions are trivially comparable.
+  EXPECT_EQ(expr, SDBMNegExpr::get(var));
+
+  // Hierarchy is okay.
+  auto generic = static_cast<SDBMExpr>(expr);
+  EXPECT_TRUE(generic.isa<SDBMNegExpr>());
+  EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
+}
+
+TEST(SDBMExpr, Sum) {
+  auto cst2 = SDBMConstantExpr::get(ctx(), 2);
+  auto var = SDBMSymbolExpr::get(ctx(), 0);
+  auto stripe = SDBMStripeExpr::get(var, cst2);
+
+  // We can create sum expressions and query them.
+  auto expr = SDBMSumExpr::get(var, cst2);
+  EXPECT_EQ(expr.getLHS(), var);
+  EXPECT_EQ(expr.getRHS(), cst2);
+  auto expr2 = SDBMSumExpr::get(stripe, cst2);
+  EXPECT_EQ(expr2.getLHS(), stripe);
+  EXPECT_EQ(expr2.getRHS(), cst2);
+
+  // Sum expressions are trivially comparable.
+  EXPECT_EQ(expr, SDBMSumExpr::get(var, cst2));
+
+  // Hierarchy is okay.
+  auto generic = static_cast<SDBMExpr>(expr);
+  EXPECT_TRUE(generic.isa<SDBMSumExpr>());
+  EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
+}
+
+TEST(SDBMExpr, Diff) {
+  auto cst2 = SDBMConstantExpr::get(ctx(), 2);
+  auto var = SDBMSymbolExpr::get(ctx(), 0);
+  auto stripe = SDBMStripeExpr::get(var, cst2);
+
+  // We can create sum expressions and query them.
+  auto expr = SDBMDiffExpr::get(var, stripe);
+  EXPECT_EQ(expr.getLHS(), var);
+  EXPECT_EQ(expr.getRHS(), stripe);
+  auto expr2 = SDBMDiffExpr::get(stripe, var);
+  EXPECT_EQ(expr2.getLHS(), stripe);
+  EXPECT_EQ(expr2.getRHS(), var);
+
+  // Sum expressions are trivially comparable.
+  EXPECT_EQ(expr, SDBMDiffExpr::get(var, stripe));
+
+  // Hierarchy is okay.
+  auto generic = static_cast<SDBMExpr>(expr);
+  EXPECT_TRUE(generic.isa<SDBMDiffExpr>());
+  EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
+}
+
+} // end namespace