srcType.getRank() < dstType.getRank() ||
parentSrcType.getRank() == dstType.getRank())
return failure();
+
// Check if the result tensor_reshape after folding the reshapeOp and
// parentReshapeOp are combined.
// If the final tensor_reshape is folding, the parentReshapeOp is
// introducing unit-dims, and the reshapeOp does an actual reshape.
- // If the final tensor_reshape op is expanding, the reshapeOp is introducing
- // unit-dims, and the parentReshapeOp does an actual reshape.
+ // If the final tensor_reshape op is expanding, the reshapeOp is
+ // introducing unit-dims, and the parentReshapeOp does an actual reshape.
bool isFoldingPattern = parentSrcType.getRank() > dstType.getRank();
- auto reassociationMaps = isFoldingPattern
- ? reshapeOp.getReassociationMaps()
- : parentReshapeOp.getReassociationMaps();
- DenseSet<unsigned> conservedDimensions;
- for (auto &map : reassociationMaps) {
- if (map.getNumResults() == 1) {
- conservedDimensions.insert(
- map.getResult(0).cast<AffineDimExpr>().getPosition());
- }
- }
-
- // Find positions at which the unit-dims exist.
- int64_t nonUnitDimPos = 0;
- DenseMap<unsigned, unsigned> nonUnitSrcDims;
- ArrayRef<int64_t> nonUnitShape =
+ ArrayRef<int64_t> expandedShape =
isFoldingPattern ? parentSrcType.getShape() : dstType.getShape();
- for (auto shape : enumerate(srcType.getShape())) {
- // Case 1 : It is a conserved dimension.
- if (conservedDimensions.count(shape.index())) {
- nonUnitSrcDims[shape.index()] = nonUnitDimPos++;
- continue;
+ ArrayRef<int64_t> foldedShape =
+ isFoldingPattern ? dstType.getShape() : parentSrcType.getShape();
+
+ unsigned expandedDim = 0, foldedDim = 0;
+ SmallVector<SmallVector<AffineExpr, 4>, 4> reassociationExprs(
+ foldedShape.size());
+ while (expandedDim < expandedShape.size() &&
+ foldedDim < foldedShape.size()) {
+ int64_t dstSize = foldedShape[foldedDim];
+ int64_t srcSize = expandedShape[expandedDim];
+ while (srcSize < dstSize && expandedDim < expandedShape.size()) {
+ reassociationExprs[foldedDim].push_back(
+ rewriter.getAffineDimExpr(expandedDim++));
+ srcSize *= expandedShape[expandedDim];
}
- // Case 2 : Dimensions dont match but the intermediate tensor is unit-dim.
- if (shape.value() == 1)
- continue;
- // Case 3 : Dimensions match, treat it as a non-unit src dim.
- if (nonUnitDimPos < static_cast<int64_t>(nonUnitShape.size()) &&
- nonUnitShape[nonUnitDimPos] == shape.value()) {
- nonUnitSrcDims[shape.index()] = nonUnitDimPos++;
- continue;
+ if (srcSize == dstSize) {
+ reassociationExprs[foldedDim].push_back(
+ rewriter.getAffineDimExpr(expandedDim++));
+ // If the next dim in foldedShape is not 1, treat subsequent dims in
+ // expandedShape which are 1 to be collapsed.
+ if (foldedDim == foldedShape.size() - 1 ||
+ foldedShape[foldedDim + 1] != 1) {
+ while (expandedDim < expandedShape.size() &&
+ expandedShape[expandedDim] == 1) {
+ reassociationExprs[foldedDim].push_back(
+ rewriter.getAffineDimExpr(expandedDim++));
+ }
+ }
+ } else {
+ return failure();
}
- return failure();
+ foldedDim++;
}
+ if (expandedDim != expandedShape.size())
+ return failure();
- // Compute reassociation maps for the final operation. Use the reassociation
- // maps that is actually doing a reshape (and not just introducing
- // unit-dims). From these maps, prune the unit-extent dimensions.
- for (AffineMap &map : reassociationMaps) {
- SmallVector<AffineExpr, 4> exprs;
- exprs.reserve(nonUnitSrcDims.size());
- for (auto result : map.getResults()) {
- unsigned dim = result.cast<AffineDimExpr>().getPosition();
- if (nonUnitSrcDims.count(dim))
- exprs.push_back(rewriter.getAffineDimExpr(nonUnitSrcDims[dim]));
- }
- map = AffineMap::get(nonUnitSrcDims.size(), 0, exprs,
- rewriter.getContext());
- }
+ SmallVector<AffineMap, 4> reassociationMaps =
+ llvm::to_vector<4>(llvm::map_range(
+ reassociationExprs, [&](ArrayRef<AffineExpr> exprs) -> AffineMap {
+ return AffineMap::get(expandedShape.size(), 0, exprs,
+ rewriter.getContext());
+ }));
rewriter.replaceOpWithNewOp<TensorReshapeOp>(
reshapeOp, dstType, parentReshapeOp.src(),
rewriter.getAffineMapArrayAttr(reassociationMaps));