--- /dev/null
+//===- SDBMExpr.h - MLIR SDBM Expression ------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// A striped difference-bound matrix (SDBM) expression is a constant expression,
+// an identifier, a binary expression with constant RHS and +, stripe operators
+// or a difference expression between two identifiers.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_SDBMEXPR_H
+#define MLIR_IR_SDBMEXPR_H
+
+#include "mlir/Support/LLVM.h"
+#include "llvm/ADT/DenseMapInfo.h"
+
+namespace mlir {
+
+class MLIRContext;
+
+enum class SDBMExprKind { Add, Stripe, Diff, Constant, DimId, SymbolId, Neg };
+
+namespace detail {
+struct SDBMExprStorage;
+struct SDBMBinaryExprStorage;
+struct SDBMDiffExprStorage;
+struct SDBMPositiveExprStorage;
+struct SDBMConstantExprStorage;
+struct SDBMNegExprStorage;
+} // namespace detail
+
+/// Striped Difference-Bounded Matrix (SDBM) expression is a base left-hand side
+/// expression for the SDBM framework. SDBM expressions are a subset of affine
+/// expressions supporting low-complexity algorithms for the operations used in
+/// loop transformations. In particular, are supported:
+/// - constant expressions;
+/// - single variables (dimensions and symbols) with +1 or -1 coefficient;
+/// - stripe expressions: "x # C", where "x" is a single variable or another
+/// stripe expression, "#" is the stripe operator, and "C" is a constant
+/// expression; "#" is defined as x - x mod C.
+/// - sum expressions between single variable/stripe expressions and constant
+/// expressions;
+/// - difference expressions between single variable/stripe expressions.
+/// `SDBMExpr` class hierarchy provides a type-safe interface to constructing
+/// and operating on SDBM expressions. For example, it requires the LHS of a
+/// sum expression to be a single variable or a stripe expression. These
+/// restrictions are intended to force the caller to perform the necessary
+/// simplifications to stay within the SDBM domain, because SDBM expressions do
+/// not combine in more cases than they do. This choice may be reconsidered in
+/// the future.
+///
+/// `SDBMExpr` and derived classes are thin wrappers around a pointer owned by
+/// an MLIRContext, and should be used by-value. They are uniqued in the
+/// MLIRContext and immortal.
+class SDBMExpr {
+public:
+ using ImplType = detail::SDBMExprStorage;
+ SDBMExpr() : impl(nullptr) {}
+ /* implicit */ SDBMExpr(ImplType *expr) : impl(expr) {}
+
+ /// SDBM expressions are thin wrappers around a unique'ed immutable pointer,
+ /// which makes them trivially assignable and trivially copyable.
+ SDBMExpr(const SDBMExpr &) = default;
+ SDBMExpr &operator=(const SDBMExpr &) = default;
+
+ /// SDBM expressions can be compared straight-forwardly.
+ bool operator==(const SDBMExpr &other) const { return impl == other.impl; }
+ bool operator!=(const SDBMExpr &other) const { return !(*this == other); }
+
+ /// SDBM expressions are convertible to `bool`: null expressions are converted
+ /// to false, non-null expressions are converted to true.
+ explicit operator bool() const { return impl != nullptr; }
+ bool operator!() const { return !static_cast<bool>(*this); }
+
+ /// Prints the SDBM expression.
+ void print(raw_ostream &os) const;
+ void dump() const;
+
+ /// LLVM-style casts.
+ template <typename U> bool isa() const { return U::isClassFor(*this); }
+ template <typename U> U dyn_cast() const {
+ if (!isa<U>())
+ return {};
+ return U(const_cast<SDBMExpr *>(this)->impl);
+ }
+ template <typename U> U cast() const {
+ assert(isa<U>() && "cast to incorrect subtype");
+ return U(const_cast<SDBMExpr *>(this)->impl);
+ }
+
+ /// Support for LLVM hashing.
+ ::llvm::hash_code hash_value() const { return ::llvm::hash_value(impl); }
+
+ /// Returns the kind of the SDBM expression.
+ SDBMExprKind getKind() const;
+
+ /// Returns the MLIR context in which this expression lives.
+ MLIRContext *getContext() const;
+
+protected:
+ ImplType *impl;
+};
+
+/// SDBM constant expression, wraps a 64-bit integer.
+class SDBMConstantExpr : public SDBMExpr {
+public:
+ using ImplType = detail::SDBMConstantExprStorage;
+
+ using SDBMExpr::SDBMExpr;
+
+ /// Obtain or create a constant expression unique'ed in the given context.
+ static SDBMConstantExpr get(MLIRContext *context, int64_t value);
+
+ static bool isClassFor(const SDBMExpr &expr) {
+ return expr.getKind() == SDBMExprKind::Constant;
+ }
+
+ int64_t getValue() const;
+};
+
+/// SDBM varying expression can be one of:
+/// - input variable expression;
+/// - stripe expression;
+/// - negation (product with -1) of either of the above.
+/// - sum of a varying and a constant expression
+/// - difference between varying expressions
+class SDBMVaryingExpr : public SDBMExpr {
+public:
+ using ImplType = detail::SDBMExprStorage;
+ using SDBMExpr::SDBMExpr;
+
+ static bool isClassFor(const SDBMExpr &expr) {
+ return expr.getKind() == SDBMExprKind::DimId ||
+ expr.getKind() == SDBMExprKind::SymbolId ||
+ expr.getKind() == SDBMExprKind::Neg ||
+ expr.getKind() == SDBMExprKind::Stripe ||
+ expr.getKind() == SDBMExprKind::Add ||
+ expr.getKind() == SDBMExprKind::Diff;
+ }
+};
+
+/// SDBM positive variable expression can be one of:
+/// - single variable expression;
+/// - stripe expression.
+class SDBMPositiveExpr : public SDBMVaryingExpr {
+public:
+ using SDBMVaryingExpr::SDBMVaryingExpr;
+
+ static bool isClassFor(const SDBMExpr &expr) {
+ return expr.getKind() == SDBMExprKind::DimId ||
+ expr.getKind() == SDBMExprKind::SymbolId ||
+ expr.getKind() == SDBMExprKind::Stripe;
+ }
+};
+
+/// SDBM sum expression. LHS is a varying expression and RHS is always a
+/// constant expression.
+class SDBMSumExpr : public SDBMVaryingExpr {
+public:
+ using ImplType = detail::SDBMBinaryExprStorage;
+ using SDBMVaryingExpr::SDBMVaryingExpr;
+
+ /// Obtain or create a sum expression unique'ed in the given context.
+ static SDBMSumExpr get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs);
+
+ static bool isClassFor(const SDBMExpr &expr) {
+ SDBMExprKind kind = expr.getKind();
+ return kind == SDBMExprKind::Add;
+ }
+
+ SDBMVaryingExpr getLHS() const;
+ SDBMConstantExpr getRHS() const;
+};
+
+/// SDBM difference expression. Both LHS and RHS are positive variable
+/// expressions.
+class SDBMDiffExpr : public SDBMVaryingExpr {
+public:
+ using ImplType = detail::SDBMDiffExprStorage;
+ using SDBMVaryingExpr::SDBMVaryingExpr;
+
+ /// Obtain or create a difference expression unique'ed in the given context.
+ static SDBMDiffExpr get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs);
+
+ static bool isClassFor(const SDBMExpr &expr) {
+ return expr.getKind() == SDBMExprKind::Diff;
+ }
+
+ SDBMPositiveExpr getLHS() const;
+ SDBMPositiveExpr getRHS() const;
+};
+
+/// SDBM stripe expression "x # C" where "x" is a positive variable expression,
+/// "C" is a constant expression and "#" is the stripe operator defined as:
+/// x # C = x - x mod C.
+class SDBMStripeExpr : public SDBMPositiveExpr {
+public:
+ using ImplType = detail::SDBMBinaryExprStorage;
+ using SDBMPositiveExpr::SDBMPositiveExpr;
+
+ static bool isClassFor(const SDBMExpr &expr) {
+ return expr.getKind() == SDBMExprKind::Stripe;
+ }
+
+ static SDBMStripeExpr get(SDBMPositiveExpr var,
+ SDBMConstantExpr stripeFactor);
+
+ SDBMPositiveExpr getVar() const;
+ SDBMConstantExpr getStripeFactor() const;
+};
+
+/// SDBM "input" variable expression can be either a dimension identifier or
+/// a symbol identifier. When used to define SDBM functions, dimensions are
+/// interpreted as function arguments while symbols are treated as unknown but
+/// constant values, hence the name.
+class SDBMInputExpr : public SDBMPositiveExpr {
+public:
+ using ImplType = detail::SDBMPositiveExprStorage;
+ using SDBMPositiveExpr::SDBMPositiveExpr;
+
+ static bool isClassFor(const SDBMExpr &expr) {
+ return expr.getKind() == SDBMExprKind::DimId ||
+ expr.getKind() == SDBMExprKind::SymbolId;
+ }
+
+ unsigned getPosition() const;
+};
+
+/// SDBM dimension expression. Dimensions correspond to function arguments
+/// when defining functions using SDBM expressions.
+class SDBMDimExpr : public SDBMInputExpr {
+public:
+ using ImplType = detail::SDBMPositiveExprStorage;
+ using SDBMInputExpr::SDBMInputExpr;
+
+ /// Obtain or create a dimension expression unique'ed in the given context.
+ static SDBMDimExpr get(MLIRContext *context, unsigned position);
+
+ static bool isClassFor(const SDBMExpr &expr) {
+ return expr.getKind() == SDBMExprKind::DimId;
+ }
+};
+
+/// SDBM symbol expression. Symbols correspond to symbolic constants when
+/// defining functions using SDBM expressions.
+class SDBMSymbolExpr : public SDBMInputExpr {
+public:
+ using ImplType = detail::SDBMPositiveExprStorage;
+ using SDBMInputExpr::SDBMInputExpr;
+
+ /// Obtain or create a symbol expression unique'ed in the given context.
+ static SDBMSymbolExpr get(MLIRContext *context, unsigned position);
+
+ static bool isClassFor(const SDBMExpr &expr) {
+ return expr.getKind() == SDBMExprKind::SymbolId;
+ }
+};
+
+/// Negation of an SDBM variable expression. Equivalent to multiplying the
+/// expression with -1 (SDBM does not support other coefficients that 1 and -1).
+class SDBMNegExpr : public SDBMVaryingExpr {
+public:
+ using ImplType = detail::SDBMNegExprStorage;
+ using SDBMVaryingExpr::SDBMVaryingExpr;
+
+ /// Obtain or create a negation expression unique'ed in the given context.
+ static SDBMNegExpr get(SDBMPositiveExpr var);
+
+ static bool isClassFor(const SDBMExpr &expr) {
+ return expr.getKind() == SDBMExprKind::Neg;
+ }
+
+ SDBMPositiveExpr getVar() const;
+};
+
+} // end namespace mlir
+
+namespace llvm {
+// SDBMVaryingExpr hash just like pointers.
+template <> struct DenseMapInfo<mlir::SDBMVaryingExpr> {
+ static mlir::SDBMVaryingExpr getEmptyKey() {
+ auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+ return mlir::SDBMVaryingExpr(
+ static_cast<mlir::SDBMExpr::ImplType *>(pointer));
+ }
+ static mlir::SDBMVaryingExpr getTombstoneKey() {
+ auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+ return mlir::SDBMVaryingExpr(
+ static_cast<mlir::SDBMExpr::ImplType *>(pointer));
+ }
+ static unsigned getHashValue(mlir::SDBMVaryingExpr expr) {
+ return expr.hash_value();
+ }
+ static bool isEqual(mlir::SDBMVaryingExpr lhs, mlir::SDBMVaryingExpr rhs) {
+ return lhs == rhs;
+ }
+};
+
+// SDBMPositiveExpr hash just like pointers.
+template <> struct DenseMapInfo<mlir::SDBMPositiveExpr> {
+ static mlir::SDBMPositiveExpr getEmptyKey() {
+ auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+ return mlir::SDBMPositiveExpr(
+ static_cast<mlir::SDBMExpr::ImplType *>(pointer));
+ }
+ static mlir::SDBMPositiveExpr getTombstoneKey() {
+ auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+ return mlir::SDBMPositiveExpr(
+ static_cast<mlir::SDBMExpr::ImplType *>(pointer));
+ }
+ static unsigned getHashValue(mlir::SDBMPositiveExpr expr) {
+ return expr.hash_value();
+ }
+ static bool isEqual(mlir::SDBMPositiveExpr lhs, mlir::SDBMPositiveExpr rhs) {
+ return lhs == rhs;
+ }
+};
+
+// SDBMConstantExpr hash just like pointers.
+template <> struct DenseMapInfo<mlir::SDBMConstantExpr> {
+ static mlir::SDBMConstantExpr getEmptyKey() {
+ auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
+ return mlir::SDBMConstantExpr(
+ static_cast<mlir::SDBMExpr::ImplType *>(pointer));
+ }
+ static mlir::SDBMConstantExpr getTombstoneKey() {
+ auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
+ return mlir::SDBMConstantExpr(
+ static_cast<mlir::SDBMExpr::ImplType *>(pointer));
+ }
+ static unsigned getHashValue(mlir::SDBMConstantExpr expr) {
+ return expr.hash_value();
+ }
+ static bool isEqual(mlir::SDBMConstantExpr lhs, mlir::SDBMConstantExpr rhs) {
+ return lhs == rhs;
+ }
+};
+} // namespace llvm
+
+#endif // MLIR_IR_SDBMEXPR_H
#include "AttributeDetail.h"
#include "IntegerSetDetail.h"
#include "LocationDetail.h"
+#include "SDBMExprDetail.h"
#include "TypeDetail.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
DenseMap<int64_t, AffineConstantExprStorage *> constExprs;
//===--------------------------------------------------------------------===//
+ // SDBM uniquing
+ //===--------------------------------------------------------------------===//
+ llvm::BumpPtrAllocator SDBMAllocator;
+ llvm::sys::SmartRWMutex<true> SDBMMutex;
+
+ DenseMap<std::tuple<SDBMVaryingExpr, SDBMConstantExpr>,
+ SDBMBinaryExprStorage *>
+ SDBMSumExprs;
+ DenseMap<std::tuple<SDBMPositiveExpr, SDBMConstantExpr>,
+ SDBMBinaryExprStorage *>
+ SDBMStripeExprs;
+ DenseMap<std::tuple<SDBMPositiveExpr, SDBMPositiveExpr>,
+ SDBMDiffExprStorage *>
+ SDBMDiffExprs;
+ std::vector<SDBMPositiveExprStorage *> SDBMDimExprs;
+ std::vector<SDBMPositiveExprStorage *> SDBMSymbolExprs;
+ DenseMap<SDBMPositiveExpr, SDBMNegExprStorage *> SDBMNegExprs;
+ DenseMap<int64_t, SDBMConstantExprStorage *> SDBMConstExprs;
+
+ //===--------------------------------------------------------------------===//
// Type uniquing
//===--------------------------------------------------------------------===//
StorageUniquer typeUniquer;
}
//===----------------------------------------------------------------------===//
+// SDBMExpr uniquing
+//===----------------------------------------------------------------------===//
+
+SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) {
+ assert(lhs && "expected SDBM variable expression");
+ assert(rhs && "expected SDBM constant");
+
+ MLIRContextImpl &impl = lhs.getContext()->getImpl();
+
+ // If LHS of a sum is another sum, fold the constant RHS parts.
+ if (auto lhsSum = lhs.dyn_cast<SDBMSumExpr>()) {
+ lhs = lhsSum.getLHS();
+ rhs = SDBMConstantExpr::get(rhs.getContext(),
+ rhs.getValue() + lhsSum.getRHS().getValue());
+ }
+
+ auto key = std::make_tuple(lhs, rhs);
+ return safeGetOrCreate(
+ impl.SDBMSumExprs, key, impl.SDBMMutex, [&impl, lhs, rhs] {
+ auto *mem = impl.SDBMAllocator.Allocate<SDBMBinaryExprStorage>();
+ return new (mem) SDBMBinaryExprStorage(SDBMExprKind::Add,
+ lhs.getContext(), lhs, rhs);
+ });
+}
+
+SDBMDiffExpr SDBMDiffExpr::get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs) {
+ assert(lhs && "expected SDBM dimension");
+ assert(rhs && "expected SDBM dimension");
+
+ MLIRContextImpl &impl = lhs.getContext()->getImpl();
+ auto key = std::make_tuple(lhs, rhs);
+ return safeGetOrCreate(
+ impl.SDBMDiffExprs, key, impl.SDBMMutex, [&impl, lhs, rhs] {
+ auto *mem = impl.SDBMAllocator.Allocate<SDBMDiffExprStorage>();
+ return new (mem) SDBMDiffExprStorage(lhs.getContext(), lhs, rhs);
+ });
+}
+
+SDBMStripeExpr SDBMStripeExpr::get(SDBMPositiveExpr var,
+ SDBMConstantExpr stripeFactor) {
+ assert(var && "expected SDBM variable expression");
+ assert(stripeFactor && "expected non-null stripe factor");
+ assert(stripeFactor.getValue() > 0 && "non-positive stripe factor");
+
+ MLIRContextImpl &impl = var.getContext()->getImpl();
+ auto key = std::make_tuple(var, stripeFactor);
+ return safeGetOrCreate(
+ impl.SDBMStripeExprs, key, impl.SDBMMutex, [&impl, var, stripeFactor] {
+ auto *mem = impl.SDBMAllocator.Allocate<SDBMBinaryExprStorage>();
+ return new (mem) SDBMBinaryExprStorage(
+ SDBMExprKind::Stripe, var.getContext(), var, stripeFactor);
+ });
+}
+
+SDBMDimExpr SDBMDimExpr::get(MLIRContext *context, unsigned position) {
+ assert(context && "expected non-null context");
+ MLIRContextImpl &impl = context->getImpl();
+ return safeGetOrCreate(
+ impl.SDBMDimExprs, position, impl.SDBMMutex, [&impl, context, position] {
+ auto *mem = impl.SDBMAllocator.Allocate<SDBMPositiveExprStorage>();
+ return new (mem)
+ SDBMPositiveExprStorage(SDBMExprKind::DimId, context, position);
+ });
+}
+
+SDBMSymbolExpr SDBMSymbolExpr::get(MLIRContext *context, unsigned position) {
+ assert(context && "expected non-null context");
+ MLIRContextImpl &impl = context->getImpl();
+ return safeGetOrCreate(
+ impl.SDBMSymbolExprs, position, impl.SDBMMutex,
+ [&impl, context, position] {
+ auto *mem = impl.SDBMAllocator.Allocate<SDBMPositiveExprStorage>();
+ return new (mem)
+ SDBMPositiveExprStorage(SDBMExprKind::SymbolId, context, position);
+ });
+}
+
+SDBMConstantExpr SDBMConstantExpr::get(MLIRContext *context, int64_t value) {
+ assert(context && "expected non-null context");
+ MLIRContextImpl &impl = context->getImpl();
+ return safeGetOrCreate(
+ impl.SDBMConstExprs, value, impl.SDBMMutex, [&impl, context, value] {
+ auto *mem = impl.SDBMAllocator.Allocate<SDBMConstantExprStorage>();
+ return new (mem) SDBMConstantExprStorage(context, value);
+ });
+}
+
+SDBMNegExpr SDBMNegExpr::get(SDBMPositiveExpr var) {
+ assert(var && "expected non-null SDBM variable expression");
+ MLIRContextImpl &impl = var.getContext()->getImpl();
+ return safeGetOrCreate(impl.SDBMNegExprs, var, impl.SDBMMutex, [&impl, var] {
+ auto *mem = impl.SDBMAllocator.Allocate<SDBMNegExprStorage>();
+ return new (mem) SDBMNegExprStorage(var);
+ });
+}
+
+//===----------------------------------------------------------------------===//
// Type uniquing
//===----------------------------------------------------------------------===//
--- /dev/null
+//===- SDBMExpr.h - MLIR SDBM Expression implementation -------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// A striped difference-bound matrix (SDBM) expression is a constant expression,
+// an identifier, a binary expression with constant RHS and +, stripe operators
+// or a difference expression between two identifiers.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/SDBMExpr.h"
+#include "SDBMExprDetail.h"
+
+#include "llvm/Support/raw_ostream.h"
+
+using namespace mlir;
+
+//===----------------------------------------------------------------------===//
+// SDBMExpr
+//===----------------------------------------------------------------------===//
+
+SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); }
+
+MLIRContext *SDBMExpr::getContext() const { return impl->getContext(); }
+
+template <typename Derived> class SDBMVisitor {
+public:
+ /// Visit the given SDBM expression, dispatching to kind-specific functions.
+ void visit(SDBMExpr expr) {
+ auto *derived = static_cast<Derived *>(this);
+ switch (expr.getKind()) {
+ case SDBMExprKind::Add:
+ case SDBMExprKind::Diff:
+ case SDBMExprKind::DimId:
+ case SDBMExprKind::SymbolId:
+ case SDBMExprKind::Neg:
+ case SDBMExprKind::Stripe:
+ return derived->visitVarying(expr.cast<SDBMVaryingExpr>());
+ case SDBMExprKind::Constant:
+ return derived->visitConstant(expr.cast<SDBMConstantExpr>());
+ }
+ }
+
+protected:
+ /// Default visitors do nothing.
+ void visitSum(SDBMSumExpr) {}
+ void visitDiff(SDBMDiffExpr) {}
+ void visitStripe(SDBMStripeExpr) {}
+ void visitDim(SDBMDimExpr) {}
+ void visitSymbol(SDBMSymbolExpr) {}
+ void visitNeg(SDBMNegExpr) {}
+ void visitConstant(SDBMConstantExpr) {}
+
+ /// Default implementation of visitPositive dispatches to the special
+ /// functions for stripes and other variables. Concrete visitors can override
+ /// it.
+ void visitPositive(SDBMPositiveExpr expr) {
+ auto *derived = static_cast<Derived *>(this);
+ if (expr.getKind() == SDBMExprKind::Stripe)
+ derived->visitStripe(expr.cast<SDBMStripeExpr>());
+ else
+ derived->visitInput(expr.cast<SDBMInputExpr>());
+ }
+
+ /// Default implementation of visitInput dispatches to the special
+ /// functions for dimensions or symbols. Concrete visitors can override it to
+ /// visit all variables instead.
+ void visitInput(SDBMInputExpr expr) {
+ auto *derived = static_cast<Derived *>(this);
+ if (expr.getKind() == SDBMExprKind::DimId)
+ derived->visitDim(expr.cast<SDBMDimExpr>());
+ else
+ derived->visitSymbol(expr.cast<SDBMSymbolExpr>());
+ }
+
+ /// Default implementation of visitVarying dispatches to the special
+ /// functions for variables and negations thereof. Concerete visitors can
+ /// override it to visit all variables and negations isntead.
+ void visitVarying(SDBMVaryingExpr expr) {
+ auto *derived = static_cast<Derived *>(this);
+ if (auto var = expr.dyn_cast<SDBMPositiveExpr>())
+ derived->visitPositive(var);
+ else if (auto neg = expr.dyn_cast<SDBMNegExpr>())
+ derived->visitNeg(neg);
+ else if (auto sum = expr.dyn_cast<SDBMSumExpr>())
+ derived->visitSum(sum);
+ else if (auto diff = expr.dyn_cast<SDBMDiffExpr>())
+ derived->visitDiff(diff);
+
+ llvm_unreachable("unhandled subtype of varying SDBM expression");
+ }
+};
+
+void SDBMExpr::print(raw_ostream &os) const {
+ struct Printer : public SDBMVisitor<Printer> {
+ Printer(raw_ostream &ostream) : prn(ostream) {}
+
+ void visitSum(SDBMSumExpr expr) {
+ visitVarying(expr.getLHS());
+ prn << " + ";
+ visitConstant(expr.getRHS());
+ }
+ void visitDiff(SDBMDiffExpr expr) {
+ visitPositive(expr.getLHS());
+ prn << " - ";
+ visitPositive(expr.getRHS());
+ }
+ void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); }
+ void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); }
+ void visitStripe(SDBMStripeExpr expr) {
+ visitPositive(expr.getVar());
+ prn << " # ";
+ visitConstant(expr.getStripeFactor());
+ }
+ void visitNeg(SDBMNegExpr expr) {
+ prn << '-';
+ visitPositive(expr.getVar());
+ }
+ void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); }
+
+ raw_ostream &prn;
+ };
+ Printer printer(os);
+ printer.visit(*this);
+}
+
+void SDBMExpr::dump() const { print(llvm::errs()); }
+
+//===----------------------------------------------------------------------===//
+// SDBMSumExpr
+//===----------------------------------------------------------------------===//
+
+SDBMVaryingExpr SDBMSumExpr::getLHS() const {
+ return static_cast<ImplType *>(impl)->lhs;
+}
+
+SDBMConstantExpr SDBMSumExpr::getRHS() const {
+ return static_cast<ImplType *>(impl)->rhs;
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMDiffExpr
+//===----------------------------------------------------------------------===//
+
+SDBMPositiveExpr SDBMDiffExpr::getLHS() const {
+ return static_cast<ImplType *>(impl)->lhs;
+}
+
+SDBMPositiveExpr SDBMDiffExpr::getRHS() const {
+ return static_cast<ImplType *>(impl)->rhs;
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMStripeExpr
+//===----------------------------------------------------------------------===//
+
+SDBMPositiveExpr SDBMStripeExpr::getVar() const {
+ if (SDBMVaryingExpr lhs = static_cast<ImplType *>(impl)->lhs)
+ return lhs.cast<SDBMPositiveExpr>();
+ return {};
+}
+
+SDBMConstantExpr SDBMStripeExpr::getStripeFactor() const {
+ return static_cast<ImplType *>(impl)->rhs;
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMInputExpr
+//===----------------------------------------------------------------------===//
+
+unsigned SDBMInputExpr::getPosition() const {
+ return static_cast<ImplType *>(impl)->position;
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMConstantExpr
+//===----------------------------------------------------------------------===//
+
+int64_t SDBMConstantExpr::getValue() const {
+ return static_cast<ImplType *>(impl)->constant;
+}
+
+//===----------------------------------------------------------------------===//
+// SDBMNegExpr
+//===----------------------------------------------------------------------===//
+
+SDBMPositiveExpr SDBMNegExpr::getVar() const {
+ return static_cast<ImplType *>(impl)->dim;
+}
--- /dev/null
+//===- SDBMExprDetail.h - MLIR SDBM Expression storage details --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This holds implementation details of SDBMExpr, in particular underlying
+// storage types. MLIRContext.cpp needs to know the storage layout for
+// allocation and unique'ing purposes.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_SDBMEXPRDETAIL_H
+#define MLIR_IR_SDBMEXPRDETAIL_H
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/SDBMExpr.h"
+#include "llvm/ADT/PointerIntPair.h"
+
+namespace mlir {
+namespace detail {
+
+struct SDBMExprStorage {
+ SDBMExprStorage(SDBMExprKind kind, MLIRContext *context)
+ : contextAndKind(context, kind) {}
+
+ SDBMExprKind getKind() { return contextAndKind.getInt(); }
+
+ MLIRContext *getContext() { return contextAndKind.getPointer(); }
+
+ // This needs to know the layout of MLIRContext so the relevant file is
+ // included.
+ llvm::PointerIntPair<MLIRContext *, 3, SDBMExprKind> contextAndKind;
+};
+
+struct SDBMBinaryExprStorage : public SDBMExprStorage {
+ SDBMBinaryExprStorage(SDBMExprKind kind, MLIRContext *context,
+ SDBMVaryingExpr left, SDBMConstantExpr right)
+ : SDBMExprStorage(kind, context), lhs(left), rhs(right) {}
+ SDBMVaryingExpr lhs;
+ SDBMConstantExpr rhs;
+};
+
+struct SDBMDiffExprStorage : public SDBMExprStorage {
+ SDBMDiffExprStorage(MLIRContext *context, SDBMPositiveExpr left,
+ SDBMPositiveExpr right)
+ : SDBMExprStorage(SDBMExprKind::Diff, context), lhs(left), rhs(right) {}
+ SDBMPositiveExpr lhs;
+ SDBMPositiveExpr rhs;
+};
+
+struct SDBMConstantExprStorage : public SDBMExprStorage {
+ SDBMConstantExprStorage(MLIRContext *context, int64_t value)
+ : SDBMExprStorage(SDBMExprKind::Constant, context), constant(value) {}
+ int64_t constant;
+};
+
+struct SDBMPositiveExprStorage : public SDBMExprStorage {
+ SDBMPositiveExprStorage(SDBMExprKind kind, MLIRContext *context, unsigned pos)
+ : SDBMExprStorage(kind, context), position(pos) {}
+ unsigned position;
+};
+
+struct SDBMNegExprStorage : public SDBMExprStorage {
+ SDBMNegExprStorage(SDBMPositiveExpr expr)
+ : SDBMExprStorage(SDBMExprKind::Neg, expr.getContext()), dim(expr) {}
+ SDBMPositiveExpr dim;
+};
+
+} // end namespace detail
+} // end namespace mlir
+
+#endif // MLIR_IR_SDBMEXPRDETAIL_H
add_mlir_unittest(MLIRIRTests
DialectTest.cpp
OperationSupportTest.cpp
+ SDBMTest.cpp
)
target_link_libraries(MLIRIRTests
PRIVATE
--- /dev/null
+//===- SDBMTest.cpp - SDBM expression unit tests --------------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/SDBMExpr.h"
+#include "gtest/gtest.h"
+
+using namespace mlir;
+
+static MLIRContext *ctx() {
+ static thread_local MLIRContext context;
+ return &context;
+}
+
+namespace {
+
+TEST(SDBMExpr, Constant) {
+ // We can create consants and query them.
+ auto expr = SDBMConstantExpr::get(ctx(), 42);
+ EXPECT_EQ(expr.getValue(), 42);
+
+ // Two separately created constants with identical values are trivially equal.
+ auto expr2 = SDBMConstantExpr::get(ctx(), 42);
+ EXPECT_EQ(expr, expr2);
+
+ // Hierarchy is okay.
+ auto generic = static_cast<SDBMExpr>(expr);
+ EXPECT_TRUE(generic.isa<SDBMConstantExpr>());
+}
+
+TEST(SDBMExpr, Dim) {
+ // We can create dimension expressions and query them.
+ auto expr = SDBMDimExpr::get(ctx(), 0);
+ EXPECT_EQ(expr.getPosition(), 0);
+
+ // Two separately created dimensions with the same position are trivially
+ // equal.
+ auto expr2 = SDBMDimExpr::get(ctx(), 0);
+ EXPECT_EQ(expr, expr2);
+
+ // Hierarchy is okay.
+ auto generic = static_cast<SDBMExpr>(expr);
+ EXPECT_TRUE(generic.isa<SDBMDimExpr>());
+ EXPECT_TRUE(generic.isa<SDBMInputExpr>());
+ EXPECT_TRUE(generic.isa<SDBMPositiveExpr>());
+ EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
+
+ // Dimensions are not Symbols.
+ auto symbol = SDBMSymbolExpr::get(ctx(), 0);
+ EXPECT_NE(expr, symbol);
+ EXPECT_FALSE(expr.isa<SDBMSymbolExpr>());
+}
+
+TEST(SDBMExpr, Symbol) {
+ // We can create symbol expressions and query them.
+ auto expr = SDBMSymbolExpr::get(ctx(), 0);
+ EXPECT_EQ(expr.getPosition(), 0);
+
+ // Two separately created symbols with the same position are trivially equal.
+ auto expr2 = SDBMSymbolExpr::get(ctx(), 0);
+ EXPECT_EQ(expr, expr2);
+
+ // Hierarchy is okay.
+ auto generic = static_cast<SDBMExpr>(expr);
+ EXPECT_TRUE(generic.isa<SDBMSymbolExpr>());
+ EXPECT_TRUE(generic.isa<SDBMInputExpr>());
+ EXPECT_TRUE(generic.isa<SDBMPositiveExpr>());
+ EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
+
+ // Dimensions are not Symbols.
+ auto symbol = SDBMDimExpr::get(ctx(), 0);
+ EXPECT_NE(expr, symbol);
+ EXPECT_FALSE(expr.isa<SDBMDimExpr>());
+}
+
+TEST(SDBMExpr, Stripe) {
+ auto cst2 = SDBMConstantExpr::get(ctx(), 2);
+ auto cst0 = SDBMConstantExpr::get(ctx(), 0);
+ auto var = SDBMSymbolExpr::get(ctx(), 0);
+
+ // We can create stripe expressions and query them.
+ auto expr = SDBMStripeExpr::get(var, cst2);
+ EXPECT_EQ(expr.getVar(), var);
+ EXPECT_EQ(expr.getStripeFactor(), cst2);
+
+ // Two separately created stripe expressions with the same LHS and RHS are
+ // trivially equal.
+ auto expr2 = SDBMStripeExpr::get(SDBMSymbolExpr::get(ctx(), 0), cst2);
+ EXPECT_EQ(expr, expr2);
+
+ // Stripes can be nested.
+ SDBMStripeExpr::get(expr, SDBMConstantExpr::get(ctx(), 4));
+
+ // Non-positive stripe factors are not allowed.
+ EXPECT_DEATH(SDBMStripeExpr::get(var, cst0), "non-positive");
+
+ // Hierarchy is okay.
+ auto generic = static_cast<SDBMExpr>(expr);
+ EXPECT_TRUE(generic.isa<SDBMStripeExpr>());
+ EXPECT_TRUE(generic.isa<SDBMPositiveExpr>());
+ EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
+}
+
+TEST(SDBMExpr, Neg) {
+ auto cst2 = SDBMConstantExpr::get(ctx(), 2);
+ auto var = SDBMSymbolExpr::get(ctx(), 0);
+ auto stripe = SDBMStripeExpr::get(var, cst2);
+
+ // We can create negation expressions and query them.
+ auto expr = SDBMNegExpr::get(var);
+ EXPECT_EQ(expr.getVar(), var);
+ auto expr2 = SDBMNegExpr::get(stripe);
+ EXPECT_EQ(expr2.getVar(), stripe);
+
+ // Neg expressions are trivially comparable.
+ EXPECT_EQ(expr, SDBMNegExpr::get(var));
+
+ // Hierarchy is okay.
+ auto generic = static_cast<SDBMExpr>(expr);
+ EXPECT_TRUE(generic.isa<SDBMNegExpr>());
+ EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
+}
+
+TEST(SDBMExpr, Sum) {
+ auto cst2 = SDBMConstantExpr::get(ctx(), 2);
+ auto var = SDBMSymbolExpr::get(ctx(), 0);
+ auto stripe = SDBMStripeExpr::get(var, cst2);
+
+ // We can create sum expressions and query them.
+ auto expr = SDBMSumExpr::get(var, cst2);
+ EXPECT_EQ(expr.getLHS(), var);
+ EXPECT_EQ(expr.getRHS(), cst2);
+ auto expr2 = SDBMSumExpr::get(stripe, cst2);
+ EXPECT_EQ(expr2.getLHS(), stripe);
+ EXPECT_EQ(expr2.getRHS(), cst2);
+
+ // Sum expressions are trivially comparable.
+ EXPECT_EQ(expr, SDBMSumExpr::get(var, cst2));
+
+ // Hierarchy is okay.
+ auto generic = static_cast<SDBMExpr>(expr);
+ EXPECT_TRUE(generic.isa<SDBMSumExpr>());
+ EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
+}
+
+TEST(SDBMExpr, Diff) {
+ auto cst2 = SDBMConstantExpr::get(ctx(), 2);
+ auto var = SDBMSymbolExpr::get(ctx(), 0);
+ auto stripe = SDBMStripeExpr::get(var, cst2);
+
+ // We can create sum expressions and query them.
+ auto expr = SDBMDiffExpr::get(var, stripe);
+ EXPECT_EQ(expr.getLHS(), var);
+ EXPECT_EQ(expr.getRHS(), stripe);
+ auto expr2 = SDBMDiffExpr::get(stripe, var);
+ EXPECT_EQ(expr2.getLHS(), stripe);
+ EXPECT_EQ(expr2.getRHS(), var);
+
+ // Sum expressions are trivially comparable.
+ EXPECT_EQ(expr, SDBMDiffExpr::get(var, stripe));
+
+ // Hierarchy is okay.
+ auto generic = static_cast<SDBMExpr>(expr);
+ EXPECT_TRUE(generic.isa<SDBMDiffExpr>());
+ EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
+}
+
+} // end namespace