From aeac932943bc7e0c54903d2c4b754e3a87e90fb0 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Wed, 19 Apr 2023 15:31:34 +0200 Subject: [PATCH] [SCF] Clean up ForOpTensorCastFolder and harden it against nop tensor casts The code was inserting a new cast, discarding it, then inserting it again. The self-cast issue is the root of #62135 because it would end up dropping the loop and inserting an invalid cast to itself. As far as I can tell tensor.cast with the same src and dst types is not invalid but it can't really be tested in isolation as it's immediately folded. Fixes #62135 Differential Revision: https://reviews.llvm.org/D148714 --- mlir/lib/Dialect/SCF/IR/SCF.cpp | 30 +++++++++++------------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 06d4add..76b3589 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -886,9 +886,9 @@ struct SimplifyTrivialLoops : public OpRewritePattern { /// Perform a replacement of one iter OpOperand of an scf.for to the /// `replacement` value which is expected to be the source of a tensor.cast. /// tensor.cast ops are inserted inside the block to account for the type cast. -static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter, - OpOperand &operand, - Value replacement) { +static SmallVector +replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand, + Value replacement) { Type oldType = operand.get().getType(), newType = replacement.getType(); assert(oldType.isa() && newType.isa() && "expected ranked tensor types"); @@ -897,8 +897,8 @@ static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter, ForOp forOp = cast(operand.getOwner()); assert(operand.getOperandNumber() >= forOp.getNumControlOperands() && "expected an iter OpOperand"); - if (operand.get().getType() == replacement.getType()) - return forOp; + assert(operand.get().getType() != replacement.getType() && + "Expected a different type"); SmallVector newIterOperands; for (OpOperand &opOperand : forOp.getIterOpOperands()) { if (opOperand.getOperandNumber() == operand.getOperandNumber()) { @@ -949,7 +949,7 @@ static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter, newResults[yieldIdx] = rewriter.create( newForOp.getLoc(), oldType, newResults[yieldIdx]); - return newForOp; + return newResults; } /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing @@ -986,7 +986,8 @@ struct ForOpTensorCastFolder : public OpRewritePattern { for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) { OpOperand &iterOpOperand = std::get<0>(it); auto incomingCast = iterOpOperand.get().getDefiningOp(); - if (!incomingCast) + if (!incomingCast || + incomingCast.getSource().getType() == incomingCast.getType()) continue; // If the dest type of the cast does not preserve static information in // the source type. @@ -998,18 +999,9 @@ struct ForOpTensorCastFolder : public OpRewritePattern { continue; // Create a new ForOp with that iter operand replaced. - auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand, - incomingCast.getSource()); - - // Insert outgoing cast and use it to replace the corresponding result. - rewriter.setInsertionPointAfter(newForOp); - SmallVector replacements = newForOp.getResults(); - unsigned returnIdx = - iterOpOperand.getOperandNumber() - op.getNumControlOperands(); - replacements[returnIdx] = rewriter.create( - op.getLoc(), incomingCast.getDest().getType(), - replacements[returnIdx]); - rewriter.replaceOp(op, replacements); + rewriter.replaceOp( + op, replaceTensorCastForOpIterArg(rewriter, iterOpOperand, + incomingCast.getSource())); return success(); } return failure(); -- 2.7.4