[mlir] refactor common idiom into AffineMap method
authorAart Bik <ajcbik@google.com>
Sat, 14 Nov 2020 02:11:47 +0000 (18:11 -0800)
committerAart Bik <ajcbik@google.com>
Sat, 14 Nov 2020 03:18:13 +0000 (19:18 -0800)
motivated by a refactoring in the new sparse code (yet to be merged), this avoids some lengthy code dup

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D91465

mlir/include/mlir/IR/AffineMap.h
mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
mlir/lib/Dialect/Vector/VectorOps.cpp
mlir/lib/Dialect/Vector/VectorTransforms.cpp
mlir/lib/IR/AffineMap.cpp

index c450024..f1f267f 100644 (file)
@@ -125,6 +125,10 @@ public:
   ArrayRef<AffineExpr> getResults() const;
   AffineExpr getResult(unsigned idx) const;
 
+  /// Extracts the position of the dimensional expression at the given result,
+  /// when the caller knows it is safe to do so.
+  unsigned getDimPosition(unsigned idx) const;
+
   /// Walk all of the AffineExpr's in this mapping. Each node in an expression
   /// tree is visited in postorder.
   void walkExprs(std::function<void(AffineExpr)> callback) const;
index abc10e8..8e1dbf1 100644 (file)
@@ -466,9 +466,8 @@ fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
     unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
     AffineMap foldedDims = reassociationMaps[resultExpr.index()];
     numFoldedDims[pos] = foldedDims.getNumResults();
-    ArrayRef<int64_t> shape = expandedShape.slice(
-        foldedDims.getResult(0).cast<AffineDimExpr>().getPosition(),
-        numFoldedDims[pos]);
+    ArrayRef<int64_t> shape =
+        expandedShape.slice(foldedDims.getDimPosition(0), numFoldedDims[pos]);
     expandedDimsShape[pos].assign(shape.begin(), shape.end());
   }
 
index 39aed71..0cc1e7c 100644 (file)
@@ -336,7 +336,7 @@ static LogicalResult verifyOutputShape(
       VectorType v = pair.first;
       auto map = pair.second;
       for (unsigned idx = 0, e = v.getRank(); idx < e; ++idx) {
-        unsigned pos = map.getResult(idx).cast<AffineDimExpr>().getPosition();
+        unsigned pos = map.getDimPosition(idx);
         if (!extents[pos])
           extents[pos] = getAffineConstantExpr(v.getShape()[idx], ctx);
       }
@@ -785,8 +785,7 @@ static Value foldExtractOpFromInsertChainAndTranspose(ExtractOp extractOp) {
     if (insertedPos.size() == extractedPos.size()) {
       bool fold = true;
       for (unsigned idx = 0, sz = extractedPos.size(); idx < sz; ++idx) {
-        auto pos =
-            permutationMap.getResult(idx).cast<AffineDimExpr>().getPosition();
+        auto pos = permutationMap.getDimPosition(idx);
         if (pos >= sz || insertedPos[pos] != extractedPos[idx]) {
           fold = false;
           break;
index 49865fd..e488db6 100644 (file)
@@ -50,7 +50,7 @@ using llvm::dbgs;
 // Helper to find an index in an affine map.
 static Optional<int64_t> getResultIndex(AffineMap map, int64_t index) {
   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
-    int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
+    int64_t idx = map.getDimPosition(i);
     if (idx == index)
       return i;
   }
@@ -76,7 +76,7 @@ static AffineMap adjustMap(AffineMap map, int64_t index,
   auto *ctx = rewriter.getContext();
   SmallVector<AffineExpr, 4> results;
   for (int64_t i = 0, e = map.getNumResults(); i < e; ++i) {
-    int64_t idx = map.getResult(i).cast<AffineDimExpr>().getPosition();
+    int64_t idx = map.getDimPosition(i);
     if (idx == index)
       continue;
     // Re-insert remaining indices, but renamed when occurring
@@ -2016,16 +2016,13 @@ Value ContractionOpLowering::lowerParallel(vector::ContractionOp op,
   int64_t iterIndex = -1;
   int64_t dimSize = -1;
   if (lhsIndex >= 0) {
-    iterIndex = iMap[0].getResult(lhsIndex).cast<AffineDimExpr>().getPosition();
-    assert(
-        (rhsIndex < 0 ||
-         iterIndex ==
-             iMap[1].getResult(rhsIndex).cast<AffineDimExpr>().getPosition()) &&
-        "parallel index should be free in LHS or batch in LHS/RHS");
+    iterIndex = iMap[0].getDimPosition(lhsIndex);
+    assert((rhsIndex < 0 || iterIndex == iMap[1].getDimPosition(rhsIndex)) &&
+           "parallel index should be free in LHS or batch in LHS/RHS");
     dimSize = lhsType.getDimSize(lhsIndex);
   } else {
     assert(rhsIndex >= 0 && "missing parallel index");
-    iterIndex = iMap[1].getResult(rhsIndex).cast<AffineDimExpr>().getPosition();
+    iterIndex = iMap[1].getDimPosition(rhsIndex);
     dimSize = rhsType.getDimSize(rhsIndex);
   }
   assert(iterIndex >= 0 && "parallel index not listed in operand mapping");
index 1f73d07..cc2cb8b 100644 (file)
@@ -227,6 +227,10 @@ AffineExpr AffineMap::getResult(unsigned idx) const {
   return map->results[idx];
 }
 
+unsigned AffineMap::getDimPosition(unsigned idx) const {
+  return getResult(idx).cast<AffineDimExpr>().getPosition();
+}
+
 /// Folds the results of the application of an affine map on the provided
 /// operands to a constant if possible. Returns false if the folding happens,
 /// true otherwise.