[SCF] Clean up ForOpTensorCastFolder and harden it against nop tensor casts
authorBenjamin Kramer <benny.kra@googlemail.com>
Wed, 19 Apr 2023 13:31:34 +0000 (15:31 +0200)
committerBenjamin Kramer <benny.kra@googlemail.com>
Wed, 19 Apr 2023 14:05:59 +0000 (16:05 +0200)
The code was inserting a new cast, discarding it, then inserting it
again.

The self-cast issue is the root of #62135 because it would end up
dropping the loop and inserting an invalid cast to itself. As far as I
can tell tensor.cast with the same src and dst types is not invalid but
it can't really be tested in isolation as it's immediately folded.

Fixes #62135

Differential Revision: https://reviews.llvm.org/D148714

mlir/lib/Dialect/SCF/IR/SCF.cpp

index 06d4add..76b3589 100644 (file)
@@ -886,9 +886,9 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
 /// Perform a replacement of one iter OpOperand of an scf.for to the
 /// `replacement` value which is expected to be the source of a tensor.cast.
 /// tensor.cast ops are inserted inside the block to account for the type cast.
-static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
-                                           OpOperand &operand,
-                                           Value replacement) {
+static SmallVector<Value>
+replaceTensorCastForOpIterArg(PatternRewriter &rewriter, OpOperand &operand,
+                              Value replacement) {
   Type oldType = operand.get().getType(), newType = replacement.getType();
   assert(oldType.isa<RankedTensorType>() && newType.isa<RankedTensorType>() &&
          "expected ranked tensor types");
@@ -897,8 +897,8 @@ static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
   ForOp forOp = cast<ForOp>(operand.getOwner());
   assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
          "expected an iter OpOperand");
-  if (operand.get().getType() == replacement.getType())
-    return forOp;
+  assert(operand.get().getType() != replacement.getType() &&
+         "Expected a different type");
   SmallVector<Value> newIterOperands;
   for (OpOperand &opOperand : forOp.getIterOpOperands()) {
     if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
@@ -949,7 +949,7 @@ static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
   newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
       newForOp.getLoc(), oldType, newResults[yieldIdx]);
 
-  return newForOp;
+  return newResults;
 }
 
 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
@@ -986,7 +986,8 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
     for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
       OpOperand &iterOpOperand = std::get<0>(it);
       auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
-      if (!incomingCast)
+      if (!incomingCast ||
+          incomingCast.getSource().getType() == incomingCast.getType())
         continue;
       // If the dest type of the cast does not preserve static information in
       // the source type.
@@ -998,18 +999,9 @@ struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
         continue;
 
       // Create a new ForOp with that iter operand replaced.
-      auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
-                                                    incomingCast.getSource());
-
-      // Insert outgoing cast and use it to replace the corresponding result.
-      rewriter.setInsertionPointAfter(newForOp);
-      SmallVector<Value> replacements = newForOp.getResults();
-      unsigned returnIdx =
-          iterOpOperand.getOperandNumber() - op.getNumControlOperands();
-      replacements[returnIdx] = rewriter.create<tensor::CastOp>(
-          op.getLoc(), incomingCast.getDest().getType(),
-          replacements[returnIdx]);
-      rewriter.replaceOp(op, replacements);
+      rewriter.replaceOp(
+          op, replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
+                                            incomingCast.getSource()));
       return success();
     }
     return failure();