Implement the conversion between SDBMExpr and AffineExpr
authorAlex Zinenko <zinenko@google.com>
Fri, 26 Apr 2019 22:07:13 +0000 (15:07 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 6 May 2019 15:18:31 +0000 (08:18 -0700)
    Since SDBM expressions are a subset of affine expressions, they can be
    converted to affine expressions in a straightforward way.  The inverse
    conversion may fail when the affine expression is not an SDBM.  Implement the
    inverse convresion assuming affine expressions are simplified and
    canonicalizied, detect subtractive and multiplicative forms of the stripe
    operation.

--

PiperOrigin-RevId: 245494735

mlir/include/mlir/IR/AffineExpr.h
mlir/include/mlir/IR/SDBMExpr.h
mlir/lib/IR/SDBMExpr.cpp
mlir/unittests/IR/SDBMTest.cpp

index 55b6d04..c976cbf 100644 (file)
@@ -48,7 +48,8 @@ enum class AffineExprKind {
   Add,
   /// RHS of mul is always a constant or a symbolic expression.
   Mul,
-  /// RHS of mod is always a constant or a symbolic expression.
+  /// RHS of mod is always a constant or a symbolic expression with a positive
+  /// value.
   Mod,
   /// RHS of floordiv is always a constant or a symbolic expression.
   FloorDiv,
index 02e1de3..fa80e2f 100644 (file)
@@ -29,6 +29,7 @@
 
 namespace mlir {
 
+class AffineExpr;
 class MLIRContext;
 
 enum class SDBMExprKind { Add, Stripe, Diff, Constant, DimId, SymbolId, Neg };
@@ -110,6 +111,17 @@ public:
   /// Returns the MLIR context in which this expression lives.
   MLIRContext *getContext() const;
 
+  /// Convert the SDBM expression into an Affine expression.  This always
+  /// succeeds because SDBM are a subset of affine.
+  AffineExpr getAsAffineExpr() const;
+
+  /// Try constructing an SDBM expression from the given affine expression.
+  /// This may fail if the affine expression is not representable as SDBM, in
+  /// which case llvm::None is returned.  The conversion procedure recognizes
+  /// (nested) multiplicative ((x floordiv B) * B) and additive (x - x mod B)
+  /// patterns for the stripe expression.
+  static Optional<SDBMExpr> tryConvertAffineExpr(AffineExpr affine);
+
 protected:
   ImplType *impl;
 };
index 2d94dd8..bb7c199 100644 (file)
@@ -23,6 +23,8 @@
 
 #include "mlir/IR/SDBMExpr.h"
 #include "SDBMExprDetail.h"
+#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineExprVisitor.h"
 
 #include "llvm/Support/raw_ostream.h"
 
@@ -36,10 +38,10 @@ SDBMExprKind SDBMExpr::getKind() const { return impl->getKind(); }
 
 MLIRContext *SDBMExpr::getContext() const { return impl->getContext(); }
 
-template <typename Derived> class SDBMVisitor {
+template <typename Derived, typename Result = void> class SDBMVisitor {
 public:
   /// Visit the given SDBM expression, dispatching to kind-specific functions.
-  void visit(SDBMExpr expr) {
+  Result visit(SDBMExpr expr) {
     auto *derived = static_cast<Derived *>(this);
     switch (expr.getKind()) {
     case SDBMExprKind::Add:
@@ -67,38 +69,38 @@ protected:
   /// Default implementation of visitPositive dispatches to the special
   /// functions for stripes and other variables.  Concrete visitors can override
   /// it.
-  void visitPositive(SDBMPositiveExpr expr) {
+  Result visitPositive(SDBMPositiveExpr expr) {
     auto *derived = static_cast<Derived *>(this);
     if (expr.getKind() == SDBMExprKind::Stripe)
-      derived->visitStripe(expr.cast<SDBMStripeExpr>());
+      return derived->visitStripe(expr.cast<SDBMStripeExpr>());
     else
-      derived->visitInput(expr.cast<SDBMInputExpr>());
+      return derived->visitInput(expr.cast<SDBMInputExpr>());
   }
 
   /// Default implementation of visitInput dispatches to the special
   /// functions for dimensions or symbols.  Concrete visitors can override it to
   /// visit all variables instead.
-  void visitInput(SDBMInputExpr expr) {
+  Result visitInput(SDBMInputExpr expr) {
     auto *derived = static_cast<Derived *>(this);
     if (expr.getKind() == SDBMExprKind::DimId)
-      derived->visitDim(expr.cast<SDBMDimExpr>());
+      return derived->visitDim(expr.cast<SDBMDimExpr>());
     else
-      derived->visitSymbol(expr.cast<SDBMSymbolExpr>());
+      return derived->visitSymbol(expr.cast<SDBMSymbolExpr>());
   }
 
   /// Default implementation of visitVarying dispatches to the special
   /// functions for variables and negations thereof.  Concerete visitors can
   /// override it to visit all variables and negations isntead.
-  void visitVarying(SDBMVaryingExpr expr) {
+  Result visitVarying(SDBMVaryingExpr expr) {
     auto *derived = static_cast<Derived *>(this);
     if (auto var = expr.dyn_cast<SDBMPositiveExpr>())
-      derived->visitPositive(var);
+      return derived->visitPositive(var);
     else if (auto neg = expr.dyn_cast<SDBMNegExpr>())
-      derived->visitNeg(neg);
+      return derived->visitNeg(neg);
     else if (auto sum = expr.dyn_cast<SDBMSumExpr>())
-      derived->visitSum(sum);
+      return derived->visitSum(sum);
     else if (auto diff = expr.dyn_cast<SDBMDiffExpr>())
-      derived->visitDiff(diff);
+      return derived->visitDiff(diff);
 
     llvm_unreachable("unhandled subtype of varying SDBM expression");
   }
@@ -137,7 +139,10 @@ void SDBMExpr::print(raw_ostream &os) const {
   printer.visit(*this);
 }
 
-void SDBMExpr::dump() const { print(llvm::errs()); }
+void SDBMExpr::dump() const {
+  print(llvm::errs());
+  llvm::errs() << '\n';
+}
 
 //===----------------------------------------------------------------------===//
 // SDBMSumExpr
@@ -151,6 +156,217 @@ SDBMConstantExpr SDBMSumExpr::getRHS() const {
   return static_cast<ImplType *>(impl)->rhs;
 }
 
+AffineExpr SDBMExpr::getAsAffineExpr() const {
+  struct Converter : public SDBMVisitor<Converter, AffineExpr> {
+    AffineExpr visitSum(SDBMSumExpr expr) {
+      AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
+      return lhs + rhs;
+    }
+
+    AffineExpr visitStripe(SDBMStripeExpr expr) {
+      AffineExpr lhs = visit(expr.getVar()),
+                 rhs = visit(expr.getStripeFactor());
+      return lhs - (lhs % rhs);
+    }
+
+    AffineExpr visitDiff(SDBMDiffExpr expr) {
+      AffineExpr lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
+      return lhs - rhs;
+    }
+
+    AffineExpr visitDim(SDBMDimExpr expr) {
+      return getAffineDimExpr(expr.getPosition(), expr.getContext());
+    }
+
+    AffineExpr visitSymbol(SDBMSymbolExpr expr) {
+      return getAffineSymbolExpr(expr.getPosition(), expr.getContext());
+    }
+
+    AffineExpr visitNeg(SDBMNegExpr expr) {
+      return getAffineBinaryOpExpr(AffineExprKind::Mul,
+                                   getAffineConstantExpr(-1, expr.getContext()),
+                                   visit(expr.getVar()));
+    }
+
+    AffineExpr visitConstant(SDBMConstantExpr expr) {
+      return getAffineConstantExpr(expr.getValue(), expr.getContext());
+    }
+  } converter;
+  return converter.visit(*this);
+}
+
+Optional<SDBMExpr> SDBMExpr::tryConvertAffineExpr(AffineExpr affine) {
+  struct Converter : public AffineExprVisitor<Converter, SDBMExpr> {
+    // Try matching the definition of the stripe operation as x - x mod C where
+    // `pos` should match "x" and `neg` should match "- (x mod C)".
+    SDBMExpr matchStripeAddPattern(AffineExpr pos, AffineExpr neg) {
+      // Check that the "pos" part is a variable expression and that the "neg"
+      // part is a mul expression.
+      auto convertedLHS = visit(pos);
+      if (!convertedLHS || !convertedLHS.isa<SDBMPositiveExpr>())
+        return {};
+
+      auto outerBinExpr = neg.dyn_cast<AffineBinaryOpExpr>();
+      if (!outerBinExpr || outerBinExpr.getKind() != AffineExprKind::Mul)
+        return {};
+
+      // In affine mul expressions, the constant part is always on the RHS.
+      // If there had been two constants, they would have been folded away.
+      assert(!outerBinExpr.getLHS().isa<AffineConstantExpr>() &&
+             "expected a constant on the RHS of an affine mul expression");
+      // Check if the RHS of mul is -1.
+      auto multiplierExpr =
+          outerBinExpr.getRHS().dyn_cast<AffineConstantExpr>();
+      if (!multiplierExpr || multiplierExpr.getValue() != -1)
+        return {};
+
+      // Check if the LHS of mul is ("pos" mod constant).
+      auto binExpr = outerBinExpr.getLHS().dyn_cast<AffineBinaryOpExpr>();
+      if (!binExpr || binExpr.getKind() != AffineExprKind::Mod ||
+          !binExpr.getRHS().isa<AffineConstantExpr>())
+        return {};
+
+      if (convertedLHS != visit(binExpr.getLHS()))
+        return {};
+
+      // If all checks pass, we have a stripe.
+      return SDBMStripeExpr::get(
+          convertedLHS.cast<SDBMPositiveExpr>(),
+          visit(binExpr.getRHS()).cast<SDBMConstantExpr>());
+    }
+
+    SDBMExpr visitAddExpr(AffineBinaryOpExpr expr) {
+      // Attempt to recover a stripe expression.  Because AffineExprs don't have
+      // a first-class difference kind, we check for both x + -1 * (x mod C) and
+      // -1 * (x mod C) + x cases.
+      if (auto stripe = matchStripeAddPattern(expr.getLHS(), expr.getRHS()))
+        return stripe;
+      if (auto stripe = matchStripeAddPattern(expr.getRHS(), expr.getLHS()))
+        return stripe;
+
+      auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
+      if (!lhs || !rhs)
+        return {};
+
+      // In a "add" AffineExpr, the constant always appears on the right.  If
+      // there were two constants, they would have been folded away.
+      assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
+      auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
+
+      // SDBM accepts LHS variables and RHS constants in a sum.
+      auto lhsVar = lhs.dyn_cast<SDBMVaryingExpr>();
+      auto rhsVar = rhs.dyn_cast<SDBMVaryingExpr>();
+      if (rhsConstant && lhsVar)
+        return SDBMSumExpr::get(lhsVar, rhsConstant);
+
+      // The sum of a negated variable and a non-negated variable is a
+      // difference, supported as a special kind in SDBM.  Because AffineExprs
+      // don't have first-class difference kind, check both LHS and RHS for
+      // negation.
+      auto lhsPos = lhs.dyn_cast<SDBMPositiveExpr>();
+      auto rhsPos = rhs.dyn_cast<SDBMPositiveExpr>();
+      auto lhsNeg = lhs.dyn_cast<SDBMNegExpr>();
+      auto rhsNeg = rhs.dyn_cast<SDBMNegExpr>();
+      if (lhsNeg && rhsVar)
+        return SDBMDiffExpr::get(rhsPos, lhsNeg.getVar());
+      if (rhsNeg && lhsVar)
+        return SDBMDiffExpr::get(lhsPos, rhsNeg.getVar());
+
+      // Other cases don't fit into SDBM.
+      return {};
+    }
+
+    // Try matching the stripe pattern "(x floordiv C) * C" where `lhs`
+    // corresponds to "(x floordiv C)" and `rhs` corresponds to "C".
+    SDBMExpr matchStripeMulPattern(AffineExpr lhs, AffineExpr rhs) {
+      // Check if LHS is a floordiv expression and rhs is a constant.
+      auto lhsBinary = lhs.dyn_cast<AffineBinaryOpExpr>();
+      auto rhsConstant = rhs.dyn_cast<AffineConstantExpr>();
+      if (!lhsBinary || !rhsConstant ||
+          lhsBinary.getKind() != AffineExprKind::FloorDiv)
+        return {};
+
+      // Check if the floordiv divides by the constant equal to RHS.
+      auto lhsRhsConstant = lhsBinary.getRHS().dyn_cast<AffineConstantExpr>();
+      if (!lhsRhsConstant || lhsRhsConstant != rhsConstant)
+        return {};
+
+      // Check if LHS can be converted to a single variable.
+      SDBMExpr converted = visit(lhsBinary.getLHS());
+      if (!converted)
+        return {};
+      auto varConverted = converted.dyn_cast<SDBMPositiveExpr>();
+      if (!varConverted)
+        return {};
+
+      // If all checks pass, we have a stripe.
+      return SDBMStripeExpr::get(
+          varConverted, SDBMConstantExpr::get(varConverted.getContext(),
+                                              rhsConstant.getValue()));
+    }
+
+    SDBMExpr visitMulExpr(AffineBinaryOpExpr expr) {
+      // Attempt to recover a stripe expression "x # C = (x floordiv C) * C".
+      if (auto stripe = matchStripeMulPattern(expr.getLHS(), expr.getRHS()))
+        return stripe;
+
+      auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
+      if (!lhs || !rhs)
+        return {};
+
+      // In a "mul" AffineExpr, the constant always appears on the right.  If
+      // there were two constants, they would have been folded away.
+      assert(!lhs.isa<SDBMConstantExpr>() && "non-canonical affine expression");
+      auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
+      if (!rhsConstant)
+        return {};
+
+      // The only supported "multiplication" expression is an SDBM is dimension
+      // negation, that is a product of dimension and constant -1.
+      auto lhsVar = lhs.dyn_cast<SDBMPositiveExpr>();
+      if (lhsVar && rhsConstant.getValue() == -1)
+        return SDBMNegExpr::get(lhsVar);
+
+      // Other multiplications are not allowed in SDBM.
+      return {};
+    }
+
+    SDBMExpr visitModExpr(AffineBinaryOpExpr expr) {
+      auto lhs = visit(expr.getLHS()), rhs = visit(expr.getRHS());
+      if (!lhs || !rhs)
+        return {};
+
+      // 'mod' can only be converted to SDBM if its LHS is a variable
+      // and its RHS is a constant.  Then it `x mod c = x - x stripe c`.
+      auto rhsConstant = rhs.dyn_cast<SDBMConstantExpr>();
+      auto lhsVar = rhs.dyn_cast<SDBMPositiveExpr>();
+      if (!lhsVar || !rhsConstant)
+        return {};
+      return SDBMDiffExpr::get(lhsVar,
+                               SDBMStripeExpr::get(lhsVar, rhsConstant));
+    }
+
+    // `a floordiv b = (a stripe b) / b`, but we have no division in SDBM
+    SDBMExpr visitFloorDivExpr(AffineBinaryOpExpr expr) { return {}; }
+    SDBMExpr visitCeilDivExpr(AffineBinaryOpExpr expr) { return {}; }
+
+    // Dimensions, symbols and constants are converted trivially.
+    SDBMExpr visitConstantExpr(AffineConstantExpr expr) {
+      return SDBMConstantExpr::get(expr.getContext(), expr.getValue());
+    }
+    SDBMExpr visitDimExpr(AffineDimExpr expr) {
+      return SDBMDimExpr::get(expr.getContext(), expr.getPosition());
+    }
+    SDBMExpr visitSymbolExpr(AffineSymbolExpr expr) {
+      return SDBMSymbolExpr::get(expr.getContext(), expr.getPosition());
+    }
+  } converter;
+
+  if (auto result = converter.visit(affine))
+    return result;
+  return None;
+}
+
 //===----------------------------------------------------------------------===//
 // SDBMDiffExpr
 //===----------------------------------------------------------------------===//
index 2075db5..91c7e3c 100644 (file)
@@ -15,6 +15,7 @@
 // limitations under the License.
 // =============================================================================
 
+#include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/SDBMExpr.h"
 #include "gtest/gtest.h"
@@ -179,4 +180,64 @@ TEST(SDBMExpr, Diff) {
   EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
 }
 
+TEST(SDBMExpr, AffineRoundTrip) {
+  // Build an expression (s0 - s0 # 2)
+  auto cst2 = SDBMConstantExpr::get(ctx(), 2);
+  auto var = SDBMSymbolExpr::get(ctx(), 0);
+  auto stripe = SDBMStripeExpr::get(var, cst2);
+  auto expr = SDBMDiffExpr::get(var, stripe);
+
+  // Check that it can be converted to AffineExpr and back, i.e. stripe
+  // detection works correctly.
+  Optional<SDBMExpr> roundtripped =
+      SDBMExpr::tryConvertAffineExpr(expr.getAsAffineExpr());
+  ASSERT_TRUE(roundtripped.hasValue());
+  EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(expr));
+
+  // Check that (s0 # 2 # 5) can be converted to AffineExpr, i.e. stripe
+  // detection supports nested expressions.
+  auto cst5 = SDBMConstantExpr::get(ctx(), 5);
+  auto outerStripe = SDBMStripeExpr::get(stripe, cst5);
+  roundtripped = SDBMExpr::tryConvertAffineExpr(outerStripe.getAsAffineExpr());
+  ASSERT_TRUE(roundtripped.hasValue());
+  EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(outerStripe));
+
+  // Check that (s0 # 2 # 5 - s0 # 2) + 2 can be converted as an example of a
+  // deeper expression tree.
+  auto diff = SDBMDiffExpr::get(outerStripe, stripe);
+  auto sum = SDBMSumExpr::get(diff, cst2);
+  roundtripped = SDBMExpr::tryConvertAffineExpr(sum.getAsAffineExpr());
+  ASSERT_TRUE(roundtripped.hasValue());
+  EXPECT_EQ(roundtripped, static_cast<SDBMExpr>(sum));
+}
+
+TEST(SDBMExpr, MatchStripeMulPattern) {
+  // Make sure conversion from AffineExpr recognizes multiplicative stripe
+  // pattern (x floordiv B) * B == x # B.
+  auto cst = getAffineConstantExpr(42, ctx());
+  auto dim = getAffineDimExpr(0, ctx());
+  auto floor = getAffineBinaryOpExpr(AffineExprKind::FloorDiv, dim, cst);
+  auto mul = getAffineBinaryOpExpr(AffineExprKind::Mul, cst, floor);
+  Optional<SDBMExpr> converted = SDBMStripeExpr::tryConvertAffineExpr(mul);
+  ASSERT_TRUE(converted.hasValue());
+  EXPECT_TRUE(converted->isa<SDBMStripeExpr>());
+}
+
+TEST(SDBMExpr, NonSDBM) {
+  auto d0 = getAffineDimExpr(0, ctx());
+  auto d1 = getAffineDimExpr(1, ctx());
+  auto sum = getAffineBinaryOpExpr(AffineExprKind::Add, d0, d1);
+  auto c2 = getAffineConstantExpr(2, ctx());
+  auto prod = getAffineBinaryOpExpr(AffineExprKind::Mul, d0, c2);
+  auto ceildiv = getAffineBinaryOpExpr(AffineExprKind::CeilDiv, d1, c2);
+
+  // The following are not valid SDBM expressions:
+  // - a sum of two variables
+  EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(sum).hasValue());
+  // - a variable with coefficient other than 1 or -1
+  EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(prod).hasValue());
+  // - a ceildiv expression
+  EXPECT_FALSE(SDBMExpr::tryConvertAffineExpr(ceildiv).hasValue());
+}
+
 } // end namespace