Reduce input arguments for `getPackingInfoFromConsumer` (NFC)
authorLorenzo Chelini <l.chelini@icloud.com>
Thu, 12 Jan 2023 14:22:59 +0000 (15:22 +0100)
committerLorenzo Chelini <l.chelini@icloud.com>
Wed, 18 Jan 2023 18:03:38 +0000 (19:03 +0100)
Pass to `getPackingInfoFromConsumer` the tensor.pack op instead of all
the arguments derived from it. Additionally, remove the padding from the
struct, as we currently don't handle propagation when `tensor.pack`
requires padding. We will add back the field when we will need it.

Reviewed By: hanchung

Differential Revision: https://reviews.llvm.org/D141837

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

index 4805354..df2b15d 100644 (file)
@@ -44,22 +44,19 @@ struct PackInfo {
   llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
   // The permutation of outer dims (on domain).
   SmallVector<int64_t> outerDimsOnDomainPerm;
-  std::optional<Value> paddingValue;
 };
 
-static PackInfo getPackingInfoFromConsumer(
-    AffineMap indexingMap, ArrayRef<OpFoldResult> innerTileSizes,
-    ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm,
-    std::optional<Value> paddingValue = std::nullopt) {
+static PackInfo getPackingInfoFromConsumer(AffineMap indexingMap,
+                                           tensor::PackOp packOp) {
   LLVM_DEBUG(
       { llvm::dbgs() << "--- Construct PackInfo From A Consumer ---\n"; });
   PackInfo packInfo;
-  packInfo.paddingValue = paddingValue;
   int64_t origNumDims = indexingMap.getNumDims();
   SmallVector<AffineExpr> exprs(indexingMap.getResults());
+  ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
   for (auto [index, innerDimPos, tileSize] :
        llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
-                       innerDimsPos, innerTileSizes)) {
+                       innerDimsPos, packOp.getMixedTiles())) {
     int64_t domainDimPos =
         exprs[innerDimPos].cast<AffineDimExpr>().getPosition();
     packInfo.tiledDimsPos.push_back(domainDimPos);
@@ -74,7 +71,7 @@ static PackInfo getPackingInfoFromConsumer(
     });
   }
 
-  for (auto dim : outerDimsPerm)
+  for (auto dim : packOp.getOuterDimsPerm())
     packInfo.outerDimsOnDomainPerm.push_back(indexingMap.getDimPosition(dim));
   if (!packInfo.outerDimsOnDomainPerm.empty()) {
     LLVM_DEBUG({
@@ -208,7 +205,7 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo,
       b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
   auto packedOperand = b.create<tensor::PackOp>(
       loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
-      packInfo.paddingValue, outerDimsPerm);
+      /*padding=*/std::nullopt, outerDimsPerm);
   return std::make_tuple(packedOperand, indexingMap);
 }
 
@@ -279,9 +276,7 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter,
 
   OpOperand *opOperand = genericOp.getDpsInitOperand(0);
   auto packInfo = getPackingInfoFromConsumer(
-      genericOp.getMatchingIndexingMap(opOperand), packOp.getMixedTiles(),
-      packOp.getInnerDimsPos(), packOp.getOuterDimsPerm(),
-      packOp.getPaddingValue());
+      genericOp.getMatchingIndexingMap(opOperand), packOp);
 
   Location loc = packOp.getLoc();
   SmallVector<Value> inputOperands;