llvm::DenseMap<int64_t, int64_t> tileToPointMapping;
// The permutation of outer dims (on domain).
SmallVector<int64_t> outerDimsOnDomainPerm;
- std::optional<Value> paddingValue;
};
-static PackInfo getPackingInfoFromConsumer(
- AffineMap indexingMap, ArrayRef<OpFoldResult> innerTileSizes,
- ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm,
- std::optional<Value> paddingValue = std::nullopt) {
+static PackInfo getPackingInfoFromConsumer(AffineMap indexingMap,
+ tensor::PackOp packOp) {
LLVM_DEBUG(
{ llvm::dbgs() << "--- Construct PackInfo From A Consumer ---\n"; });
PackInfo packInfo;
- packInfo.paddingValue = paddingValue;
int64_t origNumDims = indexingMap.getNumDims();
SmallVector<AffineExpr> exprs(indexingMap.getResults());
+ ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
for (auto [index, innerDimPos, tileSize] :
llvm::zip_equal(llvm::seq<unsigned>(0, innerDimsPos.size()),
- innerDimsPos, innerTileSizes)) {
+ innerDimsPos, packOp.getMixedTiles())) {
int64_t domainDimPos =
exprs[innerDimPos].cast<AffineDimExpr>().getPosition();
packInfo.tiledDimsPos.push_back(domainDimPos);
});
}
- for (auto dim : outerDimsPerm)
+ for (auto dim : packOp.getOuterDimsPerm())
packInfo.outerDimsOnDomainPerm.push_back(indexingMap.getDimPosition(dim));
if (!packInfo.outerDimsOnDomainPerm.empty()) {
LLVM_DEBUG({
b, loc, opOperand->get(), innerTileSizes, innerDimsPos, outerDimsPerm);
auto packedOperand = b.create<tensor::PackOp>(
loc, opOperand->get(), empty, innerDimsPos, innerTileSizes,
- packInfo.paddingValue, outerDimsPerm);
+ /*padding=*/std::nullopt, outerDimsPerm);
return std::make_tuple(packedOperand, indexingMap);
}
OpOperand *opOperand = genericOp.getDpsInitOperand(0);
auto packInfo = getPackingInfoFromConsumer(
- genericOp.getMatchingIndexingMap(opOperand), packOp.getMixedTiles(),
- packOp.getInnerDimsPos(), packOp.getOuterDimsPerm(),
- packOp.getPaddingValue());
+ genericOp.getMatchingIndexingMap(opOperand), packOp);
Location loc = packOp.getLoc();
SmallVector<Value> inputOperands;