};
struct UnitExtentReplacementInfo {
- RankedTensorType type;
+ Type type;
AffineMap indexMap;
ArrayAttr reassociation;
};
AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
ArrayRef<int64_t> shape = genericOp.getShape(opOperand);
ArrayRef<AffineExpr> exprs = indexingMap.getResults();
- SmallVector<AffineExpr, 2> reassociations;
- SmallVector<Attribute, 4> reassociationMaps;
- SmallVector<AffineExpr, 4> newIndexExprs;
- SmallVector<int64_t, 4> newShape;
+ SmallVector<AffineExpr> reassociations;
+ SmallVector<Attribute> reassociationMaps;
+ SmallVector<AffineExpr> newIndexExprs;
+ SmallVector<int64_t> newShape;
int64_t origRank = genericOp.getRank(opOperand);
AffineExpr zeroExpr = getAffineConstantExpr(0, context);
return shape[dim] == 1 && exprs[dim] == zeroExpr;
};
- unsigned dim = 0;
+ int64_t dim = 0;
// Fold dimensions that are unit-extent at the beginning of the tensor.
while (dim < origRank && isUnitExtent(dim))
reassociations.push_back(getAffineDimExpr(dim++, context));
reassociations.clear();
++dim;
}
- UnitExtentReplacementInfo info = {
- RankedTensorType::get(newShape,
- getElementTypeOrSelf(opOperand->get().getType())),
- AffineMap::get(indexingMap.getNumDims(), indexingMap.getNumSymbols(),
- newIndexExprs, context),
- ArrayAttr::get(context, reassociationMaps)};
+ // Compute the tensor or scalar replacement type.
+ Type elementType = getElementTypeOrSelf(opOperand->get().getType());
+ Type replacementType = elementType == opOperand->get().getType()
+ ? elementType
+ : RankedTensorType::get(newShape, elementType);
+ UnitExtentReplacementInfo info = {replacementType,
+ AffineMap::get(indexingMap.getNumDims(),
+ indexingMap.getNumSymbols(),
+ newIndexExprs, context),
+ ArrayAttr::get(context, reassociationMaps)};
return info;
}
MLIRContext *context = rewriter.getContext();
Location loc = genericOp.getLoc();
- SmallVector<AffineMap, 4> newIndexingMaps;
- SmallVector<ArrayAttr, 4> reassociationMaps;
- SmallVector<ShapedType, 4> newInputOutputTypes;
+ SmallVector<AffineMap> newIndexingMaps;
+ SmallVector<ArrayAttr> reassociationMaps;
+ SmallVector<Type> newInputOutputTypes;
bool doCanonicalization = false;
for (OpOperand *opOperand : genericOp.getInputAndOutputOperands()) {
- auto replacementInfo = replaceUnitExtents(genericOp, opOperand, context);
+ UnitExtentReplacementInfo replacementInfo =
+ replaceUnitExtents(genericOp, opOperand, context);
reassociationMaps.push_back(replacementInfo.reassociation);
newIndexingMaps.push_back(replacementInfo.indexMap);
newInputOutputTypes.push_back(replacementInfo.type);