#include "mlir/IR/SDBMExpr.h"
#include "SDBMExprDetail.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineExprVisitor.h"
#include "llvm/Support/raw_ostream.h"
MLIRContext *SDBMExpr::getContext() const { return impl->getContext(); }
-template <typename Derived> class SDBMVisitor {
+template <typename Derived, typename Result = void> class SDBMVisitor {
public:
/// Visit the given SDBM expression, dispatching to kind-specific functions.
- void visit(SDBMExpr expr) {
+ Result visit(SDBMExpr expr) {
auto *derived = static_cast<Derived *>(this);
switch (expr.getKind()) {
case SDBMExprKind::Add:
/// Default implementation of visitPositive dispatches to the special
/// functions for stripes and other variables. Concrete visitors can override
/// it.
- void visitPositive(SDBMPositiveExpr expr) {
+ Result visitPositive(SDBMPositiveExpr expr) {
auto *derived = static_cast<Derived *>(this);
if (expr.getKind() == SDBMExprKind::Stripe)
- derived->visitStripe(expr.cast<SDBMStripeExpr>());
+ return derived->visitStripe(expr.cast<SDBMStripeExpr>());
else
- derived->visitInput(expr.cast<SDBMInputExpr>());
+ return derived->visitInput(expr.cast<SDBMInputExpr>());
}
/// Default implementation of visitInput dispatches to the special
/// functions for dimensions or symbols. Concrete visitors can override it to
/// visit all variables instead.
- void visitInput(SDBMInputExpr expr) {
+ Result visitInput(SDBMInputExpr expr) {
auto *derived = static_cast<Derived *>(this);
if (expr.getKind() == SDBMExprKind::DimId)
- derived->visitDim(expr.cast<SDBMDimExpr>());
+ return derived->visitDim(expr.cast<SDBMDimExpr>());
else
- derived->visitSymbol(expr.cast<SDBMSymbolExpr>());
+ return derived->visitSymbol(expr.cast<SDBMSymbolExpr>());
}
/// Default implementation of visitVarying dispatches to the special
/// functions for variables and negations thereof. Concerete visitors can
/// override it to visit all variables and negations isntead.
- void visitVarying(SDBMVaryingExpr expr) {
+ Result visitVarying(SDBMVaryingExpr expr) {
auto *derived = static_cast<Derived *>(this);
if (auto var = expr.dyn_cast<SDBMPositiveExpr>())
- derived->visitPositive(var);
+ return derived->visitPositive(var);
else if (auto neg = expr.dyn_cast<SDBMNegExpr>())
- derived->visitNeg(neg);
+ return derived->visitNeg(neg);
else if (auto sum = expr.dyn_cast<SDBMSumExpr>())
- derived->visitSum(sum);
+ return derived->visitSum(sum);
else if (auto diff = expr.dyn_cast<SDBMDiffExpr>())
- derived->visitDiff(diff);
+ return derived->visitDiff(diff);
llvm_unreachable("unhandled subtype of varying SDBM expression");
}
printer.visit(*this);
}
-void SDBMExpr::dump() const { print(llvm::errs()); }
+void SDBMExpr::dump() const {
+ print(llvm::errs());
+ llvm::errs() << '\n';
+}
//===----------------------------------------------------------------------===//
// SDBMSumExpr
return static_cast<ImplType *>(impl)->rhs;
}
+AffineExpr SDBMExpr::getAsAffineExpr() const {
+ struct Converter : public SDBMVisitor<Converter, AffineExpr> {
+ AffineExpr visitSum(SDBMSumExpr expr) {
+ AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
+ return lhs + rhs;
+ }
+
+ AffineExpr visitStripe(SDBMStripeExpr expr) {
+ AffineExpr lhs = visit(expr.getVar()),
+ rhs = visit(expr.getStripeFactor());
+ return lhs - (lhs % rhs);
+ }
+
+ AffineExpr visitDiff(SDBMDiffExpr expr) {
+ AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
+ return lhs - rhs;
+ }
+
+ AffineExpr visitDim(SDBMDimExpr expr) {
+ return getAffineDimExpr(expr.getPosition(), expr.getContext());
+ }
+
+ AffineExpr visitSymbol(SDBMSymbolExpr expr) {
+ return getAffineSymbolExpr(expr.getPosition(), expr.getContext());
+ }
+
+ AffineExpr visitNeg(SDBMNegExpr expr) {
+ return getAffineBinaryOpExpr(AffineExprKind::Mul,
+ getAffineConstantExpr(-1, expr.getContext()),
+ visit(expr.getVar()));
+ }
+
+ AffineExpr visitConstant(SDBMConstantExpr expr) {
+ return getAffineConstantExpr(expr.getValue(), expr.getContext());
+ }
+ } converter;
+ return converter.visit(*this);
+}
+
+Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
+ struct Converter : public AffineExprVisitor<Converter, SDBMExpr> {
+ // Try matching the definition of the stripe operation as x - x mod C where
+ // `pos` should match "x" and `neg` should match "- (x mod C)".
+ SDBMExpr matchStripeAddPattern(AffineExpr pos, AffineExpr neg) {
+ // Check that the "pos" part is a variable expression and that the "neg"
+ // part is a mul expression.
+ auto convertedLHS = visit(pos);
+ if (!convertedLHS || !convertedLHS.isa<SDBMPositiveExpr>())
+ return {};
+
+ auto outerBinExpr = neg.dyn_cast<AffineBinaryOpExpr>();
+ if (!outerBinExpr || outerBinExpr.getKind() != AffineExprKind::Mul)
+ return {};
+
+ // In affine mul expressions, the constant part is always on the RHS.
+ // If there had been two constants, they would have been folded away.
+ assert(!outerBinExpr.getLHS().isa<AffineConstantExpr>() &&
+ "expected a constant on the RHS of an affine mul expression");
+ // Check if the RHS of mul is -1.
+ auto multiplierExpr =
+ outerBinExpr.getRHS().dyn_cast<AffineConstantExpr>();
+ if (!multiplierExpr || multiplierExpr.getValue() != -1)
+ return {};
+
+ // Check if the LHS of mul is ("pos" mod constant).
+ auto binExpr = outerBinExpr.getLHS().dyn_cast<AffineBinaryOpExpr>();
+ if (!binExpr || binExpr.getKind() != AffineExprKind::Mod ||
+ !binExpr.getRHS().isa<AffineConstantExpr>())
+ return {};
+
+ if (convertedLHS != visit(binExpr.getLHS()))
+ return {};
+
+ // If all checks pass, we have a stripe.
+ return SDBMStripeExpr::get(
+ convertedLHS.cast<SDBMPositiveExpr>(),
+ visit(binExpr.getRHS()).cast<SDBMConstantExpr>());
+ }
+
+ SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) {
+ // Attempt to recover a stripe expression. Because AffineExprs don't have
+ // a first-class difference kind, we check for both x + -1 * (x mod C) and
+ // -1 * (x mod C) + x cases.
+ if (auto stripe = matchStripeAddPattern(expr.getLHS(), expr.getRHS()))
+ return stripe;
+ if (auto stripe = matchStripeAddPattern(expr.getRHS(), expr.getLHS()))
+ return stripe;
+
+ auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
+ if (!lhs || !rhs)
+ return {};
+
+ // In a "add" AffineExpr, the constant always appears on the right. If
+ // there were two constants, they would have been folded away.
+ assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
+ auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
+
+ // SDBM accepts LHS variables and RHS constants in a sum.
+ auto lhsVar = lhs.dyn_cast<SDBMVaryingExpr>();
+ auto rhsVar = rhs.dyn_cast<SDBMVaryingExpr>();
+ if (rhsConstant && lhsVar)
+ return SDBMSumExpr::get(lhsVar, rhsConstant);
+
+ // The sum of a negated variable and a non-negated variable is a
+ // difference, supported as a special kind in SDBM. Because AffineExprs
+ // don't have first-class difference kind, check both LHS and RHS for
+ // negation.
+ auto lhsPos = lhs.dyn_cast<SDBMPositiveExpr>();
+ auto rhsPos = rhs.dyn_cast<SDBMPositiveExpr>();
+ auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
+ auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
+ if (lhsNeg && rhsVar)
+ return SDBMDiffExpr::get(rhsPos, lhsNeg.getVar());
+ if (rhsNeg && lhsVar)
+ return SDBMDiffExpr::get(lhsPos, rhsNeg.getVar());
+
+ // Other cases don't fit into SDBM.
+ return {};
+ }
+
+ // Try matching the stripe pattern "(x floordiv C) * C" where `lhs`
+ // corresponds to "(x floordiv C)" and `rhs` corresponds to "C".
+ SDBMExpr matchStripeMulPattern(AffineExpr lhs, AffineExpr rhs) {
+ // Check if LHS is a floordiv expression and rhs is a constant.
+ auto lhsBinary = lhs.dyn_cast<AffineBinaryOpExpr>();
+ auto rhsConstant = rhs.dyn_cast<AffineConstantExpr>();
+ if (!lhsBinary || !rhsConstant ||
+ lhsBinary.getKind() != AffineExprKind::FloorDiv)
+ return {};
+
+ // Check if the floordiv divides by the constant equal to RHS.
+ auto lhsRhsConstant = lhsBinary.getRHS().dyn_cast<AffineConstantExpr>();
+ if (!lhsRhsConstant || lhsRhsConstant != rhsConstant)
+ return {};
+
+ // Check if LHS can be converted to a single variable.
+ SDBMExpr converted = visit(lhsBinary.getLHS());
+ if (!converted)
+ return {};
+ auto varConverted = converted.dyn_cast<SDBMPositiveExpr>();
+ if (!varConverted)
+ return {};
+
+ // If all checks pass, we have a stripe.
+ return SDBMStripeExpr::get(
+ varConverted, SDBMConstantExpr::get(varConverted.getContext(),
+ rhsConstant.getValue()));
+ }
+
+ SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) {
+ // Attempt to recover a stripe expression "x # C = (x floordiv C) * C".
+ if (auto stripe = matchStripeMulPattern(expr.getLHS(), expr.getRHS()))
+ return stripe;
+
+ auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
+ if (!lhs || !rhs)
+ return {};
+
+ // In a "mul" AffineExpr, the constant always appears on the right. If
+ // there were two constants, they would have been folded away.
+ assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
+ auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
+ if (!rhsConstant)
+ return {};
+
+ // The only supported "multiplication" expression is an SDBM is dimension
+ // negation, that is a product of dimension and constant -1.
+ auto lhsVar = lhs.dyn_cast<SDBMPositiveExpr>();
+ if (lhsVar && rhsConstant.getValue() == -1)
+ return SDBMNegExpr::get(lhsVar);
+
+ // Other multiplications are not allowed in SDBM.
+ return {};
+ }
+
+ SDBMExpr visitModExpr(AffineBinaryOpExpr expr) {
+ auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
+ if (!lhs || !rhs)
+ return {};
+
+ // 'mod' can only be converted to SDBM if its LHS is a variable
+ // and its RHS is a constant. Then it `x mod c = x - x stripe c`.
+ auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
+ auto lhsVar = rhs.dyn_cast<SDBMPositiveExpr>();
+ if (!lhsVar || !rhsConstant)
+ return {};
+ return SDBMDiffExpr::get(lhsVar,
+ SDBMStripeExpr::get(lhsVar, rhsConstant));
+ }
+
+ // `a floordiv b = (a stripe b) / b`, but we have no division in SDBM
+ SDBMExpr visitFloorDivExpr(AffineBinaryOpExpr expr) { return {}; }
+ SDBMExpr visitCeilDivExpr(AffineBinaryOpExpr expr) { return {}; }
+
+ // Dimensions, symbols and constants are converted trivially.
+ SDBMExpr visitConstantExpr(AffineConstantExpr expr) {
+ return SDBMConstantExpr::get(expr.getContext(), expr.getValue());
+ }
+ SDBMExpr visitDimExpr(AffineDimExpr expr) {
+ return SDBMDimExpr::get(expr.getContext(), expr.getPosition());
+ }
+ SDBMExpr visitSymbolExpr(AffineSymbolExpr expr) {
+ return SDBMSymbolExpr::get(expr.getContext(), expr.getPosition());
+ }
+ } converter;
+
+ if (auto result = converter.visit(affine))
+ return result;
+ return None;
+}
+
//===----------------------------------------------------------------------===//
// SDBMDiffExpr
//===----------------------------------------------------------------------===//
// limitations under the License.
// =============================================================================
+#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/SDBMExpr.h"
#include "gtest/gtest.h"
EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
}
+TEST(SDBMExpr, AffineRoundTrip) {
+ // Build an expression (s0 - s0 # 2)
+ auto cst2 = SDBMConstantExpr::get(ctx(), 2);
+ auto var = SDBMSymbolExpr::get(ctx(), 0);
+ auto stripe = SDBMStripeExpr::get(var, cst2);
+ auto expr = SDBMDiffExpr::get(var, stripe);
+
+ // Check that it can be converted to AffineExpr and back, i.e. stripe
+ // detection works correctly.
+ Optional<SDBMExpr> roundtripped =
+ SDBMExpr::tryConvertAffineExpr(expr.getAsAffineExpr());
+ ASSERT_TRUE(roundtripped.hasValue());
+ EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(expr));
+
+ // Check that (s0 # 2 # 5) can be converted to AffineExpr, i.e. stripe
+ // detection supports nested expressions.
+ auto cst5 = SDBMConstantExpr::get(ctx(), 5);
+ auto outerStripe = SDBMStripeExpr::get(stripe, cst5);
+ roundtripped = SDBMExpr::tryConvertAffineExpr(outerStripe.getAsAffineExpr());
+ ASSERT_TRUE(roundtripped.hasValue());
+ EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(outerStripe));
+
+ // Check that (s0 # 2 # 5 - s0 # 2) + 2 can be converted as an example of a
+ // deeper expression tree.
+ auto diff = SDBMDiffExpr::get(outerStripe, stripe);
+ auto sum = SDBMSumExpr::get(diff, cst2);
+ roundtripped = SDBMExpr::tryConvertAffineExpr(sum.getAsAffineExpr());
+ ASSERT_TRUE(roundtripped.hasValue());
+ EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(sum));
+}
+
+TEST(SDBMExpr, MatchStripeMulPattern) {
+ // Make sure conversion from AffineExpr recognizes multiplicative stripe
+ // pattern (x floordiv B) * B == x # B.
+ auto cst = getAffineConstantExpr(42, ctx());
+ auto dim = getAffineDimExpr(0, ctx());
+ auto floor = getAffineBinaryOpExpr(AffineExprKind::FloorDiv, dim, cst);
+ auto mul = getAffineBinaryOpExpr(AffineExprKind::Mul, cst, floor);
+ Optional<SDBMExpr> converted = SDBMStripeExpr::tryConvertAffineExpr(mul);
+ ASSERT_TRUE(converted.hasValue());
+ EXPECT_TRUE(converted->isa<SDBMStripeExpr>());
+}
+
+TEST(SDBMExpr, NonSDBM) {
+ auto d0 = getAffineDimExpr(0, ctx());
+ auto d1 = getAffineDimExpr(1, ctx());
+ auto sum = getAffineBinaryOpExpr(AffineExprKind::Add, d0, d1);
+ auto c2 = getAffineConstantExpr(2, ctx());
+ auto prod = getAffineBinaryOpExpr(AffineExprKind::Mul, d0, c2);
+ auto ceildiv = getAffineBinaryOpExpr(AffineExprKind::CeilDiv, d1, c2);
+
+ // The following are not valid SDBM expressions:
+ // - a sum of two variables
+ EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(sum).hasValue());
+ // - a variable with coefficient other than 1 or -1
+ EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(prod).hasValue());
+ // - a ceildiv expression
+ EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(ceildiv).hasValue());
+}
+
} // end namespace