/// Perform a replacement of one iter OpOperand of an scf.for to the
/// `replacement` value which is expected to be the source of a tensor.cast.
/// tensor.cast ops are inserted inside the block to account for the type cast.
-static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
- OpOperand &operand,
- Value replacement) {
+static SmallVector<Value>
+replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
+ Value replacement) {
Type oldType = operand.get().getType(), newType = replacement.getType();
assert(oldType.isa<RankedTensorType>() && newType.isa<RankedTensorType>() &&
"expected ranked tensor types");
ForOp forOp = cast<ForOp>(operand.getOwner());
assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
"expected an iter OpOperand");
- if (operand.get().getType() == replacement.getType())
- return forOp;
+ assert(operand.get().getType() != replacement.getType() &&
+ "Expected a different type");
SmallVector<Value> newIterOperands;
for (OpOperand &opOperand : forOp.getIterOpOperands()) {
if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
newForOp.getLoc(), oldType, newResults[yieldIdx]);
- return newForOp;
+ return newResults;
}
/// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
OpOperand &iterOpOperand = std::get<0>(it);
auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
- if (!incomingCast)
+ if (!incomingCast ||
+ incomingCast.getSource().getType() == incomingCast.getType())
continue;
// If the dest type of the cast does not preserve static information in
// the source type.
continue;
// Create a new ForOp with that iter operand replaced.
- auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
- incomingCast.getSource());
-
- // Insert outgoing cast and use it to replace the corresponding result.
- rewriter.setInsertionPointAfter(newForOp);
- SmallVector<Value> replacements = newForOp.getResults();
- unsigned returnIdx =
- iterOpOperand.getOperandNumber() - op.getNumControlOperands();
- replacements[returnIdx] = rewriter.create<tensor::CastOp>(
- op.getLoc(), incomingCast.getDest().getType(),
- replacements[returnIdx]);
- rewriter.replaceOp(op, replacements);
+ rewriter.replaceOp(
+ op, replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
+ incomingCast.getSource()));
return success();
}
return failure();