Introduce SDBMDirect expression into the SDBM expression hierarchy
authorAlex Zinenko <zinenko@google.com>
Mon, 16 Sep 2019 15:08:22 +0000 (08:08 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Mon, 16 Sep 2019 15:09:01 +0000 (08:09 -0700)
Direct expressions are those that do not negate any of the variables they
involve.  They include input expressions (dimensions and symbols), stripe and
sum expressions, and combinations of those.  Reifying direct expressions as a
class is a precondition for enabling additions on the LHS of a stripe
expression.

PiperOrigin-RevId: 269336031

mlir/include/mlir/Dialect/SDBM/SDBMExpr.h
mlir/unittests/SDBM/SDBMTest.cpp

index fdb914d..f643d92 100644 (file)
@@ -176,14 +176,30 @@ public:
   }
 };
 
+/// SDBM direct expression includes exactly one variable (symbol or dimension),
+/// which is not negated in the expression.  It can be one of:
+///   - term expression;
+///   - sum expression.
+class SDBMDirectExpr : public SDBMVaryingExpr {
+public:
+  using SDBMVaryingExpr::SDBMVaryingExpr;
+
+  static bool isClassFor(const SDBMExpr &expr) {
+    return expr.getKind() == SDBMExprKind::DimId ||
+           expr.getKind() == SDBMExprKind::SymbolId ||
+           expr.getKind() == SDBMExprKind::Stripe ||
+           expr.getKind() == SDBMExprKind::Add;
+  }
+};
+
 /// SDBM term expression can be one of:
 ///  - single variable expression;
 ///  - stripe expression.
 /// Stripe expressions are treated as terms since, in the SDBM domain, they are
 /// attached to temporary variables and can appear anywhere a variable can.
-class SDBMTermExpr : public SDBMVaryingExpr {
+class SDBMTermExpr : public SDBMDirectExpr {
 public:
-  using SDBMVaryingExpr::SDBMVaryingExpr;
+  using SDBMDirectExpr::SDBMDirectExpr;
 
   static bool isClassFor(const SDBMExpr &expr) {
     return expr.getKind() == SDBMExprKind::DimId ||
@@ -194,10 +210,10 @@ public:
 
 /// SDBM sum expression.  LHS is a varying expression and RHS is always a
 /// constant expression.
-class SDBMSumExpr : public SDBMVaryingExpr {
+class SDBMSumExpr : public SDBMDirectExpr {
 public:
   using ImplType = detail::SDBMBinaryExprStorage;
-  using SDBMVaryingExpr::SDBMVaryingExpr;
+  using SDBMDirectExpr::SDBMDirectExpr;
 
   /// Obtain or create a sum expression unique'ed in the given context.
   static SDBMSumExpr get(SDBMVaryingExpr lhs, SDBMConstantExpr rhs);
@@ -352,9 +368,19 @@ protected:
   void visitNeg(SDBMNegExpr) {}
   void visitConstant(SDBMConstantExpr) {}
 
-  /// Default implementation of visitTerm dispatches to the special
-  /// functions for stripes and other variables.  Concrete visitors can override
-  /// it.
+  /// Default implementation of visitDirect dispatches to the dedicated for sums
+  /// or delegates to visitTerm for the other expression kinds.  Concrete
+  /// visitors can overload it.
+  Result visitDirect(SDBMDirectExpr expr) {
+    auto *derived = static_cast<Derived *>(this);
+    if (auto sum = expr.dyn_cast<SDBMSumExpr>())
+      return derived->visitSum(sum);
+    else
+      return derived->visitTerm(expr.cast<SDBMTermExpr>());
+  }
+
+  /// Default implementation of visitTerm dispatches to the special functions
+  /// for stripes and other variables.  Concrete visitors can override it.
   Result visitTerm(SDBMTermExpr expr) {
     auto *derived = static_cast<Derived *>(this);
     if (expr.getKind() == SDBMExprKind::Stripe)
@@ -379,12 +405,10 @@ protected:
   /// override it to visit all variables and negations instead.
   Result visitVarying(SDBMVaryingExpr expr) {
     auto *derived = static_cast<Derived *>(this);
-    if (auto var = expr.dyn_cast<SDBMTermExpr>())
-      return derived->visitTerm(var);
+    if (auto var = expr.dyn_cast<SDBMDirectExpr>())
+      return derived->visitDirect(var);
     else if (auto neg = expr.dyn_cast<SDBMNegExpr>())
       return derived->visitNeg(neg);
-    else if (auto sum = expr.dyn_cast<SDBMSumExpr>())
-      return derived->visitSum(sum);
     else if (auto diff = expr.dyn_cast<SDBMDiffExpr>())
       return derived->visitDiff(diff);
 
index b6f8080..af44c80 100644 (file)
@@ -174,6 +174,7 @@ TEST(SDBMExpr, Dim) {
   EXPECT_TRUE(generic.isa<SDBMDimExpr>());
   EXPECT_TRUE(generic.isa<SDBMInputExpr>());
   EXPECT_TRUE(generic.isa<SDBMTermExpr>());
+  EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
   EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
 
   // Dimensions are not Symbols.
@@ -196,6 +197,7 @@ TEST(SDBMExpr, Symbol) {
   EXPECT_TRUE(generic.isa<SDBMSymbolExpr>());
   EXPECT_TRUE(generic.isa<SDBMInputExpr>());
   EXPECT_TRUE(generic.isa<SDBMTermExpr>());
+  EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
   EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
 
   // Dimensions are not Symbols.
@@ -229,6 +231,7 @@ TEST(SDBMExpr, Stripe) {
   auto generic = static_cast<SDBMExpr>(expr);
   EXPECT_TRUE(generic.isa<SDBMStripeExpr>());
   EXPECT_TRUE(generic.isa<SDBMTermExpr>());
+  EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
   EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
 }
 
@@ -271,6 +274,7 @@ TEST(SDBMExpr, Sum) {
   // Hierarchy is okay.
   auto generic = static_cast<SDBMExpr>(expr);
   EXPECT_TRUE(generic.isa<SDBMSumExpr>());
+  EXPECT_TRUE(generic.isa<SDBMDirectExpr>());
   EXPECT_TRUE(generic.isa<SDBMVaryingExpr>());
 }