[MLIR] Fix bug in the method constructing semi affine expression from flattened form
authorArnab Dutta <arnab@polymagelabs.com>
Sun, 6 Nov 2022 06:59:10 +0000 (12:29 +0530)
committerUday Bondhugula <uday@polymagelabs.com>
Sun, 6 Nov 2022 06:59:17 +0000 (12:29 +0530)
Set proper offset to the second element of the index pair when either
lhs or rhs of a local expression is a dimensional identifier, so that
we do not have same index values for more than one local expression.

Reviewed By: springerm, hanchung

Differential Revision: https://reviews.llvm.org/D137389

mlir/lib/IR/AffineExpr.cpp
mlir/test/Dialect/Affine/simplify-structures.mlir

index e0f4547..00778cd 100644 (file)
@@ -986,18 +986,9 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
   // constant coefficient corresponding to the indices in `coefficients` map,
   // and affine expression corresponding to indices in `indexToExprMap` map.
 
-  for (unsigned j = 0; j < numDims; ++j) {
-    if (flatExprs[j] == 0)
-      continue;
-    // For dimensional expressions we set the index as <position number of the
-    // dimension, 0>, as we want dimensional expressions to appear before
-    // symbolic ones and products of dimensional and symbolic expressions
-    // having the dimension with the same position number.
-    std::pair<unsigned, signed> indexEntry(j, -1);
-    addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context));
-  }
   // Ensure we do not have duplicate keys in `indexToExpr` map.
-  unsigned offset = 0;
+  unsigned offsetSym = 0;
+  signed offsetDim = -1;
   for (unsigned j = numDims; j < numDims + numSymbols; ++j) {
     if (flatExprs[j] == 0)
       continue;
@@ -1006,7 +997,7 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
     // as we want symbolic expressions with the same positional number to
     // appear after dimensional expressions having the same positional number.
     std::pair<unsigned, signed> indexEntry(
-        j - numDims, std::max(numDims, numSymbols) + offset++);
+        j - numDims, std::max(numDims, numSymbols) + offsetSym++);
     addEntry(indexEntry, flatExprs[j],
              getAffineSymbolExpr(j - numDims, context));
   }
@@ -1038,13 +1029,13 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
       // constructing. When rhs is constant, we place 0 in place of keyB.
       if (lhs.isa<AffineDimExpr>()) {
         lhsPos = lhs.cast<AffineDimExpr>().getPosition();
-        std::pair<unsigned, signed> indexEntry(lhsPos, -1);
+        std::pair<unsigned, signed> indexEntry(lhsPos, offsetDim--);
         addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
                  expr);
       } else {
         lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
         std::pair<unsigned, signed> indexEntry(
-            lhsPos, std::max(numDims, numSymbols) + offset++);
+            lhsPos, std::max(numDims, numSymbols) + offsetSym++);
         addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()],
                  expr);
       }
@@ -1066,12 +1057,23 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
       lhsPos = lhs.cast<AffineSymbolExpr>().getPosition();
       rhsPos = rhs.cast<AffineSymbolExpr>().getPosition();
       std::pair<unsigned, signed> indexEntry(
-          lhsPos, std::max(numDims, numSymbols) + offset++);
+          lhsPos, std::max(numDims, numSymbols) + offsetSym++);
       addEntry(indexEntry, flatExprs[numDims + numSymbols + it.index()], expr);
     }
     addedToMap[it.index()] = true;
   }
 
+  for (unsigned j = 0; j < numDims; ++j) {
+    if (flatExprs[j] == 0)
+      continue;
+    // For dimensional expressions we set the index as <position number of the
+    // dimension, 0>, as we want dimensional expressions to appear before
+    // symbolic ones and products of dimensional and symbolic expressions
+    // having the dimension with the same position number.
+    std::pair<unsigned, signed> indexEntry(j, offsetDim--);
+    addEntry(indexEntry, flatExprs[j], getAffineDimExpr(j, context));
+  }
+
   // Constructing the simplified semi-affine sum of product/division/mod
   // expression from the flattened form in the desired sorted order of indices
   // of the various individual product/division/mod expressions.
index 903d11e..2c693ea 100644 (file)
@@ -557,3 +557,13 @@ func.func @semiaffine_modulo(%arg0: index) -> index {
   // CHECK: affine.apply #[[$MAP]]()[%{{.*}}]
   return %a : index
 }
+
+// -----
+
+// CHECK-DAG: #[[$MAP:.*]] = affine_map<()[s0, s1, s2] -> (s2 mod 2 + (s1 floordiv 2) * 2 + ((s2 floordiv 2) * s0) * 2)>
+// CHECK-LABEL: func @semiaffine_modulo_dim
+func.func @semiaffine_modulo_dim(%arg0: index, %arg1: index, %arg2: index) -> index {
+  %a = affine.apply affine_map<(d0)[s0, s1] -> (((d0 floordiv 2) * s0 + s1 floordiv 2) * 2 + d0 mod 2)> (%arg0)[%arg1, %arg2]
+  //CHECK: affine.apply #[[$MAP]]()[%{{.*}}, %{{.*}}, %{{.*}}]
+  return %a : index
+}