Use AffineMap::getSliceMap where applicable. NFCI.
authorBenjamin Kramer <benny.kra@googlemail.com>
Sat, 12 Feb 2022 13:19:35 +0000 (14:19 +0100)
committerBenjamin Kramer <benny.kra@googlemail.com>
Sat, 12 Feb 2022 13:22:05 +0000 (14:22 +0100)
mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
mlir/lib/IR/AffineMap.cpp

index 7604d14..86ef121 100644 (file)
@@ -15,7 +15,7 @@
 #include "mlir/IR/AffineExprVisitor.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/TypeUtilities.h"
-#include "llvm/ADT/SmallSet.h"
+#include "llvm/ADT/SmallBitVector.h"
 
 using namespace mlir;
 using namespace mlir::linalg;
@@ -484,15 +484,15 @@ SmallVector<int64_t, 4> LinalgOp::computeStaticLoopSizes() {
 /// are used within an AffineExpr.
 struct HasAffineDimExprVisitor
     : public AffineExprVisitor<HasAffineDimExprVisitor, bool> {
-  HasAffineDimExprVisitor(llvm::SmallSet<unsigned, 4> &positions)
-      : positions(positions) {}
+  HasAffineDimExprVisitor(llvm::SmallBitVector positions)
+      : positions(std::move(positions)) {}
 
   bool visitAffineBinaryOpExpr(AffineBinaryOpExpr binaryOpExpr) {
     return visit(binaryOpExpr.getLHS()) || visit(binaryOpExpr.getRHS());
   }
 
   bool visitDimExpr(AffineDimExpr dimExpr) {
-    return positions.count(dimExpr.getPosition());
+    return positions.test(dimExpr.getPosition());
   }
 
   bool visitConstantExpr(AffineConstantExpr constExpr) { return false; }
@@ -500,7 +500,7 @@ struct HasAffineDimExprVisitor
   bool visitSymbolExpr(AffineSymbolExpr symbolExpr) { return false; }
 
 private:
-  llvm::SmallSet<unsigned, 4> positions;
+  llvm::SmallBitVector positions;
 };
 
 LogicalResult
@@ -523,19 +523,17 @@ LinalgOp::reifyResultShapes(OpBuilder &b,
 
   /// From loopsToShapesMap extract the submap that represents the shape of the
   /// (resultIdx, dim) needed.
-  SmallVector<unsigned, 4> resultPosRange =
-      llvm::to_vector<4>(llvm::seq<unsigned>(resultShapesSubMapPos.first,
-                                             resultShapesSubMapPos.second));
-  AffineMap loopToResultsShapeMap = loopsToShapesMap.getSubMap(resultPosRange);
+  AffineMap loopToResultsShapeMap = loopsToShapesMap.getSliceMap(
+      resultShapesSubMapPos.first,
+      resultShapesSubMapPos.second - resultShapesSubMapPos.first);
   AffineMap resultShapesFromInputShapesMap =
       loopToResultsShapeMap.compose(getShapesToLoopsMap());
 
   // Check that the result dim map does not contain the positions corresponding
   // to the outputs.
-  llvm::SmallSet<unsigned, 4> outputDims;
-  llvm::for_each(resultPosRange,
-                 [&outputDims](unsigned dim) { outputDims.insert(dim); });
-  HasAffineDimExprVisitor checkDimExpr(outputDims);
+  llvm::SmallBitVector outputDims(resultShapesFromInputShapesMap.getNumDims());
+  outputDims.set(resultShapesSubMapPos.first, resultShapesSubMapPos.second);
+  HasAffineDimExprVisitor checkDimExpr(std::move(outputDims));
   Location loc = getOperation()->getLoc();
   auto allResultDimValues =
       applyMapToValues(b, loc, resultShapesFromInputShapesMap,
index 7fbe748..73cbd3b 100644 (file)
@@ -534,7 +534,7 @@ AffineMap AffineMap::getMajorSubMap(unsigned numResults) const {
     return AffineMap();
   if (numResults > getNumResults())
     return *this;
-  return getSubMap(llvm::to_vector<4>(llvm::seq<unsigned>(0, numResults)));
+  return getSliceMap(0, numResults);
 }
 
 AffineMap AffineMap::getMinorSubMap(unsigned numResults) const {
@@ -542,8 +542,7 @@ AffineMap AffineMap::getMinorSubMap(unsigned numResults) const {
     return AffineMap();
   if (numResults > getNumResults())
     return *this;
-  return getSubMap(llvm::to_vector<4>(
-      llvm::seq<unsigned>(getNumResults() - numResults, getNumResults())));
+  return getSliceMap(getNumResults() - numResults, numResults);
 }
 
 AffineMap mlir::compressDims(AffineMap map,