From: Alex Zinenko Date: Fri, 26 Apr 2019 08:05:24 +0000 (-0700) Subject: Start building SDBM infrastructure X-Git-Tag: llvmorg-11-init~1466^2~1882 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=24d0f60d312cd7ad66c86e60e416b76f782010ed;p=platform%2Fupstream%2Fllvm.git Start building SDBM infrastructure Striped difference-bound matrix expressions are a subset of affine expressions supporting low-complexity algorithms that can be useful for loop transformations. This introduces the basic data data structures for building such expressions and unique'ing them in a MLIRContext. -- PiperOrigin-RevId: 245380206 --- diff --git a/mlir/include/mlir/IR/SDBMExpr.h b/mlir/include/mlir/IR/SDBMExpr.h new file mode 100644 index 0000000..02e1de3 --- /dev/null +++ b/mlir/include/mlir/IR/SDBMExpr.h @@ -0,0 +1,353 @@ +//===- SDBMExpr.h - MLIR SDBM Expression ------------------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// A striped difference-bound matrix (SDBM) expression is a constant expression, +// an identifier, a binary expression with constant RHS and +, stripe operators +// or a difference expression between two identifiers. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_SDBMEXPR_H +#define MLIR_IR_SDBMEXPR_H + +#include "mlir/Support/LLVM.h" +#include "llvm/ADT/DenseMapInfo.h" + +namespace mlir { + +class MLIRContext; + +enum class SDBMExprKind { Add, Stripe, Diff, Constant, DimId, SymbolId, Neg }; + +namespace detail { +struct SDBMExprStorage; +struct SDBMBinaryExprStorage; +struct SDBMDiffExprStorage; +struct SDBMPositiveExprStorage; +struct SDBMConstantExprStorage; +struct SDBMNegExprStorage; +} // namespace detail + +/// Striped Difference-Bounded Matrix (SDBM) expression is a base left-hand side +/// expression for the SDBM framework. SDBM expressions are a subset of affine +/// expressions supporting low-complexity algorithms for the operations used in +/// loop transformations. In particular, are supported: +/// - constant expressions; +/// - single variables (dimensions and symbols) with +1 or -1 coefficient; +/// - stripe expressions: "x # C", where "x" is a single variable or another +/// stripe expression, "#" is the stripe operator, and "C" is a constant +/// expression; "#" is defined as x - x mod C. +/// - sum expressions between single variable/stripe expressions and constant +/// expressions; +/// - difference expressions between single variable/stripe expressions. +/// `SDBMExpr` class hierarchy provides a type-safe interface to constructing +/// and operating on SDBM expressions. For example, it requires the LHS of a +/// sum expression to be a single variable or a stripe expression. These +/// restrictions are intended to force the caller to perform the necessary +/// simplifications to stay within the SDBM domain, because SDBM expressions do +/// not combine in more cases than they do. This choice may be reconsidered in +/// the future. +/// +/// `SDBMExpr` and derived classes are thin wrappers around a pointer owned by +/// an MLIRContext, and should be used by-value. They are uniqued in the +/// MLIRContext and immortal. +class SDBMExpr { +public: + using ImplType = detail::SDBMExprStorage; + SDBMExpr() : impl(nullptr) {} + /* implicit */ SDBMExpr(ImplType *expr) : impl(expr) {} + + /// SDBM expressions are thin wrappers around a unique'ed immutable pointer, + /// which makes them trivially assignable and trivially copyable. + SDBMExpr(const SDBMExpr &) = default; + SDBMExpr &operator=(const SDBMExpr &) = default; + + /// SDBM expressions can be compared straight-forwardly. + bool operator==(const SDBMExpr &other) const { return impl == other.impl; } + bool operator!=(const SDBMExpr &other) const { return !(*this == other); } + + /// SDBM expressions are convertible to `bool`: null expressions are converted + /// to false, non-null expressions are converted to true. + explicit operator bool() const { return impl != nullptr; } + bool operator!() const { return !static_cast(*this); } + + /// Prints the SDBM expression. + void print(raw_ostream &os) const; + void dump() const; + + /// LLVM-style casts. + template bool isa() const { return U::isClassFor(*this); } + template U dyn_cast() const { + if (!isa()) + return {}; + return U(const_cast(this)->impl); + } + template U cast() const { + assert(isa() && "cast to incorrect subtype"); + return U(const_cast(this)->impl); + } + + /// Support for LLVM hashing. + ::llvm::hash_code hash_value() const { return ::llvm::hash_value(impl); } + + /// Returns the kind of the SDBM expression. + SDBMExprKind getKind() const; + + /// Returns the MLIR context in which this expression lives. + MLIRContext *getContext() const; + +protected: + ImplType *impl; +}; + +/// SDBM constant expression, wraps a 64-bit integer. +class SDBMConstantExpr : public SDBMExpr { +public: + using ImplType = detail::SDBMConstantExprStorage; + + using SDBMExpr::SDBMExpr; + + /// Obtain or create a constant expression unique'ed in the given context. + static SDBMConstantExpr get(MLIRContext *context, int64_t value); + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::Constant; + } + + int64_t getValue() const; +}; + +/// SDBM varying expression can be one of: +/// - input variable expression; +/// - stripe expression; +/// - negation (product with -1) of either of the above. +/// - sum of a varying and a constant expression +/// - difference between varying expressions +class SDBMVaryingExpr : public SDBMExpr { +public: + using ImplType = detail::SDBMExprStorage; + using SDBMExpr::SDBMExpr; + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::DimId || + expr.getKind() == SDBMExprKind::SymbolId || + expr.getKind() == SDBMExprKind::Neg || + expr.getKind() == SDBMExprKind::Stripe || + expr.getKind() == SDBMExprKind::Add || + expr.getKind() == SDBMExprKind::Diff; + } +}; + +/// SDBM positive variable expression can be one of: +/// - single variable expression; +/// - stripe expression. +class SDBMPositiveExpr : public SDBMVaryingExpr { +public: + using SDBMVaryingExpr::SDBMVaryingExpr; + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::DimId || + expr.getKind() == SDBMExprKind::SymbolId || + expr.getKind() == SDBMExprKind::Stripe; + } +}; + +/// SDBM sum expression. LHS is a varying expression and RHS is always a +/// constant expression. +class SDBMSumExpr : public SDBMVaryingExpr { +public: + using ImplType = detail::SDBMBinaryExprStorage; + using SDBMVaryingExpr::SDBMVaryingExpr; + + /// Obtain or create a sum expression unique'ed in the given context. + static SDBMSumExpr get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs); + + static bool isClassFor(const SDBMExpr &expr) { + SDBMExprKind kind = expr.getKind(); + return kind == SDBMExprKind::Add; + } + + SDBMVaryingExpr getLHS() const; + SDBMConstantExpr getRHS() const; +}; + +/// SDBM difference expression. Both LHS and RHS are positive variable +/// expressions. +class SDBMDiffExpr : public SDBMVaryingExpr { +public: + using ImplType = detail::SDBMDiffExprStorage; + using SDBMVaryingExpr::SDBMVaryingExpr; + + /// Obtain or create a difference expression unique'ed in the given context. + static SDBMDiffExpr get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs); + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::Diff; + } + + SDBMPositiveExpr getLHS() const; + SDBMPositiveExpr getRHS() const; +}; + +/// SDBM stripe expression "x # C" where "x" is a positive variable expression, +/// "C" is a constant expression and "#" is the stripe operator defined as: +/// x # C = x - x mod C. +class SDBMStripeExpr : public SDBMPositiveExpr { +public: + using ImplType = detail::SDBMBinaryExprStorage; + using SDBMPositiveExpr::SDBMPositiveExpr; + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::Stripe; + } + + static SDBMStripeExpr get(SDBMPositiveExpr var, + SDBMConstantExpr stripeFactor); + + SDBMPositiveExpr getVar() const; + SDBMConstantExpr getStripeFactor() const; +}; + +/// SDBM "input" variable expression can be either a dimension identifier or +/// a symbol identifier. When used to define SDBM functions, dimensions are +/// interpreted as function arguments while symbols are treated as unknown but +/// constant values, hence the name. +class SDBMInputExpr : public SDBMPositiveExpr { +public: + using ImplType = detail::SDBMPositiveExprStorage; + using SDBMPositiveExpr::SDBMPositiveExpr; + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::DimId || + expr.getKind() == SDBMExprKind::SymbolId; + } + + unsigned getPosition() const; +}; + +/// SDBM dimension expression. Dimensions correspond to function arguments +/// when defining functions using SDBM expressions. +class SDBMDimExpr : public SDBMInputExpr { +public: + using ImplType = detail::SDBMPositiveExprStorage; + using SDBMInputExpr::SDBMInputExpr; + + /// Obtain or create a dimension expression unique'ed in the given context. + static SDBMDimExpr get(MLIRContext *context, unsigned position); + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::DimId; + } +}; + +/// SDBM symbol expression. Symbols correspond to symbolic constants when +/// defining functions using SDBM expressions. +class SDBMSymbolExpr : public SDBMInputExpr { +public: + using ImplType = detail::SDBMPositiveExprStorage; + using SDBMInputExpr::SDBMInputExpr; + + /// Obtain or create a symbol expression unique'ed in the given context. + static SDBMSymbolExpr get(MLIRContext *context, unsigned position); + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::SymbolId; + } +}; + +/// Negation of an SDBM variable expression. Equivalent to multiplying the +/// expression with -1 (SDBM does not support other coefficients that 1 and -1). +class SDBMNegExpr : public SDBMVaryingExpr { +public: + using ImplType = detail::SDBMNegExprStorage; + using SDBMVaryingExpr::SDBMVaryingExpr; + + /// Obtain or create a negation expression unique'ed in the given context. + static SDBMNegExpr get(SDBMPositiveExpr var); + + static bool isClassFor(const SDBMExpr &expr) { + return expr.getKind() == SDBMExprKind::Neg; + } + + SDBMPositiveExpr getVar() const; +}; + +} // end namespace mlir + +namespace llvm { +// SDBMVaryingExpr hash just like pointers. +template <> struct DenseMapInfo { + static mlir::SDBMVaryingExpr getEmptyKey() { + auto *pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::SDBMVaryingExpr( + static_cast(pointer)); + } + static mlir::SDBMVaryingExpr getTombstoneKey() { + auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::SDBMVaryingExpr( + static_cast(pointer)); + } + static unsigned getHashValue(mlir::SDBMVaryingExpr expr) { + return expr.hash_value(); + } + static bool isEqual(mlir::SDBMVaryingExpr lhs, mlir::SDBMVaryingExpr rhs) { + return lhs == rhs; + } +}; + +// SDBMPositiveExpr hash just like pointers. +template <> struct DenseMapInfo { + static mlir::SDBMPositiveExpr getEmptyKey() { + auto *pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::SDBMPositiveExpr( + static_cast(pointer)); + } + static mlir::SDBMPositiveExpr getTombstoneKey() { + auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::SDBMPositiveExpr( + static_cast(pointer)); + } + static unsigned getHashValue(mlir::SDBMPositiveExpr expr) { + return expr.hash_value(); + } + static bool isEqual(mlir::SDBMPositiveExpr lhs, mlir::SDBMPositiveExpr rhs) { + return lhs == rhs; + } +}; + +// SDBMConstantExpr hash just like pointers. +template <> struct DenseMapInfo { + static mlir::SDBMConstantExpr getEmptyKey() { + auto *pointer = llvm::DenseMapInfo::getEmptyKey(); + return mlir::SDBMConstantExpr( + static_cast(pointer)); + } + static mlir::SDBMConstantExpr getTombstoneKey() { + auto *pointer = llvm::DenseMapInfo::getTombstoneKey(); + return mlir::SDBMConstantExpr( + static_cast(pointer)); + } + static unsigned getHashValue(mlir::SDBMConstantExpr expr) { + return expr.hash_value(); + } + static bool isEqual(mlir::SDBMConstantExpr lhs, mlir::SDBMConstantExpr rhs) { + return lhs == rhs; + } +}; +} // namespace llvm + +#endif // MLIR_IR_SDBMEXPR_H diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index e630914..29ac989 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -21,6 +21,7 @@ #include "AttributeDetail.h" #include "IntegerSetDetail.h" #include "LocationDetail.h" +#include "SDBMExprDetail.h" #include "TypeDetail.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -462,6 +463,26 @@ public: DenseMap constExprs; //===--------------------------------------------------------------------===// + // SDBM uniquing + //===--------------------------------------------------------------------===// + llvm::BumpPtrAllocator SDBMAllocator; + llvm::sys::SmartRWMutex SDBMMutex; + + DenseMap, + SDBMBinaryExprStorage *> + SDBMSumExprs; + DenseMap, + SDBMBinaryExprStorage *> + SDBMStripeExprs; + DenseMap, + SDBMDiffExprStorage *> + SDBMDiffExprs; + std::vector SDBMDimExprs; + std::vector SDBMSymbolExprs; + DenseMap SDBMNegExprs; + DenseMap SDBMConstExprs; + + //===--------------------------------------------------------------------===// // Type uniquing //===--------------------------------------------------------------------===// StorageUniquer typeUniquer; @@ -844,6 +865,103 @@ Location FusedLoc::get(ArrayRef locs, Attribute metadata, } //===----------------------------------------------------------------------===// +// SDBMExpr uniquing +//===----------------------------------------------------------------------===// + +SDBMSumExpr SDBMSumExpr::get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs) { + assert(lhs && "expected SDBM variable expression"); + assert(rhs && "expected SDBM constant"); + + MLIRContextImpl &impl = lhs.getContext()->getImpl(); + + // If LHS of a sum is another sum, fold the constant RHS parts. + if (auto lhsSum = lhs.dyn_cast()) { + lhs = lhsSum.getLHS(); + rhs = SDBMConstantExpr::get(rhs.getContext(), + rhs.getValue() + lhsSum.getRHS().getValue()); + } + + auto key = std::make_tuple(lhs, rhs); + return safeGetOrCreate( + impl.SDBMSumExprs, key, impl.SDBMMutex, [&impl, lhs, rhs] { + auto *mem = impl.SDBMAllocator.Allocate(); + return new (mem) SDBMBinaryExprStorage(SDBMExprKind::Add, + lhs.getContext(), lhs, rhs); + }); +} + +SDBMDiffExpr SDBMDiffExpr::get(SDBMPositiveExpr lhs, SDBMPositiveExpr rhs) { + assert(lhs && "expected SDBM dimension"); + assert(rhs && "expected SDBM dimension"); + + MLIRContextImpl &impl = lhs.getContext()->getImpl(); + auto key = std::make_tuple(lhs, rhs); + return safeGetOrCreate( + impl.SDBMDiffExprs, key, impl.SDBMMutex, [&impl, lhs, rhs] { + auto *mem = impl.SDBMAllocator.Allocate(); + return new (mem) SDBMDiffExprStorage(lhs.getContext(), lhs, rhs); + }); +} + +SDBMStripeExpr SDBMStripeExpr::get(SDBMPositiveExpr var, + SDBMConstantExpr stripeFactor) { + assert(var && "expected SDBM variable expression"); + assert(stripeFactor && "expected non-null stripe factor"); + assert(stripeFactor.getValue() > 0 && "non-positive stripe factor"); + + MLIRContextImpl &impl = var.getContext()->getImpl(); + auto key = std::make_tuple(var, stripeFactor); + return safeGetOrCreate( + impl.SDBMStripeExprs, key, impl.SDBMMutex, [&impl, var, stripeFactor] { + auto *mem = impl.SDBMAllocator.Allocate(); + return new (mem) SDBMBinaryExprStorage( + SDBMExprKind::Stripe, var.getContext(), var, stripeFactor); + }); +} + +SDBMDimExpr SDBMDimExpr::get(MLIRContext *context, unsigned position) { + assert(context && "expected non-null context"); + MLIRContextImpl &impl = context->getImpl(); + return safeGetOrCreate( + impl.SDBMDimExprs, position, impl.SDBMMutex, [&impl, context, position] { + auto *mem = impl.SDBMAllocator.Allocate(); + return new (mem) + SDBMPositiveExprStorage(SDBMExprKind::DimId, context, position); + }); +} + +SDBMSymbolExpr SDBMSymbolExpr::get(MLIRContext *context, unsigned position) { + assert(context && "expected non-null context"); + MLIRContextImpl &impl = context->getImpl(); + return safeGetOrCreate( + impl.SDBMSymbolExprs, position, impl.SDBMMutex, + [&impl, context, position] { + auto *mem = impl.SDBMAllocator.Allocate(); + return new (mem) + SDBMPositiveExprStorage(SDBMExprKind::SymbolId, context, position); + }); +} + +SDBMConstantExpr SDBMConstantExpr::get(MLIRContext *context, int64_t value) { + assert(context && "expected non-null context"); + MLIRContextImpl &impl = context->getImpl(); + return safeGetOrCreate( + impl.SDBMConstExprs, value, impl.SDBMMutex, [&impl, context, value] { + auto *mem = impl.SDBMAllocator.Allocate(); + return new (mem) SDBMConstantExprStorage(context, value); + }); +} + +SDBMNegExpr SDBMNegExpr::get(SDBMPositiveExpr var) { + assert(var && "expected non-null SDBM variable expression"); + MLIRContextImpl &impl = var.getContext()->getImpl(); + return safeGetOrCreate(impl.SDBMNegExprs, var, impl.SDBMMutex, [&impl, var] { + auto *mem = impl.SDBMAllocator.Allocate(); + return new (mem) SDBMNegExprStorage(var); + }); +} + +//===----------------------------------------------------------------------===// // Type uniquing //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/SDBMExpr.cpp b/mlir/lib/IR/SDBMExpr.cpp new file mode 100644 index 0000000..2d94dd83 --- /dev/null +++ b/mlir/lib/IR/SDBMExpr.cpp @@ -0,0 +1,202 @@ +//===- SDBMExpr.h - MLIR SDBM Expression implementation -------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// A striped difference-bound matrix (SDBM) expression is a constant expression, +// an identifier, a binary expression with constant RHS and +, stripe operators +// or a difference expression between two identifiers. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/SDBMExpr.h" +#include "SDBMExprDetail.h" + +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// SDBMExpr +//===----------------------------------------------------------------------===// + +SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); } + +MLIRContext *SDBMExpr::getContext() const { return impl->getContext(); } + +template class SDBMVisitor { +public: + /// Visit the given SDBM expression, dispatching to kind-specific functions. + void visit(SDBMExpr expr) { + auto *derived = static_cast(this); + switch (expr.getKind()) { + case SDBMExprKind::Add: + case SDBMExprKind::Diff: + case SDBMExprKind::DimId: + case SDBMExprKind::SymbolId: + case SDBMExprKind::Neg: + case SDBMExprKind::Stripe: + return derived->visitVarying(expr.cast()); + case SDBMExprKind::Constant: + return derived->visitConstant(expr.cast()); + } + } + +protected: + /// Default visitors do nothing. + void visitSum(SDBMSumExpr) {} + void visitDiff(SDBMDiffExpr) {} + void visitStripe(SDBMStripeExpr) {} + void visitDim(SDBMDimExpr) {} + void visitSymbol(SDBMSymbolExpr) {} + void visitNeg(SDBMNegExpr) {} + void visitConstant(SDBMConstantExpr) {} + + /// Default implementation of visitPositive dispatches to the special + /// functions for stripes and other variables. Concrete visitors can override + /// it. + void visitPositive(SDBMPositiveExpr expr) { + auto *derived = static_cast(this); + if (expr.getKind() == SDBMExprKind::Stripe) + derived->visitStripe(expr.cast()); + else + derived->visitInput(expr.cast()); + } + + /// Default implementation of visitInput dispatches to the special + /// functions for dimensions or symbols. Concrete visitors can override it to + /// visit all variables instead. + void visitInput(SDBMInputExpr expr) { + auto *derived = static_cast(this); + if (expr.getKind() == SDBMExprKind::DimId) + derived->visitDim(expr.cast()); + else + derived->visitSymbol(expr.cast()); + } + + /// Default implementation of visitVarying dispatches to the special + /// functions for variables and negations thereof. Concerete visitors can + /// override it to visit all variables and negations isntead. + void visitVarying(SDBMVaryingExpr expr) { + auto *derived = static_cast(this); + if (auto var = expr.dyn_cast()) + derived->visitPositive(var); + else if (auto neg = expr.dyn_cast()) + derived->visitNeg(neg); + else if (auto sum = expr.dyn_cast()) + derived->visitSum(sum); + else if (auto diff = expr.dyn_cast()) + derived->visitDiff(diff); + + llvm_unreachable("unhandled subtype of varying SDBM expression"); + } +}; + +void SDBMExpr::print(raw_ostream &os) const { + struct Printer : public SDBMVisitor { + Printer(raw_ostream &ostream) : prn(ostream) {} + + void visitSum(SDBMSumExpr expr) { + visitVarying(expr.getLHS()); + prn << " + "; + visitConstant(expr.getRHS()); + } + void visitDiff(SDBMDiffExpr expr) { + visitPositive(expr.getLHS()); + prn << " - "; + visitPositive(expr.getRHS()); + } + void visitDim(SDBMDimExpr expr) { prn << 'd' << expr.getPosition(); } + void visitSymbol(SDBMSymbolExpr expr) { prn << 's' << expr.getPosition(); } + void visitStripe(SDBMStripeExpr expr) { + visitPositive(expr.getVar()); + prn << " # "; + visitConstant(expr.getStripeFactor()); + } + void visitNeg(SDBMNegExpr expr) { + prn << '-'; + visitPositive(expr.getVar()); + } + void visitConstant(SDBMConstantExpr expr) { prn << expr.getValue(); } + + raw_ostream &prn; + }; + Printer printer(os); + printer.visit(*this); +} + +void SDBMExpr::dump() const { print(llvm::errs()); } + +//===----------------------------------------------------------------------===// +// SDBMSumExpr +//===----------------------------------------------------------------------===// + +SDBMVaryingExpr SDBMSumExpr::getLHS() const { + return static_cast(impl)->lhs; +} + +SDBMConstantExpr SDBMSumExpr::getRHS() const { + return static_cast(impl)->rhs; +} + +//===----------------------------------------------------------------------===// +// SDBMDiffExpr +//===----------------------------------------------------------------------===// + +SDBMPositiveExpr SDBMDiffExpr::getLHS() const { + return static_cast(impl)->lhs; +} + +SDBMPositiveExpr SDBMDiffExpr::getRHS() const { + return static_cast(impl)->rhs; +} + +//===----------------------------------------------------------------------===// +// SDBMStripeExpr +//===----------------------------------------------------------------------===// + +SDBMPositiveExpr SDBMStripeExpr::getVar() const { + if (SDBMVaryingExpr lhs = static_cast(impl)->lhs) + return lhs.cast(); + return {}; +} + +SDBMConstantExpr SDBMStripeExpr::getStripeFactor() const { + return static_cast(impl)->rhs; +} + +//===----------------------------------------------------------------------===// +// SDBMInputExpr +//===----------------------------------------------------------------------===// + +unsigned SDBMInputExpr::getPosition() const { + return static_cast(impl)->position; +} + +//===----------------------------------------------------------------------===// +// SDBMConstantExpr +//===----------------------------------------------------------------------===// + +int64_t SDBMConstantExpr::getValue() const { + return static_cast(impl)->constant; +} + +//===----------------------------------------------------------------------===// +// SDBMNegExpr +//===----------------------------------------------------------------------===// + +SDBMPositiveExpr SDBMNegExpr::getVar() const { + return static_cast(impl)->dim; +} diff --git a/mlir/lib/IR/SDBMExprDetail.h b/mlir/lib/IR/SDBMExprDetail.h new file mode 100644 index 0000000..f661d2a --- /dev/null +++ b/mlir/lib/IR/SDBMExprDetail.h @@ -0,0 +1,84 @@ +//===- SDBMExprDetail.h - MLIR SDBM Expression storage details --*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This holds implementation details of SDBMExpr, in particular underlying +// storage types. MLIRContext.cpp needs to know the storage layout for +// allocation and unique'ing purposes. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_IR_SDBMEXPRDETAIL_H +#define MLIR_IR_SDBMEXPRDETAIL_H + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/SDBMExpr.h" +#include "llvm/ADT/PointerIntPair.h" + +namespace mlir { +namespace detail { + +struct SDBMExprStorage { + SDBMExprStorage(SDBMExprKind kind, MLIRContext *context) + : contextAndKind(context, kind) {} + + SDBMExprKind getKind() { return contextAndKind.getInt(); } + + MLIRContext *getContext() { return contextAndKind.getPointer(); } + + // This needs to know the layout of MLIRContext so the relevant file is + // included. + llvm::PointerIntPair contextAndKind; +}; + +struct SDBMBinaryExprStorage : public SDBMExprStorage { + SDBMBinaryExprStorage(SDBMExprKind kind, MLIRContext *context, + SDBMVaryingExpr left, SDBMConstantExpr right) + : SDBMExprStorage(kind, context), lhs(left), rhs(right) {} + SDBMVaryingExpr lhs; + SDBMConstantExpr rhs; +}; + +struct SDBMDiffExprStorage : public SDBMExprStorage { + SDBMDiffExprStorage(MLIRContext *context, SDBMPositiveExpr left, + SDBMPositiveExpr right) + : SDBMExprStorage(SDBMExprKind::Diff, context), lhs(left), rhs(right) {} + SDBMPositiveExpr lhs; + SDBMPositiveExpr rhs; +}; + +struct SDBMConstantExprStorage : public SDBMExprStorage { + SDBMConstantExprStorage(MLIRContext *context, int64_t value) + : SDBMExprStorage(SDBMExprKind::Constant, context), constant(value) {} + int64_t constant; +}; + +struct SDBMPositiveExprStorage : public SDBMExprStorage { + SDBMPositiveExprStorage(SDBMExprKind kind, MLIRContext *context, unsigned pos) + : SDBMExprStorage(kind, context), position(pos) {} + unsigned position; +}; + +struct SDBMNegExprStorage : public SDBMExprStorage { + SDBMNegExprStorage(SDBMPositiveExpr expr) + : SDBMExprStorage(SDBMExprKind::Neg, expr.getContext()), dim(expr) {} + SDBMPositiveExpr dim; +}; + +} // end namespace detail +} // end namespace mlir + +#endif // MLIR_IR_SDBMEXPRDETAIL_H diff --git a/mlir/unittests/IR/CMakeLists.txt b/mlir/unittests/IR/CMakeLists.txt index 5236ea7..5d72c1f 100644 --- a/mlir/unittests/IR/CMakeLists.txt +++ b/mlir/unittests/IR/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_unittest(MLIRIRTests DialectTest.cpp OperationSupportTest.cpp + SDBMTest.cpp ) target_link_libraries(MLIRIRTests PRIVATE diff --git a/mlir/unittests/IR/SDBMTest.cpp b/mlir/unittests/IR/SDBMTest.cpp new file mode 100644 index 0000000..2075db5 --- /dev/null +++ b/mlir/unittests/IR/SDBMTest.cpp @@ -0,0 +1,182 @@ +//===- SDBMTest.cpp - SDBM expression unit tests --------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/SDBMExpr.h" +#include "gtest/gtest.h" + +using namespace mlir; + +static MLIRContext *ctx() { + static thread_local MLIRContext context; + return &context; +} + +namespace { + +TEST(SDBMExpr, Constant) { + // We can create consants and query them. + auto expr = SDBMConstantExpr::get(ctx(), 42); + EXPECT_EQ(expr.getValue(), 42); + + // Two separately created constants with identical values are trivially equal. + auto expr2 = SDBMConstantExpr::get(ctx(), 42); + EXPECT_EQ(expr, expr2); + + // Hierarchy is okay. + auto generic = static_cast(expr); + EXPECT_TRUE(generic.isa()); +} + +TEST(SDBMExpr, Dim) { + // We can create dimension expressions and query them. + auto expr = SDBMDimExpr::get(ctx(), 0); + EXPECT_EQ(expr.getPosition(), 0); + + // Two separately created dimensions with the same position are trivially + // equal. + auto expr2 = SDBMDimExpr::get(ctx(), 0); + EXPECT_EQ(expr, expr2); + + // Hierarchy is okay. + auto generic = static_cast(expr); + EXPECT_TRUE(generic.isa()); + EXPECT_TRUE(generic.isa()); + EXPECT_TRUE(generic.isa()); + EXPECT_TRUE(generic.isa()); + + // Dimensions are not Symbols. + auto symbol = SDBMSymbolExpr::get(ctx(), 0); + EXPECT_NE(expr, symbol); + EXPECT_FALSE(expr.isa()); +} + +TEST(SDBMExpr, Symbol) { + // We can create symbol expressions and query them. + auto expr = SDBMSymbolExpr::get(ctx(), 0); + EXPECT_EQ(expr.getPosition(), 0); + + // Two separately created symbols with the same position are trivially equal. + auto expr2 = SDBMSymbolExpr::get(ctx(), 0); + EXPECT_EQ(expr, expr2); + + // Hierarchy is okay. + auto generic = static_cast(expr); + EXPECT_TRUE(generic.isa()); + EXPECT_TRUE(generic.isa()); + EXPECT_TRUE(generic.isa()); + EXPECT_TRUE(generic.isa()); + + // Dimensions are not Symbols. + auto symbol = SDBMDimExpr::get(ctx(), 0); + EXPECT_NE(expr, symbol); + EXPECT_FALSE(expr.isa()); +} + +TEST(SDBMExpr, Stripe) { + auto cst2 = SDBMConstantExpr::get(ctx(), 2); + auto cst0 = SDBMConstantExpr::get(ctx(), 0); + auto var = SDBMSymbolExpr::get(ctx(), 0); + + // We can create stripe expressions and query them. + auto expr = SDBMStripeExpr::get(var, cst2); + EXPECT_EQ(expr.getVar(), var); + EXPECT_EQ(expr.getStripeFactor(), cst2); + + // Two separately created stripe expressions with the same LHS and RHS are + // trivially equal. + auto expr2 = SDBMStripeExpr::get(SDBMSymbolExpr::get(ctx(), 0), cst2); + EXPECT_EQ(expr, expr2); + + // Stripes can be nested. + SDBMStripeExpr::get(expr, SDBMConstantExpr::get(ctx(), 4)); + + // Non-positive stripe factors are not allowed. + EXPECT_DEATH(SDBMStripeExpr::get(var, cst0), "non-positive"); + + // Hierarchy is okay. + auto generic = static_cast(expr); + EXPECT_TRUE(generic.isa()); + EXPECT_TRUE(generic.isa()); + EXPECT_TRUE(generic.isa()); +} + +TEST(SDBMExpr, Neg) { + auto cst2 = SDBMConstantExpr::get(ctx(), 2); + auto var = SDBMSymbolExpr::get(ctx(), 0); + auto stripe = SDBMStripeExpr::get(var, cst2); + + // We can create negation expressions and query them. + auto expr = SDBMNegExpr::get(var); + EXPECT_EQ(expr.getVar(), var); + auto expr2 = SDBMNegExpr::get(stripe); + EXPECT_EQ(expr2.getVar(), stripe); + + // Neg expressions are trivially comparable. + EXPECT_EQ(expr, SDBMNegExpr::get(var)); + + // Hierarchy is okay. + auto generic = static_cast(expr); + EXPECT_TRUE(generic.isa()); + EXPECT_TRUE(generic.isa()); +} + +TEST(SDBMExpr, Sum) { + auto cst2 = SDBMConstantExpr::get(ctx(), 2); + auto var = SDBMSymbolExpr::get(ctx(), 0); + auto stripe = SDBMStripeExpr::get(var, cst2); + + // We can create sum expressions and query them. + auto expr = SDBMSumExpr::get(var, cst2); + EXPECT_EQ(expr.getLHS(), var); + EXPECT_EQ(expr.getRHS(), cst2); + auto expr2 = SDBMSumExpr::get(stripe, cst2); + EXPECT_EQ(expr2.getLHS(), stripe); + EXPECT_EQ(expr2.getRHS(), cst2); + + // Sum expressions are trivially comparable. + EXPECT_EQ(expr, SDBMSumExpr::get(var, cst2)); + + // Hierarchy is okay. + auto generic = static_cast(expr); + EXPECT_TRUE(generic.isa()); + EXPECT_TRUE(generic.isa()); +} + +TEST(SDBMExpr, Diff) { + auto cst2 = SDBMConstantExpr::get(ctx(), 2); + auto var = SDBMSymbolExpr::get(ctx(), 0); + auto stripe = SDBMStripeExpr::get(var, cst2); + + // We can create sum expressions and query them. + auto expr = SDBMDiffExpr::get(var, stripe); + EXPECT_EQ(expr.getLHS(), var); + EXPECT_EQ(expr.getRHS(), stripe); + auto expr2 = SDBMDiffExpr::get(stripe, var); + EXPECT_EQ(expr2.getLHS(), stripe); + EXPECT_EQ(expr2.getRHS(), var); + + // Sum expressions are trivially comparable. + EXPECT_EQ(expr, SDBMDiffExpr::get(var, stripe)); + + // Hierarchy is okay. + auto generic = static_cast(expr); + EXPECT_TRUE(generic.isa()); + EXPECT_TRUE(generic.isa()); +} + +} // end namespace