[MLIR] Simplify Semi-affine expressions by rule based matching and replacing "expr...
authorArnab Dutta <arnab.dutta@cerebras.net>
Sat, 20 Nov 2021 15:34:59 +0000 (21:04 +0530)
committerUday Bondhugula <uday@polymagelabs.com>
Sat, 20 Nov 2021 15:35:36 +0000 (21:05 +0530)
Add rule based matching for detecting and transforming "expr - q * (expr floordiv q)"
to "expr mod q", where q is a symbolic exxpression, in simplifyAdd function.

Reviewed By: bondhugula, dcaballe

Differential Revision: https://reviews.llvm.org/D112985

mlir/lib/IR/AffineExpr.cpp
mlir/test/Dialect/Affine/simplify-affine-structures.mlir

index 36be0d4..f0f54ce 100644 (file)
@@ -591,9 +591,10 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
     }
   }
 
-  // Detect and transform "expr - c * (expr floordiv c)" to "expr mod c". This
-  // leads to a much more efficient form when 'c' is a power of two, and in
-  // general a more compact and readable form.
+  // Detect and transform "expr - q * (expr floordiv q)" to "expr mod q", where
+  // q may be a constant or symbolic expression. This leads to a much more
+  // efficient form when 'c' is a power of two, and in general a more compact
+  // and readable form.
 
   // Process '(expr floordiv c) * (-c)'.
   if (!rBinOpExpr)
@@ -602,13 +603,33 @@ static AffineExpr simplifyAdd(AffineExpr lhs, AffineExpr rhs) {
   auto lrhs = rBinOpExpr.getLHS();
   auto rrhs = rBinOpExpr.getRHS();
 
+  AffineExpr llrhs, rlrhs;
+
+  // Check if lrhsBinOpExpr is of the form (expr floordiv q) * q, where q is a
+  // symbolic expression.
+  auto lrhsBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
+  // Check rrhsConstOpExpr = -1.
+  auto rrhsConstOpExpr = rrhs.dyn_cast<AffineConstantExpr>();
+  if (rrhsConstOpExpr && rrhsConstOpExpr.getValue() == -1 && lrhsBinOpExpr &&
+      lrhsBinOpExpr.getKind() == AffineExprKind::Mul) {
+    // Check llrhs = expr floordiv q.
+    llrhs = lrhsBinOpExpr.getLHS();
+    // Check rlrhs = q.
+    rlrhs = lrhsBinOpExpr.getRHS();
+    auto llrhsBinOpExpr = llrhs.dyn_cast<AffineBinaryOpExpr>();
+    if (!llrhsBinOpExpr || llrhsBinOpExpr.getKind() != AffineExprKind::FloorDiv)
+      return nullptr;
+    if (llrhsBinOpExpr.getRHS() == rlrhs && lhs == llrhsBinOpExpr.getLHS())
+      return lhs % rlrhs;
+  }
+
   // Process lrhs, which is 'expr floordiv c'.
   AffineBinaryOpExpr lrBinOpExpr = lrhs.dyn_cast<AffineBinaryOpExpr>();
   if (!lrBinOpExpr || lrBinOpExpr.getKind() != AffineExprKind::FloorDiv)
     return nullptr;
 
-  auto llrhs = lrBinOpExpr.getLHS();
-  auto rlrhs = lrBinOpExpr.getRHS();
+  llrhs = lrBinOpExpr.getLHS();
+  rlrhs = lrBinOpExpr.getRHS();
 
   if (lhs == llrhs && rlrhs == -rrhs) {
     return lhs % rlrhs;
index 8372262..39abea5 100644 (file)
@@ -533,3 +533,17 @@ func @semiaffine_simplification_product(%arg0: index, %arg1: index, %arg2: index
 // CHECK-NEXT: %[[RESULT0:.*]] = affine.apply #[[$PRODUCT]]()[%[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG0]]]
 // CHECK-NEXT: %[[RESULT1:.*]] = affine.apply #[[$SUM_OF_PRODUCTS]]()[%[[ARG3]], %[[ARG4]], %[[ARG0]], %[[ARG1]], %[[ARG2]]]
 // CHECK-NEXT: return %[[RESULT0]], %[[RESULT1]]
+
+// -----
+
+// CHECK-DAG: #[[$SIMPLIFIED_MAP:.*]] = affine_map<()[s0, s1, s2, s3] -> ((-s0 + s2 + s3) mod (s0 + s1))>
+// CHECK-LABEL: func @semi_affine_simplification_euclidean_lemma
+// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index)
+func @semi_affine_simplification_euclidean_lemma(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index, %arg5: index) -> (index, index) {
+  %a = affine.apply affine_map<(d0, d1)[s0, s1] -> ((d0 + d1) - ((d0 + d1) floordiv (s0 - s1)) * (s0 - s1) - (d0 + d1) mod (s0 - s1))>(%arg0, %arg1)[%arg2, %arg3]
+  %b = affine.apply affine_map<(d0, d1)[s0, s1] -> ((d0 + d1 - s0) - ((d0 + d1 - s0) floordiv (s0 + s1)) * (s0 + s1))>(%arg0, %arg1)[%arg2, %arg3]
+  return %a, %b : index, index
+}
+// CHECK-NEXT: %[[ZERO:.*]] = arith.constant 0 : index
+// CHECK-NEXT: %[[RESULT:.*]] = affine.apply #[[$SIMPLIFIED_MAP]]()[%[[ARG2]], %[[ARG3]], %[[ARG0]], %[[ARG1]]]
+// CHECK-NEXT: return %[[ZERO]], %[[RESULT]]