}));
SmallVector<Value> expandedOpOperands;
+ expandedOpOperands.reserve(genericOp.getNumInputs());
for (OpOperand *opOperand : genericOp.getInputOperands()) {
if (opOperand == fusableOpOperand) {
expandedOpOperands.push_back(isExpanding ? expandingReshapeOp.src()
: collapsingReshapeOp.src());
continue;
}
- AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
- RankedTensorType expandedOperandType =
- getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
- indexingMap, expansionInfo);
- if (expandedOperandType != opOperand->get().getType()) {
- // Reshape the operand to get the right type.
- SmallVector<ReassociationIndices> reassociation =
- getReassociationForExpansion(indexingMap, expansionInfo);
- expandedOpOperands.push_back(rewriter.create<TensorExpandShapeOp>(
- genericOp.getLoc(), expandedOperandType, opOperand->get(),
- reassociation));
- continue;
+ if (genericOp.isInputTensor(opOperand)) {
+ AffineMap indexingMap = genericOp.getTiedIndexingMap(opOperand);
+ RankedTensorType expandedOperandType =
+ getExpandedType(opOperand->get().getType().cast<RankedTensorType>(),
+ indexingMap, expansionInfo);
+ if (expandedOperandType != opOperand->get().getType()) {
+ // Reshape the operand to get the right type.
+ SmallVector<ReassociationIndices> reassociation =
+ getReassociationForExpansion(indexingMap, expansionInfo);
+ expandedOpOperands.push_back(rewriter.create<TensorExpandShapeOp>(
+ genericOp.getLoc(), expandedOperandType, opOperand->get(),
+ reassociation));
+ continue;
+ }
}
expandedOpOperands.push_back(opOperand->get());
}
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
- for (OpOperand *opOperand : genericOp.getInputOperands()) {
+ for (OpOperand *opOperand : genericOp.getInputTensorOperands()) {
TensorCollapseShapeOp reshapeOp =
opOperand->get().getDefiningOp<TensorCollapseShapeOp>();
if (!reshapeOp)