From: Lorenzo Chelini Date: Thu, 12 Jan 2023 14:22:59 +0000 (+0100) Subject: Reduce input arguments for `getPackingInfoFromConsumer` (NFC) X-Git-Tag: upstream/17.0.6~20536 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=ff6f4ae7b7a93e97a9a7ab0516f77baebc960d1c;p=platform%2Fupstream%2Fllvm.git Reduce input arguments for `getPackingInfoFromConsumer` (NFC) Pass to `getPackingInfoFromConsumer` the tensor.pack op instead of all the arguments derived from it. Additionally, remove the padding from the struct, as we currently don't handle propagation when `tensor.pack` requires padding. We will add back the field when we will need it. Reviewed By: hanchung Differential Revision: https://reviews.llvm.org/D141837 --- diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 4805354..df2b15d 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -44,22 +44,19 @@ struct PackInfo { llvm::DenseMap tileToPointMapping; // The permutation of outer dims (on domain). SmallVector outerDimsOnDomainPerm; - std::optional paddingValue; }; -static PackInfo getPackingInfoFromConsumer( - AffineMap indexingMap, ArrayRef innerTileSizes, - ArrayRef innerDimsPos, ArrayRef outerDimsPerm, - std::optional paddingValue = std::nullopt) { +static PackInfo getPackingInfoFromConsumer(AffineMap indexingMap, + tensor::PackOp packOp) { LLVM_DEBUG( { llvm::dbgs() << "--- Construct PackInfo From A Consumer ---\n"; }); PackInfo packInfo; - packInfo.paddingValue = paddingValue; int64_t origNumDims = indexingMap.getNumDims(); SmallVector exprs(indexingMap.getResults()); + ArrayRef innerDimsPos = packOp.getInnerDimsPos(); for (auto [index, innerDimPos, tileSize] : llvm::zip_equal(llvm::seq(0, innerDimsPos.size()), - innerDimsPos, innerTileSizes)) { + innerDimsPos, packOp.getMixedTiles())) { int64_t domainDimPos = exprs[innerDimPos].cast().getPosition(); packInfo.tiledDimsPos.push_back(domainDimPos); @@ -74,7 +71,7 @@ static PackInfo getPackingInfoFromConsumer( }); } - for (auto dim : outerDimsPerm) + for (auto dim : packOp.getOuterDimsPerm()) packInfo.outerDimsOnDomainPerm.push_back(indexingMap.getDimPosition(dim)); if (!packInfo.outerDimsOnDomainPerm.empty()) { LLVM_DEBUG({ @@ -208,7 +205,7 @@ getOrCreatePackedViewOfOperand(OpBuilder &b, Location loc, PackInfo packInfo, b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm); auto packedOperand = b.create( loc, opOperand->get(), empty, innerDimsPos, innerTileSizes, - packInfo.paddingValue, outerDimsPerm); + /*padding=*/std::nullopt, outerDimsPerm); return std::make_tuple(packedOperand, indexingMap); } @@ -279,9 +276,7 @@ bubbleUpPackOpThroughElemGenericOp(RewriterBase &rewriter, OpOperand *opOperand = genericOp.getDpsInitOperand(0); auto packInfo = getPackingInfoFromConsumer( - genericOp.getMatchingIndexingMap(opOperand), packOp.getMixedTiles(), - packOp.getInnerDimsPos(), packOp.getOuterDimsPerm(), - packOp.getPaddingValue()); + genericOp.getMatchingIndexingMap(opOperand), packOp); Location loc = packOp.getLoc(); SmallVector inputOperands;